Re: [PATCH v3 9/9] fork: Use __mt_dup() to duplicate maple tree in dup_mmap()

From: Peng Zhang
Date: Sat Oct 07 2023 - 00:26:21 EST




在 2023/10/7 09:32, Liam R. Howlett 写道:
...


[1] https://github.com/kdlucas/byte-unixbench/tree/master

Signed-off-by: Peng Zhang <zhangpeng.00@xxxxxxxxxxxxx>
---
include/linux/mm.h | 1 +
kernel/fork.c | 34 ++++++++++++++++++++----------
mm/internal.h | 3 ++-
mm/memory.c | 7 ++++---
mm/mmap.c | 52 ++++++++++++++++++++++++++++++++++++++++++++--
5 files changed, 80 insertions(+), 17 deletions(-)

diff --git a/include/linux/mm.h b/include/linux/mm.h
index 1f1d0d6b8f20..10c59dc7ffaa 100644
--- a/include/linux/mm.h
+++ b/include/linux/mm.h
@@ -3242,6 +3242,7 @@ extern void unlink_file_vma(struct vm_area_struct *);
extern struct vm_area_struct *copy_vma(struct vm_area_struct **,
unsigned long addr, unsigned long len, pgoff_t pgoff,
bool *need_rmap_locks);
+extern void undo_dup_mmap(struct mm_struct *mm, struct vm_area_struct *vma_end);
extern void exit_mmap(struct mm_struct *);
static inline int check_data_rlimit(unsigned long rlim,
diff --git a/kernel/fork.c b/kernel/fork.c
index 7ae36c2e7290..2f3d83e89fe6 100644
--- a/kernel/fork.c
+++ b/kernel/fork.c
@@ -650,7 +650,6 @@ static __latent_entropy int dup_mmap(struct mm_struct *mm,
int retval;
unsigned long charge = 0;
LIST_HEAD(uf);
- VMA_ITERATOR(old_vmi, oldmm, 0);
VMA_ITERATOR(vmi, mm, 0);
uprobe_start_dup_mmap();
@@ -678,16 +677,25 @@ static __latent_entropy int dup_mmap(struct mm_struct *mm,
goto out;
khugepaged_fork(mm, oldmm);
- retval = vma_iter_bulk_alloc(&vmi, oldmm->map_count);
- if (retval)
+ /* Use __mt_dup() to efficiently build an identical maple tree. */
+ retval = __mt_dup(&oldmm->mm_mt, &mm->mm_mt, GFP_KERNEL);
+ if (unlikely(retval))
goto out;
mt_clear_in_rcu(vmi.mas.tree);
- for_each_vma(old_vmi, mpnt) {
+ for_each_vma(vmi, mpnt) {
struct file *file;
vma_start_write(mpnt);
if (mpnt->vm_flags & VM_DONTCOPY) {
+ mas_store_gfp(&vmi.mas, NULL, GFP_KERNEL);
+
+ /* If failed, undo all completed duplications. */
+ if (unlikely(mas_is_err(&vmi.mas))) {
+ retval = xa_err(vmi.mas.node);
+ goto loop_out;
+ }
+
vm_stat_account(mm, mpnt->vm_flags, -vma_pages(mpnt));
continue;
}
@@ -749,9 +757,11 @@ static __latent_entropy int dup_mmap(struct mm_struct *mm,
if (is_vm_hugetlb_page(tmp))
hugetlb_dup_vma_private(tmp);
- /* Link the vma into the MT */
- if (vma_iter_bulk_store(&vmi, tmp))
- goto fail_nomem_vmi_store;
+ /*
+ * Link the vma into the MT. After using __mt_dup(), memory
+ * allocation is not necessary here, so it cannot fail.
+ */
+ mas_store(&vmi.mas, tmp);
mm->map_count++;
if (!(tmp->vm_flags & VM_WIPEONFORK))
@@ -760,15 +770,19 @@ static __latent_entropy int dup_mmap(struct mm_struct *mm,
if (tmp->vm_ops && tmp->vm_ops->open)
tmp->vm_ops->open(tmp);
- if (retval)
+ if (retval) {
+ mpnt = vma_next(&vmi);
goto loop_out;
+ }
}
/* a new mm has just been created */
retval = arch_dup_mmap(oldmm, mm);
loop_out:
vma_iter_free(&vmi);
- if (!retval)
+ if (likely(!retval))
mt_set_in_rcu(vmi.mas.tree);
+ else
+ undo_dup_mmap(mm, mpnt);
out:
mmap_write_unlock(mm);
flush_tlb_mm(oldmm);
@@ -778,8 +792,6 @@ static __latent_entropy int dup_mmap(struct mm_struct *mm,
uprobe_end_dup_mmap();
return retval;
-fail_nomem_vmi_store:
- unlink_anon_vmas(tmp);
fail_nomem_anon_vma_fork:
mpol_put(vma_policy(tmp));
fail_nomem_policy:
diff --git a/mm/internal.h b/mm/internal.h
index 7a961d12b088..288ec81770cb 100644
--- a/mm/internal.h
+++ b/mm/internal.h
@@ -111,7 +111,8 @@ void folio_activate(struct folio *folio);
void free_pgtables(struct mmu_gather *tlb, struct ma_state *mas,
struct vm_area_struct *start_vma, unsigned long floor,
- unsigned long ceiling, bool mm_wr_locked);
+ unsigned long ceiling, unsigned long tree_end,
+ bool mm_wr_locked);
void pmd_install(struct mm_struct *mm, pmd_t *pmd, pgtable_t *pte);
struct zap_details;
diff --git a/mm/memory.c b/mm/memory.c
index 983a40f8ee62..1fd66a0d5838 100644
--- a/mm/memory.c
+++ b/mm/memory.c
@@ -362,7 +362,8 @@ void free_pgd_range(struct mmu_gather *tlb,
void free_pgtables(struct mmu_gather *tlb, struct ma_state *mas,
struct vm_area_struct *vma, unsigned long floor,
- unsigned long ceiling, bool mm_wr_locked)
+ unsigned long ceiling, unsigned long tree_end,
+ bool mm_wr_locked)
{
do {
unsigned long addr = vma->vm_start;
@@ -372,7 +373,7 @@ void free_pgtables(struct mmu_gather *tlb, struct ma_state *mas,
* Note: USER_PGTABLES_CEILING may be passed as ceiling and may
* be 0. This will underflow and is okay.
*/
- next = mas_find(mas, ceiling - 1);
+ next = mas_find(mas, tree_end - 1);
/*
* Hide vma from rmap and truncate_pagecache before freeing
@@ -393,7 +394,7 @@ void free_pgtables(struct mmu_gather *tlb, struct ma_state *mas,
while (next && next->vm_start <= vma->vm_end + PMD_SIZE
&& !is_vm_hugetlb_page(next)) {
vma = next;
- next = mas_find(mas, ceiling - 1);
+ next = mas_find(mas, tree_end - 1);
if (mm_wr_locked)
vma_start_write(vma);
unlink_anon_vmas(vma);
diff --git a/mm/mmap.c b/mm/mmap.c
index 2ad950f773e4..daed3b423124 100644
--- a/mm/mmap.c
+++ b/mm/mmap.c
@@ -2312,7 +2312,7 @@ static void unmap_region(struct mm_struct *mm, struct ma_state *mas,
mas_set(mas, mt_start);
free_pgtables(&tlb, mas, vma, prev ? prev->vm_end : FIRST_USER_ADDRESS,
next ? next->vm_start : USER_PGTABLES_CEILING,
- mm_wr_locked);
+ tree_end, mm_wr_locked);
tlb_finish_mmu(&tlb);
}
@@ -3178,6 +3178,54 @@ int vm_brk(unsigned long addr, unsigned long len)
}
EXPORT_SYMBOL(vm_brk);
+void undo_dup_mmap(struct mm_struct *mm, struct vm_area_struct *vma_end)
+{
+ unsigned long tree_end;
+ VMA_ITERATOR(vmi, mm, 0);
+ struct vm_area_struct *vma;
+ unsigned long nr_accounted = 0;
+ int count = 0;
+
+ /*
+ * vma_end points to the first VMA that has not been duplicated. We need
+ * to unmap all VMAs before it.
+ * If vma_end is NULL, it means that all VMAs in the maple tree have
+ * been duplicated, so setting tree_end to 0 will overflow to ULONG_MAX
+ * when using it.
+ */
+ if (vma_end) {
+ tree_end = vma_end->vm_start;
+ if (tree_end == 0)
+ goto destroy;
+ } else
+ tree_end = 0;

You need to enclose this statement to meet the coding style. You could
just set tree_end = 0 at the start of the function instead, actually I
think tree_end = USER_PGTABLES_CEILING unless there is a vma_end.

+
+ vma = mas_find(&vmi.mas, tree_end - 1);

vma = vma_find(&vmi, tree_end);

+
+ if (vma) {

Probably would be cleaner to jump to destroy here too:
if (!vma)
goto destroy;

+ arch_unmap(mm, vma->vm_start, tree_end);

One more thing, it seems the maple state that is passed into
unmap_region() needs to point to the _next_ element, or the reset
doesn't work right between the unmap_vmas() and free_pgtables() call:

vma_iter_set(&vmi, vma->vm_end);


+ unmap_region(mm, &vmi.mas, vma, NULL, NULL, 0, tree_end,
+ tree_end, true);

next is vma_end, as per your comment above. Using next = vma_end allows
you to avoid adding another argument to free_pgtables().
Unfortunately, it cannot be done this way. I fell into this trap before,
and it caused incomplete page table cleanup. To solve this problem, the
only solution I can think of right now is to add an additional
parameter.

free_pgtables() will be called in unmap_region() to free the page table,
like this:

free_pgtables(&tlb, mas, vma, prev ? prev->vm_end : FIRST_USER_ADDRESS,
next ? next->vm_start : USER_PGTABLES_CEILING,
mm_wr_locked);

The problem is with 'next'. Our 'vma_end' does not exist in the actual
mmap because it has not been duplicated and cannot be used as 'next'.
If there is a real 'next', we can use 'next->vm_start' as the ceiling,
which is not a problem. If there is no 'next' (next is 'vma_end'), we
can only use 'USER_PGTABLES_CEILING' as the ceiling. Using
'vma_end->vm_start' as the ceiling will cause the page table not to be
fully freed, which may be related to alignment in 'free_pgd_range()'. To
solve this problem, we have to introduce 'tree_end', and separating
'tree_end' and 'ceiling' can solve this problem.

Can you just use ceiling? That is, just not pass in next and keep the
code as-is? This is how exit_mmap() does it and should avoid any
alignment issues. I assume you tried that and something went wrong as
well?
I tried that, but it didn't work either. In free_pgtables(), the
following line of code is used to iterate over VMAs:
mas_find(mas, ceiling - 1);
If next is passed as NULL, ceiling will be 0, resulting in iterating
over all the VMAs in the maple tree, including the last portion that was
not duplicated.

If vma_end is NULL, it means that all VMAs in the maple tree have been
duplicated, so shouldn't the correct action in this case be freeing up
to ceiling?
Yes, that's correct.

If it isn't null, then vma_end->vm_start should work as the end of the
area to free.
But there's an issue here. I initially thought the same way, but the
behavior of free_pgtables() is very strange. For the last VMA, it seems
that the ceiling passed to free_pgd_range() must be
USER_PGTABLES_CEILING.

It cannot be used vma_end->vm_start as the ceiling, possibly due to the
peculiar alignment behavior in free_pgd_range().

The code is from free_pgd_range():
if (ceiling) {
ceiling &= PMD_MASK;
if (!ceiling)
return;
}
I suspect it is related to this part. The behavior differs when the
ceiling is equal to 0 or non-zero. However, I cannot comprehend all the
details here.


With your mas_find(mas, tree_end - 1), then the vma_end will be avoided,
but free_pgd_range() will use ceiling anyways:

free_pgd_range(tlb, addr, vma->vm_end, floor, next ? next->vm_start : ceiling);

Passing in vma_end as next to unmap_region() functions in my testing
without adding arguments to free_pgtables().

How are you producing the accounting issue you mention above? Maybe I
missed something?
You can apply the patch provided at the bottom, and then use the test
program in Attachment 1 to reproduce the issue.

In dmesg, the kernel will report the following error:
[ 14.829561] BUG: non-zero pgtables_bytes on freeing mm: 12288
[ 14.832445] BUG: non-zero pgtables_bytes on freeing mm: 12288

diff --git a/kernel/fork.c b/kernel/fork.c
index 5f24f6d68ea4..fcc66acac480 100644
--- a/kernel/fork.c
+++ b/kernel/fork.c
@@ -688,7 +688,11 @@ static __latent_entropy int dup_mmap(struct mm_struct *mm,

vma_start_write(mpnt);
if (mpnt->vm_flags & VM_DONTCOPY) {
- mas_store_gfp(&vmi.mas, NULL, GFP_KERNEL);
+ if (!strcmp(current->comm, "fork_test") && ktime_get_ns() % 2) {
+ vmi.mas.node = MA_ERROR(-ENOMEM);
+ } else {
+ mas_store_gfp(&vmi.mas, NULL, GFP_KERNEL);
+ }

/* If failed, undo all completed duplications. */
if (unlikely(mas_is_err(&vmi.mas))) {






+
+ mas_set(&vmi.mas, vma->vm_end);
vma_iter_set(&vmi, vma->vm_end);
+ do {
+ if (vma->vm_flags & VM_ACCOUNT)
+ nr_accounted += vma_pages(vma);
+ remove_vma(vma, true);
+ count++;
+ cond_resched();
+ vma = mas_find(&vmi.mas, tree_end - 1);
+ } while (vma != NULL);

You can write this as:
do { ... } for_each_vma_range(vmi, vma, tree_end);

+
+ BUG_ON(count != mm->map_count);
+
+ vm_unacct_memory(nr_accounted);
+ }
+
+destroy:
+ __mt_destroy(&mm->mm_mt);
+}
+
/* Release all mmaps. */
void exit_mmap(struct mm_struct *mm)
{
@@ -3217,7 +3265,7 @@ void exit_mmap(struct mm_struct *mm)
mt_clear_in_rcu(&mm->mm_mt);
mas_set(&mas, vma->vm_end);
free_pgtables(&tlb, &mas, vma, FIRST_USER_ADDRESS,
- USER_PGTABLES_CEILING, true);
+ USER_PGTABLES_CEILING, USER_PGTABLES_CEILING, true);
tlb_finish_mmu(&tlb);
/*
--
2.20.1




#include <stdio.h>
#include <stdlib.h>
#include <sys/mman.h>
#include <unistd.h>
#include <sys/wait.h>

int main()
{
int cnt_success = 0, cnt_failure = 0;
int status;

void *addr = mmap(NULL, 4096, PROT_READ | PROT_WRITE,
MAP_PRIVATE | MAP_ANONYMOUS, -1, 0);
if (addr == MAP_FAILED) {
perror("mmap failed");
exit(1);
}
if (mprotect(addr, 4096, PROT_READ | PROT_WRITE | PROT_EXEC) == -1) {
perror("mprotect failed");
exit(1);
}
if (madvise(addr, 4096, MADV_DONTFORK) == -1) {
perror("madvise failed");
exit(1);
}
printf("VMA created at address %p\n", addr);

for (int i = 0; i < 10000; i++) {
pid_t pid = fork();
if (pid == -1) {
cnt_failure++;
} else if (pid == 0) {
exit(EXIT_SUCCESS);
} else {
cnt_success++;
wait(&status);
if (status != 0) {
fprintf(stderr, "Bad wait status: 0x%x\n",
status);
exit(2);
}
}
}

printf("success:%d failure:%d\n", cnt_success, cnt_failure);
return 0;
}