Re: [PATCHv2] vsock: Fix blocking ops call in prepare_to_wait

From: Claudio Imbrenda
Date: Fri Mar 11 2016 - 07:39:37 EST


Hello,

I think I found a problem with the patch submitted by Laura Abbott
( https://lkml.org/lkml/2016/2/4/711 ): we might miss wakeups.
Since the condition is not checked between the prepare_to_wait and the
schedule(), if a wakeup happens after the condition is checked but before
the sleep happens, and we miss it. ( A description of the problem can be
found here: http://www.makelinux.net/ldd3/chp-6-sect-2 ).

My solution (see patch below) is to shrink the area influenced by
prepare_to_wait, but keeping the fragile section around the condition, and
keep the rest of the code in "normal" running state. This way the sleep is
correct and the other functions don't need to worry. The only caveat here
is that the function(s) called to verify the conditions are really not
allowed to sleep, so if you need synchronization in the backend of e.g.
vsock_stream_has_space(), you should use spinlocks and not mutexes.

In case we want to be able to sleep while waiting for conditions, we can
consider this instead: https://lwn.net/Articles/628628/ .


I stumbled on this problem while working on fixing the upcoming virtio
backend for vsock, below is the patch I had prepared, with the original
message.

---
When a thread is prepared for waiting by calling prepare_to_wait, sleeping
is not allowed until either the wait has taken place or finish_wait has
been called. The existing code in af_vsock imposed unnecessary no-sleep
assumptions to a broad list of backend functions.
This patch shrinks the influence of prepare_to_wait to the area where it
is strictly needed, therefore relaxing the no-sleep restriction there.

Signed-off-by: Claudio Imbrenda <imbrenda@xxxxxxxxxxxxxxxxxx>
---
net/vmw_vsock/af_vsock.c | 158 +++++++++++++++++++++++++----------------------
1 file changed, 85 insertions(+), 73 deletions(-)

diff --git a/net/vmw_vsock/af_vsock.c b/net/vmw_vsock/af_vsock.c
index 9783a38..112fa8b 100644
--- a/net/vmw_vsock/af_vsock.c
+++ b/net/vmw_vsock/af_vsock.c
@@ -1209,10 +1209,14 @@ static int vsock_stream_connect(struct socket *sock, struct sockaddr *addr,

if (signal_pending(current)) {
err = sock_intr_errno(timeout);
- goto out_wait_error;
+ sk->sk_state = SS_UNCONNECTED;
+ sock->state = SS_UNCONNECTED;
+ goto out_wait;
} else if (timeout == 0) {
err = -ETIMEDOUT;
- goto out_wait_error;
+ sk->sk_state = SS_UNCONNECTED;
+ sock->state = SS_UNCONNECTED;
+ goto out_wait;
}

prepare_to_wait(sk_sleep(sk), &wait, TASK_INTERRUPTIBLE);
@@ -1220,20 +1224,17 @@ static int vsock_stream_connect(struct socket *sock, struct sockaddr *addr,

if (sk->sk_err) {
err = -sk->sk_err;
- goto out_wait_error;
- } else
+ sk->sk_state = SS_UNCONNECTED;
+ sock->state = SS_UNCONNECTED;
+ } else {
err = 0;
+ }

out_wait:
finish_wait(sk_sleep(sk), &wait);
out:
release_sock(sk);
return err;
-
-out_wait_error:
- sk->sk_state = SS_UNCONNECTED;
- sock->state = SS_UNCONNECTED;
- goto out_wait;
}

static int vsock_accept(struct socket *sock, struct socket *newsock, int flags)
@@ -1270,18 +1271,20 @@ static int vsock_accept(struct socket *sock, struct socket *newsock, int flags)
listener->sk_err == 0) {
release_sock(listener);
timeout = schedule_timeout(timeout);
+ finish_wait(sk_sleep(listener), &wait);
lock_sock(listener);

if (signal_pending(current)) {
err = sock_intr_errno(timeout);
- goto out_wait;
+ goto out;
} else if (timeout == 0) {
err = -EAGAIN;
- goto out_wait;
+ goto out;
}

prepare_to_wait(sk_sleep(listener), &wait, TASK_INTERRUPTIBLE);
}
+ finish_wait(sk_sleep(listener), &wait);

if (listener->sk_err)
err = -listener->sk_err;
@@ -1301,19 +1304,15 @@ static int vsock_accept(struct socket *sock, struct socket *newsock, int flags)
*/
if (err) {
vconnected->rejected = true;
- release_sock(connected);
- sock_put(connected);
- goto out_wait;
+ } else {
+ newsock->state = SS_CONNECTED;
+ sock_graft(connected, newsock);
}

- newsock->state = SS_CONNECTED;
- sock_graft(connected, newsock);
release_sock(connected);
sock_put(connected);
}

-out_wait:
- finish_wait(sk_sleep(listener), &wait);
out:
release_sock(listener);
return err;
@@ -1557,11 +1556,11 @@ static int vsock_stream_sendmsg(struct socket *sock, struct msghdr *msg,
if (err < 0)
goto out;

- prepare_to_wait(sk_sleep(sk), &wait, TASK_INTERRUPTIBLE);

while (total_written < len) {
ssize_t written;

+ prepare_to_wait(sk_sleep(sk), &wait, TASK_INTERRUPTIBLE);
while (vsock_stream_has_space(vsk) == 0 &&
sk->sk_err == 0 &&
!(sk->sk_shutdown & SEND_SHUTDOWN) &&
@@ -1570,27 +1569,33 @@ static int vsock_stream_sendmsg(struct socket *sock, struct msghdr *msg,
/* Don't wait for non-blocking sockets. */
if (timeout == 0) {
err = -EAGAIN;
- goto out_wait;
+ finish_wait(sk_sleep(sk), &wait);
+ goto out_err;
}

err = transport->notify_send_pre_block(vsk, &send_data);
- if (err < 0)
- goto out_wait;
+ if (err < 0) {
+ finish_wait(sk_sleep(sk), &wait);
+ goto out_err;
+ }

release_sock(sk);
timeout = schedule_timeout(timeout);
lock_sock(sk);
if (signal_pending(current)) {
err = sock_intr_errno(timeout);
- goto out_wait;
+ finish_wait(sk_sleep(sk), &wait);
+ goto out_err;
} else if (timeout == 0) {
err = -EAGAIN;
- goto out_wait;
+ finish_wait(sk_sleep(sk), &wait);
+ goto out_err;
}

prepare_to_wait(sk_sleep(sk), &wait,
TASK_INTERRUPTIBLE);
}
+ finish_wait(sk_sleep(sk), &wait);

/* These checks occur both as part of and after the loop
* conditional since we need to check before and after
@@ -1598,16 +1603,16 @@ static int vsock_stream_sendmsg(struct socket *sock, struct msghdr *msg,
*/
if (sk->sk_err) {
err = -sk->sk_err;
- goto out_wait;
+ goto out_err;
} else if ((sk->sk_shutdown & SEND_SHUTDOWN) ||
(vsk->peer_shutdown & RCV_SHUTDOWN)) {
err = -EPIPE;
- goto out_wait;
+ goto out_err;
}

err = transport->notify_send_pre_enqueue(vsk, &send_data);
if (err < 0)
- goto out_wait;
+ goto out_err;

/* Note that enqueue will only write as many bytes as are free
* in the produce queue, so we don't need to ensure len is
@@ -1620,7 +1625,7 @@ static int vsock_stream_sendmsg(struct socket *sock, struct msghdr *msg,
len - total_written);
if (written < 0) {
err = -ENOMEM;
- goto out_wait;
+ goto out_err;
}

total_written += written;
@@ -1628,14 +1633,13 @@ static int vsock_stream_sendmsg(struct socket *sock, struct msghdr *msg,
err = transport->notify_send_post_enqueue(
vsk, written, &send_data);
if (err < 0)
- goto out_wait;
+ goto out_err;

}

-out_wait:
+out_err:
if (total_written > 0)
err = total_written;
- finish_wait(sk_sleep(sk), &wait);
out:
release_sock(sk);
return err;
@@ -1716,21 +1720,61 @@ vsock_stream_recvmsg(struct socket *sock, struct msghdr *msg, size_t len,
if (err < 0)
goto out;

- prepare_to_wait(sk_sleep(sk), &wait, TASK_INTERRUPTIBLE);

while (1) {
- s64 ready = vsock_stream_has_data(vsk);
+ s64 ready;

- if (ready < 0) {
- /* Invalid queue pair content. XXX This should be
- * changed to a connection reset in a later change.
- */
+ prepare_to_wait(sk_sleep(sk), &wait, TASK_INTERRUPTIBLE);
+ ready = vsock_stream_has_data(vsk);

- err = -ENOMEM;
- goto out_wait;
- } else if (ready > 0) {
+ if (ready == 0) {
+ if (sk->sk_err != 0 ||
+ (sk->sk_shutdown & RCV_SHUTDOWN) ||
+ (vsk->peer_shutdown & SEND_SHUTDOWN)) {
+ finish_wait(sk_sleep(sk), &wait);
+ break;
+ }
+ /* Don't wait for non-blocking sockets. */
+ if (timeout == 0) {
+ err = -EAGAIN;
+ finish_wait(sk_sleep(sk), &wait);
+ break;
+ }
+
+ err = transport->notify_recv_pre_block(
+ vsk, target, &recv_data);
+ if (err < 0) {
+ finish_wait(sk_sleep(sk), &wait);
+ break;
+ }
+ release_sock(sk);
+ timeout = schedule_timeout(timeout);
+ lock_sock(sk);
+
+ if (signal_pending(current)) {
+ err = sock_intr_errno(timeout);
+ finish_wait(sk_sleep(sk), &wait);
+ break;
+ } else if (timeout == 0) {
+ err = -EAGAIN;
+ finish_wait(sk_sleep(sk), &wait);
+ break;
+ }
+ } else {
ssize_t read;

+ finish_wait(sk_sleep(sk), &wait);
+
+ if (ready < 0) {
+ /* Invalid queue pair content. XXX This should
+ * be changed to a connection reset in a later
+ * change.
+ */
+
+ err = -ENOMEM;
+ goto out;
+ }
+
err = transport->notify_recv_pre_dequeue(
vsk, target, &recv_data);
if (err < 0)
@@ -1750,42 +1794,12 @@ vsock_stream_recvmsg(struct socket *sock, struct msghdr *msg, size_t len,
vsk, target, read,
!(flags & MSG_PEEK), &recv_data);
if (err < 0)
- goto out_wait;
+ goto out;

if (read >= target || flags & MSG_PEEK)
break;

target -= read;
- } else {
- if (sk->sk_err != 0 || (sk->sk_shutdown & RCV_SHUTDOWN)
- || (vsk->peer_shutdown & SEND_SHUTDOWN)) {
- break;
- }
- /* Don't wait for non-blocking sockets. */
- if (timeout == 0) {
- err = -EAGAIN;
- break;
- }
-
- err = transport->notify_recv_pre_block(
- vsk, target, &recv_data);
- if (err < 0)
- break;
-
- release_sock(sk);
- timeout = schedule_timeout(timeout);
- lock_sock(sk);
-
- if (signal_pending(current)) {
- err = sock_intr_errno(timeout);
- break;
- } else if (timeout == 0) {
- err = -EAGAIN;
- break;
- }
-
- prepare_to_wait(sk_sleep(sk), &wait,
- TASK_INTERRUPTIBLE);
}
}

@@ -1816,8 +1830,6 @@ vsock_stream_recvmsg(struct socket *sock, struct msghdr *msg, size_t len,
err = copied;
}

-out_wait:
- finish_wait(sk_sleep(sk), &wait);
out:
release_sock(sk);
return err;
--
1.9.1