Re: [PATCH vhost v2 7/8] vdpa/mlx5: Introduce reference counting to mrs

From: Eugenio Perez Martin
Date: Tue Dec 12 2023 - 13:27:25 EST


On Tue, Dec 5, 2023 at 11:47 AM Dragos Tatulea <dtatulea@xxxxxxxxxx> wrote:
>
> Deleting the old mr during mr update (.set_map) and then modifying the
> vqs with the new mr is not a good flow for firmware. The firmware
> expects that mkeys are deleted after there are no more vqs referencing
> them.
>
> Introduce reference counting for mrs to fix this. It is the only way to
> make sure that mkeys are not in use by vqs.
>
> An mr reference is taken when the mr is associated to the mr asid table
> and when the mr is linked to the vq on create/modify. The reference is
> released when the mkey is unlinked from the vq (trough modify/destroy)
> and from the mr asid table.
>
> To make things consistent, get rid of mlx5_vdpa_destroy_mr and use
> get/put semantics everywhere.
>
> Signed-off-by: Dragos Tatulea <dtatulea@xxxxxxxxxx>
> Reviewed-by: Gal Pressman <gal@xxxxxxxxxx>

Acked-by: Eugenio Pérez <eperezma@xxxxxxxxxx>

> ---
> drivers/vdpa/mlx5/core/mlx5_vdpa.h | 8 +++--
> drivers/vdpa/mlx5/core/mr.c | 50 ++++++++++++++++++++----------
> drivers/vdpa/mlx5/net/mlx5_vnet.c | 45 ++++++++++++++++++++++-----
> 3 files changed, 78 insertions(+), 25 deletions(-)
>
> diff --git a/drivers/vdpa/mlx5/core/mlx5_vdpa.h b/drivers/vdpa/mlx5/core/mlx5_vdpa.h
> index 84547d998bcf..1a0d27b6e09a 100644
> --- a/drivers/vdpa/mlx5/core/mlx5_vdpa.h
> +++ b/drivers/vdpa/mlx5/core/mlx5_vdpa.h
> @@ -35,6 +35,8 @@ struct mlx5_vdpa_mr {
> struct vhost_iotlb *iotlb;
>
> bool user_mr;
> +
> + refcount_t refcount;
> };
>
> struct mlx5_vdpa_resources {
> @@ -118,8 +120,10 @@ int mlx5_vdpa_destroy_mkey(struct mlx5_vdpa_dev *mvdev, u32 mkey);
> struct mlx5_vdpa_mr *mlx5_vdpa_create_mr(struct mlx5_vdpa_dev *mvdev,
> struct vhost_iotlb *iotlb);
> void mlx5_vdpa_destroy_mr_resources(struct mlx5_vdpa_dev *mvdev);
> -void mlx5_vdpa_destroy_mr(struct mlx5_vdpa_dev *mvdev,
> - struct mlx5_vdpa_mr *mr);
> +void mlx5_vdpa_get_mr(struct mlx5_vdpa_dev *mvdev,
> + struct mlx5_vdpa_mr *mr);
> +void mlx5_vdpa_put_mr(struct mlx5_vdpa_dev *mvdev,
> + struct mlx5_vdpa_mr *mr);
> void mlx5_vdpa_update_mr(struct mlx5_vdpa_dev *mvdev,
> struct mlx5_vdpa_mr *mr,
> unsigned int asid);
> diff --git a/drivers/vdpa/mlx5/core/mr.c b/drivers/vdpa/mlx5/core/mr.c
> index 2197c46e563a..c7dc8914354a 100644
> --- a/drivers/vdpa/mlx5/core/mr.c
> +++ b/drivers/vdpa/mlx5/core/mr.c
> @@ -498,32 +498,52 @@ static void destroy_user_mr(struct mlx5_vdpa_dev *mvdev, struct mlx5_vdpa_mr *mr
>
> static void _mlx5_vdpa_destroy_mr(struct mlx5_vdpa_dev *mvdev, struct mlx5_vdpa_mr *mr)
> {
> + if (WARN_ON(!mr))
> + return;
> +
> if (mr->user_mr)
> destroy_user_mr(mvdev, mr);
> else
> destroy_dma_mr(mvdev, mr);
>
> vhost_iotlb_free(mr->iotlb);
> +
> + kfree(mr);
> }
>
> -void mlx5_vdpa_destroy_mr(struct mlx5_vdpa_dev *mvdev,
> - struct mlx5_vdpa_mr *mr)
> +static void _mlx5_vdpa_put_mr(struct mlx5_vdpa_dev *mvdev,
> + struct mlx5_vdpa_mr *mr)
> {
> if (!mr)
> return;
>
> + if (refcount_dec_and_test(&mr->refcount))
> + _mlx5_vdpa_destroy_mr(mvdev, mr);
> +}
> +
> +void mlx5_vdpa_put_mr(struct mlx5_vdpa_dev *mvdev,
> + struct mlx5_vdpa_mr *mr)
> +{
> mutex_lock(&mvdev->mr_mtx);
> + _mlx5_vdpa_put_mr(mvdev, mr);
> + mutex_unlock(&mvdev->mr_mtx);
> +}
>
> - _mlx5_vdpa_destroy_mr(mvdev, mr);
> +static void _mlx5_vdpa_get_mr(struct mlx5_vdpa_dev *mvdev,
> + struct mlx5_vdpa_mr *mr)
> +{
> + if (!mr)
> + return;
>
> - for (int i = 0; i < MLX5_VDPA_NUM_AS; i++) {
> - if (mvdev->mr[i] == mr)
> - mvdev->mr[i] = NULL;
> - }
> + refcount_inc(&mr->refcount);
> +}
>
> +void mlx5_vdpa_get_mr(struct mlx5_vdpa_dev *mvdev,
> + struct mlx5_vdpa_mr *mr)
> +{
> + mutex_lock(&mvdev->mr_mtx);
> + _mlx5_vdpa_get_mr(mvdev, mr);
> mutex_unlock(&mvdev->mr_mtx);
> -
> - kfree(mr);
> }
>
> void mlx5_vdpa_update_mr(struct mlx5_vdpa_dev *mvdev,
> @@ -534,20 +554,16 @@ void mlx5_vdpa_update_mr(struct mlx5_vdpa_dev *mvdev,
>
> mutex_lock(&mvdev->mr_mtx);
>
> + _mlx5_vdpa_put_mr(mvdev, old_mr);
> mvdev->mr[asid] = new_mr;
> - if (old_mr) {
> - _mlx5_vdpa_destroy_mr(mvdev, old_mr);
> - kfree(old_mr);
> - }
>
> mutex_unlock(&mvdev->mr_mtx);
> -
> }
>
> void mlx5_vdpa_destroy_mr_resources(struct mlx5_vdpa_dev *mvdev)
> {
> for (int i = 0; i < MLX5_VDPA_NUM_AS; i++)
> - mlx5_vdpa_destroy_mr(mvdev, mvdev->mr[i]);
> + mlx5_vdpa_update_mr(mvdev, NULL, i);
>
> prune_iotlb(mvdev->cvq.iotlb);
> }
> @@ -607,6 +623,8 @@ struct mlx5_vdpa_mr *mlx5_vdpa_create_mr(struct mlx5_vdpa_dev *mvdev,
> if (err)
> goto out_err;
>
> + refcount_set(&mr->refcount, 1);
> +
> return mr;
>
> out_err:
> @@ -651,7 +669,7 @@ int mlx5_vdpa_reset_mr(struct mlx5_vdpa_dev *mvdev, unsigned int asid)
> if (asid >= MLX5_VDPA_NUM_AS)
> return -EINVAL;
>
> - mlx5_vdpa_destroy_mr(mvdev, mvdev->mr[asid]);
> + mlx5_vdpa_update_mr(mvdev, NULL, asid);
>
> if (asid == 0 && MLX5_CAP_GEN(mvdev->mdev, umem_uid_0)) {
> if (mlx5_vdpa_create_dma_mr(mvdev))
> diff --git a/drivers/vdpa/mlx5/net/mlx5_vnet.c b/drivers/vdpa/mlx5/net/mlx5_vnet.c
> index 6a21223d97a8..133cbb66dcfe 100644
> --- a/drivers/vdpa/mlx5/net/mlx5_vnet.c
> +++ b/drivers/vdpa/mlx5/net/mlx5_vnet.c
> @@ -123,6 +123,9 @@ struct mlx5_vdpa_virtqueue {
>
> u64 modified_fields;
>
> + struct mlx5_vdpa_mr *vq_mr;
> + struct mlx5_vdpa_mr *desc_mr;
> +
> struct msi_map map;
>
> /* keep last in the struct */
> @@ -946,6 +949,14 @@ static int create_virtqueue(struct mlx5_vdpa_net *ndev, struct mlx5_vdpa_virtque
> kfree(in);
> mvq->virtq_id = MLX5_GET(general_obj_out_cmd_hdr, out, obj_id);
>
> + mlx5_vdpa_get_mr(mvdev, vq_mr);
> + mvq->vq_mr = vq_mr;
> +
> + if (vq_desc_mr && MLX5_CAP_DEV_VDPA_EMULATION(mvdev->mdev, desc_group_mkey_supported)) {
> + mlx5_vdpa_get_mr(mvdev, vq_desc_mr);
> + mvq->desc_mr = vq_desc_mr;
> + }
> +
> return 0;
>
> err_cmd:
> @@ -972,6 +983,12 @@ static void destroy_virtqueue(struct mlx5_vdpa_net *ndev, struct mlx5_vdpa_virtq
> }
> mvq->fw_state = MLX5_VIRTIO_NET_Q_OBJECT_NONE;
> umems_destroy(ndev, mvq);
> +
> + mlx5_vdpa_put_mr(&ndev->mvdev, mvq->vq_mr);
> + mvq->vq_mr = NULL;
> +
> + mlx5_vdpa_put_mr(&ndev->mvdev, mvq->desc_mr);
> + mvq->desc_mr = NULL;
> }
>
> static u32 get_rqpn(struct mlx5_vdpa_virtqueue *mvq, bool fw)
> @@ -1207,6 +1224,8 @@ static int modify_virtqueue(struct mlx5_vdpa_net *ndev,
> int inlen = MLX5_ST_SZ_BYTES(modify_virtio_net_q_in);
> u32 out[MLX5_ST_SZ_DW(modify_virtio_net_q_out)] = {};
> struct mlx5_vdpa_dev *mvdev = &ndev->mvdev;
> + struct mlx5_vdpa_mr *desc_mr = NULL;
> + struct mlx5_vdpa_mr *vq_mr = NULL;
> bool state_change = false;
> void *obj_context;
> void *cmd_hdr;
> @@ -1257,19 +1276,19 @@ static int modify_virtqueue(struct mlx5_vdpa_net *ndev,
> MLX5_SET(virtio_net_q_object, obj_context, hw_used_index, mvq->used_idx);
>
> if (mvq->modified_fields & MLX5_VIRTQ_MODIFY_MASK_VIRTIO_Q_MKEY) {
> - struct mlx5_vdpa_mr *mr = mvdev->mr[mvdev->group2asid[MLX5_VDPA_DATAVQ_GROUP]];
> + vq_mr = mvdev->mr[mvdev->group2asid[MLX5_VDPA_DATAVQ_GROUP]];
>
> - if (mr)
> - MLX5_SET(virtio_q, vq_ctx, virtio_q_mkey, mr->mkey);
> + if (vq_mr)
> + MLX5_SET(virtio_q, vq_ctx, virtio_q_mkey, vq_mr->mkey);
> else
> mvq->modified_fields &= ~MLX5_VIRTQ_MODIFY_MASK_VIRTIO_Q_MKEY;
> }
>
> if (mvq->modified_fields & MLX5_VIRTQ_MODIFY_MASK_DESC_GROUP_MKEY) {
> - struct mlx5_vdpa_mr *mr = mvdev->mr[mvdev->group2asid[MLX5_VDPA_DATAVQ_DESC_GROUP]];
> + desc_mr = mvdev->mr[mvdev->group2asid[MLX5_VDPA_DATAVQ_DESC_GROUP]];
>
> - if (mr && MLX5_CAP_DEV_VDPA_EMULATION(mvdev->mdev, desc_group_mkey_supported))
> - MLX5_SET(virtio_q, vq_ctx, desc_group_mkey, mr->mkey);
> + if (desc_mr && MLX5_CAP_DEV_VDPA_EMULATION(mvdev->mdev, desc_group_mkey_supported))
> + MLX5_SET(virtio_q, vq_ctx, desc_group_mkey, desc_mr->mkey);
> else
> mvq->modified_fields &= ~MLX5_VIRTQ_MODIFY_MASK_DESC_GROUP_MKEY;
> }
> @@ -1282,6 +1301,18 @@ static int modify_virtqueue(struct mlx5_vdpa_net *ndev,
> if (state_change)
> mvq->fw_state = state;
>
> + if (mvq->modified_fields & MLX5_VIRTQ_MODIFY_MASK_VIRTIO_Q_MKEY) {
> + mlx5_vdpa_put_mr(mvdev, mvq->vq_mr);
> + mlx5_vdpa_get_mr(mvdev, vq_mr);
> + mvq->vq_mr = vq_mr;
> + }
> +
> + if (mvq->modified_fields & MLX5_VIRTQ_MODIFY_MASK_DESC_GROUP_MKEY) {
> + mlx5_vdpa_put_mr(mvdev, mvq->desc_mr);
> + mlx5_vdpa_get_mr(mvdev, desc_mr);
> + mvq->desc_mr = desc_mr;
> + }
> +
> mvq->modified_fields = 0;
>
> done:
> @@ -3095,7 +3126,7 @@ static int set_map_data(struct mlx5_vdpa_dev *mvdev, struct vhost_iotlb *iotlb,
> return mlx5_vdpa_update_cvq_iotlb(mvdev, iotlb, asid);
>
> out_err:
> - mlx5_vdpa_destroy_mr(mvdev, new_mr);
> + mlx5_vdpa_put_mr(mvdev, new_mr);
> return err;
> }
>
> --
> 2.42.0
>