[v1,2/3] x86/mm: Prepare sme_encrypt_kernel() for PAGE aligned encryption
diff mbox series

Message ID 20171207233402.29646.77306.stgit@tlendack-t1.amdoffice.net
State New, archived
Headers show
Series
  • x86: SME: BSP/SME microcode update fix
Related show

Commit Message

Tom Lendacky Dec. 7, 2017, 11:34 p.m. UTC
In preparation for encrypting more than just the kernel, the encryption
support in sme_encrypt_kernel() needs to support 4KB page aligned
encryption instead of just 2MB large page aligned encryption.

Update the routines that populate the PGD to support non-2MB aligned
addresses.  This is done by creating PTE page tables for the start
and end portion of the address range that fall outside of the 2MB
alignment.  This results in, at most, two extra pages to hold the
PTE entries for each mapping of a range.

Signed-off-by: Tom Lendacky <thomas.lendacky@amd.com>
---
 arch/x86/mm/mem_encrypt.c      |  115 ++++++++++++++++++++++++++++++++++------
 arch/x86/mm/mem_encrypt_boot.S |   20 +++++--
 2 files changed, 114 insertions(+), 21 deletions(-)

Comments

Borislav Petkov Dec. 21, 2017, 12:58 p.m. UTC | #1
On Thu, Dec 07, 2017 at 05:34:02PM -0600, Tom Lendacky wrote:
> In preparation for encrypting more than just the kernel, the encryption
> support in sme_encrypt_kernel() needs to support 4KB page aligned
> encryption instead of just 2MB large page aligned encryption.

...

>  static void __init __sme_map_range(pgd_t *pgd, unsigned long vaddr,
>  				   unsigned long vaddr_end,
> -				   unsigned long paddr, pmdval_t pmd_flags)
> +				   unsigned long paddr,
> +				   pmdval_t pmd_flags, pteval_t pte_flags)
>  {
> -	while (vaddr < vaddr_end) {
> -		sme_populate_pgd(pgd, vaddr, paddr, pmd_flags);
> +	if (vaddr & ~PMD_PAGE_MASK) {
> +		/* Start is not 2MB aligned, create PTE entries */
> +		unsigned long pmd_start = ALIGN(vaddr, PMD_PAGE_SIZE);
> +
> +		while (vaddr < pmd_start) {
> +			sme_populate_pgd(pgd, vaddr, paddr, pte_flags);
> +
> +			vaddr += PAGE_SIZE;
> +			paddr += PAGE_SIZE;
> +		}
> +	}

This chunk...

> +
> +	while (vaddr < (vaddr_end & PMD_PAGE_MASK)) {
> +		sme_populate_pgd_large(pgd, vaddr, paddr, pmd_flags);
>  
>  		vaddr += PMD_PAGE_SIZE;
>  		paddr += PMD_PAGE_SIZE;
>  	}
> +
> +	if (vaddr_end & ~PMD_PAGE_MASK) {
> +		/* End is not 2MB aligned, create PTE entries */
> +		while (vaddr < vaddr_end) {
> +			sme_populate_pgd(pgd, vaddr, paddr, pte_flags);
> +
> +			vaddr += PAGE_SIZE;
> +			paddr += PAGE_SIZE;
> +		}
> +	}

... and this chunk look like could be extracted in a wrapper as they're
pretty similar.

>  static void __init sme_map_range_encrypted(pgd_t *pgd,
> @@ -582,7 +658,8 @@ static void __init sme_map_range_encrypted(pgd_t *pgd,
>  					   unsigned long vaddr_end,
>  					   unsigned long paddr)
>  {
> -	__sme_map_range(pgd, vaddr, vaddr_end, paddr, PMD_FLAGS_ENC);
> +	__sme_map_range(pgd, vaddr, vaddr_end, paddr,
> +			PMD_FLAGS_ENC, PTE_FLAGS_ENC);
>  }
>  
>  static void __init sme_map_range_decrypted(pgd_t *pgd,
> @@ -590,7 +667,8 @@ static void __init sme_map_range_decrypted(pgd_t *pgd,
>  					   unsigned long vaddr_end,
>  					   unsigned long paddr)
>  {
> -	__sme_map_range(pgd, vaddr, vaddr_end, paddr, PMD_FLAGS_DEC);
> +	__sme_map_range(pgd, vaddr, vaddr_end, paddr,
> +			PMD_FLAGS_DEC, PTE_FLAGS_DEC);
>  }
>  
>  static void __init sme_map_range_decrypted_wp(pgd_t *pgd,
> @@ -598,19 +676,20 @@ static void __init sme_map_range_decrypted_wp(pgd_t *pgd,
>  					      unsigned long vaddr_end,
>  					      unsigned long paddr)
>  {
> -	__sme_map_range(pgd, vaddr, vaddr_end, paddr, PMD_FLAGS_DEC_WP);
> +	__sme_map_range(pgd, vaddr, vaddr_end, paddr,
> +			PMD_FLAGS_DEC_WP, PTE_FLAGS_DEC_WP);
>  }
>  
>  static unsigned long __init sme_pgtable_calc(unsigned long len)
>  {
> -	unsigned long p4d_size, pud_size, pmd_size;
> +	unsigned long p4d_size, pud_size, pmd_size, pte_size;
>  	unsigned long total;
>  
>  	/*
>  	 * Perform a relatively simplistic calculation of the pagetable
> -	 * entries that are needed. That mappings will be covered by 2MB
> -	 * PMD entries so we can conservatively calculate the required
> -	 * number of P4D, PUD and PMD structures needed to perform the
> +	 * entries that are needed. That mappings will be covered mostly

				    Those

> +	 * by 2MB PMD entries so we can conservatively calculate the required
> +	 * number of P4D, PUD, PMD and PTE structures needed to perform the
>  	 * mappings. Incrementing the count for each covers the case where
>  	 * the addresses cross entries.
>  	 */
> @@ -626,8 +705,9 @@ static unsigned long __init sme_pgtable_calc(unsigned long len)
>  	}
>  	pmd_size = (ALIGN(len, PUD_SIZE) / PUD_SIZE) + 1;
>  	pmd_size *= sizeof(pmd_t) * PTRS_PER_PMD;
> +	pte_size = 2 * sizeof(pte_t) * PTRS_PER_PTE;

This needs the explanation from the commit message about the 2 extra
pages, as a comment above it.

> -	total = p4d_size + pud_size + pmd_size;
> +	total = p4d_size + pud_size + pmd_size + pte_size;
>  
>  	/*
>  	 * Now calculate the added pagetable structures needed to populate
> @@ -710,10 +790,13 @@ void __init sme_encrypt_kernel(void)
>  
>  	/*
>  	 * The total workarea includes the executable encryption area and
> -	 * the pagetable area.
> +	 * the pagetable area. The start of the workarea is already 2MB
> +	 * aligned, align the end of the workarea on a 2MB boundary so that
> +	 * we don't try to create/allocate PTE entries from the workarea
> +	 * before it is mapped.
>  	 */
>  	workarea_len = execute_len + pgtable_area_len;
> -	workarea_end = workarea_start + workarea_len;
> +	workarea_end = ALIGN(workarea_start + workarea_len, PMD_PAGE_SIZE);
>  
>  	/*
>  	 * Set the address to the start of where newly created pagetable
> diff --git a/arch/x86/mm/mem_encrypt_boot.S b/arch/x86/mm/mem_encrypt_boot.S
> index 730e6d5..20cca86 100644
> --- a/arch/x86/mm/mem_encrypt_boot.S
> +++ b/arch/x86/mm/mem_encrypt_boot.S
> @@ -120,23 +120,33 @@ ENTRY(__enc_copy)
>  
>  	wbinvd				/* Invalidate any cache entries */
>  
> +	push	%r12
> +
>  	/* Copy/encrypt 2MB at a time */
> +	movq	$PMD_PAGE_SIZE, %r12
>  1:
> +	cmpq	%r12, %r9
> +	jnb	2f
> +	movq	%r9, %r12
> +
> +2:
>  	movq	%r11, %rsi		/* Source - decrypted kernel */
>  	movq	%r8, %rdi		/* Dest   - intermediate copy buffer */
> -	movq	$PMD_PAGE_SIZE, %rcx	/* 2MB length */
> +	movq	%r12, %rcx
>  	rep	movsb
>  
>  	movq	%r8, %rsi		/* Source - intermediate copy buffer */
>  	movq	%r10, %rdi		/* Dest   - encrypted kernel */
> -	movq	$PMD_PAGE_SIZE, %rcx	/* 2MB length */
> +	movq	%r12, %rcx
>  	rep	movsb
>  
> -	addq	$PMD_PAGE_SIZE, %r11
> -	addq	$PMD_PAGE_SIZE, %r10
> -	subq	$PMD_PAGE_SIZE, %r9	/* Kernel length decrement */
> +	addq	%r12, %r11
> +	addq	%r12, %r10
> +	subq	%r12, %r9		/* Kernel length decrement */
>  	jnz	1b			/* Kernel length not zero? */
>  
> +	pop	%r12
> +
>  	/* Restore PAT register */
>  	push	%rdx			/* Save original PAT value */
>  	movl	$MSR_IA32_CR_PAT, %ecx

Right, for this here can you pls do a cleanup pre-patch which does

	push
	push
	push

	/* meat of the function */

	pop
	pop
	pop

as now those pushes and pops are interspersed with the rest of the insns
and that makes following through it harder. F17h has a stack engine so
those streams of pushes and pops won't hurt perf and besides, this is
not even perf-sensitive.

Thx.
Tom Lendacky Dec. 21, 2017, 4:35 p.m. UTC | #2
On 12/21/2017 6:58 AM, Borislav Petkov wrote:
> On Thu, Dec 07, 2017 at 05:34:02PM -0600, Tom Lendacky wrote:
>> In preparation for encrypting more than just the kernel, the encryption
>> support in sme_encrypt_kernel() needs to support 4KB page aligned
>> encryption instead of just 2MB large page aligned encryption.
> 
> ...
> 
>>  static void __init __sme_map_range(pgd_t *pgd, unsigned long vaddr,
>>  				   unsigned long vaddr_end,
>> -				   unsigned long paddr, pmdval_t pmd_flags)
>> +				   unsigned long paddr,
>> +				   pmdval_t pmd_flags, pteval_t pte_flags)
>>  {
>> -	while (vaddr < vaddr_end) {
>> -		sme_populate_pgd(pgd, vaddr, paddr, pmd_flags);
>> +	if (vaddr & ~PMD_PAGE_MASK) {
>> +		/* Start is not 2MB aligned, create PTE entries */
>> +		unsigned long pmd_start = ALIGN(vaddr, PMD_PAGE_SIZE);
>> +
>> +		while (vaddr < pmd_start) {
>> +			sme_populate_pgd(pgd, vaddr, paddr, pte_flags);
>> +
>> +			vaddr += PAGE_SIZE;
>> +			paddr += PAGE_SIZE;
>> +		}
>> +	}
> 
> This chunk...
> 
>> +
>> +	while (vaddr < (vaddr_end & PMD_PAGE_MASK)) {
>> +		sme_populate_pgd_large(pgd, vaddr, paddr, pmd_flags);
>>  
>>  		vaddr += PMD_PAGE_SIZE;
>>  		paddr += PMD_PAGE_SIZE;
>>  	}
>> +
>> +	if (vaddr_end & ~PMD_PAGE_MASK) {
>> +		/* End is not 2MB aligned, create PTE entries */
>> +		while (vaddr < vaddr_end) {
>> +			sme_populate_pgd(pgd, vaddr, paddr, pte_flags);
>> +
>> +			vaddr += PAGE_SIZE;
>> +			paddr += PAGE_SIZE;
>> +		}
>> +	}
> 
> ... and this chunk look like could be extracted in a wrapper as they're
> pretty similar.
> 

Should be doable, I'll take care of it.

>>  static void __init sme_map_range_encrypted(pgd_t *pgd,
>> @@ -582,7 +658,8 @@ static void __init sme_map_range_encrypted(pgd_t *pgd,
>>  					   unsigned long vaddr_end,
>>  					   unsigned long paddr)
>>  {
>> -	__sme_map_range(pgd, vaddr, vaddr_end, paddr, PMD_FLAGS_ENC);
>> +	__sme_map_range(pgd, vaddr, vaddr_end, paddr,
>> +			PMD_FLAGS_ENC, PTE_FLAGS_ENC);
>>  }
>>  
>>  static void __init sme_map_range_decrypted(pgd_t *pgd,
>> @@ -590,7 +667,8 @@ static void __init sme_map_range_decrypted(pgd_t *pgd,
>>  					   unsigned long vaddr_end,
>>  					   unsigned long paddr)
>>  {
>> -	__sme_map_range(pgd, vaddr, vaddr_end, paddr, PMD_FLAGS_DEC);
>> +	__sme_map_range(pgd, vaddr, vaddr_end, paddr,
>> +			PMD_FLAGS_DEC, PTE_FLAGS_DEC);
>>  }
>>  
>>  static void __init sme_map_range_decrypted_wp(pgd_t *pgd,
>> @@ -598,19 +676,20 @@ static void __init sme_map_range_decrypted_wp(pgd_t *pgd,
>>  					      unsigned long vaddr_end,
>>  					      unsigned long paddr)
>>  {
>> -	__sme_map_range(pgd, vaddr, vaddr_end, paddr, PMD_FLAGS_DEC_WP);
>> +	__sme_map_range(pgd, vaddr, vaddr_end, paddr,
>> +			PMD_FLAGS_DEC_WP, PTE_FLAGS_DEC_WP);
>>  }
>>  
>>  static unsigned long __init sme_pgtable_calc(unsigned long len)
>>  {
>> -	unsigned long p4d_size, pud_size, pmd_size;
>> +	unsigned long p4d_size, pud_size, pmd_size, pte_size;
>>  	unsigned long total;
>>  
>>  	/*
>>  	 * Perform a relatively simplistic calculation of the pagetable
>> -	 * entries that are needed. That mappings will be covered by 2MB
>> -	 * PMD entries so we can conservatively calculate the required
>> -	 * number of P4D, PUD and PMD structures needed to perform the
>> +	 * entries that are needed. That mappings will be covered mostly
> 
> 				    Those
> 

Got it.

>> +	 * by 2MB PMD entries so we can conservatively calculate the required
>> +	 * number of P4D, PUD, PMD and PTE structures needed to perform the
>>  	 * mappings. Incrementing the count for each covers the case where
>>  	 * the addresses cross entries.
>>  	 */
>> @@ -626,8 +705,9 @@ static unsigned long __init sme_pgtable_calc(unsigned long len)
>>  	}
>>  	pmd_size = (ALIGN(len, PUD_SIZE) / PUD_SIZE) + 1;
>>  	pmd_size *= sizeof(pmd_t) * PTRS_PER_PMD;
>> +	pte_size = 2 * sizeof(pte_t) * PTRS_PER_PTE;
> 
> This needs the explanation from the commit message about the 2 extra
> pages, as a comment above it.

Yup, I'll include the info about the (possible) extra PTE entries.

> 
>> -	total = p4d_size + pud_size + pmd_size;
>> +	total = p4d_size + pud_size + pmd_size + pte_size;
>>  
>>  	/*
>>  	 * Now calculate the added pagetable structures needed to populate
>> @@ -710,10 +790,13 @@ void __init sme_encrypt_kernel(void)
>>  
>>  	/*
>>  	 * The total workarea includes the executable encryption area and
>> -	 * the pagetable area.
>> +	 * the pagetable area. The start of the workarea is already 2MB
>> +	 * aligned, align the end of the workarea on a 2MB boundary so that
>> +	 * we don't try to create/allocate PTE entries from the workarea
>> +	 * before it is mapped.
>>  	 */
>>  	workarea_len = execute_len + pgtable_area_len;
>> -	workarea_end = workarea_start + workarea_len;
>> +	workarea_end = ALIGN(workarea_start + workarea_len, PMD_PAGE_SIZE);
>>  
>>  	/*
>>  	 * Set the address to the start of where newly created pagetable
>> diff --git a/arch/x86/mm/mem_encrypt_boot.S b/arch/x86/mm/mem_encrypt_boot.S
>> index 730e6d5..20cca86 100644
>> --- a/arch/x86/mm/mem_encrypt_boot.S
>> +++ b/arch/x86/mm/mem_encrypt_boot.S
>> @@ -120,23 +120,33 @@ ENTRY(__enc_copy)
>>  
>>  	wbinvd				/* Invalidate any cache entries */
>>  
>> +	push	%r12
>> +
>>  	/* Copy/encrypt 2MB at a time */
>> +	movq	$PMD_PAGE_SIZE, %r12
>>  1:
>> +	cmpq	%r12, %r9
>> +	jnb	2f
>> +	movq	%r9, %r12
>> +
>> +2:
>>  	movq	%r11, %rsi		/* Source - decrypted kernel */
>>  	movq	%r8, %rdi		/* Dest   - intermediate copy buffer */
>> -	movq	$PMD_PAGE_SIZE, %rcx	/* 2MB length */
>> +	movq	%r12, %rcx
>>  	rep	movsb
>>  
>>  	movq	%r8, %rsi		/* Source - intermediate copy buffer */
>>  	movq	%r10, %rdi		/* Dest   - encrypted kernel */
>> -	movq	$PMD_PAGE_SIZE, %rcx	/* 2MB length */
>> +	movq	%r12, %rcx
>>  	rep	movsb
>>  
>> -	addq	$PMD_PAGE_SIZE, %r11
>> -	addq	$PMD_PAGE_SIZE, %r10
>> -	subq	$PMD_PAGE_SIZE, %r9	/* Kernel length decrement */
>> +	addq	%r12, %r11
>> +	addq	%r12, %r10
>> +	subq	%r12, %r9		/* Kernel length decrement */
>>  	jnz	1b			/* Kernel length not zero? */
>>  
>> +	pop	%r12
>> +
>>  	/* Restore PAT register */
>>  	push	%rdx			/* Save original PAT value */
>>  	movl	$MSR_IA32_CR_PAT, %ecx
> 
> Right, for this here can you pls do a cleanup pre-patch which does
> 
> 	push
> 	push
> 	push
> 
> 	/* meat of the function */
> 
> 	pop
> 	pop
> 	pop
> 
> as now those pushes and pops are interspersed with the rest of the insns
> and that makes following through it harder. F17h has a stack engine so
> those streams of pushes and pops won't hurt perf and besides, this is
> not even perf-sensitive.
> 

Ok, I'll clean that up to move any pushes/pops to the beginning/end of
the function, as well as moving some register saving earlier which may
eliminate some push/pop instructions.

Thanks,
Tom

> Thx.
>

Patch
diff mbox series

diff --git a/arch/x86/mm/mem_encrypt.c b/arch/x86/mm/mem_encrypt.c
index 2d8404b..1f0efb8 100644
--- a/arch/x86/mm/mem_encrypt.c
+++ b/arch/x86/mm/mem_encrypt.c
@@ -486,6 +486,7 @@  static void __init sme_clear_pgd(pgd_t *pgd_base, unsigned long start,
 #define PGD_FLAGS	_KERNPG_TABLE_NOENC
 #define P4D_FLAGS	_KERNPG_TABLE_NOENC
 #define PUD_FLAGS	_KERNPG_TABLE_NOENC
+#define PMD_FLAGS	_KERNPG_TABLE_NOENC
 
 #define PMD_FLAGS_LARGE		(__PAGE_KERNEL_LARGE_EXEC & ~_PAGE_GLOBAL)
 
@@ -494,8 +495,14 @@  static void __init sme_clear_pgd(pgd_t *pgd_base, unsigned long start,
 				 (_PAGE_PAT | _PAGE_PWT))
 #define PMD_FLAGS_ENC		(PMD_FLAGS_LARGE | _PAGE_ENC)
 
-static void __init sme_populate_pgd(pgd_t *pgd_base, unsigned long vaddr,
-				    unsigned long paddr, pmdval_t pmd_flags)
+#define PTE_FLAGS		(__PAGE_KERNEL_EXEC & ~_PAGE_GLOBAL)
+
+#define PTE_FLAGS_DEC		PTE_FLAGS
+#define PTE_FLAGS_DEC_WP	((PTE_FLAGS_DEC & ~_PAGE_CACHE_MASK) | \
+				 (_PAGE_PAT | _PAGE_PWT))
+#define PTE_FLAGS_ENC		(PTE_FLAGS | _PAGE_ENC)
+
+static pmd_t __init *sme_prepare_pgd(pgd_t *pgd_base, unsigned long vaddr)
 {
 	pgd_t *pgd_p;
 	p4d_t *p4d_p;
@@ -546,7 +553,7 @@  static void __init sme_populate_pgd(pgd_t *pgd_base, unsigned long vaddr,
 	pud_p += pud_index(vaddr);
 	if (native_pud_val(*pud_p)) {
 		if (native_pud_val(*pud_p) & _PAGE_PSE)
-			return;
+			return NULL;
 
 		pmd_p = (pmd_t *)(native_pud_val(*pud_p) & ~PTE_FLAGS_MASK);
 	} else {
@@ -560,21 +567,90 @@  static void __init sme_populate_pgd(pgd_t *pgd_base, unsigned long vaddr,
 		native_set_pud(pud_p, pud);
 	}
 
+	return pmd_p;
+}
+
+static void __init sme_populate_pgd_large(pgd_t *pgd, unsigned long vaddr,
+					  unsigned long paddr,
+					  pmdval_t pmd_flags)
+{
+	pmd_t *pmd_p;
+
+	pmd_p = sme_prepare_pgd(pgd, vaddr);
+	if (!pmd_p)
+		return;
+
 	pmd_p += pmd_index(vaddr);
 	if (!native_pmd_val(*pmd_p) || !(native_pmd_val(*pmd_p) & _PAGE_PSE))
 		native_set_pmd(pmd_p, native_make_pmd(paddr | pmd_flags));
 }
 
+static void __init sme_populate_pgd(pgd_t *pgd, unsigned long vaddr,
+				    unsigned long paddr,
+				    pteval_t pte_flags)
+{
+	pmd_t *pmd_p;
+	pte_t *pte_p;
+
+	pmd_p = sme_prepare_pgd(pgd, vaddr);
+	if (!pmd_p)
+		return;
+
+	pmd_p += pmd_index(vaddr);
+	if (native_pmd_val(*pmd_p)) {
+		if (native_pmd_val(*pmd_p) & _PAGE_PSE)
+			return;
+
+		pte_p = (pte_t *)(native_pmd_val(*pmd_p) & ~PTE_FLAGS_MASK);
+	} else {
+		pmd_t pmd;
+
+		pte_p = pgtable_area;
+		memset(pte_p, 0, sizeof(*pte_p) * PTRS_PER_PTE);
+		pgtable_area += sizeof(*pte_p) * PTRS_PER_PTE;
+
+		pmd = native_make_pmd((pteval_t)pte_p + PMD_FLAGS);
+		native_set_pmd(pmd_p, pmd);
+	}
+
+	pte_p += pte_index(vaddr);
+	if (!native_pte_val(*pte_p))
+		native_set_pte(pte_p, native_make_pte(paddr | pte_flags));
+}
+
 static void __init __sme_map_range(pgd_t *pgd, unsigned long vaddr,
 				   unsigned long vaddr_end,
-				   unsigned long paddr, pmdval_t pmd_flags)
+				   unsigned long paddr,
+				   pmdval_t pmd_flags, pteval_t pte_flags)
 {
-	while (vaddr < vaddr_end) {
-		sme_populate_pgd(pgd, vaddr, paddr, pmd_flags);
+	if (vaddr & ~PMD_PAGE_MASK) {
+		/* Start is not 2MB aligned, create PTE entries */
+		unsigned long pmd_start = ALIGN(vaddr, PMD_PAGE_SIZE);
+
+		while (vaddr < pmd_start) {
+			sme_populate_pgd(pgd, vaddr, paddr, pte_flags);
+
+			vaddr += PAGE_SIZE;
+			paddr += PAGE_SIZE;
+		}
+	}
+
+	while (vaddr < (vaddr_end & PMD_PAGE_MASK)) {
+		sme_populate_pgd_large(pgd, vaddr, paddr, pmd_flags);
 
 		vaddr += PMD_PAGE_SIZE;
 		paddr += PMD_PAGE_SIZE;
 	}
+
+	if (vaddr_end & ~PMD_PAGE_MASK) {
+		/* End is not 2MB aligned, create PTE entries */
+		while (vaddr < vaddr_end) {
+			sme_populate_pgd(pgd, vaddr, paddr, pte_flags);
+
+			vaddr += PAGE_SIZE;
+			paddr += PAGE_SIZE;
+		}
+	}
 }
 
 static void __init sme_map_range_encrypted(pgd_t *pgd,
@@ -582,7 +658,8 @@  static void __init sme_map_range_encrypted(pgd_t *pgd,
 					   unsigned long vaddr_end,
 					   unsigned long paddr)
 {
-	__sme_map_range(pgd, vaddr, vaddr_end, paddr, PMD_FLAGS_ENC);
+	__sme_map_range(pgd, vaddr, vaddr_end, paddr,
+			PMD_FLAGS_ENC, PTE_FLAGS_ENC);
 }
 
 static void __init sme_map_range_decrypted(pgd_t *pgd,
@@ -590,7 +667,8 @@  static void __init sme_map_range_decrypted(pgd_t *pgd,
 					   unsigned long vaddr_end,
 					   unsigned long paddr)
 {
-	__sme_map_range(pgd, vaddr, vaddr_end, paddr, PMD_FLAGS_DEC);
+	__sme_map_range(pgd, vaddr, vaddr_end, paddr,
+			PMD_FLAGS_DEC, PTE_FLAGS_DEC);
 }
 
 static void __init sme_map_range_decrypted_wp(pgd_t *pgd,
@@ -598,19 +676,20 @@  static void __init sme_map_range_decrypted_wp(pgd_t *pgd,
 					      unsigned long vaddr_end,
 					      unsigned long paddr)
 {
-	__sme_map_range(pgd, vaddr, vaddr_end, paddr, PMD_FLAGS_DEC_WP);
+	__sme_map_range(pgd, vaddr, vaddr_end, paddr,
+			PMD_FLAGS_DEC_WP, PTE_FLAGS_DEC_WP);
 }
 
 static unsigned long __init sme_pgtable_calc(unsigned long len)
 {
-	unsigned long p4d_size, pud_size, pmd_size;
+	unsigned long p4d_size, pud_size, pmd_size, pte_size;
 	unsigned long total;
 
 	/*
 	 * Perform a relatively simplistic calculation of the pagetable
-	 * entries that are needed. That mappings will be covered by 2MB
-	 * PMD entries so we can conservatively calculate the required
-	 * number of P4D, PUD and PMD structures needed to perform the
+	 * entries that are needed. That mappings will be covered mostly
+	 * by 2MB PMD entries so we can conservatively calculate the required
+	 * number of P4D, PUD, PMD and PTE structures needed to perform the
 	 * mappings. Incrementing the count for each covers the case where
 	 * the addresses cross entries.
 	 */
@@ -626,8 +705,9 @@  static unsigned long __init sme_pgtable_calc(unsigned long len)
 	}
 	pmd_size = (ALIGN(len, PUD_SIZE) / PUD_SIZE) + 1;
 	pmd_size *= sizeof(pmd_t) * PTRS_PER_PMD;
+	pte_size = 2 * sizeof(pte_t) * PTRS_PER_PTE;
 
-	total = p4d_size + pud_size + pmd_size;
+	total = p4d_size + pud_size + pmd_size + pte_size;
 
 	/*
 	 * Now calculate the added pagetable structures needed to populate
@@ -710,10 +790,13 @@  void __init sme_encrypt_kernel(void)
 
 	/*
 	 * The total workarea includes the executable encryption area and
-	 * the pagetable area.
+	 * the pagetable area. The start of the workarea is already 2MB
+	 * aligned, align the end of the workarea on a 2MB boundary so that
+	 * we don't try to create/allocate PTE entries from the workarea
+	 * before it is mapped.
 	 */
 	workarea_len = execute_len + pgtable_area_len;
-	workarea_end = workarea_start + workarea_len;
+	workarea_end = ALIGN(workarea_start + workarea_len, PMD_PAGE_SIZE);
 
 	/*
 	 * Set the address to the start of where newly created pagetable
diff --git a/arch/x86/mm/mem_encrypt_boot.S b/arch/x86/mm/mem_encrypt_boot.S
index 730e6d5..20cca86 100644
--- a/arch/x86/mm/mem_encrypt_boot.S
+++ b/arch/x86/mm/mem_encrypt_boot.S
@@ -120,23 +120,33 @@  ENTRY(__enc_copy)
 
 	wbinvd				/* Invalidate any cache entries */
 
+	push	%r12
+
 	/* Copy/encrypt 2MB at a time */
+	movq	$PMD_PAGE_SIZE, %r12
 1:
+	cmpq	%r12, %r9
+	jnb	2f
+	movq	%r9, %r12
+
+2:
 	movq	%r11, %rsi		/* Source - decrypted kernel */
 	movq	%r8, %rdi		/* Dest   - intermediate copy buffer */
-	movq	$PMD_PAGE_SIZE, %rcx	/* 2MB length */
+	movq	%r12, %rcx
 	rep	movsb
 
 	movq	%r8, %rsi		/* Source - intermediate copy buffer */
 	movq	%r10, %rdi		/* Dest   - encrypted kernel */
-	movq	$PMD_PAGE_SIZE, %rcx	/* 2MB length */
+	movq	%r12, %rcx
 	rep	movsb
 
-	addq	$PMD_PAGE_SIZE, %r11
-	addq	$PMD_PAGE_SIZE, %r10
-	subq	$PMD_PAGE_SIZE, %r9	/* Kernel length decrement */
+	addq	%r12, %r11
+	addq	%r12, %r10
+	subq	%r12, %r9		/* Kernel length decrement */
 	jnz	1b			/* Kernel length not zero? */
 
+	pop	%r12
+
 	/* Restore PAT register */
 	push	%rdx			/* Save original PAT value */
 	movl	$MSR_IA32_CR_PAT, %ecx