[PATCH RFC v2 4/4] mm/ptshare: Add page fault handling for page table shared regions

From: Khalid Aziz
Date: Wed Apr 26 2023 - 12:51:10 EST


Add support for creating a new set of shared page tables in a new
mm_struct upon mmap of an region that can potentially share page
tables. Add page fault handling for this now shared region. Modify
free_pgtables path to make sure page tables in the shared regions
are kept intact when a process using page table region exits and
there are other mappers for the shared region. Clean up mm_struct
holding shared page tables when the last process sharing the region
exits.

Signed-off-by: Khalid Aziz <khalid.aziz@xxxxxxxxxx>
Signed-off-by: Matthew Wilcox (Oracle) <willy@xxxxxxxxxxxxx>
---
mm/internal.h | 2 +
mm/memory.c | 105 ++++++++++++++++++++++++++++++------
mm/ptshare.c | 143 ++++++++++++++++++++++++++++++++++++++++++++++++--
3 files changed, 232 insertions(+), 18 deletions(-)

diff --git a/mm/internal.h b/mm/internal.h
index 3efb8738e26f..924065f721fe 100644
--- a/mm/internal.h
+++ b/mm/internal.h
@@ -1061,4 +1061,6 @@ struct ptshare_data {
int ptshare_new_mm(struct file *file, struct vm_area_struct *vma);
void ptshare_del_mm(struct vm_area_struct *vm);
int ptshare_insert_vma(struct mm_struct *mm, struct vm_area_struct *vma);
+extern vm_fault_t find_shared_vma(struct vm_area_struct **vmap,
+ unsigned long *addrp, unsigned int flags);
#endif /* __MM_INTERNAL_H */
diff --git a/mm/memory.c b/mm/memory.c
index 01a23ad48a04..c67318ffd001 100644
--- a/mm/memory.c
+++ b/mm/memory.c
@@ -172,17 +172,28 @@ void mm_trace_rss_stat(struct mm_struct *mm, int member)
* has been handled earlier when unmapping all the memory regions.
*/
static void free_pte_range(struct mmu_gather *tlb, pmd_t *pmd,
- unsigned long addr)
+ unsigned long addr, bool shared_pte)
{
pgtable_t token = pmd_pgtable(*pmd);
pmd_clear(pmd);
+ /*
+ * if this address range shares page tables with other processes,
+ * do not release pte pages. Those pages will be released when
+ * host mm that hosts these pte pages is released
+ */
+ if (shared_pte) {
+ tlb_flush_pmd_range(tlb, addr, PAGE_SIZE);
+ tlb->freed_tables = 1;
+ return;
+ }
pte_free_tlb(tlb, token, addr);
mm_dec_nr_ptes(tlb->mm);
}

static inline void free_pmd_range(struct mmu_gather *tlb, pud_t *pud,
unsigned long addr, unsigned long end,
- unsigned long floor, unsigned long ceiling)
+ unsigned long floor, unsigned long ceiling,
+ bool shared_pte)
{
pmd_t *pmd;
unsigned long next;
@@ -194,7 +205,7 @@ static inline void free_pmd_range(struct mmu_gather *tlb, pud_t *pud,
next = pmd_addr_end(addr, end);
if (pmd_none_or_clear_bad(pmd))
continue;
- free_pte_range(tlb, pmd, addr);
+ free_pte_range(tlb, pmd, addr, shared_pte);
} while (pmd++, addr = next, addr != end);

start &= PUD_MASK;
@@ -210,13 +221,19 @@ static inline void free_pmd_range(struct mmu_gather *tlb, pud_t *pud,

pmd = pmd_offset(pud, start);
pud_clear(pud);
- pmd_free_tlb(tlb, pmd, start);
- mm_dec_nr_pmds(tlb->mm);
+ if (shared_pte) {
+ tlb_flush_pud_range(tlb, start, PAGE_SIZE);
+ tlb->freed_tables = 1;
+ } else {
+ pmd_free_tlb(tlb, pmd, start);
+ mm_dec_nr_pmds(tlb->mm);
+ }
}

static inline void free_pud_range(struct mmu_gather *tlb, p4d_t *p4d,
unsigned long addr, unsigned long end,
- unsigned long floor, unsigned long ceiling)
+ unsigned long floor, unsigned long ceiling,
+ bool shared_pte)
{
pud_t *pud;
unsigned long next;
@@ -228,7 +245,8 @@ static inline void free_pud_range(struct mmu_gather *tlb, p4d_t *p4d,
next = pud_addr_end(addr, end);
if (pud_none_or_clear_bad(pud))
continue;
- free_pmd_range(tlb, pud, addr, next, floor, ceiling);
+ free_pmd_range(tlb, pud, addr, next, floor, ceiling,
+ shared_pte);
} while (pud++, addr = next, addr != end);

start &= P4D_MASK;
@@ -250,7 +268,8 @@ static inline void free_pud_range(struct mmu_gather *tlb, p4d_t *p4d,

static inline void free_p4d_range(struct mmu_gather *tlb, pgd_t *pgd,
unsigned long addr, unsigned long end,
- unsigned long floor, unsigned long ceiling)
+ unsigned long floor, unsigned long ceiling,
+ bool shared_pte)
{
p4d_t *p4d;
unsigned long next;
@@ -262,7 +281,8 @@ static inline void free_p4d_range(struct mmu_gather *tlb, pgd_t *pgd,
next = p4d_addr_end(addr, end);
if (p4d_none_or_clear_bad(p4d))
continue;
- free_pud_range(tlb, p4d, addr, next, floor, ceiling);
+ free_pud_range(tlb, p4d, addr, next, floor, ceiling,
+ shared_pte);
} while (p4d++, addr = next, addr != end);

start &= PGDIR_MASK;
@@ -284,9 +304,10 @@ static inline void free_p4d_range(struct mmu_gather *tlb, pgd_t *pgd,
/*
* This function frees user-level page tables of a process.
*/
-void free_pgd_range(struct mmu_gather *tlb,
+static void _free_pgd_range(struct mmu_gather *tlb,
unsigned long addr, unsigned long end,
- unsigned long floor, unsigned long ceiling)
+ unsigned long floor, unsigned long ceiling,
+ bool shared_pte)
{
pgd_t *pgd;
unsigned long next;
@@ -342,10 +363,18 @@ void free_pgd_range(struct mmu_gather *tlb,
next = pgd_addr_end(addr, end);
if (pgd_none_or_clear_bad(pgd))
continue;
- free_p4d_range(tlb, pgd, addr, next, floor, ceiling);
+ free_p4d_range(tlb, pgd, addr, next, floor, ceiling,
+ shared_pte);
} while (pgd++, addr = next, addr != end);
}

+void free_pgd_range(struct mmu_gather *tlb,
+ unsigned long addr, unsigned long end,
+ unsigned long floor, unsigned long ceiling)
+{
+ _free_pgd_range(tlb, addr, end, floor, ceiling, false);
+}
+
void free_pgtables(struct mmu_gather *tlb, struct maple_tree *mt,
struct vm_area_struct *vma, unsigned long floor,
unsigned long ceiling)
@@ -375,16 +404,20 @@ void free_pgtables(struct mmu_gather *tlb, struct maple_tree *mt,
} else {
/*
* Optimization: gather nearby vmas into one call down
+ * but make sure vmas not sharing page tables do
+ * not get combined with vmas sharing page tables
*/
while (next && next->vm_start <= vma->vm_end + PMD_SIZE
- && !is_vm_hugetlb_page(next)) {
+ && !is_vm_hugetlb_page(next)
+ && (vma_is_shared(next) == vma_is_shared(vma))) {
vma = next;
next = mas_find(&mas, ceiling - 1);
unlink_anon_vmas(vma);
unlink_file_vma(vma);
}
- free_pgd_range(tlb, addr, vma->vm_end,
- floor, next ? next->vm_start : ceiling);
+ _free_pgd_range(tlb, addr, vma->vm_end,
+ floor, next ? next->vm_start : ceiling,
+ vma_is_shared(vma));
}
vma = next;
} while (vma);
@@ -5181,6 +5214,8 @@ vm_fault_t handle_mm_fault(struct vm_area_struct *vma, unsigned long address,
unsigned int flags, struct pt_regs *regs)
{
vm_fault_t ret;
+ bool shared = false;
+ struct mm_struct *orig_mm;

__set_current_state(TASK_RUNNING);

@@ -5191,6 +5226,16 @@ vm_fault_t handle_mm_fault(struct vm_area_struct *vma, unsigned long address,
if (ret)
return ret;

+ orig_mm = vma->vm_mm;
+ if (unlikely(vma_is_shared(vma))) {
+ ret = find_shared_vma(&vma, &address, flags);
+ if (ret)
+ return ret;
+ if (!vma)
+ return VM_FAULT_SIGSEGV;
+ shared = true;
+ }
+
if (!arch_vma_access_permitted(vma, flags & FAULT_FLAG_WRITE,
flags & FAULT_FLAG_INSTRUCTION,
flags & FAULT_FLAG_REMOTE))
@@ -5212,6 +5257,36 @@ vm_fault_t handle_mm_fault(struct vm_area_struct *vma, unsigned long address,

lru_gen_exit_fault();

+ /*
+ * Release the read lock on shared VMA's parent mm unless
+ * __handle_mm_fault released the lock already.
+ * __handle_mm_fault sets VM_FAULT_RETRY in return value if
+ * it released mmap lock. If lock was released, that implies
+ * the lock would have been released on task's original mm if
+ * this were not a shared PTE vma. To keep lock state consistent,
+ * make sure to release the lock on task's original mm
+ */
+ if (shared) {
+ int release_mmlock = 1;
+
+ if (!(ret & VM_FAULT_RETRY)) {
+ mmap_read_unlock(vma->vm_mm);
+ release_mmlock = 0;
+ } else if ((flags & FAULT_FLAG_ALLOW_RETRY) &&
+ (flags & FAULT_FLAG_RETRY_NOWAIT)) {
+ mmap_read_unlock(vma->vm_mm);
+ release_mmlock = 0;
+ }
+
+ /*
+ * Reset guest vma pointers that were set up in
+ * find_shared_vma() to process this fault.
+ */
+ vma->vm_mm = orig_mm;
+ if (release_mmlock)
+ mmap_read_unlock(orig_mm);
+ }
+
if (flags & FAULT_FLAG_USER) {
mem_cgroup_exit_user_fault();
/*
diff --git a/mm/ptshare.c b/mm/ptshare.c
index f6784268958c..e0991a877355 100644
--- a/mm/ptshare.c
+++ b/mm/ptshare.c
@@ -13,6 +13,136 @@
#include <asm/pgalloc.h>
#include "internal.h"

+/*
+ */
+static pmd_t
+*get_pmd(struct mm_struct *mm, unsigned long addr)
+{
+ pgd_t *pgd;
+ p4d_t *p4d;
+ pud_t *pud;
+ pmd_t *pmd;
+
+ pgd = pgd_offset(mm, addr);
+ if (pgd_none(*pgd))
+ return NULL;
+
+ p4d = p4d_offset(pgd, addr);
+ if (p4d_none(*p4d)) {
+ p4d = p4d_alloc(mm, pgd, addr);
+ if (!p4d)
+ return NULL;
+ }
+
+ pud = pud_offset(p4d, addr);
+ if (pud_none(*pud)) {
+ pud = pud_alloc(mm, p4d, addr);
+ if (!pud)
+ return NULL;
+ }
+
+ pmd = pmd_offset(pud, addr);
+ if (pmd_none(*pmd)) {
+ pmd = pmd_alloc(mm, pud, addr);
+ if (!pmd)
+ return NULL;
+ }
+
+ return pmd;
+}
+
+/*
+ * Find the shared pmd entries in host mm struct and install them into
+ * guest page tables.
+ */
+static int
+ptshare_copy_pmd(struct mm_struct *host_mm, struct mm_struct *guest_mm,
+ struct vm_area_struct *vma, unsigned long addr)
+{
+ pgd_t *guest_pgd;
+ p4d_t *guest_p4d;
+ pud_t *guest_pud;
+ pmd_t *host_pmd;
+ spinlock_t *host_ptl, *guest_ptl;
+
+ guest_pgd = pgd_offset(guest_mm, addr);
+ guest_p4d = p4d_offset(guest_pgd, addr);
+ if (p4d_none(*guest_p4d)) {
+ guest_p4d = p4d_alloc(guest_mm, guest_pgd, addr);
+ if (!guest_p4d)
+ return 1;
+ }
+
+ guest_pud = pud_offset(guest_p4d, addr);
+ if (pud_none(*guest_pud)) {
+ host_pmd = get_pmd(host_mm, addr);
+ if (!host_pmd)
+ return 1;
+
+ get_page(virt_to_page(host_pmd));
+ host_ptl = pmd_lockptr(host_mm, host_pmd);
+ guest_ptl = pud_lockptr(guest_mm, guest_pud);
+ spin_lock(host_ptl);
+ spin_lock(guest_ptl);
+ pud_populate(guest_mm, guest_pud,
+ (pmd_t *)((unsigned long)host_pmd & PAGE_MASK));
+ put_page(virt_to_page(host_pmd));
+ spin_unlock(guest_ptl);
+ spin_unlock(host_ptl);
+ }
+
+ return 0;
+}
+
+/*
+ * Find the shared page tables in hosting mm struct and install those in
+ * the guest mm struct
+ */
+vm_fault_t
+find_shared_vma(struct vm_area_struct **vmap, unsigned long *addrp,
+ unsigned int flags)
+{
+ struct ptshare_data *info;
+ struct mm_struct *host_mm;
+ struct vm_area_struct *host_vma, *guest_vma = *vmap;
+ unsigned long host_addr;
+ pmd_t *guest_pmd, *host_pmd;
+
+ if ((!guest_vma->vm_file) || (!guest_vma->vm_file->f_mapping))
+ return 0;
+ info = guest_vma->vm_file->f_mapping->ptshare_data;
+ if (!info) {
+ pr_warn("VM_SHARED_PT vma with NULL ptshare_data");
+ dump_stack_print_info(KERN_WARNING);
+ return 0;
+ }
+ host_mm = info->mm;
+
+ mmap_read_lock(host_mm);
+ host_addr = *addrp - guest_vma->vm_start + host_mm->mmap_base;
+ host_pmd = get_pmd(host_mm, host_addr);
+ guest_pmd = get_pmd(guest_vma->vm_mm, *addrp);
+ if (!pmd_same(*guest_pmd, *host_pmd)) {
+ set_pmd(guest_pmd, *host_pmd);
+ mmap_read_unlock(host_mm);
+ return VM_FAULT_NOPAGE;
+ }
+
+ *addrp = host_addr;
+ host_vma = find_vma(host_mm, host_addr);
+ if (!host_vma)
+ return VM_FAULT_SIGSEGV;
+
+ /*
+ * Point vm_mm for the faulting vma to the mm struct holding shared
+ * page tables so the fault handling will happen in the right
+ * shared context
+ */
+ guest_vma->vm_mm = host_mm;
+
+ return 0;
+}
+
/*
* Create a new mm struct that will hold the shared PTEs. Pointer to
* this new mm is stored in the data structure ptshare_data which also
@@ -38,6 +168,7 @@ ptshare_new_mm(struct file *file, struct vm_area_struct *vma)
new_mm->task_size = len;
if (!new_mm->task_size)
new_mm->task_size--;
+ new_mm->owner = NULL;

info = kzalloc(sizeof(*info), GFP_KERNEL);
if (!info) {
@@ -63,7 +194,7 @@ ptshare_new_mm(struct file *file, struct vm_area_struct *vma)
* insert vma into mm holding shared page tables
*/
int
-ptshare_insert_vma(struct mm_struct *mm, struct vm_area_struct *vma)
+ptshare_insert_vma(struct mm_struct *host_mm, struct vm_area_struct *vma)
{
struct vm_area_struct *new_vma;
int err = 0;
@@ -80,12 +211,18 @@ ptshare_insert_vma(struct mm_struct *mm, struct vm_area_struct *vma)
*/
vm_flags_clear(new_vma, (VM_SHARED | VM_SHARED_PT));
vm_flags_set(new_vma, VM_NOHUGEPAGE);
- new_vma->vm_mm = mm;
+ new_vma->vm_mm = host_mm;

- err = insert_vm_struct(mm, new_vma);
+ err = insert_vm_struct(host_mm, new_vma);
if (err)
return -ENOMEM;

+ /*
+ * Copy the PMD entries from host mm to guest so they use the
+ * same PTEs
+ */
+ err = ptshare_copy_pmd(host_mm, vma->vm_mm, vma, vma->vm_start);
+
return err;
}

--
2.37.2