[RFC PATCH 5/6] iommu: Support mm PASID 1:1 with sva domain

From: Tina Zhang
Date: Thu Jul 06 2023 - 21:35:30 EST


Each mm bound to devices gets a PASID and a corresponding sva domain
allocated in iommu_sva_bind_device(), which are referenced by iommu_mm
field of the mm. And that PASID and sva domain get released in iommu_sva_
unbind_device() when no devices are binding to that mm. As a result,
during the life cycle, sva domain has 1:1 with mm PASID.

Since the required info of PASID and sva domain are kept in struct
iommu_mm_data of a mm, use mm->iommu_mm field instead of the old pasid
field in mm struct.

Signed-off-by: Tina Zhang <tina.zhang@xxxxxxxxx>
---
drivers/iommu/iommu-sva.c | 54 +++++++++++++++++++++++++++------------
drivers/iommu/iommu.c | 1 +
include/linux/iommu.h | 8 +++---
3 files changed, 42 insertions(+), 21 deletions(-)

diff --git a/drivers/iommu/iommu-sva.c b/drivers/iommu/iommu-sva.c
index 7a41b6510e385..342d8ba9ab479 100644
--- a/drivers/iommu/iommu-sva.c
+++ b/drivers/iommu/iommu-sva.c
@@ -15,6 +15,7 @@ static DEFINE_IDA(iommu_global_pasid_ida);
/* Allocate a PASID for the mm within range (inclusive) */
static int iommu_sva_alloc_pasid(struct mm_struct *mm, ioasid_t min, ioasid_t max)
{
+ struct iommu_mm_data *iommu_mm = NULL;
int ret = 0;

if (min == IOMMU_PASID_INVALID ||
@@ -33,9 +34,18 @@ static int iommu_sva_alloc_pasid(struct mm_struct *mm, ioasid_t min, ioasid_t ma
goto out;
}

+ iommu_mm = kzalloc(sizeof(struct iommu_mm_data), GFP_KERNEL);
+ if (!iommu_mm) {
+ ret = -ENOMEM;
+ goto out;
+ }
+ mm->iommu_mm = iommu_mm;
+
ret = ida_alloc_range(&iommu_global_pasid_ida, min, max, GFP_KERNEL);
- if (ret < min)
+ if (ret < min) {
+ kfree(iommu_mm);
goto out;
+ }
mm_set_pasid(mm, ret);
ret = 0;
out:
@@ -61,7 +71,7 @@ static int iommu_sva_alloc_pasid(struct mm_struct *mm, ioasid_t min, ioasid_t ma
*/
struct iommu_sva *iommu_sva_bind_device(struct device *dev, struct mm_struct *mm)
{
- struct iommu_domain *domain;
+ struct iommu_domain *domain, *sva_domain = mm->iommu_mm->sva_domain;
struct iommu_sva *handle;
ioasid_t max_pasids;
int ret;
@@ -88,31 +98,41 @@ struct iommu_sva *iommu_sva_bind_device(struct device *dev, struct mm_struct *mm
goto out_unlock;
}

- if (domain) {
- domain->users++;
- goto out;
+ if (unlikely(domain)) {
+ /* Re-attach the device to the same domain? */
+ if (domain == sva_domain) {
+ goto out;
+ } else {
+ /* Didn't get detached from the previous domain? */
+ ret = -EBUSY;
+ goto out_unlock;
+ }
}

- /* Allocate a new domain and set it on device pasid. */
- domain = iommu_sva_domain_alloc(dev, mm);
- if (!domain) {
- ret = -ENOMEM;
- goto out_unlock;
+ if (sva_domain) {
+ sva_domain->users++;
+ } else {
+ /* Allocate a new domain and set it on device pasid. */
+ sva_domain = iommu_sva_domain_alloc(dev, mm);
+ if (!sva_domain) {
+ ret = -ENOMEM;
+ goto out_unlock;
+ }
+ sva_domain->users = 1;
}

- ret = iommu_attach_device_pasid(domain, dev, mm_get_pasid(mm));
+ ret = iommu_attach_device_pasid(sva_domain, dev, mm_get_pasid(mm));
if (ret)
goto out_free_domain;
- domain->users = 1;
out:
mutex_unlock(&iommu_sva_lock);
handle->dev = dev;
- handle->domain = domain;
+ handle->domain = sva_domain;

return handle;

out_free_domain:
- iommu_domain_free(domain);
+ iommu_domain_free(sva_domain);
out_unlock:
mutex_unlock(&iommu_sva_lock);
kfree(handle);
@@ -136,10 +156,9 @@ void iommu_sva_unbind_device(struct iommu_sva *handle)
struct device *dev = handle->dev;

mutex_lock(&iommu_sva_lock);
- if (--domain->users == 0) {
- iommu_detach_device_pasid(domain, dev, pasid);
+ iommu_detach_device_pasid(domain, dev, pasid);
+ if (--domain->users == 0)
iommu_domain_free(domain);
- }
mutex_unlock(&iommu_sva_lock);
kfree(handle);
}
@@ -217,4 +236,5 @@ void mm_pasid_drop(struct mm_struct *mm)
return;

ida_free(&iommu_global_pasid_ida, mm_get_pasid(mm));
+ kfree(mm->iommu_mm);
}
diff --git a/drivers/iommu/iommu.c b/drivers/iommu/iommu.c
index 35fa1c1b12826..2f55a157b1f15 100644
--- a/drivers/iommu/iommu.c
+++ b/drivers/iommu/iommu.c
@@ -3393,5 +3393,6 @@ struct iommu_domain *iommu_sva_domain_alloc(struct device *dev,
domain->iopf_handler = iommu_sva_handle_iopf;
domain->fault_data = mm;

+ mm->iommu_mm->sva_domain = domain;
return domain;
}
diff --git a/include/linux/iommu.h b/include/linux/iommu.h
index 20135912584ba..1511ded7bc910 100644
--- a/include/linux/iommu.h
+++ b/include/linux/iommu.h
@@ -1175,20 +1175,20 @@ static inline bool tegra_dev_iommu_get_stream_id(struct device *dev, u32 *stream
#ifdef CONFIG_IOMMU_SVA
static inline void mm_pasid_init(struct mm_struct *mm)
{
- mm->pasid = IOMMU_PASID_INVALID;
+ mm->iommu_mm = &default_iommu_mm;
}
static inline bool mm_valid_pasid(struct mm_struct *mm)
{
- return mm->pasid != IOMMU_PASID_INVALID;
+ return mm->iommu_mm->pasid != IOMMU_PASID_INVALID;
}
static inline u32 mm_get_pasid(struct mm_struct *mm)
{
- return mm->pasid;
+ return mm->iommu_mm->pasid;
}

static inline void mm_set_pasid(struct mm_struct *mm, u32 pasid)
{
- mm->pasid = pasid;
+ mm->iommu_mm->pasid = pasid;
}
void mm_pasid_drop(struct mm_struct *mm);
struct iommu_sva *iommu_sva_bind_device(struct device *dev,
--
2.34.1