[PATCH] mm/mempolicy: Correctly update prev when policy is equal on mbind

From: Lorenzo Stoakes
Date: Sun Apr 30 2023 - 09:00:28 EST


The refactoring in commit f4e9e0e69468 ("mm/mempolicy: fix use-after-free
of VMA iterator") introduces a subtle bug which arises when attempting to
apply a new NUMA policy across a range of VMAs in mbind_range().

The refactoring passes a **prev pointer to keep track of the previous VMA
in order to reduce duplication, and in all but one case it keeps this
correctly updated.

The bug arises when a VMA within the specified range has an equivalent
policy as determined by mpol_equal() - which unlike other cases, does not
update prev.

This can result in a situation where, later in the iteration, a VMA is
found whose policy does need to change. At this point, vma_merge() is
invoked with prev pointing to a VMA which is before the previous VMA.

Since vma_merge() discovers the curr VMA by looking for the one immediately
after prev, it will now be in a situation where this VMA is incorrect and
the merge will not proceed correctly.

This is checked in the VM_WARN_ON() invariant case with end > curr->vm_end,
which, if a merge is possible, results in a warning (if CONFIG_DEBUG_VM is
specified).

I note that vma_merge() performs these invariant checks only after
merge_prev/merge_next are checked, which is debatable as it hides this
issue if no merge is possible even though a buggy situation has arisen.

The solution is simply to update the prev pointer even when policies are
equal.

This caused a bug to arise in the 6.2.y stable tree, and this patch
resolves this bug.

Reported-by: kernel test robot <oliver.sang@xxxxxxxxx>
Link: https://lore.kernel.org/oe-lkp/202304292203.44ddeff6-oliver.sang@xxxxxxxxx
Fixes: f4e9e0e69468 ("mm/mempolicy: fix use-after-free of VMA iterator")
Signed-off-by: Lorenzo Stoakes <lstoakes@xxxxxxxxx>
---
mm/mempolicy.c | 4 +++-
1 file changed, 3 insertions(+), 1 deletion(-)

diff --git a/mm/mempolicy.c b/mm/mempolicy.c
index 2068b594dc88..1756389a0609 100644
--- a/mm/mempolicy.c
+++ b/mm/mempolicy.c
@@ -808,8 +808,10 @@ static int mbind_range(struct vma_iterator *vmi, struct vm_area_struct *vma,
vmstart = vma->vm_start;
}

- if (mpol_equal(vma_policy(vma), new_pol))
+ if (mpol_equal(vma_policy(vma), new_pol)) {
+ *prev = vma;
return 0;
+ }

pgoff = vma->vm_pgoff + ((vmstart - vma->vm_start) >> PAGE_SHIFT);
merged = vma_merge(vmi, vma->vm_mm, *prev, vmstart, vmend, vma->vm_flags,
--
2.40.1