Re: [PATCH] tools/virtio: Test virtual address range detection

From: Jason Wang
Date: Mon Feb 21 2022 - 22:25:37 EST


On Tue, Feb 22, 2022 at 12:17 AM David Woodhouse <dwmw2@xxxxxxxxxxxxx> wrote:
>
> As things stand, an application which wants to use vhost with a trivial
> 1:1 mapping of its virtual address space is forced to jump through hoops
> to detect what the address range might be. The VHOST_SET_MEM_TABLE ioctl
> helpfully doesn't fail immediately; you only get a failure *later* when
> you attempt to set the backend, if the table *could* map to an address
> which is out of range, even if no out-of-range address is actually
> being referenced.
>
> Since userspace is growing workarounds for this lovely kernel API, let's
> ensure that we have a regression test that does things basically the same
> way as https://gitlab.com/openconnect/openconnect/-/commit/443edd9d8826
> does.

I wonder if it's useful to have a small library that wraps vhost
kernel uAPI somewhere.

(In the future, we may want to let the kernel accept 1:1 mapping by
figuring out the illegal range by itself?)

Thanks

>
> This is untested as I can't actually get virtio_test to work at all; it
> just seems to deadlock on a spinlock. But it's getting the right answer
> for the virtio range on x86_64 at least.
>
> Signed-off-by: David Woodhouse <dwmw2@xxxxxxxxxxxxx>
> ---
>
> Please, tell me I don't need to do this. But if I *do*, it needs a
> regression test in-kernel.
>
> tools/virtio/virtio_test.c | 109 ++++++++++++++++++++++++++++++++++++-
> 1 file changed, 106 insertions(+), 3 deletions(-)
>
> diff --git a/tools/virtio/virtio_test.c b/tools/virtio/virtio_test.c
> index cb3f29c09aff..e40eeeb05b71 100644
> --- a/tools/virtio/virtio_test.c
> +++ b/tools/virtio/virtio_test.c
> @@ -11,6 +11,7 @@
> #include <sys/ioctl.h>
> #include <sys/stat.h>
> #include <sys/types.h>
> +#include <sys/mman.h>
> #include <fcntl.h>
> #include <stdbool.h>
> #include <linux/virtio_types.h>
> @@ -124,6 +125,109 @@ static void vq_info_add(struct vdev_info *dev, int num)
> dev->nvqs++;
> }
>
> +/*
> + * This is awful. The kernel doesn't let us just ask for a 1:1 mapping of
> + * our virtual address space; we have to *know* the minimum and maximum
> + * addresses. We can't test it directly with VHOST_SET_MEM_TABLE because
> + * that actually succeeds, and the failure only occurs later when we try
> + * to use a buffer at an address that *is* valid, but our memory table
> + * *could* point to addresses that aren't. Ewww.
> + *
> + * So... attempt to work out what TASK_SIZE is for the kernel we happen
> + * to be running on right now...
> + */
> +
> +static int testaddr(unsigned long addr)
> +{
> + void *res = mmap((void *)addr, getpagesize(), PROT_NONE,
> + MAP_FIXED|MAP_ANONYMOUS, -1, 0);
> + if (res == MAP_FAILED) {
> + if (errno == EEXIST || errno == EINVAL)
> + return 1;
> +
> + /* We get ENOMEM for a bad virtual address */
> + return 0;
> + }
> + /* It shouldn't actually succeed without either MAP_SHARED or
> + * MAP_PRIVATE in the flags, but just in case... */
> + munmap((void *)addr, getpagesize());
> + return 1;
> +}
> +
> +static int find_vmem_range(struct vhost_memory *vmem)
> +{
> + const unsigned long page_size = getpagesize();
> + unsigned long top;
> + unsigned long bottom;
> +
> + top = -page_size;
> +
> + if (testaddr(top)) {
> + vmem->regions[0].memory_size = top;
> + goto out;
> + }
> +
> + /* 'top' is the lowest address known *not* to work */
> + bottom = top;
> + while (1) {
> + bottom >>= 1;
> + bottom &= ~(page_size - 1);
> + assert(bottom);
> +
> + if (testaddr(bottom))
> + break;
> + top = bottom;
> + }
> +
> + /* It's often a page or two below the boundary */
> + top -= page_size;
> + if (testaddr(top)) {
> + vmem->regions[0].memory_size = top;
> + goto out;
> + }
> + top -= page_size;
> + if (testaddr(top)) {
> + vmem->regions[0].memory_size = top;
> + goto out;
> + }
> +
> + /* Now, bottom is the highest address known to work,
> + and we must search between it and 'top' which is
> + the lowest address known not to. */
> + while (bottom + page_size != top) {
> + unsigned long test = bottom + (top - bottom) / 2;
> + test &= ~(page_size - 1);
> +
> + if (testaddr(test)) {
> + bottom = test;
> + continue;
> + }
> + test -= page_size;
> + if (testaddr(test)) {
> + vmem->regions[0].memory_size = test;
> + goto out;
> + }
> +
> + test -= page_size;
> + if (testaddr(test)) {
> + vmem->regions[0].memory_size = test;
> + goto out;
> + }
> + top = test;
> + }
> + vmem->regions[0].memory_size = bottom;
> +
> + out:
> + vmem->regions[0].guest_phys_addr = page_size;
> + vmem->regions[0].userspace_addr = page_size;
> + printf("Detected virtual address range 0x%lx-0x%lx\n",
> + page_size,
> + (unsigned long)(page_size + vmem->regions[0].memory_size));
> +
> + return 0;
> +}
> +
> +
> static void vdev_info_init(struct vdev_info* dev, unsigned long long features)
> {
> int r;
> @@ -143,9 +247,8 @@ static void vdev_info_init(struct vdev_info* dev, unsigned long long features)
> memset(dev->mem, 0, offsetof(struct vhost_memory, regions) +
> sizeof dev->mem->regions[0]);
> dev->mem->nregions = 1;
> - dev->mem->regions[0].guest_phys_addr = (long)dev->buf;
> - dev->mem->regions[0].userspace_addr = (long)dev->buf;
> - dev->mem->regions[0].memory_size = dev->buf_size;
> + r = find_vmem_range(dev->mem);
> + assert(r >= 0);
> r = ioctl(dev->control, VHOST_SET_MEM_TABLE, dev->mem);
> assert(r >= 0);
> }
>
>