All of lore.kernel.org
 help / color / mirror / Atom feed
* [PATCH] Shared page tables during fork
@ 2021-07-01 13:46 Kaiyang Zhao
  2021-07-06 18:47 ` Souptick Joarder
                   ` (2 more replies)
  0 siblings, 3 replies; 4+ messages in thread
From: Kaiyang Zhao @ 2021-07-01 13:46 UTC (permalink / raw)
  To: akpm; +Cc: linux-mm, Kaiyang Zhao

In our research work [https://dl.acm.org/doi/10.1145/3447786.3456258], we
have identified a method that for large applications (i.e., a few hundred
MBs and larger), can significantly speed up the fork system call. Currently
the amount of time that the fork system call takes to complete is
proportional to the size of allocated memory of a process, and our design
speeds up fork invocation by up to 270x at 50GB in our experiments.

The design is that instead of copying the entire paging tree during the
fork invocation, we make the child and the parent process share the same
set of last-level page tables, which will be reference counted. To preserve
the copy-on-write semantics, we disable the write permission in PMD entries
in fork, and copy PTE tables as needed in the page fault handler.

We tested a prototype with large workloads that call fork to take snapshots
such as fuzzers (e.g., AFL), and it yielded over 2x the execution
throughput for AFL. The patch is a prototype for x86 only and does not
support huge pages and swapping, and is meant to demonstrate the potential
performance gains to fork. Applications can opt-in by a switch use_odf in
procfs.

On a side note, an approach that shares page tables was proposed by Dave
McCracken [http://lkml.iu.edu/hypermail/linux/kernel/0508.3/1623.html,
https://www.kernel.org/doc/ols/2006/ols2006v2-pages-125-130.pdf], but never
made it into the kernel.  We believe that with the increasing memory
consumption of modern applications and modern use cases of fork such as
snapshotting, the shared page table approach in the context of fork is
worth exploring.

Please let us know your level of interest in this or comments on the
general design. Thank you.

Signed-off-by: Kaiyang Zhao <zhao776@purdue.edu>
---
 arch/x86/include/asm/pgtable.h |  19 +-
 fs/proc/base.c                 |  74 ++++++
 include/linux/mm.h             |  11 +
 include/linux/mm_types.h       |   2 +
 include/linux/pgtable.h        |  11 +
 include/linux/sched/coredump.h |   5 +-
 kernel/fork.c                  |   7 +-
 mm/gup.c                       |  61 ++++-
 mm/memory.c                    | 401 +++++++++++++++++++++++++++++++--
 mm/mmap.c                      |  91 +++++++-
 mm/mprotect.c                  |   6 +
 11 files changed, 668 insertions(+), 20 deletions(-)

diff --git a/arch/x86/include/asm/pgtable.h b/arch/x86/include/asm/pgtable.h
index b6c97b8f59ec..0fda05a5c7a1 100644
--- a/arch/x86/include/asm/pgtable.h
+++ b/arch/x86/include/asm/pgtable.h
@@ -410,6 +410,16 @@ static inline pmd_t pmd_clear_flags(pmd_t pmd, pmdval_t clear)
 	return native_make_pmd(v & ~clear);
 }
 
+static inline pmd_t pmd_mknonpresent(pmd_t pmd)
+{
+	return pmd_clear_flags(pmd, _PAGE_PRESENT);
+}
+
+static inline pmd_t pmd_mkpresent(pmd_t pmd)
+{
+	return pmd_set_flags(pmd, _PAGE_PRESENT);
+}
+
 #ifdef CONFIG_HAVE_ARCH_USERFAULTFD_WP
 static inline int pmd_uffd_wp(pmd_t pmd)
 {
@@ -798,6 +808,11 @@ static inline int pmd_present(pmd_t pmd)
 	return pmd_flags(pmd) & (_PAGE_PRESENT | _PAGE_PROTNONE | _PAGE_PSE);
 }
 
+static inline int pmd_iswrite(pmd_t pmd)
+{
+	return pmd_flags(pmd) & (_PAGE_RW);
+}
+
 #ifdef CONFIG_NUMA_BALANCING
 /*
  * These work without NUMA balancing but the kernel does not care. See the
@@ -833,7 +848,7 @@ static inline unsigned long pmd_page_vaddr(pmd_t pmd)
  * Currently stuck as a macro due to indirect forward reference to
  * linux/mmzone.h's __section_mem_map_addr() definition:
  */
-#define pmd_page(pmd)	pfn_to_page(pmd_pfn(pmd))
+#define pmd_page(pmd)	pfn_to_page(pmd_pfn(pmd_mkpresent(pmd)))
 
 /*
  * Conversion functions: convert a page and protection to a page entry,
@@ -846,7 +861,7 @@ static inline unsigned long pmd_page_vaddr(pmd_t pmd)
 
 static inline int pmd_bad(pmd_t pmd)
 {
-	return (pmd_flags(pmd) & ~_PAGE_USER) != _KERNPG_TABLE;
+	return ((pmd_flags(pmd) & ~(_PAGE_USER)) | (_PAGE_RW | _PAGE_PRESENT)) != _KERNPG_TABLE;
 }
 
 static inline unsigned long pages_to_mb(unsigned long npg)
diff --git a/fs/proc/base.c b/fs/proc/base.c
index e5b5f7709d48..936f33594539 100644
--- a/fs/proc/base.c
+++ b/fs/proc/base.c
@@ -2935,6 +2935,79 @@ static const struct file_operations proc_coredump_filter_operations = {
 };
 #endif
 
+static ssize_t proc_use_odf_read(struct file *file, char __user *buf,
+					 size_t count, loff_t *ppos)
+{
+	struct task_struct *task = get_proc_task(file_inode(file));
+	struct mm_struct *mm;
+	char buffer[PROC_NUMBUF];
+	size_t len;
+	int ret;
+
+	if (!task)
+		return -ESRCH;
+
+	ret = 0;
+	mm = get_task_mm(task);
+	if (mm) {
+		len = snprintf(buffer, sizeof(buffer), "%lu\n",
+			  ((mm->flags & MMF_USE_ODF_MASK) >> MMF_USE_ODF));
+		mmput(mm);
+		ret = simple_read_from_buffer(buf, count, ppos, buffer, len);
+	}
+
+	put_task_struct(task);
+
+	return ret;
+}
+
+static ssize_t proc_use_odf_write(struct file *file,
+					  const char __user *buf,
+					  size_t count,
+					  loff_t *ppos)
+{
+	struct task_struct *task;
+	struct mm_struct *mm;
+	unsigned int val;
+	int ret;
+
+	ret = kstrtouint_from_user(buf, count, 0, &val);
+	if (ret < 0)
+		return ret;
+
+	ret = -ESRCH;
+	task = get_proc_task(file_inode(file));
+	if (!task)
+		goto out_no_task;
+
+	mm = get_task_mm(task);
+	if (!mm)
+		goto out_no_mm;
+	ret = 0;
+
+	if (val == 1) {
+		set_bit(MMF_USE_ODF, &mm->flags);
+	} else if (val == 0) {
+		clear_bit(MMF_USE_ODF, &mm->flags);
+	} else {
+		//ignore
+	}
+
+	mmput(mm);
+ out_no_mm:
+	put_task_struct(task);
+ out_no_task:
+	if (ret < 0)
+		return ret;
+	return count;
+}
+
+static const struct file_operations proc_use_odf_operations = {
+	.read		= proc_use_odf_read,
+	.write		= proc_use_odf_write,
+	.llseek		= generic_file_llseek,
+};
+
 #ifdef CONFIG_TASK_IO_ACCOUNTING
 static int do_io_accounting(struct task_struct *task, struct seq_file *m, int whole)
 {
@@ -3253,6 +3326,7 @@ static const struct pid_entry tgid_base_stuff[] = {
 #ifdef CONFIG_ELF_CORE
 	REG("coredump_filter", S_IRUGO|S_IWUSR, proc_coredump_filter_operations),
 #endif
+	REG("use_odf", S_IRUGO|S_IWUSR, proc_use_odf_operations),
 #ifdef CONFIG_TASK_IO_ACCOUNTING
 	ONE("io",	S_IRUSR, proc_tgid_io_accounting),
 #endif
diff --git a/include/linux/mm.h b/include/linux/mm.h
index 57453dba41b9..a30eca9e236a 100644
--- a/include/linux/mm.h
+++ b/include/linux/mm.h
@@ -664,6 +664,7 @@ static inline void vma_init(struct vm_area_struct *vma, struct mm_struct *mm)
 	memset(vma, 0, sizeof(*vma));
 	vma->vm_mm = mm;
 	vma->vm_ops = &dummy_vm_ops;
+	vma->pte_table_counter_pending = true;
 	INIT_LIST_HEAD(&vma->anon_vma_chain);
 }
 
@@ -2250,6 +2251,9 @@ static inline bool pgtable_pte_page_ctor(struct page *page)
 		return false;
 	__SetPageTable(page);
 	inc_lruvec_page_state(page, NR_PAGETABLE);
+
+	atomic64_set(&(page->pte_table_refcount), 0);
+
 	return true;
 }
 
@@ -2276,6 +2280,8 @@ static inline void pgtable_pte_page_dtor(struct page *page)
 
 #define pte_alloc(mm, pmd) (unlikely(pmd_none(*(pmd))) && __pte_alloc(mm, pmd))
 
+#define tfork_pte_alloc(mm, pmd) (__tfork_pte_alloc(mm, pmd))
+
 #define pte_alloc_map(mm, pmd, address)			\
 	(pte_alloc(mm, pmd) ? NULL : pte_offset_map(pmd, address))
 
@@ -2283,6 +2289,10 @@ static inline void pgtable_pte_page_dtor(struct page *page)
 	(pte_alloc(mm, pmd) ?			\
 		 NULL : pte_offset_map_lock(mm, pmd, address, ptlp))
 
+#define tfork_pte_alloc_map_lock(mm, pmd, address, ptlp)	\
+	(tfork_pte_alloc(mm, pmd) ?			\
+		 NULL : pte_offset_map_lock(mm, pmd, address, ptlp))
+
 #define pte_alloc_kernel(pmd, address)			\
 	((unlikely(pmd_none(*(pmd))) && __pte_alloc_kernel(pmd))? \
 		NULL: pte_offset_kernel(pmd, address))
@@ -2616,6 +2626,7 @@ extern int do_madvise(struct mm_struct *mm, unsigned long start, size_t len_in,
 #ifdef CONFIG_MMU
 extern int __mm_populate(unsigned long addr, unsigned long len,
 			 int ignore_errors);
+extern int __mm_populate_nolock(unsigned long addr, unsigned long len, int ignore_errors);
 static inline void mm_populate(unsigned long addr, unsigned long len)
 {
 	/* Ignore errors */
diff --git a/include/linux/mm_types.h b/include/linux/mm_types.h
index f37abb2d222e..e06c677ce279 100644
--- a/include/linux/mm_types.h
+++ b/include/linux/mm_types.h
@@ -158,6 +158,7 @@ struct page {
 			union {
 				struct mm_struct *pt_mm; /* x86 pgds only */
 				atomic_t pt_frag_refcount; /* powerpc */
+				atomic64_t pte_table_refcount;
 			};
 #if USE_SPLIT_PTE_PTLOCKS
 #if ALLOC_SPLIT_PTLOCKS
@@ -379,6 +380,7 @@ struct vm_area_struct {
 	struct mempolicy *vm_policy;	/* NUMA policy for the VMA */
 #endif
 	struct vm_userfaultfd_ctx vm_userfaultfd_ctx;
+	bool pte_table_counter_pending;
 } __randomize_layout;
 
 struct core_thread {
diff --git a/include/linux/pgtable.h b/include/linux/pgtable.h
index d147480cdefc..6afd77ff82e6 100644
--- a/include/linux/pgtable.h
+++ b/include/linux/pgtable.h
@@ -90,6 +90,11 @@ static inline pte_t *pte_offset_kernel(pmd_t *pmd, unsigned long address)
 	return (pte_t *)pmd_page_vaddr(*pmd) + pte_index(address);
 }
 #define pte_offset_kernel pte_offset_kernel
+static inline pte_t *tfork_pte_offset_kernel(pmd_t pmd_val, unsigned long address)
+{
+	return (pte_t *)pmd_page_vaddr(pmd_val) + pte_index(address);
+}
+#define tfork_pte_offset_kernel tfork_pte_offset_kernel
 #endif
 
 #if defined(CONFIG_HIGHPTE)
@@ -782,6 +787,12 @@ static inline void arch_swap_restore(swp_entry_t entry, struct page *page)
 })
 #endif
 
+#define pte_table_start(addr)						\
+(addr & PMD_MASK)
+
+#define pte_table_end(addr)                                             \
+(((addr) + PMD_SIZE) & PMD_MASK)
+
 /*
  * When walking page tables, we usually want to skip any p?d_none entries;
  * and any p?d_bad entries - reporting the error before resetting to none.
diff --git a/include/linux/sched/coredump.h b/include/linux/sched/coredump.h
index 4d9e3a656875..8f6e50bc04ab 100644
--- a/include/linux/sched/coredump.h
+++ b/include/linux/sched/coredump.h
@@ -83,7 +83,10 @@ static inline int get_dumpable(struct mm_struct *mm)
 #define MMF_HAS_PINNED		28	/* FOLL_PIN has run, never cleared */
 #define MMF_DISABLE_THP_MASK	(1 << MMF_DISABLE_THP)
 
+#define MMF_USE_ODF 29
+#define MMF_USE_ODF_MASK (1 << MMF_USE_ODF)
+
 #define MMF_INIT_MASK		(MMF_DUMPABLE_MASK | MMF_DUMP_FILTER_MASK |\
-				 MMF_DISABLE_THP_MASK)
+				 MMF_DISABLE_THP_MASK | MMF_USE_ODF_MASK)
 
 #endif /* _LINUX_SCHED_COREDUMP_H */
diff --git a/kernel/fork.c b/kernel/fork.c
index d738aae40f9e..4f21ea4f4f38 100644
--- a/kernel/fork.c
+++ b/kernel/fork.c
@@ -594,8 +594,13 @@ static __latent_entropy int dup_mmap(struct mm_struct *mm,
 		rb_parent = &tmp->vm_rb;
 
 		mm->map_count++;
-		if (!(tmp->vm_flags & VM_WIPEONFORK))
+		if (!(tmp->vm_flags & VM_WIPEONFORK)) {
 			retval = copy_page_range(tmp, mpnt);
+			if (oldmm->flags & MMF_USE_ODF_MASK) {
+				tmp->pte_table_counter_pending = false; // reference of the shared PTE table by the new VMA is counted in copy_pmd_range_tfork
+				mpnt->pte_table_counter_pending = false; // don't double count when forking again
+			}
+		}
 
 		if (tmp->vm_ops && tmp->vm_ops->open)
 			tmp->vm_ops->open(tmp);
diff --git a/mm/gup.c b/mm/gup.c
index 42b8b1fa6521..5768f339b0ff 100644
--- a/mm/gup.c
+++ b/mm/gup.c
@@ -1489,8 +1489,11 @@ long populate_vma_page_range(struct vm_area_struct *vma,
 	 * to break COW, except for shared mappings because these don't COW
 	 * and we would not want to dirty them for nothing.
 	 */
-	if ((vma->vm_flags & (VM_WRITE | VM_SHARED)) == VM_WRITE)
-		gup_flags |= FOLL_WRITE;
+	if ((vma->vm_flags & (VM_WRITE | VM_SHARED)) == VM_WRITE) {
+		if (!(mm->flags & MMF_USE_ODF_MASK)) { //for ODF processes, only allocate page tables
+			gup_flags |= FOLL_WRITE;
+		}
+	}
 
 	/*
 	 * We want mlock to succeed for regions that have any permissions
@@ -1669,6 +1672,60 @@ static long __get_user_pages_locked(struct mm_struct *mm, unsigned long start,
 }
 #endif /* !CONFIG_MMU */
 
+int __mm_populate_nolock(unsigned long start, unsigned long len, int ignore_errors)
+{
+	struct mm_struct *mm = current->mm;
+	unsigned long end, nstart, nend;
+	struct vm_area_struct *vma = NULL;
+	int locked = 0;
+	long ret = 0;
+
+	end = start + len;
+
+	for (nstart = start; nstart < end; nstart = nend) {
+		/*
+		 * We want to fault in pages for [nstart; end) address range.
+		 * Find first corresponding VMA.
+		 */
+		if (!locked) {
+			locked = 1;
+			//down_read(&mm->mmap_sem);
+			vma = find_vma(mm, nstart);
+		} else if (nstart >= vma->vm_end)
+			vma = vma->vm_next;
+		if (!vma || vma->vm_start >= end)
+			break;
+		/*
+		 * Set [nstart; nend) to intersection of desired address
+		 * range with the first VMA. Also, skip undesirable VMA types.
+		 */
+		nend = min(end, vma->vm_end);
+		if (vma->vm_flags & (VM_IO | VM_PFNMAP))
+			continue;
+		if (nstart < vma->vm_start)
+			nstart = vma->vm_start;
+		/*
+		 * Now fault in a range of pages. populate_vma_page_range()
+		 * double checks the vma flags, so that it won't mlock pages
+		 * if the vma was already munlocked.
+		 */
+		ret = populate_vma_page_range(vma, nstart, nend, &locked);
+		if (ret < 0) {
+			if (ignore_errors) {
+				ret = 0;
+				continue;	/* continue at next VMA */
+			}
+			break;
+		}
+		nend = nstart + ret * PAGE_SIZE;
+		ret = 0;
+	}
+	/*if (locked)
+		up_read(&mm->mmap_sem);
+	*/
+	return ret;	/* 0 or negative error code */
+}
+
 /**
  * get_dump_page() - pin user page in memory while writing it to core dump
  * @addr: user address
diff --git a/mm/memory.c b/mm/memory.c
index db86558791f1..2b28766e4213 100644
--- a/mm/memory.c
+++ b/mm/memory.c
@@ -83,6 +83,9 @@
 #include <asm/tlb.h>
 #include <asm/tlbflush.h>
 
+static bool tfork_one_pte_table(struct mm_struct *, pmd_t *, unsigned long, unsigned long);
+static inline void init_rss_vec(int *rss);
+static inline void add_mm_rss_vec(struct mm_struct *mm, int *rss);
 #include "pgalloc-track.h"
 #include "internal.h"
 
@@ -227,7 +230,16 @@ static void free_pte_range(struct mmu_gather *tlb, pmd_t *pmd,
 			   unsigned long addr)
 {
 	pgtable_t token = pmd_pgtable(*pmd);
+	long counter;
 	pmd_clear(pmd);
+	counter = atomic64_read(&(token->pte_table_refcount));
+	if (counter > 0) {
+		//the pte table can only be shared in this case
+#ifdef CONFIG_DEBUG_VM
+		printk("free_pte_range: addr=%lx, counter=%ld, not freeing table", addr, counter);
+#endif
+		return;  //pte table is still in use
+	}
 	pte_free_tlb(tlb, token, addr);
 	mm_dec_nr_ptes(tlb->mm);
 }
@@ -433,6 +445,118 @@ void free_pgtables(struct mmu_gather *tlb, struct vm_area_struct *vma,
 	}
 }
 
+// frees every page described by the pte table
+void zap_one_pte_table(pmd_t pmd_val, unsigned long addr, struct mm_struct *mm)
+{
+	int rss[NR_MM_COUNTERS];
+	pte_t *pte;
+	unsigned long end;
+
+	init_rss_vec(rss);
+	addr = pte_table_start(addr);
+	end = pte_table_end(addr);
+	pte = tfork_pte_offset_kernel(pmd_val, addr);
+	do {
+		pte_t ptent = *pte;
+
+		if (pte_none(ptent))
+			continue;
+
+		if (pte_present(ptent)) {
+			struct page *page;
+
+			if (pte_special(ptent)) { //known special pte: vvar VMA, which has just one page shared system-wide. Shouldn't matter
+				continue;
+			}
+			page = vm_normal_page(NULL, addr, ptent); //kyz : vma is not important
+			if (unlikely(!page))
+				continue;
+			rss[mm_counter(page)]--;
+#ifdef CONFIG_DEBUG_VM
+			//   printk("zap_one_pte_table: addr=%lx, end=%lx, (before) mapcount=%d, refcount=%d\n", addr, end, page_mapcount(page), page_ref_count(page));
+#endif
+			page_remove_rmap(page, false);
+			put_page(page);
+		}
+	} while (pte++, addr += PAGE_SIZE, addr != end);
+
+	add_mm_rss_vec(mm, rss);
+}
+
+/* pmd lock should be held
+ * returns 1 if the table becomes unused
+ */
+int dereference_pte_table(pmd_t pmd_val, bool free_table, struct mm_struct *mm, unsigned long addr)
+{
+	struct page *table_page;
+
+	table_page = pmd_page(pmd_val);
+
+	if (atomic64_dec_and_test(&(table_page->pte_table_refcount))) {
+#ifdef CONFIG_DEBUG_VM
+		printk("dereference_pte_table: addr=%lx, free_table=%d, pte table reached end of life\n", addr, free_table);
+#endif
+
+		zap_one_pte_table(pmd_val, addr, mm);
+		if (free_table) {
+			pgtable_pte_page_dtor(table_page);
+			__free_page(table_page);
+			mm_dec_nr_ptes(mm);
+		}
+		return 1;
+	} else {
+#ifdef CONFIG_DEBUG_VM
+		printk("dereference_pte_table: addr=%lx, (after) pte_table_count=%lld\n", addr, atomic64_read(&(table_page->pte_table_refcount)));
+#endif
+	}
+	return 0;
+}
+
+int dereference_pte_table_multiple(pmd_t pmd_val, bool free_table, struct mm_struct *mm, unsigned long addr, int num)
+{
+	struct page *table_page;
+	int count_after;
+
+	table_page = pmd_page(pmd_val);
+	count_after = atomic64_sub_return(num, &(table_page->pte_table_refcount));
+	if (count_after <= 0) {
+#ifdef CONFIG_DEBUG_VM
+		printk("dereference_pte_table_multiple: addr=%lx, free_table=%d, num=%d, after count=%d, table reached end of life\n", addr, free_table, num, count_after);
+#endif
+
+		zap_one_pte_table(pmd_val, addr, mm);
+		if (free_table) {
+			pgtable_pte_page_dtor(table_page);
+			__free_page(table_page);
+			mm_dec_nr_ptes(mm);
+		}
+		return 1;
+	} else {
+#ifdef CONFIG_DEBUG_VM
+		printk("dereference_pte_table_multiple: addr=%lx, num=%d, (after) count=%lld\n", addr, num, atomic64_read(&(table_page->pte_table_refcount)));
+#endif
+	}
+	return 0;
+}
+
+int __tfork_pte_alloc(struct mm_struct *mm, pmd_t *pmd)
+{
+	pgtable_t new = pte_alloc_one(mm);
+
+	if (!new)
+		return -ENOMEM;
+	smp_wmb(); /* Could be smp_wmb__xxx(before|after)_spin_lock */
+
+	mm_inc_nr_ptes(mm);
+	//kyz: won't check if the pte table already exists
+	pmd_populate(mm, pmd, new);
+	new = NULL;
+	if (new)
+		pte_free(mm, new);
+	return 0;
+}
+
+
 int __pte_alloc(struct mm_struct *mm, pmd_t *pmd)
 {
 	spinlock_t *ptl;
@@ -928,6 +1052,45 @@ copy_present_page(struct vm_area_struct *dst_vma, struct vm_area_struct *src_vma
 	return 0;
 }
 
+static inline unsigned long
+copy_one_pte_tfork(struct mm_struct *dst_mm,
+		pte_t *dst_pte, pte_t *src_pte, struct vm_area_struct *vma,
+		unsigned long addr, int *rss)
+{
+	unsigned long vm_flags = vma->vm_flags;
+	pte_t pte = *src_pte;
+	struct page *page;
+
+	/*
+	 * If it's a COW mapping
+	 * only protect in the child (the faulting process)
+	 */
+	if (is_cow_mapping(vm_flags) && pte_write(pte)) {
+		pte = pte_wrprotect(pte);
+	}
+
+	/*
+	 * If it's a shared mapping, mark it clean in
+	 * the child
+	 */
+	if (vm_flags & VM_SHARED)
+		pte = pte_mkclean(pte);
+	pte = pte_mkold(pte);
+
+	page = vm_normal_page(vma, addr, pte);
+	if (page) {
+		get_page(page);
+		page_dup_rmap(page, false);
+		rss[mm_counter(page)]++;
+#ifdef CONFIG_DEBUG_VM
+//		printk("copy_one_pte_tfork: addr=%lx, (after) mapcount=%d, refcount=%d\n", addr, page_mapcount(page), page_ref_count(page));
+#endif
+	}
+
+	set_pte_at(dst_mm, addr, dst_pte, pte);
+	return 0;
+}
+
 /*
  * Copy one pte.  Returns 0 if succeeded, or -EAGAIN if one preallocated page
  * is required to copy this pte.
@@ -999,6 +1162,59 @@ page_copy_prealloc(struct mm_struct *src_mm, struct vm_area_struct *vma,
 	return new_page;
 }
 
+static int copy_pte_range_tfork(struct mm_struct *dst_mm,
+		   pmd_t *dst_pmd, pmd_t src_pmd_val, struct vm_area_struct *vma,
+		   unsigned long addr, unsigned long end)
+{
+	pte_t *orig_src_pte, *orig_dst_pte;
+	pte_t *src_pte, *dst_pte;
+	spinlock_t *dst_ptl;
+	int rss[NR_MM_COUNTERS];
+	swp_entry_t entry = (swp_entry_t){0};
+	struct page *dst_pte_page;
+
+	init_rss_vec(rss);
+
+	src_pte = tfork_pte_offset_kernel(src_pmd_val, addr); //src_pte points to the old table
+	if (!pmd_iswrite(*dst_pmd)) {
+		dst_pte = tfork_pte_alloc_map_lock(dst_mm, dst_pmd, addr, &dst_ptl);  //dst_pte points to a new table
+#ifdef CONFIG_DEBUG_VM
+		printk("copy_pte_range_tfork: allocated new table. addr=%lx, prev_table_page=%px, table_page=%px\n", addr, pmd_page(src_pmd_val), pmd_page(*dst_pmd));
+#endif
+	} else {
+		dst_pte = pte_alloc_map_lock(dst_mm, dst_pmd, addr, &dst_ptl);
+	}
+	if (!dst_pte)
+		return -ENOMEM;
+
+	dst_pte_page = pmd_page(*dst_pmd);
+	atomic64_inc(&(dst_pte_page->pte_table_refcount)); //kyz: associates the VMA with the new table
+#ifdef CONFIG_DEBUG_VM
+	printk("copy_pte_range_tfork: addr = %lx, end = %lx, new pte table counter (after)=%lld\n", addr, end, atomic64_read(&(dst_pte_page->pte_table_refcount)));
+#endif
+
+	orig_src_pte = src_pte;
+	orig_dst_pte = dst_pte;
+	arch_enter_lazy_mmu_mode();
+
+	do {
+		if (pte_none(*src_pte)) {
+			continue;
+		}
+		entry.val = copy_one_pte_tfork(dst_mm, dst_pte, src_pte,
+							vma, addr, rss);
+		if (entry.val)
+			printk("kyz: failed copy_one_pte_tfork call\n");
+	} while (dst_pte++, src_pte++, addr += PAGE_SIZE, addr != end);
+
+	arch_leave_lazy_mmu_mode();
+	pte_unmap(orig_src_pte);
+	add_mm_rss_vec(dst_mm, rss);
+	pte_unmap_unlock(orig_dst_pte, dst_ptl);
+
+	return 0;
+}
+
 static int
 copy_pte_range(struct vm_area_struct *dst_vma, struct vm_area_struct *src_vma,
 	       pmd_t *dst_pmd, pmd_t *src_pmd, unsigned long addr,
@@ -1130,8 +1346,9 @@ copy_pmd_range(struct vm_area_struct *dst_vma, struct vm_area_struct *src_vma,
 {
 	struct mm_struct *dst_mm = dst_vma->vm_mm;
 	struct mm_struct *src_mm = src_vma->vm_mm;
-	pmd_t *src_pmd, *dst_pmd;
+	pmd_t *src_pmd, *dst_pmd, src_pmd_value;
 	unsigned long next;
+	struct page *table_page;
 
 	dst_pmd = pmd_alloc(dst_mm, dst_pud, addr);
 	if (!dst_pmd)
@@ -1153,9 +1370,43 @@ copy_pmd_range(struct vm_area_struct *dst_vma, struct vm_area_struct *src_vma,
 		}
 		if (pmd_none_or_clear_bad(src_pmd))
 			continue;
-		if (copy_pte_range(dst_vma, src_vma, dst_pmd, src_pmd,
-				   addr, next))
-			return -ENOMEM;
+		if (src_mm->flags & MMF_USE_ODF_MASK) {
+#ifdef CONFIG_DEBUG_VM
+			printk("copy_pmd_range: vm_start=%lx, addr=%lx, vm_end=%lx, end=%lx\n", src_vma->vm_start, addr, src_vma->vm_end, end);
+#endif
+
+			src_pmd_value = *src_pmd;
+			//kyz: sets write-protect to the pmd entry if the vma is writable
+			if (src_vma->vm_flags & VM_WRITE) {
+				src_pmd_value = pmd_wrprotect(src_pmd_value);
+				set_pmd_at(src_mm, addr, src_pmd, src_pmd_value);
+			}
+			table_page = pmd_page(*src_pmd);
+			if (src_vma->pte_table_counter_pending) { // kyz : the old VMA hasn't been counted in the PTE table, count it now
+				atomic64_add(2, &(table_page->pte_table_refcount));
+#ifdef CONFIG_DEBUG_VM
+				printk("copy_pmd_range: addr=%lx, pte table counter (after counting old&new)=%lld\n", addr, atomic64_read(&(table_page->pte_table_refcount)));
+#endif
+			} else {
+				atomic64_inc(&(table_page->pte_table_refcount));  //increments the pte table counter
+				if (atomic64_read(&(table_page->pte_table_refcount)) == 1) {  //the VMA is old, but the pte table is new (created by a fault after the last odf call)
+					atomic64_set(&(table_page->pte_table_refcount), 2);
+#ifdef CONFIG_DEBUG_VM
+					printk("copy_pmd_range: addr=%lx, pte table counter (old VMA, new pte table)=%lld\n", addr, atomic64_read(&(table_page->pte_table_refcount)));
+#endif
+				}
+#ifdef CONFIG_DEBUG_VM
+				else {
+					printk("copy_pmd_range: addr=%lx, pte table counter (after counting new)=%lld\n", addr, atomic64_read(&(table_page->pte_table_refcount)));
+				}
+#endif
+			}
+			set_pmd_at(dst_mm, addr, dst_pmd, src_pmd_value);  //shares the table with the child
+		} else {
+			if (copy_pte_range(dst_vma, src_vma, dst_pmd, src_pmd,
+						addr, next))
+				return -ENOMEM;
+		}
 	} while (dst_pmd++, src_pmd++, addr = next, addr != end);
 	return 0;
 }
@@ -1240,9 +1491,10 @@ copy_page_range(struct vm_area_struct *dst_vma, struct vm_area_struct *src_vma)
 	 * readonly mappings. The tradeoff is that copy_page_range is more
 	 * efficient than faulting.
 	 */
-	if (!(src_vma->vm_flags & (VM_HUGETLB | VM_PFNMAP | VM_MIXEDMAP)) &&
+/*	if (!(src_vma->vm_flags & (VM_HUGETLB | VM_PFNMAP | VM_MIXEDMAP)) &&
 	    !src_vma->anon_vma)
 		return 0;
+*/
 
 	if (is_vm_hugetlb_page(src_vma))
 		return copy_hugetlb_page_range(dst_mm, src_mm, src_vma);
@@ -1304,7 +1556,7 @@ copy_page_range(struct vm_area_struct *dst_vma, struct vm_area_struct *src_vma)
 static unsigned long zap_pte_range(struct mmu_gather *tlb,
 				struct vm_area_struct *vma, pmd_t *pmd,
 				unsigned long addr, unsigned long end,
-				struct zap_details *details)
+				struct zap_details *details, bool invalidate_pmd)
 {
 	struct mm_struct *mm = tlb->mm;
 	int force_flush = 0;
@@ -1343,8 +1595,10 @@ static unsigned long zap_pte_range(struct mmu_gather *tlb,
 				    details->check_mapping != page_rmapping(page))
 					continue;
 			}
-			ptent = ptep_get_and_clear_full(mm, addr, pte,
-							tlb->fullmm);
+			if (!invalidate_pmd) {
+				ptent = ptep_get_and_clear_full(mm, addr, pte,
+						tlb->fullmm);
+			}
 			tlb_remove_tlb_entry(tlb, pte, addr);
 			if (unlikely(!page))
 				continue;
@@ -1358,8 +1612,12 @@ static unsigned long zap_pte_range(struct mmu_gather *tlb,
 				    likely(!(vma->vm_flags & VM_SEQ_READ)))
 					mark_page_accessed(page);
 			}
-			rss[mm_counter(page)]--;
-			page_remove_rmap(page, false);
+			if (!invalidate_pmd) {
+				rss[mm_counter(page)]--;
+				page_remove_rmap(page, false);
+			} else {
+				continue;
+			}
 			if (unlikely(page_mapcount(page) < 0))
 				print_bad_pte(vma, addr, ptent, page);
 			if (unlikely(__tlb_remove_page(tlb, page))) {
@@ -1446,12 +1704,16 @@ static inline unsigned long zap_pmd_range(struct mmu_gather *tlb,
 				struct zap_details *details)
 {
 	pmd_t *pmd;
-	unsigned long next;
+	unsigned long next, table_start, table_end;
+	spinlock_t *ptl;
+	struct page *table_page;
+	bool got_new_table = false;
 
 	pmd = pmd_offset(pud, addr);
 	do {
+		ptl = pmd_lock(vma->vm_mm, pmd);
 		next = pmd_addr_end(addr, end);
-		if (is_swap_pmd(*pmd) || pmd_trans_huge(*pmd) || pmd_devmap(*pmd)) {
+		if (pmd_trans_huge(*pmd) || pmd_devmap(*pmd)) {
 			if (next - addr != HPAGE_PMD_SIZE)
 				__split_huge_pmd(vma, pmd, addr, false, NULL);
 			else if (zap_huge_pmd(tlb, vma, pmd, addr))
@@ -1478,8 +1740,49 @@ static inline unsigned long zap_pmd_range(struct mmu_gather *tlb,
 		 */
 		if (pmd_none_or_trans_huge_or_clear_bad(pmd))
 			goto next;
-		next = zap_pte_range(tlb, vma, pmd, addr, next, details);
+		//kyz: copy if the pte table is shared and VMA does not cover fully the 2MB region
+		table_page = pmd_page(*pmd);
+		table_start = pte_table_start(addr);
+
+		if ((!pmd_iswrite(*pmd)) && (!vma->pte_table_counter_pending)) {//shared pte table. vma has gone through odf
+			table_end = pte_table_end(addr);
+			if (table_start < vma->vm_start || table_end > vma->vm_end) {
+#ifdef CONFIG_DEBUG_VM
+				printk("%s: addr=%lx, end=%lx, table_start=%lx, table_end=%lx, copy then zap\n", __func__, addr, end, table_start, table_end);
+#endif
+				if (dereference_pte_table(*pmd, false, vma->vm_mm, addr) != 1) { //dec the counter of the shared table. tfork_one_pte_table cannot find the current VMA (which is being unmapped)
+					got_new_table = tfork_one_pte_table(vma->vm_mm, pmd, addr, vma->vm_end);
+					if (got_new_table) {
+						next = zap_pte_range(tlb, vma, pmd, addr, next, details, false);
+					} else {
+#ifdef CONFIG_DEBUG_VM
+						printk("zap_pmd_range: no more VMAs in this process are using the table, but there are other processes using it\n");
+#endif
+						pmd_clear(pmd);
+					}
+				} else {
+#ifdef CONFIG_DEBUG_VM
+					printk("zap_pmd_range: the shared table is dead. NOT copying after all.\n");
+#endif
+					// the shared table will be freed by unmap_single_vma()
+				}
+			} else {
+#ifdef CONFIG_DEBUG_VM
+				printk("%s: addr=%lx, end=%lx, table_start=%lx, table_end=%lx, zap while preserving pte entries\n", __func__, addr, end, table_start, table_end);
+#endif
+				//kyz: shared and fully covered by the VMA, preserve the pte entries
+				next = zap_pte_range(tlb, vma, pmd, addr, next, details, true);
+				dereference_pte_table(*pmd, true, vma->vm_mm, addr);
+				pmd_clear(pmd);
+			}
+		} else {
+			next = zap_pte_range(tlb, vma, pmd, addr, next, details, false);
+			if (!vma->pte_table_counter_pending) {
+				atomic64_dec(&(table_page->pte_table_refcount));
+			}
+		}
 next:
+		spin_unlock(ptl);
 		cond_resched();
 	} while (pmd++, addr = next, addr != end);
 
@@ -4476,6 +4779,66 @@ static vm_fault_t wp_huge_pud(struct vm_fault *vmf, pud_t orig_pud)
 	return VM_FAULT_FALLBACK;
 }
 
+/*  kyz: Handles an entire pte-level page table, covering multiple VMAs (if they exist)
+ *  Returns true if a new table is put in place, false otherwise.
+ *  if exclude is not 0, the vma that covers addr to exclude will not be copied
+ */
+static bool tfork_one_pte_table(struct mm_struct *mm, pmd_t *dst_pmd, unsigned long addr, unsigned long exclude)
+{
+	unsigned long table_end, end, orig_addr;
+	struct vm_area_struct *vma;
+	pmd_t orig_pmd_val;
+	bool copied = false;
+	struct page *orig_pte_page;
+	int num_vmas = 0;
+
+	if (!pmd_none(*dst_pmd)) {
+		orig_pmd_val = *dst_pmd;
+	} else {
+		BUG();
+	}
+
+	//kyz: Starts from the beginning of the range covered by the table
+	orig_addr = addr;
+	table_end = pte_table_end(addr);
+	addr = pte_table_start(addr);
+#ifdef CONFIG_DEBUG_VM
+	orig_pte_page = pmd_page(orig_pmd_val);
+	printk("tfork_one_pte_table: shared pte table counter=%lld, Covered Range: start=%lx, end=%lx\n", atomic64_read(&(orig_pte_page->pte_table_refcount)), addr, table_end);
+#endif
+	do {
+		vma = find_vma(mm, addr);
+		if (!vma) {
+			break;  //inexplicable
+		}
+		if (vma->vm_start >= table_end) {
+			break;
+		}
+		end = pmd_addr_end(addr, vma->vm_end);
+		if (vma->pte_table_counter_pending) {  //this vma is newly mapped (clean) and (fully/partly) described by this pte table
+			addr = end;
+			continue;
+		}
+		if (vma->vm_start > addr) {
+			addr = vma->vm_start;
+		}
+		if (exclude > 0 && vma->vm_start <= orig_addr && vma->vm_end >= exclude) {
+			addr = end;
+			continue;
+		}
+#ifdef CONFIG_DEBUG_VM
+		printk("tfork_one_pte_table: vm_start=%lx, vm_end=%lx\n", vma->vm_start, vma->vm_end);
+#endif
+		num_vmas++;
+		copy_pte_range_tfork(mm, dst_pmd, orig_pmd_val, vma, addr, end);
+		copied = true;
+		addr = end;
+	} while (addr < table_end);
+
+	dereference_pte_table_multiple(orig_pmd_val, true, mm, orig_addr, num_vmas);
+	return copied;
+}
+
 /*
  * These routines also need to handle stuff like marking pages dirty
  * and/or accessed for architectures that don't do it in hardware (most
@@ -4610,6 +4973,7 @@ static vm_fault_t __handle_mm_fault(struct vm_area_struct *vma,
 	pgd_t *pgd;
 	p4d_t *p4d;
 	vm_fault_t ret;
+	spinlock_t *ptl;
 
 	pgd = pgd_offset(mm, address);
 	p4d = p4d_alloc(mm, pgd, address);
@@ -4659,6 +5023,7 @@ static vm_fault_t __handle_mm_fault(struct vm_area_struct *vma,
 		vmf.orig_pmd = *vmf.pmd;
 
 		barrier();
+		/*
 		if (unlikely(is_swap_pmd(vmf.orig_pmd))) {
 			VM_BUG_ON(thp_migration_supported() &&
 					  !is_pmd_migration_entry(vmf.orig_pmd));
@@ -4666,6 +5031,7 @@ static vm_fault_t __handle_mm_fault(struct vm_area_struct *vma,
 				pmd_migration_entry_wait(mm, vmf.pmd);
 			return 0;
 		}
+		*/
 		if (pmd_trans_huge(vmf.orig_pmd) || pmd_devmap(vmf.orig_pmd)) {
 			if (pmd_protnone(vmf.orig_pmd) && vma_is_accessible(vma))
 				return do_huge_pmd_numa_page(&vmf);
@@ -4679,6 +5045,15 @@ static vm_fault_t __handle_mm_fault(struct vm_area_struct *vma,
 				return 0;
 			}
 		}
+		//kyz: checks if the pmd entry prohibits writes
+		if ((!pmd_none(vmf.orig_pmd)) && (!pmd_iswrite(vmf.orig_pmd)) && (vma->vm_flags & VM_WRITE)) {
+#ifdef CONFIG_DEBUG_VM
+			printk("__handle_mm_fault: PID=%d, addr=%lx\n", current->pid, address);
+#endif
+			ptl = pmd_lock(mm, vmf.pmd);
+			tfork_one_pte_table(mm, vmf.pmd, vmf.address, 0u);
+			spin_unlock(ptl);
+		}
 	}
 
 	return handle_pte_fault(&vmf);
diff --git a/mm/mmap.c b/mm/mmap.c
index ca54d36d203a..308d86cfe544 100644
--- a/mm/mmap.c
+++ b/mm/mmap.c
@@ -47,6 +47,7 @@
 #include <linux/pkeys.h>
 #include <linux/oom.h>
 #include <linux/sched/mm.h>
+#include <linux/pagewalk.h>
 
 #include <linux/uaccess.h>
 #include <asm/cacheflush.h>
@@ -276,6 +277,9 @@ SYSCALL_DEFINE1(brk, unsigned long, brk)
 
 success:
 	populate = newbrk > oldbrk && (mm->def_flags & VM_LOCKED) != 0;
+	if (mm->flags & MMF_USE_ODF_MASK) { //for ODF
+		populate = true;
+	}
 	if (downgraded)
 		mmap_read_unlock(mm);
 	else
@@ -1115,6 +1119,50 @@ can_vma_merge_after(struct vm_area_struct *vma, unsigned long vm_flags,
 	return 0;
 }
 
+static int pgtable_counter_fixup_pmd_entry(pmd_t *pmd, unsigned long addr,
+			       unsigned long next, struct mm_walk *walk)
+{
+	struct page *table_page;
+
+	table_page = pmd_page(*pmd);
+	atomic64_inc(&(table_page->pte_table_refcount));
+
+#ifdef CONFIG_DEBUG_VM
+	printk("fixup inc: addr=%lx\n", addr);
+#endif
+
+	walk->action = ACTION_CONTINUE;  //skip pte level
+	return 0;
+}
+
+static int pgtable_counter_fixup_test(unsigned long addr, unsigned long next,
+			  struct mm_walk *walk)
+{
+	return 0;
+}
+
+static const struct mm_walk_ops pgtable_counter_fixup_walk_ops = {
+.pmd_entry = pgtable_counter_fixup_pmd_entry,
+.test_walk = pgtable_counter_fixup_test
+};
+
+int merge_vma_pgtable_counter_fixup(struct vm_area_struct *vma, unsigned long start, unsigned long end)
+{
+	if (vma->pte_table_counter_pending) {
+		return 0;
+	} else {
+#ifdef CONFIG_DEBUG_VM
+		printk("merge fixup: vm_start=%lx, vm_end=%lx, inc start=%lx, inc end=%lx\n", vma->vm_start, vma->vm_end, start, end);
+#endif
+		start = pte_table_end(start);
+		end = pte_table_start(end);
+		__mm_populate_nolock(start, end-start, 1); //popuate tables for extended address range so that we can increment counters
+		walk_page_range(vma->vm_mm, start, end, &pgtable_counter_fixup_walk_ops, NULL);
+	}
+
+	return 0;
+}
+
 /*
  * Given a mapping request (addr,end,vm_flags,file,pgoff), figure out
  * whether that can be merged with its predecessor or its successor.
@@ -1215,6 +1263,9 @@ struct vm_area_struct *vma_merge(struct mm_struct *mm,
 		if (err)
 			return NULL;
 		khugepaged_enter_vma_merge(prev, vm_flags);
+
+		merge_vma_pgtable_counter_fixup(prev, addr, end);
+
 		return prev;
 	}
 
@@ -1242,6 +1293,9 @@ struct vm_area_struct *vma_merge(struct mm_struct *mm,
 		if (err)
 			return NULL;
 		khugepaged_enter_vma_merge(area, vm_flags);
+
+		merge_vma_pgtable_counter_fixup(area, addr, end);
+
 		return area;
 	}
 
@@ -1584,8 +1638,15 @@ unsigned long do_mmap(struct file *file, unsigned long addr,
 	addr = mmap_region(file, addr, len, vm_flags, pgoff, uf);
 	if (!IS_ERR_VALUE(addr) &&
 	    ((vm_flags & VM_LOCKED) ||
-	     (flags & (MAP_POPULATE | MAP_NONBLOCK)) == MAP_POPULATE))
+	     (flags & (MAP_POPULATE | MAP_NONBLOCK)) == MAP_POPULATE ||
+	     (mm->flags & MMF_USE_ODF_MASK))) {
+#ifdef CONFIG_DEBUG_VM
+		if (mm->flags & MMF_USE_ODF_MASK) {
+			printk("mmap: force populate, addr=%lx, len=%lx\n", addr, len);
+		}
+#endif
 		*populate = len;
+	}
 	return addr;
 }
 
@@ -2799,6 +2860,31 @@ int split_vma(struct mm_struct *mm, struct vm_area_struct *vma,
 	return __split_vma(mm, vma, addr, new_below);
 }
 
+/* left and right vma after the split, address of split */
+int split_vma_pgtable_counter_fixup(struct vm_area_struct *lvma, struct vm_area_struct *rvma, bool orig_pending_flag)
+{
+	if (orig_pending_flag) {
+		return 0;  //the new vma will have pending flag as true by default, just as the old vma
+	} else {
+#ifdef CONFIG_DEBUG_VM
+		printk("split fixup: set vma flag to false, rvma_start=%lx\n", rvma->vm_start);
+#endif
+		lvma->pte_table_counter_pending = false;
+		rvma->pte_table_counter_pending = false;
+
+		if (pte_table_start(rvma->vm_start) == rvma->vm_start) {  //the split was right at the pte table boundary
+			return 0;  //the only case where we don't increment pte table counter
+		} else {
+#ifdef CONFIG_DEBUG_VM
+			printk("split fixup: rvma_start=%lx\n", rvma->vm_start);
+#endif
+			walk_page_range(rvma->vm_mm, pte_table_start(rvma->vm_start), pte_table_end(rvma->vm_start), &pgtable_counter_fixup_walk_ops, NULL);
+		}
+	}
+
+	return 0;
+}
+
 static inline void
 unlock_range(struct vm_area_struct *start, unsigned long limit)
 {
@@ -2869,6 +2955,8 @@ int __do_munmap(struct mm_struct *mm, unsigned long start, size_t len,
 		if (error)
 			return error;
 		prev = vma;
+
+		split_vma_pgtable_counter_fixup(prev, prev->vm_next, prev->pte_table_counter_pending);
 	}
 
 	/* Does it split the last one? */
@@ -2877,6 +2965,7 @@ int __do_munmap(struct mm_struct *mm, unsigned long start, size_t len,
 		int error = __split_vma(mm, last, end, 1);
 		if (error)
 			return error;
+		split_vma_pgtable_counter_fixup(last->vm_prev, last, last->pte_table_counter_pending);
 	}
 	vma = vma_next(mm, prev);
 
diff --git a/mm/mprotect.c b/mm/mprotect.c
index 4cb240fd9936..d396b1d38fab 100644
--- a/mm/mprotect.c
+++ b/mm/mprotect.c
@@ -445,6 +445,8 @@ static const struct mm_walk_ops prot_none_walk_ops = {
 	.test_walk		= prot_none_test,
 };
 
+int split_vma_pgtable_counter_fixup(struct vm_area_struct *lvma, struct vm_area_struct *rvma, bool orig_pending_flag);
+
 int
 mprotect_fixup(struct vm_area_struct *vma, struct vm_area_struct **pprev,
 	unsigned long start, unsigned long end, unsigned long newflags)
@@ -517,12 +519,16 @@ mprotect_fixup(struct vm_area_struct *vma, struct vm_area_struct **pprev,
 		error = split_vma(mm, vma, start, 1);
 		if (error)
 			goto fail;
+
+		split_vma_pgtable_counter_fixup(vma->vm_prev, vma, vma->pte_table_counter_pending);
 	}
 
 	if (end != vma->vm_end) {
 		error = split_vma(mm, vma, end, 0);
 		if (error)
 			goto fail;
+
+		split_vma_pgtable_counter_fixup(vma, vma->vm_next, vma->pte_table_counter_pending);
 	}
 
 success:
-- 
2.30.2



^ permalink raw reply related	[flat|nested] 4+ messages in thread

* Re: [PATCH] Shared page tables during fork
  2021-07-01 13:46 [PATCH] Shared page tables during fork Kaiyang Zhao
@ 2021-07-06 18:47 ` Souptick Joarder
  2021-07-06 23:01 ` Dave Hansen
  2021-07-07  7:00 ` Hillf Danton
  2 siblings, 0 replies; 4+ messages in thread
From: Souptick Joarder @ 2021-07-06 18:47 UTC (permalink / raw)
  To: Kaiyang Zhao; +Cc: Andrew Morton, Linux-MM

On Thu, Jul 1, 2021 at 7:17 PM Kaiyang Zhao <zhao776@purdue.edu> wrote:
>
> In our research work [https://dl.acm.org/doi/10.1145/3447786.3456258], we
> have identified a method that for large applications (i.e., a few hundred
> MBs and larger), can significantly speed up the fork system call. Currently
> the amount of time that the fork system call takes to complete is
> proportional to the size of allocated memory of a process, and our design
> speeds up fork invocation by up to 270x at 50GB in our experiments.
>
> The design is that instead of copying the entire paging tree during the
> fork invocation, we make the child and the parent process share the same
> set of last-level page tables, which will be reference counted. To preserve
> the copy-on-write semantics, we disable the write permission in PMD entries
> in fork, and copy PTE tables as needed in the page fault handler.

Does application have options to choose between default fork() and new
on demand fork() ?

>
> We tested a prototype with large workloads that call fork to take snapshots
> such as fuzzers (e.g., AFL), and it yielded over 2x the execution
> throughput for AFL. The patch is a prototype for x86 only and does not
> support huge pages and swapping, and is meant to demonstrate the potential
> performance gains to fork. Applications can opt-in by a switch use_odf in
> procfs.
>
> On a side note, an approach that shares page tables was proposed by Dave
> McCracken [http://lkml.iu.edu/hypermail/linux/kernel/0508.3/1623.html,
> https://www.kernel.org/doc/ols/2006/ols2006v2-pages-125-130.pdf], but never
> made it into the kernel.  We believe that with the increasing memory
> consumption of modern applications and modern use cases of fork such as
> snapshotting, the shared page table approach in the context of fork is
> worth exploring.
>
> Please let us know your level of interest in this or comments on the
> general design. Thank you.
>
> Signed-off-by: Kaiyang Zhao <zhao776@purdue.edu>
> ---
>  arch/x86/include/asm/pgtable.h |  19 +-
>  fs/proc/base.c                 |  74 ++++++
>  include/linux/mm.h             |  11 +
>  include/linux/mm_types.h       |   2 +
>  include/linux/pgtable.h        |  11 +
>  include/linux/sched/coredump.h |   5 +-
>  kernel/fork.c                  |   7 +-
>  mm/gup.c                       |  61 ++++-
>  mm/memory.c                    | 401 +++++++++++++++++++++++++++++++--
>  mm/mmap.c                      |  91 +++++++-
>  mm/mprotect.c                  |   6 +
>  11 files changed, 668 insertions(+), 20 deletions(-)
>
> diff --git a/arch/x86/include/asm/pgtable.h b/arch/x86/include/asm/pgtable.h
> index b6c97b8f59ec..0fda05a5c7a1 100644
> --- a/arch/x86/include/asm/pgtable.h
> +++ b/arch/x86/include/asm/pgtable.h
> @@ -410,6 +410,16 @@ static inline pmd_t pmd_clear_flags(pmd_t pmd, pmdval_t clear)
>         return native_make_pmd(v & ~clear);
>  }
>
> +static inline pmd_t pmd_mknonpresent(pmd_t pmd)
> +{
> +       return pmd_clear_flags(pmd, _PAGE_PRESENT);
> +}
> +
> +static inline pmd_t pmd_mkpresent(pmd_t pmd)
> +{
> +       return pmd_set_flags(pmd, _PAGE_PRESENT);
> +}
> +
>  #ifdef CONFIG_HAVE_ARCH_USERFAULTFD_WP
>  static inline int pmd_uffd_wp(pmd_t pmd)
>  {
> @@ -798,6 +808,11 @@ static inline int pmd_present(pmd_t pmd)
>         return pmd_flags(pmd) & (_PAGE_PRESENT | _PAGE_PROTNONE | _PAGE_PSE);
>  }
>
> +static inline int pmd_iswrite(pmd_t pmd)
> +{
> +       return pmd_flags(pmd) & (_PAGE_RW);
> +}
> +
>  #ifdef CONFIG_NUMA_BALANCING
>  /*
>   * These work without NUMA balancing but the kernel does not care. See the
> @@ -833,7 +848,7 @@ static inline unsigned long pmd_page_vaddr(pmd_t pmd)
>   * Currently stuck as a macro due to indirect forward reference to
>   * linux/mmzone.h's __section_mem_map_addr() definition:
>   */
> -#define pmd_page(pmd)  pfn_to_page(pmd_pfn(pmd))
> +#define pmd_page(pmd)  pfn_to_page(pmd_pfn(pmd_mkpresent(pmd)))
>
>  /*
>   * Conversion functions: convert a page and protection to a page entry,
> @@ -846,7 +861,7 @@ static inline unsigned long pmd_page_vaddr(pmd_t pmd)
>
>  static inline int pmd_bad(pmd_t pmd)
>  {
> -       return (pmd_flags(pmd) & ~_PAGE_USER) != _KERNPG_TABLE;
> +       return ((pmd_flags(pmd) & ~(_PAGE_USER)) | (_PAGE_RW | _PAGE_PRESENT)) != _KERNPG_TABLE;
>  }
>
>  static inline unsigned long pages_to_mb(unsigned long npg)
> diff --git a/fs/proc/base.c b/fs/proc/base.c
> index e5b5f7709d48..936f33594539 100644
> --- a/fs/proc/base.c
> +++ b/fs/proc/base.c
> @@ -2935,6 +2935,79 @@ static const struct file_operations proc_coredump_filter_operations = {
>  };
>  #endif
>
> +static ssize_t proc_use_odf_read(struct file *file, char __user *buf,
> +                                        size_t count, loff_t *ppos)
> +{
> +       struct task_struct *task = get_proc_task(file_inode(file));
> +       struct mm_struct *mm;
> +       char buffer[PROC_NUMBUF];
> +       size_t len;
> +       int ret;
> +
> +       if (!task)
> +               return -ESRCH;
> +
> +       ret = 0;
> +       mm = get_task_mm(task);
> +       if (mm) {
> +               len = snprintf(buffer, sizeof(buffer), "%lu\n",
> +                         ((mm->flags & MMF_USE_ODF_MASK) >> MMF_USE_ODF));
> +               mmput(mm);
> +               ret = simple_read_from_buffer(buf, count, ppos, buffer, len);
> +       }
> +
> +       put_task_struct(task);
> +
> +       return ret;
> +}
> +
> +static ssize_t proc_use_odf_write(struct file *file,
> +                                         const char __user *buf,
> +                                         size_t count,
> +                                         loff_t *ppos)
> +{
> +       struct task_struct *task;
> +       struct mm_struct *mm;
> +       unsigned int val;
> +       int ret;
> +
> +       ret = kstrtouint_from_user(buf, count, 0, &val);
> +       if (ret < 0)
> +               return ret;
> +
> +       ret = -ESRCH;
> +       task = get_proc_task(file_inode(file));
> +       if (!task)
> +               goto out_no_task;
> +
> +       mm = get_task_mm(task);
> +       if (!mm)
> +               goto out_no_mm;
> +       ret = 0;
> +
> +       if (val == 1) {
> +               set_bit(MMF_USE_ODF, &mm->flags);
> +       } else if (val == 0) {
> +               clear_bit(MMF_USE_ODF, &mm->flags);
> +       } else {
> +               //ignore
> +       }
> +
> +       mmput(mm);
> + out_no_mm:
> +       put_task_struct(task);
> + out_no_task:
> +       if (ret < 0)
> +               return ret;
> +       return count;
> +}
> +
> +static const struct file_operations proc_use_odf_operations = {
> +       .read           = proc_use_odf_read,
> +       .write          = proc_use_odf_write,
> +       .llseek         = generic_file_llseek,
> +};
> +
>  #ifdef CONFIG_TASK_IO_ACCOUNTING
>  static int do_io_accounting(struct task_struct *task, struct seq_file *m, int whole)
>  {
> @@ -3253,6 +3326,7 @@ static const struct pid_entry tgid_base_stuff[] = {
>  #ifdef CONFIG_ELF_CORE
>         REG("coredump_filter", S_IRUGO|S_IWUSR, proc_coredump_filter_operations),
>  #endif
> +       REG("use_odf", S_IRUGO|S_IWUSR, proc_use_odf_operations),
>  #ifdef CONFIG_TASK_IO_ACCOUNTING
>         ONE("io",       S_IRUSR, proc_tgid_io_accounting),
>  #endif
> diff --git a/include/linux/mm.h b/include/linux/mm.h
> index 57453dba41b9..a30eca9e236a 100644
> --- a/include/linux/mm.h
> +++ b/include/linux/mm.h
> @@ -664,6 +664,7 @@ static inline void vma_init(struct vm_area_struct *vma, struct mm_struct *mm)
>         memset(vma, 0, sizeof(*vma));
>         vma->vm_mm = mm;
>         vma->vm_ops = &dummy_vm_ops;
> +       vma->pte_table_counter_pending = true;
>         INIT_LIST_HEAD(&vma->anon_vma_chain);
>  }
>
> @@ -2250,6 +2251,9 @@ static inline bool pgtable_pte_page_ctor(struct page *page)
>                 return false;
>         __SetPageTable(page);
>         inc_lruvec_page_state(page, NR_PAGETABLE);
> +
> +       atomic64_set(&(page->pte_table_refcount), 0);
> +
>         return true;
>  }
>
> @@ -2276,6 +2280,8 @@ static inline void pgtable_pte_page_dtor(struct page *page)
>
>  #define pte_alloc(mm, pmd) (unlikely(pmd_none(*(pmd))) && __pte_alloc(mm, pmd))
>
> +#define tfork_pte_alloc(mm, pmd) (__tfork_pte_alloc(mm, pmd))
> +
>  #define pte_alloc_map(mm, pmd, address)                        \
>         (pte_alloc(mm, pmd) ? NULL : pte_offset_map(pmd, address))
>
> @@ -2283,6 +2289,10 @@ static inline void pgtable_pte_page_dtor(struct page *page)
>         (pte_alloc(mm, pmd) ?                   \
>                  NULL : pte_offset_map_lock(mm, pmd, address, ptlp))
>
> +#define tfork_pte_alloc_map_lock(mm, pmd, address, ptlp)       \
> +       (tfork_pte_alloc(mm, pmd) ?                     \
> +                NULL : pte_offset_map_lock(mm, pmd, address, ptlp))
> +
>  #define pte_alloc_kernel(pmd, address)                 \
>         ((unlikely(pmd_none(*(pmd))) && __pte_alloc_kernel(pmd))? \
>                 NULL: pte_offset_kernel(pmd, address))
> @@ -2616,6 +2626,7 @@ extern int do_madvise(struct mm_struct *mm, unsigned long start, size_t len_in,
>  #ifdef CONFIG_MMU
>  extern int __mm_populate(unsigned long addr, unsigned long len,
>                          int ignore_errors);
> +extern int __mm_populate_nolock(unsigned long addr, unsigned long len, int ignore_errors);
>  static inline void mm_populate(unsigned long addr, unsigned long len)
>  {
>         /* Ignore errors */
> diff --git a/include/linux/mm_types.h b/include/linux/mm_types.h
> index f37abb2d222e..e06c677ce279 100644
> --- a/include/linux/mm_types.h
> +++ b/include/linux/mm_types.h
> @@ -158,6 +158,7 @@ struct page {
>                         union {
>                                 struct mm_struct *pt_mm; /* x86 pgds only */
>                                 atomic_t pt_frag_refcount; /* powerpc */
> +                               atomic64_t pte_table_refcount;
>                         };
>  #if USE_SPLIT_PTE_PTLOCKS
>  #if ALLOC_SPLIT_PTLOCKS
> @@ -379,6 +380,7 @@ struct vm_area_struct {
>         struct mempolicy *vm_policy;    /* NUMA policy for the VMA */
>  #endif
>         struct vm_userfaultfd_ctx vm_userfaultfd_ctx;
> +       bool pte_table_counter_pending;
>  } __randomize_layout;
>
>  struct core_thread {
> diff --git a/include/linux/pgtable.h b/include/linux/pgtable.h
> index d147480cdefc..6afd77ff82e6 100644
> --- a/include/linux/pgtable.h
> +++ b/include/linux/pgtable.h
> @@ -90,6 +90,11 @@ static inline pte_t *pte_offset_kernel(pmd_t *pmd, unsigned long address)
>         return (pte_t *)pmd_page_vaddr(*pmd) + pte_index(address);
>  }
>  #define pte_offset_kernel pte_offset_kernel
> +static inline pte_t *tfork_pte_offset_kernel(pmd_t pmd_val, unsigned long address)
> +{
> +       return (pte_t *)pmd_page_vaddr(pmd_val) + pte_index(address);
> +}
> +#define tfork_pte_offset_kernel tfork_pte_offset_kernel
>  #endif
>
>  #if defined(CONFIG_HIGHPTE)
> @@ -782,6 +787,12 @@ static inline void arch_swap_restore(swp_entry_t entry, struct page *page)
>  })
>  #endif
>
> +#define pte_table_start(addr)                                          \
> +(addr & PMD_MASK)
> +
> +#define pte_table_end(addr)                                             \
> +(((addr) + PMD_SIZE) & PMD_MASK)
> +
>  /*
>   * When walking page tables, we usually want to skip any p?d_none entries;
>   * and any p?d_bad entries - reporting the error before resetting to none.
> diff --git a/include/linux/sched/coredump.h b/include/linux/sched/coredump.h
> index 4d9e3a656875..8f6e50bc04ab 100644
> --- a/include/linux/sched/coredump.h
> +++ b/include/linux/sched/coredump.h
> @@ -83,7 +83,10 @@ static inline int get_dumpable(struct mm_struct *mm)
>  #define MMF_HAS_PINNED         28      /* FOLL_PIN has run, never cleared */
>  #define MMF_DISABLE_THP_MASK   (1 << MMF_DISABLE_THP)
>
> +#define MMF_USE_ODF 29
> +#define MMF_USE_ODF_MASK (1 << MMF_USE_ODF)
> +
>  #define MMF_INIT_MASK          (MMF_DUMPABLE_MASK | MMF_DUMP_FILTER_MASK |\
> -                                MMF_DISABLE_THP_MASK)
> +                                MMF_DISABLE_THP_MASK | MMF_USE_ODF_MASK)
>
>  #endif /* _LINUX_SCHED_COREDUMP_H */
> diff --git a/kernel/fork.c b/kernel/fork.c
> index d738aae40f9e..4f21ea4f4f38 100644
> --- a/kernel/fork.c
> +++ b/kernel/fork.c
> @@ -594,8 +594,13 @@ static __latent_entropy int dup_mmap(struct mm_struct *mm,
>                 rb_parent = &tmp->vm_rb;
>
>                 mm->map_count++;
> -               if (!(tmp->vm_flags & VM_WIPEONFORK))
> +               if (!(tmp->vm_flags & VM_WIPEONFORK)) {
>                         retval = copy_page_range(tmp, mpnt);
> +                       if (oldmm->flags & MMF_USE_ODF_MASK) {
> +                               tmp->pte_table_counter_pending = false; // reference of the shared PTE table by the new VMA is counted in copy_pmd_range_tfork
> +                               mpnt->pte_table_counter_pending = false; // don't double count when forking again
> +                       }
> +               }
>
>                 if (tmp->vm_ops && tmp->vm_ops->open)
>                         tmp->vm_ops->open(tmp);
> diff --git a/mm/gup.c b/mm/gup.c
> index 42b8b1fa6521..5768f339b0ff 100644
> --- a/mm/gup.c
> +++ b/mm/gup.c
> @@ -1489,8 +1489,11 @@ long populate_vma_page_range(struct vm_area_struct *vma,
>          * to break COW, except for shared mappings because these don't COW
>          * and we would not want to dirty them for nothing.
>          */
> -       if ((vma->vm_flags & (VM_WRITE | VM_SHARED)) == VM_WRITE)
> -               gup_flags |= FOLL_WRITE;
> +       if ((vma->vm_flags & (VM_WRITE | VM_SHARED)) == VM_WRITE) {
> +               if (!(mm->flags & MMF_USE_ODF_MASK)) { //for ODF processes, only allocate page tables
> +                       gup_flags |= FOLL_WRITE;
> +               }
> +       }
>
>         /*
>          * We want mlock to succeed for regions that have any permissions
> @@ -1669,6 +1672,60 @@ static long __get_user_pages_locked(struct mm_struct *mm, unsigned long start,
>  }
>  #endif /* !CONFIG_MMU */
>
> +int __mm_populate_nolock(unsigned long start, unsigned long len, int ignore_errors)
> +{
> +       struct mm_struct *mm = current->mm;
> +       unsigned long end, nstart, nend;
> +       struct vm_area_struct *vma = NULL;
> +       int locked = 0;
> +       long ret = 0;
> +
> +       end = start + len;
> +
> +       for (nstart = start; nstart < end; nstart = nend) {
> +               /*
> +                * We want to fault in pages for [nstart; end) address range.
> +                * Find first corresponding VMA.
> +                */
> +               if (!locked) {
> +                       locked = 1;
> +                       //down_read(&mm->mmap_sem);
> +                       vma = find_vma(mm, nstart);
> +               } else if (nstart >= vma->vm_end)
> +                       vma = vma->vm_next;
> +               if (!vma || vma->vm_start >= end)
> +                       break;
> +               /*
> +                * Set [nstart; nend) to intersection of desired address
> +                * range with the first VMA. Also, skip undesirable VMA types.
> +                */
> +               nend = min(end, vma->vm_end);
> +               if (vma->vm_flags & (VM_IO | VM_PFNMAP))
> +                       continue;
> +               if (nstart < vma->vm_start)
> +                       nstart = vma->vm_start;
> +               /*
> +                * Now fault in a range of pages. populate_vma_page_range()
> +                * double checks the vma flags, so that it won't mlock pages
> +                * if the vma was already munlocked.
> +                */
> +               ret = populate_vma_page_range(vma, nstart, nend, &locked);
> +               if (ret < 0) {
> +                       if (ignore_errors) {
> +                               ret = 0;
> +                               continue;       /* continue at next VMA */
> +                       }
> +                       break;
> +               }
> +               nend = nstart + ret * PAGE_SIZE;
> +               ret = 0;
> +       }
> +       /*if (locked)
> +               up_read(&mm->mmap_sem);
> +       */
> +       return ret;     /* 0 or negative error code */
> +}
> +
>  /**
>   * get_dump_page() - pin user page in memory while writing it to core dump
>   * @addr: user address
> diff --git a/mm/memory.c b/mm/memory.c
> index db86558791f1..2b28766e4213 100644
> --- a/mm/memory.c
> +++ b/mm/memory.c
> @@ -83,6 +83,9 @@
>  #include <asm/tlb.h>
>  #include <asm/tlbflush.h>
>
> +static bool tfork_one_pte_table(struct mm_struct *, pmd_t *, unsigned long, unsigned long);
> +static inline void init_rss_vec(int *rss);
> +static inline void add_mm_rss_vec(struct mm_struct *mm, int *rss);
>  #include "pgalloc-track.h"
>  #include "internal.h"
>
> @@ -227,7 +230,16 @@ static void free_pte_range(struct mmu_gather *tlb, pmd_t *pmd,
>                            unsigned long addr)
>  {
>         pgtable_t token = pmd_pgtable(*pmd);
> +       long counter;
>         pmd_clear(pmd);
> +       counter = atomic64_read(&(token->pte_table_refcount));
> +       if (counter > 0) {
> +               //the pte table can only be shared in this case
> +#ifdef CONFIG_DEBUG_VM
> +               printk("free_pte_range: addr=%lx, counter=%ld, not freeing table", addr, counter);
> +#endif
> +               return;  //pte table is still in use
> +       }
>         pte_free_tlb(tlb, token, addr);
>         mm_dec_nr_ptes(tlb->mm);
>  }
> @@ -433,6 +445,118 @@ void free_pgtables(struct mmu_gather *tlb, struct vm_area_struct *vma,
>         }
>  }
>
> +// frees every page described by the pte table
> +void zap_one_pte_table(pmd_t pmd_val, unsigned long addr, struct mm_struct *mm)
> +{
> +       int rss[NR_MM_COUNTERS];
> +       pte_t *pte;
> +       unsigned long end;
> +
> +       init_rss_vec(rss);
> +       addr = pte_table_start(addr);
> +       end = pte_table_end(addr);
> +       pte = tfork_pte_offset_kernel(pmd_val, addr);
> +       do {
> +               pte_t ptent = *pte;
> +
> +               if (pte_none(ptent))
> +                       continue;
> +
> +               if (pte_present(ptent)) {
> +                       struct page *page;
> +
> +                       if (pte_special(ptent)) { //known special pte: vvar VMA, which has just one page shared system-wide. Shouldn't matter
> +                               continue;
> +                       }
> +                       page = vm_normal_page(NULL, addr, ptent); //kyz : vma is not important
> +                       if (unlikely(!page))
> +                               continue;
> +                       rss[mm_counter(page)]--;
> +#ifdef CONFIG_DEBUG_VM
> +                       //   printk("zap_one_pte_table: addr=%lx, end=%lx, (before) mapcount=%d, refcount=%d\n", addr, end, page_mapcount(page), page_ref_count(page));
> +#endif
> +                       page_remove_rmap(page, false);
> +                       put_page(page);
> +               }
> +       } while (pte++, addr += PAGE_SIZE, addr != end);
> +
> +       add_mm_rss_vec(mm, rss);
> +}
> +
> +/* pmd lock should be held
> + * returns 1 if the table becomes unused
> + */
> +int dereference_pte_table(pmd_t pmd_val, bool free_table, struct mm_struct *mm, unsigned long addr)
> +{
> +       struct page *table_page;
> +
> +       table_page = pmd_page(pmd_val);
> +
> +       if (atomic64_dec_and_test(&(table_page->pte_table_refcount))) {
> +#ifdef CONFIG_DEBUG_VM
> +               printk("dereference_pte_table: addr=%lx, free_table=%d, pte table reached end of life\n", addr, free_table);
> +#endif
> +
> +               zap_one_pte_table(pmd_val, addr, mm);
> +               if (free_table) {
> +                       pgtable_pte_page_dtor(table_page);
> +                       __free_page(table_page);
> +                       mm_dec_nr_ptes(mm);
> +               }
> +               return 1;
> +       } else {
> +#ifdef CONFIG_DEBUG_VM
> +               printk("dereference_pte_table: addr=%lx, (after) pte_table_count=%lld\n", addr, atomic64_read(&(table_page->pte_table_refcount)));
> +#endif
> +       }
> +       return 0;
> +}
> +
> +int dereference_pte_table_multiple(pmd_t pmd_val, bool free_table, struct mm_struct *mm, unsigned long addr, int num)
> +{
> +       struct page *table_page;
> +       int count_after;
> +
> +       table_page = pmd_page(pmd_val);
> +       count_after = atomic64_sub_return(num, &(table_page->pte_table_refcount));
> +       if (count_after <= 0) {
> +#ifdef CONFIG_DEBUG_VM
> +               printk("dereference_pte_table_multiple: addr=%lx, free_table=%d, num=%d, after count=%d, table reached end of life\n", addr, free_table, num, count_after);
> +#endif
> +
> +               zap_one_pte_table(pmd_val, addr, mm);
> +               if (free_table) {
> +                       pgtable_pte_page_dtor(table_page);
> +                       __free_page(table_page);
> +                       mm_dec_nr_ptes(mm);
> +               }
> +               return 1;
> +       } else {
> +#ifdef CONFIG_DEBUG_VM
> +               printk("dereference_pte_table_multiple: addr=%lx, num=%d, (after) count=%lld\n", addr, num, atomic64_read(&(table_page->pte_table_refcount)));
> +#endif
> +       }
> +       return 0;
> +}
> +
> +int __tfork_pte_alloc(struct mm_struct *mm, pmd_t *pmd)
> +{
> +       pgtable_t new = pte_alloc_one(mm);
> +
> +       if (!new)
> +               return -ENOMEM;
> +       smp_wmb(); /* Could be smp_wmb__xxx(before|after)_spin_lock */
> +
> +       mm_inc_nr_ptes(mm);
> +       //kyz: won't check if the pte table already exists
> +       pmd_populate(mm, pmd, new);
> +       new = NULL;
> +       if (new)
> +               pte_free(mm, new);
> +       return 0;
> +}
> +
> +
>  int __pte_alloc(struct mm_struct *mm, pmd_t *pmd)
>  {
>         spinlock_t *ptl;
> @@ -928,6 +1052,45 @@ copy_present_page(struct vm_area_struct *dst_vma, struct vm_area_struct *src_vma
>         return 0;
>  }
>
> +static inline unsigned long
> +copy_one_pte_tfork(struct mm_struct *dst_mm,
> +               pte_t *dst_pte, pte_t *src_pte, struct vm_area_struct *vma,
> +               unsigned long addr, int *rss)
> +{
> +       unsigned long vm_flags = vma->vm_flags;
> +       pte_t pte = *src_pte;
> +       struct page *page;
> +
> +       /*
> +        * If it's a COW mapping
> +        * only protect in the child (the faulting process)
> +        */
> +       if (is_cow_mapping(vm_flags) && pte_write(pte)) {
> +               pte = pte_wrprotect(pte);
> +       }
> +
> +       /*
> +        * If it's a shared mapping, mark it clean in
> +        * the child
> +        */
> +       if (vm_flags & VM_SHARED)
> +               pte = pte_mkclean(pte);
> +       pte = pte_mkold(pte);
> +
> +       page = vm_normal_page(vma, addr, pte);
> +       if (page) {
> +               get_page(page);
> +               page_dup_rmap(page, false);
> +               rss[mm_counter(page)]++;
> +#ifdef CONFIG_DEBUG_VM
> +//             printk("copy_one_pte_tfork: addr=%lx, (after) mapcount=%d, refcount=%d\n", addr, page_mapcount(page), page_ref_count(page));
> +#endif
> +       }
> +
> +       set_pte_at(dst_mm, addr, dst_pte, pte);
> +       return 0;
> +}
> +
>  /*
>   * Copy one pte.  Returns 0 if succeeded, or -EAGAIN if one preallocated page
>   * is required to copy this pte.
> @@ -999,6 +1162,59 @@ page_copy_prealloc(struct mm_struct *src_mm, struct vm_area_struct *vma,
>         return new_page;
>  }
>
> +static int copy_pte_range_tfork(struct mm_struct *dst_mm,
> +                  pmd_t *dst_pmd, pmd_t src_pmd_val, struct vm_area_struct *vma,
> +                  unsigned long addr, unsigned long end)
> +{
> +       pte_t *orig_src_pte, *orig_dst_pte;
> +       pte_t *src_pte, *dst_pte;
> +       spinlock_t *dst_ptl;
> +       int rss[NR_MM_COUNTERS];
> +       swp_entry_t entry = (swp_entry_t){0};
> +       struct page *dst_pte_page;
> +
> +       init_rss_vec(rss);
> +
> +       src_pte = tfork_pte_offset_kernel(src_pmd_val, addr); //src_pte points to the old table
> +       if (!pmd_iswrite(*dst_pmd)) {
> +               dst_pte = tfork_pte_alloc_map_lock(dst_mm, dst_pmd, addr, &dst_ptl);  //dst_pte points to a new table
> +#ifdef CONFIG_DEBUG_VM
> +               printk("copy_pte_range_tfork: allocated new table. addr=%lx, prev_table_page=%px, table_page=%px\n", addr, pmd_page(src_pmd_val), pmd_page(*dst_pmd));
> +#endif
> +       } else {
> +               dst_pte = pte_alloc_map_lock(dst_mm, dst_pmd, addr, &dst_ptl);
> +       }
> +       if (!dst_pte)
> +               return -ENOMEM;
> +
> +       dst_pte_page = pmd_page(*dst_pmd);
> +       atomic64_inc(&(dst_pte_page->pte_table_refcount)); //kyz: associates the VMA with the new table
> +#ifdef CONFIG_DEBUG_VM
> +       printk("copy_pte_range_tfork: addr = %lx, end = %lx, new pte table counter (after)=%lld\n", addr, end, atomic64_read(&(dst_pte_page->pte_table_refcount)));
> +#endif
> +
> +       orig_src_pte = src_pte;
> +       orig_dst_pte = dst_pte;
> +       arch_enter_lazy_mmu_mode();
> +
> +       do {
> +               if (pte_none(*src_pte)) {
> +                       continue;
> +               }
> +               entry.val = copy_one_pte_tfork(dst_mm, dst_pte, src_pte,
> +                                                       vma, addr, rss);
> +               if (entry.val)
> +                       printk("kyz: failed copy_one_pte_tfork call\n");
> +       } while (dst_pte++, src_pte++, addr += PAGE_SIZE, addr != end);
> +
> +       arch_leave_lazy_mmu_mode();
> +       pte_unmap(orig_src_pte);
> +       add_mm_rss_vec(dst_mm, rss);
> +       pte_unmap_unlock(orig_dst_pte, dst_ptl);
> +
> +       return 0;
> +}
> +
>  static int
>  copy_pte_range(struct vm_area_struct *dst_vma, struct vm_area_struct *src_vma,
>                pmd_t *dst_pmd, pmd_t *src_pmd, unsigned long addr,
> @@ -1130,8 +1346,9 @@ copy_pmd_range(struct vm_area_struct *dst_vma, struct vm_area_struct *src_vma,
>  {
>         struct mm_struct *dst_mm = dst_vma->vm_mm;
>         struct mm_struct *src_mm = src_vma->vm_mm;
> -       pmd_t *src_pmd, *dst_pmd;
> +       pmd_t *src_pmd, *dst_pmd, src_pmd_value;
>         unsigned long next;
> +       struct page *table_page;
>
>         dst_pmd = pmd_alloc(dst_mm, dst_pud, addr);
>         if (!dst_pmd)
> @@ -1153,9 +1370,43 @@ copy_pmd_range(struct vm_area_struct *dst_vma, struct vm_area_struct *src_vma,
>                 }
>                 if (pmd_none_or_clear_bad(src_pmd))
>                         continue;
> -               if (copy_pte_range(dst_vma, src_vma, dst_pmd, src_pmd,
> -                                  addr, next))
> -                       return -ENOMEM;
> +               if (src_mm->flags & MMF_USE_ODF_MASK) {
> +#ifdef CONFIG_DEBUG_VM
> +                       printk("copy_pmd_range: vm_start=%lx, addr=%lx, vm_end=%lx, end=%lx\n", src_vma->vm_start, addr, src_vma->vm_end, end);
> +#endif
> +
> +                       src_pmd_value = *src_pmd;
> +                       //kyz: sets write-protect to the pmd entry if the vma is writable
> +                       if (src_vma->vm_flags & VM_WRITE) {
> +                               src_pmd_value = pmd_wrprotect(src_pmd_value);
> +                               set_pmd_at(src_mm, addr, src_pmd, src_pmd_value);
> +                       }
> +                       table_page = pmd_page(*src_pmd);
> +                       if (src_vma->pte_table_counter_pending) { // kyz : the old VMA hasn't been counted in the PTE table, count it now
> +                               atomic64_add(2, &(table_page->pte_table_refcount));
> +#ifdef CONFIG_DEBUG_VM
> +                               printk("copy_pmd_range: addr=%lx, pte table counter (after counting old&new)=%lld\n", addr, atomic64_read(&(table_page->pte_table_refcount)));
> +#endif
> +                       } else {
> +                               atomic64_inc(&(table_page->pte_table_refcount));  //increments the pte table counter
> +                               if (atomic64_read(&(table_page->pte_table_refcount)) == 1) {  //the VMA is old, but the pte table is new (created by a fault after the last odf call)
> +                                       atomic64_set(&(table_page->pte_table_refcount), 2);
> +#ifdef CONFIG_DEBUG_VM
> +                                       printk("copy_pmd_range: addr=%lx, pte table counter (old VMA, new pte table)=%lld\n", addr, atomic64_read(&(table_page->pte_table_refcount)));
> +#endif
> +                               }
> +#ifdef CONFIG_DEBUG_VM
> +                               else {
> +                                       printk("copy_pmd_range: addr=%lx, pte table counter (after counting new)=%lld\n", addr, atomic64_read(&(table_page->pte_table_refcount)));
> +                               }
> +#endif
> +                       }
> +                       set_pmd_at(dst_mm, addr, dst_pmd, src_pmd_value);  //shares the table with the child
> +               } else {
> +                       if (copy_pte_range(dst_vma, src_vma, dst_pmd, src_pmd,
> +                                               addr, next))
> +                               return -ENOMEM;
> +               }
>         } while (dst_pmd++, src_pmd++, addr = next, addr != end);
>         return 0;
>  }
> @@ -1240,9 +1491,10 @@ copy_page_range(struct vm_area_struct *dst_vma, struct vm_area_struct *src_vma)
>          * readonly mappings. The tradeoff is that copy_page_range is more
>          * efficient than faulting.
>          */
> -       if (!(src_vma->vm_flags & (VM_HUGETLB | VM_PFNMAP | VM_MIXEDMAP)) &&
> +/*     if (!(src_vma->vm_flags & (VM_HUGETLB | VM_PFNMAP | VM_MIXEDMAP)) &&
>             !src_vma->anon_vma)
>                 return 0;
> +*/
>
>         if (is_vm_hugetlb_page(src_vma))
>                 return copy_hugetlb_page_range(dst_mm, src_mm, src_vma);
> @@ -1304,7 +1556,7 @@ copy_page_range(struct vm_area_struct *dst_vma, struct vm_area_struct *src_vma)
>  static unsigned long zap_pte_range(struct mmu_gather *tlb,
>                                 struct vm_area_struct *vma, pmd_t *pmd,
>                                 unsigned long addr, unsigned long end,
> -                               struct zap_details *details)
> +                               struct zap_details *details, bool invalidate_pmd)
>  {
>         struct mm_struct *mm = tlb->mm;
>         int force_flush = 0;
> @@ -1343,8 +1595,10 @@ static unsigned long zap_pte_range(struct mmu_gather *tlb,
>                                     details->check_mapping != page_rmapping(page))
>                                         continue;
>                         }
> -                       ptent = ptep_get_and_clear_full(mm, addr, pte,
> -                                                       tlb->fullmm);
> +                       if (!invalidate_pmd) {
> +                               ptent = ptep_get_and_clear_full(mm, addr, pte,
> +                                               tlb->fullmm);
> +                       }
>                         tlb_remove_tlb_entry(tlb, pte, addr);
>                         if (unlikely(!page))
>                                 continue;
> @@ -1358,8 +1612,12 @@ static unsigned long zap_pte_range(struct mmu_gather *tlb,
>                                     likely(!(vma->vm_flags & VM_SEQ_READ)))
>                                         mark_page_accessed(page);
>                         }
> -                       rss[mm_counter(page)]--;
> -                       page_remove_rmap(page, false);
> +                       if (!invalidate_pmd) {
> +                               rss[mm_counter(page)]--;
> +                               page_remove_rmap(page, false);
> +                       } else {
> +                               continue;
> +                       }
>                         if (unlikely(page_mapcount(page) < 0))
>                                 print_bad_pte(vma, addr, ptent, page);
>                         if (unlikely(__tlb_remove_page(tlb, page))) {
> @@ -1446,12 +1704,16 @@ static inline unsigned long zap_pmd_range(struct mmu_gather *tlb,
>                                 struct zap_details *details)
>  {
>         pmd_t *pmd;
> -       unsigned long next;
> +       unsigned long next, table_start, table_end;
> +       spinlock_t *ptl;
> +       struct page *table_page;
> +       bool got_new_table = false;
>
>         pmd = pmd_offset(pud, addr);
>         do {
> +               ptl = pmd_lock(vma->vm_mm, pmd);
>                 next = pmd_addr_end(addr, end);
> -               if (is_swap_pmd(*pmd) || pmd_trans_huge(*pmd) || pmd_devmap(*pmd)) {
> +               if (pmd_trans_huge(*pmd) || pmd_devmap(*pmd)) {
>                         if (next - addr != HPAGE_PMD_SIZE)
>                                 __split_huge_pmd(vma, pmd, addr, false, NULL);
>                         else if (zap_huge_pmd(tlb, vma, pmd, addr))
> @@ -1478,8 +1740,49 @@ static inline unsigned long zap_pmd_range(struct mmu_gather *tlb,
>                  */
>                 if (pmd_none_or_trans_huge_or_clear_bad(pmd))
>                         goto next;
> -               next = zap_pte_range(tlb, vma, pmd, addr, next, details);
> +               //kyz: copy if the pte table is shared and VMA does not cover fully the 2MB region
> +               table_page = pmd_page(*pmd);
> +               table_start = pte_table_start(addr);
> +
> +               if ((!pmd_iswrite(*pmd)) && (!vma->pte_table_counter_pending)) {//shared pte table. vma has gone through odf
> +                       table_end = pte_table_end(addr);
> +                       if (table_start < vma->vm_start || table_end > vma->vm_end) {
> +#ifdef CONFIG_DEBUG_VM
> +                               printk("%s: addr=%lx, end=%lx, table_start=%lx, table_end=%lx, copy then zap\n", __func__, addr, end, table_start, table_end);
> +#endif
> +                               if (dereference_pte_table(*pmd, false, vma->vm_mm, addr) != 1) { //dec the counter of the shared table. tfork_one_pte_table cannot find the current VMA (which is being unmapped)
> +                                       got_new_table = tfork_one_pte_table(vma->vm_mm, pmd, addr, vma->vm_end);
> +                                       if (got_new_table) {
> +                                               next = zap_pte_range(tlb, vma, pmd, addr, next, details, false);
> +                                       } else {
> +#ifdef CONFIG_DEBUG_VM
> +                                               printk("zap_pmd_range: no more VMAs in this process are using the table, but there are other processes using it\n");
> +#endif
> +                                               pmd_clear(pmd);
> +                                       }
> +                               } else {
> +#ifdef CONFIG_DEBUG_VM
> +                                       printk("zap_pmd_range: the shared table is dead. NOT copying after all.\n");
> +#endif
> +                                       // the shared table will be freed by unmap_single_vma()
> +                               }
> +                       } else {
> +#ifdef CONFIG_DEBUG_VM
> +                               printk("%s: addr=%lx, end=%lx, table_start=%lx, table_end=%lx, zap while preserving pte entries\n", __func__, addr, end, table_start, table_end);
> +#endif
> +                               //kyz: shared and fully covered by the VMA, preserve the pte entries
> +                               next = zap_pte_range(tlb, vma, pmd, addr, next, details, true);
> +                               dereference_pte_table(*pmd, true, vma->vm_mm, addr);
> +                               pmd_clear(pmd);
> +                       }
> +               } else {
> +                       next = zap_pte_range(tlb, vma, pmd, addr, next, details, false);
> +                       if (!vma->pte_table_counter_pending) {
> +                               atomic64_dec(&(table_page->pte_table_refcount));
> +                       }
> +               }
>  next:
> +               spin_unlock(ptl);
>                 cond_resched();
>         } while (pmd++, addr = next, addr != end);
>
> @@ -4476,6 +4779,66 @@ static vm_fault_t wp_huge_pud(struct vm_fault *vmf, pud_t orig_pud)
>         return VM_FAULT_FALLBACK;
>  }
>
> +/*  kyz: Handles an entire pte-level page table, covering multiple VMAs (if they exist)
> + *  Returns true if a new table is put in place, false otherwise.
> + *  if exclude is not 0, the vma that covers addr to exclude will not be copied
> + */
> +static bool tfork_one_pte_table(struct mm_struct *mm, pmd_t *dst_pmd, unsigned long addr, unsigned long exclude)
> +{
> +       unsigned long table_end, end, orig_addr;
> +       struct vm_area_struct *vma;
> +       pmd_t orig_pmd_val;
> +       bool copied = false;
> +       struct page *orig_pte_page;
> +       int num_vmas = 0;
> +
> +       if (!pmd_none(*dst_pmd)) {
> +               orig_pmd_val = *dst_pmd;
> +       } else {
> +               BUG();
> +       }
> +
> +       //kyz: Starts from the beginning of the range covered by the table
> +       orig_addr = addr;
> +       table_end = pte_table_end(addr);
> +       addr = pte_table_start(addr);
> +#ifdef CONFIG_DEBUG_VM
> +       orig_pte_page = pmd_page(orig_pmd_val);
> +       printk("tfork_one_pte_table: shared pte table counter=%lld, Covered Range: start=%lx, end=%lx\n", atomic64_read(&(orig_pte_page->pte_table_refcount)), addr, table_end);
> +#endif
> +       do {
> +               vma = find_vma(mm, addr);
> +               if (!vma) {
> +                       break;  //inexplicable
> +               }
> +               if (vma->vm_start >= table_end) {
> +                       break;
> +               }
> +               end = pmd_addr_end(addr, vma->vm_end);
> +               if (vma->pte_table_counter_pending) {  //this vma is newly mapped (clean) and (fully/partly) described by this pte table
> +                       addr = end;
> +                       continue;
> +               }
> +               if (vma->vm_start > addr) {
> +                       addr = vma->vm_start;
> +               }
> +               if (exclude > 0 && vma->vm_start <= orig_addr && vma->vm_end >= exclude) {
> +                       addr = end;
> +                       continue;
> +               }
> +#ifdef CONFIG_DEBUG_VM
> +               printk("tfork_one_pte_table: vm_start=%lx, vm_end=%lx\n", vma->vm_start, vma->vm_end);
> +#endif
> +               num_vmas++;
> +               copy_pte_range_tfork(mm, dst_pmd, orig_pmd_val, vma, addr, end);
> +               copied = true;
> +               addr = end;
> +       } while (addr < table_end);
> +
> +       dereference_pte_table_multiple(orig_pmd_val, true, mm, orig_addr, num_vmas);
> +       return copied;
> +}
> +
>  /*
>   * These routines also need to handle stuff like marking pages dirty
>   * and/or accessed for architectures that don't do it in hardware (most
> @@ -4610,6 +4973,7 @@ static vm_fault_t __handle_mm_fault(struct vm_area_struct *vma,
>         pgd_t *pgd;
>         p4d_t *p4d;
>         vm_fault_t ret;
> +       spinlock_t *ptl;
>
>         pgd = pgd_offset(mm, address);
>         p4d = p4d_alloc(mm, pgd, address);
> @@ -4659,6 +5023,7 @@ static vm_fault_t __handle_mm_fault(struct vm_area_struct *vma,
>                 vmf.orig_pmd = *vmf.pmd;
>
>                 barrier();
> +               /*
>                 if (unlikely(is_swap_pmd(vmf.orig_pmd))) {
>                         VM_BUG_ON(thp_migration_supported() &&
>                                           !is_pmd_migration_entry(vmf.orig_pmd));
> @@ -4666,6 +5031,7 @@ static vm_fault_t __handle_mm_fault(struct vm_area_struct *vma,
>                                 pmd_migration_entry_wait(mm, vmf.pmd);
>                         return 0;
>                 }
> +               */
>                 if (pmd_trans_huge(vmf.orig_pmd) || pmd_devmap(vmf.orig_pmd)) {
>                         if (pmd_protnone(vmf.orig_pmd) && vma_is_accessible(vma))
>                                 return do_huge_pmd_numa_page(&vmf);
> @@ -4679,6 +5045,15 @@ static vm_fault_t __handle_mm_fault(struct vm_area_struct *vma,
>                                 return 0;
>                         }
>                 }
> +               //kyz: checks if the pmd entry prohibits writes
> +               if ((!pmd_none(vmf.orig_pmd)) && (!pmd_iswrite(vmf.orig_pmd)) && (vma->vm_flags & VM_WRITE)) {
> +#ifdef CONFIG_DEBUG_VM
> +                       printk("__handle_mm_fault: PID=%d, addr=%lx\n", current->pid, address);
> +#endif
> +                       ptl = pmd_lock(mm, vmf.pmd);
> +                       tfork_one_pte_table(mm, vmf.pmd, vmf.address, 0u);
> +                       spin_unlock(ptl);
> +               }
>         }
>
>         return handle_pte_fault(&vmf);
> diff --git a/mm/mmap.c b/mm/mmap.c
> index ca54d36d203a..308d86cfe544 100644
> --- a/mm/mmap.c
> +++ b/mm/mmap.c
> @@ -47,6 +47,7 @@
>  #include <linux/pkeys.h>
>  #include <linux/oom.h>
>  #include <linux/sched/mm.h>
> +#include <linux/pagewalk.h>
>
>  #include <linux/uaccess.h>
>  #include <asm/cacheflush.h>
> @@ -276,6 +277,9 @@ SYSCALL_DEFINE1(brk, unsigned long, brk)
>
>  success:
>         populate = newbrk > oldbrk && (mm->def_flags & VM_LOCKED) != 0;
> +       if (mm->flags & MMF_USE_ODF_MASK) { //for ODF
> +               populate = true;
> +       }
>         if (downgraded)
>                 mmap_read_unlock(mm);
>         else
> @@ -1115,6 +1119,50 @@ can_vma_merge_after(struct vm_area_struct *vma, unsigned long vm_flags,
>         return 0;
>  }
>
> +static int pgtable_counter_fixup_pmd_entry(pmd_t *pmd, unsigned long addr,
> +                              unsigned long next, struct mm_walk *walk)
> +{
> +       struct page *table_page;
> +
> +       table_page = pmd_page(*pmd);
> +       atomic64_inc(&(table_page->pte_table_refcount));
> +
> +#ifdef CONFIG_DEBUG_VM
> +       printk("fixup inc: addr=%lx\n", addr);
> +#endif
> +
> +       walk->action = ACTION_CONTINUE;  //skip pte level
> +       return 0;
> +}
> +
> +static int pgtable_counter_fixup_test(unsigned long addr, unsigned long next,
> +                         struct mm_walk *walk)
> +{
> +       return 0;
> +}
> +
> +static const struct mm_walk_ops pgtable_counter_fixup_walk_ops = {
> +.pmd_entry = pgtable_counter_fixup_pmd_entry,
> +.test_walk = pgtable_counter_fixup_test
> +};
> +
> +int merge_vma_pgtable_counter_fixup(struct vm_area_struct *vma, unsigned long start, unsigned long end)
> +{
> +       if (vma->pte_table_counter_pending) {
> +               return 0;
> +       } else {
> +#ifdef CONFIG_DEBUG_VM
> +               printk("merge fixup: vm_start=%lx, vm_end=%lx, inc start=%lx, inc end=%lx\n", vma->vm_start, vma->vm_end, start, end);
> +#endif
> +               start = pte_table_end(start);
> +               end = pte_table_start(end);
> +               __mm_populate_nolock(start, end-start, 1); //popuate tables for extended address range so that we can increment counters
> +               walk_page_range(vma->vm_mm, start, end, &pgtable_counter_fixup_walk_ops, NULL);
> +       }
> +
> +       return 0;
> +}
> +
>  /*
>   * Given a mapping request (addr,end,vm_flags,file,pgoff), figure out
>   * whether that can be merged with its predecessor or its successor.
> @@ -1215,6 +1263,9 @@ struct vm_area_struct *vma_merge(struct mm_struct *mm,
>                 if (err)
>                         return NULL;
>                 khugepaged_enter_vma_merge(prev, vm_flags);
> +
> +               merge_vma_pgtable_counter_fixup(prev, addr, end);
> +
>                 return prev;
>         }
>
> @@ -1242,6 +1293,9 @@ struct vm_area_struct *vma_merge(struct mm_struct *mm,
>                 if (err)
>                         return NULL;
>                 khugepaged_enter_vma_merge(area, vm_flags);
> +
> +               merge_vma_pgtable_counter_fixup(area, addr, end);
> +
>                 return area;
>         }
>
> @@ -1584,8 +1638,15 @@ unsigned long do_mmap(struct file *file, unsigned long addr,
>         addr = mmap_region(file, addr, len, vm_flags, pgoff, uf);
>         if (!IS_ERR_VALUE(addr) &&
>             ((vm_flags & VM_LOCKED) ||
> -            (flags & (MAP_POPULATE | MAP_NONBLOCK)) == MAP_POPULATE))
> +            (flags & (MAP_POPULATE | MAP_NONBLOCK)) == MAP_POPULATE ||
> +            (mm->flags & MMF_USE_ODF_MASK))) {
> +#ifdef CONFIG_DEBUG_VM
> +               if (mm->flags & MMF_USE_ODF_MASK) {
> +                       printk("mmap: force populate, addr=%lx, len=%lx\n", addr, len);
> +               }
> +#endif
>                 *populate = len;
> +       }
>         return addr;
>  }
>
> @@ -2799,6 +2860,31 @@ int split_vma(struct mm_struct *mm, struct vm_area_struct *vma,
>         return __split_vma(mm, vma, addr, new_below);
>  }
>
> +/* left and right vma after the split, address of split */
> +int split_vma_pgtable_counter_fixup(struct vm_area_struct *lvma, struct vm_area_struct *rvma, bool orig_pending_flag)
> +{
> +       if (orig_pending_flag) {
> +               return 0;  //the new vma will have pending flag as true by default, just as the old vma
> +       } else {
> +#ifdef CONFIG_DEBUG_VM
> +               printk("split fixup: set vma flag to false, rvma_start=%lx\n", rvma->vm_start);
> +#endif
> +               lvma->pte_table_counter_pending = false;
> +               rvma->pte_table_counter_pending = false;
> +
> +               if (pte_table_start(rvma->vm_start) == rvma->vm_start) {  //the split was right at the pte table boundary
> +                       return 0;  //the only case where we don't increment pte table counter
> +               } else {
> +#ifdef CONFIG_DEBUG_VM
> +                       printk("split fixup: rvma_start=%lx\n", rvma->vm_start);
> +#endif
> +                       walk_page_range(rvma->vm_mm, pte_table_start(rvma->vm_start), pte_table_end(rvma->vm_start), &pgtable_counter_fixup_walk_ops, NULL);
> +               }
> +       }
> +
> +       return 0;
> +}
> +
>  static inline void
>  unlock_range(struct vm_area_struct *start, unsigned long limit)
>  {
> @@ -2869,6 +2955,8 @@ int __do_munmap(struct mm_struct *mm, unsigned long start, size_t len,
>                 if (error)
>                         return error;
>                 prev = vma;
> +
> +               split_vma_pgtable_counter_fixup(prev, prev->vm_next, prev->pte_table_counter_pending);
>         }
>
>         /* Does it split the last one? */
> @@ -2877,6 +2965,7 @@ int __do_munmap(struct mm_struct *mm, unsigned long start, size_t len,
>                 int error = __split_vma(mm, last, end, 1);
>                 if (error)
>                         return error;
> +               split_vma_pgtable_counter_fixup(last->vm_prev, last, last->pte_table_counter_pending);
>         }
>         vma = vma_next(mm, prev);
>
> diff --git a/mm/mprotect.c b/mm/mprotect.c
> index 4cb240fd9936..d396b1d38fab 100644
> --- a/mm/mprotect.c
> +++ b/mm/mprotect.c
> @@ -445,6 +445,8 @@ static const struct mm_walk_ops prot_none_walk_ops = {
>         .test_walk              = prot_none_test,
>  };
>
> +int split_vma_pgtable_counter_fixup(struct vm_area_struct *lvma, struct vm_area_struct *rvma, bool orig_pending_flag);
> +
>  int
>  mprotect_fixup(struct vm_area_struct *vma, struct vm_area_struct **pprev,
>         unsigned long start, unsigned long end, unsigned long newflags)
> @@ -517,12 +519,16 @@ mprotect_fixup(struct vm_area_struct *vma, struct vm_area_struct **pprev,
>                 error = split_vma(mm, vma, start, 1);
>                 if (error)
>                         goto fail;
> +
> +               split_vma_pgtable_counter_fixup(vma->vm_prev, vma, vma->pte_table_counter_pending);
>         }
>
>         if (end != vma->vm_end) {
>                 error = split_vma(mm, vma, end, 0);
>                 if (error)
>                         goto fail;
> +
> +               split_vma_pgtable_counter_fixup(vma, vma->vm_next, vma->pte_table_counter_pending);
>         }
>
>  success:
> --
> 2.30.2
>
>


^ permalink raw reply	[flat|nested] 4+ messages in thread

* Re: [PATCH] Shared page tables during fork
  2021-07-01 13:46 [PATCH] Shared page tables during fork Kaiyang Zhao
  2021-07-06 18:47 ` Souptick Joarder
@ 2021-07-06 23:01 ` Dave Hansen
  2021-07-07  7:00 ` Hillf Danton
  2 siblings, 0 replies; 4+ messages in thread
From: Dave Hansen @ 2021-07-06 23:01 UTC (permalink / raw)
  To: Kaiyang Zhao, akpm; +Cc: linux-mm

On 7/1/21 6:46 AM, Kaiyang Zhao wrote:
> The design is that instead of copying the entire paging tree during the
> fork invocation, we make the child and the parent process share the same
> set of last-level page tables, which will be reference counted. To preserve
> the copy-on-write semantics, we disable the write permission in PMD entries
> in fork, and copy PTE tables as needed in the page fault handler.

That's clever.  But, I'm not sure it's comprehensive.  How, for
instance, do you handle get_user_pages() users that don't actually write
to the mappings?  Or, memory reclaim where the kernel itself goes and
zaps page table entries without accessing the mapping itself that's
being zapped.

I would have expected a *lot* more pervasive changes to page table
walkers across the kernel.

Oh, and the code itself makes my eyes bleed.  You might want to spend a
bit of time to clean out the debug printk()s and make sure this gets
somewhere close to passing checkpatch.pl if you want to be taken more
seriously.  For example:

> +		if (pte_present(ptent)) {
> +			struct page *page;
> +
> +			if (pte_special(ptent)) { //known special pte: vvar VMA, which has just one page shared system-wide. Shouldn't matter
> +				continue;
> +			}
> +			page = vm_normal_page(NULL, addr, ptent); //kyz : vma is not important
> +			if (unlikely(!page))
> +				continue;
> +			rss[mm_counter(page)]--;
> +#ifdef CONFIG_DEBUG_VM
> +			//   printk("zap_one_pte_table: addr=%lx, end=%lx, (before) mapcount=%d, refcount=%d\n", addr, end, page_mapcount(page), page_ref_count(page));
> +#endif
> +			page_remove_rmap(page, false);
> +			put_page(page);
> +		}



^ permalink raw reply	[flat|nested] 4+ messages in thread

* Re: [PATCH] Shared page tables during fork
  2021-07-01 13:46 [PATCH] Shared page tables during fork Kaiyang Zhao
  2021-07-06 18:47 ` Souptick Joarder
  2021-07-06 23:01 ` Dave Hansen
@ 2021-07-07  7:00 ` Hillf Danton
  2 siblings, 0 replies; 4+ messages in thread
From: Hillf Danton @ 2021-07-07  7:00 UTC (permalink / raw)
  To: Kaiyang Zhao; +Cc: linux-mm

On Thu,  1 Jul 2021 09:46:18 -0400 Kaiyang Zhao <zhao776@purdue.edu> wrote:
>+
>+int __tfork_pte_alloc(struct mm_struct *mm, pmd_t *pmd)
>+{
>+	pgtable_t new = pte_alloc_one(mm);
>+
>+	if (!new)
>+		return -ENOMEM;
>+	smp_wmb(); /* Could be smp_wmb__xxx(before|after)_spin_lock */

Adding a comment pointing to the matching smp_rmb() helps more.
>+
>+	mm_inc_nr_ptes(mm);
>+	//kyz: won't check if the pte table already exists
>+	pmd_populate(mm, pmd, new);
>+	new = NULL;
>+	if (new)
>+		pte_free(mm, new);
>+	return 0;
>+}
>+
>+
> int __pte_alloc(struct mm_struct *mm, pmd_t *pmd)
> {
> 	spinlock_t *ptl;
>@@ -928,6 +1052,45 @@ copy_present_page(struct vm_area_struct *dst_vma, struct vm_area_struct *src_vma
> 	return 0;
> }
>
>+static inline unsigned long
>+copy_one_pte_tfork(struct mm_struct *dst_mm,
>+		pte_t *dst_pte, pte_t *src_pte, struct vm_area_struct *vma,
>+		unsigned long addr, int *rss)
>+{
>+	unsigned long vm_flags = vma->vm_flags;
>+	pte_t pte = *src_pte;
>+	struct page *page;
>+
>+	/*
>+	 * If it's a COW mapping
>+	 * only protect in the child (the faulting process)
>+	 */
>+	if (is_cow_mapping(vm_flags) && pte_write(pte)) {
>+		pte = pte_wrprotect(pte);

Is it likely a loophole to leak the parent's info to the child?
>+	}
>+
>+	/*
>+	 * If it's a shared mapping, mark it clean in
>+	 * the child
>+	 */
>+	if (vm_flags & VM_SHARED)
>+		pte = pte_mkclean(pte);
>+	pte = pte_mkold(pte);


^ permalink raw reply	[flat|nested] 4+ messages in thread

end of thread, other threads:[~2021-07-07  7:01 UTC | newest]

Thread overview: 4+ messages (download: mbox.gz / follow: Atom feed)
-- links below jump to the message on this page --
2021-07-01 13:46 [PATCH] Shared page tables during fork Kaiyang Zhao
2021-07-06 18:47 ` Souptick Joarder
2021-07-06 23:01 ` Dave Hansen
2021-07-07  7:00 ` Hillf Danton

This is an external index of several public inboxes,
see mirroring instructions on how to clone and mirror
all data and code used by this external index.