Re: [RFC PATCH 4/4] Fix: sched/membarrier: p->mm->membarrier_state racy load

From: Peter Zijlstra
Date: Fri Sep 06 2019 - 04:24:11 EST


On Thu, Sep 05, 2019 at 11:13:00PM -0400, Mathieu Desnoyers wrote:

> diff --git a/include/linux/mm_types.h b/include/linux/mm_types.h
> index 6a7a1083b6fb..7020572eb605 100644
> --- a/include/linux/mm_types.h
> +++ b/include/linux/mm_types.h
> @@ -382,6 +382,9 @@ struct mm_struct {
> unsigned long task_size; /* size of task vm space */
> unsigned long highest_vm_end; /* highest vma end address */
> pgd_t * pgd;
> +#ifdef CONFIG_MEMBARRIER

Stick in a comment, on why here. To be close to data already used by
switch_mm().

> + atomic_t membarrier_state;
> +#endif
>
> /**
> * @mm_users: The number of users including userspace.

> diff --git a/kernel/sched/core.c b/kernel/sched/core.c
> index 010d578118d6..1cffc1aa403c 100644
> --- a/kernel/sched/core.c
> +++ b/kernel/sched/core.c
> @@ -3038,6 +3038,7 @@ prepare_task_switch(struct rq *rq, struct task_struct *prev,
> perf_event_task_sched_out(prev, next);
> rseq_preempt(prev);
> fire_sched_out_preempt_notifiers(prev, next);
> + membarrier_prepare_task_switch(rq, prev, next);

This had me confused for a while, because I initially thought we'd only
do this for switch_mm(), but you're made it agressive and track kernel
threads too.

I think we can do that slightly different. See below...

> prepare_task(next);
> prepare_arch_switch(next);
> }

> diff --git a/kernel/sched/membarrier.c b/kernel/sched/membarrier.c
> index 7e0a0d6535f3..5744c300d29e 100644
> --- a/kernel/sched/membarrier.c
> +++ b/kernel/sched/membarrier.c
> @@ -30,6 +30,28 @@ static void ipi_mb(void *info)

> +void membarrier_execve(struct task_struct *t)
> +{
> + atomic_set(&t->mm->membarrier_state, 0);
> + WRITE_ONCE(this_rq()->membarrier_state, 0);

It is the callsite of this one that had me puzzled and confused. I
think it works by accident more than anything else.

You see; I thought the rules were that we'd change it near/before
switch_mm(), and this is quite a way _after_.

I think it might be best to place the call in exec_mmap(), right before
activate_mm().

But that then had me wonder about the membarrier_prepate_task_switch()
thing...

> +/*
> + * The scheduler provides memory barriers required by membarrier between:
> + * - prior user-space memory accesses and store to rq->membarrier_state,
> + * - store to rq->membarrier_state and following user-space memory accesses.
> + * In the same way it provides those guarantees around store to rq->curr.
> + */
> +static inline void membarrier_prepare_task_switch(struct rq *rq,
> + struct task_struct *prev,
> + struct task_struct *next)
> +{
> + int membarrier_state = 0;
> + struct mm_struct *next_mm = next->mm;
> +
> + if (prev->mm == next_mm)
> + return;
> + if (next_mm)
> + membarrier_state = atomic_read(&next_mm->membarrier_state);
> + if (READ_ONCE(rq->membarrier_state) != membarrier_state)
> + WRITE_ONCE(rq->membarrier_state, membarrier_state);
> +}

So if you make the above something like:

static inline void
membarrier_switch_mm(struct rq *rq, struct mm_struct *prev_mm, struct mm_struct *next_mm)
{
int membarrier_state;

if (prev_mm == next_mm)
return;

membarrier_state = atomic_read(&next_mm->membarrier_state);
if (READ_ONCE(rq->membarrier_state) == membarrier_state)
return;

WRITE_ONCE(rq->membarrier_state, membarrier_state);
}

And put it right in front of switch_mm() in context_switch() then we'll
deal with kernel on the other side, like so:

> @@ -70,16 +90,13 @@ static int membarrier_global_expedited(void)
> if (cpu == raw_smp_processor_id())
> continue;
>
> - rcu_read_lock();
> - p = task_rcu_dereference(&cpu_rq(cpu)->curr);
> - if (p && p->mm && (atomic_read(&p->mm->membarrier_state) &
> - MEMBARRIER_STATE_GLOBAL_EXPEDITED)) {
> + if (READ_ONCE(cpu_rq(cpu)->membarrier_state) &
> + MEMBARRIER_STATE_GLOBAL_EXPEDITED) {

p = rcu_dereference(rq->curr);
if ((READ_ONCE(cpu_rq(cpu)->membarrier_state) & MEMBARRIER_STATE_GLOBAL_EXPEDITED) &&
!(p->flags & PF_KTHREAD))

> if (!fallback)
> __cpumask_set_cpu(cpu, tmpmask);
> else
> smp_call_function_single(cpu, ipi_mb, NULL, 1);
> }
> - rcu_read_unlock();
> }
> if (!fallback) {
> preempt_disable();

does that make sense?

(also, I hate how long all these membarrier names are)