Re: [PATCH v8 12/12] iommu: Use refcount for fault data access

From: Baolu Lu
Date: Mon Dec 11 2023 - 22:49:13 EST


On 12/11/23 11:12 PM, Jason Gunthorpe wrote:
On Thu, Dec 07, 2023 at 02:43:08PM +0800, Lu Baolu wrote:
+/*
+ * Return the fault parameter of a device if it exists. Otherwise, return NULL.
+ * On a successful return, the caller takes a reference of this parameter and
+ * should put it after use by calling iopf_put_dev_fault_param().
+ */
+static struct iommu_fault_param *iopf_get_dev_fault_param(struct device *dev)
+{
+ struct dev_iommu *param = dev->iommu;
+ struct iommu_fault_param *fault_param;
+
+ if (!param)
+ return NULL;

Is it actually possible to call this function on a device that does
not have an iommu driver probed? I'd be surprised by that, maybe this
should be WARN_ONE

Above check seems to be unnecessary. This helper should only be used
during the iommu probe and release. We can just remove it as any drivers
that abuse this will generate a null-pointer reference warning.


+
+ rcu_read_lock();
+ fault_param = param->fault_param;

The RCU stuff is not right, like this:

diff --git a/drivers/iommu/io-pgfault.c b/drivers/iommu/io-pgfault.c
index 2ace32c6d13bf3..0258f79c8ddf98 100644
--- a/drivers/iommu/io-pgfault.c
+++ b/drivers/iommu/io-pgfault.c
@@ -40,7 +40,7 @@ static struct iommu_fault_param *iopf_get_dev_fault_param(struct device *dev)
return NULL;
rcu_read_lock();
- fault_param = param->fault_param;
+ fault_param = rcu_dereference(param->fault_param);
if (fault_param && !refcount_inc_not_zero(&fault_param->users))
fault_param = NULL;
rcu_read_unlock();
@@ -51,17 +51,8 @@ static struct iommu_fault_param *iopf_get_dev_fault_param(struct device *dev)
/* Caller must hold a reference of the fault parameter. */
static void iopf_put_dev_fault_param(struct iommu_fault_param *fault_param)
{
- struct dev_iommu *param = fault_param->dev->iommu;
-
- rcu_read_lock();
- if (!refcount_dec_and_test(&fault_param->users)) {
- rcu_read_unlock();
- return;
- }
- rcu_read_unlock();
-
- param->fault_param = NULL;
- kfree_rcu(fault_param, rcu);
+ if (refcount_dec_and_test(&fault_param->users))
+ kfree_rcu(fault_param, rcu);
}
/**
@@ -174,7 +165,7 @@ static int iommu_handle_iopf(struct iommu_fault *fault,
}
mutex_unlock(&iopf_param->lock);
- ret = domain->iopf_handler(group);
+ ret = domain->iopf_handler(iopf_param, group);
mutex_lock(&iopf_param->lock);
if (ret)
iopf_free_group(group);
@@ -398,7 +389,8 @@ int iopf_queue_add_device(struct iopf_queue *queue, struct device *dev)
mutex_lock(&queue->lock);
mutex_lock(&param->lock);
- if (param->fault_param) {
+ if (rcu_dereference_check(param->fault_param,
+ lockdep_is_held(&param->lock))) {
ret = -EBUSY;
goto done_unlock;
}
@@ -418,7 +410,7 @@ int iopf_queue_add_device(struct iopf_queue *queue, struct device *dev)
list_add(&fault_param->queue_list, &queue->devices);
fault_param->queue = queue;
- param->fault_param = fault_param;
+ rcu_assign_pointer(param->fault_param, fault_param);
done_unlock:
mutex_unlock(&param->lock);
@@ -442,10 +434,12 @@ int iopf_queue_remove_device(struct iopf_queue *queue, struct device *dev)
int ret = 0;
struct iopf_fault *iopf, *next;
struct dev_iommu *param = dev->iommu;
- struct iommu_fault_param *fault_param = param->fault_param;
+ struct iommu_fault_param *fault_param;
mutex_lock(&queue->lock);
mutex_lock(&param->lock);
+ fault_param = rcu_dereference_check(param->fault_param,
+ lockdep_is_held(&param->lock));
if (!fault_param) {
ret = -ENODEV;
goto unlock;
@@ -467,7 +461,10 @@ int iopf_queue_remove_device(struct iopf_queue *queue, struct device *dev)
list_for_each_entry_safe(iopf, next, &fault_param->partial, list)
kfree(iopf);
- iopf_put_dev_fault_param(fault_param);
+ /* dec the ref owned by iopf_queue_add_device() */
+ rcu_assign_pointer(param->fault_param, NULL);
+ if (refcount_dec_and_test(&fault_param->users))
+ kfree_rcu(fault_param, rcu);
unlock:
mutex_unlock(&param->lock);
mutex_unlock(&queue->lock);
diff --git a/drivers/iommu/iommu-sva.c b/drivers/iommu/iommu-sva.c
index 325d1810e133a1..63c1a233a7e91f 100644
--- a/drivers/iommu/iommu-sva.c
+++ b/drivers/iommu/iommu-sva.c
@@ -232,10 +232,9 @@ static void iommu_sva_handle_iopf(struct work_struct *work)
iopf_free_group(group);
}
-static int iommu_sva_iopf_handler(struct iopf_group *group)
+static int iommu_sva_iopf_handler(struct iommu_fault_param *fault_param,
+ struct iopf_group *group)
{
- struct iommu_fault_param *fault_param = group->dev->iommu->fault_param;
-
INIT_WORK(&group->work, iommu_sva_handle_iopf);
if (!queue_work(fault_param->queue->wq, &group->work))
return -EBUSY;
diff --git a/include/linux/iommu.h b/include/linux/iommu.h
index 8020bb44a64ab1..e16fa9811d5023 100644
--- a/include/linux/iommu.h
+++ b/include/linux/iommu.h
@@ -41,6 +41,7 @@ struct iommu_dirty_ops;
struct notifier_block;
struct iommu_sva;
struct iommu_dma_cookie;
+struct iommu_fault_param;
#define IOMMU_FAULT_PERM_READ (1 << 0) /* read */
#define IOMMU_FAULT_PERM_WRITE (1 << 1) /* write */
@@ -210,7 +211,8 @@ struct iommu_domain {
unsigned long pgsize_bitmap; /* Bitmap of page sizes in use */
struct iommu_domain_geometry geometry;
struct iommu_dma_cookie *iova_cookie;
- int (*iopf_handler)(struct iopf_group *group);
+ int (*iopf_handler)(struct iommu_fault_param *fault_param,
+ struct iopf_group *group);

How about folding fault_param into iopf_group?

iopf_group is the central data around a iopf handling. The iopf_group
holds the reference count of the device's fault parameter structure
throughout its entire lifecycle.

void *fault_data;
union {
struct {
@@ -637,7 +639,7 @@ struct iommu_fault_param {
*/
struct dev_iommu {
struct mutex lock;
- struct iommu_fault_param *fault_param;
+ struct iommu_fault_param __rcu *fault_param;
struct iommu_fwspec *fwspec;
struct iommu_device *iommu_dev;
void *priv;

The iommu_page_response() needs to change accordingly which is pointed
out in the next email.

Others look good to me. Thank you so much!

Best regards,
baolu