[PATCH bpf-next v2 5/6] bpf, net: Support SO_REUSEPORT sockets with bpf_sk_assign

From: Lorenz Bauer
Date: Tue Jun 13 2023 - 06:15:30 EST


Currently the bpf_sk_assign helper in tc BPF context refuses SO_REUSEPORT
sockets. This means we can't use the helper to steer traffic to Envoy,
which configures SO_REUSEPORT on its sockets. In turn, we're blocked
from removing TPROXY from our setup.

The reason that bpf_sk_assign refuses such sockets is that the
bpf_sk_lookup helpers don't execute SK_REUSEPORT programs. Instead,
one of the reuseport sockets is selected by hash. This could cause
dispatch to the "wrong" socket:

sk = bpf_sk_lookup_tcp(...) // select SO_REUSEPORT by hash
bpf_sk_assign(skb, sk) // SK_REUSEPORT wasn't executed

Fixing this isn't as simple as invoking SK_REUSEPORT from the lookup
helpers unfortunately. In the tc context, L2 headers are at the start
of the skb, while SK_REUSEPORT expects L3 headers instead.

Instead, we execute the SK_REUSEPORT program when the assigned socket
is pulled out of the skb, further up the stack. This creates some
trickiness with regards to refcounting as bpf_sk_assign will put both
refcounted and RCU freed sockets in skb->sk. reuseport sockets are RCU
freed. We can infer that the sk_assigned socket is RCU freed if the
reuseport lookup succeeds, but convincing yourself of this fact isn't
straight forward. Therefore we defensively check refcounting on the
sk_assign sock even though it's probably not required in practice.

Fixes: 8e368dc72e86 ("bpf: Fix use of sk->sk_reuseport from sk_assign")
Fixes: cf7fbe660f2d ("bpf: Add socket assign support")
Co-developed-by: Daniel Borkmann <daniel@xxxxxxxxxxxxx>
Signed-off-by: Daniel Borkmann <daniel@xxxxxxxxxxxxx>
Signed-off-by: Lorenz Bauer <lmb@xxxxxxxxxxxxx>
Cc: Joe Stringer <joe@xxxxxxxxx>
Link: https://lore.kernel.org/bpf/CACAyw98+qycmpQzKupquhkxbvWK4OFyDuuLMBNROnfWMZxUWeA@xxxxxxxxxxxxxx/
---
include/net/inet6_hashtables.h | 59 ++++++++++++++++++++++++++++++++++++++----
include/net/inet_hashtables.h | 52 +++++++++++++++++++++++++++++++++++--
include/net/sock.h | 7 +++--
include/uapi/linux/bpf.h | 3 ---
net/core/filter.c | 2 --
net/ipv4/udp.c | 8 ++++--
net/ipv6/udp.c | 10 ++++---
tools/include/uapi/linux/bpf.h | 3 ---
8 files changed, 122 insertions(+), 22 deletions(-)

diff --git a/include/net/inet6_hashtables.h b/include/net/inet6_hashtables.h
index 4d2a1a3c0be7..4d300af6ccb6 100644
--- a/include/net/inet6_hashtables.h
+++ b/include/net/inet6_hashtables.h
@@ -103,6 +103,49 @@ static inline struct sock *__inet6_lookup(struct net *net,
daddr, hnum, dif, sdif);
}

+static inline
+struct sock *inet6_steal_sock(struct net *net, struct sk_buff *skb, int doff,
+ const struct in6_addr *saddr, const __be16 sport,
+ const struct in6_addr *daddr, const __be16 dport,
+ bool *refcounted, inet6_ehashfn_t ehashfn)
+{
+ struct sock *sk, *reuse_sk;
+ bool prefetched;
+
+ sk = skb_steal_sock(skb, refcounted, &prefetched);
+ if (!sk)
+ return NULL;
+
+ if (!prefetched)
+ return sk;
+
+ if (sk->sk_protocol == IPPROTO_TCP) {
+ if (sk->sk_state != TCP_LISTEN)
+ return sk;
+ } else if (sk->sk_protocol == IPPROTO_UDP) {
+ if (sk->sk_state != TCP_CLOSE)
+ return sk;
+ } else {
+ return sk;
+ }
+
+ reuse_sk = inet6_lookup_reuseport(net, sk, skb, doff,
+ saddr, sport, daddr, ntohs(dport),
+ ehashfn);
+ if (!reuse_sk || reuse_sk == sk)
+ return sk;
+
+ /* We've chosen a new reuseport sock which is never refcounted.
+ * sk might be refcounted however, drop the reference if necessary.
+ */
+ if (*refcounted) {
+ sock_put(sk);
+ *refcounted = false;
+ }
+
+ return reuse_sk;
+}
+
static inline struct sock *__inet6_lookup_skb(struct inet_hashinfo *hashinfo,
struct sk_buff *skb, int doff,
const __be16 sport,
@@ -110,14 +153,20 @@ static inline struct sock *__inet6_lookup_skb(struct inet_hashinfo *hashinfo,
int iif, int sdif,
bool *refcounted)
{
- struct sock *sk = skb_steal_sock(skb, refcounted);
-
+ struct net *net = dev_net(skb_dst(skb)->dev);
+ const struct ipv6hdr *ip6h = ipv6_hdr(skb);
+ struct sock *sk;
+
+ sk = inet6_steal_sock(net, skb, doff, &ip6h->saddr, sport, &ip6h->daddr, dport,
+ refcounted, inet6_ehashfn);
+ if (IS_ERR(sk))
+ return NULL;
if (sk)
return sk;

- return __inet6_lookup(dev_net(skb_dst(skb)->dev), hashinfo, skb,
- doff, &ipv6_hdr(skb)->saddr, sport,
- &ipv6_hdr(skb)->daddr, ntohs(dport),
+ return __inet6_lookup(net, hashinfo, skb,
+ doff, &ip6h->saddr, sport,
+ &ip6h->daddr, ntohs(dport),
iif, sdif, refcounted);
}

diff --git a/include/net/inet_hashtables.h b/include/net/inet_hashtables.h
index aa02f1db1f86..2c405d9df300 100644
--- a/include/net/inet_hashtables.h
+++ b/include/net/inet_hashtables.h
@@ -449,6 +449,49 @@ static inline struct sock *inet_lookup(struct net *net,
return sk;
}

+static inline
+struct sock *inet_steal_sock(struct net *net, struct sk_buff *skb, int doff,
+ const __be32 saddr, const __be16 sport,
+ const __be32 daddr, const __be16 dport,
+ bool *refcounted, inet_ehashfn_t ehashfn)
+{
+ struct sock *sk, *reuse_sk;
+ bool prefetched;
+
+ sk = skb_steal_sock(skb, refcounted, &prefetched);
+ if (!sk)
+ return NULL;
+
+ if (!prefetched)
+ return sk;
+
+ if (sk->sk_protocol == IPPROTO_TCP) {
+ if (sk->sk_state != TCP_LISTEN)
+ return sk;
+ } else if (sk->sk_protocol == IPPROTO_UDP) {
+ if (sk->sk_state != TCP_CLOSE)
+ return sk;
+ } else {
+ return sk;
+ }
+
+ reuse_sk = inet_lookup_reuseport(net, sk, skb, doff,
+ saddr, sport, daddr, ntohs(dport),
+ ehashfn);
+ if (!reuse_sk || reuse_sk == sk)
+ return sk;
+
+ /* We've chosen a new reuseport sock which is never refcounted.
+ * sk might be refcounted however, drop the reference if necessary.
+ */
+ if (*refcounted) {
+ sock_put(sk);
+ *refcounted = false;
+ }
+
+ return reuse_sk;
+}
+
static inline struct sock *__inet_lookup_skb(struct inet_hashinfo *hashinfo,
struct sk_buff *skb,
int doff,
@@ -457,13 +500,18 @@ static inline struct sock *__inet_lookup_skb(struct inet_hashinfo *hashinfo,
const int sdif,
bool *refcounted)
{
- struct sock *sk = skb_steal_sock(skb, refcounted);
+ struct net *net = dev_net(skb_dst(skb)->dev);
const struct iphdr *iph = ip_hdr(skb);
+ struct sock *sk;

+ sk = inet_steal_sock(net, skb, doff, iph->saddr, sport, iph->daddr, dport,
+ refcounted, inet_ehashfn);
+ if (IS_ERR(sk))
+ return NULL;
if (sk)
return sk;

- return __inet_lookup(dev_net(skb_dst(skb)->dev), hashinfo, skb,
+ return __inet_lookup(net, hashinfo, skb,
doff, iph->saddr, sport,
iph->daddr, dport, inet_iif(skb), sdif,
refcounted);
diff --git a/include/net/sock.h b/include/net/sock.h
index 656ea89f60ff..5645570c2a64 100644
--- a/include/net/sock.h
+++ b/include/net/sock.h
@@ -2806,20 +2806,23 @@ sk_is_refcounted(struct sock *sk)
* skb_steal_sock - steal a socket from an sk_buff
* @skb: sk_buff to steal the socket from
* @refcounted: is set to true if the socket is reference-counted
+ * @prefetched: is set to true if the socket was assigned from bpf
*/
static inline struct sock *
-skb_steal_sock(struct sk_buff *skb, bool *refcounted)
+skb_steal_sock(struct sk_buff *skb, bool *refcounted, bool *prefetched)
{
if (skb->sk) {
struct sock *sk = skb->sk;

*refcounted = true;
- if (skb_sk_is_prefetched(skb))
+ *prefetched = skb_sk_is_prefetched(skb);
+ if (*prefetched)
*refcounted = sk_is_refcounted(sk);
skb->destructor = NULL;
skb->sk = NULL;
return sk;
}
+ *prefetched = false;
*refcounted = false;
return NULL;
}
diff --git a/include/uapi/linux/bpf.h b/include/uapi/linux/bpf.h
index a7b5e91dd768..d6fb6f43b0f3 100644
--- a/include/uapi/linux/bpf.h
+++ b/include/uapi/linux/bpf.h
@@ -4158,9 +4158,6 @@ union bpf_attr {
* **-EOPNOTSUPP** if the operation is not supported, for example
* a call from outside of TC ingress.
*
- * **-ESOCKTNOSUPPORT** if the socket type is not supported
- * (reuseport).
- *
* long bpf_sk_assign(struct bpf_sk_lookup *ctx, struct bpf_sock *sk, u64 flags)
* Description
* Helper is overloaded depending on BPF program type. This
diff --git a/net/core/filter.c b/net/core/filter.c
index 428df050d021..d4be0a1d754c 100644
--- a/net/core/filter.c
+++ b/net/core/filter.c
@@ -7278,8 +7278,6 @@ BPF_CALL_3(bpf_sk_assign, struct sk_buff *, skb, struct sock *, sk, u64, flags)
return -EOPNOTSUPP;
if (unlikely(dev_net(skb->dev) != sock_net(sk)))
return -ENETUNREACH;
- if (unlikely(sk_fullsock(sk) && sk->sk_reuseport))
- return -ESOCKTNOSUPPORT;
if (sk_is_refcounted(sk) &&
unlikely(!refcount_inc_not_zero(&sk->sk_refcnt)))
return -ENOENT;
diff --git a/net/ipv4/udp.c b/net/ipv4/udp.c
index 2500e92050a0..a3fa781432b9 100644
--- a/net/ipv4/udp.c
+++ b/net/ipv4/udp.c
@@ -2373,7 +2373,11 @@ int __udp4_lib_rcv(struct sk_buff *skb, struct udp_table *udptable,
if (udp4_csum_init(skb, uh, proto))
goto csum_error;

- sk = skb_steal_sock(skb, &refcounted);
+ sk = inet_steal_sock(net, skb, sizeof(struct udphdr), saddr, uh->source, daddr, uh->dest,
+ &refcounted, udp_ehashfn);
+ if (IS_ERR(sk))
+ goto no_sk;
+
if (sk) {
struct dst_entry *dst = skb_dst(skb);
int ret;
@@ -2394,7 +2398,7 @@ int __udp4_lib_rcv(struct sk_buff *skb, struct udp_table *udptable,
sk = __udp4_lib_lookup_skb(skb, uh->source, uh->dest, udptable);
if (sk)
return udp_unicast_rcv_skb(sk, skb, uh);
-
+no_sk:
if (!xfrm4_policy_check(NULL, XFRM_POLICY_IN, skb))
goto drop;
nf_reset_ct(skb);
diff --git a/net/ipv6/udp.c b/net/ipv6/udp.c
index 961b7e61f02c..0a90f34696ad 100644
--- a/net/ipv6/udp.c
+++ b/net/ipv6/udp.c
@@ -910,9 +910,9 @@ int __udp6_lib_rcv(struct sk_buff *skb, struct udp_table *udptable,
enum skb_drop_reason reason = SKB_DROP_REASON_NOT_SPECIFIED;
const struct in6_addr *saddr, *daddr;
struct net *net = dev_net(skb->dev);
+ bool refcounted;
struct udphdr *uh;
struct sock *sk;
- bool refcounted;
u32 ulen = 0;

if (!pskb_may_pull(skb, sizeof(struct udphdr)))
@@ -949,7 +949,11 @@ int __udp6_lib_rcv(struct sk_buff *skb, struct udp_table *udptable,
goto csum_error;

/* Check if the socket is already available, e.g. due to early demux */
- sk = skb_steal_sock(skb, &refcounted);
+ sk = inet6_steal_sock(net, skb, sizeof(struct udphdr), saddr, uh->source, daddr, uh->dest,
+ &refcounted, udp6_ehashfn);
+ if (IS_ERR(sk))
+ goto no_sk;
+
if (sk) {
struct dst_entry *dst = skb_dst(skb);
int ret;
@@ -983,7 +987,7 @@ int __udp6_lib_rcv(struct sk_buff *skb, struct udp_table *udptable,
goto report_csum_error;
return udp6_unicast_rcv_skb(sk, skb, uh);
}
-
+no_sk:
reason = SKB_DROP_REASON_NO_SOCKET;

if (!uh->check)
diff --git a/tools/include/uapi/linux/bpf.h b/tools/include/uapi/linux/bpf.h
index a7b5e91dd768..d6fb6f43b0f3 100644
--- a/tools/include/uapi/linux/bpf.h
+++ b/tools/include/uapi/linux/bpf.h
@@ -4158,9 +4158,6 @@ union bpf_attr {
* **-EOPNOTSUPP** if the operation is not supported, for example
* a call from outside of TC ingress.
*
- * **-ESOCKTNOSUPPORT** if the socket type is not supported
- * (reuseport).
- *
* long bpf_sk_assign(struct bpf_sk_lookup *ctx, struct bpf_sock *sk, u64 flags)
* Description
* Helper is overloaded depending on BPF program type. This

--
2.40.1