Re: [PATCH v4 6/6] mm: handle userfaults under VMA lock

From: Peter Xu
Date: Wed Jun 28 2023 - 09:58:56 EST


On Wed, Jun 28, 2023 at 12:18:00AM -0700, Suren Baghdasaryan wrote:
> Enable handle_userfault to operate under VMA lock by releasing VMA lock
> instead of mmap_lock and retrying. Note that FAULT_FLAG_RETRY_NOWAIT
> should never be used when handling faults under per-VMA lock protection
> because that would break the assumption that lock is dropped on retry.
>
> Signed-off-by: Suren Baghdasaryan <surenb@xxxxxxxxxx>

Besides the NOWAIT typo all look sane. Since there seems to need at least
one more version I'll still comment on a few things..

> ---
> fs/userfaultfd.c | 39 ++++++++++++++++++---------------------
> include/linux/mm.h | 39 +++++++++++++++++++++++++++++++++++++++
> mm/filemap.c | 8 --------
> mm/memory.c | 9 ---------
> 4 files changed, 57 insertions(+), 38 deletions(-)
>
> diff --git a/fs/userfaultfd.c b/fs/userfaultfd.c
> index 4e800bb7d2ab..d019e7df6f15 100644
> --- a/fs/userfaultfd.c
> +++ b/fs/userfaultfd.c
> @@ -277,17 +277,16 @@ static inline struct uffd_msg userfault_msg(unsigned long address,
> * hugepmd ranges.
> */
> static inline bool userfaultfd_huge_must_wait(struct userfaultfd_ctx *ctx,
> - struct vm_area_struct *vma,
> - unsigned long address,
> - unsigned long flags,
> - unsigned long reason)
> + struct vm_fault *vmf,
> + unsigned long reason)
> {
> + struct vm_area_struct *vma = vmf->vma;
> pte_t *ptep, pte;
> bool ret = true;
>
> - mmap_assert_locked(ctx->mm);
> + assert_fault_locked(ctx->mm, vmf);

AFAIU ctx->mm must be the same as vma->vm_mm here, so maybe we can also
drop *ctx here altogether if we've already dropped plenty.

>
> - ptep = hugetlb_walk(vma, address, vma_mmu_pagesize(vma));
> + ptep = hugetlb_walk(vma, vmf->address, vma_mmu_pagesize(vma));
> if (!ptep)
> goto out;
>
> @@ -308,10 +307,8 @@ static inline bool userfaultfd_huge_must_wait(struct userfaultfd_ctx *ctx,
> }
> #else
> static inline bool userfaultfd_huge_must_wait(struct userfaultfd_ctx *ctx,
> - struct vm_area_struct *vma,
> - unsigned long address,
> - unsigned long flags,
> - unsigned long reason)
> + struct vm_fault *vmf,
> + unsigned long reason)
> {
> return false; /* should never get here */
> }
> @@ -325,11 +322,11 @@ static inline bool userfaultfd_huge_must_wait(struct userfaultfd_ctx *ctx,
> * threads.
> */
> static inline bool userfaultfd_must_wait(struct userfaultfd_ctx *ctx,
> - unsigned long address,
> - unsigned long flags,
> + struct vm_fault *vmf,
> unsigned long reason)
> {
> struct mm_struct *mm = ctx->mm;
> + unsigned long address = vmf->address;
> pgd_t *pgd;
> p4d_t *p4d;
> pud_t *pud;
> @@ -337,7 +334,7 @@ static inline bool userfaultfd_must_wait(struct userfaultfd_ctx *ctx,
> pte_t *pte;
> bool ret = true;
>
> - mmap_assert_locked(mm);
> + assert_fault_locked(mm, vmf);
>
> pgd = pgd_offset(mm, address);
> if (!pgd_present(*pgd))
> @@ -445,7 +442,7 @@ vm_fault_t handle_userfault(struct vm_fault *vmf, unsigned long reason)
> * Coredumping runs without mmap_lock so we can only check that
> * the mmap_lock is held, if PF_DUMPCORE was not set.
> */
> - mmap_assert_locked(mm);
> + assert_fault_locked(mm, vmf);
>
> ctx = vma->vm_userfaultfd_ctx.ctx;
> if (!ctx)
> @@ -522,8 +519,11 @@ vm_fault_t handle_userfault(struct vm_fault *vmf, unsigned long reason)
> * and wait.
> */
> ret = VM_FAULT_RETRY;
> - if (vmf->flags & FAULT_FLAG_RETRY_NOWAIT)
> + if (vmf->flags & FAULT_FLAG_RETRY_NOWAIT) {
> + /* Per-VMA lock is expected to be dropped on VM_FAULT_RETRY */
> + BUG_ON(vmf->flags & FAULT_FLAG_RETRY_NOWAIT);

Here is not the only place that we can have FAULT_FLAG_RETRY_NOWAIT.
E.g. folio_lock_or_retry() can also get it, so check here may or may not
help much.

The other thing is please consider not using BUG_ON if possible.
WARN_ON_ONCE() is IMHO always more preferred if the kernel can still try to
run even if it triggers.

I'd rather drop this change, leaving space for future when vma lock may be
supported in gup paths with NOWAIT, then here it'll work naturally, afaiu.
If we really want a sanity check, maybe the best place is when entering
handle_mm_fault(), to be explicit, sanitize_fault_flags().

> goto out;
> + }
>
> /* take the reference before dropping the mmap_lock */
> userfaultfd_ctx_get(ctx);
> @@ -561,15 +561,12 @@ vm_fault_t handle_userfault(struct vm_fault *vmf, unsigned long reason)
> spin_unlock_irq(&ctx->fault_pending_wqh.lock);
>
> if (!is_vm_hugetlb_page(vma))
> - must_wait = userfaultfd_must_wait(ctx, vmf->address, vmf->flags,
> - reason);
> + must_wait = userfaultfd_must_wait(ctx, vmf, reason);
> else
> - must_wait = userfaultfd_huge_must_wait(ctx, vma,
> - vmf->address,
> - vmf->flags, reason);
> + must_wait = userfaultfd_huge_must_wait(ctx, vmf, reason);
> if (is_vm_hugetlb_page(vma))
> hugetlb_vma_unlock_read(vma);
> - mmap_read_unlock(mm);
> + release_fault_lock(vmf);
>
> if (likely(must_wait && !READ_ONCE(ctx->released))) {
> wake_up_poll(&ctx->fd_wqh, EPOLLIN);
> diff --git a/include/linux/mm.h b/include/linux/mm.h
> index fec149585985..70bb2f923e33 100644
> --- a/include/linux/mm.h
> +++ b/include/linux/mm.h
> @@ -705,6 +705,17 @@ static inline bool vma_try_start_write(struct vm_area_struct *vma)
> return true;
> }
>
> +static inline void vma_assert_locked(struct vm_area_struct *vma)
> +{
> + int mm_lock_seq;
> +
> + if (__is_vma_write_locked(vma, &mm_lock_seq))
> + return;
> +
> + lockdep_assert_held(&vma->vm_lock->lock);
> + VM_BUG_ON_VMA(!rwsem_is_locked(&vma->vm_lock->lock), vma);
> +}
> +
> static inline void vma_assert_write_locked(struct vm_area_struct *vma)
> {
> int mm_lock_seq;
> @@ -723,6 +734,23 @@ static inline void vma_mark_detached(struct vm_area_struct *vma, bool detached)
> struct vm_area_struct *lock_vma_under_rcu(struct mm_struct *mm,
> unsigned long address);
>
> +static inline
> +void assert_fault_locked(struct mm_struct *mm, struct vm_fault *vmf)
> +{
> + if (vmf->flags & FAULT_FLAG_VMA_LOCK)
> + vma_assert_locked(vmf->vma);
> + else
> + mmap_assert_locked(mm);
> +}
> +
> +static inline void release_fault_lock(struct vm_fault *vmf)
> +{
> + if (vmf->flags & FAULT_FLAG_VMA_LOCK)
> + vma_end_read(vmf->vma);
> + else
> + mmap_read_unlock(vmf->vma->vm_mm);
> +}
> +
> #else /* CONFIG_PER_VMA_LOCK */
>
> static inline void vma_init_lock(struct vm_area_struct *vma) {}
> @@ -736,6 +764,17 @@ static inline void vma_assert_write_locked(struct vm_area_struct *vma) {}
> static inline void vma_mark_detached(struct vm_area_struct *vma,
> bool detached) {}
>
> +static inline
> +void assert_fault_locked(struct mm_struct *mm, struct vm_fault *vmf)
> +{
> + mmap_assert_locked(mm);
> +}
> +
> +static inline void release_fault_lock(struct vm_fault *vmf)
> +{
> + mmap_read_unlock(vmf->vma->vm_mm);
> +}
> +
> #endif /* CONFIG_PER_VMA_LOCK */
>
> /*
> diff --git a/mm/filemap.c b/mm/filemap.c
> index 7ee078e1a0d2..d4d8f474e0c5 100644
> --- a/mm/filemap.c
> +++ b/mm/filemap.c
> @@ -1699,14 +1699,6 @@ static int __folio_lock_async(struct folio *folio, struct wait_page_queue *wait)
> return ret;
> }
>
> -static void release_fault_lock(struct vm_fault *vmf)
> -{
> - if (vmf->flags & FAULT_FLAG_VMA_LOCK)
> - vma_end_read(vmf->vma);
> - else
> - mmap_read_unlock(vmf->vma->vm_mm);
> -}

The movement is fine but may not be the cleanest. It'll be nicer to me if
it's put at the right place when introduced - after all in the same series.

Thanks,

> -
> /*
> * Return values:
> * 0 - folio is locked.
> diff --git a/mm/memory.c b/mm/memory.c
> index 76c7907e7286..c6c759922f39 100644
> --- a/mm/memory.c
> +++ b/mm/memory.c
> @@ -5294,15 +5294,6 @@ struct vm_area_struct *lock_vma_under_rcu(struct mm_struct *mm,
> if (!vma_start_read(vma))
> goto inval;
>
> - /*
> - * Due to the possibility of userfault handler dropping mmap_lock, avoid
> - * it for now and fall back to page fault handling under mmap_lock.
> - */
> - if (userfaultfd_armed(vma)) {
> - vma_end_read(vma);
> - goto inval;
> - }
> -
> /* Check since vm_start/vm_end might change before we lock the VMA */
> if (unlikely(address < vma->vm_start || address >= vma->vm_end)) {
> vma_end_read(vma);
> --
> 2.41.0.162.gfafddb0af9-goog
>

--
Peter Xu