Re: [Bug Report] bpf: incorrectly pruning runtime execution path

From: Hao Sun
Date: Wed Dec 13 2023 - 05:25:26 EST


On Wed, Dec 13, 2023 at 1:51 AM Andrii Nakryiko
<andrii.nakryiko@xxxxxxxxx> wrote:
>
> On Mon, Dec 11, 2023 at 7:31 AM Hao Sun <sunhao.th@xxxxxxxxx> wrote:
> >
> > Hi,
> >
> > The verifier incorrectly prunes a path expected to be executed at
> > runtime. In the following program, the execution path is:
> > from 6 to 8 (taken) -> from 11 to 15 (taken) -> from 18 to 22
> > (taken) -> from 26 to 27 (fall-through) -> from 29 to 30
> > (fall-through)
> > The verifier prunes the checking path at #26, skipping the actual
> > execution path.
> >
> > 0: (18) r2 = 0x1a000000be
> > 2: (bf) r5 = r1
> > 3: (bf) r8 = r2
> > 4: (bc) w4 = w5
> > 5: (85) call bpf_get_current_cgroup_id#680112
> > 6: (36) if w8 >= 0x69 goto pc+1
> > 7: (95) exit
> > 8: (18) r4 = 0x52
> > 10: (84) w4 = -w4
> > 11: (45) if r0 & 0xfffffffe goto pc+3
> > 12: (1f) r8 -= r4
> > 13: (0f) r0 += r0
> > 14: (2f) r4 *= r4
> > 15: (18) r3 = 0x1f00000034
> > 17: (c4) w4 s>>= 29
> > 18: (56) if w8 != 0xf goto pc+3
> > 19: r3 = bswap32 r3
> > 20: (18) r2 = 0x1c
> > 22: (67) r4 <<= 2
> > 23: (bf) r5 = r8
> > 24: (18) r2 = 0x4
> > 26: (7e) if w8 s>= w0 goto pc+5
> > 27: (4f) r8 |= r8
> > 28: (0f) r8 += r8
> > 29: (d6) if w5 s<= 0x1d goto pc+2
> > 30: (18) r0 = 0x4 ; incorrectly pruned here
>
>
>
> > 32: (95) exit
> >
> > -------- Verifier Log --------
> > func#0 @0
> > 0: R1=ctx() R10=fp0
> > 0: (18) r2 = 0x1a000000be ; R2_w=0x1a000000be
> > 2: (bf) r5 = r1 ; R1=ctx() R5_w=ctx()
> > 3: (bf) r8 = r2 ; R2_w=0x1a000000be R8_w=0x1a000000be
> > 4: (bc) w4 = w5 ;
> > R4_w=scalar(smin=0,smax=umax=0xffffffff,var_off=(0x0; 0xffffffff))
> > R5_w=ctx()
> > 5: (85) call bpf_get_current_cgroup_id#80 ; R0_w=scalar()
> > 6: (36) if w8 >= 0x69 goto pc+1
> > mark_precise: frame0: last_idx 6 first_idx 0 subseq_idx -1
> > mark_precise: frame0: regs=r8 stack= before 5: (85) call
> > bpf_get_current_cgroup_id#80
> > mark_precise: frame0: regs=r8 stack= before 4: (bc) w4 = w5
> > mark_precise: frame0: regs=r8 stack= before 3: (bf) r8 = r2
> > mark_precise: frame0: regs=r2 stack= before 2: (bf) r5 = r1
> > mark_precise: frame0: regs=r2 stack= before 0: (18) r2 = 0x1a000000be
> > 6: R8_w=0x1a000000be
> > 8: (18) r4 = 0x52 ; R4_w=82
> > 10: (84) w4 = -w4 ; R4=scalar()
> > 11: (45) if r0 & 0xfffffffe goto pc+3 ;
> > R0=scalar(smin=smin32=0,smax=umax=smax32=umax32=1,var_off=(0x0; 0x1))
> > 12: (1f) r8 -= r4 ; R4=scalar() R8_w=scalar()
> > 13: (0f) r0 += r0 ;
> > R0_w=scalar(smin=smin32=0,smax=umax=smax32=umax32=2,var_off=(0x0;
> > 0x3))
> > 14: (2f) r4 *= r4 ; R4_w=scalar()
> > 15: (18) r3 = 0x1f00000034 ; R3_w=0x1f00000034
> > 17: (c4) w4 s>>= 29 ;
> > R4_w=scalar(smin=0,smax=umax=0xffffffff,smin32=-4,smax32=3,var_off=(0x0;
> > 0xffffffff))
> > 18: (56) if w8 != 0xf goto pc+3 ;
> > R8_w=scalar(smin=0x800000000000000f,smax=0x7fffffff0000000f,umin=smin32=umin32=15,umax=0xffffffff0000000f,smax32=umax32=15,var_off=(0xf;
> > 0xffffffff00000000))
> > 19: (d7) r3 = bswap32 r3 ; R3_w=scalar()
> > 20: (18) r2 = 0x1c ; R2=28
> > 22: (67) r4 <<= 2 ;
> > R4_w=scalar(smin=0,smax=umax=0x3fffffffc,smax32=0x7ffffffc,umax32=0xfffffffc,var_off=(0x0;
> > 0x3fffffffc))
> > 23: (bf) r5 = r8 ;
> > R5_w=scalar(id=1,smin=0x800000000000000f,smax=0x7fffffff0000000f,umin=smin32=umin32=15,umax=0xffffffff0000000f,smax32=umax32=15,var_off=(0xf;
> > 0xffffffff00000000))
> > R8=scalar(id=1,smin=0x800000000000000f,smax=0x7fffffff0000000f,umin=smin32=umin32=15,umax=0xffffffff0000000f,smax32=umax32=15,var_off=(0xf;
> > 0xffffffff00000000))
> > 24: (18) r2 = 0x4 ; R2_w=4
> > 26: (7e) if w8 s>= w0 goto pc+5
>
> so here w8=15 and w0=[0,2], always taken, right?
>
> > mark_precise: frame0: last_idx 26 first_idx 22 subseq_idx -1
> > mark_precise: frame0: regs=r5,r8 stack= before 24: (18) r2 = 0x4
> > mark_precise: frame0: regs=r5,r8 stack= before 23: (bf) r5 = r8
> > mark_precise: frame0: regs=r8 stack= before 22: (67) r4 <<= 2
> > mark_precise: frame0: parent state regs=r8 stack=:
> > R0_rw=scalar(smin=smin32=0,smax=umax=smax32=umax32=2,var_off=(0x0;
> > 0x3)) R2_w=28 R3_w=scalar()
> > R4_rw=scalar(smin=0,smax=umax=0xffffffff,smin32=-4,smax32=3,var_off=(0x0;
> > 0xffffffff)) R8_rw=Pscalar(smin=0x800000000000000f,smax=0x7fffffff0000000f,umin=smin32=umin32=15,umax=0xffffffff0000000f,smax32=umax32=15,var_off=(0xf;
> > 0xffffffff00000000)) R10=fp0
> > mark_precise: frame0: last_idx 20 first_idx 11 subseq_idx 22
> > mark_precise: frame0: regs=r8 stack= before 20: (18) r2 = 0x1c
> > mark_precise: frame0: regs=r8 stack= before 19: (d7) r3 = bswap32 r3
> > mark_precise: frame0: regs=r8 stack= before 18: (56) if w8 != 0xf goto pc+3
> > mark_precise: frame0: regs=r8 stack= before 17: (c4) w4 s>>= 29
> > mark_precise: frame0: regs=r8 stack= before 15: (18) r3 = 0x1f00000034
> > mark_precise: frame0: regs=r8 stack= before 14: (2f) r4 *= r4
> > mark_precise: frame0: regs=r8 stack= before 13: (0f) r0 += r0
> > mark_precise: frame0: regs=r8 stack= before 12: (1f) r8 -= r4
> > mark_precise: frame0: regs=r4,r8 stack= before 11: (45) if r0 &
> > 0xfffffffe goto pc+3
> > mark_precise: frame0: parent state regs=r4,r8 stack=: R0_rw=scalar()
> > R4_rw=Pscalar() R8_rw=P0x1a000000be R10=fp0
> > mark_precise: frame0: last_idx 10 first_idx 0 subseq_idx 11
> > mark_precise: frame0: regs=r4,r8 stack= before 10: (84) w4 = -w4
> > mark_precise: frame0: regs=r4,r8 stack= before 8: (18) r4 = 0x52
> > mark_precise: frame0: regs=r8 stack= before 6: (36) if w8 >= 0x69 goto pc+1
> > mark_precise: frame0: regs=r8 stack= before 5: (85) call
> > bpf_get_current_cgroup_id#80
> > mark_precise: frame0: regs=r8 stack= before 4: (bc) w4 = w5
> > mark_precise: frame0: regs=r8 stack= before 3: (bf) r8 = r2
> > mark_precise: frame0: regs=r2 stack= before 2: (bf) r5 = r1
> > mark_precise: frame0: regs=r2 stack= before 0: (18) r2 = 0x1a000000be
> > mark_precise: frame0: last_idx 26 first_idx 22 subseq_idx -1
> > mark_precise: frame0: regs=r0 stack= before 24: (18) r2 = 0x4
> > mark_precise: frame0: regs=r0 stack= before 23: (bf) r5 = r8
> > mark_precise: frame0: regs=r0 stack= before 22: (67) r4 <<= 2
> > mark_precise: frame0: parent state regs=r0 stack=:
> > R0_rw=Pscalar(smin=smin32=0,smax=umax=smax32=umax32=2,var_off=(0x0;
> > 0x3)) R2_w=28 R3_w=scalar()
> > R4_rw=scalar(smin=0,smax=umax=0xffffffff,smin32=-4,smax32=3,var_off=(0x0;
> > 0xffffffff)) R8_rw=Pscalar(smin=0x800000000000000f,smax=0x7fffffff0000000f,umin=smin32=umin32=15,umax=0xffffffff0000000f,smax32=umax32=15,var_off=(0xf;
> > 0xffffffff00000000)) R10=fp0
> > mark_precise: frame0: last_idx 20 first_idx 11 subseq_idx 22
> > mark_precise: frame0: regs=r0 stack= before 20: (18) r2 = 0x1c
> > mark_precise: frame0: regs=r0 stack= before 19: (d7) r3 = bswap32 r3
> > mark_precise: frame0: regs=r0 stack= before 18: (56) if w8 != 0xf goto pc+3
> > mark_precise: frame0: regs=r0 stack= before 17: (c4) w4 s>>= 29
> > mark_precise: frame0: regs=r0 stack= before 15: (18) r3 = 0x1f00000034
> > mark_precise: frame0: regs=r0 stack= before 14: (2f) r4 *= r4
> > mark_precise: frame0: regs=r0 stack= before 13: (0f) r0 += r0
> > mark_precise: frame0: regs=r0 stack= before 12: (1f) r8 -= r4
> > mark_precise: frame0: regs=r0 stack= before 11: (45) if r0 &
> > 0xfffffffe goto pc+3
> > mark_precise: frame0: parent state regs=r0 stack=: R0_rw=Pscalar()
> > R4_rw=Pscalar() R8_rw=P0x1a000000be R10=fp0
> > mark_precise: frame0: last_idx 10 first_idx 0 subseq_idx 11
> > mark_precise: frame0: regs=r0 stack= before 10: (84) w4 = -w4
> > mark_precise: frame0: regs=r0 stack= before 8: (18) r4 = 0x52
> > mark_precise: frame0: regs=r0 stack= before 6: (36) if w8 >= 0x69 goto pc+1
> > mark_precise: frame0: regs=r0 stack= before 5: (85) call
> > bpf_get_current_cgroup_id#80
> > 26: R0=scalar(smin=smin32=0,smax=umax=smax32=umax32=2,var_off=(0x0;
> > 0x3)) R8=scalar(id=1,smin=0x800000000000000f,smax=0x7fffffff0000000f,umin=smin32=umin32=15,umax=0xffffffff0000000f,smax32=umax32=15,var_off=(0xf;
> > 0xffffffff00000000))
> > 32: (95) exit
> >
> > from 18 to 22: R0_w=scalar(smin=smin32=0,smax=umax=smax32=umax32=2,var_off=(0x0;
> > 0x3)) R3_w=0x1f00000034
> > R4_w=scalar(smin=0,smax=umax=0xffffffff,smin32=-4,smax32=3,var_off=(0x0;
> > 0xffffffff)) R8_w=scalar() R10=fp0
> > 22: R0_w=scalar(smin=smin32=0,smax=umax=smax32=umax32=2,var_off=(0x0;
> > 0x3)) R3_w=0x1f00000034
> > R4_w=scalar(smin=0,smax=umax=0xffffffff,smin32=-4,smax32=3,var_off=(0x0;
> > 0xffffffff)) R8_w=scalar() R10=fp0
> > 22: (67) r4 <<= 2 ;
> > R4_w=scalar(smin=0,smax=umax=0x3fffffffc,smax32=0x7ffffffc,umax32=0xfffffffc,var_off=(0x0;
> > 0x3fffffffc))
> > 23: (bf) r5 = r8 ; R5_w=scalar(id=2) R8_w=scalar(id=2)
> > 24: (18) r2 = 0x4 ; R2=4
> > 26: (7e) if w8 s>= w0 goto pc+5 ;
> > R0=scalar(smin=smin32=0,smax=umax=smax32=umax32=2,var_off=(0x0; 0x3))
> > R8=scalar(id=2,smax32=1)
>
> we didn't prune here, assuming w8 < w0, so w8=w5 is at most 1 (because
> r0 is [0, 2])
>
> > 27: (4f) r8 |= r8 ; R8_w=scalar()
>
> here r5 and r8 are disassociated
>
> > 28: (0f) r8 += r8 ; R8_w=scalar()
> > 29: (d6) if w5 s<= 0x1d goto pc+2
>
> w5 is at most 1 (signed), so this is always true, so we just to exit,
> 30: is still never visited
>
> > mark_precise: frame0: last_idx 29 first_idx 26 subseq_idx -1
> > mark_precise: frame0: regs=r5 stack= before 28: (0f) r8 += r8
> > mark_precise: frame0: regs=r5 stack= before 27: (4f) r8 |= r8
> > mark_precise: frame0: regs=r5 stack= before 26: (7e) if w8 s>= w0 goto pc+5
> > mark_precise: frame0: parent state regs=r5 stack=:
> > R0_rw=scalar(smin=smin32=0,smax=umax=smax32=umax32=2,var_off=(0x0;
> > 0x3)) R2_w=4 R3_w=0x1f00000034
> > R4_w=scalar(smin=0,smax=umax=0x3fffffffc,smax32=0x7ffffffc,umax32=0xfffffffc,var_off=(0x0;
> > 0x3fffffffc)) R5_rw=Pscalar(id=2) R8_rw=scalar(id=2) R10=fp0
> > mark_precise: frame0: last_idx 24 first_idx 11 subseq_idx 26
> > mark_precise: frame0: regs=r5,r8 stack= before 24: (18) r2 = 0x4
> > mark_precise: frame0: regs=r5,r8 stack= before 23: (bf) r5 = r8
> > mark_precise: frame0: regs=r8 stack= before 22: (67) r4 <<= 2
> > mark_precise: frame0: regs=r8 stack= before 18: (56) if w8 != 0xf goto pc+3
> > mark_precise: frame0: regs=r8 stack= before 17: (c4) w4 s>>= 29
> > mark_precise: frame0: regs=r8 stack= before 15: (18) r3 = 0x1f00000034
> > mark_precise: frame0: regs=r8 stack= before 14: (2f) r4 *= r4
> > mark_precise: frame0: regs=r8 stack= before 13: (0f) r0 += r0
> > mark_precise: frame0: regs=r8 stack= before 12: (1f) r8 -= r4
> > mark_precise: frame0: regs=r4,r8 stack= before 11: (45) if r0 &
> > 0xfffffffe goto pc+3
> > mark_precise: frame0: parent state regs= stack=: R0_rw=Pscalar()
> > R4_rw=Pscalar() R8_rw=P0x1a000000be R10=fp0
> > 29: R5=scalar(id=2,smax32=1)
> > 32: (95) exit
> >
> > from 26 to 32: R0=scalar(smin=smin32=0,smax=umax=smax32=umax32=2,var_off=(0x0;
> > 0x3)) R2=4 R3=0x1f00000034
> > R4=scalar(smin=0,smax=umax=0x3fffffffc,smax32=0x7ffffffc,umax32=0xfffffffc,var_off=(0x0;
> > 0x3fffffffc)) R5=scalar(id=2,smax=0x7fffffff7fffffff,umax=0xffffffff7fffffff,smin32=0,umax32=0x7fffffff,var_off=(0x0;
> > 0xffffffff7fffffff))
> > R8=scalar(id=2,smax=0x7fffffff7fffffff,umax=0xffffffff7fffffff,smin32=0,umax32=0x7fffffff,var_off=(0x0;
> > 0xffffffff7fffffff)) R10=fp0
> > 32: R0=scalar(smin=smin32=0,smax=umax=smax32=umax32=2,var_off=(0x0;
> > 0x3)) R2=4 R3=0x1f00000034
> > R4=scalar(smin=0,smax=umax=0x3fffffffc,smax32=0x7ffffffc,umax32=0xfffffffc,var_off=(0x0;
> > 0x3fffffffc)) R5=scalar(id=2,smax=0x7fffffff7fffffff,umax=0xffffffff7fffffff,smin32=0,umax32=0x7fffffff,var_off=(0x0;
> > 0xffffffff7fffffff))
> > R8=scalar(id=2,smax=0x7fffffff7fffffff,umax=0xffffffff7fffffff,smin32=0,umax32=0x7fffffff,var_off=(0x0;
> > 0xffffffff7fffffff)) R10=fp0
> > 32: (95) exit
>
> here we also skipped 30:, and w8 was in [0,0x7fffffff] range, r0 is
> [0,2], but it's precision doesn't matter as we didn't do any pruning
>
> NOTE this one.
>
> >
> > from 11 to 15: R0=scalar() R4=scalar() R8=0x1a000000be R10=fp0
> > 15: R0=scalar() R4=scalar() R8=0x1a000000be R10=fp0
> > 15: (18) r3 = 0x1f00000034 ; R3_w=0x1f00000034
> > 17: (c4) w4 s>>= 29 ;
> > R4=scalar(smin=0,smax=umax=0xffffffff,smin32=-4,smax32=3,var_off=(0x0;
> > 0xffffffff))
> > 18: (56) if w8 != 0xf goto pc+3
>
> known true, always taken
>
> > mark_precise: frame0: last_idx 18 first_idx 18 subseq_idx -1
> > mark_precise: frame0: parent state regs=r8 stack=: R0=scalar()
> > R3_w=0x1f00000034
> > R4_w=scalar(smin=0,smax=umax=0xffffffff,smin32=-4,smax32=3,var_off=(0x0;
> > 0xffffffff)) R8_r=P0x1a000000be R10=fp0
> > mark_precise: frame0: last_idx 17 first_idx 11 subseq_idx 18
> > mark_precise: frame0: regs=r8 stack= before 17: (c4) w4 s>>= 29
> > mark_precise: frame0: regs=r8 stack= before 15: (18) r3 = 0x1f00000034
> > mark_precise: frame0: regs=r8 stack= before 11: (45) if r0 &
> > 0xfffffffe goto pc+3
> > mark_precise: frame0: parent state regs= stack=: R0_rw=Pscalar()
> > R4_rw=Pscalar() R8_rw=P0x1a000000be R10=fp0
> > 18: R8=0x1a000000be
> > 22: (67) r4 <<= 2 ;
> > R4_w=scalar(smin=0,smax=umax=0x3fffffffc,smax32=0x7ffffffc,umax32=0xfffffffc,var_off=(0x0;
> > 0x3fffffffc))
> > 23: (bf) r5 = r8 ; R5_w=0x1a000000be R8=0x1a000000be
> > 24: (18) r2 = 0x4
> > frame 0: propagating r5
> > mark_precise: frame0: last_idx 26 first_idx 18 subseq_idx -1
> > mark_precise: frame0: regs=r5 stack= before 24: (18) r2 = 0x4
> > mark_precise: frame0: regs=r5 stack= before 23: (bf) r5 = r8
> > mark_precise: frame0: regs=r8 stack= before 22: (67) r4 <<= 2
> > mark_precise: frame0: regs=r8 stack= before 18: (56) if w8 != 0xf goto pc+3
> > mark_precise: frame0: parent state regs= stack=: R0_r=scalar()
> > R3_w=0x1f00000034
> > R4_rw=scalar(smin=0,smax=umax=0xffffffff,smin32=-4,smax32=3,var_off=(0x0;
> > 0xffffffff)) R8_r=P0x1a000000be R10=fp0
> > 26: safe
>
> and here we basically need to evaluate
>
> if w8 s>= w0 goto pc+5
>
> w8 is precisely known to be 0x000000be, while w0 is unknown. Now go
> back to "NOTE this one" mark above. w8 is inside [0, 0xffffffff]
> range, right? And w0 is unknown, while up in "NOTE this one" w0 didn't
> matter, so it stayed imprecise. This is a match. It seems correct.
>

Thanks for the detailed explanation.
This is the exact point where I got confused. w8 and w5 share the same id
at this point, the range of w5 is also updated according to w0. Even though
r5 and r8 are disassociated later, w0 actually matters.

Assume the verifier does not prune at this point, then w5 would be unknown
after this point, and #30 will be explored. The branch "from 29 to 30" is the
taken path at runtime, see below.

>
[...]
>
> r0 is marked precise at 26:
>
> mark_precise: frame0: last_idx 26 first_idx 22 subseq_idx -1
> mark_precise: frame0: regs=r0 stack= before 24: (18) r2 = 0x4
> mark_precise: frame0: regs=r0 stack= before 23: (bf) r5 = r8
> mark_precise: frame0: regs=r0 stack= before 22: (67) r4 <<= 2
> mark_precise: frame0: parent state regs=r0 stack=:
> R0_rw=Pscalar(smin=smin32=0,sm
> ax=umax=smax32=umax32=2,var_off=(0x0;
> 0x3)) R2_w=28 R3_w=scalar()
> R4_rw=scalar(smin=0,smax=umax=0xffffffff,smin32=-4,smax32=3,var_off=(0x0;
> 0xffffffff)) R8_rw=Pscalar(smin=0x800000000000000f,smax=0x7fffffff0000000f,umin=smin32=umin32=15,umax=0xffffffff0000000f,smax32=umax32=15,var_off=(0xf;
> 0xffffffff00000000)) R10=fp0
>
[...]
> > However, seems it's not, so the next time when the verifier checks
> > #26, R0 is incorrectly ignored.
> > We have mark_precise_scalar_ids(), but it's called before calculating
> > the mask once.
>
> I'm not following the remark about mark_precise_scalar_ids(). That
> works fine, but has nothing to do with r0. mark_precise_scalar_ids()
> identifies that r8 and r5 are linked together, and you can see from
> the log that we mark both r5 and r8 as precise.
>
> > I investigated for quite a while, but mark_chain_pricision() is really
> > hard to follow.
> >
> > Here is a reduced C repro, maybe someone else can shed some light on this.
> > C repro: https://pastebin.com/raw/chrshhGQ
>
> So you claim is that
>
> 30: (18) r0 = 0x4 ; incorrectly pruned here
>
>
> Can you please show a detailed code patch in which we do reach 30
> actually? I might have missed it, but so far it look like verifier is
> doing everything right.
>

I tried to convert the repro to a valid test case in inline asm, but seems
JSET (if r0 & 0xfffffffe goto pc+3) is currently not supported in clang-17.
Will try after clang-18 is released.

#30 is expected to be executed, see below where everything after ";" is
the runtime value:
...
6: (36) if w8 >= 0x69 goto pc+1 ; w8 = 0xbe, always taken
...
11: (45) if r0 & 0xfffffffe goto pc+3 ; r0 = 0x616, taken
...
18: (56) if w8 != 0xf goto pc+3 ; w8 not touched, taken
...
23: (bf) r5 = r8 ; w5 = 0xbe
24: (18) r2 = 0x4
26: (7e) if w8 s>= w0 goto pc+5 ; non-taken
27: (4f) r8 |= r8
28: (0f) r8 += r8
29: (d6) if w5 s<= 0x1d goto pc+2 ; non-taken
30: (18) r0 = 0x4 ; executed

Since the verifier prunes at #26, #30 is dead and eliminated. So, #30
is executed after manually commenting out the dead code rewrite pass.

>From my understanding, I think r0 should be marked as precise when
first backtrack from #29, because r5 range at this point depends on w0
as r8 and r5 share the same id at #26.