[PATCH 3.2 61/95] l2tp: fix race in l2tp_recv_common()

From: Ben Hutchings
Date: Sun Jul 16 2017 - 10:59:00 EST


3.2.91-rc1 review patch. If anyone has any objections, please let me know.

------------------

From: Guillaume Nault <g.nault@xxxxxxxxxxxx>

commit 61b9a047729bb230978178bca6729689d0c50ca2 upstream.

Taking a reference on sessions in l2tp_recv_common() is racy; this
has to be done by the callers.

To this end, a new function is required (l2tp_session_get()) to
atomically lookup a session and take a reference on it. Callers then
have to manually drop this reference.

Fixes: fd558d186df2 ("l2tp: Split pppol2tp patch into separate l2tp and ppp parts")
Signed-off-by: Guillaume Nault <g.nault@xxxxxxxxxxxx>
Signed-off-by: David S. Miller <davem@xxxxxxxxxxxxx>
[bwh: Backported to 3.2:
- Drop changes to l2tp_ip6.c
- Add 'pos' parameter to hlist_for_each_entry{,_rcu}() calls
- Adjust context]
Signed-off-by: Ben Hutchings <ben@xxxxxxxxxxxxxxx>
---
--- a/net/l2tp/l2tp_core.c
+++ b/net/l2tp/l2tp_core.c
@@ -225,6 +225,56 @@ struct l2tp_session *l2tp_session_find(s
}
EXPORT_SYMBOL_GPL(l2tp_session_find);

+/* Like l2tp_session_find() but takes a reference on the returned session.
+ * Optionally calls session->ref() too if do_ref is true.
+ */
+struct l2tp_session *l2tp_session_get(struct net *net,
+ struct l2tp_tunnel *tunnel,
+ u32 session_id, bool do_ref)
+{
+ struct hlist_head *session_list;
+ struct l2tp_session *session;
+ struct hlist_node *walk;
+
+ if (!tunnel) {
+ struct l2tp_net *pn = l2tp_pernet(net);
+
+ session_list = l2tp_session_id_hash_2(pn, session_id);
+
+ rcu_read_lock_bh();
+ hlist_for_each_entry_rcu(session, walk, session_list, global_hlist) {
+ if (session->session_id == session_id) {
+ l2tp_session_inc_refcount(session);
+ if (do_ref && session->ref)
+ session->ref(session);
+ rcu_read_unlock_bh();
+
+ return session;
+ }
+ }
+ rcu_read_unlock_bh();
+
+ return NULL;
+ }
+
+ session_list = l2tp_session_id_hash(tunnel, session_id);
+ read_lock_bh(&tunnel->hlist_lock);
+ hlist_for_each_entry(session, walk, session_list, hlist) {
+ if (session->session_id == session_id) {
+ l2tp_session_inc_refcount(session);
+ if (do_ref && session->ref)
+ session->ref(session);
+ read_unlock_bh(&tunnel->hlist_lock);
+
+ return session;
+ }
+ }
+ read_unlock_bh(&tunnel->hlist_lock);
+
+ return NULL;
+}
+EXPORT_SYMBOL_GPL(l2tp_session_get);
+
struct l2tp_session *l2tp_session_find_nth(struct l2tp_tunnel *tunnel, int nth)
{
int hash;
@@ -524,6 +574,9 @@ static inline int l2tp_verify_udp_checks
* a data (not control) frame before coming here. Fields up to the
* session-id have already been parsed and ptr points to the data
* after the session-id.
+ *
+ * session->ref() must have been called prior to l2tp_recv_common().
+ * session->deref() will be called automatically after skb is processed.
*/
void l2tp_recv_common(struct l2tp_session *session, struct sk_buff *skb,
unsigned char *ptr, unsigned char *optr, u16 hdrflags,
@@ -533,14 +586,6 @@ void l2tp_recv_common(struct l2tp_sessio
int offset;
u32 ns, nr;

- /* The ref count is increased since we now hold a pointer to
- * the session. Take care to decrement the refcnt when exiting
- * this function from now on...
- */
- l2tp_session_inc_refcount(session);
- if (session->ref)
- (*session->ref)(session);
-
/* Parse and check optional cookie */
if (session->peer_cookie_len > 0) {
if (memcmp(ptr, &session->peer_cookie[0], session->peer_cookie_len)) {
@@ -711,8 +756,6 @@ void l2tp_recv_common(struct l2tp_sessio
/* Try to dequeue as many skbs from reorder_q as we can. */
l2tp_recv_dequeue(session);

- l2tp_session_dec_refcount(session);
-
return;

discard:
@@ -721,8 +764,6 @@ discard:

if (session->deref)
(*session->deref)(session);
-
- l2tp_session_dec_refcount(session);
}
EXPORT_SYMBOL(l2tp_recv_common);

@@ -818,8 +859,14 @@ static int l2tp_udp_recv_core(struct l2t
}

/* Find the session context */
- session = l2tp_session_find(tunnel->l2tp_net, tunnel, session_id);
+ session = l2tp_session_get(tunnel->l2tp_net, tunnel, session_id, true);
if (!session || !session->recv_skb) {
+ if (session) {
+ if (session->deref)
+ session->deref(session);
+ l2tp_session_dec_refcount(session);
+ }
+
/* Not found? Pass to userspace to deal with */
PRINTK(tunnel->debug, L2TP_MSG_DATA, KERN_INFO,
"%s: no session found (%u/%u). Passing up.\n",
@@ -828,6 +875,7 @@ static int l2tp_udp_recv_core(struct l2t
}

l2tp_recv_common(session, skb, ptr, optr, hdrflags, length, payload_hook);
+ l2tp_session_dec_refcount(session);

return 0;

--- a/net/l2tp/l2tp_core.h
+++ b/net/l2tp/l2tp_core.h
@@ -222,6 +222,9 @@ out:
return tunnel;
}

+struct l2tp_session *l2tp_session_get(struct net *net,
+ struct l2tp_tunnel *tunnel,
+ u32 session_id, bool do_ref);
extern struct l2tp_session *l2tp_session_find(struct net *net, struct l2tp_tunnel *tunnel, u32 session_id);
extern struct l2tp_session *l2tp_session_find_nth(struct l2tp_tunnel *tunnel, int nth);
extern struct l2tp_session *l2tp_session_find_by_ifname(struct net *net, char *ifname);
--- a/net/l2tp/l2tp_ip.c
+++ b/net/l2tp/l2tp_ip.c
@@ -148,19 +148,19 @@ static int l2tp_ip_recv(struct sk_buff *
}

/* Ok, this is a data packet. Lookup the session. */
- session = l2tp_session_find(&init_net, NULL, session_id);
- if (session == NULL)
+ session = l2tp_session_get(&init_net, NULL, session_id, true);
+ if (!session)
goto discard;

tunnel = session->tunnel;
- if (tunnel == NULL)
- goto discard;
+ if (!tunnel)
+ goto discard_sess;

/* Trace packet contents, if enabled */
if (tunnel->debug & L2TP_MSG_DATA) {
length = min(32u, skb->len);
if (!pskb_may_pull(skb, length))
- goto discard;
+ goto discard_sess;

/* Point to L2TP header */
optr = ptr = skb->data;
@@ -176,6 +176,7 @@ static int l2tp_ip_recv(struct sk_buff *
}

l2tp_recv_common(session, skb, ptr, optr, 0, skb->len, tunnel->recv_payload_hook);
+ l2tp_session_dec_refcount(session);

return 0;

@@ -211,6 +212,12 @@ pass_up:

return sk_receive_skb(sk, skb, 1);

+discard_sess:
+ if (session->deref)
+ session->deref(session);
+ l2tp_session_dec_refcount(session);
+ goto discard;
+
discard_put:
sock_put(sk);