[RFC PATCH 2/5] mm: Select victim memcg using bpf prog

From: Chuyi Zhou
Date: Thu Jul 27 2023 - 03:44:13 EST


This patch use BPF prog to bypass the default select_bad_process method
and select a victim memcg when gobal oom is invoked. Specifically, we
iterate root_mem_cgroup's children and select a next iteration root
through __bpf_run_oom_policy(). Repeat until we finally find a leaf
memcg in the last layer. Then we use oom_evaluate_task() to find a
victim task in the selected memcg. If there are no suitable process
to be killed in the memcg, we go back to the default method.

Suggested-by: Abel Wu <wuyun.abel@xxxxxxxxxxxxx>
Signed-off-by: Chuyi Zhou <zhouchuyi@xxxxxxxxxxxxx>
---
include/linux/memcontrol.h | 6 +++++
mm/memcontrol.c | 50 ++++++++++++++++++++++++++++++++++++++
mm/oom_kill.c | 17 +++++++++++++
3 files changed, 73 insertions(+)

diff --git a/include/linux/memcontrol.h b/include/linux/memcontrol.h
index 5818af8eca5a..7fedc2521c8b 100644
--- a/include/linux/memcontrol.h
+++ b/include/linux/memcontrol.h
@@ -1155,6 +1155,7 @@ unsigned long mem_cgroup_soft_limit_reclaim(pg_data_t *pgdat, int order,
gfp_t gfp_mask,
unsigned long *total_scanned);

+struct mem_cgroup *select_victim_memcg(void);
#else /* CONFIG_MEMCG */

#define MEM_CGROUP_ID_SHIFT 0
@@ -1588,6 +1589,11 @@ unsigned long mem_cgroup_soft_limit_reclaim(pg_data_t *pgdat, int order,
{
return 0;
}
+
+static inline struct mem_cgroup *select_victim_memcg(void)
+{
+ return NULL;
+}
#endif /* CONFIG_MEMCG */

static inline void __inc_lruvec_kmem_state(void *p, enum node_stat_item idx)
diff --git a/mm/memcontrol.c b/mm/memcontrol.c
index e8ca4bdcb03c..c6b42635f1af 100644
--- a/mm/memcontrol.c
+++ b/mm/memcontrol.c
@@ -64,6 +64,7 @@
#include <linux/psi.h>
#include <linux/seq_buf.h>
#include <linux/sched/isolation.h>
+#include <linux/bpf_oom.h>
#include "internal.h"
#include <net/sock.h>
#include <net/ip.h>
@@ -2638,6 +2639,55 @@ void mem_cgroup_handle_over_high(void)
css_put(&memcg->css);
}

+struct mem_cgroup *select_victim_memcg(void)
+{
+ struct cgroup_subsys_state *pos, *parent, *victim;
+ struct mem_cgroup *victim_memcg;
+
+ parent = &root_mem_cgroup->css;
+ victim_memcg = NULL;
+
+ if (!cgroup_subsys_on_dfl(memory_cgrp_subsys))
+ return NULL;
+
+ rcu_read_lock();
+ while (parent) {
+ struct cgroup_subsys_state *chosen = NULL;
+ struct mem_cgroup *pos_mem, *chosen_mem;
+ u64 chosen_id, pos_id;
+ int cmp_ret;
+
+ victim = parent;
+
+ list_for_each_entry_rcu(pos, &parent->children, sibling) {
+ pos_id = cgroup_id(pos->cgroup);
+ if (!chosen)
+ goto chose;
+
+ cmp_ret = __bpf_run_oom_policy(chosen_id, pos_id);
+ if (cmp_ret == BPF_OOM_CMP_GREATER)
+ continue;
+ if (cmp_ret == BPF_OOM_CMP_EQUAL) {
+ pos_mem = mem_cgroup_from_css(pos);
+ chosen_mem = mem_cgroup_from_css(chosen);
+ if (page_counter_read(&pos_mem->memory) <=
+ page_counter_read(&chosen_mem->memory))
+ continue;
+ }
+chose:
+ chosen = pos;
+ chosen_id = pos_id;
+ }
+ parent = chosen;
+ }
+
+ if (victim && css_tryget(victim))
+ victim_memcg = mem_cgroup_from_css(victim);
+ rcu_read_unlock();
+
+ return victim_memcg;
+}
+
static int try_charge_memcg(struct mem_cgroup *memcg, gfp_t gfp_mask,
unsigned int nr_pages)
{
diff --git a/mm/oom_kill.c b/mm/oom_kill.c
index 01af8adaa16c..b88c8c7d4ee4 100644
--- a/mm/oom_kill.c
+++ b/mm/oom_kill.c
@@ -361,6 +361,19 @@ static int oom_evaluate_task(struct task_struct *task, void *arg)
return 1;
}

+static bool bpf_select_bad_process(struct oom_control *oc)
+{
+ struct mem_cgroup *victim_memcg;
+
+ victim_memcg = select_victim_memcg();
+ if (victim_memcg) {
+ mem_cgroup_scan_tasks(victim_memcg, oom_evaluate_task, oc);
+ css_put(&victim_memcg->css);
+ }
+
+ return !!oc->chosen;
+}
+
/*
* Simple selection loop. We choose the process with the highest number of
* 'points'. In case scan was aborted, oc->chosen is set to -1.
@@ -372,6 +385,9 @@ static void select_bad_process(struct oom_control *oc)
if (is_memcg_oom(oc))
mem_cgroup_scan_tasks(oc->memcg, oom_evaluate_task, oc);
else {
+ if (bpf_oom_policy_enabled() && bpf_select_bad_process(oc))
+ return;
+
struct task_struct *p;

rcu_read_lock();
@@ -1426,3 +1442,4 @@ bool bpf_oom_policy_enabled(void)
rcu_read_unlock();
return !empty;
}
+
--
2.20.1