[PATCH v6 2/6] memcontrol: allows mem_cgroup_iter() to check for onlineness

From: Nhat Pham
Date: Mon Nov 27 2023 - 14:37:17 EST


The new zswap writeback scheme requires an online-only memcg hierarchy
traversal. Add a new parameter to mem_cgroup_iter() to check for
onlineness before returning.

Signed-off-by: Nhat Pham <nphamcs@xxxxxxxxx>
---
include/linux/memcontrol.h | 4 ++--
mm/memcontrol.c | 17 ++++++++++-------
mm/shrinker.c | 4 ++--
mm/vmscan.c | 26 +++++++++++++-------------
4 files changed, 27 insertions(+), 24 deletions(-)

diff --git a/include/linux/memcontrol.h b/include/linux/memcontrol.h
index 7bdcf3020d7a..86adce081a08 100644
--- a/include/linux/memcontrol.h
+++ b/include/linux/memcontrol.h
@@ -832,7 +832,7 @@ static inline void mem_cgroup_put(struct mem_cgroup *memcg)

struct mem_cgroup *mem_cgroup_iter(struct mem_cgroup *,
struct mem_cgroup *,
- struct mem_cgroup_reclaim_cookie *);
+ struct mem_cgroup_reclaim_cookie *, bool online);
void mem_cgroup_iter_break(struct mem_cgroup *, struct mem_cgroup *);
void mem_cgroup_scan_tasks(struct mem_cgroup *memcg,
int (*)(struct task_struct *, void *), void *arg);
@@ -1381,7 +1381,7 @@ static inline struct lruvec *folio_lruvec_lock_irqsave(struct folio *folio,
static inline struct mem_cgroup *
mem_cgroup_iter(struct mem_cgroup *root,
struct mem_cgroup *prev,
- struct mem_cgroup_reclaim_cookie *reclaim)
+ struct mem_cgroup_reclaim_cookie *reclaim, bool online)
{
return NULL;
}
diff --git a/mm/memcontrol.c b/mm/memcontrol.c
index 564aa8f25b71..a1f051adaa15 100644
--- a/mm/memcontrol.c
+++ b/mm/memcontrol.c
@@ -221,14 +221,14 @@ enum res_type {
* be used for reference counting.
*/
#define for_each_mem_cgroup_tree(iter, root) \
- for (iter = mem_cgroup_iter(root, NULL, NULL); \
+ for (iter = mem_cgroup_iter(root, NULL, NULL, false); \
iter != NULL; \
- iter = mem_cgroup_iter(root, iter, NULL))
+ iter = mem_cgroup_iter(root, iter, NULL, false))

#define for_each_mem_cgroup(iter) \
- for (iter = mem_cgroup_iter(NULL, NULL, NULL); \
+ for (iter = mem_cgroup_iter(NULL, NULL, NULL, false); \
iter != NULL; \
- iter = mem_cgroup_iter(NULL, iter, NULL))
+ iter = mem_cgroup_iter(NULL, iter, NULL, false))

static inline bool task_is_dying(void)
{
@@ -1115,6 +1115,7 @@ struct mem_cgroup *get_mem_cgroup_from_current(void)
* @root: hierarchy root
* @prev: previously returned memcg, NULL on first invocation
* @reclaim: cookie for shared reclaim walks, NULL for full walks
+ * @online: skip offline memcgs
*
* Returns references to children of the hierarchy below @root, or
* @root itself, or %NULL after a full round-trip.
@@ -1129,7 +1130,8 @@ struct mem_cgroup *get_mem_cgroup_from_current(void)
*/
struct mem_cgroup *mem_cgroup_iter(struct mem_cgroup *root,
struct mem_cgroup *prev,
- struct mem_cgroup_reclaim_cookie *reclaim)
+ struct mem_cgroup_reclaim_cookie *reclaim,
+ bool online)
{
struct mem_cgroup_reclaim_iter *iter;
struct cgroup_subsys_state *css = NULL;
@@ -1199,7 +1201,8 @@ struct mem_cgroup *mem_cgroup_iter(struct mem_cgroup *root,
* is provided by the caller, so we know it's alive
* and kicking, and don't take an extra reference.
*/
- if (css == &root->css || css_tryget(css)) {
+ if (css == &root->css || (!online && css_tryget(css)) ||
+ css_tryget_online(css)) {
memcg = mem_cgroup_from_css(css);
break;
}
@@ -1812,7 +1815,7 @@ static int mem_cgroup_soft_reclaim(struct mem_cgroup *root_memcg,
excess = soft_limit_excess(root_memcg);

while (1) {
- victim = mem_cgroup_iter(root_memcg, victim, &reclaim);
+ victim = mem_cgroup_iter(root_memcg, victim, &reclaim, false);
if (!victim) {
loop++;
if (loop >= 2) {
diff --git a/mm/shrinker.c b/mm/shrinker.c
index dd91eab43ed3..54f5d3aa4f27 100644
--- a/mm/shrinker.c
+++ b/mm/shrinker.c
@@ -160,7 +160,7 @@ static int expand_shrinker_info(int new_id)
new_size = shrinker_unit_size(new_nr_max);
old_size = shrinker_unit_size(shrinker_nr_max);

- memcg = mem_cgroup_iter(NULL, NULL, NULL);
+ memcg = mem_cgroup_iter(NULL, NULL, NULL, false);
do {
ret = expand_one_shrinker_info(memcg, new_size, old_size,
new_nr_max);
@@ -168,7 +168,7 @@ static int expand_shrinker_info(int new_id)
mem_cgroup_iter_break(NULL, memcg);
goto out;
}
- } while ((memcg = mem_cgroup_iter(NULL, memcg, NULL)) != NULL);
+ } while ((memcg = mem_cgroup_iter(NULL, memcg, NULL, false)) != NULL);
out:
if (!ret)
shrinker_nr_max = new_nr_max;
diff --git a/mm/vmscan.c b/mm/vmscan.c
index d8c3338fee0f..9a65ee3a1bb7 100644
--- a/mm/vmscan.c
+++ b/mm/vmscan.c
@@ -397,10 +397,10 @@ static unsigned long drop_slab_node(int nid)
unsigned long freed = 0;
struct mem_cgroup *memcg = NULL;

- memcg = mem_cgroup_iter(NULL, NULL, NULL);
+ memcg = mem_cgroup_iter(NULL, NULL, NULL, false);
do {
freed += shrink_slab(GFP_KERNEL, nid, memcg, 0);
- } while ((memcg = mem_cgroup_iter(NULL, memcg, NULL)) != NULL);
+ } while ((memcg = mem_cgroup_iter(NULL, memcg, NULL, false)) != NULL);

return freed;
}
@@ -3935,7 +3935,7 @@ static void lru_gen_age_node(struct pglist_data *pgdat, struct scan_control *sc)
if (!min_ttl || sc->order || sc->priority == DEF_PRIORITY)
return;

- memcg = mem_cgroup_iter(NULL, NULL, NULL);
+ memcg = mem_cgroup_iter(NULL, NULL, NULL, false);
do {
struct lruvec *lruvec = mem_cgroup_lruvec(memcg, pgdat);

@@ -3945,7 +3945,7 @@ static void lru_gen_age_node(struct pglist_data *pgdat, struct scan_control *sc)
}

cond_resched();
- } while ((memcg = mem_cgroup_iter(NULL, memcg, NULL)));
+ } while ((memcg = mem_cgroup_iter(NULL, memcg, NULL, false)));

/*
* The main goal is to OOM kill if every generation from all memcgs is
@@ -5037,7 +5037,7 @@ static void lru_gen_change_state(bool enabled)
else
static_branch_disable_cpuslocked(&lru_gen_caps[LRU_GEN_CORE]);

- memcg = mem_cgroup_iter(NULL, NULL, NULL);
+ memcg = mem_cgroup_iter(NULL, NULL, NULL, false);
do {
int nid;

@@ -5061,7 +5061,7 @@ static void lru_gen_change_state(bool enabled)
}

cond_resched();
- } while ((memcg = mem_cgroup_iter(NULL, memcg, NULL)));
+ } while ((memcg = mem_cgroup_iter(NULL, memcg, NULL, false)));
unlock:
mutex_unlock(&state_mutex);
put_online_mems();
@@ -5164,7 +5164,7 @@ static void *lru_gen_seq_start(struct seq_file *m, loff_t *pos)
if (!m->private)
return ERR_PTR(-ENOMEM);

- memcg = mem_cgroup_iter(NULL, NULL, NULL);
+ memcg = mem_cgroup_iter(NULL, NULL, NULL, false);
do {
int nid;

@@ -5172,7 +5172,7 @@ static void *lru_gen_seq_start(struct seq_file *m, loff_t *pos)
if (!nr_to_skip--)
return get_lruvec(memcg, nid);
}
- } while ((memcg = mem_cgroup_iter(NULL, memcg, NULL)));
+ } while ((memcg = mem_cgroup_iter(NULL, memcg, NULL, false)));

return NULL;
}
@@ -5195,7 +5195,7 @@ static void *lru_gen_seq_next(struct seq_file *m, void *v, loff_t *pos)

nid = next_memory_node(nid);
if (nid == MAX_NUMNODES) {
- memcg = mem_cgroup_iter(NULL, memcg, NULL);
+ memcg = mem_cgroup_iter(NULL, memcg, NULL, false);
if (!memcg)
return NULL;

@@ -5798,7 +5798,7 @@ static void shrink_node_memcgs(pg_data_t *pgdat, struct scan_control *sc)
struct mem_cgroup *target_memcg = sc->target_mem_cgroup;
struct mem_cgroup *memcg;

- memcg = mem_cgroup_iter(target_memcg, NULL, NULL);
+ memcg = mem_cgroup_iter(target_memcg, NULL, NULL, false);
do {
struct lruvec *lruvec = mem_cgroup_lruvec(memcg, pgdat);
unsigned long reclaimed;
@@ -5855,7 +5855,7 @@ static void shrink_node_memcgs(pg_data_t *pgdat, struct scan_control *sc)
sc->nr_scanned - scanned,
sc->nr_reclaimed - reclaimed);

- } while ((memcg = mem_cgroup_iter(target_memcg, memcg, NULL)));
+ } while ((memcg = mem_cgroup_iter(target_memcg, memcg, NULL, false)));
}

static void shrink_node(pg_data_t *pgdat, struct scan_control *sc)
@@ -6522,12 +6522,12 @@ static void kswapd_age_node(struct pglist_data *pgdat, struct scan_control *sc)
if (!inactive_is_low(lruvec, LRU_INACTIVE_ANON))
return;

- memcg = mem_cgroup_iter(NULL, NULL, NULL);
+ memcg = mem_cgroup_iter(NULL, NULL, NULL, false);
do {
lruvec = mem_cgroup_lruvec(memcg, pgdat);
shrink_active_list(SWAP_CLUSTER_MAX, lruvec,
sc, LRU_ACTIVE_ANON);
- memcg = mem_cgroup_iter(NULL, memcg, NULL);
+ memcg = mem_cgroup_iter(NULL, memcg, NULL, false);
} while (memcg);
}

--
2.34.1