[PATCH v2 05/11] mm/mempolicy: refactor kernel_get_mempolicy for code re-use

From: Gregory Price
Date: Sat Dec 09 2023 - 02:00:16 EST


Pull operation flag checking from inside do_get_mempolicy out
to kernel_get_mempolicy. This allows us to flatten the
internal code, and break it into separate functions for future
syscalls (get_mempolicy2, process_get_mempolicy) to re-use the
code, even after additional extensions are made.

The primary change is that the flag is treated as the multiplexer
that it actually is. For get_mempolicy, the flags represents 3
different primary operations:

if (flags & MPOL_F_MEMS_ALLOWED)
return task->mems_allowed
else if (flags & MPOL_F_ADDR)
return vma mempolicy information
else
return task mempolicy information

Plus the behavior modifying flag:

if (flags & MPOL_F_NODE)
change the return value of (int __user *policy)
based on whether MPOL_F_ADDR was set.

The original behavior of get_mempolicy is retained, but we utilize
the new mempolicy_args structure to pass the operations down the
stack. This will allow us to extend the internal functions without
affecting the legacy behavior of get_mempolicy.

Signed-off-by: Gregory Price <gregory.price@xxxxxxxxxxxx>
---
mm/mempolicy.c | 240 ++++++++++++++++++++++++++++++-------------------
1 file changed, 150 insertions(+), 90 deletions(-)

diff --git a/mm/mempolicy.c b/mm/mempolicy.c
index 324dbf1782df..ce5b7963e9b5 100644
--- a/mm/mempolicy.c
+++ b/mm/mempolicy.c
@@ -895,106 +895,107 @@ static int lookup_node(struct mm_struct *mm, unsigned long addr)
return ret;
}

-/* Retrieve NUMA policy */
-static long do_get_mempolicy(int *policy, nodemask_t *nmask,
- unsigned long addr, unsigned long flags)
+/* Retrieve the mems_allowed for current task */
+static inline long do_get_mems_allowed(nodemask_t *nmask)
{
- int err;
- struct mm_struct *mm = current->mm;
- struct vm_area_struct *vma = NULL;
- struct mempolicy *pol = current->mempolicy, *pol_refcount = NULL;
+ task_lock(current);
+ *nmask = cpuset_current_mems_allowed;
+ task_unlock(current);
+ return 0;
+}

- if (flags &
- ~(unsigned long)(MPOL_F_NODE|MPOL_F_ADDR|MPOL_F_MEMS_ALLOWED))
- return -EINVAL;
+/* If the policy has additional node information to retrieve, return it */
+static long do_get_policy_node(struct mempolicy *pol)
+{
+ /*
+ * For MPOL_INTERLEAVE, the extended node information is the next
+ * node that will be selected for interleave. For weighted interleave
+ * we return the next node based on the current weight.
+ */
+ if (pol == current->mempolicy && pol->mode == MPOL_INTERLEAVE)
+ return next_node_in(current->il_prev, pol->nodes);

- if (flags & MPOL_F_MEMS_ALLOWED) {
- if (flags & (MPOL_F_NODE|MPOL_F_ADDR))
- return -EINVAL;
- *policy = 0; /* just so it's initialized */
+ if (pol == current->mempolicy &&
+ pol->mode == MPOL_WEIGHTED_INTERLEAVE) {
+ if (pol->wil.cur_weight)
+ return current->il_prev;
+ else
+ return next_node_in(current->il_prev, pol->nodes);
+ }
+ return -EINVAL;
+}
+
+/* Handle user_nodemask condition when fetching nodemask for userspace */
+static void do_get_mempolicy_nodemask(struct mempolicy *pol, nodemask_t *nmask)
+{
+ if (mpol_store_user_nodemask(pol)) {
+ *nmask = pol->w.user_nodemask;
+ } else {
task_lock(current);
- *nmask = cpuset_current_mems_allowed;
+ get_policy_nodemask(pol, nmask);
task_unlock(current);
- return 0;
}
+}

- if (flags & MPOL_F_ADDR) {
- pgoff_t ilx; /* ignored here */
- /*
- * Do NOT fall back to task policy if the
- * vma/shared policy at addr is NULL. We
- * want to return MPOL_DEFAULT in this case.
- */
- mmap_read_lock(mm);
- vma = vma_lookup(mm, addr);
- if (!vma) {
- mmap_read_unlock(mm);
- return -EFAULT;
- }
- pol = __get_vma_policy(vma, addr, &ilx);
- } else if (addr)
- return -EINVAL;
+/* Retrieve NUMA policy for a VMA assocated with a given address */
+static long do_get_vma_mempolicy(struct mempolicy_args *args)
+{
+ pgoff_t ilx;
+ struct mm_struct *mm = current->mm;
+ struct vm_area_struct *vma = NULL;
+ struct mempolicy *pol = NULL;

+ mmap_read_lock(mm);
+ vma = vma_lookup(mm, args->addr);
+ if (!vma) {
+ mmap_read_unlock(mm);
+ return -EFAULT;
+ }
+ pol = __get_vma_policy(vma, args->addr, &ilx);
if (!pol)
- pol = &default_policy; /* indicates default behavior */
+ pol = &default_policy;
+ /* this may cause a double-reference, resolved by a put+cond_put */
+ mpol_get(pol);
+ mmap_read_unlock(mm);

- if (flags & MPOL_F_NODE) {
- if (flags & MPOL_F_ADDR) {
- /*
- * Take a refcount on the mpol, because we are about to
- * drop the mmap_lock, after which only "pol" remains
- * valid, "vma" is stale.
- */
- pol_refcount = pol;
- vma = NULL;
- mpol_get(pol);
- mmap_read_unlock(mm);
- err = lookup_node(mm, addr);
- if (err < 0)
- goto out;
- *policy = err;
- } else if (pol == current->mempolicy &&
- pol->mode == MPOL_INTERLEAVE) {
- *policy = next_node_in(current->il_prev, pol->nodes);
- } else if (pol == current->mempolicy &&
- (pol->mode == MPOL_WEIGHTED_INTERLEAVE)) {
- if (pol->wil.cur_weight)
- *policy = current->il_prev;
- else
- *policy = next_node_in(current->il_prev,
- pol->nodes);
- } else {
- err = -EINVAL;
- goto out;
- }
- } else {
- *policy = pol == &default_policy ? MPOL_DEFAULT :
- pol->mode;
- /*
- * Internal mempolicy flags must be masked off before exposing
- * the policy to userspace.
- */
- *policy |= (pol->flags & MPOL_MODE_FLAGS);
- }
+ /* Fetch the node for the given address */
+ args->addr_node = lookup_node(mm, args->addr);

- err = 0;
- if (nmask) {
- if (mpol_store_user_nodemask(pol)) {
- *nmask = pol->w.user_nodemask;
- } else {
- task_lock(current);
- get_policy_nodemask(pol, nmask);
- task_unlock(current);
- }
+ args->mode = pol == &default_policy ? MPOL_DEFAULT : pol->mode;
+ args->mode_flags = (pol->flags & MPOL_MODE_FLAGS);
+
+ /* If this policy has extra node info, fetch that */
+ args->policy_node = do_get_policy_node(pol);
+
+ if (args->policy_nodes)
+ do_get_mempolicy_nodemask(pol, args->policy_nodes);
+
+ if (pol != &default_policy) {
+ mpol_put(pol);
+ mpol_cond_put(pol);
}

- out:
- mpol_cond_put(pol);
- if (vma)
- mmap_read_unlock(mm);
- if (pol_refcount)
- mpol_put(pol_refcount);
- return err;
+ return 0;
+}
+
+/* Retrieve NUMA policy for the current task */
+static long do_get_task_mempolicy(struct mempolicy_args *args)
+{
+ struct mempolicy *pol = current->mempolicy;
+
+ if (!pol)
+ pol = &default_policy; /* indicates default behavior */
+
+ args->mode = pol == &default_policy ? MPOL_DEFAULT : pol->mode;
+ /* Internal flags must be masked off before exposing to userspace */
+ args->mode_flags = (pol->flags & MPOL_MODE_FLAGS);
+
+ args->policy_node = do_get_policy_node(pol);
+
+ if (args->policy_nodes)
+ do_get_mempolicy_nodemask(pol, args->policy_nodes);
+
+ return 0;
}

#ifdef CONFIG_MIGRATION
@@ -1731,16 +1732,75 @@ static int kernel_get_mempolicy(int __user *policy,
unsigned long addr,
unsigned long flags)
{
+ struct mempolicy_args args;
int err;
- int pval;
+ int pval = 0;
nodemask_t nodes;

if (nmask != NULL && maxnode < nr_node_ids)
return -EINVAL;

- addr = untagged_addr(addr);
+ if (flags &
+ ~(unsigned long)(MPOL_F_NODE|MPOL_F_ADDR|MPOL_F_MEMS_ALLOWED))
+ return -EINVAL;

- err = do_get_mempolicy(&pval, &nodes, addr, flags);
+ /* Ensure any data that may be copied to userland is initialized */
+ memset(&args, 0, sizeof(args));
+ args.policy_nodes = &nodes;
+ args.addr = untagged_addr(addr);
+
+ /*
+ * set_mempolicy was originally multiplexed based on 3 flags:
+ * MPOL_F_MEMS_ALLOWED: fetch task->mems_allowed
+ * MPOL_F_ADDR : operate on vma->mempolicy
+ * MPOL_F_NODE : change return value of *policy
+ *
+ * Split this behavior out here, rather than internal functions,
+ * so that the internal functions can be re-used by future
+ * get_mempolicy2 interfaces and the arg structure made extensible
+ */
+ if (flags & MPOL_F_MEMS_ALLOWED) {
+ if (flags & (MPOL_F_NODE|MPOL_F_ADDR))
+ return -EINVAL;
+ pval = 0; /* just so it's initialized */
+ err = do_get_mems_allowed(&nodes);
+ } else if (flags & MPOL_F_ADDR) {
+ /* If F_ADDR, we operation on a vma policy (or default) */
+ err = do_get_vma_mempolicy(&args);
+ if (err)
+ return err;
+ /* if (F_ADDR | F_NODE), *pval is the address' node */
+ if (flags & MPOL_F_NODE) {
+ /* if we failed to fetch, that's likely an EFAULT */
+ if (args.addr_node < 0)
+ return args.addr_node;
+ pval = args.addr_node;
+ } else
+ pval = args.mode | args.mode_flags;
+ } else {
+ /* if not F_ADDR and addr != null, EINVAL */
+ if (addr)
+ return -EINVAL;
+
+ err = do_get_task_mempolicy(&args);
+ if (err)
+ return err;
+ /*
+ * if F_NODE was set and mode was MPOL_INTERLEAVE
+ * *pval is equal to next interleave node.
+ *
+ * if args.policy_node < 0, this means the mode did
+ * not have a policy. This presently emulates the
+ * original behavior of (F_NODE) & (!MPOL_INTERLEAVE)
+ * producing -EINVAL
+ */
+ if (flags & MPOL_F_NODE) {
+ if (args.policy_node < 0)
+ return args.policy_node;
+ pval = args.policy_node;
+ } else
+ pval = args.mode | args.mode_flags;
+ }

if (err)
return err;
--
2.39.1