[PATCH RFC 5/8] riscv/kaslr: support sparse memory model

From: Zong Li
Date: Tue Mar 24 2020 - 03:31:10 EST


For sparse memory model, we select a random memory node first, then get
a random offset in this node. It gets one memory node in flat memory
model case.

Signed-off-by: Zong Li <zong.li@xxxxxxxxxx>
---
arch/riscv/kernel/kaslr.c | 139 ++++++++++++++++++++++++++++----------
1 file changed, 105 insertions(+), 34 deletions(-)

diff --git a/arch/riscv/kernel/kaslr.c b/arch/riscv/kernel/kaslr.c
index 9ec2b608eb7f..59001d6fdfc3 100644
--- a/arch/riscv/kernel/kaslr.c
+++ b/arch/riscv/kernel/kaslr.c
@@ -55,8 +55,9 @@ static __init int get_node_addr_size_cells(const char *path, int *addr_cell,

static __init void kaslr_get_mem_info(uintptr_t *mem_start,
uintptr_t *mem_size)
+ uintptr_t kernel_size, int find_index)
{
- int node, root, addr_cells, size_cells;
+ int node, root, addr_cells, size_cells, idx = 0;
u64 base, size;

/* Get root node's address cells and size cells. */
@@ -81,14 +82,56 @@ static __init void kaslr_get_mem_info(uintptr_t *mem_start,
reg = get_reg_address(addr_cells, reg, &base);
reg = get_reg_address(size_cells, reg, &size);

- *mem_start = base;
- *mem_size = size;
+ if (size < (kernel_size * 2))
+ continue;

- break;
+ if (idx == find_index) {
+ *mem_start = base;
+ *mem_size = size;
+ break;
+ }
+
+ idx++;
}
}
}

+static __init int get_memory_nodes_num(uintptr_t kernel_size)
+{
+ int node, root, addr_cells, size_cells, total_nodes = 0;
+ u64 base, size;
+
+ /* Get root node's address cells and size cells. */
+ root = get_node_addr_size_cells("/", &addr_cells, &size_cells);
+ if (root < 0)
+ return 0;
+
+ /* Get memory base address and size. */
+ fdt_for_each_subnode(node, dtb_early_va, root) {
+ const char *dev_type;
+ const u32 *reg;
+
+ dev_type = fdt_getprop(dtb_early_va, node, "device_type", NULL);
+ if (!dev_type)
+ continue;
+
+ if (!strcmp(dev_type, "memory")) {
+ reg = fdt_getprop(dtb_early_va, node, "reg", NULL);
+ if (!reg)
+ return 0;
+
+ reg = get_reg_address(addr_cells, reg, &base);
+ reg = get_reg_address(size_cells, reg, &size);
+
+ /* Candidate ensures that it don't overlap itself. */
+ if (size > kernel_size * 2)
+ total_nodes++;
+ }
+ }
+
+ return total_nodes;
+}
+
/* Return a default seed if there is no HW generator. */
static u64 kaslr_default_seed = ULL(-1);
static __init u64 kaslr_get_seed(void)
@@ -198,10 +241,11 @@ static __init bool has_regions_overlapping(uintptr_t start_addr,
return false;
}

-static inline __init unsigned long get_legal_offset(int random_index,
- int max_index,
- uintptr_t mem_start,
- uintptr_t kernel_size)
+static inline __init unsigned long get_legal_offset_in_node(int random_index,
+ int max_index,
+ uintptr_t mem_start,
+ uintptr_t
+ kernel_size)
{
uintptr_t start_addr, end_addr;
int idx, stop_idx;
@@ -214,7 +258,8 @@ static inline __init unsigned long get_legal_offset(int random_index,

/* Check overlap to other regions. */
if (!has_regions_overlapping(start_addr, end_addr))
- return idx * SZ_2M + kernel_size;
+ return idx * SZ_2M + kernel_size + (mem_start -
+ __pa(PAGE_OFFSET));

if (idx-- < 0)
idx = max_index;
@@ -224,6 +269,56 @@ static inline __init unsigned long get_legal_offset(int random_index,
return 0;
}

+#define MEM_RESERVE_START __pa(PAGE_OFFSET)
+static inline __init unsigned long get_legal_offset(u64 random,
+ uintptr_t kernel_size)
+{
+ int mem_nodes, idx, stop_idx, index;
+ uintptr_t mem_start = 0, mem_size = 0, random_size, ret;
+
+ mem_nodes = get_memory_nodes_num(kernel_size);
+
+ idx = stop_idx = random % mem_nodes;
+
+ do {
+ kaslr_get_mem_info(&mem_start, &mem_size, kernel_size, idx);
+
+ if (!mem_size)
+ return 0;
+
+ if (mem_start < MEM_RESERVE_START) {
+ mem_size -= MEM_RESERVE_START - mem_start;
+ mem_start = MEM_RESERVE_START;
+ }
+
+ /*
+ * Limit randomization range within 1G, so we can exploit
+ * early_pmd/early_pte during early page table phase.
+ */
+ random_size = min_t(u64,
+ mem_size - (kernel_size * 2),
+ SZ_1G - (kernel_size * 2));
+
+ if (!random_size || random_size < SZ_2M)
+ return 0;
+
+ /* The index of 2M block in whole available region */
+ index = random % (random_size / SZ_2M);
+
+ ret =
+ get_legal_offset_in_node(index, random_size / SZ_2M,
+ mem_start, kernel_size);
+ if (ret)
+ break;
+
+ if (idx-- < 0)
+ idx = mem_nodes - 1;
+
+ } while (idx != stop_idx);
+
+ return ret;
+}
+
static inline __init u64 rotate_xor(u64 hash, const void *area, size_t size)
{
size_t i;
@@ -238,12 +333,9 @@ static inline __init u64 rotate_xor(u64 hash, const void *area, size_t size)
return hash;
}

-#define MEM_RESERVE_START __pa(PAGE_OFFSET)
static __init uintptr_t get_random_offset(u64 seed, uintptr_t kernel_size)
{
- uintptr_t mem_start = 0, mem_size= 0, random_size;
uintptr_t kernel_size_align = round_up(kernel_size, SZ_2M);
- int index;
u64 random = 0;
cycles_t time_base;

@@ -261,28 +353,7 @@ static __init uintptr_t get_random_offset(u64 seed, uintptr_t kernel_size)
if (seed)
random = rotate_xor(random, &seed, sizeof(seed));

- kaslr_get_mem_info(&mem_start, &mem_size);
- if (!mem_size)
- return 0;
-
- if (mem_start < MEM_RESERVE_START) {
- mem_size -= MEM_RESERVE_START - mem_start;
- mem_start = MEM_RESERVE_START;
- }
-
- /*
- * Limit randomization range within 1G, so we can exploit
- * early_pmd/early_pte during early page table phase.
- */
- random_size = min_t(u64,
- mem_size - (kernel_size_align * 2),
- SZ_1G - (kernel_size_align * 2));
-
- /* The index of 2M block in whole avaliable region */
- index = random % (random_size / SZ_2M);
-
- return get_legal_offset(index, random_size / SZ_2M,
- mem_start, kernel_size_align);
+ return get_legal_offset(random, kernel_size_align);
}

uintptr_t __init kaslr_early_init(void)
--
2.25.1