[RFC PATCH 04/11] mm/mempolicy: modify get_mempolicy call stack to take a task argument

From: Gregory Price
Date: Wed Nov 22 2023 - 16:12:34 EST


To make mempolicy fetchable by external tasks, we must first change
the callstack to take a task as an argument.

Modify the following functions to require a task argument:
do_get_mempolicy
kernel_get_mempolicy

The way the task->mm is acquired must change slightly to enable this
change. Originally, do_get_mempolicy would acquire the task->mm
directly via (current->mm). This is unsafe to do in a non-current
context. However, utilizing get_task_mm would break the original
functionality of do_get_mempolicy due to the following check
in get_task_mm:

if (mm) {
if (task->flags & PF_KTHREAD)
mm = NULL;
else
mmget(mm);
}

To retain the original behavior, if (task == current) we access
the task->mm directly, but if (task != current) we will utilize
get_task_mm to safely access the mm.

We simplify the get/put mechanics by always taking a reference to
the mm, even if we are in the context of (task == current).

Additionally, since the mempolicy will become externally modifiable,
we need to take the task lock to acquire task->mempolicy safely,
regardless of whether we are operating on current or not.

Signed-off-by: Gregory Price <gregory.price@xxxxxxxxxxxx>
---
mm/mempolicy.c | 43 +++++++++++++++++++++++++++++--------------
1 file changed, 29 insertions(+), 14 deletions(-)

diff --git a/mm/mempolicy.c b/mm/mempolicy.c
index 9ea3e1bfc002..4519f39b1a07 100644
--- a/mm/mempolicy.c
+++ b/mm/mempolicy.c
@@ -899,8 +899,9 @@ static int lookup_node(struct mm_struct *mm, unsigned long addr)
}

/* Retrieve NUMA policy */
-static long do_get_mempolicy(int *policy, nodemask_t *nmask,
- unsigned long addr, unsigned long flags)
+static long do_get_mempolicy(struct task_struct *task, int *policy,
+ nodemask_t *nmask, unsigned long addr,
+ unsigned long flags)
{
int err;
struct mm_struct *mm;
@@ -915,9 +916,9 @@ static long do_get_mempolicy(int *policy, nodemask_t *nmask,
if (flags & (MPOL_F_NODE|MPOL_F_ADDR))
return -EINVAL;
*policy = 0; /* just so it's initialized */
- task_lock(current);
- *nmask = cpuset_current_mems_allowed;
- task_unlock(current);
+ task_lock(task);
+ *nmask = task->mems_allowed;
+ task_unlock(task);
return 0;
}

@@ -928,7 +929,16 @@ static long do_get_mempolicy(int *policy, nodemask_t *nmask,
* vma/shared policy at addr is NULL. We
* want to return MPOL_DEFAULT in this case.
*/
- mm = current->mm;
+ if (task == current) {
+ /*
+ * original behavior allows a kernel task changing its
+ * own policy to avoid the condition in get_task_mm,
+ * so we'll directly access
+ */
+ mm = task->mm;
+ mmget(mm);
+ } else
+ mm = get_task_mm(task);
mmap_read_lock(mm);
vma = vma_lookup(mm, addr);
if (!vma) {
@@ -947,8 +957,10 @@ static long do_get_mempolicy(int *policy, nodemask_t *nmask,
return -EINVAL;
else {
/* take a reference of the task policy now */
- pol = current->mempolicy;
+ task_lock(task);
+ pol = task->mempolicy;
mpol_get(pol);
+ task_unlock(task);
}

if (!pol) {
@@ -962,12 +974,13 @@ static long do_get_mempolicy(int *policy, nodemask_t *nmask,
vma = NULL;
mmap_read_unlock(mm);
err = lookup_node(mm, addr);
+ mmput(mm);
if (err < 0)
goto out;
*policy = err;
- } else if (pol == current->mempolicy &&
+ } else if (pol == task->mempolicy &&
pol->mode == MPOL_INTERLEAVE) {
- *policy = next_node_in(current->il_prev, pol->nodes);
+ *policy = next_node_in(task->il_prev, pol->nodes);
} else {
err = -EINVAL;
goto out;
@@ -987,9 +1000,9 @@ static long do_get_mempolicy(int *policy, nodemask_t *nmask,
if (mpol_store_user_nodemask(pol)) {
*nmask = pol->w.user_nodemask;
} else {
- task_lock(current);
+ task_lock(task);
get_policy_nodemask(pol, nmask);
- task_unlock(current);
+ task_unlock(task);
}
}

@@ -1704,7 +1717,8 @@ SYSCALL_DEFINE4(migrate_pages, pid_t, pid, unsigned long, maxnode,
}

/* Retrieve NUMA policy */
-static int kernel_get_mempolicy(int __user *policy,
+static int kernel_get_mempolicy(struct task_struct *task,
+ int __user *policy,
unsigned long __user *nmask,
unsigned long maxnode,
unsigned long addr,
@@ -1719,7 +1733,7 @@ static int kernel_get_mempolicy(int __user *policy,

addr = untagged_addr(addr);

- err = do_get_mempolicy(&pval, &nodes, addr, flags);
+ err = do_get_mempolicy(task, &pval, &nodes, addr, flags);

if (err)
return err;
@@ -1737,7 +1751,8 @@ SYSCALL_DEFINE5(get_mempolicy, int __user *, policy,
unsigned long __user *, nmask, unsigned long, maxnode,
unsigned long, addr, unsigned long, flags)
{
- return kernel_get_mempolicy(policy, nmask, maxnode, addr, flags);
+ return kernel_get_mempolicy(current, policy, nmask, maxnode, addr,
+ flags);
}

bool vma_migratable(struct vm_area_struct *vma)
--
2.39.1