[PATCH] tls: fix missing memory barrier in tls_init

From: Dae R. Jeong
Date: Fri Nov 10 2023 - 13:09:05 EST


In tls_init(), a write memory barrier is missing, and store-store
reordering may cause NULL dereference in tls_{setsockopt,getsockopt}.

CPU0 CPU1
----- -----
// In tls_init()
// In tls_ctx_create()
ctx = kzalloc()
ctx->sk_proto = READ_ONCE(sk->sk_prot) -(1)

// In update_sk_prot()
WRITE_ONCE(sk->sk_prot, tls_prots) -(2)

// In sock_common_setsockopt()
READ_ONCE(sk->sk_prot)->setsockopt()

// In tls_{setsockopt,getsockopt}()
ctx->sk_proto->setsockopt() -(3)

In the above scenario, when (1) and (2) are reordered, (3) can observe
the NULL value of ctx->sk_proto, causing NULL dereference.

To fix it, we rely on rcu_assign_pointer() which implies the release
barrier semantic. By moving rcu_assign_pointer() after ctx is fully
initialized, we can ensure that all fields of ctx are visible when
changing sk->sk_prot.

Also, as Sabrina suggested, this patch gets rid of tls_ctx_create(),
and move all that into tls_init().

Signed-off-by: Dae R. Jeong <threeearcat@xxxxxxxxx>
---
net/tls/tls_main.c | 32 +++++++++++++++-----------------
1 file changed, 15 insertions(+), 17 deletions(-)

diff --git a/net/tls/tls_main.c b/net/tls/tls_main.c
index 1c2c6800949d..235fa93dc7ef 100644
--- a/net/tls/tls_main.c
+++ b/net/tls/tls_main.c
@@ -806,22 +806,6 @@ static int tls_setsockopt(struct sock *sk, int level, int optname,
return do_tls_setsockopt(sk, optname, optval, optlen);
}

-struct tls_context *tls_ctx_create(struct sock *sk)
-{
- struct inet_connection_sock *icsk = inet_csk(sk);
- struct tls_context *ctx;
-
- ctx = kzalloc(sizeof(*ctx), GFP_ATOMIC);
- if (!ctx)
- return NULL;
-
- mutex_init(&ctx->tx_lock);
- rcu_assign_pointer(icsk->icsk_ulp_data, ctx);
- ctx->sk_proto = READ_ONCE(sk->sk_prot);
- ctx->sk = sk;
- return ctx;
-}
-
static void build_proto_ops(struct proto_ops ops[TLS_NUM_CONFIG][TLS_NUM_CONFIG],
const struct proto_ops *base)
{
@@ -933,6 +917,7 @@ static void build_protos(struct proto prot[TLS_NUM_CONFIG][TLS_NUM_CONFIG],

static int tls_init(struct sock *sk)
{
+ struct inet_connection_sock *icsk = inet_csk(sk);
struct tls_context *ctx;
int rc = 0;

@@ -954,14 +939,27 @@ static int tls_init(struct sock *sk)

/* allocate tls context */
write_lock_bh(&sk->sk_callback_lock);
- ctx = tls_ctx_create(sk);
+ ctx = kzalloc(sizeof(*ctx), GFP_ATOMIC);
if (!ctx) {
rc = -ENOMEM;
goto out;
}

+ mutex_init(&ctx->tx_lock);
+ ctx->sk_proto = READ_ONCE(sk->sk_prot);
+ ctx->sk = sk;
ctx->tx_conf = TLS_BASE;
ctx->rx_conf = TLS_BASE;
+ /* rcu_assign_pointer() should be called after initialization of
+ * all fields of ctx. It ensures that all fields of ctx are
+ * visible before changing sk->sk_prot, and prevents reading of
+ * uninitialized fields in tls_{getsockopt,setsockopt}. Note that
+ * we do not need a read barrier in tls_{getsockopt,setsockopt} as
+ * there is an address dependency between
+ * sk->sk_proto->{getsockopt,setsockopt} and ctx->sk_proto.
+ */
+ rcu_assign_pointer(icsk->icsk_ulp_data, ctx);
+
update_sk_prot(sk, ctx);
out:
write_unlock_bh(&sk->sk_callback_lock);
--
2.42.1