From 49559efc5cb26aadbcf580de03afd6e4ff67cedc Mon Sep 17 00:00:00 2001 From: Jean-Philippe Brucker Date: Wed, 6 Nov 2019 10:10:07 +0000 Subject: [PATCH] uacce: Track mm<->queue bonds The IOMMU core only tracks mm<->device bonds at the moment, because it only needs to handle IOTLB invalidation and PASID table entries. However uacce needs a finer granularity since multiple queues from the same device can be bound to an mm. When the mm exits, all bound queues must be stopped so that the IOMMU can safely clear the PASID table entry and reallocate the PASID. Introduce an intermediate struct uacce_mm that links uacce devices and queues. Note that an mm may be bound to multiple devices but an uacce_mm structure only ever belongs to a single device, because we don't need anything more complex (if multiple devices are bound to one mm, then we'll create one uacce_mm for each bond). uacce_device --+-- uacce_mm --+-- uacce_queue | '-- uacce_queue | '-- uacce_mm --+-- uacce_queue +-- uacce_queue '-- uacce_queue If multiple device drivers need this model, it should be possible to move it to iommu-sva in the future, with some changes to the API, and have mm_exit() be called for multiple contexts per iommu_bond. Signed-off-by: Jean-Philippe Brucker --- drivers/misc/uacce/uacce.c | 174 +++++++++++++++++++++++++++---------- include/linux/uacce.h | 34 ++++++-- 2 files changed, 152 insertions(+), 56 deletions(-) diff --git a/drivers/misc/uacce/uacce.c b/drivers/misc/uacce/uacce.c index 2b6b03855ac6..d8a7fbfe7399 100644 --- a/drivers/misc/uacce/uacce.c +++ b/drivers/misc/uacce/uacce.c @@ -92,15 +92,19 @@ static long uacce_fops_compat_ioctl(struct file *filep, static int uacce_sva_exit(struct device *dev, struct iommu_sva *handle, void *data) { - struct uacce_device *uacce = data; + struct uacce_mm *uacce_mm = data; struct uacce_queue *q; - mutex_lock(&uacce->q_lock); - list_for_each_entry(q, &uacce->qs, list) { - if (q->pid == task_pid_nr(current)) - uacce_put_queue(q); - } - mutex_unlock(&uacce->q_lock); + /* + * No new queue can be added concurrently because no caller can have a + * reference to this mm. But there may be concurrent calls to + * uacce_mm_put(), so we need the lock. + */ + mutex_lock(&uacce_mm->lock); + list_for_each_entry(q, &uacce_mm->queues, list) + uacce_put_queue(q); + uacce_mm->mm = NULL; + mutex_unlock(&uacce_mm->lock); return 0; } @@ -109,13 +113,88 @@ static struct iommu_sva_ops uacce_sva_ops = { .mm_exit = uacce_sva_exit, }; -static int uacce_fops_open(struct inode *inode, struct file *filep) +static struct uacce_mm *uacce_mm_get(struct uacce_device *uacce, + struct uacce_queue *q, + struct mm_struct *mm) { + struct uacce_mm *uacce_mm = NULL; struct iommu_sva *handle = NULL; + int ret; + + lockdep_assert_held(&uacce->mm_lock); + + list_for_each_entry(uacce_mm, &uacce->mm_list, list) { + if (uacce_mm->mm == mm) { + mutex_lock(&uacce_mm->lock); + list_add(&q->list, &uacce_mm->queues); + mutex_unlock(&uacce_mm->lock); + return uacce_mm; + } + } + + uacce_mm = kzalloc(sizeof(*uacce_mm), GFP_KERNEL); + if (!uacce_mm) + return NULL; + + if (uacce->flags & UACCE_DEV_SVA) { + /* + * Safe to pass an incomplete uacce_mm, since mm_exit cannot + * fire while we hold a reference to the mm. + */ + handle = iommu_sva_bind_device(uacce->pdev, mm, uacce_mm); + if (IS_ERR(handle)) + goto err_free; + + ret = iommu_sva_set_ops(handle, &uacce_sva_ops); + if (ret) + goto err_unbind; + + uacce_mm->pasid = iommu_sva_get_pasid(handle); + if (uacce_mm->pasid == IOMMU_PASID_INVALID) + goto err_unbind; + } + + uacce_mm->mm = mm; + uacce_mm->handle = handle; + INIT_LIST_HEAD(&uacce_mm->queues); + mutex_init(&uacce_mm->lock); + list_add(&q->list, &uacce_mm->queues); + list_add(&uacce_mm->list, &uacce->mm_list); + + return uacce_mm; + +err_unbind: + if (handle) + iommu_sva_unbind_device(handle); +err_free: + kfree(uacce_mm); + return NULL; +} + +static void uacce_mm_put(struct uacce_queue *q) +{ + struct uacce_mm *uacce_mm = q->uacce_mm; + + lockdep_assert_held(&q->uacce->mm_lock); + + mutex_lock(&uacce_mm->lock); + list_del(&q->list); + mutex_unlock(&uacce_mm->lock); + + if (list_empty(&uacce_mm->queues)) { + if (uacce_mm->handle) + iommu_sva_unbind_device(uacce_mm->handle); + list_del(&uacce_mm->list); + kfree(uacce_mm); + } +} + +static int uacce_fops_open(struct inode *inode, struct file *filep) +{ + struct uacce_mm *uacce_mm = NULL; struct uacce_device *uacce; struct uacce_queue *q; int ret = 0; - int pasid = 0; uacce = xa_load(&uacce_xa, iminor(inode)); if (!uacce) @@ -130,44 +209,37 @@ static int uacce_fops_open(struct inode *inode, struct file *filep) goto out_with_module; } - if (uacce->flags & UACCE_DEV_SVA) { - handle = iommu_sva_bind_device(uacce->pdev, current->mm, uacce); - if (IS_ERR(handle)) - goto out_with_mem; - - ret = iommu_sva_set_ops(handle, &uacce_sva_ops); - if (ret) - goto out_unbind; + q->state = UACCE_Q_ZOMBIE; - pasid = iommu_sva_get_pasid(handle); - if (pasid == IOMMU_PASID_INVALID) - goto out_unbind; + mutex_lock(&uacce->mm_lock); + uacce_mm = uacce_mm_get(uacce, q, current->mm); + mutex_unlock(&uacce->mm_lock); + if (!uacce_mm) { + ret = -ENOMEM; + goto out_with_mem; } + q->uacce = uacce; + q->uacce_mm = uacce_mm; + if (uacce->ops->get_queue) { - ret = uacce->ops->get_queue(uacce, pasid, q); + ret = uacce->ops->get_queue(uacce, uacce_mm->pasid, q); if (ret < 0) - goto out_unbind; + goto out_with_mm; } q->pid = task_pid_nr(current); - q->pasid = pasid; - q->handle = handle; - q->uacce = uacce; memset(q->qfrs, 0, sizeof(q->qfrs)); init_waitqueue_head(&q->wait); filep->private_data = q; q->state = UACCE_Q_INIT; - mutex_lock(&uacce->q_lock); - list_add(&q->list, &uacce->qs); - mutex_unlock(&uacce->q_lock); - return 0; -out_unbind: - if (uacce->flags & UACCE_DEV_SVA) - iommu_sva_unbind_device(handle); +out_with_mm: + mutex_lock(&uacce->mm_lock); + uacce_mm_put(q); + mutex_unlock(&uacce->mm_lock); out_with_mem: kfree(q); out_with_module: @@ -182,12 +254,10 @@ static int uacce_fops_release(struct inode *inode, struct file *filep) uacce_put_queue(q); - if (uacce->flags & UACCE_DEV_SVA) - iommu_sva_unbind_device(q->handle); + mutex_lock(&uacce->mm_lock); + uacce_mm_put(q); + mutex_unlock(&uacce->mm_lock); - mutex_lock(&uacce->q_lock); - list_del(&q->list); - mutex_unlock(&uacce->q_lock); kfree(q); module_put(uacce->pdev->driver->owner); @@ -484,8 +554,8 @@ struct uacce_device *uacce_register(struct device *parent, goto err_with_xa; } - INIT_LIST_HEAD(&uacce->qs); - mutex_init(&uacce->q_lock); + INIT_LIST_HEAD(&uacce->mm_list); + mutex_init(&uacce->mm_lock); uacce->cdev->ops = &uacce_fops; uacce->cdev->owner = THIS_MODULE; device_initialize(&uacce->dev); @@ -519,20 +589,30 @@ EXPORT_SYMBOL_GPL(uacce_register); */ void uacce_unregister(struct uacce_device *uacce) { + struct uacce_mm *uacce_mm; + struct uacce_queue *q; + if (!uacce) return; - mutex_lock(&uacce->q_lock); - if (!list_empty(&uacce->qs)) { - struct uacce_queue *q; - - list_for_each_entry(q, &uacce->qs, list) { + mutex_lock(&uacce->mm_lock); + list_for_each_entry(uacce_mm, &uacce->mm_list, list) { + /* + * We don't take the uacce_mm->lock here. Since we hold the + * device's mm_lock, no queue can be added to or removed from + * this uacce_mm. We may run concurrently with mm_exit, but + * uacce_put_queue() is serialized and iommu_sva_unbind_device() + * waits for the lock that mm_exit is holding. + */ + list_for_each_entry(q, &uacce_mm->queues, list) uacce_put_queue(q); - if (uacce->flags & UACCE_DEV_SVA) - iommu_sva_unbind_device(q->handle); + + if (uacce->flags & UACCE_DEV_SVA) { + iommu_sva_unbind_device(uacce_mm->handle); + uacce_mm->handle = NULL; } } - mutex_unlock(&uacce->q_lock); + mutex_unlock(&uacce->mm_lock); if (uacce->flags & UACCE_DEV_SVA) iommu_dev_disable_feature(uacce->pdev, IOMMU_DEV_FEAT_SVA); diff --git a/include/linux/uacce.h b/include/linux/uacce.h index 04c8643c130b..8564e078287a 100644 --- a/include/linux/uacce.h +++ b/include/linux/uacce.h @@ -88,10 +88,9 @@ enum uacce_q_state { * @uacce: pointer to uacce * @priv: private pointer * @wait: wait queue head - * @pasid: pasid of the queue * @pid: pid of the process using the queue - * @handle: iommu_sva handle return from iommu_sva_bind_device - * @list: queue list + * @list: index into uacce_mm + * @uacce_mm: the corresponding mm * @qfrs: pointer of qfr regions * @state: queue state machine */ @@ -99,10 +98,9 @@ struct uacce_queue { struct uacce_device *uacce; void *priv; wait_queue_head_t wait; - int pasid; pid_t pid; - struct iommu_sva *handle; struct list_head list; + struct uacce_mm *uacce_mm; struct uacce_qfile_region *qfrs[UACCE_QFRT_MAX]; enum uacce_q_state state; }; @@ -121,8 +119,8 @@ struct uacce_queue { * @cdev: cdev of the uacce * @dev: dev of the uacce * @priv: private pointer of the uacce - * @qs: list head of queue->list - * @q_lock: lock for qs + * @mm_list: list head of uacce_mm->list + * @mm_lock: lock for mm_list */ struct uacce_device { const char *algs; @@ -137,8 +135,26 @@ struct uacce_device { struct cdev *cdev; struct device dev; void *priv; - struct list_head qs; - struct mutex q_lock; + struct list_head mm_list; + struct mutex mm_lock; +}; + +/* + * struct uacce_mm - keep track of queues bound to a process + * @list: index into uacce_device + * @queues: list of queues + * @mm: the mm struct + * @lock: protects the list of queues + * @pasid: pasid of the queue + * @handle: iommu_sva handle return from iommu_sva_bind_device + */ +struct uacce_mm { + struct list_head list; + struct list_head queues; + struct mm_struct *mm; + struct mutex lock; + int pasid; + struct iommu_sva *handle; }; #if IS_ENABLED(CONFIG_UACCE) -- 2.20.1