[PATCH 5.10 257/545] inet: add READ_ONCE(sk->sk_bound_dev_if) in INET_MATCH()

From: Greg Kroah-Hartman
Date: Fri Aug 19 2022 - 12:19:03 EST


From: Eric Dumazet <edumazet@xxxxxxxxxx>

[ Upstream commit 4915d50e300e96929d2462041d6f6c6f061167fd ]

INET_MATCH() runs without holding a lock on the socket.

We probably need to annotate most reads.

This patch makes INET_MATCH() an inline function
to ease our changes.

v2:

We remove the 32bit version of it, as modern compilers
should generate the same code really, no need to
try to be smarter.

Also make 'struct net *net' the first argument.

Signed-off-by: Eric Dumazet <edumazet@xxxxxxxxxx>
Signed-off-by: David S. Miller <davem@xxxxxxxxxxxxx>
Signed-off-by: Sasha Levin <sashal@xxxxxxxxxx>
---
include/net/inet_hashtables.h | 35 ++++++++++++++++-------------------
include/net/sock.h | 3 ---
net/ipv4/inet_hashtables.c | 15 +++++----------
net/ipv4/udp.c | 3 +--
4 files changed, 22 insertions(+), 34 deletions(-)

diff --git a/include/net/inet_hashtables.h b/include/net/inet_hashtables.h
index d4d611064a76..816851807fa8 100644
--- a/include/net/inet_hashtables.h
+++ b/include/net/inet_hashtables.h
@@ -289,7 +289,6 @@ static inline struct sock *inet_lookup_listener(struct net *net,
((__force __portpair)(((__u32)(__dport) << 16) | (__force __u32)(__be16)(__sport)))
#endif

-#if (BITS_PER_LONG == 64)
#ifdef __BIG_ENDIAN
#define INET_ADDR_COOKIE(__name, __saddr, __daddr) \
const __addrpair __name = (__force __addrpair) ( \
@@ -301,24 +300,22 @@ static inline struct sock *inet_lookup_listener(struct net *net,
(((__force __u64)(__be32)(__daddr)) << 32) | \
((__force __u64)(__be32)(__saddr)))
#endif /* __BIG_ENDIAN */
-#define INET_MATCH(__sk, __net, __cookie, __saddr, __daddr, __ports, __dif, __sdif) \
- (((__sk)->sk_portpair == (__ports)) && \
- ((__sk)->sk_addrpair == (__cookie)) && \
- (((__sk)->sk_bound_dev_if == (__dif)) || \
- ((__sk)->sk_bound_dev_if == (__sdif))) && \
- net_eq(sock_net(__sk), (__net)))
-#else /* 32-bit arch */
-#define INET_ADDR_COOKIE(__name, __saddr, __daddr) \
- const int __name __deprecated __attribute__((unused))
-
-#define INET_MATCH(__sk, __net, __cookie, __saddr, __daddr, __ports, __dif, __sdif) \
- (((__sk)->sk_portpair == (__ports)) && \
- ((__sk)->sk_daddr == (__saddr)) && \
- ((__sk)->sk_rcv_saddr == (__daddr)) && \
- (((__sk)->sk_bound_dev_if == (__dif)) || \
- ((__sk)->sk_bound_dev_if == (__sdif))) && \
- net_eq(sock_net(__sk), (__net)))
-#endif /* 64-bit arch */
+
+static inline bool INET_MATCH(struct net *net, const struct sock *sk,
+ const __addrpair cookie, const __portpair ports,
+ int dif, int sdif)
+{
+ int bound_dev_if;
+
+ if (!net_eq(sock_net(sk), net) ||
+ sk->sk_portpair != ports ||
+ sk->sk_addrpair != cookie)
+ return false;
+
+ /* Paired with WRITE_ONCE() from sock_bindtoindex_locked() */
+ bound_dev_if = READ_ONCE(sk->sk_bound_dev_if);
+ return bound_dev_if == dif || bound_dev_if == sdif;
+}

/* Sockets in TCP_CLOSE state are _always_ taken out of the hash, so we need
* not check it for lookups anymore, thanks Alexey. -DaveM
diff --git a/include/net/sock.h b/include/net/sock.h
index c72b0fc4c752..333131f47ac1 100644
--- a/include/net/sock.h
+++ b/include/net/sock.h
@@ -160,9 +160,6 @@ typedef __u64 __bitwise __addrpair;
* for struct sock and struct inet_timewait_sock.
*/
struct sock_common {
- /* skc_daddr and skc_rcv_saddr must be grouped on a 8 bytes aligned
- * address on 64bit arches : cf INET_MATCH()
- */
union {
__addrpair skc_addrpair;
struct {
diff --git a/net/ipv4/inet_hashtables.c b/net/ipv4/inet_hashtables.c
index f38b71cc3edb..7dbe80e30b9d 100644
--- a/net/ipv4/inet_hashtables.c
+++ b/net/ipv4/inet_hashtables.c
@@ -410,13 +410,11 @@ struct sock *__inet_lookup_established(struct net *net,
sk_nulls_for_each_rcu(sk, node, &head->chain) {
if (sk->sk_hash != hash)
continue;
- if (likely(INET_MATCH(sk, net, acookie,
- saddr, daddr, ports, dif, sdif))) {
+ if (likely(INET_MATCH(net, sk, acookie, ports, dif, sdif))) {
if (unlikely(!refcount_inc_not_zero(&sk->sk_refcnt)))
goto out;
- if (unlikely(!INET_MATCH(sk, net, acookie,
- saddr, daddr, ports,
- dif, sdif))) {
+ if (unlikely(!INET_MATCH(net, sk, acookie,
+ ports, dif, sdif))) {
sock_gen_put(sk);
goto begin;
}
@@ -465,8 +463,7 @@ static int __inet_check_established(struct inet_timewait_death_row *death_row,
if (sk2->sk_hash != hash)
continue;

- if (likely(INET_MATCH(sk2, net, acookie,
- saddr, daddr, ports, dif, sdif))) {
+ if (likely(INET_MATCH(net, sk2, acookie, ports, dif, sdif))) {
if (sk2->sk_state == TCP_TIME_WAIT) {
tw = inet_twsk(sk2);
if (twsk_unique(sk, sk2, twp))
@@ -532,9 +529,7 @@ static bool inet_ehash_lookup_by_sk(struct sock *sk,
if (esk->sk_hash != sk->sk_hash)
continue;
if (sk->sk_family == AF_INET) {
- if (unlikely(INET_MATCH(esk, net, acookie,
- sk->sk_daddr,
- sk->sk_rcv_saddr,
+ if (unlikely(INET_MATCH(net, esk, acookie,
ports, dif, sdif))) {
return true;
}
diff --git a/net/ipv4/udp.c b/net/ipv4/udp.c
index 6056d5609167..e498c7666ec6 100644
--- a/net/ipv4/udp.c
+++ b/net/ipv4/udp.c
@@ -2490,8 +2490,7 @@ static struct sock *__udp4_lib_demux_lookup(struct net *net,
struct sock *sk;

udp_portaddr_for_each_entry_rcu(sk, &hslot2->head) {
- if (INET_MATCH(sk, net, acookie, rmt_addr,
- loc_addr, ports, dif, sdif))
+ if (INET_MATCH(net, sk, acookie, ports, dif, sdif))
return sk;
/* Only check first socket in chain */
break;
--
2.35.1