[PATCH Fix 2/3] maple_tree: Change spanning store to work on larger trees

From: Liam Howlett
Date: Wed Jun 15 2022 - 10:20:42 EST


Spanning store had an issue which could lead to double free during a
large tree modification. Fix this by being more careful about how nodes
are added to the to-be-freed and to-be-destroyed list on this operation.

Reported-by: Qian Cai <quic_qiancai@xxxxxxxxxxx>
Signed-off-by: Liam R. Howlett <Liam.Howlett@xxxxxxxxxx>
---
lib/maple_tree.c | 325 ++++++++++++++++++++++++++++++-----------------
1 file changed, 206 insertions(+), 119 deletions(-)

diff --git a/lib/maple_tree.c b/lib/maple_tree.c
index a1035963ae0d..f413b6f0da2b 100644
--- a/lib/maple_tree.c
+++ b/lib/maple_tree.c
@@ -1388,7 +1388,6 @@ static inline unsigned char ma_data_end(struct maple_node *node,
/*
* mas_data_end() - Find the end of the data (slot).
* @mas: the maple state
- * @type: the type of maple node
*
* This method is optimized to check the metadata of a node if the node type
* supports data end metadata.
@@ -2272,6 +2271,31 @@ static inline void mas_wr_node_walk(struct ma_wr_state *wr_mas)
wr_mas->offset_end = mas->offset = offset;
}

+/*
+ * mas_topiary_range() - Add a range of slots to the topiary.
+ * @mas: The maple state
+ * @destroy: The topiary to add the slots (usually destroy)
+ * @start: The starting slot inclusively
+ * @end: The end slot inclusively
+ */
+static inline void mas_topiary_range(struct ma_state *mas,
+ struct ma_topiary *destroy, unsigned char start, unsigned char end)
+{
+ void __rcu **slots;
+ unsigned offset;
+
+ MT_BUG_ON(mas->tree, mte_is_leaf(mas->node));
+ slots = ma_slots(mas_mn(mas), mte_node_type(mas->node));
+ for (offset = start; offset <= end; offset++) {
+ struct maple_enode *enode = mas_slot_locked(mas, slots, offset);
+
+ if (mte_dead_node(enode))
+ continue;
+
+ mat_add(destroy, enode);
+ }
+}
+
/*
* mast_topiary() - Add the portions of the tree to the removal list; either to
* be freed or discarded (destroy walk).
@@ -2280,48 +2304,62 @@ static inline void mas_wr_node_walk(struct ma_wr_state *wr_mas)
static inline void mast_topiary(struct maple_subtree_state *mast)
{
MA_WR_STATE(wr_mas, mast->orig_l, NULL);
- unsigned char l_off, r_off, offset;
- unsigned long l_index;
- struct maple_enode *child;
- void __rcu **slots;
+ unsigned char r_start, r_end;
+ unsigned char l_start, l_end;
+ void **l_slots, **r_slots;

wr_mas.type = mte_node_type(mast->orig_l->node);
- /* The left node is consumed, so add to the free list. */
- l_index = mast->orig_l->index;
mast->orig_l->index = mast->orig_l->last;
mas_wr_node_walk(&wr_mas);
- mast->orig_l->index = l_index;
- l_off = mast->orig_l->offset;
- r_off = mast->orig_r->offset;
- if (mast->orig_l->node == mast->orig_r->node) {
- slots = ma_slots(mte_to_node(mast->orig_l->node), wr_mas.type);
- for (offset = l_off + 1; offset < r_off; offset++)
- mat_add(mast->destroy, mas_slot_locked(mast->orig_l,
- slots, offset));
+ l_start = mast->orig_l->offset + 1;
+ l_end = mas_data_end(mast->orig_l);
+ r_start = 0;
+ r_end = mast->orig_r->offset;
+
+ if (r_end)
+ r_end--;
+
+ l_slots = ma_slots(mas_mn(mast->orig_l),
+ mte_node_type(mast->orig_l->node));
+
+ r_slots = ma_slots(mas_mn(mast->orig_r),
+ mte_node_type(mast->orig_r->node));

+ if ((l_start < l_end) &&
+ mte_dead_node(mas_slot_locked(mast->orig_l, l_slots, l_start))) {
+ l_start++;
+ }
+
+ if (mte_dead_node(mas_slot_locked(mast->orig_r, r_slots, r_end))) {
+ if (r_end)
+ r_end--;
+ }
+
+ if ((l_start > r_end) && (mast->orig_l->node == mast->orig_r->node))
return;
+
+ /* At the node where left and right sides meet, add the parts between */
+ if (mast->orig_l->node == mast->orig_r->node) {
+ return mas_topiary_range(mast->orig_l, mast->destroy,
+ l_start, r_end);
}

/* mast->orig_r is different and consumed. */
if (mte_is_leaf(mast->orig_r->node))
return;

- /* Now destroy l_off + 1 -> end and 0 -> r_off - 1 */
- offset = l_off + 1;
- slots = ma_slots(mte_to_node(mast->orig_l->node), wr_mas.type);
- while (offset < mt_slots[wr_mas.type]) {
- child = mas_slot_locked(mast->orig_l, slots, offset++);
- if (!child)
- break;
+ if (mte_dead_node(mas_slot_locked(mast->orig_l, l_slots, l_end)))
+ l_end--;

- mat_add(mast->destroy, child);
- }

- slots = ma_slots(mte_to_node(mast->orig_r->node),
- mte_node_type(mast->orig_r->node));
- for (offset = 0; offset < r_off; offset++)
- mat_add(mast->destroy,
- mas_slot_locked(mast->orig_l, slots, offset));
+ if (l_start <= l_end)
+ mas_topiary_range(mast->orig_l, mast->destroy, l_start, l_end);
+
+ if (mte_dead_node(mas_slot_locked(mast->orig_r, r_slots, r_start)))
+ r_start++;
+
+ if (r_start <= r_end)
+ mas_topiary_range(mast->orig_r, mast->destroy, 0, r_end);
}

/*
@@ -2329,19 +2367,13 @@ static inline void mast_topiary(struct maple_subtree_state *mast)
* @mast: The maple subtree state
* @old_r: The encoded maple node to the right (next node).
*/
-static inline void mast_rebalance_next(struct maple_subtree_state *mast,
- struct maple_enode *old_r, bool free)
+static inline void mast_rebalance_next(struct maple_subtree_state *mast)
{
unsigned char b_end = mast->bn->b_end;

mas_mab_cp(mast->orig_r, 0, mt_slot_count(mast->orig_r->node),
mast->bn, b_end);
- if (free)
- mat_add(mast->free, old_r);
-
mast->orig_r->last = mast->orig_r->max;
- if (old_r == mast->orig_l->node)
- mast->orig_l->node = mast->orig_r->node;
}

/*
@@ -2349,17 +2381,13 @@ static inline void mast_rebalance_next(struct maple_subtree_state *mast,
* @mast: The maple subtree state
* @old_l: The encoded maple node to the left (previous node)
*/
-static inline void mast_rebalance_prev(struct maple_subtree_state *mast,
- struct maple_enode *old_l)
+static inline void mast_rebalance_prev(struct maple_subtree_state *mast)
{
unsigned char end = mas_data_end(mast->orig_l) + 1;
unsigned char b_end = mast->bn->b_end;

mab_shift_right(mast->bn, end);
mas_mab_cp(mast->orig_l, 0, end - 1, mast->bn, 0);
- mat_add(mast->free, old_l);
- if (mast->orig_r->node == old_l)
- mast->orig_r->node = mast->orig_l->node;
mast->l->min = mast->orig_l->min;
mast->orig_l->index = mast->orig_l->min;
mast->bn->b_end = end + b_end;
@@ -2367,68 +2395,116 @@ static inline void mast_rebalance_prev(struct maple_subtree_state *mast,
}

/*
- * mast_sibling_rebalance_right() - Rebalance from nodes with the same parents.
- * Check the right side, then the left. Data is copied into the @mast->bn.
+ * mast_spanning_rebalance() - Rebalance nodes with nearest neighbour favouring
+ * the node to the right. Checking the nodes to the right then the left at each
+ * level upwards until root is reached. Free and destroy as needed.
+ * Data is copied into the @mast->bn.
* @mast: The maple_subtree_state.
*/
static inline
-bool mast_sibling_rebalance_right(struct maple_subtree_state *mast, bool free)
+bool mast_spanning_rebalance(struct maple_subtree_state *mast)
{
- struct maple_enode *old_r;
- struct maple_enode *old_l;
+ struct ma_state r_tmp = *mast->orig_r;
+ struct ma_state l_tmp = *mast->orig_l;
+ struct maple_enode *ancestor = NULL;
+ unsigned char start, end;
+ unsigned char depth = 0;

- old_r = mast->orig_r->node;
- if (mas_next_sibling(mast->orig_r)) {
- mast_rebalance_next(mast, old_r, free);
- return true;
- }
+ r_tmp = *mast->orig_r;
+ l_tmp = *mast->orig_l;
+ do {
+ mas_ascend(mast->orig_r);
+ mas_ascend(mast->orig_l);
+ depth++;
+ if (!ancestor &&
+ (mast->orig_r->node == mast->orig_l->node)) {
+ ancestor = mast->orig_r->node;
+ end = mast->orig_r->offset - 1;
+ start = mast->orig_l->offset + 1;
+ }

- old_l = mast->orig_l->node;
- if (mas_prev_sibling(mast->orig_l)) {
- mast->bn->type = mte_node_type(mast->orig_l->node);
- mast_rebalance_prev(mast, old_l);
- return true;
- }
+ if (mast->orig_r->offset < mas_data_end(mast->orig_r)) {
+ if (!ancestor) {
+ ancestor = mast->orig_r->node;
+ start = 0;
+ }

- return false;
-}
+ mast->orig_r->offset++;
+ do {
+ mas_descend(mast->orig_r);
+ mast->orig_r->offset = 0;
+ depth--;
+ } while (depth);

-static inline int mas_prev_node(struct ma_state *mas, unsigned long min);
-static inline int mas_next_node(struct ma_state *mas, struct maple_node *node,
- unsigned long max);
-/*
- * mast_cousin_rebalance_right() - Rebalance from nodes with different parents.
- * Check the right side, then the left. Data is copied into the @mast->bn.
- * @mast: The maple_subtree_state.
- */
-static inline
-bool mast_cousin_rebalance_right(struct maple_subtree_state *mast, bool free)
-{
- struct maple_enode *old_l = mast->orig_l->node;
- struct maple_enode *old_r = mast->orig_r->node;
+ mast_rebalance_next(mast);
+ do {
+ unsigned char l_off = 0;
+ struct maple_enode *child = r_tmp.node;

- MA_STATE(tmp, mast->orig_r->tree, mast->orig_r->index, mast->orig_r->last);
+ mas_ascend(&r_tmp);
+ if (ancestor == r_tmp.node)
+ l_off = start;

- tmp = *mast->orig_r;
- mas_next_node(mast->orig_r, mas_mn(mast->orig_r), ULONG_MAX);
- if (!mas_is_none(mast->orig_r)) {
- mast_rebalance_next(mast, old_r, free);
- return true;
- }
+ if (r_tmp.offset)
+ r_tmp.offset--;

- *mast->orig_r = *mast->orig_l;
- *mast->r = *mast->l;
- mas_prev_node(mast->orig_l, 0);
- if (mas_is_none(mast->orig_l)) {
- /* Making a new root with the contents of mast->bn */
- *mast->orig_l = *mast->orig_r;
- *mast->orig_r = tmp;
- return false;
- }
+ if (l_off < r_tmp.offset)
+ mas_topiary_range(&r_tmp, mast->destroy,
+ l_off, r_tmp.offset);

- mast->orig_l->offset = 0;
- mast_rebalance_prev(mast, old_l);
- return true;
+ if (l_tmp.node != child)
+ mat_add(mast->free, child);
+
+ } while (r_tmp.node != ancestor);
+
+ *mast->orig_l = l_tmp;
+ return true;
+
+ } else if (mast->orig_l->offset != 0) {
+ if (!ancestor) {
+ ancestor = mast->orig_l->node;
+ end = mas_data_end(mast->orig_l);
+ }
+
+ mast->orig_l->offset--;
+ do {
+ mas_descend(mast->orig_l);
+ mast->orig_l->offset =
+ mas_data_end(mast->orig_l);
+ depth--;
+ } while (depth);
+
+ mast_rebalance_prev(mast);
+ do {
+ unsigned char r_off;
+ struct maple_enode *child = l_tmp.node;
+
+ mas_ascend(&l_tmp);
+ if (ancestor == l_tmp.node)
+ r_off = end;
+ else
+ r_off = mas_data_end(&l_tmp);
+
+ if (l_tmp.offset < r_off)
+ l_tmp.offset++;
+
+ if (l_tmp.offset < r_off)
+ mas_topiary_range(&l_tmp, mast->destroy,
+ l_tmp.offset, r_off);
+
+ if (r_tmp.node != child)
+ mat_add(mast->free, child);
+
+ } while (l_tmp.node != ancestor);
+
+ *mast->orig_r = r_tmp;
+ return true;
+ }
+ } while (!mte_is_root(mast->orig_r->node));
+
+ *mast->orig_r = r_tmp;
+ *mast->orig_l = l_tmp;
+ return false;
}

/*
@@ -2462,18 +2538,16 @@ mast_ascend_free(struct maple_subtree_state *mast)
* The node may not contain the value so set slot to ensure all
* of the nodes contents are freed or destroyed.
*/
- if (mast->orig_r->max < mast->orig_r->last)
- mast->orig_r->offset = mas_data_end(mast->orig_r) + 1;
- else {
- wr_mas.type = mte_node_type(mast->orig_r->node);
- mas_wr_node_walk(&wr_mas);
- }
+ wr_mas.type = mte_node_type(mast->orig_r->node);
+ mas_wr_node_walk(&wr_mas);
/* Set up the left side of things */
mast->orig_l->offset = 0;
mast->orig_l->index = mast->l->min;
wr_mas.mas = mast->orig_l;
wr_mas.type = mte_node_type(mast->orig_l->node);
mas_wr_node_walk(&wr_mas);
+
+ mast->bn->type = wr_mas.type;
}

/*
@@ -2881,7 +2955,7 @@ static int mas_spanning_rebalance(struct ma_state *mas,
struct maple_enode *left = NULL, *middle = NULL, *right = NULL;

MA_STATE(l_mas, mas->tree, mas->index, mas->index);
- MA_STATE(r_mas, mas->tree, mas->index, mas->index);
+ MA_STATE(r_mas, mas->tree, mas->index, mas->last);
MA_STATE(m_mas, mas->tree, mas->index, mas->index);
MA_TOPIARY(free, mas->tree);
MA_TOPIARY(destroy, mas->tree);
@@ -2897,14 +2971,9 @@ static int mas_spanning_rebalance(struct ma_state *mas,
mast->destroy = &destroy;
l_mas.node = r_mas.node = m_mas.node = MAS_NONE;
if (!mas_is_root_limits(mast->orig_l) &&
- unlikely(mast->bn->b_end <= mt_min_slots[mast->bn->type])) {
- /*
- * Do not free the current node as it may be freed in a bulk
- * free.
- */
- if (!mast_sibling_rebalance_right(mast, false))
- mast_cousin_rebalance_right(mast, false);
- }
+ unlikely(mast->bn->b_end <= mt_min_slots[mast->bn->type]))
+ mast_spanning_rebalance(mast);
+
mast->orig_l->depth = 0;

/*
@@ -2948,6 +3017,15 @@ static int mas_spanning_rebalance(struct ma_state *mas,

/* Copy anything necessary out of the right node. */
mast_combine_cp_right(mast);
+ if (mte_dead_node(mast->orig_l->node) ||
+ mte_dead_node(mast->orig_r->node)) {
+ printk("FUCKED. l %p is %s and r %p is %s\n",
+ mas_mn(mast->orig_l),
+ mte_dead_node(mast->orig_l->node) ? "dead" : "alive",
+ mas_mn(mast->orig_r),
+ mte_dead_node(mast->orig_r->node) ? "dead" : "alive");
+ printk("Writing %lu-%lu\n", mas->index, mas->last);
+ }
mast_topiary(mast);
mast->orig_l->last = mast->orig_l->max;

@@ -2961,15 +3039,14 @@ static int mas_spanning_rebalance(struct ma_state *mas,
if (mas_is_root_limits(mast->orig_l))
break;

- /* Try to get enough data for the next iteration. */
- if (!mast_sibling_rebalance_right(mast, true))
- if (!mast_cousin_rebalance_right(mast, true))
- break;
+ if (!mast_spanning_rebalance(mast))
+ break;

/* rebalancing from other nodes may require another loop. */
if (!count)
count++;
}
+
l_mas.node = mt_mk_node(ma_mnode_ptr(mas_pop_node(mas)),
mte_node_type(mast->orig_l->node));
mast->orig_l->depth++;
@@ -3042,6 +3119,7 @@ static inline int mas_rebalance(struct ma_state *mas,
mast.orig_l = &l_mas;
mast.orig_r = &r_mas;
mast.bn = b_node;
+ mast.bn->type = mte_node_type(mas->node);

l_mas = r_mas = *mas;

@@ -3855,7 +3933,7 @@ static inline int mas_new_root(struct ma_state *mas, void *entry)
return 1;
}
/*
- * mas_spanning_store() - Create a subtree with the store operation completed
+ * mas_wr_spanning_store() - Create a subtree with the store operation completed
* and new nodes where necessary, then place the sub-tree in the actual tree.
* Note that mas is expected to point to the node which caused the store to
* span.
@@ -3941,6 +4019,13 @@ static inline int mas_wr_spanning_store(struct ma_wr_state *wr_mas)
mast.bn = &b_node;
mast.orig_l = &l_mas;
mast.orig_r = &r_mas;
+ if (mte_dead_node(mast.orig_l->node) ||
+ mte_dead_node(mast.orig_r->node)) {
+ printk("FUCKED. l is %s and r is %s\n",
+ mte_dead_node(mast.orig_l->node) ? "dead" : "alive",
+ mte_dead_node(mast.orig_r->node) ? "dead" : "alive");
+ printk("Writing %lu-%lu\n", mas->index, mas->last);
+ }
/* Combine l_mas and r_mas and split them up evenly again. */
return mas_spanning_rebalance(mas, &mast, height + 1);
}
@@ -5387,6 +5472,9 @@ static inline void __rcu **mas_destroy_descend(struct ma_state *mas,
node = mas_mn(mas);
slots = ma_slots(node, mte_node_type(mas->node));
next = mas_slot_locked(mas, slots, 0);
+ if ((mte_dead_node(next)))
+ next = mas_slot_locked(mas, slots, 1);
+
mte_set_node_dead(mas->node);
node->type = mte_node_type(mas->node);
node->piv_parent = prev;
@@ -5394,6 +5482,7 @@ static inline void __rcu **mas_destroy_descend(struct ma_state *mas,
offset = 0;
prev = mas->node;
} while (!mte_is_leaf(next));
+
return slots;
}

@@ -5427,17 +5516,15 @@ static void mt_destroy_walk(struct maple_enode *enode, unsigned char ma_flags,
mas.node = node->piv_parent;
if (mas_mn(&mas) == node)
goto start_slots_free;
+
type = mte_node_type(mas.node);
slots = ma_slots(mte_to_node(mas.node), type);
- if ((offset < mt_slots[type])) {
- struct maple_enode *next = slots[offset];
+ if ((offset < mt_slots[type]) && mte_node_type(slots[offset]) &&
+ mte_to_node(slots[offset])) {
+ struct maple_enode *parent = mas.node;

- if (mte_node_type(next) && mte_to_node(next)) {
- struct maple_enode *parent = mas.node;
-
- mas.node = mas_slot_locked(&mas, slots, offset);
- slots = mas_destroy_descend(&mas, parent, offset);
- }
+ mas.node = mas_slot_locked(&mas, slots, offset);
+ slots = mas_destroy_descend(&mas, parent, offset);
}
node = mas_mn(&mas);
} while (start != mas.node);
--
2.35.1