Re: bpf: shift-out-of-bounds in tnum_rshift()

From: Hao Sun
Date: Wed Nov 01 2023 - 05:53:08 EST


On Fri, Oct 27, 2023 at 7:51 PM Andrii Nakryiko
<andrii.nakryiko@xxxxxxxxx> wrote:
>
> On Wed, Oct 25, 2023 at 10:34 AM Eduard Zingerman <eddyz87@xxxxxxxxx> wrote:
> >
> > On Tue, 2023-10-24 at 14:40 +0200, Hao Sun wrote:
> > > Hi,
> > >
> > > The following program can trigger a shift-out-of-bounds in
> > > tnum_rshift(), called by scalar32_min_max_rsh():
> > >
> > > 0: (bc) w0 = w1
> > > 1: (bf) r2 = r0
> > > 2: (18) r3 = 0xd
> > > 4: (bc) w4 = w0
> > > 5: (bf) r5 = r0
> > > 6: (bf) r7 = r3
> > > 7: (bf) r8 = r4
> > > 8: (2f) r8 *= r5
> > > 9: (cf) r5 s>>= r5
> > > 10: (a6) if w8 < 0xfffffffb goto pc+10
> > > 11: (1f) r7 -= r5
> > > 12: (71) r6 = *(u8 *)(r1 +17)
> > > 13: (5f) r3 &= r8
> > > 14: (74) w2 >>= 30
> > > 15: (1f) r7 -= r5
> > > 16: (5d) if r8 != r6 goto pc+4
> > > 17: (c7) r8 s>>= 5
> > > 18: (cf) r0 s>>= r0
> > > 19: (7f) r0 >>= r0
> > > 20: (7c) w5 >>= w8 # shift-out-bounds here
> > > 21: exit
> >
> > Here is a simplified example:
> >
> > SEC("?tp")
> > __success __retval(0)
> > __naked void large_shifts(void)
> > {
> > asm volatile (" \
> > call %[bpf_get_prandom_u32]; \n\
> > r8 = r0; \n\
> > r6 = r0; \n\
> > r6 &= 0xf; \n\
> > if w8 < 0xffffffff goto +2; \n\
> > if r8 != r6 goto +1; \n\
> > w0 >>= w8; /* shift-out-bounds here */ \n\
> > exit; \n\
> > " :
> > : __imm(bpf_get_prandom_u32)
> > : __clobber_all);
> > }
> >
>
> With my changes the verifier does correctly derive that r8 != r6 will
> always happen, and thus skips w0 >>= w8. But the test itself with

A similar issue can be triggered after your patch for JNE/JEQ.

For the following case, the verifier would shift out of bound:
// 0: r0 = -2
BPF_MOV64_IMM(BPF_REG_0, -2),
// 1: r0 /= 1
BPF_ALU64_IMM(BPF_DIV, BPF_REG_0, 1),
// 2: r8 = r0
BPF_MOV64_REG(BPF_REG_8, BPF_REG_0),
// 3: if w8 != 0xfffffffe goto+4
BPF_JMP32_IMM(BPF_JNE, BPF_REG_8, 0xfffffffe, 4),
// 4: if r8 s> 0xd goto+3
BPF_JMP_IMM(BPF_JSGT, BPF_REG_8, 0xd, 3),
// 5: r4 = 0x2
BPF_MOV64_IMM(BPF_REG_4, 0x2),
// 6: if r8 s<= r4 goto+1
BPF_JMP_REG(BPF_JSLE, BPF_REG_8, BPF_REG_4, 1),
// 7: w8 s>>= w0 # shift out of bound here
BPF_ALU32_REG(BPF_ARSH, BPF_REG_8, BPF_REG_0),
// 8: exit
BPF_EXIT_INSN(),

-------- Verifier Log --------
func#0 @0
0: R1=ctx(off=0,imm=0) R10=fp0
0: (b7) r0 = -2 ; R0_w=-2
1: (37) r0 /= 1 ; R0_w=scalar()
2: (bf) r8 = r0 ; R0_w=scalar(id=1) R8_w=scalar(id=1)
3: (56) if w8 != 0xfffffffe goto pc+4 ;
R8_w=scalar(id=1,smin=-9223372032559808514,smax=9223372036854775806,umin=umin32=4294967294,umax=18446744073709551614,smin32=-2,smax32=-2,
umax32=4294967294,var_off=(0xfffffffe; 0xffffffff00000000))
4: (65) if r8 s> 0xd goto pc+3 ;
R8_w=scalar(id=1,smin=-9223372032559808514,smax=13,umin=umin32=4294967294,umax=18446744073709551614,smin32=-2,smax32=-2,umax32=4294967294,
var_off=(0xfffffffe; 0xffffffff00000000))
5: (b7) r4 = 2 ; R4_w=2
6: (dd) if r8 s<= r4 goto pc+1 ; R4_w=2 R8_w=4294967294
7: (cc) w8 s>>= w0 ; R0=4294967294 R8=4294967295
8: (95) exit

Here, after #6, reg range is incorrect, seems to be an issue in JSLE case
in is_branch_taken(). Is this issue fixed in your patch series?

> __retval(0) is not a valid test, so it would be good to construct
> something that will correctly return 0 at runtime (or use some other
> check). So I won't put this test into my patch set and will live it as
> a follow up for someone. But here's the log for anyone curious:
>
> VERIFIER LOG:
> =============
> func#0 @0
> 0: R1=ctx(off=0,imm=0) R10=fp0
> ; asm volatile (" \
> 0: (85) call bpf_get_prandom_u32#7 ; R0_w=scalar()
> 1: (bf) r8 = r0 ; R0_w=scalar(id=1) R8_w=scalar(id=1)
> 2: (bf) r6 = r0 ; R0_w=scalar(id=1) R6_w=scalar(id=1)
> 3: (57) r6 &= 15 ;
> R6_w=scalar(smin=smin32=0,smax=umax=smax32=umax32=15,var_off=(0x0;
> 0xf))
> 4: (a6) if w8 < 0xffffffff goto pc+2 ;
> R8_w=scalar(id=1,smin=-9223372032559808513,umin=umin32=4294967295,smin32=-1,smax32=-1,var_off=(0xffffffff;
> 0xffffffff00000000))
> 5: (5d) if r8 != r6 goto pc+1
> mark_precise: frame0: last_idx 5 first_idx 0 subseq_idx -1
> mark_precise: frame0: regs=r0,r8 stack= before 4: (a6) if w8 <
> 0xffffffff goto pc+2
> mark_precise: frame0: regs=r0,r8 stack= before 3: (57) r6 &= 15
> mark_precise: frame0: regs=r0,r8 stack= before 2: (bf) r6 = r0
> mark_precise: frame0: regs=r0,r8 stack= before 1: (bf) r8 = r0
> mark_precise: frame0: regs=r0 stack= before 0: (85) call bpf_get_prandom_u32#7
> mark_precise: frame0: last_idx 5 first_idx 0 subseq_idx -1
> mark_precise: frame0: regs=r6 stack= before 4: (a6) if w8 < 0xffffffff goto pc+2
> mark_precise: frame0: regs=r6 stack= before 3: (57) r6 &= 15
> mark_precise: frame0: regs=r6 stack= before 2: (bf) r6 = r0
> mark_precise: frame0: regs=r0 stack= before 1: (bf) r8 = r0
> mark_precise: frame0: regs=r0 stack= before 0: (85) call bpf_get_prandom_u32#7
> 5: R6_w=scalar(smin=smin32=0,smax=umax=smax32=umax32=15,var_off=(0x0;
> 0xf)) R8_w=scalar(id=1,smin=-9223372032559808513,umin=umin32=4294967295,smin32=-1,smax32=-1,var_off=(0xffffffff;
> 0xffffffff00000000))
> 7: (95) exit
>
> from 4 to 7: R0=scalar(id=1,smax=9223372036854775806,umax=18446744073709551614,umax32=4294967294)
> R6=scalar(smin=smin32=0,smax=umax=smax32=umax32=15,var_off=(0x0; 0xf))
> R8=scalar(id=1,smax=9223372036854775806,umax=18446744073709551614,umax32=4294967294)
> R10=fp0
> 7: R0=scalar(id=1,smax=9223372036854775806,umax=18446744073709551614,umax32=4294967294)
> R6=scalar(smin=smin32=0,smax=umax=smax32=umax32=15,var_off=(0x0; 0xf))
> R8=scalar(id=1,smax=9223372036854775806,umax=18446744073709551614,umax32=4294967294)
> R10=fp0
> 7: (95) exit
> processed 8 insns (limit 1000000) max_states_per_insn 0 total_states 1
> peak_states 1 mark_read 1
> =============
>
> at insn #4, simulating a FALSE condition, verifier knows that r6 is
> [0, 15], while w8 is exactly 0xffffffff, so at insn #5 it can tell
> that 0xffffffff can never be equal to a value in [0, 15] range, and
> thus skips the shift instruction.
>