Re: [PATCH] mm/userfaultfd: fix memory corruption due to writeprotect

From: Nadav Amit
Date: Tue Dec 22 2020 - 07:42:55 EST


> On Dec 21, 2020, at 1:24 PM, Yu Zhao <yuzhao@xxxxxxxxxx> wrote:
>
> On Mon, Dec 21, 2020 at 12:26:22PM -0800, Linus Torvalds wrote:
>> On Mon, Dec 21, 2020 at 12:23 PM Nadav Amit <nadav.amit@xxxxxxxxx> wrote:
>>> Using mmap_write_lock() was my initial fix and there was a strong pushback
>>> on this approach due to its potential impact on performance.
>>
>> From whom?
>>
>> Somebody who doesn't understand that correctness is more important
>> than performance? And that userfaultfd is not the most important part
>> of the system?
>>
>> The fact is, userfaultfd is CLEARLY BUGGY.
>>
>> Linus
>
> Fair enough.
>
> Nadav, for your patch (you might want to update the commit message).
>
> Reviewed-by: Yu Zhao <yuzhao@xxxxxxxxxx>
>
> While we are all here, there is also clear_soft_dirty() that could
> use a similar fix…

Just an update as for why I have still not sent v2: I fixed
clear_soft_dirty(), created a reproducer, and the reproducer kept failing.

So after some debugging, it appears that clear_refs_write() does not flush
the TLB. It indeed calls tlb_finish_mmu() but since 0758cd830494
("asm-generic/tlb: avoid potential double flush”), tlb_finish_mmu() does not
flush the TLB since there is clear_refs_write() does not call to
__tlb_adjust_range() (unless there are nested TLBs are pending).

So I have a patch for this issue too: arguably the tlb_gather interface is
not the right one for clear_refs_write() that does not clear PTEs but
changes them.

Yet, sadly, my reproducer keeps falling (less frequently, but still). So I
will keep debugging to see what goes wrong. I will send v2 once I figure out
what the heck is wrong in the code or my reproducer.

For the reference, here is my reproducer:

-- >8 --

#define _GNU_SOURCE
#include <sys/types.h>
#include <sys/stat.h>
#include <sys/mman.h>
#include <unistd.h>
#include <stdio.h>
#include <stdlib.h>
#include <assert.h>
#include <fcntl.h>
#include <string.h>
#include <threads.h>
#include <stdatomic.h>

#define PAGE_SIZE (4096)
#define TLB_SIZE (2000)
#define N_PAGES (300000)
#define ITERATIONS (100)
#define N_THREADS (2)

static int stop;
static char *m;

static int writer(void *argp)
{
unsigned long t_idx = (unsigned long)argp;
int i, cnt = 0;

while (!atomic_load(&stop)) {
cnt++;
atomic_fetch_add((atomic_int *)m, 1);

/*
* First thread only accesses the page to have it cached in the
* TLB.
*/
if (t_idx == 0)
continue;

/*
* Other threads access enough entries to cause eviction from
* the TLB and trigger #PF upon the next access (before the TLB
* flush of clear_ref actually takes place).
*/
for (i = 1; i < TLB_SIZE; i++) {
if (atomic_load((atomic_int *)(m + PAGE_SIZE * i))) {
fprintf(stderr, "unexpected error\n");
exit(1);
}
}
}
return cnt;
}

/*
* Runs mlock/munlock in the background to raise the page-count of the page and
* force copying instead of reusing the page.
*/
static int do_mlock(void *argp)
{
while (!atomic_load(&stop)) {
if (mlock(m, PAGE_SIZE) || munlock(m, PAGE_SIZE)) {
perror("mlock/munlock");
exit(1);
}
}
return 0;
}

int main(void)
{
int r, cnt, fd, total = 0;
long i;
thrd_t thr[N_THREADS];
thrd_t mlock_thr[N_THREADS];

fd = open("/proc/self/clear_refs", O_WRONLY, 0666);
if (fd < 0) {
perror("open");
exit(1);
}

/*
* Have large memory for clear_ref, so there would be some time between
* the unmap and the actual deferred flush.
*/
m = mmap(NULL, PAGE_SIZE * N_PAGES, PROT_READ|PROT_WRITE,
MAP_PRIVATE|MAP_ANONYMOUS|MAP_POPULATE, -1, 0);
if (m == MAP_FAILED) {
perror("mmap");
exit(1);
}

for (i = 0; i < N_THREADS; i++) {
r = thrd_create(&thr[i], writer, (void *)i);
assert(r == thrd_success);
}

for (i = 0; i < N_THREADS; i++) {
r = thrd_create(&mlock_thr[i], do_mlock, (void *)i);
assert(r == thrd_success);
}

for (i = 0; i < ITERATIONS; i++) {
for (i = 0; i < ITERATIONS; i++) {
r = pwrite(fd, "4", 1, 0);
if (r < 0) {
perror("pwrite");
exit(1);
}
}
}

atomic_store(&stop, 1);

for (i = 0; i < N_THREADS; i++) {
r = thrd_join(mlock_thr[i], NULL);
assert(r == thrd_success);
}

for (i = 0; i < N_THREADS; i++) {
r = thrd_join(thr[i], &cnt);
assert(r == thrd_success);
total += cnt;
}

r = atomic_load((atomic_int *)(m));
if (r != total) {
fprintf(stderr, "failed - expected=%d actual=%d\n", total, r);
exit(-1);
}

fprintf(stderr, "ok\n");
return 0;
}