Re: [PATCH v9 14/14] iommu: Track iopf group instead of last fault

From: Jason Gunthorpe
Date: Fri Jan 05 2024 - 12:54:06 EST


On Wed, Dec 20, 2023 at 09:23:32AM +0800, Lu Baolu wrote:
> /**
> - * iommu_handle_iopf - IO Page Fault handler
> - * @fault: fault event
> - * @iopf_param: the fault parameter of the device.
> + * iommu_report_device_fault() - Report fault event to device driver
> + * @dev: the device
> + * @evt: fault event data
> *
> - * Add a fault to the device workqueue, to be handled by mm.
> + * Called by IOMMU drivers when a fault is detected, typically in a threaded IRQ
> + * handler. When this function fails and the fault is recoverable, it is the
> + * caller's responsibility to complete the fault.

This patch seems OK for what it does so:

Reviewed-by: Jason Gunthorpe <jgg@xxxxxxxxxx>

However, this seems like a strange design, surely this function should
just call ops->page_response() when it can't enqueue the fault?

It is much cleaner that way, so maybe you can take this into a
following patch (along with the driver fixes to accomodate. (and
perhaps iommu_report_device_fault() should return void too)

Also iopf_group_response() should return void (another patch!),
nothing can do anything with the failure. This implies that
ops->page_response() must also return void - which is consistent with
what the drivers do, the failure paths are all integrity validations
of the fault and should be WARN_ON'd not return codes.

diff --git a/drivers/iommu/io-pgfault.c b/drivers/iommu/io-pgfault.c
index 7d11b74e4048e2..2715e24fd64234 100644
--- a/drivers/iommu/io-pgfault.c
+++ b/drivers/iommu/io-pgfault.c
@@ -39,7 +39,7 @@ static void iopf_put_dev_fault_param(struct iommu_fault_param *fault_param)
kfree_rcu(fault_param, rcu);
}

-void iopf_free_group(struct iopf_group *group)
+static void __iopf_free_group(struct iopf_group *group)
{
struct iopf_fault *iopf, *next;

@@ -50,6 +50,11 @@ void iopf_free_group(struct iopf_group *group)

/* Pair with iommu_report_device_fault(). */
iopf_put_dev_fault_param(group->fault_param);
+}
+
+void iopf_free_group(struct iopf_group *group)
+{
+ __iopf_free_group(group);
kfree(group);
}
EXPORT_SYMBOL_GPL(iopf_free_group);
@@ -97,14 +102,49 @@ static int report_partial_fault(struct iommu_fault_param *fault_param,
return 0;
}

+static struct iopf_group *iopf_group_alloc(struct iommu_fault_param *iopf_param,
+ struct iopf_fault *evt,
+ struct iopf_group *abort_group)
+{
+ struct iopf_fault *iopf, *next;
+ struct iopf_group *group;
+
+ group = kzalloc(sizeof(*group), GFP_KERNEL);
+ if (!group) {
+ /*
+ * We always need to construct the group as we need it to abort
+ * the request at the driver if it cfan't be handled.
+ */
+ group = abort_group;
+ }
+
+ group->fault_param = iopf_param;
+ group->last_fault.fault = evt->fault;
+ INIT_LIST_HEAD(&group->faults);
+ INIT_LIST_HEAD(&group->pending_node);
+ list_add(&group->last_fault.list, &group->faults);
+
+ /* See if we have partial faults for this group */
+ mutex_lock(&iopf_param->lock);
+ list_for_each_entry_safe(iopf, next, &iopf_param->partial, list) {
+ if (iopf->fault.prm.grpid == evt->fault.prm.grpid)
+ /* Insert *before* the last fault */
+ list_move(&iopf->list, &group->faults);
+ }
+ list_add(&group->pending_node, &iopf_param->faults);
+ mutex_unlock(&iopf_param->lock);
+
+ return group;
+}
+
/**
* iommu_report_device_fault() - Report fault event to device driver
* @dev: the device
* @evt: fault event data
*
* Called by IOMMU drivers when a fault is detected, typically in a threaded IRQ
- * handler. When this function fails and the fault is recoverable, it is the
- * caller's responsibility to complete the fault.
+ * handler. If this function fails then ops->page_response() was called to
+ * complete evt if required.
*
* This module doesn't handle PCI PASID Stop Marker; IOMMU drivers must discard
* them before reporting faults. A PASID Stop Marker (LRW = 0b100) doesn't
@@ -143,22 +183,24 @@ int iommu_report_device_fault(struct device *dev, struct iopf_fault *evt)
{
struct iommu_fault *fault = &evt->fault;
struct iommu_fault_param *iopf_param;
- struct iopf_fault *iopf, *next;
- struct iommu_domain *domain;
+ struct iopf_group abort_group;
struct iopf_group *group;
int ret;

+/*
+ remove this too, it is pointless. The driver should only invoke this function on page_req faults.
if (fault->type != IOMMU_FAULT_PAGE_REQ)
return -EOPNOTSUPP;
+*/

iopf_param = iopf_get_dev_fault_param(dev);
- if (!iopf_param)
+ if (WARN_ON(!iopf_param))
return -ENODEV;

if (!(fault->prm.flags & IOMMU_FAULT_PAGE_REQUEST_LAST_PAGE)) {
ret = report_partial_fault(iopf_param, fault);
iopf_put_dev_fault_param(iopf_param);
-
+ /* A request that is not the last does not need to be ack'd */
return ret;
}

@@ -170,56 +212,34 @@ int iommu_report_device_fault(struct device *dev, struct iopf_fault *evt)
* will send a response to the hardware. We need to clean up before
* leaving, otherwise partial faults will be stuck.
*/
- domain = get_domain_for_iopf(dev, fault);
- if (!domain) {
- ret = -EINVAL;
- goto cleanup_partial;
- }
-
- group = kzalloc(sizeof(*group), GFP_KERNEL);
- if (!group) {
+ group = iopf_group_alloc(iopf_param, evt, &abort_group);
+ if (group == &abort_group) {
ret = -ENOMEM;
- goto cleanup_partial;
+ goto err_abort;
}

- group->fault_param = iopf_param;
- group->last_fault.fault = *fault;
- INIT_LIST_HEAD(&group->faults);
- INIT_LIST_HEAD(&group->pending_node);
- group->domain = domain;
- list_add(&group->last_fault.list, &group->faults);
-
- /* See if we have partial faults for this group */
- mutex_lock(&iopf_param->lock);
- list_for_each_entry_safe(iopf, next, &iopf_param->partial, list) {
- if (iopf->fault.prm.grpid == fault->prm.grpid)
- /* Insert *before* the last fault */
- list_move(&iopf->list, &group->faults);
+ group->domain = get_domain_for_iopf(dev, fault);
+ if (!group->domain) {
+ ret = -EINVAL;
+ goto err_abort;
}
- list_add(&group->pending_node, &iopf_param->faults);
- mutex_unlock(&iopf_param->lock);

- ret = domain->iopf_handler(group);
- if (ret) {
- mutex_lock(&iopf_param->lock);
- list_del_init(&group->pending_node);
- mutex_unlock(&iopf_param->lock);
+ /*
+ * On success iopf_handler must call iopf_group_response() and
+ * iopf_free_group()
+ */
+ ret = group->domain->iopf_handler(group);
+ if (ret)
+ goto err_abort;
+ return 0;
+
+err_abort:
+ iopf_group_response(group,
+ IOMMU_PAGE_RESP_FAILURE); //?? right code?
+ if (group == &abort_group)
+ __iopf_free_group(group);
+ else
iopf_free_group(group);
- }
-
- return ret;
-
-cleanup_partial:
- mutex_lock(&iopf_param->lock);
- list_for_each_entry_safe(iopf, next, &iopf_param->partial, list) {
- if (iopf->fault.prm.grpid == fault->prm.grpid) {
- list_del(&iopf->list);
- kfree(iopf);
- }
- }
- mutex_unlock(&iopf_param->lock);
- iopf_put_dev_fault_param(iopf_param);
-
return ret;
}
EXPORT_SYMBOL_GPL(iommu_report_device_fault);
@@ -262,7 +282,7 @@ EXPORT_SYMBOL_GPL(iopf_queue_flush_dev);
*
* Return 0 on success and <0 on error.
*/
-int iopf_group_response(struct iopf_group *group,
+void iopf_group_response(struct iopf_group *group,
enum iommu_page_response_code status)
{
struct iommu_fault_param *fault_param = group->fault_param;
@@ -400,9 +420,9 @@ EXPORT_SYMBOL_GPL(iopf_queue_add_device);
*/
void iopf_queue_remove_device(struct iopf_queue *queue, struct device *dev)
{
- struct iopf_fault *iopf, *next;
+ struct iopf_fault *partial_iopf;
+ struct iopf_fault *next;
struct iopf_group *group, *temp;
- struct iommu_page_response resp;
struct dev_iommu *param = dev->iommu;
struct iommu_fault_param *fault_param;
const struct iommu_ops *ops = dev_iommu_ops(dev);
@@ -416,15 +436,16 @@ void iopf_queue_remove_device(struct iopf_queue *queue, struct device *dev)
goto unlock;

mutex_lock(&fault_param->lock);
- list_for_each_entry_safe(iopf, next, &fault_param->partial, list)
- kfree(iopf);
+ list_for_each_entry_safe(partial_iopf, next, &fault_param->partial, list)
+ kfree(partial_iopf);

list_for_each_entry_safe(group, temp, &fault_param->faults, pending_node) {
- memset(&resp, 0, sizeof(struct iommu_page_response));
- iopf = &group->last_fault;
- resp.pasid = iopf->fault.prm.pasid;
- resp.grpid = iopf->fault.prm.grpid;
- resp.code = IOMMU_PAGE_RESP_INVALID;
+ struct iopf_fault *iopf = &group->last_fault;
+ struct iommu_page_response resp = {
+ .pasid = iopf->fault.prm.pasid,
+ .grpid = iopf->fault.prm.grpid,
+ .code = IOMMU_PAGE_RESP_INVALID
+ };

if (iopf->fault.prm.flags & IOMMU_FAULT_PAGE_RESPONSE_NEEDS_PASID)
resp.flags = IOMMU_PAGE_RESP_PASID_VALID;