[PATCH v2] riscv, bpf: Adapt bpf trampoline to optimized riscv ftrace framework

From: Pu Lehui
Date: Fri Jul 21 2023 - 06:06:43 EST


From: Pu Lehui <pulehui@xxxxxxxxxx>

Commit 6724a76cff85 ("riscv: ftrace: Reduce the detour code size to
half") optimizes the detour code size of kernel functions to half with
T0 register and the upcoming DYNAMIC_FTRACE_WITH_DIRECT_CALLS of riscv
is based on this optimization, we need to adapt riscv bpf trampoline
based on this. One thing to do is to reduce detour code size of bpf
programs, and the second is to deal with the return address after the
execution of bpf trampoline. Meanwhile, we need to construct the frame
of parent function, otherwise we will miss one layer when unwinding.
The related tests have passed.

Signed-off-by: Pu Lehui <pulehui@xxxxxxxxxx>
---
v2:
- target to riscv tree to avoid the crash caused by the wrong merge
sequence.
- fix the wrong position of traced function FP.

arch/riscv/net/bpf_jit_comp64.c | 153 +++++++++++++++++---------------
1 file changed, 82 insertions(+), 71 deletions(-)

diff --git a/arch/riscv/net/bpf_jit_comp64.c b/arch/riscv/net/bpf_jit_comp64.c
index c648864c8cd1..0ca4f5c0097c 100644
--- a/arch/riscv/net/bpf_jit_comp64.c
+++ b/arch/riscv/net/bpf_jit_comp64.c
@@ -13,6 +13,8 @@
#include <asm/patch.h>
#include "bpf_jit.h"

+#define RV_FENTRY_NINSNS 2
+
#define RV_REG_TCC RV_REG_A6
#define RV_REG_TCC_SAVED RV_REG_S6 /* Store A6 in S6 if program do calls */

@@ -241,7 +243,7 @@ static void __build_epilogue(bool is_tail_call, struct rv_jit_context *ctx)
if (!is_tail_call)
emit_mv(RV_REG_A0, RV_REG_A5, ctx);
emit_jalr(RV_REG_ZERO, is_tail_call ? RV_REG_T3 : RV_REG_RA,
- is_tail_call ? 20 : 0, /* skip reserved nops and TCC init */
+ is_tail_call ? (RV_FENTRY_NINSNS + 1) * 4 : 0, /* skip reserved nops and TCC init */
ctx);
}

@@ -618,32 +620,7 @@ static int add_exception_handler(const struct bpf_insn *insn,
return 0;
}

-static int gen_call_or_nops(void *target, void *ip, u32 *insns)
-{
- s64 rvoff;
- int i, ret;
- struct rv_jit_context ctx;
-
- ctx.ninsns = 0;
- ctx.insns = (u16 *)insns;
-
- if (!target) {
- for (i = 0; i < 4; i++)
- emit(rv_nop(), &ctx);
- return 0;
- }
-
- rvoff = (s64)(target - (ip + 4));
- emit(rv_sd(RV_REG_SP, -8, RV_REG_RA), &ctx);
- ret = emit_jump_and_link(RV_REG_RA, rvoff, false, &ctx);
- if (ret)
- return ret;
- emit(rv_ld(RV_REG_RA, -8, RV_REG_SP), &ctx);
-
- return 0;
-}
-
-static int gen_jump_or_nops(void *target, void *ip, u32 *insns)
+static int gen_jump_or_nops(void *target, void *ip, u32 *insns, bool is_call)
{
s64 rvoff;
struct rv_jit_context ctx;
@@ -658,38 +635,35 @@ static int gen_jump_or_nops(void *target, void *ip, u32 *insns)
}

rvoff = (s64)(target - ip);
- return emit_jump_and_link(RV_REG_ZERO, rvoff, false, &ctx);
+ return emit_jump_and_link(is_call ? RV_REG_T0 : RV_REG_ZERO, rvoff, false, &ctx);
}

int bpf_arch_text_poke(void *ip, enum bpf_text_poke_type poke_type,
void *old_addr, void *new_addr)
{
- u32 old_insns[4], new_insns[4];
+ u32 old_insns[RV_FENTRY_NINSNS], new_insns[RV_FENTRY_NINSNS];
bool is_call = poke_type == BPF_MOD_CALL;
- int (*gen_insns)(void *target, void *ip, u32 *insns);
- int ninsns = is_call ? 4 : 2;
int ret;

- if (!is_bpf_text_address((unsigned long)ip))
+ if (!is_kernel_text((unsigned long)ip) &&
+ !is_bpf_text_address((unsigned long)ip))
return -ENOTSUPP;

- gen_insns = is_call ? gen_call_or_nops : gen_jump_or_nops;
-
- ret = gen_insns(old_addr, ip, old_insns);
+ ret = gen_jump_or_nops(old_addr, ip, old_insns, is_call);
if (ret)
return ret;

- if (memcmp(ip, old_insns, ninsns * 4))
+ if (memcmp(ip, old_insns, RV_FENTRY_NINSNS * 4))
return -EFAULT;

- ret = gen_insns(new_addr, ip, new_insns);
+ ret = gen_jump_or_nops(new_addr, ip, new_insns, is_call);
if (ret)
return ret;

cpus_read_lock();
mutex_lock(&text_mutex);
- if (memcmp(ip, new_insns, ninsns * 4))
- ret = patch_text(ip, new_insns, ninsns);
+ if (memcmp(ip, new_insns, RV_FENTRY_NINSNS * 4))
+ ret = patch_text(ip, new_insns, RV_FENTRY_NINSNS);
mutex_unlock(&text_mutex);
cpus_read_unlock();

@@ -787,8 +761,7 @@ static int __arch_prepare_bpf_trampoline(struct bpf_tramp_image *im,
int i, ret, offset;
int *branches_off = NULL;
int stack_size = 0, nregs = m->nr_args;
- int retaddr_off, fp_off, retval_off, args_off;
- int nregs_off, ip_off, run_ctx_off, sreg_off;
+ int retval_off, args_off, nregs_off, ip_off, run_ctx_off, sreg_off;
struct bpf_tramp_links *fentry = &tlinks[BPF_TRAMP_FENTRY];
struct bpf_tramp_links *fexit = &tlinks[BPF_TRAMP_FEXIT];
struct bpf_tramp_links *fmod_ret = &tlinks[BPF_TRAMP_MODIFY_RETURN];
@@ -796,13 +769,27 @@ static int __arch_prepare_bpf_trampoline(struct bpf_tramp_image *im,
bool save_ret;
u32 insn;

- /* Generated trampoline stack layout:
+ /* Two types of generated trampoline stack layout:
+ *
+ * 1. trampoline called from function entry
+ * --------------------------------------
+ * FP + 8 [ RA to parent func ] return address to parent
+ * function
+ * FP + 0 [ FP of parent func ] frame pointer of parent
+ * function
+ * FP - 8 [ T0 to traced func ] return address of traced
+ * function
+ * FP - 16 [ FP of traced func ] frame pointer of traced
+ * function
+ * --------------------------------------
*
- * FP - 8 [ RA of parent func ] return address of parent
+ * 2. trampoline called directly
+ * --------------------------------------
+ * FP - 8 [ RA to caller func ] return address to caller
* function
- * FP - retaddr_off [ RA of traced func ] return address of traced
+ * FP - 16 [ FP of caller func ] frame pointer of caller
* function
- * FP - fp_off [ FP of parent func ]
+ * --------------------------------------
*
* FP - retval_off [ return value ] BPF_TRAMP_F_CALL_ORIG or
* BPF_TRAMP_F_RET_FENTRY_RET
@@ -833,14 +820,8 @@ static int __arch_prepare_bpf_trampoline(struct bpf_tramp_image *im,
if (nregs > 8)
return -ENOTSUPP;

- /* room for parent function return address */
- stack_size += 8;
-
- stack_size += 8;
- retaddr_off = stack_size;
-
- stack_size += 8;
- fp_off = stack_size;
+ /* room of trampoline frame to store return address and frame pointer */
+ stack_size += 16;

save_ret = flags & (BPF_TRAMP_F_CALL_ORIG | BPF_TRAMP_F_RET_FENTRY_RET);
if (save_ret) {
@@ -867,12 +848,29 @@ static int __arch_prepare_bpf_trampoline(struct bpf_tramp_image *im,

stack_size = round_up(stack_size, 16);

- emit_addi(RV_REG_SP, RV_REG_SP, -stack_size, ctx);
-
- emit_sd(RV_REG_SP, stack_size - retaddr_off, RV_REG_RA, ctx);
- emit_sd(RV_REG_SP, stack_size - fp_off, RV_REG_FP, ctx);
-
- emit_addi(RV_REG_FP, RV_REG_SP, stack_size, ctx);
+ if (func_addr) {
+ /* For the trampoline called from function entry,
+ * the frame of traced function and the frame of
+ * trampoline need to be considered.
+ */
+ emit_addi(RV_REG_SP, RV_REG_SP, -16, ctx);
+ emit_sd(RV_REG_SP, 8, RV_REG_RA, ctx);
+ emit_sd(RV_REG_SP, 0, RV_REG_FP, ctx);
+ emit_addi(RV_REG_FP, RV_REG_SP, 16, ctx);
+
+ emit_addi(RV_REG_SP, RV_REG_SP, -stack_size, ctx);
+ emit_sd(RV_REG_SP, stack_size - 8, RV_REG_T0, ctx);
+ emit_sd(RV_REG_SP, stack_size - 16, RV_REG_FP, ctx);
+ emit_addi(RV_REG_FP, RV_REG_SP, stack_size, ctx);
+ } else {
+ /* For the trampoline called directly, just handle
+ * the frame of trampoline.
+ */
+ emit_addi(RV_REG_SP, RV_REG_SP, -stack_size, ctx);
+ emit_sd(RV_REG_SP, stack_size - 8, RV_REG_RA, ctx);
+ emit_sd(RV_REG_SP, stack_size - 16, RV_REG_FP, ctx);
+ emit_addi(RV_REG_FP, RV_REG_SP, stack_size, ctx);
+ }

/* callee saved register S1 to pass start time */
emit_sd(RV_REG_FP, -sreg_off, RV_REG_S1, ctx);
@@ -890,7 +888,7 @@ static int __arch_prepare_bpf_trampoline(struct bpf_tramp_image *im,

/* skip to actual body of traced function */
if (flags & BPF_TRAMP_F_SKIP_FRAME)
- orig_call += 16;
+ orig_call += RV_FENTRY_NINSNS * 4;

if (flags & BPF_TRAMP_F_CALL_ORIG) {
emit_imm(RV_REG_A0, (const s64)im, ctx);
@@ -967,17 +965,30 @@ static int __arch_prepare_bpf_trampoline(struct bpf_tramp_image *im,

emit_ld(RV_REG_S1, -sreg_off, RV_REG_FP, ctx);

- if (flags & BPF_TRAMP_F_SKIP_FRAME)
- /* return address of parent function */
- emit_ld(RV_REG_RA, stack_size - 8, RV_REG_SP, ctx);
- else
- /* return address of traced function */
- emit_ld(RV_REG_RA, stack_size - retaddr_off, RV_REG_SP, ctx);
+ if (func_addr) {
+ /* trampoline called from function entry */
+ emit_ld(RV_REG_T0, stack_size - 8, RV_REG_SP, ctx);
+ emit_ld(RV_REG_FP, stack_size - 16, RV_REG_SP, ctx);
+ emit_addi(RV_REG_SP, RV_REG_SP, stack_size, ctx);

- emit_ld(RV_REG_FP, stack_size - fp_off, RV_REG_SP, ctx);
- emit_addi(RV_REG_SP, RV_REG_SP, stack_size, ctx);
+ emit_ld(RV_REG_RA, 8, RV_REG_SP, ctx);
+ emit_ld(RV_REG_FP, 0, RV_REG_SP, ctx);
+ emit_addi(RV_REG_SP, RV_REG_SP, 16, ctx);

- emit_jalr(RV_REG_ZERO, RV_REG_RA, 0, ctx);
+ if (flags & BPF_TRAMP_F_SKIP_FRAME)
+ /* return to parent function */
+ emit_jalr(RV_REG_ZERO, RV_REG_RA, 0, ctx);
+ else
+ /* return to traced function */
+ emit_jalr(RV_REG_ZERO, RV_REG_T0, 0, ctx);
+ } else {
+ /* trampoline called directly */
+ emit_ld(RV_REG_RA, stack_size - 8, RV_REG_SP, ctx);
+ emit_ld(RV_REG_FP, stack_size - 16, RV_REG_SP, ctx);
+ emit_addi(RV_REG_SP, RV_REG_SP, stack_size, ctx);
+
+ emit_jalr(RV_REG_ZERO, RV_REG_RA, 0, ctx);
+ }

ret = ctx->ninsns;
out:
@@ -1691,8 +1702,8 @@ void bpf_jit_build_prologue(struct rv_jit_context *ctx)

store_offset = stack_adjust - 8;

- /* reserve 4 nop insns */
- for (i = 0; i < 4; i++)
+ /* nops reserved for auipc+jalr pair */
+ for (i = 0; i < RV_FENTRY_NINSNS; i++)
emit(rv_nop(), ctx);

/* First instruction is always setting the tail-call-counter
--
2.25.1