[PATCH rfc v2 01/10] mm: add a generic VMA lock-based page fault handler

From: Kefeng Wang
Date: Mon Aug 21 2023 - 08:31:20 EST


The ARCH_SUPPORTS_PER_VMA_LOCK are enabled by more and more architectures,
eg, x86, arm64, powerpc and s390, and riscv, those implementation are very
similar which results in some duplicated codes, let's add a generic VMA
lock-based page fault handler try_to_vma_locked_page_fault() to eliminate
them, and which also make us easy to support this on new architectures.

Since different architectures use different way to check vma whether is
accessable or not, the struct pt_regs, page fault error code and vma flags
are added into struct vm_fault, then, the architecture's page fault code
could re-use struct vm_fault to record and check vma accessable by each
own implementation.

Signed-off-by: Kefeng Wang <wangkefeng.wang@xxxxxxxxxx>
---
include/linux/mm.h | 17 +++++++++++++++++
include/linux/mm_types.h | 2 ++
mm/memory.c | 39 +++++++++++++++++++++++++++++++++++++++
3 files changed, 58 insertions(+)

diff --git a/include/linux/mm.h b/include/linux/mm.h
index 3f764e84e567..22a6f4c56ff3 100644
--- a/include/linux/mm.h
+++ b/include/linux/mm.h
@@ -512,9 +512,12 @@ struct vm_fault {
pgoff_t pgoff; /* Logical page offset based on vma */
unsigned long address; /* Faulting virtual address - masked */
unsigned long real_address; /* Faulting virtual address - unmasked */
+ unsigned long fault_code; /* Faulting error code during page fault */
+ struct pt_regs *regs; /* The registers stored during page fault */
};
enum fault_flag flags; /* FAULT_FLAG_xxx flags
* XXX: should really be 'const' */
+ vm_flags_t vm_flags; /* VMA flags to be used for access checking */
pmd_t *pmd; /* Pointer to pmd entry matching
* the 'address' */
pud_t *pud; /* Pointer to pud entry matching
@@ -774,6 +777,9 @@ static inline void assert_fault_locked(struct vm_fault *vmf)
struct vm_area_struct *lock_vma_under_rcu(struct mm_struct *mm,
unsigned long address);

+bool arch_vma_access_error(struct vm_area_struct *vma, struct vm_fault *vmf);
+vm_fault_t try_vma_locked_page_fault(struct vm_fault *vmf);
+
#else /* CONFIG_PER_VMA_LOCK */

static inline bool vma_start_read(struct vm_area_struct *vma)
@@ -801,6 +807,17 @@ static inline void assert_fault_locked(struct vm_fault *vmf)
mmap_assert_locked(vmf->vma->vm_mm);
}

+static inline struct vm_area_struct *lock_vma_under_rcu(struct mm_struct *mm,
+ unsigned long address)
+{
+ return NULL;
+}
+
+static inline vm_fault_t try_vma_locked_page_fault(struct vm_fault *vmf)
+{
+ return VM_FAULT_NONE;
+}
+
#endif /* CONFIG_PER_VMA_LOCK */

extern const struct vm_operations_struct vma_dummy_vm_ops;
diff --git a/include/linux/mm_types.h b/include/linux/mm_types.h
index f5ba5b0bc836..702820cea3f9 100644
--- a/include/linux/mm_types.h
+++ b/include/linux/mm_types.h
@@ -1119,6 +1119,7 @@ typedef __bitwise unsigned int vm_fault_t;
* fault. Used to decide whether a process gets delivered SIGBUS or
* just gets major/minor fault counters bumped up.
*
+ * @VM_FAULT_NONE: Special case, not starting to handle fault
* @VM_FAULT_OOM: Out Of Memory
* @VM_FAULT_SIGBUS: Bad access
* @VM_FAULT_MAJOR: Page read from storage
@@ -1139,6 +1140,7 @@ typedef __bitwise unsigned int vm_fault_t;
*
*/
enum vm_fault_reason {
+ VM_FAULT_NONE = (__force vm_fault_t)0x000000,
VM_FAULT_OOM = (__force vm_fault_t)0x000001,
VM_FAULT_SIGBUS = (__force vm_fault_t)0x000002,
VM_FAULT_MAJOR = (__force vm_fault_t)0x000004,
diff --git a/mm/memory.c b/mm/memory.c
index 3b4aaa0d2fff..60fe35db5134 100644
--- a/mm/memory.c
+++ b/mm/memory.c
@@ -5510,6 +5510,45 @@ struct vm_area_struct *lock_vma_under_rcu(struct mm_struct *mm,
count_vm_vma_lock_event(VMA_LOCK_ABORT);
return NULL;
}
+
+#ifdef CONFIG_PER_VMA_LOCK
+bool __weak arch_vma_access_error(struct vm_area_struct *vma, struct vm_fault *vmf)
+{
+ return (vma->vm_flags & vmf->vm_flags) == 0;
+}
+#endif
+
+vm_fault_t try_vma_locked_page_fault(struct vm_fault *vmf)
+{
+ vm_fault_t fault = VM_FAULT_NONE;
+ struct vm_area_struct *vma;
+
+ if (!(vmf->flags & FAULT_FLAG_USER))
+ return fault;
+
+ vma = lock_vma_under_rcu(current->mm, vmf->real_address);
+ if (!vma)
+ return fault;
+
+ if (arch_vma_access_error(vma, vmf)) {
+ vma_end_read(vma);
+ return fault;
+ }
+
+ fault = handle_mm_fault(vma, vmf->real_address,
+ vmf->flags | FAULT_FLAG_VMA_LOCK, vmf->regs);
+
+ if (!(fault & (VM_FAULT_RETRY | VM_FAULT_COMPLETED)))
+ vma_end_read(vma);
+
+ if (fault & VM_FAULT_RETRY)
+ count_vm_vma_lock_event(VMA_LOCK_RETRY);
+ else
+ count_vm_vma_lock_event(VMA_LOCK_SUCCESS);
+
+ return fault;
+}
+
#endif /* CONFIG_PER_VMA_LOCK */

#ifndef __PAGETABLE_P4D_FOLDED
--
2.27.0