[RFC PATCH v2 2/4] mm/mempolicy: Implement set_mempolicy2 and

From: Gregory Price
Date: Mon Oct 02 2023 - 20:22:21 EST


sys_set_mempolicy is limited by its current argument structure
(mode, nodes, flags) to implementing policies that can be described
in that manner.

Implement set/get_mempolicy2 with a new mempolicy_args structure
which encapsulates the old behavior, and allows for new mempolicies
which may require additional information.

Signed-off-by: Gregory Price <gregory.price@xxxxxxxxxxxx>
---
arch/x86/entry/syscalls/syscall_32.tbl | 2 +
arch/x86/entry/syscalls/syscall_64.tbl | 2 +
include/linux/syscalls.h | 4 +
include/uapi/asm-generic/unistd.h | 10 +-
include/uapi/linux/mempolicy.h | 29 ++++
mm/mempolicy.c | 196 ++++++++++++++++++++++++-
6 files changed, 241 insertions(+), 2 deletions(-)

diff --git a/arch/x86/entry/syscalls/syscall_32.tbl b/arch/x86/entry/syscalls/syscall_32.tbl
index 2d0b1bd866ea..a72ef588a704 100644
--- a/arch/x86/entry/syscalls/syscall_32.tbl
+++ b/arch/x86/entry/syscalls/syscall_32.tbl
@@ -457,3 +457,5 @@
450 i386 set_mempolicy_home_node sys_set_mempolicy_home_node
451 i386 cachestat sys_cachestat
452 i386 fchmodat2 sys_fchmodat2
+454 i386 set_mempolicy2 sys_set_mempolicy2
+455 i386 get_mempolicy2 sys_get_mempolicy2
diff --git a/arch/x86/entry/syscalls/syscall_64.tbl b/arch/x86/entry/syscalls/syscall_64.tbl
index 1d6eee30eceb..ec54064de8b3 100644
--- a/arch/x86/entry/syscalls/syscall_64.tbl
+++ b/arch/x86/entry/syscalls/syscall_64.tbl
@@ -375,6 +375,8 @@
451 common cachestat sys_cachestat
452 common fchmodat2 sys_fchmodat2
453 64 map_shadow_stack sys_map_shadow_stack
+454 common set_mempolicy2 sys_set_mempolicy2
+455 common get_mempolicy2 sys_get_mempolicy2

#
# Due to a historical design error, certain syscalls are numbered differently
diff --git a/include/linux/syscalls.h b/include/linux/syscalls.h
index 22bc6bc147f8..0c4a71177df9 100644
--- a/include/linux/syscalls.h
+++ b/include/linux/syscalls.h
@@ -813,6 +813,10 @@ asmlinkage long sys_get_mempolicy(int __user *policy,
unsigned long addr, unsigned long flags);
asmlinkage long sys_set_mempolicy(int mode, const unsigned long __user *nmask,
unsigned long maxnode);
+asmlinkage long sys_get_mempolicy2(struct mempolicy_args __user *args,
+ size_t size);
+asmlinkage long sys_set_mempolicy2(struct mempolicy_args __user *args,
+ size_t size);
asmlinkage long sys_migrate_pages(pid_t pid, unsigned long maxnode,
const unsigned long __user *from,
const unsigned long __user *to);
diff --git a/include/uapi/asm-generic/unistd.h b/include/uapi/asm-generic/unistd.h
index abe087c53b4b..397dcf804941 100644
--- a/include/uapi/asm-generic/unistd.h
+++ b/include/uapi/asm-generic/unistd.h
@@ -823,8 +823,16 @@ __SYSCALL(__NR_cachestat, sys_cachestat)
#define __NR_fchmodat2 452
__SYSCALL(__NR_fchmodat2, sys_fchmodat2)

+/* CONFIG_MMU only */
+#ifndef __ARCH_NOMMU
+#define __NR_set_mempolicy 454
+__SYSCALL(__NR_set_mempolicy2, sys_set_mempolicy2)
+#define __NR_set_mempolicy 455
+__SYSCALL(__NR_get_mempolicy2, sys_get_mempolicy2)
+#endif
+
#undef __NR_syscalls
-#define __NR_syscalls 453
+#define __NR_syscalls 456

/*
* 32 bit systems traditionally used different
diff --git a/include/uapi/linux/mempolicy.h b/include/uapi/linux/mempolicy.h
index 046d0ccba4cd..ea386872094b 100644
--- a/include/uapi/linux/mempolicy.h
+++ b/include/uapi/linux/mempolicy.h
@@ -23,9 +23,38 @@ enum {
MPOL_INTERLEAVE,
MPOL_LOCAL,
MPOL_PREFERRED_MANY,
+ MPOL_LEGACY, /* set_mempolicy limited to above modes */
MPOL_MAX, /* always last member of enum */
};

+struct mempolicy_args {
+ unsigned short mode;
+ unsigned long *nodemask;
+ unsigned long maxnode;
+ unsigned short flags;
+ struct {
+ /* Memory allowed */
+ struct {
+ unsigned long maxnode;
+ unsigned long *nodemask;
+ } allowed;
+ /* Address information */
+ struct {
+ unsigned long addr;
+ unsigned long node;
+ unsigned short mode;
+ unsigned short flags;
+ } addr;
+ /* Interleave */
+ } get;
+ /* Mode specific settings */
+ union {
+ struct {
+ unsigned long next_node; /* get only */
+ } interleave;
+ };
+};
+
/* Flags for set_mempolicy */
#define MPOL_F_STATIC_NODES (1 << 15)
#define MPOL_F_RELATIVE_NODES (1 << 14)
diff --git a/mm/mempolicy.c b/mm/mempolicy.c
index ad26f41b91de..936c641f554e 100644
--- a/mm/mempolicy.c
+++ b/mm/mempolicy.c
@@ -1478,7 +1478,7 @@ static inline int sanitize_mpol_flags(int *mode, unsigned short *flags)
*flags = *mode & MPOL_MODE_FLAGS;
*mode &= ~MPOL_MODE_FLAGS;

- if ((unsigned int)(*mode) >= MPOL_MAX)
+ if ((unsigned int)(*mode) >= MPOL_LEGACY)
return -EINVAL;
if ((*flags & MPOL_F_STATIC_NODES) && (*flags & MPOL_F_RELATIVE_NODES))
return -EINVAL;
@@ -1609,6 +1609,200 @@ SYSCALL_DEFINE3(set_mempolicy, int, mode, const unsigned long __user *, nmask,
return kernel_set_mempolicy(mode, nmask, maxnode);
}

+static long do_set_mempolicy2(struct mempolicy_args *args)
+{
+ struct mempolicy *new = NULL;
+ nodemask_t nodes;
+ int err;
+
+ if (args->mode <= MPOL_LEGACY)
+ return -EINVAL;
+
+ if (args->mode >= MPOL_MAX)
+ return -EINVAL;
+
+ err = get_nodes(&nodes, args->nodemask, args->maxnode);
+ if (err)
+ return err;
+
+ new = mpol_new(args->mode, args->flags, &nodes);
+ if (IS_ERR(new))
+ return PTR_ERR(new);
+
+ switch (args->mode) {
+ default:
+ BUG();
+ }
+
+ if (err)
+ goto out;
+
+ err = replace_mempolicy(new, &nodes);
+out:
+ if (err)
+ mpol_put(new);
+ return err;
+};
+
+static bool mempolicy2_args_valid(struct mempolicy_args *kargs)
+{
+ /* Legacy modes are routed through the legacy interface */
+ return kargs->mode > MPOL_LEGACY && kargs->mode < MPOL_MAX;
+}
+
+static long kernel_set_mempolicy2(const struct mempolicy_args __user *uargs,
+ size_t usize)
+{
+ struct mempolicy_args kargs;
+ int err;
+
+ if (usize < sizeof(kargs))
+ return -EINVAL;
+
+ err = copy_struct_from_user(&kargs, sizeof(kargs), uargs, usize);
+ if (err)
+ return err;
+
+ /* If the mode is legacy, use the legacy path */
+ if (kargs.mode < MPOL_LEGACY) {
+ int legacy_mode = kargs.mode | kargs.flags;
+ const unsigned long __user *lnmask = kargs.nodemask;
+ unsigned long maxnode = kargs.maxnode;
+
+ return kernel_set_mempolicy(legacy_mode, lnmask, maxnode);
+ }
+
+ if (!mempolicy2_args_valid(&kargs))
+ return -EINVAL;
+
+ return do_set_mempolicy2(&kargs);
+}
+
+SYSCALL_DEFINE2(set_mempolicy2, const struct mempolicy_args __user *, args,
+ size_t, size)
+{
+ return kernel_set_mempolicy2(args, size);
+}
+
+/* Gets extended mempolicy information */
+static long do_get_mempolicy2(struct mempolicy_args *kargs)
+{
+ struct mempolicy *pol = current->mempolicy;
+ nodemask_t knodes;
+ int rc = 0;
+
+ kargs->mode = pol->mode;
+ /* Mask off internal flags */
+ kargs->flags = pol->flags & MPOL_MODE_FLAGS;
+
+ if (kargs->nodemask) {
+ if (mpol_store_user_nodemask(pol)) {
+ knodes = pol->w.user_nodemask;
+ } else {
+ task_lock(current);
+ get_policy_nodemask(pol, &knodes);
+ task_unlock(current);
+ }
+ rc = copy_nodes_to_user(kargs->nodemask, kargs->maxnode,
+ &knodes);
+ if (rc)
+ return rc;
+ }
+
+
+ if (kargs->get.allowed.nodemask) {
+ task_lock(current);
+ knodes = cpuset_current_mems_allowed;
+ task_unlock(current);
+ rc = copy_nodes_to_user(kargs->get.allowed.nodemask,
+ kargs->get.allowed.maxnode,
+ &knodes);
+ if (rc)
+ return rc;
+ }
+
+ if (kargs->get.addr.addr) {
+ struct mempolicy *addr_pol;
+ struct vm_area_struct *vma;
+ struct mm_struct *mm = current->mm;
+ unsigned long addr = kargs->get.addr.addr;
+
+ /*
+ * Do NOT fall back to task policy if the vma/shared policy
+ * at addr is NULL. Return MPOL_DEFAULT in this case.
+ */
+ mmap_read_lock(mm);
+ vma = vma_lookup(mm, addr);
+ if (!vma) {
+ mmap_read_unlock(mm);
+ return -EFAULT;
+ }
+ if (vma->vm_ops && vma->vm_ops->get_policy)
+ addr_pol = vma->vm_ops->get_policy(vma, addr);
+ else
+ addr_pol = vma->vm_policy;
+
+ kargs->get.addr.mode = addr_pol->mode;
+ /* Mask off internal flags */
+ kargs->get.addr.flags = (pol->flags & MPOL_MODE_FLAGS);
+
+ /*
+ * Take a refcount on the mpol, because we are about to
+ * drop the mmap_lock, after which only "pol" remains
+ * valid, "vma" is stale.
+ */
+ vma = NULL;
+ mpol_get(addr_pol);
+ mmap_read_unlock(mm);
+ rc = lookup_node(mm, addr);
+ mpol_put(addr_pol);
+ if (rc < 0)
+ return rc;
+ kargs->get.addr.node = rc;
+ }
+
+ switch (kargs->mode) {
+ case MPOL_INTERLEAVE:
+ kargs->interleave.next_node = next_node_in(current->il_prev,
+ pol->nodes);
+ rc = 0;
+ break;
+ default:
+ BUG();
+ }
+
+ return rc;
+}
+
+static long kernel_get_mempolicy2(struct mempolicy_args __user *uargs,
+ size_t usize)
+{
+ struct mempolicy_args kargs;
+ int err;
+
+ if (usize < sizeof(kargs))
+ return -EINVAL;
+
+ err = copy_struct_from_user(&kargs, sizeof(kargs), uargs, usize);
+ if (err)
+ return err;
+
+ /* Get the extended memory policy information (kargs.ext) */
+ err = do_get_mempolicy2(&kargs);
+ if (err)
+ return err;
+
+ err = copy_to_user(uargs, &kargs, sizeof(kargs));
+
+ return err;
+}
+
+SYSCALL_DEFINE2(get_mempolicy2, struct mempolicy_args __user *, policy,
+ size_t, size)
+{
+ return kernel_get_mempolicy2(policy, size);
+}
+
static int kernel_migrate_pages(pid_t pid, unsigned long maxnode,
const unsigned long __user *old_nodes,
const unsigned long __user *new_nodes)
--
2.39.1