[v6,13/17] static_call: Add static_call_cond()
diff mbox series

Message ID 20200710134336.918547865@infradead.org
State New
Headers show
Series
  • Add static_call
Related show

Commit Message

Peter Zijlstra July 10, 2020, 1:38 p.m. UTC
Extend the static_call infrastructure to optimize the following common
pattern:

	if (func_ptr)
		func_ptr(args...)

For the trampoline (which is in effect a tail-call), we patch the
JMP.d32 into a RET, which then directly consumes the trampoline call.

For the in-line sites we replace the CALL with a NOP5.

NOTE: this is 'obviously' limited to functions with a 'void' return type.

NOTE: DEFINE_STATIC_COND_CALL() only requires a typename, as opposed
      to a full function.

Signed-off-by: Peter Zijlstra (Intel) <peterz@infradead.org>
---
 arch/x86/include/asm/static_call.h |   12 +++--
 arch/x86/kernel/static_call.c      |   42 +++++++++++++-----
 include/linux/static_call.h        |   86 +++++++++++++++++++++++++++++++++++++
 3 files changed, 127 insertions(+), 13 deletions(-)

Comments

Steven Rostedt July 10, 2020, 11:08 p.m. UTC | #1
On Fri, 10 Jul 2020 15:38:44 +0200
Peter Zijlstra <peterz@infradead.org> wrote:

> +static void __static_call_transform(void *insn, enum insn_type type, void *func)
>  {
> -	const void *code = text_gen_insn(opcode, insn, func);
> +	int size = CALL_INSN_SIZE;
> +	const void *code;
>  
> -	if (WARN_ONCE(*(u8 *)insn != opcode,
> -		      "unexpected static call insn opcode 0x%x at %pS\n",
> -		      opcode, insn))

I would still feel better if we did some sort of sanity check before
just writing to the text. Confirm this is a jmp, call, ret or nop?

-- Steve


> -		return;
> +	switch (type) {
> +	case CALL:
> +		code = text_gen_insn(CALL_INSN_OPCODE, insn, func);
> +		break;
> +
> +	case NOP:
> +		code = ideal_nops[NOP_ATOMIC5];
> +		break;
> +
> +	case JMP:
> +		code = text_gen_insn(JMP32_INSN_OPCODE, insn, func);
> +		break;
> +
> +	case RET:
> +		code = text_gen_insn(RET_INSN_OPCODE, insn, func);
> +		size = RET_INSN_SIZE;
> +		break;
> +	}
>  
> -	if (memcmp(insn, code, CALL_INSN_SIZE) == 0)
> +	if (memcmp(insn, code, size) == 0)
>  		return;
>  
> -	text_poke_bp(insn, code, CALL_INSN_SIZE, NULL);
> +	text_poke_bp(insn, code, size, NULL);
>  }
>
Peter Zijlstra July 11, 2020, 5:09 a.m. UTC | #2
On Fri, Jul 10, 2020 at 07:08:25PM -0400, Steven Rostedt wrote:
> On Fri, 10 Jul 2020 15:38:44 +0200
> Peter Zijlstra <peterz@infradead.org> wrote:
> 
> > +static void __static_call_transform(void *insn, enum insn_type type, void *func)
> >  {
> > -	const void *code = text_gen_insn(opcode, insn, func);
> > +	int size = CALL_INSN_SIZE;
> > +	const void *code;
> >  
> > -	if (WARN_ONCE(*(u8 *)insn != opcode,
> > -		      "unexpected static call insn opcode 0x%x at %pS\n",
> > -		      opcode, insn))
> 
> I would still feel better if we did some sort of sanity check before
> just writing to the text. Confirm this is a jmp, call, ret or nop?

I'll see if I can come up with something, but I'm not sure we keep
enough state to be able to reconstruct what should be there.
Peter Zijlstra July 11, 2020, 10:49 a.m. UTC | #3
On Fri, Jul 10, 2020 at 07:08:25PM -0400, Steven Rostedt wrote:
> On Fri, 10 Jul 2020 15:38:44 +0200
> Peter Zijlstra <peterz@infradead.org> wrote:
> 
> > +static void __static_call_transform(void *insn, enum insn_type type, void *func)
> >  {
> > -	const void *code = text_gen_insn(opcode, insn, func);
> > +	int size = CALL_INSN_SIZE;
> > +	const void *code;
> >  
> > -	if (WARN_ONCE(*(u8 *)insn != opcode,
> > -		      "unexpected static call insn opcode 0x%x at %pS\n",
> > -		      opcode, insn))
> 
> I would still feel better if we did some sort of sanity check before
> just writing to the text. Confirm this is a jmp, call, ret or nop?

Something like so (on top of the next patch) ?

I'm not convinced it actually helps much, but if it makes you feel
better :-)


--- a/arch/x86/kernel/static_call.c
+++ b/arch/x86/kernel/static_call.c
@@ -56,15 +56,36 @@ static inline enum insn_type __sc_insn(b
 	return 2*tail + null;
 }
 
+static void __static_call_validate(void *insn, bool tail)
+{
+	u8 opcode = *(u8 *)insn;
+
+	if (tail) {
+		if (opcode == JMP32_INSN_OPCODE ||
+		    opcode == RET_INSN_OPCODE)
+			return;
+	} else {
+		if (opcode == CALL_INSN_OPCODE ||
+		    !memcmp(insn, ideal_nops[NOP_ATOMIC5], 5))
+			return;
+	}
+
+	WARN_ONCE(1, "unexpected static_call insn opcode 0x%x at %pS\n", opcode, insn);
+}
+
 void arch_static_call_transform(void *site, void *tramp, void *func, bool tail)
 {
 	mutex_lock(&text_mutex);
 
-	if (tramp)
+	if (tramp) {
+		__static_call_validate(tramp, true);
 		__static_call_transform(tramp, __sc_insn(!func, true), func);
+	}
 
-	if (IS_ENABLED(CONFIG_HAVE_STATIC_CALL_INLINE) && site)
+	if (IS_ENABLED(CONFIG_HAVE_STATIC_CALL_INLINE) && site) {
+		__static_call_validate(site, tail);
 		__static_call_transform(site, __sc_insn(!func, tail), func);
+	}
 
 	mutex_unlock(&text_mutex);
 }
Steven Rostedt July 13, 2020, 8:32 p.m. UTC | #4
On Sat, 11 Jul 2020 12:49:30 +0200
Peter Zijlstra <peterz@infradead.org> wrote:
> 
> Something like so (on top of the next patch) ?
> 
> I'm not convinced it actually helps much, but if it makes you feel
> better :-)

After you have bricked a bunch of people's NICs, you would be paranoid
about this too!

You work for Intel, next time you go to an office, see if you can find
my picture on any dartboards in there ;-)


> 
> 
> --- a/arch/x86/kernel/static_call.c
> +++ b/arch/x86/kernel/static_call.c
> @@ -56,15 +56,36 @@ static inline enum insn_type __sc_insn(b
>  	return 2*tail + null;
>  }
>  
> +static void __static_call_validate(void *insn, bool tail)
> +{
> +	u8 opcode = *(u8 *)insn;
> +
> +	if (tail) {
> +		if (opcode == JMP32_INSN_OPCODE ||
> +		    opcode == RET_INSN_OPCODE)
> +			return;
> +	} else {
> +		if (opcode == CALL_INSN_OPCODE ||
> +		    !memcmp(insn, ideal_nops[NOP_ATOMIC5], 5))
> +			return;
> +	}
> +
> +	WARN_ONCE(1, "unexpected static_call insn opcode 0x%x at %pS\n", opcode, insn);
> +}
> +
>  void arch_static_call_transform(void *site, void *tramp, void *func, bool tail)
>  {
>  	mutex_lock(&text_mutex);
>  
> -	if (tramp)
> +	if (tramp) {
> +		__static_call_validate(tramp, true);
>  		__static_call_transform(tramp, __sc_insn(!func, true), func);
> +	}
>  
> -	if (IS_ENABLED(CONFIG_HAVE_STATIC_CALL_INLINE) && site)
> +	if (IS_ENABLED(CONFIG_HAVE_STATIC_CALL_INLINE) && site) {
> +		__static_call_validate(site, tail);

I'd feel even more better if the validate failed, we just don't do the
update.

-- Steve


>  		__static_call_transform(site, __sc_insn(!func, tail), func);
> +	}
>  
>  	mutex_unlock(&text_mutex);
>  }
Peter Zijlstra July 14, 2020, 9:53 a.m. UTC | #5
On Mon, Jul 13, 2020 at 04:32:39PM -0400, Steven Rostedt wrote:
> On Sat, 11 Jul 2020 12:49:30 +0200
> Peter Zijlstra <peterz@infradead.org> wrote:
> > 
> > Something like so (on top of the next patch) ?
> > 
> > I'm not convinced it actually helps much, but if it makes you feel
> > better :-)
> 
> After you have bricked a bunch of people's NICs, you would be paranoid
> about this too!
> 
> You work for Intel, next time you go to an office, see if you can find
> my picture on any dartboards in there ;-)

I'll tell em it was bad hardware for being able to get bricked like that
instead, how's that? ;-)

> > --- a/arch/x86/kernel/static_call.c
> > +++ b/arch/x86/kernel/static_call.c
> > @@ -56,15 +56,36 @@ static inline enum insn_type __sc_insn(b
> >  	return 2*tail + null;
> >  }
> >  
> > +static void __static_call_validate(void *insn, bool tail)
> > +{
> > +	u8 opcode = *(u8 *)insn;
> > +
> > +	if (tail) {
> > +		if (opcode == JMP32_INSN_OPCODE ||
> > +		    opcode == RET_INSN_OPCODE)
> > +			return;
> > +	} else {
> > +		if (opcode == CALL_INSN_OPCODE ||
> > +		    !memcmp(insn, ideal_nops[NOP_ATOMIC5], 5))
> > +			return;
> > +	}
> > +
> > +	WARN_ONCE(1, "unexpected static_call insn opcode 0x%x at %pS\n", opcode, insn);
> > +}
> > +
> >  void arch_static_call_transform(void *site, void *tramp, void *func, bool tail)
> >  {
> >  	mutex_lock(&text_mutex);
> >  
> > -	if (tramp)
> > +	if (tramp) {
> > +		__static_call_validate(tramp, true);
> >  		__static_call_transform(tramp, __sc_insn(!func, true), func);
> > +	}
> >  
> > -	if (IS_ENABLED(CONFIG_HAVE_STATIC_CALL_INLINE) && site)
> > +	if (IS_ENABLED(CONFIG_HAVE_STATIC_CALL_INLINE) && site) {
> > +		__static_call_validate(site, tail);
> 
> I'd feel even more better if the validate failed, we just don't do the
> update.

It is either this or BUG()/panic(), you pick ;-)

Patch
diff mbox series

--- a/arch/x86/include/asm/static_call.h
+++ b/arch/x86/include/asm/static_call.h
@@ -20,15 +20,21 @@ 
  * it does tail-call optimization on the call; since you cannot compute the
  * relative displacement across sections.
  */
-#define ARCH_DEFINE_STATIC_CALL_TRAMP(name, func)			\
+
+#define __ARCH_DEFINE_STATIC_CALL_TRAMP(name, insns)			\
 	asm(".pushsection .static_call.text, \"ax\"		\n"	\
 	    ".align 4						\n"	\
 	    ".globl " STATIC_CALL_TRAMP_STR(name) "		\n"	\
 	    STATIC_CALL_TRAMP_STR(name) ":			\n"	\
-	    "	.byte 0xe9 # jmp.d32				\n"	\
-	    "	.long " #func " - (. + 4)			\n"	\
+	    insns "						\n"	\
 	    ".type " STATIC_CALL_TRAMP_STR(name) ", @function	\n"	\
 	    ".size " STATIC_CALL_TRAMP_STR(name) ", . - " STATIC_CALL_TRAMP_STR(name) " \n" \
 	    ".popsection					\n")
 
+#define ARCH_DEFINE_STATIC_CALL_TRAMP(name, func)			\
+	__ARCH_DEFINE_STATIC_CALL_TRAMP(name, ".byte 0xe9; .long " #func " - (. + 4)")
+
+#define ARCH_DEFINE_STATIC_CALL_NULL_TRAMP(name)			\
+	__ARCH_DEFINE_STATIC_CALL_TRAMP(name, "ret; nop; nop; nop; nop")
+
 #endif /* _ASM_STATIC_CALL_H */
--- a/arch/x86/kernel/static_call.c
+++ b/arch/x86/kernel/static_call.c
@@ -4,19 +4,41 @@ 
 #include <linux/bug.h>
 #include <asm/text-patching.h>
 
-static void __static_call_transform(void *insn, u8 opcode, void *func)
+enum insn_type {
+	CALL = 0, /* site call */
+	NOP = 1,  /* site cond-call */
+	JMP = 2,  /* tramp / site tail-call */
+	RET = 3,  /* tramp / site cond-tail-call */
+};
+
+static void __static_call_transform(void *insn, enum insn_type type, void *func)
 {
-	const void *code = text_gen_insn(opcode, insn, func);
+	int size = CALL_INSN_SIZE;
+	const void *code;
 
-	if (WARN_ONCE(*(u8 *)insn != opcode,
-		      "unexpected static call insn opcode 0x%x at %pS\n",
-		      opcode, insn))
-		return;
+	switch (type) {
+	case CALL:
+		code = text_gen_insn(CALL_INSN_OPCODE, insn, func);
+		break;
+
+	case NOP:
+		code = ideal_nops[NOP_ATOMIC5];
+		break;
+
+	case JMP:
+		code = text_gen_insn(JMP32_INSN_OPCODE, insn, func);
+		break;
+
+	case RET:
+		code = text_gen_insn(RET_INSN_OPCODE, insn, func);
+		size = RET_INSN_SIZE;
+		break;
+	}
 
-	if (memcmp(insn, code, CALL_INSN_SIZE) == 0)
+	if (memcmp(insn, code, size) == 0)
 		return;
 
-	text_poke_bp(insn, code, CALL_INSN_SIZE, NULL);
+	text_poke_bp(insn, code, size, NULL);
 }
 
 void arch_static_call_transform(void *site, void *tramp, void *func)
@@ -24,10 +46,10 @@  void arch_static_call_transform(void *si
 	mutex_lock(&text_mutex);
 
 	if (tramp)
-		__static_call_transform(tramp, JMP32_INSN_OPCODE, func);
+		__static_call_transform(tramp, func ? JMP : RET, func);
 
 	if (IS_ENABLED(CONFIG_HAVE_STATIC_CALL_INLINE) && site)
-		__static_call_transform(site, CALL_INSN_OPCODE, func);
+		__static_call_transform(site, func ? CALL : NOP, func);
 
 	mutex_unlock(&text_mutex);
 }
--- a/include/linux/static_call.h
+++ b/include/linux/static_call.h
@@ -16,7 +16,9 @@ 
  *
  *   DECLARE_STATIC_CALL(name, func);
  *   DEFINE_STATIC_CALL(name, func);
+ *   DEFINE_STATIC_CALL_NULL(name, typename);
  *   static_call(name)(args...);
+ *   static_call_cond(name)(args...);
  *   static_call_update(name, func);
  *
  * Usage example:
@@ -52,6 +54,43 @@ 
  *   rather than calling through the trampoline.  This requires objtool or a
  *   compiler plugin to detect all the static_call() sites and annotate them
  *   in the .static_call_sites section.
+ *
+ *
+ * Notes on NULL function pointers:
+ *
+ *   Static_call()s support NULL functions, with many of the caveats that
+ *   regular function pointers have.
+ *
+ *   Clearly calling a NULL function pointer is 'BAD', so too for
+ *   static_call()s (although when HAVE_STATIC_CALL it might not be immediately
+ *   fatal). A NULL static_call can be the result of:
+ *
+ *     DECLARE_STATIC_CALL_NULL(my_static_call, void (*)(int));
+ *
+ *   which is equivalent to declaring a NULL function pointer with just a
+ *   typename:
+ *
+ *     void (*my_func_ptr)(int arg1) = NULL;
+ *
+ *   or using static_call_update() with a NULL function. In both cases the
+ *   HAVE_STATIC_CALL implementation will patch the trampoline with a RET
+ *   instruction, instead of an immediate tail-call JMP. HAVE_STATIC_CALL_INLINE
+ *   architectures can patch the trampoline call to a NOP.
+ *
+ *   In all cases, any argument evaluation is unconditional. Unlike a regular
+ *   conditional function pointer call:
+ *
+ *     if (my_func_ptr)
+ *         my_func_ptr(arg1)
+ *
+ *   where the argument evaludation also depends on the pointer value.
+ *
+ *   When calling a static_call that can be NULL, use:
+ *
+ *     static_call_cond(name)(arg1);
+ *
+ *   which will include the required value tests to avoid NULL-pointer
+ *   dereferences.
  */
 
 #include <linux/types.h>
@@ -120,7 +159,16 @@  extern int static_call_text_reserved(voi
 	};								\
 	ARCH_DEFINE_STATIC_CALL_TRAMP(name, _func)
 
+#define DEFINE_STATIC_CALL_NULL(name, _func)				\
+	DECLARE_STATIC_CALL(name, _func);				\
+	struct static_call_key STATIC_CALL_KEY(name) = {		\
+		.func = NULL,						\
+		.type = 1,						\
+	};								\
+	ARCH_DEFINE_STATIC_CALL_NULL_TRAMP(name)
+
 #define static_call(name)	__static_call(name)
+#define static_call_cond(name)	(void)__static_call(name)
 
 #define EXPORT_STATIC_CALL(name)					\
 	EXPORT_SYMBOL(STATIC_CALL_KEY(name));				\
@@ -143,7 +191,15 @@  struct static_call_key {
 	};								\
 	ARCH_DEFINE_STATIC_CALL_TRAMP(name, _func)
 
+#define DEFINE_STATIC_CALL_NULL(name, _func)				\
+	DECLARE_STATIC_CALL(name, _func);				\
+	struct static_call_key STATIC_CALL_KEY(name) = {		\
+		.func = NULL,						\
+	};								\
+	ARCH_DEFINE_STATIC_CALL_NULL_TRAMP(name)
+
 #define static_call(name)	__static_call(name)
+#define static_call_cond(name)	(void)__static_call(name)
 
 static inline
 void __static_call_update(struct static_call_key *key, void *tramp, void *func)
@@ -179,9 +235,39 @@  struct static_call_key {
 		.func = _func,						\
 	}
 
+#define DEFINE_STATIC_CALL_NULL(name, _func)				\
+	DECLARE_STATIC_CALL(name, _func);				\
+	struct static_call_key STATIC_CALL_KEY(name) = {		\
+		.func = NULL,						\
+	}
+
 #define static_call(name)						\
 	((typeof(STATIC_CALL_TRAMP(name))*)(STATIC_CALL_KEY(name).func))
 
+static inline void __static_call_nop(void) { }
+
+/*
+ * This horrific hack takes care of two things:
+ *
+ *  - it ensures the compiler will only load the function pointer ONCE,
+ *    which avoids a reload race.
+ *
+ *  - it ensures the argument evaluation is unconditional, similar
+ *    to the HAVE_STATIC_CALL variant.
+ *
+ * Sadly current GCC/Clang (10 for both) do not optimize this properly
+ * and will emit an indirect call for the NULL case :-(
+ */
+#define __static_call_cond(name)					\
+({									\
+	void *func = READ_ONCE(STATIC_CALL_KEY(name).func);		\
+	if (!func)							\
+		func = &__static_call_nop;				\
+	(typeof(STATIC_CALL_TRAMP(name))*)func;				\
+})
+
+#define static_call_cond(name)	(void)__static_call_cond(name)
+
 static inline
 void __static_call_update(struct static_call_key *key, void *tramp, void *func)
 {