Re: [RFC PATCH v3 04/11] virtio/vsock: add transport zerocopy callback

From: Arseniy Krasnov
Date: Thu Nov 10 2022 - 06:16:15 EST


On 06.11.2022 22:41, Arseniy Krasnov wrote:
> This adds transport callback which processes rx queue of socket and
> instead of copying data to user provided buffer, it inserts data pages
> of each packet to user's vm area.
>
> Signed-off-by: Arseniy Krasnov <AVKrasnov@xxxxxxxxxxxxxx>
> ---
> include/linux/virtio_vsock.h | 7 +
> include/uapi/linux/virtio_vsock.h | 14 ++
> net/vmw_vsock/virtio_transport_common.c | 244 +++++++++++++++++++++++-
> 3 files changed, 261 insertions(+), 4 deletions(-)
>
> diff --git a/include/linux/virtio_vsock.h b/include/linux/virtio_vsock.h
> index c1be40f89a89..d10fdfd8d144 100644
> --- a/include/linux/virtio_vsock.h
> +++ b/include/linux/virtio_vsock.h
> @@ -37,6 +37,7 @@ struct virtio_vsock_sock {
> u32 buf_alloc;
> struct list_head rx_queue;
> u32 msg_count;
> + struct page *usr_poll_page;
> };
>
> struct virtio_vsock_pkt {
> @@ -51,6 +52,7 @@ struct virtio_vsock_pkt {
> bool reply;
> bool tap_delivered;
> bool slab_buf;
> + bool split;
> };
>
> struct virtio_vsock_pkt_info {
> @@ -131,6 +133,11 @@ int virtio_transport_dgram_bind(struct vsock_sock *vsk,
> struct sockaddr_vm *addr);
> bool virtio_transport_dgram_allow(u32 cid, u32 port);
>
> +int virtio_transport_zerocopy_init(struct vsock_sock *vsk,
> + struct vm_area_struct *vma);
> +int virtio_transport_zerocopy_dequeue(struct vsock_sock *vsk,
> + struct page **pages,
> + unsigned long *pages_num);
> int virtio_transport_connect(struct vsock_sock *vsk);
>
> int virtio_transport_shutdown(struct vsock_sock *vsk, int mode);
> diff --git a/include/uapi/linux/virtio_vsock.h b/include/uapi/linux/virtio_vsock.h
> index 64738838bee5..2a0e4f309918 100644
> --- a/include/uapi/linux/virtio_vsock.h
> +++ b/include/uapi/linux/virtio_vsock.h
> @@ -66,6 +66,20 @@ struct virtio_vsock_hdr {
> __le32 fwd_cnt;
> } __attribute__((packed));
>
> +struct virtio_vsock_usr_hdr {
> + u32 flags;
> + u32 len;
> +} __attribute__((packed));
> +
> +#define VIRTIO_VSOCK_USR_POLL_NO_DATA 0
> +#define VIRTIO_VSOCK_USR_POLL_HAS_DATA 1
> +#define VIRTIO_VSOCK_USR_POLL_SHUTDOWN ~0
> +
> +struct virtio_vsock_usr_hdr_pref {
> + u32 poll_value;
> + u32 hdr_num;
> +} __attribute__((packed));
> +
> enum virtio_vsock_type {
> VIRTIO_VSOCK_TYPE_STREAM = 1,
> VIRTIO_VSOCK_TYPE_SEQPACKET = 2,
> diff --git a/net/vmw_vsock/virtio_transport_common.c b/net/vmw_vsock/virtio_transport_common.c
> index 444764869670..fa4a2688a5d5 100644
> --- a/net/vmw_vsock/virtio_transport_common.c
> +++ b/net/vmw_vsock/virtio_transport_common.c
> @@ -12,6 +12,7 @@
> #include <linux/ctype.h>
> #include <linux/list.h>
> #include <linux/virtio_vsock.h>
> +#include <linux/mm.h>
> #include <uapi/linux/vsockmon.h>
>
> #include <net/sock.h>
> @@ -241,6 +242,14 @@ static int virtio_transport_send_pkt_info(struct vsock_sock *vsk,
> static bool virtio_transport_inc_rx_pkt(struct virtio_vsock_sock *vvs,
> struct virtio_vsock_pkt *pkt)
> {
> + if (vvs->usr_poll_page) {
> + struct virtio_vsock_usr_hdr_pref *hdr;
> +
> + hdr = (struct virtio_vsock_usr_hdr_pref *)page_to_virt(vvs->usr_poll_page);
> +
> + hdr->poll_value = VIRTIO_VSOCK_USR_POLL_HAS_DATA;
> + }
> +
> if (vvs->rx_bytes + pkt->len > vvs->buf_alloc)
> return false;
>
> @@ -253,6 +262,14 @@ static void virtio_transport_dec_rx_pkt(struct virtio_vsock_sock *vvs,
> {
> vvs->rx_bytes -= pkt->len;
> vvs->fwd_cnt += pkt->len;
> +
> + if (!vvs->rx_bytes && vvs->usr_poll_page) {
> + struct virtio_vsock_usr_hdr_pref *hdr;
> +
> + hdr = (struct virtio_vsock_usr_hdr_pref *)page_to_virt(vvs->usr_poll_page);
> +
> + hdr->poll_value = VIRTIO_VSOCK_USR_POLL_NO_DATA;
> + }
> }
>
> void virtio_transport_inc_tx_pkt(struct virtio_vsock_sock *vvs, struct virtio_vsock_pkt *pkt)
> @@ -347,6 +364,191 @@ virtio_transport_stream_do_peek(struct vsock_sock *vsk,
> return err;
> }
>
> +int virtio_transport_zerocopy_init(struct vsock_sock *vsk,
> + struct vm_area_struct *vma)
> +{
> + struct virtio_vsock_sock *vvs;
> + int err = 0;
> +
> + if (vma->vm_end - vma->vm_start < 2 * PAGE_SIZE)
> + return -EINVAL;
> +
> + vvs = vsk->trans;
> +
> + spin_lock_bh(&vvs->rx_lock);
> +
> + if (!vvs->usr_poll_page) {
> + /* GFP_ATOMIC because of spinlock. */
> + vvs->usr_poll_page = alloc_page(GFP_KERNEL | GFP_ATOMIC);
^^^ oops, only GFP_ATOMIC is needed
> +
> + if (!vvs->usr_poll_page) {
> + err = -ENOMEM;
> + } else {
> + struct virtio_vsock_usr_hdr_pref *usr_hdr_pref;
> + unsigned long one_page = 1;
> +
> + usr_hdr_pref = page_to_virt(vvs->usr_poll_page);
> +
> + if (vsk->peer_shutdown & SHUTDOWN_MASK) {
> + usr_hdr_pref->poll_value = VIRTIO_VSOCK_USR_POLL_SHUTDOWN;
> + } else {
> + usr_hdr_pref->poll_value = vvs->rx_bytes ?
> + VIRTIO_VSOCK_USR_POLL_HAS_DATA :
> + VIRTIO_VSOCK_USR_POLL_NO_DATA;
> + }
> +
> + usr_hdr_pref->hdr_num = 0;
> +
> + err = vm_insert_pages(vma, vma->vm_start,
> + &vvs->usr_poll_page,
> + &one_page);
> +
> + if (one_page)
> + err = -EINVAL;
> + }
> + } else {
> + err = -EINVAL;
> + }
> +
> + spin_unlock_bh(&vvs->rx_lock);
> +
> + return err;
> +}
> +EXPORT_SYMBOL_GPL(virtio_transport_zerocopy_init);
> +
> +int virtio_transport_zerocopy_dequeue(struct vsock_sock *vsk,
> + struct page **pages,
> + unsigned long *pages_num)
> +{
> + struct virtio_vsock_usr_hdr_pref *usr_hdr_pref;
> + struct virtio_vsock_usr_hdr *usr_hdr_buffer;
> + struct virtio_vsock_sock *vvs;
> + unsigned long max_usr_hdrs;
> + struct page *usr_hdr_page;
> + int pages_cnt;
> +
> + if (*pages_num < 2)
> + return -EINVAL;
> +
> + vvs = vsk->trans;
> +
> + max_usr_hdrs = (PAGE_SIZE - sizeof(*usr_hdr_pref)) / sizeof(*usr_hdr_buffer);
> + *pages_num = min(max_usr_hdrs, *pages_num);
> + pages_cnt = 0;
> +
> + spin_lock_bh(&vvs->rx_lock);
> +
> + if (!vvs->usr_poll_page) {
> + spin_unlock_bh(&vvs->rx_lock);
> + return -EINVAL;
> + }
> +
> + usr_hdr_page = vvs->usr_poll_page;
> + usr_hdr_pref = page_to_virt(usr_hdr_page);
> + usr_hdr_buffer = (struct virtio_vsock_usr_hdr *)(usr_hdr_pref + 1);
> + usr_hdr_pref->hdr_num = 0;
> +
> + /* If ref counter is 1, then page owned during
> + * allocation and not mapped, so insert it to
> + * the output array. It will be mapped.
> + */
> + if (page_ref_count(usr_hdr_page) == 1) {
> + pages[pages_cnt++] = usr_hdr_page;
> + /* Inc ref one more, as AF_VSOCK layer calls
> + * 'put_page()' for each returned page.
> + */
> + get_page(usr_hdr_page);
> + } else {
> + pages[pages_cnt++] = NULL;
> + }
> +
> + /* Polling page is already mapped. */
> + while (!list_empty(&vvs->rx_queue) &&
> + pages_cnt < *pages_num) {
> + struct virtio_vsock_pkt *pkt;
> + ssize_t rest_data_bytes;
> + size_t moved_data_bytes;
> + unsigned long pg_offs;
> +
> + pkt = list_first_entry(&vvs->rx_queue,
> + struct virtio_vsock_pkt, list);
> +
> + rest_data_bytes = le32_to_cpu(pkt->hdr.len) - pkt->off;
> +
> + /* For packets, bigger than one page, split it's
> + * high order allocated buffer to 0 order pages.
> + * Otherwise 'vm_insert_pages()' will fail, for
> + * all pages except first.
> + */
> + if (rest_data_bytes > PAGE_SIZE) {
> + /* High order buffer not split yet. */
> + if (!pkt->split) {
> + split_page(virt_to_page(pkt->buf),
> + get_order(le32_to_cpu(pkt->hdr.len)));
> + pkt->split = true;
> + }
> + }
> +
> + pg_offs = pkt->off;
> + moved_data_bytes = 0;
> +
> + while (rest_data_bytes &&
> + pages_cnt < *pages_num) {
> + struct page *buf_page;
> +
> + buf_page = virt_to_page(pkt->buf + pg_offs);
> +
> + pages[pages_cnt++] = buf_page;
> + /* Get reference to prevent this page being
> + * returned to page allocator when packet will
> + * be freed. Ref count will be 2.
> + */
> + get_page(buf_page);
> + pg_offs += PAGE_SIZE;
> +
> + if (rest_data_bytes >= PAGE_SIZE) {
> + moved_data_bytes += PAGE_SIZE;
> + rest_data_bytes -= PAGE_SIZE;
> + } else {
> + moved_data_bytes += rest_data_bytes;
> + rest_data_bytes = 0;
> + }
> + }
> +
> + if (!rest_data_bytes)
> + usr_hdr_buffer->flags = le32_to_cpu(pkt->hdr.flags);
> + else
> + usr_hdr_buffer->flags = 0;
> +
> + usr_hdr_buffer->len = moved_data_bytes;
> +
> + usr_hdr_buffer++;
> + usr_hdr_pref->hdr_num++;
> +
> + pkt->off = pg_offs;
> +
> + if (rest_data_bytes == 0) {
> + list_del(&pkt->list);
> + virtio_transport_dec_rx_pkt(vvs, pkt);
> + virtio_transport_free_pkt(pkt);
> + }
> +
> + /* Now ref count for all pages of packet is 1. */
> + }
> +
> + if (*pages_num - 1 < max_usr_hdrs)
> + usr_hdr_buffer->len = 0;
> +
> + spin_unlock_bh(&vvs->rx_lock);
> +
> + virtio_transport_send_credit_update(vsk);
> +
> + *pages_num = pages_cnt;
> +
> + return 0;
> +}
> +EXPORT_SYMBOL_GPL(virtio_transport_zerocopy_dequeue);
> +
> static ssize_t
> virtio_transport_stream_do_dequeue(struct vsock_sock *vsk,
> struct msghdr *msg,
> @@ -969,11 +1171,21 @@ void virtio_transport_release(struct vsock_sock *vsk)
> {
> struct sock *sk = &vsk->sk;
> bool remove_sock = true;
> + struct virtio_vsock_sock *vvs = vsk->trans;
>
> if (sk->sk_type == SOCK_STREAM || sk->sk_type == SOCK_SEQPACKET)
> remove_sock = virtio_transport_close(vsk);
>
> if (remove_sock) {
> + spin_lock_bh(&vvs->rx_lock);
> +
> + if (vvs->usr_poll_page) {
> + __free_page(vvs->usr_poll_page);
> + vvs->usr_poll_page = NULL;
> + }
> +
> + spin_unlock_bh(&vvs->rx_lock);
> +
> sock_set_flag(sk, SOCK_DONE);
> virtio_transport_remove_sock(vsk);
> }
> @@ -1077,6 +1289,7 @@ virtio_transport_recv_connected(struct sock *sk,
> struct virtio_vsock_pkt *pkt)
> {
> struct vsock_sock *vsk = vsock_sk(sk);
> + struct virtio_vsock_sock *vvs = vsk->trans;
> int err = 0;
>
> switch (le16_to_cpu(pkt->hdr.op)) {
> @@ -1095,6 +1308,19 @@ virtio_transport_recv_connected(struct sock *sk,
> vsk->peer_shutdown |= RCV_SHUTDOWN;
> if (le32_to_cpu(pkt->hdr.flags) & VIRTIO_VSOCK_SHUTDOWN_SEND)
> vsk->peer_shutdown |= SEND_SHUTDOWN;
> +
> + spin_lock_bh(&vvs->rx_lock);
> +
> + if (vvs->usr_poll_page) {
> + struct virtio_vsock_usr_hdr_pref *hdr;
> +
> + hdr = (struct virtio_vsock_usr_hdr_pref *)page_to_virt(vvs->usr_poll_page);
> +
> + hdr->poll_value = 0xffffffff;
> + }
> +
> + spin_unlock_bh(&vvs->rx_lock);
> +
> if (vsk->peer_shutdown == SHUTDOWN_MASK &&
> vsock_stream_has_data(vsk) <= 0 &&
> !sock_flag(sk, SOCK_DONE)) {
> @@ -1343,11 +1569,21 @@ EXPORT_SYMBOL_GPL(virtio_transport_recv_pkt);
> void virtio_transport_free_pkt(struct virtio_vsock_pkt *pkt)
> {
> if (pkt->buf_len) {
> - if (pkt->slab_buf)
> + if (pkt->slab_buf) {
> kvfree(pkt->buf);
> - else
> - free_pages((unsigned long)pkt->buf,
> - get_order(pkt->buf_len));
> + } else {
> + unsigned int order = get_order(pkt->buf_len);
> + unsigned long buf = (unsigned long)pkt->buf;
> +
> + if (pkt->split) {
> + int i;
> +
> + for (i = 0; i < (1 << order); i++)
> + free_page(buf + i * PAGE_SIZE);
> + } else {
> + free_pages(buf, order);
> + }
> + }
> }
>
> kfree(pkt);