Re: [PATCH RFC net-next v2 3/4] vsock: Add lockless sendmsg() support

From: Bobby Eshleman
Date: Thu Apr 20 2023 - 02:21:41 EST


On Wed, Apr 19, 2023 at 11:30:53AM +0200, Stefano Garzarella wrote:
> On Fri, Apr 14, 2023 at 12:25:59AM +0000, Bobby Eshleman wrote:
> > Because the dgram sendmsg() path for AF_VSOCK acquires the socket lock
> > it does not scale when many senders share a socket.
> >
> > Prior to this patch the socket lock is used to protect the local_addr,
> > remote_addr, transport, and buffer size variables. What follows are the
> > new protection schemes for the various protected fields that ensure a
> > race-free multi-sender sendmsg() path for vsock dgrams.
> >
> > - local_addr
> > local_addr changes as a result of binding a socket. The write path
> > for local_addr is bind() and various vsock_auto_bind() call sites.
> > After a socket has been bound via vsock_auto_bind() or bind(), subsequent
> > calls to bind()/vsock_auto_bind() do not write to local_addr again. bind()
> > rejects the user request and vsock_auto_bind() early exits.
> > Therefore, the local addr can not change while a parallel thread is
> > in sendmsg() and lock-free reads of local addr in sendmsg() are safe.
> > Change: only acquire lock for auto-binding as-needed in sendmsg().
> >
> > - vsk->transport
> > Updated upon socket creation and it doesn't change again until the
>
> This is true only for dgram, right?
>

Yes.

> How do we decide which transport to assign for dgram?
>

The transport is assigned in proto->create() [vsock_create()]. It is
assigned there *only* for dgrams, whereas for streams/seqpackets it is
assigned in connect(). vsock_create() sets transport to
'transport_dgram' if sock->type == SOCK_DGRAM.

vsock_sk_destruct() then eventually sets vsk->transport to NULL.

Neither destruct nor create can occur in parallel with sendmsg().
create() hasn't yet returned the sockfd for the user to call upon it
sendmsg(), and AFAICT destruct is only called after the final socket
reference is released, which only happens after the socket no longer
exists in the fd lookup table and so sendmsg() will fail before it ever
has the chance to race.

> > socket is destroyed, which only happens after the socket refcnt reaches
> > zero. This prevents any sendmsg() call from being entered because the
> > sockfd lookup fails beforehand. That is, sendmsg() and vsk->transport
> > writes cannot execute in parallel. Additionally, connect() doesn't
> > update vsk->transport for dgrams as it does for streams. Therefore
> > vsk->transport is also safe to access lock-free in the sendmsg() path.
> > No change.
> >
> > - buffer size variables
> > Not used by dgram, so they do not need protection. No change.
>
> Is this true because for dgram we use the socket buffer?
> Is it the same for VMCI?

Yes. The buf_alloc derived from buffer_size is also always ignored after
being initialized once since credits aren't used.

My reading of the VMCI code is that the buffer_size and
buffer_{min,max}_size variables are used only in connectible calls (like
connect and recv_listen), but not for datagrams.

>
> >
> > - remote_addr
> > Needs additional protection because before this patch the
> > remote_addr (consisting of several fields such as cid, port, and flags)
> > only changed atomically under socket lock context. By acquiring the
> > socket lock to read the structure, the changes made by connect() were
> > always made visible to sendmsg() atomically. Consequently, to retain
> > atomicity of updates but offer lock-free access, this patch
> > redesigns this field as an RCU-protected pointer.
> >
> > Writers are still synchronized using the socket lock, but readers
> > only read inside RCU read-side critical sections.
> >
> > Helpers are introduced for accessing and updating the new pointer.
> >
> > The remote_addr structure is wrapped together with an rcu_head into a
> > sockaddr_vm_rcu structure so that kfree_rcu() can be used. This removes
> > the need of writers to use synchronize_rcu() after freeing old structures
> > which is simply more efficient and reduces code churn where remote_addr
> > is already being updated inside read-side sections.
> >
> > Only virtio has been tested, but updates were necessary to the VMCI and
> > hyperv code. Unfortunately the author does not have access to
> > VMCI/hyperv systems so those changes are untested.
> >
> > Perf Tests
> > vCPUS: 16
> > Threads: 16
> > Payload: 4KB
> > Test Runs: 5
> > Type: SOCK_DGRAM
> >
> > Before: 245.2 MB/s
> > After: 509.2 MB/s (+107%)
> >
> > Notably, on the same test system, vsock dgram even outperforms
> > multi-threaded UDP over virtio-net with vhost and MQ support enabled.
>
> Cool!
>
> This patch is quite large, so I need to review it carefully in future
> versions, but in general it makes sense to me.
>
> Thanks,
> Stefano
>

Thanks for the initial comments!

Best,
Bobby

> >
> > Throughput metrics for single-threaded SOCK_DGRAM and
> > single/multi-threaded SOCK_STREAM showed no statistically signficant
> > throughput changes (lowest p-value reaching 0.27), with the range of the
> > mean difference ranging between -5% to +1%.
> >
> > Signed-off-by: Bobby Eshleman <bobby.eshleman@xxxxxxxxxxxxx>
> > ---
> > drivers/vhost/vsock.c | 12 +-
> > include/net/af_vsock.h | 19 ++-
> > net/vmw_vsock/af_vsock.c | 261 ++++++++++++++++++++++++++++----
> > net/vmw_vsock/diag.c | 10 +-
> > net/vmw_vsock/hyperv_transport.c | 15 +-
> > net/vmw_vsock/virtio_transport_common.c | 22 ++-
> > net/vmw_vsock/vmci_transport.c | 70 ++++++---
> > 7 files changed, 344 insertions(+), 65 deletions(-)
> >
> > diff --git a/drivers/vhost/vsock.c b/drivers/vhost/vsock.c
> > index 028cf079225e..da105cb856ac 100644
> > --- a/drivers/vhost/vsock.c
> > +++ b/drivers/vhost/vsock.c
> > @@ -296,13 +296,17 @@ static int
> > vhost_transport_cancel_pkt(struct vsock_sock *vsk)
> > {
> > struct vhost_vsock *vsock;
> > + unsigned int cid;
> > int cnt = 0;
> > int ret = -ENODEV;
> >
> > rcu_read_lock();
> > + ret = vsock_remote_addr_cid(vsk, &cid);
> > + if (ret < 0)
> > + goto out;
> >
> > /* Find the vhost_vsock according to guest context id */
> > - vsock = vhost_vsock_get(vsk->remote_addr.svm_cid);
> > + vsock = vhost_vsock_get(cid);
> > if (!vsock)
> > goto out;
> >
> > @@ -686,6 +690,10 @@ static void vhost_vsock_flush(struct vhost_vsock *vsock)
> > static void vhost_vsock_reset_orphans(struct sock *sk)
> > {
> > struct vsock_sock *vsk = vsock_sk(sk);
> > + unsigned int cid;
> > +
> > + if (vsock_remote_addr_cid(vsk, &cid) < 0)
> > + return;
> >
> > /* vmci_transport.c doesn't take sk_lock here either. At least we're
> > * under vsock_table_lock so the sock cannot disappear while we're
> > @@ -693,7 +701,7 @@ static void vhost_vsock_reset_orphans(struct sock *sk)
> > */
> >
> > /* If the peer is still valid, no need to reset connection */
> > - if (vhost_vsock_get(vsk->remote_addr.svm_cid))
> > + if (vhost_vsock_get(cid))
> > return;
> >
> > /* If the close timeout is pending, let it expire. This avoids races
> > diff --git a/include/net/af_vsock.h b/include/net/af_vsock.h
> > index 57af28fede19..c02fd6ad0047 100644
> > --- a/include/net/af_vsock.h
> > +++ b/include/net/af_vsock.h
> > @@ -25,12 +25,17 @@ extern spinlock_t vsock_table_lock;
> > #define vsock_sk(__sk) ((struct vsock_sock *)__sk)
> > #define sk_vsock(__vsk) (&(__vsk)->sk)
> >
> > +struct sockaddr_vm_rcu {
> > + struct sockaddr_vm addr;
> > + struct rcu_head rcu;
> > +};
> > +
> > struct vsock_sock {
> > /* sk must be the first member. */
> > struct sock sk;
> > const struct vsock_transport *transport;
> > struct sockaddr_vm local_addr;
> > - struct sockaddr_vm remote_addr;
> > + struct sockaddr_vm_rcu * __rcu remote_addr;
> > /* Links for the global tables of bound and connected sockets. */
> > struct list_head bound_table;
> > struct list_head connected_table;
> > @@ -206,7 +211,7 @@ void vsock_release_pending(struct sock *pending);
> > void vsock_add_pending(struct sock *listener, struct sock *pending);
> > void vsock_remove_pending(struct sock *listener, struct sock *pending);
> > void vsock_enqueue_accept(struct sock *listener, struct sock *connected);
> > -void vsock_insert_connected(struct vsock_sock *vsk);
> > +int vsock_insert_connected(struct vsock_sock *vsk);
> > void vsock_remove_bound(struct vsock_sock *vsk);
> > void vsock_remove_connected(struct vsock_sock *vsk);
> > struct sock *vsock_find_bound_socket(struct sockaddr_vm *addr);
> > @@ -244,4 +249,14 @@ static inline void __init vsock_bpf_build_proto(void)
> > {}
> > #endif
> >
> > +/* RCU-protected remote addr helpers */
> > +int vsock_remote_addr_cid(struct vsock_sock *vsk, unsigned int *cid);
> > +int vsock_remote_addr_port(struct vsock_sock *vsk, unsigned int *port);
> > +int vsock_remote_addr_cid_port(struct vsock_sock *vsk, unsigned int *cid,
> > + unsigned int *port);
> > +int vsock_remote_addr_copy(struct vsock_sock *vsk, struct sockaddr_vm *dest);
> > +bool vsock_remote_addr_bound(struct vsock_sock *vsk);
> > +bool vsock_remote_addr_equals(struct vsock_sock *vsk, struct sockaddr_vm *other);
> > +int vsock_remote_addr_update_cid_port(struct vsock_sock *vsk, u32 cid, u32 port);
> > +
> > #endif /* __AF_VSOCK_H__ */
> > diff --git a/net/vmw_vsock/af_vsock.c b/net/vmw_vsock/af_vsock.c
> > index 46b3f35e3adc..93b4abbf20b4 100644
> > --- a/net/vmw_vsock/af_vsock.c
> > +++ b/net/vmw_vsock/af_vsock.c
> > @@ -145,6 +145,139 @@ static const struct vsock_transport *transport_local;
> > static DEFINE_MUTEX(vsock_register_mutex);
> >
> > /**** UTILS ****/
> > +bool vsock_remote_addr_bound(struct vsock_sock *vsk)
> > +{
> > + struct sockaddr_vm_rcu *remote_addr;
> > + bool ret;
> > +
> > + rcu_read_lock();
> > + remote_addr = rcu_dereference(vsk->remote_addr);
> > + if (!remote_addr) {
> > + rcu_read_unlock();
> > + return false;
> > + }
> > +
> > + ret = vsock_addr_bound(&remote_addr->addr);
> > + rcu_read_unlock();
> > +
> > + return ret;
> > +}
> > +EXPORT_SYMBOL_GPL(vsock_remote_addr_bound);
> > +
> > +int vsock_remote_addr_copy(struct vsock_sock *vsk, struct sockaddr_vm *dest)
> > +{
> > + struct sockaddr_vm_rcu *remote_addr;
> > +
> > + rcu_read_lock();
> > + remote_addr = rcu_dereference(vsk->remote_addr);
> > + if (!remote_addr) {
> > + rcu_read_unlock();
> > + return -EINVAL;
> > + }
> > + memcpy(dest, &remote_addr->addr, sizeof(*dest));
> > + rcu_read_unlock();
> > +
> > + return 0;
> > +}
> > +EXPORT_SYMBOL_GPL(vsock_remote_addr_copy);
> > +
> > +int vsock_remote_addr_cid(struct vsock_sock *vsk, unsigned int *cid)
> > +{
> > + return vsock_remote_addr_cid_port(vsk, cid, NULL);
> > +}
> > +EXPORT_SYMBOL_GPL(vsock_remote_addr_cid);
> > +
> > +int vsock_remote_addr_port(struct vsock_sock *vsk, unsigned int *port)
> > +{
> > + return vsock_remote_addr_cid_port(vsk, NULL, port);
> > +}
> > +EXPORT_SYMBOL_GPL(vsock_remote_addr_port);
> > +
> > +int vsock_remote_addr_cid_port(struct vsock_sock *vsk, unsigned int *cid,
> > + unsigned int *port)
> > +{
> > + struct sockaddr_vm_rcu *remote_addr;
> > +
> > + rcu_read_lock();
> > + remote_addr = rcu_dereference(vsk->remote_addr);
> > + if (!remote_addr) {
> > + rcu_read_unlock();
> > + return -EINVAL;
> > + }
> > +
> > + if (cid)
> > + *cid = remote_addr->addr.svm_cid;
> > + if (port)
> > + *port = remote_addr->addr.svm_port;
> > +
> > + rcu_read_unlock();
> > + return 0;
> > +}
> > +EXPORT_SYMBOL_GPL(vsock_remote_addr_cid_port);
> > +
> > +/* The socket lock must be held by the caller */
> > +int vsock_remote_addr_update_cid_port(struct vsock_sock *vsk, u32 cid, u32 port)
> > +{
> > + struct sockaddr_vm_rcu *old, *new;
> > +
> > + new = kmalloc(sizeof(*new), GFP_KERNEL);
> > + if (!new)
> > + return -ENOMEM;
> > +
> > + rcu_read_lock();
> > + old = rcu_dereference(vsk->remote_addr);
> > + if (!old) {
> > + kfree(new);
> > + return -EINVAL;
> > + }
> > + memcpy(&new->addr, &old->addr, sizeof(new->addr));
> > + rcu_read_unlock();
> > +
> > + new->addr.svm_cid = cid;
> > + new->addr.svm_port = port;
> > +
> > + old = rcu_replace_pointer(vsk->remote_addr, new, lockdep_sock_is_held(sk_vsock(vsk)));
> > + kfree_rcu(old, rcu);
> > +
> > + return 0;
> > +}
> > +EXPORT_SYMBOL_GPL(vsock_remote_addr_update_cid_port);
> > +
> > +/* The socket lock must be held by the caller */
> > +int vsock_remote_addr_update(struct vsock_sock *vsk, struct sockaddr_vm *src)
> > +{
> > + struct sockaddr_vm_rcu *old, *new;
> > +
> > + new = kmalloc(sizeof(*new), GFP_KERNEL);
> > + if (!new)
> > + return -ENOMEM;
> > +
> > + memcpy(&new->addr, src, sizeof(new->addr));
> > + old = rcu_replace_pointer(vsk->remote_addr, new, lockdep_sock_is_held(sk_vsock(vsk)));
> > + kfree_rcu(old, rcu);
> > +
> > + return 0;
> > +}
> > +
> > +bool vsock_remote_addr_equals(struct vsock_sock *vsk,
> > + struct sockaddr_vm *other)
> > +{
> > + struct sockaddr_vm_rcu *remote_addr;
> > + bool equals;
> > +
> > + rcu_read_lock();
> > + remote_addr = rcu_dereference(vsk->remote_addr);
> > + if (!remote_addr) {
> > + rcu_read_unlock();
> > + return false;
> > + }
> > +
> > + equals = vsock_addr_equals_addr(&remote_addr->addr, other);
> > + rcu_read_unlock();
> > +
> > + return equals;
> > +}
> > +EXPORT_SYMBOL_GPL(vsock_remote_addr_equals);
> >
> > /* Each bound VSocket is stored in the bind hash table and each connected
> > * VSocket is stored in the connected hash table.
> > @@ -254,10 +387,16 @@ static struct sock *__vsock_find_connected_socket(struct sockaddr_vm *src,
> >
> > list_for_each_entry(vsk, vsock_connected_sockets(src, dst),
> > connected_table) {
> > - if (vsock_addr_equals_addr(src, &vsk->remote_addr) &&
> > + struct sockaddr_vm_rcu *remote_addr;
> > +
> > + rcu_read_lock();
> > + remote_addr = rcu_dereference(vsk->remote_addr);
> > + if (vsock_addr_equals_addr(src, &remote_addr->addr) &&
> > dst->svm_port == vsk->local_addr.svm_port) {
> > + rcu_read_unlock();
> > return sk_vsock(vsk);
> > }
> > + rcu_read_unlock();
> > }
> >
> > return NULL;
> > @@ -270,14 +409,25 @@ static void vsock_insert_unbound(struct vsock_sock *vsk)
> > spin_unlock_bh(&vsock_table_lock);
> > }
> >
> > -void vsock_insert_connected(struct vsock_sock *vsk)
> > +int vsock_insert_connected(struct vsock_sock *vsk)
> > {
> > - struct list_head *list = vsock_connected_sockets(
> > - &vsk->remote_addr, &vsk->local_addr);
> > + struct list_head *list;
> > + struct sockaddr_vm_rcu *remote_addr;
> > +
> > + rcu_read_lock();
> > + remote_addr = rcu_dereference(vsk->remote_addr);
> > + if (!remote_addr) {
> > + rcu_read_unlock();
> > + return -EINVAL;
> > + }
> > + list = vsock_connected_sockets(&remote_addr->addr, &vsk->local_addr);
> > + rcu_read_unlock();
> >
> > spin_lock_bh(&vsock_table_lock);
> > __vsock_insert_connected(list, vsk);
> > spin_unlock_bh(&vsock_table_lock);
> > +
> > + return 0;
> > }
> > EXPORT_SYMBOL_GPL(vsock_insert_connected);
> >
> > @@ -438,10 +588,17 @@ int vsock_assign_transport(struct vsock_sock *vsk, struct vsock_sock *psk)
> > {
> > const struct vsock_transport *new_transport;
> > struct sock *sk = sk_vsock(vsk);
> > - unsigned int remote_cid = vsk->remote_addr.svm_cid;
> > + struct sockaddr_vm remote_addr;
> > + unsigned int remote_cid;
> > __u8 remote_flags;
> > int ret;
> >
> > + ret = vsock_remote_addr_copy(vsk, &remote_addr);
> > + if (ret < 0)
> > + return ret;
> > +
> > + remote_cid = remote_addr.svm_cid;
> > +
> > /* If the packet is coming with the source and destination CIDs higher
> > * than VMADDR_CID_HOST, then a vsock channel where all the packets are
> > * forwarded to the host should be established. Then the host will
> > @@ -451,10 +608,15 @@ int vsock_assign_transport(struct vsock_sock *vsk, struct vsock_sock *psk)
> > * the connect path the flag can be set by the user space application.
> > */
> > if (psk && vsk->local_addr.svm_cid > VMADDR_CID_HOST &&
> > - vsk->remote_addr.svm_cid > VMADDR_CID_HOST)
> > - vsk->remote_addr.svm_flags |= VMADDR_FLAG_TO_HOST;
> > + remote_addr.svm_cid > VMADDR_CID_HOST) {
> > + remote_addr.svm_flags |= VMADDR_CID_HOST;
> > +
> > + ret = vsock_remote_addr_update(vsk, &remote_addr);
> > + if (ret < 0)
> > + return ret;
> > + }
> >
> > - remote_flags = vsk->remote_addr.svm_flags;
> > + remote_flags = remote_addr.svm_flags;
> >
> > switch (sk->sk_type) {
> > case SOCK_DGRAM:
> > @@ -742,6 +904,7 @@ static struct sock *__vsock_create(struct net *net,
> > unsigned short type,
> > int kern)
> > {
> > + struct sockaddr_vm *remote_addr;
> > struct sock *sk;
> > struct vsock_sock *psk;
> > struct vsock_sock *vsk;
> > @@ -761,7 +924,14 @@ static struct sock *__vsock_create(struct net *net,
> >
> > vsk = vsock_sk(sk);
> > vsock_addr_init(&vsk->local_addr, VMADDR_CID_ANY, VMADDR_PORT_ANY);
> > - vsock_addr_init(&vsk->remote_addr, VMADDR_CID_ANY, VMADDR_PORT_ANY);
> > +
> > + remote_addr = kmalloc(sizeof(*remote_addr), GFP_KERNEL);
> > + if (!remote_addr) {
> > + sk_free(sk);
> > + return NULL;
> > + }
> > + vsock_addr_init(remote_addr, VMADDR_CID_ANY, VMADDR_PORT_ANY);
> > + rcu_assign_pointer(vsk->remote_addr, remote_addr);
> >
> > sk->sk_destruct = vsock_sk_destruct;
> > sk->sk_backlog_rcv = vsock_queue_rcv_skb;
> > @@ -845,6 +1015,7 @@ static void __vsock_release(struct sock *sk, int level)
> > static void vsock_sk_destruct(struct sock *sk)
> > {
> > struct vsock_sock *vsk = vsock_sk(sk);
> > + struct sockaddr_vm_rcu *remote_addr;
> >
> > vsock_deassign_transport(vsk);
> >
> > @@ -852,8 +1023,8 @@ static void vsock_sk_destruct(struct sock *sk)
> > * possibly register the address family with the kernel.
> > */
> > vsock_addr_init(&vsk->local_addr, VMADDR_CID_ANY, VMADDR_PORT_ANY);
> > - vsock_addr_init(&vsk->remote_addr, VMADDR_CID_ANY, VMADDR_PORT_ANY);
> > -
> > + remote_addr = rcu_replace_pointer(vsk->remote_addr, NULL, 1);
> > + kfree_rcu(remote_addr);
> > put_cred(vsk->owner);
> > }
> >
> > @@ -943,6 +1114,7 @@ static int vsock_getname(struct socket *sock,
> > struct sock *sk;
> > struct vsock_sock *vsk;
> > struct sockaddr_vm *vm_addr;
> > + struct sockaddr_vm_rcu *rcu_ptr;
> >
> > sk = sock->sk;
> > vsk = vsock_sk(sk);
> > @@ -951,11 +1123,17 @@ static int vsock_getname(struct socket *sock,
> > lock_sock(sk);
> >
> > if (peer) {
> > + rcu_read_lock();
> > if (sock->state != SS_CONNECTED) {
> > err = -ENOTCONN;
> > goto out;
> > }
> > - vm_addr = &vsk->remote_addr;
> > + rcu_ptr = rcu_dereference(vsk->remote_addr);
> > + if (!rcu_ptr) {
> > + err = -EINVAL;
> > + goto out;
> > + }
> > + vm_addr = &rcu_ptr->addr;
> > } else {
> > vm_addr = &vsk->local_addr;
> > }
> > @@ -975,6 +1153,8 @@ static int vsock_getname(struct socket *sock,
> > err = sizeof(*vm_addr);
> >
> > out:
> > + if (peer)
> > + rcu_read_unlock();
> > release_sock(sk);
> > return err;
> > }
> > @@ -1161,7 +1341,7 @@ static int vsock_dgram_sendmsg(struct socket *sock, struct msghdr *msg,
> > int err;
> > struct sock *sk;
> > struct vsock_sock *vsk;
> > - struct sockaddr_vm *remote_addr;
> > + struct sockaddr_vm stack_addr, *remote_addr;
> > const struct vsock_transport *transport;
> >
> > if (msg->msg_flags & MSG_OOB)
> > @@ -1172,15 +1352,26 @@ static int vsock_dgram_sendmsg(struct socket *sock, struct msghdr *msg,
> > sk = sock->sk;
> > vsk = vsock_sk(sk);
> >
> > - lock_sock(sk);
> > + /* If auto-binding is required, acquire the slock to avoid potential
> > + * race conditions. Otherwise, do not acquire the lock.
> > + *
> > + * We know that the first check of local_addr is racy (indicated by
> > + * data_race()). By acquiring the lock and then subsequently checking
> > + * again if local_addr is bound (inside vsock_auto_bind()), we can
> > + * ensure there are no real data races.
> > + *
> > + * This technique is borrowed by inet_send_prepare().
> > + */
> > + if (data_race(!vsock_addr_bound(&vsk->local_addr))) {
> > + lock_sock(sk);
> > + err = vsock_auto_bind(vsk);
> > + release_sock(sk);
> > + if (err)
> > + return err;
> > + }
> >
> > transport = vsk->transport;
> >
> > - err = vsock_auto_bind(vsk);
> > - if (err)
> > - goto out;
> > -
> > -
> > /* If the provided message contains an address, use that. Otherwise
> > * fall back on the socket's remote handle (if it has been connected).
> > */
> > @@ -1199,18 +1390,26 @@ static int vsock_dgram_sendmsg(struct socket *sock, struct msghdr *msg,
> > goto out;
> > }
> > } else if (sock->state == SS_CONNECTED) {
> > - remote_addr = &vsk->remote_addr;
> > + err = vsock_remote_addr_copy(vsk, &stack_addr);
> > + if (err < 0)
> > + goto out;
> >
> > - if (remote_addr->svm_cid == VMADDR_CID_ANY)
> > - remote_addr->svm_cid = transport->get_local_cid();
> > + if (stack_addr.svm_cid == VMADDR_CID_ANY) {
> > + stack_addr.svm_cid = transport->get_local_cid();
> > + lock_sock(sk_vsock(vsk));
> > + vsock_remote_addr_update(vsk, &stack_addr);
> > + release_sock(sk_vsock(vsk));
> > + }
> >
> > /* XXX Should connect() or this function ensure remote_addr is
> > * bound?
> > */
> > - if (!vsock_addr_bound(&vsk->remote_addr)) {
> > + if (!vsock_addr_bound(&stack_addr)) {
> > err = -EINVAL;
> > goto out;
> > }
> > +
> > + remote_addr = &stack_addr;
> > } else {
> > err = -EINVAL;
> > goto out;
> > @@ -1225,7 +1424,6 @@ static int vsock_dgram_sendmsg(struct socket *sock, struct msghdr *msg,
> > err = transport->dgram_enqueue(vsk, remote_addr, msg, len);
> >
> > out:
> > - release_sock(sk);
> > return err;
> > }
> >
> > @@ -1243,8 +1441,7 @@ static int vsock_dgram_connect(struct socket *sock,
> > err = vsock_addr_cast(addr, addr_len, &remote_addr);
> > if (err == -EAFNOSUPPORT && remote_addr->svm_family == AF_UNSPEC) {
> > lock_sock(sk);
> > - vsock_addr_init(&vsk->remote_addr, VMADDR_CID_ANY,
> > - VMADDR_PORT_ANY);
> > + vsock_remote_addr_update_cid_port(vsk, VMADDR_CID_ANY, VMADDR_PORT_ANY);
> > sock->state = SS_UNCONNECTED;
> > release_sock(sk);
> > return 0;
> > @@ -1263,7 +1460,10 @@ static int vsock_dgram_connect(struct socket *sock,
> > goto out;
> > }
> >
> > - memcpy(&vsk->remote_addr, remote_addr, sizeof(vsk->remote_addr));
> > + err = vsock_remote_addr_update(vsk, remote_addr);
> > + if (err < 0)
> > + goto out;
> > +
> > sock->state = SS_CONNECTED;
> >
> > /* sock map disallows redirection of non-TCP sockets with sk_state !=
> > @@ -1399,8 +1599,9 @@ static int vsock_connect(struct socket *sock, struct sockaddr *addr,
> > }
> >
> > /* Set the remote address that we are connecting to. */
> > - memcpy(&vsk->remote_addr, remote_addr,
> > - sizeof(vsk->remote_addr));
> > + err = vsock_remote_addr_update(vsk, remote_addr);
> > + if (err)
> > + goto out;
> >
> > err = vsock_assign_transport(vsk, NULL);
> > if (err)
> > @@ -1831,7 +2032,7 @@ static int vsock_connectible_sendmsg(struct socket *sock, struct msghdr *msg,
> > goto out;
> > }
> >
> > - if (!vsock_addr_bound(&vsk->remote_addr)) {
> > + if (!vsock_remote_addr_bound(vsk)) {
> > err = -EDESTADDRREQ;
> > goto out;
> > }
> > diff --git a/net/vmw_vsock/diag.c b/net/vmw_vsock/diag.c
> > index a2823b1c5e28..f843bae86b32 100644
> > --- a/net/vmw_vsock/diag.c
> > +++ b/net/vmw_vsock/diag.c
> > @@ -15,8 +15,14 @@ static int sk_diag_fill(struct sock *sk, struct sk_buff *skb,
> > u32 portid, u32 seq, u32 flags)
> > {
> > struct vsock_sock *vsk = vsock_sk(sk);
> > + struct sockaddr_vm remote_addr;
> > struct vsock_diag_msg *rep;
> > struct nlmsghdr *nlh;
> > + int err;
> > +
> > + err = vsock_remote_addr_copy(vsk, &remote_addr);
> > + if (err < 0)
> > + return err;
> >
> > nlh = nlmsg_put(skb, portid, seq, SOCK_DIAG_BY_FAMILY, sizeof(*rep),
> > flags);
> > @@ -36,8 +42,8 @@ static int sk_diag_fill(struct sock *sk, struct sk_buff *skb,
> > rep->vdiag_shutdown = sk->sk_shutdown;
> > rep->vdiag_src_cid = vsk->local_addr.svm_cid;
> > rep->vdiag_src_port = vsk->local_addr.svm_port;
> > - rep->vdiag_dst_cid = vsk->remote_addr.svm_cid;
> > - rep->vdiag_dst_port = vsk->remote_addr.svm_port;
> > + rep->vdiag_dst_cid = remote_addr.svm_cid;
> > + rep->vdiag_dst_port = remote_addr.svm_port;
> > rep->vdiag_ino = sock_i_ino(sk);
> >
> > sock_diag_save_cookie(sk, rep->vdiag_cookie);
> > diff --git a/net/vmw_vsock/hyperv_transport.c b/net/vmw_vsock/hyperv_transport.c
> > index 7cb1a9d2cdb4..462b2ec3e6e9 100644
> > --- a/net/vmw_vsock/hyperv_transport.c
> > +++ b/net/vmw_vsock/hyperv_transport.c
> > @@ -336,9 +336,11 @@ static void hvs_open_connection(struct vmbus_channel *chan)
> > hvs_addr_init(&vnew->local_addr, if_type);
> >
> > /* Remote peer is always the host */
> > - vsock_addr_init(&vnew->remote_addr,
> > - VMADDR_CID_HOST, VMADDR_PORT_ANY);
> > - vnew->remote_addr.svm_port = get_port_by_srv_id(if_instance);
> > + ret = vsock_remote_addr_update_cid_port(vnew, VMADDR_CID_HOST,
> > + get_port_by_srv_id(if_instance));
> > + if (ret < 0)
> > + goto out;
> > +
> > ret = vsock_assign_transport(vnew, vsock_sk(sk));
> > /* Transport assigned (looking at remote_addr) must be the
> > * same where we received the request.
> > @@ -459,13 +461,18 @@ static int hvs_connect(struct vsock_sock *vsk)
> > {
> > union hvs_service_id vm, host;
> > struct hvsock *h = vsk->trans;
> > + int err;
> >
> > vm.srv_id = srv_id_template;
> > vm.svm_port = vsk->local_addr.svm_port;
> > h->vm_srv_id = vm.srv_id;
> >
> > host.srv_id = srv_id_template;
> > - host.svm_port = vsk->remote_addr.svm_port;
> > +
> > + err = vsock_remote_addr_port(vsk, &host.svm_port);
> > + if (err < 0)
> > + return err;
> > +
> > h->host_srv_id = host.srv_id;
> >
> > return vmbus_send_tl_connect_request(&h->vm_srv_id, &h->host_srv_id);
> > diff --git a/net/vmw_vsock/virtio_transport_common.c b/net/vmw_vsock/virtio_transport_common.c
> > index 925acface893..1b87704e516a 100644
> > --- a/net/vmw_vsock/virtio_transport_common.c
> > +++ b/net/vmw_vsock/virtio_transport_common.c
> > @@ -258,8 +258,9 @@ static int virtio_transport_send_pkt_info(struct vsock_sock *vsk,
> > src_cid = t_ops->transport.get_local_cid();
> > src_port = vsk->local_addr.svm_port;
> > if (!info->remote_cid) {
> > - dst_cid = vsk->remote_addr.svm_cid;
> > - dst_port = vsk->remote_addr.svm_port;
> > + ret = vsock_remote_addr_cid_port(vsk, &dst_cid, &dst_port);
> > + if (ret < 0)
> > + return ret;
> > } else {
> > dst_cid = info->remote_cid;
> > dst_port = info->remote_port;
> > @@ -1169,7 +1170,9 @@ virtio_transport_recv_connecting(struct sock *sk,
> > case VIRTIO_VSOCK_OP_RESPONSE:
> > sk->sk_state = TCP_ESTABLISHED;
> > sk->sk_socket->state = SS_CONNECTED;
> > - vsock_insert_connected(vsk);
> > + err = vsock_insert_connected(vsk);
> > + if (err)
> > + goto destroy;
> > sk->sk_state_change(sk);
> > break;
> > case VIRTIO_VSOCK_OP_INVALID:
> > @@ -1403,9 +1406,8 @@ virtio_transport_recv_listen(struct sock *sk, struct sk_buff *skb,
> > vchild = vsock_sk(child);
> > vsock_addr_init(&vchild->local_addr, le64_to_cpu(hdr->dst_cid),
> > le32_to_cpu(hdr->dst_port));
> > - vsock_addr_init(&vchild->remote_addr, le64_to_cpu(hdr->src_cid),
> > - le32_to_cpu(hdr->src_port));
> > -
> > + vsock_remote_addr_update_cid_port(vchild, le64_to_cpu(hdr->src_cid),
> > + le32_to_cpu(hdr->src_port));
> > ret = vsock_assign_transport(vchild, vsk);
> > /* Transport assigned (looking at remote_addr) must be the same
> > * where we received the request.
> > @@ -1420,7 +1422,13 @@ virtio_transport_recv_listen(struct sock *sk, struct sk_buff *skb,
> > if (virtio_transport_space_update(child, skb))
> > child->sk_write_space(child);
> >
> > - vsock_insert_connected(vchild);
> > + ret = vsock_insert_connected(vchild);
> > + if (ret) {
> > + release_sock(child);
> > + virtio_transport_reset_no_sock(t, skb);
> > + sock_put(child);
> > + return ret;
> > + }
> > vsock_enqueue_accept(sk, child);
> > virtio_transport_send_response(vchild, skb);
> >
> > diff --git a/net/vmw_vsock/vmci_transport.c b/net/vmw_vsock/vmci_transport.c
> > index b370070194fa..c0c445e7d925 100644
> > --- a/net/vmw_vsock/vmci_transport.c
> > +++ b/net/vmw_vsock/vmci_transport.c
> > @@ -283,18 +283,25 @@ vmci_transport_send_control_pkt(struct sock *sk,
> > u16 proto,
> > struct vmci_handle handle)
> > {
> > + struct sockaddr_vm addr_stack;
> > + struct sockaddr_vm *remote_addr = &addr_stack;
> > struct vsock_sock *vsk;
> > + int err;
> >
> > vsk = vsock_sk(sk);
> >
> > if (!vsock_addr_bound(&vsk->local_addr))
> > return -EINVAL;
> >
> > - if (!vsock_addr_bound(&vsk->remote_addr))
> > + if (!vsock_remote_addr_bound(vsk))
> > return -EINVAL;
> >
> > + err = vsock_remote_addr_copy(vsk, &addr_stack);
> > + if (err < 0)
> > + return err;
> > +
> > return vmci_transport_alloc_send_control_pkt(&vsk->local_addr,
> > - &vsk->remote_addr,
> > + remote_addr,
> > type, size, mode,
> > wait, proto, handle);
> > }
> > @@ -317,6 +324,7 @@ static int vmci_transport_send_reset(struct sock *sk,
> > struct sockaddr_vm *dst_ptr;
> > struct sockaddr_vm dst;
> > struct vsock_sock *vsk;
> > + int err;
> >
> > if (pkt->type == VMCI_TRANSPORT_PACKET_TYPE_RST)
> > return 0;
> > @@ -326,13 +334,16 @@ static int vmci_transport_send_reset(struct sock *sk,
> > if (!vsock_addr_bound(&vsk->local_addr))
> > return -EINVAL;
> >
> > - if (vsock_addr_bound(&vsk->remote_addr)) {
> > - dst_ptr = &vsk->remote_addr;
> > + if (vsock_remote_addr_bound(vsk)) {
> > + err = vsock_remote_addr_copy(vsk, &dst);
> > + if (err < 0)
> > + return err;
> > } else {
> > vsock_addr_init(&dst, pkt->dg.src.context,
> > pkt->src_port);
> > - dst_ptr = &dst;
> > }
> > + dst_ptr = &dst;
> > +
> > return vmci_transport_alloc_send_control_pkt(&vsk->local_addr, dst_ptr,
> > VMCI_TRANSPORT_PACKET_TYPE_RST,
> > 0, 0, NULL, VSOCK_PROTO_INVALID,
> > @@ -490,7 +501,7 @@ static struct sock *vmci_transport_get_pending(
> >
> > list_for_each_entry(vpending, &vlistener->pending_links,
> > pending_links) {
> > - if (vsock_addr_equals_addr(&src, &vpending->remote_addr) &&
> > + if (vsock_remote_addr_equals(vpending, &src) &&
> > pkt->dst_port == vpending->local_addr.svm_port) {
> > pending = sk_vsock(vpending);
> > sock_hold(pending);
> > @@ -1015,8 +1026,8 @@ static int vmci_transport_recv_listen(struct sock *sk,
> >
> > vsock_addr_init(&vpending->local_addr, pkt->dg.dst.context,
> > pkt->dst_port);
> > - vsock_addr_init(&vpending->remote_addr, pkt->dg.src.context,
> > - pkt->src_port);
> > + vsock_remote_addr_update_cid_port(vpending, pkt->dg.src.context,
> > + pkt->src_port);
> >
> > err = vsock_assign_transport(vpending, vsock_sk(sk));
> > /* Transport assigned (looking at remote_addr) must be the same
> > @@ -1133,6 +1144,7 @@ vmci_transport_recv_connecting_server(struct sock *listener,
> > {
> > struct vsock_sock *vpending;
> > struct vmci_handle handle;
> > + unsigned int vpending_remote_cid;
> > struct vmci_qp *qpair;
> > bool is_local;
> > u32 flags;
> > @@ -1189,8 +1201,13 @@ vmci_transport_recv_connecting_server(struct sock *listener,
> > /* vpending->local_addr always has a context id so we do not need to
> > * worry about VMADDR_CID_ANY in this case.
> > */
> > - is_local =
> > - vpending->remote_addr.svm_cid == vpending->local_addr.svm_cid;
> > + err = vsock_remote_addr_cid(vpending, &vpending_remote_cid);
> > + if (err < 0) {
> > + skerr = EPROTO;
> > + goto destroy;
> > + }
> > +
> > + is_local = vpending_remote_cid == vpending->local_addr.svm_cid;
> > flags = VMCI_QPFLAG_ATTACH_ONLY;
> > flags |= is_local ? VMCI_QPFLAG_LOCAL : 0;
> >
> > @@ -1203,7 +1220,7 @@ vmci_transport_recv_connecting_server(struct sock *listener,
> > flags,
> > vmci_transport_is_trusted(
> > vpending,
> > - vpending->remote_addr.svm_cid));
> > + vpending_remote_cid));
> > if (err < 0) {
> > vmci_transport_send_reset(pending, pkt);
> > skerr = -err;
> > @@ -1306,9 +1323,20 @@ vmci_transport_recv_connecting_client(struct sock *sk,
> > break;
> > case VMCI_TRANSPORT_PACKET_TYPE_NEGOTIATE:
> > case VMCI_TRANSPORT_PACKET_TYPE_NEGOTIATE2:
> > + struct sockaddr_vm_rcu *remote_addr;
> > +
> > + rcu_read_lock();
> > + remote_addr = rcu_dereference(vsk->remote_addr);
> > + if (!remote_addr) {
> > + skerr = EPROTO;
> > + err = -EINVAL;
> > + rcu_read_unlock();
> > + goto destroy;
> > + }
> > +
> > if (pkt->u.size == 0
> > - || pkt->dg.src.context != vsk->remote_addr.svm_cid
> > - || pkt->src_port != vsk->remote_addr.svm_port
> > + || pkt->dg.src.context != remote_addr->addr.svm_cid
> > + || pkt->src_port != remote_addr->addr.svm_port
> > || !vmci_handle_is_invalid(vmci_trans(vsk)->qp_handle)
> > || vmci_trans(vsk)->qpair
> > || vmci_trans(vsk)->produce_size != 0
> > @@ -1316,9 +1344,10 @@ vmci_transport_recv_connecting_client(struct sock *sk,
> > || vmci_trans(vsk)->detach_sub_id != VMCI_INVALID_ID) {
> > skerr = EPROTO;
> > err = -EINVAL;
> > -
> > + rcu_read_unlock();
> > goto destroy;
> > }
> > + rcu_read_unlock();
> >
> > err = vmci_transport_recv_connecting_client_negotiate(sk, pkt);
> > if (err) {
> > @@ -1379,6 +1408,7 @@ static int vmci_transport_recv_connecting_client_negotiate(
> > int err;
> > struct vsock_sock *vsk;
> > struct vmci_handle handle;
> > + unsigned int remote_cid;
> > struct vmci_qp *qpair;
> > u32 detach_sub_id;
> > bool is_local;
> > @@ -1449,19 +1479,23 @@ static int vmci_transport_recv_connecting_client_negotiate(
> >
> > /* Make VMCI select the handle for us. */
> > handle = VMCI_INVALID_HANDLE;
> > - is_local = vsk->remote_addr.svm_cid == vsk->local_addr.svm_cid;
> > +
> > + err = vsock_remote_addr_cid(vsk, &remote_cid);
> > + if (err < 0)
> > + goto destroy;
> > +
> > + is_local = remote_cid == vsk->local_addr.svm_cid;
> > flags = is_local ? VMCI_QPFLAG_LOCAL : 0;
> >
> > err = vmci_transport_queue_pair_alloc(&qpair,
> > &handle,
> > pkt->u.size,
> > pkt->u.size,
> > - vsk->remote_addr.svm_cid,
> > + remote_cid,
> > flags,
> > vmci_transport_is_trusted(
> > vsk,
> > - vsk->
> > - remote_addr.svm_cid));
> > + remote_cid));
> > if (err < 0)
> > goto destroy;
> >
> >
> > --
> > 2.30.2
> >
>