All of lore.kernel.org
 help / color / mirror / Atom feed
From: Alistair Popple <apopple@nvidia.com>
To: Jason Gunthorpe <jgg@nvidia.com>
Cc: Sean Christopherson <seanjc@google.com>,
	Robin Murphy <robin.murphy@arm.com>,
	Andrew Morton <akpm@linux-foundation.org>,
	will@kernel.org, catalin.marinas@arm.com, linux-mm@kvack.org,
	linux-kernel@vger.kernel.org, nicolinc@nvidia.com,
	linux-arm-kernel@lists.infradead.org, kvm@vger.kernel.org,
	John Hubbard <jhubbard@nvidia.com>,
	zhi.wang.linux@gmail.com
Subject: Re: [PATCH 2/2] arm64: Notify on pte permission upgrades
Date: Fri, 09 Jun 2023 16:05:05 +1000	[thread overview]
Message-ID: <87h6rhw4i0.fsf@nvidia.com> (raw)
In-Reply-To: <87pm65wfmk.fsf@nvidia.com>


Alistair Popple <apopple@nvidia.com> writes:

> Alistair Popple <apopple@nvidia.com> writes:
>
>> Alistair Popple <apopple@nvidia.com> writes:
>>
>>>> On Tue, May 30, 2023 at 02:44:40PM -0700, Sean Christopherson wrote:
>>>>> > KVM already has locking for invalidate_start/end - it has to check
>>>>> > mmu_notifier_retry_cache() with the sequence numbers/etc around when
>>>>> > it does does hva_to_pfn()
>>>>> > 
>>>>> > The bug is that the kvm_vcpu_reload_apic_access_page() path is
>>>>> > ignoring this locking so it ignores in-progress range
>>>>> > invalidations. It should spin until the invalidation clears like other
>>>>> > places in KVM.
>>>>> > 
>>>>> > The comment is kind of misleading because drivers shouldn't be abusing
>>>>> > the iommu centric invalidate_range() thing to fix missing locking in
>>>>> > start/end users. :\
>>>>> > 
>>>>> > So if KVM could be fixed up we could make invalidate_range defined to
>>>>> > be an arch specific callback to synchronize the iommu TLB.
>>>>> 
>>>>> And maybe rename invalidate_range() and/or invalidate_range_{start,end}() to make
>>>>> it super obvious that they are intended for two different purposes?  E.g. instead
>>>>> of invalidate_range(), something like invalidate_secondary_tlbs().
>>>>
>>>> Yeah, I think I would call it invalidate_arch_secondary_tlb() and
>>>> document it as being an arch specific set of invalidations that match
>>>> the architected TLB maintenance requrements. And maybe we can check it
>>>> more carefully to make it be called in less places. Like I'm not sure
>>>> it is right to call it from invalidate_range_end under this new
>>>> definition..
>>>
>>> I'd be happy to look at that, although it sounds like Sean already is.
>
> Thanks Sean for getting the KVM fix posted so quickly. I'm looking into
> doing the rename now.
>
> Do we want to do more than a simple rename and tidy up of callers
> though? What I'm thinking is introducing something like an IOMMU/TLB
> specific variant of notifiers (eg. struct tlb_notifier) which has the
> invalidate_secondary_tlbs() callback in say struct tlb_notifier_ops
> rather than leaving that in the mmu_notifier_ops.
>
> Implementation wise we'd reuse most of the mmu_notifier code, but it
> would help make the two different uses of notifiers clearer. Thoughts?

So something like the below incomplete patch (against v6.2) which would
introduce a new struct tlb_notifier and associated ops. The change isn't
huge, but does result in some churn and another layer of indirection in
mmu_notifier.c. Otherwise we can just rename the callback.

--- 

diff --git a/drivers/iommu/arm/arm-smmu-v3/arm-smmu-v3-sva.c b/drivers/iommu/arm/arm-smmu-v3/arm-smmu-v3-sva.c
index a5a63b1c947e..c300cd435609 100644
--- a/drivers/iommu/arm/arm-smmu-v3/arm-smmu-v3-sva.c
+++ b/drivers/iommu/arm/arm-smmu-v3/arm-smmu-v3-sva.c
@@ -14,7 +14,7 @@
 #include "../../io-pgtable-arm.h"
 
 struct arm_smmu_mmu_notifier {
-	struct mmu_notifier		mn;
+	struct tlb_notifier		mn;
 	struct arm_smmu_ctx_desc	*cd;
 	bool				cleared;
 	refcount_t			refs;
@@ -186,7 +186,7 @@ static void arm_smmu_free_shared_cd(struct arm_smmu_ctx_desc *cd)
 	}
 }
 
-static void arm_smmu_mm_invalidate_range(struct mmu_notifier *mn,
+static void arm_smmu_mm_invalidate_range(struct tlb_notifier *mn,
 					 struct mm_struct *mm,
 					 unsigned long start, unsigned long end)
 {
@@ -207,7 +207,7 @@ static void arm_smmu_mm_invalidate_range(struct mmu_notifier *mn,
 	arm_smmu_atc_inv_domain(smmu_domain, mm->pasid, start, size);
 }
 
-static void arm_smmu_mm_release(struct mmu_notifier *mn, struct mm_struct *mm)
+static void arm_smmu_mm_release(struct tlb_notifier *mn, struct mm_struct *mm)
 {
 	struct arm_smmu_mmu_notifier *smmu_mn = mn_to_smmu(mn);
 	struct arm_smmu_domain *smmu_domain = smmu_mn->domain;
@@ -231,15 +231,15 @@ static void arm_smmu_mm_release(struct mmu_notifier *mn, struct mm_struct *mm)
 	mutex_unlock(&sva_lock);
 }
 
-static void arm_smmu_mmu_notifier_free(struct mmu_notifier *mn)
+static void arm_smmu_mmu_notifier_free(struct tlb_notifier *mn)
 {
 	kfree(mn_to_smmu(mn));
 }
 
-static const struct mmu_notifier_ops arm_smmu_mmu_notifier_ops = {
-	.invalidate_range	= arm_smmu_mm_invalidate_range,
-	.release		= arm_smmu_mm_release,
-	.free_notifier		= arm_smmu_mmu_notifier_free,
+static const struct tlb_notifier_ops arm_smmu_tlb_notifier_ops = {
+	.invalidate_secondary_tlbs = arm_smmu_mm_invalidate_range,
+	.release		   = arm_smmu_mm_release,
+	.free_notifier		   = arm_smmu_mmu_notifier_free,
 };
 
 /* Allocate or get existing MMU notifier for this {domain, mm} pair */
@@ -252,7 +252,7 @@ arm_smmu_mmu_notifier_get(struct arm_smmu_domain *smmu_domain,
 	struct arm_smmu_mmu_notifier *smmu_mn;
 
 	list_for_each_entry(smmu_mn, &smmu_domain->mmu_notifiers, list) {
-		if (smmu_mn->mn.mm == mm) {
+		if (smmu_mn->mn.mm_notifier_chain.mm == mm) {
 			refcount_inc(&smmu_mn->refs);
 			return smmu_mn;
 		}
@@ -271,9 +271,9 @@ arm_smmu_mmu_notifier_get(struct arm_smmu_domain *smmu_domain,
 	refcount_set(&smmu_mn->refs, 1);
 	smmu_mn->cd = cd;
 	smmu_mn->domain = smmu_domain;
-	smmu_mn->mn.ops = &arm_smmu_mmu_notifier_ops;
+	smmu_mn->mn.ops = &arm_smmu_tlb_notifier_ops;
 
-	ret = mmu_notifier_register(&smmu_mn->mn, mm);
+	ret = tlb_notifier_register(&smmu_mn->mn, mm);
 	if (ret) {
 		kfree(smmu_mn);
 		goto err_free_cd;
@@ -288,7 +288,7 @@ arm_smmu_mmu_notifier_get(struct arm_smmu_domain *smmu_domain,
 
 err_put_notifier:
 	/* Frees smmu_mn */
-	mmu_notifier_put(&smmu_mn->mn);
+	tlb_notifier_put(&smmu_mn->mn);
 err_free_cd:
 	arm_smmu_free_shared_cd(cd);
 	return ERR_PTR(ret);
@@ -296,7 +296,7 @@ arm_smmu_mmu_notifier_get(struct arm_smmu_domain *smmu_domain,
 
 static void arm_smmu_mmu_notifier_put(struct arm_smmu_mmu_notifier *smmu_mn)
 {
-	struct mm_struct *mm = smmu_mn->mn.mm;
+	struct mm_struct *mm = smmu_mn->mn.mm_notifier_chain.mm;
 	struct arm_smmu_ctx_desc *cd = smmu_mn->cd;
 	struct arm_smmu_domain *smmu_domain = smmu_mn->domain;
 
@@ -316,7 +316,7 @@ static void arm_smmu_mmu_notifier_put(struct arm_smmu_mmu_notifier *smmu_mn)
 	}
 
 	/* Frees smmu_mn */
-	mmu_notifier_put(&smmu_mn->mn);
+	tlb_notifier_put(&smmu_mn->mn);
 	arm_smmu_free_shared_cd(cd);
 }
 
diff --git a/include/linux/mmu_notifier.h b/include/linux/mmu_notifier.h
index d6c06e140277..157571497b28 100644
--- a/include/linux/mmu_notifier.h
+++ b/include/linux/mmu_notifier.h
@@ -13,6 +13,7 @@ struct mmu_notifier_subscriptions;
 struct mmu_notifier;
 struct mmu_notifier_range;
 struct mmu_interval_notifier;
+struct mm_notifier_chain;
 
 /**
  * enum mmu_notifier_event - reason for the mmu notifier callback
@@ -61,6 +62,18 @@ enum mmu_notifier_event {
 
 #define MMU_NOTIFIER_RANGE_BLOCKABLE (1 << 0)
 
+struct tlb_notifier;
+struct tlb_notifier_ops {
+	void (*invalidate_secondary_tlbs)(struct tlb_notifier *subscription,
+					struct mm_struct *mm,
+					unsigned long start,
+					unsigned long end);
+
+	void (*free_notifier)(struct tlb_notifier *subscription);
+	void (*release)(struct tlb_notifier *subscription,
+			struct mm_struct *mm);
+};
+
 struct mmu_notifier_ops {
 	/*
 	 * Called either by mmu_notifier_unregister or when the mm is
@@ -186,29 +199,6 @@ struct mmu_notifier_ops {
 	void (*invalidate_range_end)(struct mmu_notifier *subscription,
 				     const struct mmu_notifier_range *range);
 
-	/*
-	 * invalidate_range() is either called between
-	 * invalidate_range_start() and invalidate_range_end() when the
-	 * VM has to free pages that where unmapped, but before the
-	 * pages are actually freed, or outside of _start()/_end() when
-	 * a (remote) TLB is necessary.
-	 *
-	 * If invalidate_range() is used to manage a non-CPU TLB with
-	 * shared page-tables, it not necessary to implement the
-	 * invalidate_range_start()/end() notifiers, as
-	 * invalidate_range() already catches the points in time when an
-	 * external TLB range needs to be flushed. For more in depth
-	 * discussion on this see Documentation/mm/mmu_notifier.rst
-	 *
-	 * Note that this function might be called with just a sub-range
-	 * of what was passed to invalidate_range_start()/end(), if
-	 * called between those functions.
-	 */
-	void (*invalidate_range)(struct mmu_notifier *subscription,
-				 struct mm_struct *mm,
-				 unsigned long start,
-				 unsigned long end);
-
 	/*
 	 * These callbacks are used with the get/put interface to manage the
 	 * lifetime of the mmu_notifier memory. alloc_notifier() returns a new
@@ -234,14 +224,23 @@ struct mmu_notifier_ops {
  * 2. One of the reverse map locks is held (i_mmap_rwsem or anon_vma->rwsem).
  * 3. No other concurrent thread can access the list (release)
  */
-struct mmu_notifier {
+struct mm_notifier_chain {
 	struct hlist_node hlist;
-	const struct mmu_notifier_ops *ops;
 	struct mm_struct *mm;
 	struct rcu_head rcu;
 	unsigned int users;
 };
 
+struct mmu_notifier {
+	const struct mmu_notifier_ops *ops;
+	struct mm_notifier_chain mm_notifier_chain;
+};
+
+struct tlb_notifier {
+	const struct tlb_notifier_ops *ops;
+	struct mm_notifier_chain mm_notifier_chain;
+};
+
 /**
  * struct mmu_interval_notifier_ops
  * @invalidate: Upon return the caller must stop using any SPTEs within this
@@ -283,6 +282,10 @@ static inline int mm_has_notifiers(struct mm_struct *mm)
 	return unlikely(mm->notifier_subscriptions);
 }
 
+int tlb_notifier_register(struct tlb_notifier *subscription,
+			struct mm_struct *mm);
+void tlb_notifier_put(struct tlb_notifier *subscription);
+
 struct mmu_notifier *mmu_notifier_get_locked(const struct mmu_notifier_ops *ops,
 					     struct mm_struct *mm);
 static inline struct mmu_notifier *
diff --git a/mm/mmu_notifier.c b/mm/mmu_notifier.c
index f45ff1b7626a..cdc3a373a225 100644
--- a/mm/mmu_notifier.c
+++ b/mm/mmu_notifier.c
@@ -47,6 +47,16 @@ struct mmu_notifier_subscriptions {
 	struct hlist_head deferred_list;
 };
 
+struct mmu_notifier *mmu_notifier_from_chain(struct mm_notifier_chain *chain)
+{
+	return container_of(chain, struct mmu_notifier, mm_notifier_chain);
+}
+
+struct tlb_notifier *tlb_notifier_from_chain(struct mm_notifier_chain *chain)
+{
+	return container_of(chain, struct tlb_notifier, mm_notifier_chain);
+}
+
 /*
  * This is a collision-retry read-side/write-side 'lock', a lot like a
  * seqcount, however this allows multiple write-sides to hold it at
@@ -299,7 +309,7 @@ static void mn_itree_release(struct mmu_notifier_subscriptions *subscriptions,
 static void mn_hlist_release(struct mmu_notifier_subscriptions *subscriptions,
 			     struct mm_struct *mm)
 {
-	struct mmu_notifier *subscription;
+	struct mm_notifier_chain *chain;
 	int id;
 
 	/*
@@ -307,8 +317,10 @@ static void mn_hlist_release(struct mmu_notifier_subscriptions *subscriptions,
 	 * ->release returns.
 	 */
 	id = srcu_read_lock(&srcu);
-	hlist_for_each_entry_rcu(subscription, &subscriptions->list, hlist,
-				 srcu_read_lock_held(&srcu))
+	hlist_for_each_entry_rcu(chain, &subscriptions->list, hlist,
+				srcu_read_lock_held(&srcu)) {
+		struct mmu_notifier *subscription =
+			mmu_notifier_from_chain(chain);
 		/*
 		 * If ->release runs before mmu_notifier_unregister it must be
 		 * handled, as it's the only way for the driver to flush all
@@ -317,18 +329,19 @@ static void mn_hlist_release(struct mmu_notifier_subscriptions *subscriptions,
 		 */
 		if (subscription->ops->release)
 			subscription->ops->release(subscription, mm);
+	}
 
 	spin_lock(&subscriptions->lock);
 	while (unlikely(!hlist_empty(&subscriptions->list))) {
-		subscription = hlist_entry(subscriptions->list.first,
-					   struct mmu_notifier, hlist);
+		chain = hlist_entry(subscriptions->list.first,
+					   struct mm_notifier_chain, hlist);
 		/*
 		 * We arrived before mmu_notifier_unregister so
 		 * mmu_notifier_unregister will do nothing other than to wait
 		 * for ->release to finish and for mmu_notifier_unregister to
 		 * return.
 		 */
-		hlist_del_init_rcu(&subscription->hlist);
+		hlist_del_init_rcu(&chain->hlist);
 	}
 	spin_unlock(&subscriptions->lock);
 	srcu_read_unlock(&srcu, id);
@@ -366,13 +379,15 @@ int __mmu_notifier_clear_flush_young(struct mm_struct *mm,
 					unsigned long start,
 					unsigned long end)
 {
-	struct mmu_notifier *subscription;
+	struct mm_notifier_chain *chain;
 	int young = 0, id;
 
 	id = srcu_read_lock(&srcu);
-	hlist_for_each_entry_rcu(subscription,
+	hlist_for_each_entry_rcu(chain,
 				 &mm->notifier_subscriptions->list, hlist,
 				 srcu_read_lock_held(&srcu)) {
+		struct mmu_notifier *subscription =
+			mmu_notifier_from_chain(chain);
 		if (subscription->ops->clear_flush_young)
 			young |= subscription->ops->clear_flush_young(
 				subscription, mm, start, end);
@@ -386,13 +401,15 @@ int __mmu_notifier_clear_young(struct mm_struct *mm,
 			       unsigned long start,
 			       unsigned long end)
 {
-	struct mmu_notifier *subscription;
+	struct mm_notifier_chain *chain;
 	int young = 0, id;
 
 	id = srcu_read_lock(&srcu);
-	hlist_for_each_entry_rcu(subscription,
+	hlist_for_each_entry_rcu(chain,
 				 &mm->notifier_subscriptions->list, hlist,
 				 srcu_read_lock_held(&srcu)) {
+		struct mmu_notifier *subscription =
+			mmu_notifier_from_chain(chain);
 		if (subscription->ops->clear_young)
 			young |= subscription->ops->clear_young(subscription,
 								mm, start, end);
@@ -405,13 +422,15 @@ int __mmu_notifier_clear_young(struct mm_struct *mm,
 int __mmu_notifier_test_young(struct mm_struct *mm,
 			      unsigned long address)
 {
-	struct mmu_notifier *subscription;
+	struct mm_notifier_chain *chain;
 	int young = 0, id;
 
 	id = srcu_read_lock(&srcu);
-	hlist_for_each_entry_rcu(subscription,
+	hlist_for_each_entry_rcu(chain,
 				 &mm->notifier_subscriptions->list, hlist,
 				 srcu_read_lock_held(&srcu)) {
+		struct mmu_notifier *subscription =
+			mmu_notifier_from_chain(chain);
 		if (subscription->ops->test_young) {
 			young = subscription->ops->test_young(subscription, mm,
 							      address);
@@ -427,13 +446,15 @@ int __mmu_notifier_test_young(struct mm_struct *mm,
 void __mmu_notifier_change_pte(struct mm_struct *mm, unsigned long address,
 			       pte_t pte)
 {
-	struct mmu_notifier *subscription;
+	struct mm_notifier_chain *chain;
 	int id;
 
 	id = srcu_read_lock(&srcu);
-	hlist_for_each_entry_rcu(subscription,
+	hlist_for_each_entry_rcu(chain,
 				 &mm->notifier_subscriptions->list, hlist,
 				 srcu_read_lock_held(&srcu)) {
+		struct mmu_notifier *subscription =
+			mmu_notifier_from_chain(chain);
 		if (subscription->ops->change_pte)
 			subscription->ops->change_pte(subscription, mm, address,
 						      pte);
@@ -476,13 +497,15 @@ static int mn_hlist_invalidate_range_start(
 	struct mmu_notifier_subscriptions *subscriptions,
 	struct mmu_notifier_range *range)
 {
-	struct mmu_notifier *subscription;
+	struct mm_notifier_chain *chain;
 	int ret = 0;
 	int id;
 
 	id = srcu_read_lock(&srcu);
-	hlist_for_each_entry_rcu(subscription, &subscriptions->list, hlist,
+	hlist_for_each_entry_rcu(chain, &subscriptions->list, hlist,
 				 srcu_read_lock_held(&srcu)) {
+		struct mmu_notifier *subscription =
+			mmu_notifier_from_chain(chain);
 		const struct mmu_notifier_ops *ops = subscription->ops;
 
 		if (ops->invalidate_range_start) {
@@ -519,8 +542,10 @@ static int mn_hlist_invalidate_range_start(
 		 * notifiers and one or more failed start, any that succeeded
 		 * start are expecting their end to be called.  Do so now.
 		 */
-		hlist_for_each_entry_rcu(subscription, &subscriptions->list,
+		hlist_for_each_entry_rcu(chain, &subscriptions->list,
 					 hlist, srcu_read_lock_held(&srcu)) {
+			struct mmu_notifier *subscription =
+				mmu_notifier_from_chain(chain);
 			if (!subscription->ops->invalidate_range_end)
 				continue;
 
@@ -553,35 +578,20 @@ static void
 mn_hlist_invalidate_end(struct mmu_notifier_subscriptions *subscriptions,
 			struct mmu_notifier_range *range, bool only_end)
 {
-	struct mmu_notifier *subscription;
+	struct mm_notifier_chain *chain;
 	int id;
 
 	id = srcu_read_lock(&srcu);
-	hlist_for_each_entry_rcu(subscription, &subscriptions->list, hlist,
+	hlist_for_each_entry_rcu(chain, &subscriptions->list, hlist,
 				 srcu_read_lock_held(&srcu)) {
-		/*
-		 * Call invalidate_range here too to avoid the need for the
-		 * subsystem of having to register an invalidate_range_end
-		 * call-back when there is invalidate_range already. Usually a
-		 * subsystem registers either invalidate_range_start()/end() or
-		 * invalidate_range(), so this will be no additional overhead
-		 * (besides the pointer check).
-		 *
-		 * We skip call to invalidate_range() if we know it is safe ie
-		 * call site use mmu_notifier_invalidate_range_only_end() which
-		 * is safe to do when we know that a call to invalidate_range()
-		 * already happen under page table lock.
-		 */
-		if (!only_end && subscription->ops->invalidate_range)
-			subscription->ops->invalidate_range(subscription,
-							    range->mm,
-							    range->start,
-							    range->end);
-		if (subscription->ops->invalidate_range_end) {
+		struct mmu_notifier *subscription =
+			mmu_notifier_from_chain(chain);
+		const struct mmu_notifier_ops *ops = subscription->ops;
+
+		if (ops->invalidate_range_end) {
 			if (!mmu_notifier_range_blockable(range))
 				non_block_start();
-			subscription->ops->invalidate_range_end(subscription,
-								range);
+			ops->invalidate_range_end(subscription, range);
 			if (!mmu_notifier_range_blockable(range))
 				non_block_end();
 		}
@@ -607,27 +617,24 @@ void __mmu_notifier_invalidate_range_end(struct mmu_notifier_range *range,
 void __mmu_notifier_invalidate_range(struct mm_struct *mm,
 				  unsigned long start, unsigned long end)
 {
-	struct mmu_notifier *subscription;
+	struct mm_notifier_chain *chain;
 	int id;
 
 	id = srcu_read_lock(&srcu);
-	hlist_for_each_entry_rcu(subscription,
+	hlist_for_each_entry_rcu(chain,
 				 &mm->notifier_subscriptions->list, hlist,
 				 srcu_read_lock_held(&srcu)) {
-		if (subscription->ops->invalidate_range)
-			subscription->ops->invalidate_range(subscription, mm,
-							    start, end);
+		struct tlb_notifier *subscription =
+			tlb_notifier_from_chain(chain);
+		if (subscription->ops->invalidate_secondary_tlbs)
+			subscription->ops->invalidate_secondary_tlbs(subscription,
+								mm, start, end);
 	}
 	srcu_read_unlock(&srcu, id);
 }
 
-/*
- * Same as mmu_notifier_register but here the caller must hold the mmap_lock in
- * write mode. A NULL mn signals the notifier is being registered for itree
- * mode.
- */
-int __mmu_notifier_register(struct mmu_notifier *subscription,
-			    struct mm_struct *mm)
+int __mm_notifier_chain_register(struct mm_notifier_chain *chain,
+				struct mm_struct *mm)
 {
 	struct mmu_notifier_subscriptions *subscriptions = NULL;
 	int ret;
@@ -677,14 +684,14 @@ int __mmu_notifier_register(struct mmu_notifier *subscription,
 	if (subscriptions)
 		smp_store_release(&mm->notifier_subscriptions, subscriptions);
 
-	if (subscription) {
+	if (chain) {
 		/* Pairs with the mmdrop in mmu_notifier_unregister_* */
 		mmgrab(mm);
-		subscription->mm = mm;
-		subscription->users = 1;
+		chain->mm = mm;
+		chain->users = 1;
 
 		spin_lock(&mm->notifier_subscriptions->lock);
-		hlist_add_head_rcu(&subscription->hlist,
+		hlist_add_head_rcu(&chain->hlist,
 				   &mm->notifier_subscriptions->list);
 		spin_unlock(&mm->notifier_subscriptions->lock);
 	} else
@@ -698,6 +705,18 @@ int __mmu_notifier_register(struct mmu_notifier *subscription,
 	kfree(subscriptions);
 	return ret;
 }
+
+/*
+ * Same as mmu_notifier_register but here the caller must hold the mmap_lock in
+ * write mode. A NULL mn signals the notifier is being registered for itree
+ * mode.
+ */
+int __mmu_notifier_register(struct mmu_notifier *subscription,
+			    struct mm_struct *mm)
+{
+	return __mm_notifier_chain_register(&subscription->mm_notifier_chain,
+					mm);
+}
 EXPORT_SYMBOL_GPL(__mmu_notifier_register);
 
 /**
@@ -731,20 +750,34 @@ int mmu_notifier_register(struct mmu_notifier *subscription,
 }
 EXPORT_SYMBOL_GPL(mmu_notifier_register);
 
+int tlb_notifier_register(struct tlb_notifier *subscription,
+			struct mm_struct *mm)
+{
+	int ret;
+
+	mmap_write_lock(mm);
+	ret = __mm_notifier_chain_register(&subscription->mm_notifier_chain, mm);
+	mmap_write_unlock(mm);
+	return ret;
+}
+EXPORT_SYMBOL_GPL(tlb_notifier_register);
+
 static struct mmu_notifier *
 find_get_mmu_notifier(struct mm_struct *mm, const struct mmu_notifier_ops *ops)
 {
-	struct mmu_notifier *subscription;
+	struct mm_notifier_chain *chain;
 
 	spin_lock(&mm->notifier_subscriptions->lock);
-	hlist_for_each_entry_rcu(subscription,
+	hlist_for_each_entry_rcu(chain,
 				 &mm->notifier_subscriptions->list, hlist,
 				 lockdep_is_held(&mm->notifier_subscriptions->lock)) {
+		struct mmu_notifier *subscription =
+			mmu_notifier_from_chain(chain);
 		if (subscription->ops != ops)
 			continue;
 
-		if (likely(subscription->users != UINT_MAX))
-			subscription->users++;
+		if (likely(chain->users != UINT_MAX))
+			chain->users++;
 		else
 			subscription = ERR_PTR(-EOVERFLOW);
 		spin_unlock(&mm->notifier_subscriptions->lock);
@@ -822,7 +855,7 @@ void mmu_notifier_unregister(struct mmu_notifier *subscription,
 {
 	BUG_ON(atomic_read(&mm->mm_count) <= 0);
 
-	if (!hlist_unhashed(&subscription->hlist)) {
+	if (!hlist_unhashed(&subscription->mm_notifier_chain.hlist)) {
 		/*
 		 * SRCU here will force exit_mmap to wait for ->release to
 		 * finish before freeing the pages.
@@ -843,7 +876,7 @@ void mmu_notifier_unregister(struct mmu_notifier *subscription,
 		 * Can not use list_del_rcu() since __mmu_notifier_release
 		 * can delete it before we hold the lock.
 		 */
-		hlist_del_init_rcu(&subscription->hlist);
+		hlist_del_init_rcu(&subscription->mm_notifier_chain.hlist);
 		spin_unlock(&mm->notifier_subscriptions->lock);
 	}
 
@@ -861,15 +894,34 @@ EXPORT_SYMBOL_GPL(mmu_notifier_unregister);
 
 static void mmu_notifier_free_rcu(struct rcu_head *rcu)
 {
-	struct mmu_notifier *subscription =
-		container_of(rcu, struct mmu_notifier, rcu);
-	struct mm_struct *mm = subscription->mm;
+	struct mm_notifier_chain *chain =
+		container_of(rcu, struct mm_notifier_chain, rcu);
+	struct mm_struct *mm = chain->mm;
+	struct mmu_notifier *subscription = mmu_notifier_from_chain(chain);
 
 	subscription->ops->free_notifier(subscription);
 	/* Pairs with the get in __mmu_notifier_register() */
 	mmdrop(mm);
 }
 
+void mm_notifier_chain_put(struct mm_notifier_chain *chain)
+{
+	struct mm_struct *mm = chain->mm;
+
+	spin_lock(&mm->notifier_subscriptions->lock);
+	if (WARN_ON(!chain->users) ||
+		--chain->users)
+		goto out_unlock;
+	hlist_del_init_rcu(&chain->hlist);
+	spin_unlock(&mm->notifier_subscriptions->lock);
+
+	call_srcu(&srcu, &chain->rcu, mmu_notifier_free_rcu);
+	return;
+
+out_unlock:
+	spin_unlock(&mm->notifier_subscriptions->lock);
+}
+
 /**
  * mmu_notifier_put - Release the reference on the notifier
  * @subscription: The notifier to act on
@@ -894,22 +946,16 @@ static void mmu_notifier_free_rcu(struct rcu_head *rcu)
  */
 void mmu_notifier_put(struct mmu_notifier *subscription)
 {
-	struct mm_struct *mm = subscription->mm;
-
-	spin_lock(&mm->notifier_subscriptions->lock);
-	if (WARN_ON(!subscription->users) || --subscription->users)
-		goto out_unlock;
-	hlist_del_init_rcu(&subscription->hlist);
-	spin_unlock(&mm->notifier_subscriptions->lock);
-
-	call_srcu(&srcu, &subscription->rcu, mmu_notifier_free_rcu);
-	return;
-
-out_unlock:
-	spin_unlock(&mm->notifier_subscriptions->lock);
+	mm_notifier_chain_put(&subscription->mm_notifier_chain);
 }
 EXPORT_SYMBOL_GPL(mmu_notifier_put);
 
+void tlb_notifier_put(struct tlb_notifier *subscription)
+{
+	mm_notifier_chain_put(&subscription->mm_notifier_chain);
+}
+EXPORT_SYMBOL_GPL(tlb_notifier_put);
+
 static int __mmu_interval_notifier_insert(
 	struct mmu_interval_notifier *interval_sub, struct mm_struct *mm,
 	struct mmu_notifier_subscriptions *subscriptions, unsigned long start,

WARNING: multiple messages have this Message-ID (diff)
From: Alistair Popple <apopple@nvidia.com>
To: Jason Gunthorpe <jgg@nvidia.com>
Cc: Sean Christopherson <seanjc@google.com>,
	Robin Murphy <robin.murphy@arm.com>,
	Andrew Morton <akpm@linux-foundation.org>,
	will@kernel.org, catalin.marinas@arm.com, linux-mm@kvack.org,
	linux-kernel@vger.kernel.org, nicolinc@nvidia.com,
	linux-arm-kernel@lists.infradead.org, kvm@vger.kernel.org,
	John Hubbard <jhubbard@nvidia.com>,
	zhi.wang.linux@gmail.com
Subject: Re: [PATCH 2/2] arm64: Notify on pte permission upgrades
Date: Fri, 09 Jun 2023 16:05:05 +1000	[thread overview]
Message-ID: <87h6rhw4i0.fsf@nvidia.com> (raw)
In-Reply-To: <87pm65wfmk.fsf@nvidia.com>


Alistair Popple <apopple@nvidia.com> writes:

> Alistair Popple <apopple@nvidia.com> writes:
>
>> Alistair Popple <apopple@nvidia.com> writes:
>>
>>>> On Tue, May 30, 2023 at 02:44:40PM -0700, Sean Christopherson wrote:
>>>>> > KVM already has locking for invalidate_start/end - it has to check
>>>>> > mmu_notifier_retry_cache() with the sequence numbers/etc around when
>>>>> > it does does hva_to_pfn()
>>>>> > 
>>>>> > The bug is that the kvm_vcpu_reload_apic_access_page() path is
>>>>> > ignoring this locking so it ignores in-progress range
>>>>> > invalidations. It should spin until the invalidation clears like other
>>>>> > places in KVM.
>>>>> > 
>>>>> > The comment is kind of misleading because drivers shouldn't be abusing
>>>>> > the iommu centric invalidate_range() thing to fix missing locking in
>>>>> > start/end users. :\
>>>>> > 
>>>>> > So if KVM could be fixed up we could make invalidate_range defined to
>>>>> > be an arch specific callback to synchronize the iommu TLB.
>>>>> 
>>>>> And maybe rename invalidate_range() and/or invalidate_range_{start,end}() to make
>>>>> it super obvious that they are intended for two different purposes?  E.g. instead
>>>>> of invalidate_range(), something like invalidate_secondary_tlbs().
>>>>
>>>> Yeah, I think I would call it invalidate_arch_secondary_tlb() and
>>>> document it as being an arch specific set of invalidations that match
>>>> the architected TLB maintenance requrements. And maybe we can check it
>>>> more carefully to make it be called in less places. Like I'm not sure
>>>> it is right to call it from invalidate_range_end under this new
>>>> definition..
>>>
>>> I'd be happy to look at that, although it sounds like Sean already is.
>
> Thanks Sean for getting the KVM fix posted so quickly. I'm looking into
> doing the rename now.
>
> Do we want to do more than a simple rename and tidy up of callers
> though? What I'm thinking is introducing something like an IOMMU/TLB
> specific variant of notifiers (eg. struct tlb_notifier) which has the
> invalidate_secondary_tlbs() callback in say struct tlb_notifier_ops
> rather than leaving that in the mmu_notifier_ops.
>
> Implementation wise we'd reuse most of the mmu_notifier code, but it
> would help make the two different uses of notifiers clearer. Thoughts?

So something like the below incomplete patch (against v6.2) which would
introduce a new struct tlb_notifier and associated ops. The change isn't
huge, but does result in some churn and another layer of indirection in
mmu_notifier.c. Otherwise we can just rename the callback.

--- 

diff --git a/drivers/iommu/arm/arm-smmu-v3/arm-smmu-v3-sva.c b/drivers/iommu/arm/arm-smmu-v3/arm-smmu-v3-sva.c
index a5a63b1c947e..c300cd435609 100644
--- a/drivers/iommu/arm/arm-smmu-v3/arm-smmu-v3-sva.c
+++ b/drivers/iommu/arm/arm-smmu-v3/arm-smmu-v3-sva.c
@@ -14,7 +14,7 @@
 #include "../../io-pgtable-arm.h"
 
 struct arm_smmu_mmu_notifier {
-	struct mmu_notifier		mn;
+	struct tlb_notifier		mn;
 	struct arm_smmu_ctx_desc	*cd;
 	bool				cleared;
 	refcount_t			refs;
@@ -186,7 +186,7 @@ static void arm_smmu_free_shared_cd(struct arm_smmu_ctx_desc *cd)
 	}
 }
 
-static void arm_smmu_mm_invalidate_range(struct mmu_notifier *mn,
+static void arm_smmu_mm_invalidate_range(struct tlb_notifier *mn,
 					 struct mm_struct *mm,
 					 unsigned long start, unsigned long end)
 {
@@ -207,7 +207,7 @@ static void arm_smmu_mm_invalidate_range(struct mmu_notifier *mn,
 	arm_smmu_atc_inv_domain(smmu_domain, mm->pasid, start, size);
 }
 
-static void arm_smmu_mm_release(struct mmu_notifier *mn, struct mm_struct *mm)
+static void arm_smmu_mm_release(struct tlb_notifier *mn, struct mm_struct *mm)
 {
 	struct arm_smmu_mmu_notifier *smmu_mn = mn_to_smmu(mn);
 	struct arm_smmu_domain *smmu_domain = smmu_mn->domain;
@@ -231,15 +231,15 @@ static void arm_smmu_mm_release(struct mmu_notifier *mn, struct mm_struct *mm)
 	mutex_unlock(&sva_lock);
 }
 
-static void arm_smmu_mmu_notifier_free(struct mmu_notifier *mn)
+static void arm_smmu_mmu_notifier_free(struct tlb_notifier *mn)
 {
 	kfree(mn_to_smmu(mn));
 }
 
-static const struct mmu_notifier_ops arm_smmu_mmu_notifier_ops = {
-	.invalidate_range	= arm_smmu_mm_invalidate_range,
-	.release		= arm_smmu_mm_release,
-	.free_notifier		= arm_smmu_mmu_notifier_free,
+static const struct tlb_notifier_ops arm_smmu_tlb_notifier_ops = {
+	.invalidate_secondary_tlbs = arm_smmu_mm_invalidate_range,
+	.release		   = arm_smmu_mm_release,
+	.free_notifier		   = arm_smmu_mmu_notifier_free,
 };
 
 /* Allocate or get existing MMU notifier for this {domain, mm} pair */
@@ -252,7 +252,7 @@ arm_smmu_mmu_notifier_get(struct arm_smmu_domain *smmu_domain,
 	struct arm_smmu_mmu_notifier *smmu_mn;
 
 	list_for_each_entry(smmu_mn, &smmu_domain->mmu_notifiers, list) {
-		if (smmu_mn->mn.mm == mm) {
+		if (smmu_mn->mn.mm_notifier_chain.mm == mm) {
 			refcount_inc(&smmu_mn->refs);
 			return smmu_mn;
 		}
@@ -271,9 +271,9 @@ arm_smmu_mmu_notifier_get(struct arm_smmu_domain *smmu_domain,
 	refcount_set(&smmu_mn->refs, 1);
 	smmu_mn->cd = cd;
 	smmu_mn->domain = smmu_domain;
-	smmu_mn->mn.ops = &arm_smmu_mmu_notifier_ops;
+	smmu_mn->mn.ops = &arm_smmu_tlb_notifier_ops;
 
-	ret = mmu_notifier_register(&smmu_mn->mn, mm);
+	ret = tlb_notifier_register(&smmu_mn->mn, mm);
 	if (ret) {
 		kfree(smmu_mn);
 		goto err_free_cd;
@@ -288,7 +288,7 @@ arm_smmu_mmu_notifier_get(struct arm_smmu_domain *smmu_domain,
 
 err_put_notifier:
 	/* Frees smmu_mn */
-	mmu_notifier_put(&smmu_mn->mn);
+	tlb_notifier_put(&smmu_mn->mn);
 err_free_cd:
 	arm_smmu_free_shared_cd(cd);
 	return ERR_PTR(ret);
@@ -296,7 +296,7 @@ arm_smmu_mmu_notifier_get(struct arm_smmu_domain *smmu_domain,
 
 static void arm_smmu_mmu_notifier_put(struct arm_smmu_mmu_notifier *smmu_mn)
 {
-	struct mm_struct *mm = smmu_mn->mn.mm;
+	struct mm_struct *mm = smmu_mn->mn.mm_notifier_chain.mm;
 	struct arm_smmu_ctx_desc *cd = smmu_mn->cd;
 	struct arm_smmu_domain *smmu_domain = smmu_mn->domain;
 
@@ -316,7 +316,7 @@ static void arm_smmu_mmu_notifier_put(struct arm_smmu_mmu_notifier *smmu_mn)
 	}
 
 	/* Frees smmu_mn */
-	mmu_notifier_put(&smmu_mn->mn);
+	tlb_notifier_put(&smmu_mn->mn);
 	arm_smmu_free_shared_cd(cd);
 }
 
diff --git a/include/linux/mmu_notifier.h b/include/linux/mmu_notifier.h
index d6c06e140277..157571497b28 100644
--- a/include/linux/mmu_notifier.h
+++ b/include/linux/mmu_notifier.h
@@ -13,6 +13,7 @@ struct mmu_notifier_subscriptions;
 struct mmu_notifier;
 struct mmu_notifier_range;
 struct mmu_interval_notifier;
+struct mm_notifier_chain;
 
 /**
  * enum mmu_notifier_event - reason for the mmu notifier callback
@@ -61,6 +62,18 @@ enum mmu_notifier_event {
 
 #define MMU_NOTIFIER_RANGE_BLOCKABLE (1 << 0)
 
+struct tlb_notifier;
+struct tlb_notifier_ops {
+	void (*invalidate_secondary_tlbs)(struct tlb_notifier *subscription,
+					struct mm_struct *mm,
+					unsigned long start,
+					unsigned long end);
+
+	void (*free_notifier)(struct tlb_notifier *subscription);
+	void (*release)(struct tlb_notifier *subscription,
+			struct mm_struct *mm);
+};
+
 struct mmu_notifier_ops {
 	/*
 	 * Called either by mmu_notifier_unregister or when the mm is
@@ -186,29 +199,6 @@ struct mmu_notifier_ops {
 	void (*invalidate_range_end)(struct mmu_notifier *subscription,
 				     const struct mmu_notifier_range *range);
 
-	/*
-	 * invalidate_range() is either called between
-	 * invalidate_range_start() and invalidate_range_end() when the
-	 * VM has to free pages that where unmapped, but before the
-	 * pages are actually freed, or outside of _start()/_end() when
-	 * a (remote) TLB is necessary.
-	 *
-	 * If invalidate_range() is used to manage a non-CPU TLB with
-	 * shared page-tables, it not necessary to implement the
-	 * invalidate_range_start()/end() notifiers, as
-	 * invalidate_range() already catches the points in time when an
-	 * external TLB range needs to be flushed. For more in depth
-	 * discussion on this see Documentation/mm/mmu_notifier.rst
-	 *
-	 * Note that this function might be called with just a sub-range
-	 * of what was passed to invalidate_range_start()/end(), if
-	 * called between those functions.
-	 */
-	void (*invalidate_range)(struct mmu_notifier *subscription,
-				 struct mm_struct *mm,
-				 unsigned long start,
-				 unsigned long end);
-
 	/*
 	 * These callbacks are used with the get/put interface to manage the
 	 * lifetime of the mmu_notifier memory. alloc_notifier() returns a new
@@ -234,14 +224,23 @@ struct mmu_notifier_ops {
  * 2. One of the reverse map locks is held (i_mmap_rwsem or anon_vma->rwsem).
  * 3. No other concurrent thread can access the list (release)
  */
-struct mmu_notifier {
+struct mm_notifier_chain {
 	struct hlist_node hlist;
-	const struct mmu_notifier_ops *ops;
 	struct mm_struct *mm;
 	struct rcu_head rcu;
 	unsigned int users;
 };
 
+struct mmu_notifier {
+	const struct mmu_notifier_ops *ops;
+	struct mm_notifier_chain mm_notifier_chain;
+};
+
+struct tlb_notifier {
+	const struct tlb_notifier_ops *ops;
+	struct mm_notifier_chain mm_notifier_chain;
+};
+
 /**
  * struct mmu_interval_notifier_ops
  * @invalidate: Upon return the caller must stop using any SPTEs within this
@@ -283,6 +282,10 @@ static inline int mm_has_notifiers(struct mm_struct *mm)
 	return unlikely(mm->notifier_subscriptions);
 }
 
+int tlb_notifier_register(struct tlb_notifier *subscription,
+			struct mm_struct *mm);
+void tlb_notifier_put(struct tlb_notifier *subscription);
+
 struct mmu_notifier *mmu_notifier_get_locked(const struct mmu_notifier_ops *ops,
 					     struct mm_struct *mm);
 static inline struct mmu_notifier *
diff --git a/mm/mmu_notifier.c b/mm/mmu_notifier.c
index f45ff1b7626a..cdc3a373a225 100644
--- a/mm/mmu_notifier.c
+++ b/mm/mmu_notifier.c
@@ -47,6 +47,16 @@ struct mmu_notifier_subscriptions {
 	struct hlist_head deferred_list;
 };
 
+struct mmu_notifier *mmu_notifier_from_chain(struct mm_notifier_chain *chain)
+{
+	return container_of(chain, struct mmu_notifier, mm_notifier_chain);
+}
+
+struct tlb_notifier *tlb_notifier_from_chain(struct mm_notifier_chain *chain)
+{
+	return container_of(chain, struct tlb_notifier, mm_notifier_chain);
+}
+
 /*
  * This is a collision-retry read-side/write-side 'lock', a lot like a
  * seqcount, however this allows multiple write-sides to hold it at
@@ -299,7 +309,7 @@ static void mn_itree_release(struct mmu_notifier_subscriptions *subscriptions,
 static void mn_hlist_release(struct mmu_notifier_subscriptions *subscriptions,
 			     struct mm_struct *mm)
 {
-	struct mmu_notifier *subscription;
+	struct mm_notifier_chain *chain;
 	int id;
 
 	/*
@@ -307,8 +317,10 @@ static void mn_hlist_release(struct mmu_notifier_subscriptions *subscriptions,
 	 * ->release returns.
 	 */
 	id = srcu_read_lock(&srcu);
-	hlist_for_each_entry_rcu(subscription, &subscriptions->list, hlist,
-				 srcu_read_lock_held(&srcu))
+	hlist_for_each_entry_rcu(chain, &subscriptions->list, hlist,
+				srcu_read_lock_held(&srcu)) {
+		struct mmu_notifier *subscription =
+			mmu_notifier_from_chain(chain);
 		/*
 		 * If ->release runs before mmu_notifier_unregister it must be
 		 * handled, as it's the only way for the driver to flush all
@@ -317,18 +329,19 @@ static void mn_hlist_release(struct mmu_notifier_subscriptions *subscriptions,
 		 */
 		if (subscription->ops->release)
 			subscription->ops->release(subscription, mm);
+	}
 
 	spin_lock(&subscriptions->lock);
 	while (unlikely(!hlist_empty(&subscriptions->list))) {
-		subscription = hlist_entry(subscriptions->list.first,
-					   struct mmu_notifier, hlist);
+		chain = hlist_entry(subscriptions->list.first,
+					   struct mm_notifier_chain, hlist);
 		/*
 		 * We arrived before mmu_notifier_unregister so
 		 * mmu_notifier_unregister will do nothing other than to wait
 		 * for ->release to finish and for mmu_notifier_unregister to
 		 * return.
 		 */
-		hlist_del_init_rcu(&subscription->hlist);
+		hlist_del_init_rcu(&chain->hlist);
 	}
 	spin_unlock(&subscriptions->lock);
 	srcu_read_unlock(&srcu, id);
@@ -366,13 +379,15 @@ int __mmu_notifier_clear_flush_young(struct mm_struct *mm,
 					unsigned long start,
 					unsigned long end)
 {
-	struct mmu_notifier *subscription;
+	struct mm_notifier_chain *chain;
 	int young = 0, id;
 
 	id = srcu_read_lock(&srcu);
-	hlist_for_each_entry_rcu(subscription,
+	hlist_for_each_entry_rcu(chain,
 				 &mm->notifier_subscriptions->list, hlist,
 				 srcu_read_lock_held(&srcu)) {
+		struct mmu_notifier *subscription =
+			mmu_notifier_from_chain(chain);
 		if (subscription->ops->clear_flush_young)
 			young |= subscription->ops->clear_flush_young(
 				subscription, mm, start, end);
@@ -386,13 +401,15 @@ int __mmu_notifier_clear_young(struct mm_struct *mm,
 			       unsigned long start,
 			       unsigned long end)
 {
-	struct mmu_notifier *subscription;
+	struct mm_notifier_chain *chain;
 	int young = 0, id;
 
 	id = srcu_read_lock(&srcu);
-	hlist_for_each_entry_rcu(subscription,
+	hlist_for_each_entry_rcu(chain,
 				 &mm->notifier_subscriptions->list, hlist,
 				 srcu_read_lock_held(&srcu)) {
+		struct mmu_notifier *subscription =
+			mmu_notifier_from_chain(chain);
 		if (subscription->ops->clear_young)
 			young |= subscription->ops->clear_young(subscription,
 								mm, start, end);
@@ -405,13 +422,15 @@ int __mmu_notifier_clear_young(struct mm_struct *mm,
 int __mmu_notifier_test_young(struct mm_struct *mm,
 			      unsigned long address)
 {
-	struct mmu_notifier *subscription;
+	struct mm_notifier_chain *chain;
 	int young = 0, id;
 
 	id = srcu_read_lock(&srcu);
-	hlist_for_each_entry_rcu(subscription,
+	hlist_for_each_entry_rcu(chain,
 				 &mm->notifier_subscriptions->list, hlist,
 				 srcu_read_lock_held(&srcu)) {
+		struct mmu_notifier *subscription =
+			mmu_notifier_from_chain(chain);
 		if (subscription->ops->test_young) {
 			young = subscription->ops->test_young(subscription, mm,
 							      address);
@@ -427,13 +446,15 @@ int __mmu_notifier_test_young(struct mm_struct *mm,
 void __mmu_notifier_change_pte(struct mm_struct *mm, unsigned long address,
 			       pte_t pte)
 {
-	struct mmu_notifier *subscription;
+	struct mm_notifier_chain *chain;
 	int id;
 
 	id = srcu_read_lock(&srcu);
-	hlist_for_each_entry_rcu(subscription,
+	hlist_for_each_entry_rcu(chain,
 				 &mm->notifier_subscriptions->list, hlist,
 				 srcu_read_lock_held(&srcu)) {
+		struct mmu_notifier *subscription =
+			mmu_notifier_from_chain(chain);
 		if (subscription->ops->change_pte)
 			subscription->ops->change_pte(subscription, mm, address,
 						      pte);
@@ -476,13 +497,15 @@ static int mn_hlist_invalidate_range_start(
 	struct mmu_notifier_subscriptions *subscriptions,
 	struct mmu_notifier_range *range)
 {
-	struct mmu_notifier *subscription;
+	struct mm_notifier_chain *chain;
 	int ret = 0;
 	int id;
 
 	id = srcu_read_lock(&srcu);
-	hlist_for_each_entry_rcu(subscription, &subscriptions->list, hlist,
+	hlist_for_each_entry_rcu(chain, &subscriptions->list, hlist,
 				 srcu_read_lock_held(&srcu)) {
+		struct mmu_notifier *subscription =
+			mmu_notifier_from_chain(chain);
 		const struct mmu_notifier_ops *ops = subscription->ops;
 
 		if (ops->invalidate_range_start) {
@@ -519,8 +542,10 @@ static int mn_hlist_invalidate_range_start(
 		 * notifiers and one or more failed start, any that succeeded
 		 * start are expecting their end to be called.  Do so now.
 		 */
-		hlist_for_each_entry_rcu(subscription, &subscriptions->list,
+		hlist_for_each_entry_rcu(chain, &subscriptions->list,
 					 hlist, srcu_read_lock_held(&srcu)) {
+			struct mmu_notifier *subscription =
+				mmu_notifier_from_chain(chain);
 			if (!subscription->ops->invalidate_range_end)
 				continue;
 
@@ -553,35 +578,20 @@ static void
 mn_hlist_invalidate_end(struct mmu_notifier_subscriptions *subscriptions,
 			struct mmu_notifier_range *range, bool only_end)
 {
-	struct mmu_notifier *subscription;
+	struct mm_notifier_chain *chain;
 	int id;
 
 	id = srcu_read_lock(&srcu);
-	hlist_for_each_entry_rcu(subscription, &subscriptions->list, hlist,
+	hlist_for_each_entry_rcu(chain, &subscriptions->list, hlist,
 				 srcu_read_lock_held(&srcu)) {
-		/*
-		 * Call invalidate_range here too to avoid the need for the
-		 * subsystem of having to register an invalidate_range_end
-		 * call-back when there is invalidate_range already. Usually a
-		 * subsystem registers either invalidate_range_start()/end() or
-		 * invalidate_range(), so this will be no additional overhead
-		 * (besides the pointer check).
-		 *
-		 * We skip call to invalidate_range() if we know it is safe ie
-		 * call site use mmu_notifier_invalidate_range_only_end() which
-		 * is safe to do when we know that a call to invalidate_range()
-		 * already happen under page table lock.
-		 */
-		if (!only_end && subscription->ops->invalidate_range)
-			subscription->ops->invalidate_range(subscription,
-							    range->mm,
-							    range->start,
-							    range->end);
-		if (subscription->ops->invalidate_range_end) {
+		struct mmu_notifier *subscription =
+			mmu_notifier_from_chain(chain);
+		const struct mmu_notifier_ops *ops = subscription->ops;
+
+		if (ops->invalidate_range_end) {
 			if (!mmu_notifier_range_blockable(range))
 				non_block_start();
-			subscription->ops->invalidate_range_end(subscription,
-								range);
+			ops->invalidate_range_end(subscription, range);
 			if (!mmu_notifier_range_blockable(range))
 				non_block_end();
 		}
@@ -607,27 +617,24 @@ void __mmu_notifier_invalidate_range_end(struct mmu_notifier_range *range,
 void __mmu_notifier_invalidate_range(struct mm_struct *mm,
 				  unsigned long start, unsigned long end)
 {
-	struct mmu_notifier *subscription;
+	struct mm_notifier_chain *chain;
 	int id;
 
 	id = srcu_read_lock(&srcu);
-	hlist_for_each_entry_rcu(subscription,
+	hlist_for_each_entry_rcu(chain,
 				 &mm->notifier_subscriptions->list, hlist,
 				 srcu_read_lock_held(&srcu)) {
-		if (subscription->ops->invalidate_range)
-			subscription->ops->invalidate_range(subscription, mm,
-							    start, end);
+		struct tlb_notifier *subscription =
+			tlb_notifier_from_chain(chain);
+		if (subscription->ops->invalidate_secondary_tlbs)
+			subscription->ops->invalidate_secondary_tlbs(subscription,
+								mm, start, end);
 	}
 	srcu_read_unlock(&srcu, id);
 }
 
-/*
- * Same as mmu_notifier_register but here the caller must hold the mmap_lock in
- * write mode. A NULL mn signals the notifier is being registered for itree
- * mode.
- */
-int __mmu_notifier_register(struct mmu_notifier *subscription,
-			    struct mm_struct *mm)
+int __mm_notifier_chain_register(struct mm_notifier_chain *chain,
+				struct mm_struct *mm)
 {
 	struct mmu_notifier_subscriptions *subscriptions = NULL;
 	int ret;
@@ -677,14 +684,14 @@ int __mmu_notifier_register(struct mmu_notifier *subscription,
 	if (subscriptions)
 		smp_store_release(&mm->notifier_subscriptions, subscriptions);
 
-	if (subscription) {
+	if (chain) {
 		/* Pairs with the mmdrop in mmu_notifier_unregister_* */
 		mmgrab(mm);
-		subscription->mm = mm;
-		subscription->users = 1;
+		chain->mm = mm;
+		chain->users = 1;
 
 		spin_lock(&mm->notifier_subscriptions->lock);
-		hlist_add_head_rcu(&subscription->hlist,
+		hlist_add_head_rcu(&chain->hlist,
 				   &mm->notifier_subscriptions->list);
 		spin_unlock(&mm->notifier_subscriptions->lock);
 	} else
@@ -698,6 +705,18 @@ int __mmu_notifier_register(struct mmu_notifier *subscription,
 	kfree(subscriptions);
 	return ret;
 }
+
+/*
+ * Same as mmu_notifier_register but here the caller must hold the mmap_lock in
+ * write mode. A NULL mn signals the notifier is being registered for itree
+ * mode.
+ */
+int __mmu_notifier_register(struct mmu_notifier *subscription,
+			    struct mm_struct *mm)
+{
+	return __mm_notifier_chain_register(&subscription->mm_notifier_chain,
+					mm);
+}
 EXPORT_SYMBOL_GPL(__mmu_notifier_register);
 
 /**
@@ -731,20 +750,34 @@ int mmu_notifier_register(struct mmu_notifier *subscription,
 }
 EXPORT_SYMBOL_GPL(mmu_notifier_register);
 
+int tlb_notifier_register(struct tlb_notifier *subscription,
+			struct mm_struct *mm)
+{
+	int ret;
+
+	mmap_write_lock(mm);
+	ret = __mm_notifier_chain_register(&subscription->mm_notifier_chain, mm);
+	mmap_write_unlock(mm);
+	return ret;
+}
+EXPORT_SYMBOL_GPL(tlb_notifier_register);
+
 static struct mmu_notifier *
 find_get_mmu_notifier(struct mm_struct *mm, const struct mmu_notifier_ops *ops)
 {
-	struct mmu_notifier *subscription;
+	struct mm_notifier_chain *chain;
 
 	spin_lock(&mm->notifier_subscriptions->lock);
-	hlist_for_each_entry_rcu(subscription,
+	hlist_for_each_entry_rcu(chain,
 				 &mm->notifier_subscriptions->list, hlist,
 				 lockdep_is_held(&mm->notifier_subscriptions->lock)) {
+		struct mmu_notifier *subscription =
+			mmu_notifier_from_chain(chain);
 		if (subscription->ops != ops)
 			continue;
 
-		if (likely(subscription->users != UINT_MAX))
-			subscription->users++;
+		if (likely(chain->users != UINT_MAX))
+			chain->users++;
 		else
 			subscription = ERR_PTR(-EOVERFLOW);
 		spin_unlock(&mm->notifier_subscriptions->lock);
@@ -822,7 +855,7 @@ void mmu_notifier_unregister(struct mmu_notifier *subscription,
 {
 	BUG_ON(atomic_read(&mm->mm_count) <= 0);
 
-	if (!hlist_unhashed(&subscription->hlist)) {
+	if (!hlist_unhashed(&subscription->mm_notifier_chain.hlist)) {
 		/*
 		 * SRCU here will force exit_mmap to wait for ->release to
 		 * finish before freeing the pages.
@@ -843,7 +876,7 @@ void mmu_notifier_unregister(struct mmu_notifier *subscription,
 		 * Can not use list_del_rcu() since __mmu_notifier_release
 		 * can delete it before we hold the lock.
 		 */
-		hlist_del_init_rcu(&subscription->hlist);
+		hlist_del_init_rcu(&subscription->mm_notifier_chain.hlist);
 		spin_unlock(&mm->notifier_subscriptions->lock);
 	}
 
@@ -861,15 +894,34 @@ EXPORT_SYMBOL_GPL(mmu_notifier_unregister);
 
 static void mmu_notifier_free_rcu(struct rcu_head *rcu)
 {
-	struct mmu_notifier *subscription =
-		container_of(rcu, struct mmu_notifier, rcu);
-	struct mm_struct *mm = subscription->mm;
+	struct mm_notifier_chain *chain =
+		container_of(rcu, struct mm_notifier_chain, rcu);
+	struct mm_struct *mm = chain->mm;
+	struct mmu_notifier *subscription = mmu_notifier_from_chain(chain);
 
 	subscription->ops->free_notifier(subscription);
 	/* Pairs with the get in __mmu_notifier_register() */
 	mmdrop(mm);
 }
 
+void mm_notifier_chain_put(struct mm_notifier_chain *chain)
+{
+	struct mm_struct *mm = chain->mm;
+
+	spin_lock(&mm->notifier_subscriptions->lock);
+	if (WARN_ON(!chain->users) ||
+		--chain->users)
+		goto out_unlock;
+	hlist_del_init_rcu(&chain->hlist);
+	spin_unlock(&mm->notifier_subscriptions->lock);
+
+	call_srcu(&srcu, &chain->rcu, mmu_notifier_free_rcu);
+	return;
+
+out_unlock:
+	spin_unlock(&mm->notifier_subscriptions->lock);
+}
+
 /**
  * mmu_notifier_put - Release the reference on the notifier
  * @subscription: The notifier to act on
@@ -894,22 +946,16 @@ static void mmu_notifier_free_rcu(struct rcu_head *rcu)
  */
 void mmu_notifier_put(struct mmu_notifier *subscription)
 {
-	struct mm_struct *mm = subscription->mm;
-
-	spin_lock(&mm->notifier_subscriptions->lock);
-	if (WARN_ON(!subscription->users) || --subscription->users)
-		goto out_unlock;
-	hlist_del_init_rcu(&subscription->hlist);
-	spin_unlock(&mm->notifier_subscriptions->lock);
-
-	call_srcu(&srcu, &subscription->rcu, mmu_notifier_free_rcu);
-	return;
-
-out_unlock:
-	spin_unlock(&mm->notifier_subscriptions->lock);
+	mm_notifier_chain_put(&subscription->mm_notifier_chain);
 }
 EXPORT_SYMBOL_GPL(mmu_notifier_put);
 
+void tlb_notifier_put(struct tlb_notifier *subscription)
+{
+	mm_notifier_chain_put(&subscription->mm_notifier_chain);
+}
+EXPORT_SYMBOL_GPL(tlb_notifier_put);
+
 static int __mmu_interval_notifier_insert(
 	struct mmu_interval_notifier *interval_sub, struct mm_struct *mm,
 	struct mmu_notifier_subscriptions *subscriptions, unsigned long start,

_______________________________________________
linux-arm-kernel mailing list
linux-arm-kernel@lists.infradead.org
http://lists.infradead.org/mailman/listinfo/linux-arm-kernel

  reply	other threads:[~2023-06-09  6:10 UTC|newest]

Thread overview: 46+ messages / expand[flat|nested]  mbox.gz  Atom feed  top
2023-05-24  1:47 [PATCH 1/2] mmu_notifiers: Restore documentation for .invalidate_range() Alistair Popple
2023-05-24  1:47 ` Alistair Popple
2023-05-24  1:47 ` [PATCH 2/2] arm64: Notify on pte permission upgrades Alistair Popple
2023-05-24  1:47   ` Alistair Popple
2023-05-28  0:02   ` Jason Gunthorpe
2023-05-28  0:02     ` Jason Gunthorpe
2023-05-30  8:05     ` Alistair Popple
2023-05-30  8:05       ` Alistair Popple
2023-05-30 11:54       ` Jason Gunthorpe
2023-05-30 11:54         ` Jason Gunthorpe
2023-05-30 12:14         ` Robin Murphy
2023-05-30 12:14           ` Robin Murphy
2023-05-30 12:52           ` Jason Gunthorpe
2023-05-30 12:52             ` Jason Gunthorpe
2023-05-30 13:44             ` Robin Murphy
2023-05-30 13:44               ` Robin Murphy
2023-05-30 14:06               ` Jason Gunthorpe
2023-05-30 14:06                 ` Jason Gunthorpe
2023-05-30 21:44                 ` Sean Christopherson
2023-05-30 21:44                   ` Sean Christopherson
2023-05-30 23:08                   ` Jason Gunthorpe
2023-05-30 23:08                     ` Jason Gunthorpe
2023-05-31  0:30                     ` Alistair Popple
2023-05-31  0:30                       ` Alistair Popple
2023-05-31  0:32                       ` Jason Gunthorpe
2023-05-31  0:32                         ` Jason Gunthorpe
2023-05-31  2:46                         ` Alistair Popple
2023-05-31  2:46                           ` Alistair Popple
2023-05-31 15:30                           ` Jason Gunthorpe
2023-05-31 15:30                             ` Jason Gunthorpe
2023-05-31 23:56                             ` Alistair Popple
2023-05-31 23:56                               ` Alistair Popple
     [not found]                       ` <31cdd164783fefad4c9ef4a6d33c1e0094405d0f03added523a82dd9febdf15f@mu.id>
2023-06-09  2:06                         ` Alistair Popple
2023-06-09  2:06                           ` Alistair Popple
2023-06-09  6:05                           ` Alistair Popple [this message]
2023-06-09  6:05                             ` Alistair Popple
2023-05-24  2:20 ` [PATCH 1/2] mmu_notifiers: Restore documentation for .invalidate_range() John Hubbard
2023-05-24  2:20   ` John Hubbard
2023-05-24  4:45   ` Alistair Popple
2023-05-24  4:45     ` Alistair Popple
2023-05-27 23:56   ` Jason Gunthorpe
2023-05-27 23:56     ` Jason Gunthorpe
2023-05-24  3:48 ` Zhi Wang
2023-05-24  3:48   ` Zhi Wang
2023-05-24  4:57   ` Alistair Popple
2023-05-24  4:57     ` Alistair Popple

Reply instructions:

You may reply publicly to this message via plain-text email
using any one of the following methods:

* Save the following mbox file, import it into your mail client,
  and reply-to-all from there: mbox

  Avoid top-posting and favor interleaved quoting:
  https://en.wikipedia.org/wiki/Posting_style#Interleaved_style

* Reply using the --to, --cc, and --in-reply-to
  switches of git-send-email(1):

  git send-email \
    --in-reply-to=87h6rhw4i0.fsf@nvidia.com \
    --to=apopple@nvidia.com \
    --cc=akpm@linux-foundation.org \
    --cc=catalin.marinas@arm.com \
    --cc=jgg@nvidia.com \
    --cc=jhubbard@nvidia.com \
    --cc=kvm@vger.kernel.org \
    --cc=linux-arm-kernel@lists.infradead.org \
    --cc=linux-kernel@vger.kernel.org \
    --cc=linux-mm@kvack.org \
    --cc=nicolinc@nvidia.com \
    --cc=robin.murphy@arm.com \
    --cc=seanjc@google.com \
    --cc=will@kernel.org \
    --cc=zhi.wang.linux@gmail.com \
    /path/to/YOUR_REPLY

  https://kernel.org/pub/software/scm/git/docs/git-send-email.html

* If your mail client supports setting the In-Reply-To header
  via mailto: links, try the mailto: link
Be sure your reply has a Subject: header at the top and a blank line before the message body.
This is an external index of several public inboxes,
see mirroring instructions on how to clone and mirror
all data and code used by this external index.