[PATCH 6.1 07/14] maple_tree: be more cautious about dead nodes

From: Liam R. Howlett
Date: Tue Apr 11 2023 - 11:13:02 EST


From: "Liam R. Howlett" <Liam.Howlett@xxxxxxxxxx>

commit 39d0bd86c499ecd6abae42a9b7112056c5560691 upstream.

ma_pivots() and ma_data_end() may be called with a dead node. Ensure to
that the node isn't dead before using the returned values.

This is necessary for RCU mode of the maple tree.

Link: https://lkml.kernel.org/r/20230227173632.3292573-1-surenb@xxxxxxxxxx
Link: https://lkml.kernel.org/r/20230227173632.3292573-2-surenb@xxxxxxxxxx
Fixes: 54a611b60590 ("Maple Tree: add new data structure")
Cc: <Stable@xxxxxxxxxxxxxxx>
Signed-off-by: Liam Howlett <Liam.Howlett@xxxxxxxxxx>
---
lib/maple_tree.c | 52 +++++++++++++++++++++++++++++++++++++++---------
1 file changed, 43 insertions(+), 9 deletions(-)

diff --git a/lib/maple_tree.c b/lib/maple_tree.c
index c50646fcb8ca..7c8225e7df13 100644
--- a/lib/maple_tree.c
+++ b/lib/maple_tree.c
@@ -534,6 +534,7 @@ static inline bool ma_dead_node(const struct maple_node *node)

return (parent == node);
}
+
/*
* mte_dead_node() - check if the @enode is dead.
* @enode: The encoded maple node
@@ -615,6 +616,8 @@ static inline unsigned int mas_alloc_req(const struct ma_state *mas)
* @node - the maple node
* @type - the node type
*
+ * In the event of a dead node, this array may be %NULL
+ *
* Return: A pointer to the maple node pivots
*/
static inline unsigned long *ma_pivots(struct maple_node *node,
@@ -1086,8 +1089,11 @@ static int mas_ascend(struct ma_state *mas)
a_type = mas_parent_enum(mas, p_enode);
a_node = mte_parent(p_enode);
a_slot = mte_parent_slot(p_enode);
- pivots = ma_pivots(a_node, a_type);
a_enode = mt_mk_node(a_node, a_type);
+ pivots = ma_pivots(a_node, a_type);
+
+ if (unlikely(ma_dead_node(a_node)))
+ return 1;

if (!set_min && a_slot) {
set_min = true;
@@ -1393,6 +1399,9 @@ static inline unsigned char ma_data_end(struct maple_node *node,
{
unsigned char offset;

+ if (!pivots)
+ return 0;
+
if (type == maple_arange_64)
return ma_meta_end(node, type);

@@ -1428,6 +1437,9 @@ static inline unsigned char mas_data_end(struct ma_state *mas)
return ma_meta_end(node, type);

pivots = ma_pivots(node, type);
+ if (unlikely(ma_dead_node(node)))
+ return 0;
+
offset = mt_pivots[type] - 1;
if (likely(!pivots[offset]))
return ma_meta_end(node, type);
@@ -4499,6 +4511,9 @@ static inline int mas_prev_node(struct ma_state *mas, unsigned long min)
node = mas_mn(mas);
slots = ma_slots(node, mt);
pivots = ma_pivots(node, mt);
+ if (unlikely(ma_dead_node(node)))
+ return 1;
+
mas->max = pivots[offset];
if (offset)
mas->min = pivots[offset - 1] + 1;
@@ -4520,6 +4535,9 @@ static inline int mas_prev_node(struct ma_state *mas, unsigned long min)
slots = ma_slots(node, mt);
pivots = ma_pivots(node, mt);
offset = ma_data_end(node, mt, pivots, mas->max);
+ if (unlikely(ma_dead_node(node)))
+ return 1;
+
if (offset)
mas->min = pivots[offset - 1] + 1;

@@ -4568,6 +4586,7 @@ static inline int mas_next_node(struct ma_state *mas, struct maple_node *node,
struct maple_enode *enode;
int level = 0;
unsigned char offset;
+ unsigned char node_end;
enum maple_type mt;
void __rcu **slots;

@@ -4591,7 +4610,11 @@ static inline int mas_next_node(struct ma_state *mas, struct maple_node *node,
node = mas_mn(mas);
mt = mte_node_type(mas->node);
pivots = ma_pivots(node, mt);
- } while (unlikely(offset == ma_data_end(node, mt, pivots, mas->max)));
+ node_end = ma_data_end(node, mt, pivots, mas->max);
+ if (unlikely(ma_dead_node(node)))
+ return 1;
+
+ } while (unlikely(offset == node_end));

slots = ma_slots(node, mt);
pivot = mas_safe_pivot(mas, pivots, ++offset, mt);
@@ -4607,6 +4630,9 @@ static inline int mas_next_node(struct ma_state *mas, struct maple_node *node,
mt = mte_node_type(mas->node);
slots = ma_slots(node, mt);
pivots = ma_pivots(node, mt);
+ if (unlikely(ma_dead_node(node)))
+ return 1;
+
offset = 0;
pivot = pivots[0];
}
@@ -4653,11 +4679,14 @@ static inline void *mas_next_nentry(struct ma_state *mas,
return NULL;
}

- pivots = ma_pivots(node, type);
slots = ma_slots(node, type);
- mas->index = mas_safe_min(mas, pivots, mas->offset);
+ pivots = ma_pivots(node, type);
count = ma_data_end(node, type, pivots, mas->max);
- if (ma_dead_node(node))
+ if (unlikely(ma_dead_node(node)))
+ return NULL;
+
+ mas->index = mas_safe_min(mas, pivots, mas->offset);
+ if (unlikely(ma_dead_node(node)))
return NULL;

if (mas->index > max)
@@ -4815,6 +4844,11 @@ static inline void *mas_prev_nentry(struct ma_state *mas, unsigned long limit,

slots = ma_slots(mn, mt);
pivots = ma_pivots(mn, mt);
+ if (unlikely(ma_dead_node(mn))) {
+ mas_rewalk(mas, index);
+ goto retry;
+ }
+
if (offset == mt_pivots[mt])
pivot = mas->max;
else
@@ -6617,11 +6651,11 @@ static inline void *mas_first_entry(struct ma_state *mas, struct maple_node *mn,
while (likely(!ma_is_leaf(mt))) {
MT_BUG_ON(mas->tree, mte_dead_node(mas->node));
slots = ma_slots(mn, mt);
- pivots = ma_pivots(mn, mt);
- max = pivots[0];
entry = mas_slot(mas, slots, 0);
+ pivots = ma_pivots(mn, mt);
if (unlikely(ma_dead_node(mn)))
return NULL;
+ max = pivots[0];
mas->node = entry;
mn = mas_mn(mas);
mt = mte_node_type(mas->node);
@@ -6641,13 +6675,13 @@ static inline void *mas_first_entry(struct ma_state *mas, struct maple_node *mn,
if (likely(entry))
return entry;

- pivots = ma_pivots(mn, mt);
- mas->index = pivots[0] + 1;
mas->offset = 1;
entry = mas_slot(mas, slots, 1);
+ pivots = ma_pivots(mn, mt);
if (unlikely(ma_dead_node(mn)))
return NULL;

+ mas->index = pivots[0] + 1;
if (mas->index > limit)
goto none;

--
2.39.2