Re: vfio-pci: protect remap_pfn_range() from simultaneous calls

From: Ankur Arora
Date: Thu Feb 25 2021 - 19:58:17 EST


Hi Bharat,

Can you test the patch below to see if it works for you?

Also could you add some more detail to your earlier description of
the bug?
In particular, AFAICS you are using ODP (-DPDK?) with multiple
threads touching this region. From your stack, it looks like the
fault was user-space generated, and I'm guessing you were not
using the VFIO_IOMMU_MAP_DMA.

Ankur

-- >8 --

Subject: [PATCH] vfio-pci: protect io_remap_pfn_range() from simultaneous calls

vfio_pci_mmap_fault() maps the complete VMA on fault. With concurrent
faults, this would result in multiple calls to io_remap_pfn_range(),
where it would hit a BUG_ON(!pte_none(*pte)) in remap_pte_range().
(It would also link the same VMA multiple times in vdev->vma_list
but given the BUG_ON that is less serious.)

Normally, however, this won't happen -- at least with vfio_iommu_type1 --
the VFIO_IOMMU_MAP_DMA path is protected by iommu->lock.

If, however, we are using some kind of parallelization mechanism like
this one with ktask under discussion [1], we would hit this.
Even if we were doing this serially, given that vfio-pci remaps a larger
extent than strictly necessary it should internally enforce coherence of
its data structures.

Handle this by using the VMA's presence in the vdev->vma_list as
indicative of a fully mapped VMA and returning success early to
all but the first VMA fault. Note that this is clearly optimstic given
that the mapping is ongoing, and might mean that the caller sees
more faults until the remap is done.

[1] https://lore.kernel.org/linux-mm/20181105145141.6f9937f6@xxxxxxxxx/

Signed-off-by: Ankur Arora <ankur.a.arora@xxxxxxxxxx>
---
drivers/vfio/pci/vfio_pci.c | 25 ++++++++++++++++++++++++-
1 file changed, 24 insertions(+), 1 deletion(-)

diff --git a/drivers/vfio/pci/vfio_pci.c b/drivers/vfio/pci/vfio_pci.c
index 65e7e6b44578..b9f509863db1 100644
--- a/drivers/vfio/pci/vfio_pci.c
+++ b/drivers/vfio/pci/vfio_pci.c
@@ -1573,6 +1573,11 @@ static int __vfio_pci_add_vma(struct vfio_pci_device *vdev,
{
struct vfio_pci_mmap_vma *mmap_vma;

+ list_for_each_entry(mmap_vma, &vdev->vma_list, vma_next) {
+ if (mmap_vma->vma == vma)
+ return 1;
+ }
+
mmap_vma = kmalloc(sizeof(*mmap_vma), GFP_KERNEL);
if (!mmap_vma)
return -ENOMEM;
@@ -1613,6 +1618,7 @@ static vm_fault_t vfio_pci_mmap_fault(struct vm_fault *vmf)
struct vm_area_struct *vma = vmf->vma;
struct vfio_pci_device *vdev = vma->vm_private_data;
vm_fault_t ret = VM_FAULT_NOPAGE;
+ int vma_present;

mutex_lock(&vdev->vma_lock);
down_read(&vdev->memory_lock);
@@ -1623,7 +1629,21 @@ static vm_fault_t vfio_pci_mmap_fault(struct vm_fault *vmf)
goto up_out;
}

- if (__vfio_pci_add_vma(vdev, vma)) {
+ /*
+ * __vfio_pci_add_vma() either adds the vma to the vdev->vma_list
+ * (vma_present == 0), or indicates that the vma is already present
+ * on the list (vma_present == 1).
+ *
+ * Overload the meaning of this flag to also imply that the vma is
+ * fully mapped. This allows us to serialize the mapping -- ensuring
+ * that simultaneous faults will not both try to call
+ * io_remap_pfn_range().
+ *
+ * However, this might mean that callers to which we returned success
+ * optimistically will see more faults until the remap is complete.
+ */
+ vma_present = __vfio_pci_add_vma(vdev, vma);
+ if (vma_present < 0) {
ret = VM_FAULT_OOM;
mutex_unlock(&vdev->vma_lock);
goto up_out;
@@ -1631,6 +1651,9 @@ static vm_fault_t vfio_pci_mmap_fault(struct vm_fault *vmf)

mutex_unlock(&vdev->vma_lock);

+ if (vma_present)
+ goto up_out;
+
if (io_remap_pfn_range(vma, vma->vm_start, vma->vm_pgoff,
vma->vm_end - vma->vm_start, vma->vm_page_prot))
ret = VM_FAULT_SIGBUS;
--
2.29.2