]> git.karo-electronics.de Git - mv-sheeva.git/blobdiff - arch/x86/kvm/mmu.c
KVM: MMU: fix mmu notifier invalidate handler for huge spte
[mv-sheeva.git] / arch / x86 / kvm / mmu.c
index a6f695d76928675008a99f2030b00299856840c4..812770cddc8d5456f59056996531357cf9592c4f 100644 (file)
@@ -7,6 +7,7 @@
  * MMU support
  *
  * Copyright (C) 2006 Qumranet, Inc.
+ * Copyright 2010 Red Hat, Inc. and/or its affilates.
  *
  * Authors:
  *   Yaniv Kamay  <yaniv@qumranet.com>
@@ -32,6 +33,7 @@
 #include <linux/compiler.h>
 #include <linux/srcu.h>
 #include <linux/slab.h>
+#include <linux/uaccess.h>
 
 #include <asm/page.h>
 #include <asm/cmpxchg.h>
@@ -90,8 +92,6 @@ module_param(oos_shadow, bool, 0644);
 #define PT_FIRST_AVAIL_BITS_SHIFT 9
 #define PT64_SECOND_AVAIL_BITS_SHIFT 52
 
-#define VALID_PAGE(x) ((x) != INVALID_PAGE)
-
 #define PT64_LEVEL_BITS 9
 
 #define PT64_LEVEL_SHIFT(level) \
@@ -173,7 +173,7 @@ struct kvm_shadow_walk_iterator {
             shadow_walk_okay(&(_walker));                      \
             shadow_walk_next(&(_walker)))
 
-typedef int (*mmu_parent_walk_fn) (struct kvm_mmu_page *sp);
+typedef void (*mmu_parent_walk_fn) (struct kvm_mmu_page *sp, u64 *spte);
 
 static struct kmem_cache *pte_chain_cache;
 static struct kmem_cache *rmap_desc_cache;
@@ -288,6 +288,34 @@ static void __set_spte(u64 *sptep, u64 spte)
 #endif
 }
 
+static u64 __xchg_spte(u64 *sptep, u64 new_spte)
+{
+#ifdef CONFIG_X86_64
+       return xchg(sptep, new_spte);
+#else
+       u64 old_spte;
+
+       do {
+               old_spte = *sptep;
+       } while (cmpxchg64(sptep, old_spte, new_spte) != old_spte);
+
+       return old_spte;
+#endif
+}
+
+static void update_spte(u64 *sptep, u64 new_spte)
+{
+       u64 old_spte;
+
+       if (!shadow_accessed_mask || (new_spte & shadow_accessed_mask)) {
+               __set_spte(sptep, new_spte);
+       } else {
+               old_spte = __xchg_spte(sptep, new_spte);
+               if (old_spte & shadow_accessed_mask)
+                       mark_page_accessed(pfn_to_page(spte_to_pfn(old_spte)));
+       }
+}
+
 static int mmu_topup_memory_cache(struct kvm_mmu_memory_cache *cache,
                                  struct kmem_cache *base_cache, int min)
 {
@@ -304,10 +332,11 @@ static int mmu_topup_memory_cache(struct kvm_mmu_memory_cache *cache,
        return 0;
 }
 
-static void mmu_free_memory_cache(struct kvm_mmu_memory_cache *mc)
+static void mmu_free_memory_cache(struct kvm_mmu_memory_cache *mc,
+                                 struct kmem_cache *cache)
 {
        while (mc->nobjs)
-               kfree(mc->objects[--mc->nobjs]);
+               kmem_cache_free(cache, mc->objects[--mc->nobjs]);
 }
 
 static int mmu_topup_memory_cache_page(struct kvm_mmu_memory_cache *cache,
@@ -355,10 +384,11 @@ out:
 
 static void mmu_free_memory_caches(struct kvm_vcpu *vcpu)
 {
-       mmu_free_memory_cache(&vcpu->arch.mmu_pte_chain_cache);
-       mmu_free_memory_cache(&vcpu->arch.mmu_rmap_desc_cache);
+       mmu_free_memory_cache(&vcpu->arch.mmu_pte_chain_cache, pte_chain_cache);
+       mmu_free_memory_cache(&vcpu->arch.mmu_rmap_desc_cache, rmap_desc_cache);
        mmu_free_memory_cache_page(&vcpu->arch.mmu_page_cache);
-       mmu_free_memory_cache(&vcpu->arch.mmu_page_header_cache);
+       mmu_free_memory_cache(&vcpu->arch.mmu_page_header_cache,
+                               mmu_page_header_cache);
 }
 
 static void *mmu_memory_cache_alloc(struct kvm_mmu_memory_cache *mc,
@@ -379,7 +409,7 @@ static struct kvm_pte_chain *mmu_alloc_pte_chain(struct kvm_vcpu *vcpu)
 
 static void mmu_free_pte_chain(struct kvm_pte_chain *pc)
 {
-       kfree(pc);
+       kmem_cache_free(pte_chain_cache, pc);
 }
 
 static struct kvm_rmap_desc *mmu_alloc_rmap_desc(struct kvm_vcpu *vcpu)
@@ -390,7 +420,23 @@ static struct kvm_rmap_desc *mmu_alloc_rmap_desc(struct kvm_vcpu *vcpu)
 
 static void mmu_free_rmap_desc(struct kvm_rmap_desc *rd)
 {
-       kfree(rd);
+       kmem_cache_free(rmap_desc_cache, rd);
+}
+
+static gfn_t kvm_mmu_page_get_gfn(struct kvm_mmu_page *sp, int index)
+{
+       if (!sp->role.direct)
+               return sp->gfns[index];
+
+       return sp->gfn + (index << ((sp->role.level - 1) * PT64_LEVEL_BITS));
+}
+
+static void kvm_mmu_page_set_gfn(struct kvm_mmu_page *sp, int index, gfn_t gfn)
+{
+       if (sp->role.direct)
+               BUG_ON(gfn != kvm_mmu_page_get_gfn(sp, index));
+       else
+               sp->gfns[index] = gfn;
 }
 
 /*
@@ -403,8 +449,8 @@ static int *slot_largepage_idx(gfn_t gfn,
 {
        unsigned long idx;
 
-       idx = (gfn / KVM_PAGES_PER_HPAGE(level)) -
-             (slot->base_gfn / KVM_PAGES_PER_HPAGE(level));
+       idx = (gfn >> KVM_HPAGE_GFN_SHIFT(level)) -
+             (slot->base_gfn >> KVM_HPAGE_GFN_SHIFT(level));
        return &slot->lpage_info[level - 2][idx].write_count;
 }
 
@@ -414,9 +460,7 @@ static void account_shadowed(struct kvm *kvm, gfn_t gfn)
        int *write_count;
        int i;
 
-       gfn = unalias_gfn(kvm, gfn);
-
-       slot = gfn_to_memslot_unaliased(kvm, gfn);
+       slot = gfn_to_memslot(kvm, gfn);
        for (i = PT_DIRECTORY_LEVEL;
             i < PT_PAGE_TABLE_LEVEL + KVM_NR_PAGE_SIZES; ++i) {
                write_count   = slot_largepage_idx(gfn, slot, i);
@@ -430,8 +474,7 @@ static void unaccount_shadowed(struct kvm *kvm, gfn_t gfn)
        int *write_count;
        int i;
 
-       gfn = unalias_gfn(kvm, gfn);
-       slot = gfn_to_memslot_unaliased(kvm, gfn);
+       slot = gfn_to_memslot(kvm, gfn);
        for (i = PT_DIRECTORY_LEVEL;
             i < PT_PAGE_TABLE_LEVEL + KVM_NR_PAGE_SIZES; ++i) {
                write_count   = slot_largepage_idx(gfn, slot, i);
@@ -447,8 +490,7 @@ static int has_wrprotected_page(struct kvm *kvm,
        struct kvm_memory_slot *slot;
        int *largepage_idx;
 
-       gfn = unalias_gfn(kvm, gfn);
-       slot = gfn_to_memslot_unaliased(kvm, gfn);
+       slot = gfn_to_memslot(kvm, gfn);
        if (slot) {
                largepage_idx = slot_largepage_idx(gfn, slot, level);
                return *largepage_idx;
@@ -501,7 +543,6 @@ static int mapping_level(struct kvm_vcpu *vcpu, gfn_t large_gfn)
 
 /*
  * Take gfn and return the reverse mapping to it.
- * Note: gfn must be unaliased before this function get called
  */
 
 static unsigned long *gfn_to_rmap(struct kvm *kvm, gfn_t gfn, int level)
@@ -513,8 +554,8 @@ static unsigned long *gfn_to_rmap(struct kvm *kvm, gfn_t gfn, int level)
        if (likely(level == PT_PAGE_TABLE_LEVEL))
                return &slot->rmap[gfn - slot->base_gfn];
 
-       idx = (gfn / KVM_PAGES_PER_HPAGE(level)) -
-               (slot->base_gfn / KVM_PAGES_PER_HPAGE(level));
+       idx = (gfn >> KVM_HPAGE_GFN_SHIFT(level)) -
+               (slot->base_gfn >> KVM_HPAGE_GFN_SHIFT(level));
 
        return &slot->lpage_info[level - 2][idx].rmap_pde;
 }
@@ -541,9 +582,8 @@ static int rmap_add(struct kvm_vcpu *vcpu, u64 *spte, gfn_t gfn)
 
        if (!is_rmap_spte(*spte))
                return count;
-       gfn = unalias_gfn(vcpu->kvm, gfn);
        sp = page_header(__pa(spte));
-       sp->gfns[spte - sp->spt] = gfn;
+       kvm_mmu_page_set_gfn(sp, spte - sp->spt, gfn);
        rmapp = gfn_to_rmap(vcpu->kvm, gfn, sp->role.level);
        if (!*rmapp) {
                rmap_printk("rmap_add: %p %llx 0->1\n", spte, *spte);
@@ -600,19 +640,13 @@ static void rmap_remove(struct kvm *kvm, u64 *spte)
        struct kvm_rmap_desc *desc;
        struct kvm_rmap_desc *prev_desc;
        struct kvm_mmu_page *sp;
-       pfn_t pfn;
+       gfn_t gfn;
        unsigned long *rmapp;
        int i;
 
-       if (!is_rmap_spte(*spte))
-               return;
        sp = page_header(__pa(spte));
-       pfn = spte_to_pfn(*spte);
-       if (*spte & shadow_accessed_mask)
-               kvm_set_pfn_accessed(pfn);
-       if (is_writable_pte(*spte))
-               kvm_set_pfn_dirty(pfn);
-       rmapp = gfn_to_rmap(kvm, sp->gfns[spte - sp->spt], sp->role.level);
+       gfn = kvm_mmu_page_get_gfn(sp, spte - sp->spt);
+       rmapp = gfn_to_rmap(kvm, gfn, sp->role.level);
        if (!*rmapp) {
                printk(KERN_ERR "rmap_remove: %p %llx 0->BUG\n", spte, *spte);
                BUG();
@@ -644,6 +678,22 @@ static void rmap_remove(struct kvm *kvm, u64 *spte)
        }
 }
 
+static void drop_spte(struct kvm *kvm, u64 *sptep, u64 new_spte)
+{
+       pfn_t pfn;
+       u64 old_spte;
+
+       old_spte = __xchg_spte(sptep, new_spte);
+       if (!is_rmap_spte(old_spte))
+               return;
+       pfn = spte_to_pfn(old_spte);
+       if (old_spte & shadow_accessed_mask)
+               kvm_set_pfn_accessed(pfn);
+       if (is_writable_pte(old_spte))
+               kvm_set_pfn_dirty(pfn);
+       rmap_remove(kvm, sptep);
+}
+
 static u64 *rmap_next(struct kvm *kvm, unsigned long *rmapp, u64 *spte)
 {
        struct kvm_rmap_desc *desc;
@@ -676,7 +726,6 @@ static int rmap_write_protect(struct kvm *kvm, u64 gfn)
        u64 *spte;
        int i, write_protected = 0;
 
-       gfn = unalias_gfn(kvm, gfn);
        rmapp = gfn_to_rmap(kvm, gfn, PT_PAGE_TABLE_LEVEL);
 
        spte = rmap_next(kvm, rmapp, NULL);
@@ -685,7 +734,7 @@ static int rmap_write_protect(struct kvm *kvm, u64 gfn)
                BUG_ON(!(*spte & PT_PRESENT_MASK));
                rmap_printk("rmap_write_protect: spte %p %llx\n", spte, *spte);
                if (is_writable_pte(*spte)) {
-                       __set_spte(spte, *spte & ~PT_WRITABLE_MASK);
+                       update_spte(spte, *spte & ~PT_WRITABLE_MASK);
                        write_protected = 1;
                }
                spte = rmap_next(kvm, rmapp, spte);
@@ -709,9 +758,9 @@ static int rmap_write_protect(struct kvm *kvm, u64 gfn)
                        BUG_ON((*spte & (PT_PAGE_SIZE_MASK|PT_PRESENT_MASK)) != (PT_PAGE_SIZE_MASK|PT_PRESENT_MASK));
                        pgprintk("rmap_write_protect(large): spte %p %llx %lld\n", spte, *spte, gfn);
                        if (is_writable_pte(*spte)) {
-                               rmap_remove(kvm, spte);
+                               drop_spte(kvm, spte,
+                                         shadow_trap_nonpresent_pte);
                                --kvm->stat.lpages;
-                               __set_spte(spte, shadow_trap_nonpresent_pte);
                                spte = NULL;
                                write_protected = 1;
                        }
@@ -731,8 +780,7 @@ static int kvm_unmap_rmapp(struct kvm *kvm, unsigned long *rmapp,
        while ((spte = rmap_next(kvm, rmapp, NULL))) {
                BUG_ON(!(*spte & PT_PRESENT_MASK));
                rmap_printk("kvm_rmap_unmap_hva: spte %p %llx\n", spte, *spte);
-               rmap_remove(kvm, spte);
-               __set_spte(spte, shadow_trap_nonpresent_pte);
+               drop_spte(kvm, spte, shadow_trap_nonpresent_pte);
                need_tlb_flush = 1;
        }
        return need_tlb_flush;
@@ -742,7 +790,7 @@ static int kvm_set_pte_rmapp(struct kvm *kvm, unsigned long *rmapp,
                             unsigned long data)
 {
        int need_flush = 0;
-       u64 *spte, new_spte;
+       u64 *spte, new_spte, old_spte;
        pte_t *ptep = (pte_t *)data;
        pfn_t new_pfn;
 
@@ -754,8 +802,7 @@ static int kvm_set_pte_rmapp(struct kvm *kvm, unsigned long *rmapp,
                rmap_printk("kvm_set_pte_rmapp: spte %p %llx\n", spte, *spte);
                need_flush = 1;
                if (pte_write(*ptep)) {
-                       rmap_remove(kvm, spte);
-                       __set_spte(spte, shadow_trap_nonpresent_pte);
+                       drop_spte(kvm, spte, shadow_trap_nonpresent_pte);
                        spte = rmap_next(kvm, rmapp, NULL);
                } else {
                        new_spte = *spte &~ (PT64_BASE_ADDR_MASK);
@@ -763,9 +810,13 @@ static int kvm_set_pte_rmapp(struct kvm *kvm, unsigned long *rmapp,
 
                        new_spte &= ~PT_WRITABLE_MASK;
                        new_spte &= ~SPTE_HOST_WRITEABLE;
+                       new_spte &= ~shadow_accessed_mask;
                        if (is_writable_pte(*spte))
                                kvm_set_pfn_dirty(spte_to_pfn(*spte));
-                       __set_spte(spte, new_spte);
+                       old_spte = __xchg_spte(spte, new_spte);
+                       if (is_shadow_present_pte(old_spte)
+                           && (old_spte & shadow_accessed_mask))
+                               mark_page_accessed(pfn_to_page(spte_to_pfn(old_spte)));
                        spte = rmap_next(kvm, rmapp, spte);
                }
        }
@@ -799,8 +850,12 @@ static int kvm_handle_hva(struct kvm *kvm, unsigned long hva,
                        ret = handler(kvm, &memslot->rmap[gfn_offset], data);
 
                        for (j = 0; j < KVM_NR_PAGE_SIZES - 1; ++j) {
-                               int idx = gfn_offset;
-                               idx /= KVM_PAGES_PER_HPAGE(PT_DIRECTORY_LEVEL + j);
+                               unsigned long idx;
+                               int sh;
+
+                               sh = KVM_HPAGE_GFN_SHIFT(PT_DIRECTORY_LEVEL+j);
+                               idx = ((memslot->base_gfn+gfn_offset) >> sh) -
+                                       (memslot->base_gfn >> sh);
                                ret |= handler(kvm,
                                        &memslot->lpage_info[j][idx].rmap_pde,
                                        data);
@@ -863,7 +918,6 @@ static void rmap_recycle(struct kvm_vcpu *vcpu, u64 *spte, gfn_t gfn)
 
        sp = page_header(__pa(spte));
 
-       gfn = unalias_gfn(vcpu->kvm, gfn);
        rmapp = gfn_to_rmap(vcpu->kvm, gfn, sp->role.level);
 
        kvm_unmap_rmapp(vcpu->kvm, rmapp, 0);
@@ -894,10 +948,12 @@ static int is_empty_shadow_page(u64 *spt)
 static void kvm_mmu_free_page(struct kvm *kvm, struct kvm_mmu_page *sp)
 {
        ASSERT(is_empty_shadow_page(sp->spt));
+       hlist_del(&sp->hash_link);
        list_del(&sp->link);
        __free_page(virt_to_page(sp->spt));
-       __free_page(virt_to_page(sp->gfns));
-       kfree(sp);
+       if (!sp->role.direct)
+               __free_page(virt_to_page(sp->gfns));
+       kmem_cache_free(mmu_page_header_cache, sp);
        ++kvm->arch.n_free_mmu_pages;
 }
 
@@ -907,13 +963,15 @@ static unsigned kvm_page_table_hashfn(gfn_t gfn)
 }
 
 static struct kvm_mmu_page *kvm_mmu_alloc_page(struct kvm_vcpu *vcpu,
-                                              u64 *parent_pte)
+                                              u64 *parent_pte, int direct)
 {
        struct kvm_mmu_page *sp;
 
        sp = mmu_memory_cache_alloc(&vcpu->arch.mmu_page_header_cache, sizeof *sp);
        sp->spt = mmu_memory_cache_alloc(&vcpu->arch.mmu_page_cache, PAGE_SIZE);
-       sp->gfns = mmu_memory_cache_alloc(&vcpu->arch.mmu_page_cache, PAGE_SIZE);
+       if (!direct)
+               sp->gfns = mmu_memory_cache_alloc(&vcpu->arch.mmu_page_cache,
+                                                 PAGE_SIZE);
        set_page_private(virt_to_page(sp->spt), (unsigned long)sp);
        list_add(&sp->link, &vcpu->kvm->arch.active_mmu_pages);
        bitmap_zero(sp->slot_bitmap, KVM_MEMORY_SLOTS + KVM_PRIVATE_MEM_SLOTS);
@@ -998,7 +1056,6 @@ static void mmu_page_remove_parent_pte(struct kvm_mmu_page *sp,
        BUG();
 }
 
-
 static void mmu_parent_walk(struct kvm_mmu_page *sp, mmu_parent_walk_fn fn)
 {
        struct kvm_pte_chain *pte_chain;
@@ -1008,63 +1065,37 @@ static void mmu_parent_walk(struct kvm_mmu_page *sp, mmu_parent_walk_fn fn)
 
        if (!sp->multimapped && sp->parent_pte) {
                parent_sp = page_header(__pa(sp->parent_pte));
-               fn(parent_sp);
-               mmu_parent_walk(parent_sp, fn);
+               fn(parent_sp, sp->parent_pte);
                return;
        }
+
        hlist_for_each_entry(pte_chain, node, &sp->parent_ptes, link)
                for (i = 0; i < NR_PTE_CHAIN_ENTRIES; ++i) {
-                       if (!pte_chain->parent_ptes[i])
+                       u64 *spte = pte_chain->parent_ptes[i];
+
+                       if (!spte)
                                break;
-                       parent_sp = page_header(__pa(pte_chain->parent_ptes[i]));
-                       fn(parent_sp);
-                       mmu_parent_walk(parent_sp, fn);
+                       parent_sp = page_header(__pa(spte));
+                       fn(parent_sp, spte);
                }
 }
 
-static void kvm_mmu_update_unsync_bitmap(u64 *spte)
+static void mark_unsync(struct kvm_mmu_page *sp, u64 *spte);
+static void kvm_mmu_mark_parents_unsync(struct kvm_mmu_page *sp)
 {
-       unsigned int index;
-       struct kvm_mmu_page *sp = page_header(__pa(spte));
-
-       index = spte - sp->spt;
-       if (!__test_and_set_bit(index, sp->unsync_child_bitmap))
-               sp->unsync_children++;
-       WARN_ON(!sp->unsync_children);
+       mmu_parent_walk(sp, mark_unsync);
 }
 
-static void kvm_mmu_update_parents_unsync(struct kvm_mmu_page *sp)
+static void mark_unsync(struct kvm_mmu_page *sp, u64 *spte)
 {
-       struct kvm_pte_chain *pte_chain;
-       struct hlist_node *node;
-       int i;
+       unsigned int index;
 
-       if (!sp->parent_pte)
+       index = spte - sp->spt;
+       if (__test_and_set_bit(index, sp->unsync_child_bitmap))
                return;
-
-       if (!sp->multimapped) {
-               kvm_mmu_update_unsync_bitmap(sp->parent_pte);
+       if (sp->unsync_children++)
                return;
-       }
-
-       hlist_for_each_entry(pte_chain, node, &sp->parent_ptes, link)
-               for (i = 0; i < NR_PTE_CHAIN_ENTRIES; ++i) {
-                       if (!pte_chain->parent_ptes[i])
-                               break;
-                       kvm_mmu_update_unsync_bitmap(pte_chain->parent_ptes[i]);
-               }
-}
-
-static int unsync_walk_fn(struct kvm_mmu_page *sp)
-{
-       kvm_mmu_update_parents_unsync(sp);
-       return 1;
-}
-
-static void kvm_mmu_mark_parents_unsync(struct kvm_mmu_page *sp)
-{
-       mmu_parent_walk(sp, unsync_walk_fn);
-       kvm_mmu_update_parents_unsync(sp);
+       kvm_mmu_mark_parents_unsync(sp);
 }
 
 static void nonpaging_prefetch_page(struct kvm_vcpu *vcpu,
@@ -1077,7 +1108,7 @@ static void nonpaging_prefetch_page(struct kvm_vcpu *vcpu,
 }
 
 static int nonpaging_sync_page(struct kvm_vcpu *vcpu,
-                              struct kvm_mmu_page *sp)
+                              struct kvm_mmu_page *sp, bool clear_unsync)
 {
        return 1;
 }
@@ -1123,35 +1154,40 @@ static int __mmu_unsync_walk(struct kvm_mmu_page *sp,
        int i, ret, nr_unsync_leaf = 0;
 
        for_each_unsync_children(sp->unsync_child_bitmap, i) {
+               struct kvm_mmu_page *child;
                u64 ent = sp->spt[i];
 
-               if (is_shadow_present_pte(ent) && !is_large_pte(ent)) {
-                       struct kvm_mmu_page *child;
-                       child = page_header(ent & PT64_BASE_ADDR_MASK);
-
-                       if (child->unsync_children) {
-                               if (mmu_pages_add(pvec, child, i))
-                                       return -ENOSPC;
-
-                               ret = __mmu_unsync_walk(child, pvec);
-                               if (!ret)
-                                       __clear_bit(i, sp->unsync_child_bitmap);
-                               else if (ret > 0)
-                                       nr_unsync_leaf += ret;
-                               else
-                                       return ret;
-                       }
+               if (!is_shadow_present_pte(ent) || is_large_pte(ent))
+                       goto clear_child_bitmap;
+
+               child = page_header(ent & PT64_BASE_ADDR_MASK);
+
+               if (child->unsync_children) {
+                       if (mmu_pages_add(pvec, child, i))
+                               return -ENOSPC;
+
+                       ret = __mmu_unsync_walk(child, pvec);
+                       if (!ret)
+                               goto clear_child_bitmap;
+                       else if (ret > 0)
+                               nr_unsync_leaf += ret;
+                       else
+                               return ret;
+               } else if (child->unsync) {
+                       nr_unsync_leaf++;
+                       if (mmu_pages_add(pvec, child, i))
+                               return -ENOSPC;
+               } else
+                        goto clear_child_bitmap;
 
-                       if (child->unsync) {
-                               nr_unsync_leaf++;
-                               if (mmu_pages_add(pvec, child, i))
-                                       return -ENOSPC;
-                       }
-               }
+               continue;
+
+clear_child_bitmap:
+               __clear_bit(i, sp->unsync_child_bitmap);
+               sp->unsync_children--;
+               WARN_ON((int)sp->unsync_children < 0);
        }
 
-       if (find_first_bit(sp->unsync_child_bitmap, 512) == 512)
-               sp->unsync_children = 0;
 
        return nr_unsync_leaf;
 }
@@ -1166,26 +1202,6 @@ static int mmu_unsync_walk(struct kvm_mmu_page *sp,
        return __mmu_unsync_walk(sp, pvec);
 }
 
-static struct kvm_mmu_page *kvm_mmu_lookup_page(struct kvm *kvm, gfn_t gfn)
-{
-       unsigned index;
-       struct hlist_head *bucket;
-       struct kvm_mmu_page *sp;
-       struct hlist_node *node;
-
-       pgprintk("%s: looking for gfn %lx\n", __func__, gfn);
-       index = kvm_page_table_hashfn(gfn);
-       bucket = &kvm->arch.mmu_page_hash[index];
-       hlist_for_each_entry(sp, node, bucket, hash_link)
-               if (sp->gfn == gfn && !sp->role.direct
-                   && !sp->role.invalid) {
-                       pgprintk("%s: found role %x\n",
-                                __func__, sp->role.word);
-                       return sp;
-               }
-       return NULL;
-}
-
 static void kvm_unlink_unsync_page(struct kvm *kvm, struct kvm_mmu_page *sp)
 {
        WARN_ON(!sp->unsync);
@@ -1194,20 +1210,36 @@ static void kvm_unlink_unsync_page(struct kvm *kvm, struct kvm_mmu_page *sp)
        --kvm->stat.mmu_unsync;
 }
 
-static int kvm_mmu_zap_page(struct kvm *kvm, struct kvm_mmu_page *sp);
+static int kvm_mmu_prepare_zap_page(struct kvm *kvm, struct kvm_mmu_page *sp,
+                                   struct list_head *invalid_list);
+static void kvm_mmu_commit_zap_page(struct kvm *kvm,
+                                   struct list_head *invalid_list);
+
+#define for_each_gfn_sp(kvm, sp, gfn, pos)                             \
+  hlist_for_each_entry(sp, pos,                                                \
+   &(kvm)->arch.mmu_page_hash[kvm_page_table_hashfn(gfn)], hash_link)  \
+       if ((sp)->gfn != (gfn)) {} else
+
+#define for_each_gfn_indirect_valid_sp(kvm, sp, gfn, pos)              \
+  hlist_for_each_entry(sp, pos,                                                \
+   &(kvm)->arch.mmu_page_hash[kvm_page_table_hashfn(gfn)], hash_link)  \
+               if ((sp)->gfn != (gfn) || (sp)->role.direct ||          \
+                       (sp)->role.invalid) {} else
 
-static int kvm_sync_page(struct kvm_vcpu *vcpu, struct kvm_mmu_page *sp)
+/* @sp->gfn should be write-protected at the call site */
+static int __kvm_sync_page(struct kvm_vcpu *vcpu, struct kvm_mmu_page *sp,
+                          struct list_head *invalid_list, bool clear_unsync)
 {
        if (sp->role.cr4_pae != !!is_pae(vcpu)) {
-               kvm_mmu_zap_page(vcpu->kvm, sp);
+               kvm_mmu_prepare_zap_page(vcpu->kvm, sp, invalid_list);
                return 1;
        }
 
-       if (rmap_write_protect(vcpu->kvm, sp->gfn))
-               kvm_flush_remote_tlbs(vcpu->kvm);
-       kvm_unlink_unsync_page(vcpu->kvm, sp);
-       if (vcpu->arch.mmu.sync_page(vcpu, sp)) {
-               kvm_mmu_zap_page(vcpu->kvm, sp);
+       if (clear_unsync)
+               kvm_unlink_unsync_page(vcpu->kvm, sp);
+
+       if (vcpu->arch.mmu.sync_page(vcpu, sp, clear_unsync)) {
+               kvm_mmu_prepare_zap_page(vcpu->kvm, sp, invalid_list);
                return 1;
        }
 
@@ -1215,6 +1247,52 @@ static int kvm_sync_page(struct kvm_vcpu *vcpu, struct kvm_mmu_page *sp)
        return 0;
 }
 
+static int kvm_sync_page_transient(struct kvm_vcpu *vcpu,
+                                  struct kvm_mmu_page *sp)
+{
+       LIST_HEAD(invalid_list);
+       int ret;
+
+       ret = __kvm_sync_page(vcpu, sp, &invalid_list, false);
+       if (ret)
+               kvm_mmu_commit_zap_page(vcpu->kvm, &invalid_list);
+
+       return ret;
+}
+
+static int kvm_sync_page(struct kvm_vcpu *vcpu, struct kvm_mmu_page *sp,
+                        struct list_head *invalid_list)
+{
+       return __kvm_sync_page(vcpu, sp, invalid_list, true);
+}
+
+/* @gfn should be write-protected at the call site */
+static void kvm_sync_pages(struct kvm_vcpu *vcpu,  gfn_t gfn)
+{
+       struct kvm_mmu_page *s;
+       struct hlist_node *node;
+       LIST_HEAD(invalid_list);
+       bool flush = false;
+
+       for_each_gfn_indirect_valid_sp(vcpu->kvm, s, gfn, node) {
+               if (!s->unsync)
+                       continue;
+
+               WARN_ON(s->role.level != PT_PAGE_TABLE_LEVEL);
+               if ((s->role.cr4_pae != !!is_pae(vcpu)) ||
+                       (vcpu->arch.mmu.sync_page(vcpu, s, true))) {
+                       kvm_mmu_prepare_zap_page(vcpu->kvm, s, &invalid_list);
+                       continue;
+               }
+               kvm_unlink_unsync_page(vcpu->kvm, s);
+               flush = true;
+       }
+
+       kvm_mmu_commit_zap_page(vcpu->kvm, &invalid_list);
+       if (flush)
+               kvm_mmu_flush_tlb(vcpu);
+}
+
 struct mmu_page_path {
        struct kvm_mmu_page *parent[PT64_ROOT_LEVEL-1];
        unsigned int idx[PT64_ROOT_LEVEL-1];
@@ -1281,6 +1359,7 @@ static void mmu_sync_children(struct kvm_vcpu *vcpu,
        struct kvm_mmu_page *sp;
        struct mmu_page_path parents;
        struct kvm_mmu_pages pages;
+       LIST_HEAD(invalid_list);
 
        kvm_mmu_pages_init(parent, &parents, &pages);
        while (mmu_unsync_walk(parent, &pages)) {
@@ -1293,9 +1372,10 @@ static void mmu_sync_children(struct kvm_vcpu *vcpu,
                        kvm_flush_remote_tlbs(vcpu->kvm);
 
                for_each_sp(pages, sp, parents, i) {
-                       kvm_sync_page(vcpu, sp);
+                       kvm_sync_page(vcpu, sp, &invalid_list);
                        mmu_pages_clear_parents(&parents);
                }
+               kvm_mmu_commit_zap_page(vcpu->kvm, &invalid_list);
                cond_resched_lock(&vcpu->kvm->mmu_lock);
                kvm_mmu_pages_init(parent, &parents, &pages);
        }
@@ -1310,11 +1390,10 @@ static struct kvm_mmu_page *kvm_mmu_get_page(struct kvm_vcpu *vcpu,
                                             u64 *parent_pte)
 {
        union kvm_mmu_page_role role;
-       unsigned index;
        unsigned quadrant;
-       struct hlist_head *bucket;
        struct kvm_mmu_page *sp;
-       struct hlist_node *node, *tmp;
+       struct hlist_node *node;
+       bool need_sync = false;
 
        role = vcpu->arch.mmu.base_role;
        role.level = level;
@@ -1322,40 +1401,45 @@ static struct kvm_mmu_page *kvm_mmu_get_page(struct kvm_vcpu *vcpu,
        if (role.direct)
                role.cr4_pae = 0;
        role.access = access;
-       if (vcpu->arch.mmu.root_level <= PT32_ROOT_LEVEL) {
+       if (!tdp_enabled && vcpu->arch.mmu.root_level <= PT32_ROOT_LEVEL) {
                quadrant = gaddr >> (PAGE_SHIFT + (PT64_PT_BITS * level));
                quadrant &= (1 << ((PT32_PT_BITS - PT64_PT_BITS) * level)) - 1;
                role.quadrant = quadrant;
        }
-       index = kvm_page_table_hashfn(gfn);
-       bucket = &vcpu->kvm->arch.mmu_page_hash[index];
-       hlist_for_each_entry_safe(sp, node, tmp, bucket, hash_link)
-               if (sp->gfn == gfn) {
-                       if (sp->unsync)
-                               if (kvm_sync_page(vcpu, sp))
-                                       continue;
+       for_each_gfn_sp(vcpu->kvm, sp, gfn, node) {
+               if (!need_sync && sp->unsync)
+                       need_sync = true;
 
-                       if (sp->role.word != role.word)
-                               continue;
+               if (sp->role.word != role.word)
+                       continue;
 
-                       mmu_page_add_parent_pte(vcpu, sp, parent_pte);
-                       if (sp->unsync_children) {
-                               set_bit(KVM_REQ_MMU_SYNC, &vcpu->requests);
-                               kvm_mmu_mark_parents_unsync(sp);
-                       }
-                       trace_kvm_mmu_get_page(sp, false);
-                       return sp;
-               }
+               if (sp->unsync && kvm_sync_page_transient(vcpu, sp))
+                       break;
+
+               mmu_page_add_parent_pte(vcpu, sp, parent_pte);
+               if (sp->unsync_children) {
+                       kvm_make_request(KVM_REQ_MMU_SYNC, vcpu);
+                       kvm_mmu_mark_parents_unsync(sp);
+               } else if (sp->unsync)
+                       kvm_mmu_mark_parents_unsync(sp);
+
+               trace_kvm_mmu_get_page(sp, false);
+               return sp;
+       }
        ++vcpu->kvm->stat.mmu_cache_miss;
-       sp = kvm_mmu_alloc_page(vcpu, parent_pte);
+       sp = kvm_mmu_alloc_page(vcpu, parent_pte, direct);
        if (!sp)
                return sp;
        sp->gfn = gfn;
        sp->role = role;
-       hlist_add_head(&sp->hash_link, bucket);
+       hlist_add_head(&sp->hash_link,
+               &vcpu->kvm->arch.mmu_page_hash[kvm_page_table_hashfn(gfn)]);
        if (!direct) {
                if (rmap_write_protect(vcpu->kvm, gfn))
                        kvm_flush_remote_tlbs(vcpu->kvm);
+               if (level > PT_PAGE_TABLE_LEVEL && need_sync)
+                       kvm_sync_pages(vcpu, gfn);
+
                account_shadowed(vcpu->kvm, gfn);
        }
        if (shadow_trap_nonpresent_pte != shadow_notrap_nonpresent_pte)
@@ -1402,6 +1486,47 @@ static void shadow_walk_next(struct kvm_shadow_walk_iterator *iterator)
        --iterator->level;
 }
 
+static void link_shadow_page(u64 *sptep, struct kvm_mmu_page *sp)
+{
+       u64 spte;
+
+       spte = __pa(sp->spt)
+               | PT_PRESENT_MASK | PT_ACCESSED_MASK
+               | PT_WRITABLE_MASK | PT_USER_MASK;
+       __set_spte(sptep, spte);
+}
+
+static void drop_large_spte(struct kvm_vcpu *vcpu, u64 *sptep)
+{
+       if (is_large_pte(*sptep)) {
+               drop_spte(vcpu->kvm, sptep, shadow_trap_nonpresent_pte);
+               kvm_flush_remote_tlbs(vcpu->kvm);
+       }
+}
+
+static void validate_direct_spte(struct kvm_vcpu *vcpu, u64 *sptep,
+                                  unsigned direct_access)
+{
+       if (is_shadow_present_pte(*sptep) && !is_large_pte(*sptep)) {
+               struct kvm_mmu_page *child;
+
+               /*
+                * For the direct sp, if the guest pte's dirty bit
+                * changed form clean to dirty, it will corrupt the
+                * sp's access: allow writable in the read-only sp,
+                * so we should update the spte at this point to get
+                * a new sp with the correct access.
+                */
+               child = page_header(*sptep & PT64_BASE_ADDR_MASK);
+               if (child->role.access == direct_access)
+                       return;
+
+               mmu_page_remove_parent_pte(child, sptep);
+               __set_spte(sptep, shadow_trap_nonpresent_pte);
+               kvm_flush_remote_tlbs(vcpu->kvm);
+       }
+}
+
 static void kvm_mmu_page_unlink_children(struct kvm *kvm,
                                         struct kvm_mmu_page *sp)
 {
@@ -1422,7 +1547,8 @@ static void kvm_mmu_page_unlink_children(struct kvm *kvm,
                        } else {
                                if (is_large_pte(ent))
                                        --kvm->stat.lpages;
-                               rmap_remove(kvm, &pt[i]);
+                               drop_spte(kvm, &pt[i],
+                                         shadow_trap_nonpresent_pte);
                        }
                }
                pt[i] = shadow_trap_nonpresent_pte;
@@ -1464,7 +1590,8 @@ static void kvm_mmu_unlink_parents(struct kvm *kvm, struct kvm_mmu_page *sp)
 }
 
 static int mmu_zap_unsync_children(struct kvm *kvm,
-                                  struct kvm_mmu_page *parent)
+                                  struct kvm_mmu_page *parent,
+                                  struct list_head *invalid_list)
 {
        int i, zapped = 0;
        struct mmu_page_path parents;
@@ -1478,7 +1605,7 @@ static int mmu_zap_unsync_children(struct kvm *kvm,
                struct kvm_mmu_page *sp;
 
                for_each_sp(pages, sp, parents, i) {
-                       kvm_mmu_zap_page(kvm, sp);
+                       kvm_mmu_prepare_zap_page(kvm, sp, invalid_list);
                        mmu_pages_clear_parents(&parents);
                        zapped++;
                }
@@ -1488,32 +1615,52 @@ static int mmu_zap_unsync_children(struct kvm *kvm,
        return zapped;
 }
 
-static int kvm_mmu_zap_page(struct kvm *kvm, struct kvm_mmu_page *sp)
+static int kvm_mmu_prepare_zap_page(struct kvm *kvm, struct kvm_mmu_page *sp,
+                                   struct list_head *invalid_list)
 {
        int ret;
 
-       trace_kvm_mmu_zap_page(sp);
+       trace_kvm_mmu_prepare_zap_page(sp);
        ++kvm->stat.mmu_shadow_zapped;
-       ret = mmu_zap_unsync_children(kvm, sp);
+       ret = mmu_zap_unsync_children(kvm, sp, invalid_list);
        kvm_mmu_page_unlink_children(kvm, sp);
        kvm_mmu_unlink_parents(kvm, sp);
-       kvm_flush_remote_tlbs(kvm);
        if (!sp->role.invalid && !sp->role.direct)
                unaccount_shadowed(kvm, sp->gfn);
        if (sp->unsync)
                kvm_unlink_unsync_page(kvm, sp);
        if (!sp->root_count) {
-               hlist_del(&sp->hash_link);
-               kvm_mmu_free_page(kvm, sp);
+               /* Count self */
+               ret++;
+               list_move(&sp->link, invalid_list);
        } else {
-               sp->role.invalid = 1;
                list_move(&sp->link, &kvm->arch.active_mmu_pages);
                kvm_reload_remote_mmus(kvm);
        }
+
+       sp->role.invalid = 1;
        kvm_mmu_reset_last_pte_updated(kvm);
        return ret;
 }
 
+static void kvm_mmu_commit_zap_page(struct kvm *kvm,
+                                   struct list_head *invalid_list)
+{
+       struct kvm_mmu_page *sp;
+
+       if (list_empty(invalid_list))
+               return;
+
+       kvm_flush_remote_tlbs(kvm);
+
+       do {
+               sp = list_first_entry(invalid_list, struct kvm_mmu_page, link);
+               WARN_ON(!sp->role.invalid || sp->root_count);
+               kvm_mmu_free_page(kvm, sp);
+       } while (!list_empty(invalid_list));
+
+}
+
 /*
  * Changing the number of mmu pages allocated to the vm
  * Note: if kvm_nr_mmu_pages is too small, you will get dead lock
@@ -1521,6 +1668,7 @@ static int kvm_mmu_zap_page(struct kvm *kvm, struct kvm_mmu_page *sp)
 void kvm_mmu_change_mmu_pages(struct kvm *kvm, unsigned int kvm_nr_mmu_pages)
 {
        int used_pages;
+       LIST_HEAD(invalid_list);
 
        used_pages = kvm->arch.n_alloc_mmu_pages - kvm->arch.n_free_mmu_pages;
        used_pages = max(0, used_pages);
@@ -1538,9 +1686,10 @@ void kvm_mmu_change_mmu_pages(struct kvm *kvm, unsigned int kvm_nr_mmu_pages)
 
                        page = container_of(kvm->arch.active_mmu_pages.prev,
                                            struct kvm_mmu_page, link);
-                       used_pages -= kvm_mmu_zap_page(kvm, page);
-                       used_pages--;
+                       used_pages -= kvm_mmu_prepare_zap_page(kvm, page,
+                                                              &invalid_list);
                }
+               kvm_mmu_commit_zap_page(kvm, &invalid_list);
                kvm_nr_mmu_pages = used_pages;
                kvm->arch.n_free_mmu_pages = 0;
        }
@@ -1553,47 +1702,36 @@ void kvm_mmu_change_mmu_pages(struct kvm *kvm, unsigned int kvm_nr_mmu_pages)
 
 static int kvm_mmu_unprotect_page(struct kvm *kvm, gfn_t gfn)
 {
-       unsigned index;
-       struct hlist_head *bucket;
        struct kvm_mmu_page *sp;
-       struct hlist_node *node, *n;
+       struct hlist_node *node;
+       LIST_HEAD(invalid_list);
        int r;
 
        pgprintk("%s: looking for gfn %lx\n", __func__, gfn);
        r = 0;
-       index = kvm_page_table_hashfn(gfn);
-       bucket = &kvm->arch.mmu_page_hash[index];
-restart:
-       hlist_for_each_entry_safe(sp, node, n, bucket, hash_link)
-               if (sp->gfn == gfn && !sp->role.direct) {
-                       pgprintk("%s: gfn %lx role %x\n", __func__, gfn,
-                                sp->role.word);
-                       r = 1;
-                       if (kvm_mmu_zap_page(kvm, sp))
-                               goto restart;
-               }
+
+       for_each_gfn_indirect_valid_sp(kvm, sp, gfn, node) {
+               pgprintk("%s: gfn %lx role %x\n", __func__, gfn,
+                        sp->role.word);
+               r = 1;
+               kvm_mmu_prepare_zap_page(kvm, sp, &invalid_list);
+       }
+       kvm_mmu_commit_zap_page(kvm, &invalid_list);
        return r;
 }
 
 static void mmu_unshadow(struct kvm *kvm, gfn_t gfn)
 {
-       unsigned index;
-       struct hlist_head *bucket;
        struct kvm_mmu_page *sp;
-       struct hlist_node *node, *nn;
+       struct hlist_node *node;
+       LIST_HEAD(invalid_list);
 
-       index = kvm_page_table_hashfn(gfn);
-       bucket = &kvm->arch.mmu_page_hash[index];
-restart:
-       hlist_for_each_entry_safe(sp, node, nn, bucket, hash_link) {
-               if (sp->gfn == gfn && !sp->role.direct
-                   && !sp->role.invalid) {
-                       pgprintk("%s: zap %lx %x\n",
-                                __func__, gfn, sp->role.word);
-                       if (kvm_mmu_zap_page(kvm, sp))
-                               goto restart;
-               }
+       for_each_gfn_indirect_valid_sp(kvm, sp, gfn, node) {
+               pgprintk("%s: zap %lx %x\n",
+                        __func__, gfn, sp->role.word);
+               kvm_mmu_prepare_zap_page(kvm, sp, &invalid_list);
        }
+       kvm_mmu_commit_zap_page(kvm, &invalid_list);
 }
 
 static void page_header_update_slot(struct kvm *kvm, void *pte, gfn_t gfn)
@@ -1723,47 +1861,51 @@ u8 kvm_get_guest_memory_type(struct kvm_vcpu *vcpu, gfn_t gfn)
 }
 EXPORT_SYMBOL_GPL(kvm_get_guest_memory_type);
 
-static int kvm_unsync_page(struct kvm_vcpu *vcpu, struct kvm_mmu_page *sp)
+static void __kvm_unsync_page(struct kvm_vcpu *vcpu, struct kvm_mmu_page *sp)
 {
-       unsigned index;
-       struct hlist_head *bucket;
-       struct kvm_mmu_page *s;
-       struct hlist_node *node, *n;
-
-       index = kvm_page_table_hashfn(sp->gfn);
-       bucket = &vcpu->kvm->arch.mmu_page_hash[index];
-       /* don't unsync if pagetable is shadowed with multiple roles */
-       hlist_for_each_entry_safe(s, node, n, bucket, hash_link) {
-               if (s->gfn != sp->gfn || s->role.direct)
-                       continue;
-               if (s->role.word != sp->role.word)
-                       return 1;
-       }
        trace_kvm_mmu_unsync_page(sp);
        ++vcpu->kvm->stat.mmu_unsync;
        sp->unsync = 1;
 
        kvm_mmu_mark_parents_unsync(sp);
-
        mmu_convert_notrap(sp);
-       return 0;
+}
+
+static void kvm_unsync_pages(struct kvm_vcpu *vcpu,  gfn_t gfn)
+{
+       struct kvm_mmu_page *s;
+       struct hlist_node *node;
+
+       for_each_gfn_indirect_valid_sp(vcpu->kvm, s, gfn, node) {
+               if (s->unsync)
+                       continue;
+               WARN_ON(s->role.level != PT_PAGE_TABLE_LEVEL);
+               __kvm_unsync_page(vcpu, s);
+       }
 }
 
 static int mmu_need_write_protect(struct kvm_vcpu *vcpu, gfn_t gfn,
                                  bool can_unsync)
 {
-       struct kvm_mmu_page *shadow;
+       struct kvm_mmu_page *s;
+       struct hlist_node *node;
+       bool need_unsync = false;
 
-       shadow = kvm_mmu_lookup_page(vcpu->kvm, gfn);
-       if (shadow) {
-               if (shadow->role.level != PT_PAGE_TABLE_LEVEL)
+       for_each_gfn_indirect_valid_sp(vcpu->kvm, s, gfn, node) {
+               if (!can_unsync)
                        return 1;
-               if (shadow->unsync)
-                       return 0;
-               if (can_unsync && oos_shadow)
-                       return kvm_unsync_page(vcpu, shadow);
-               return 1;
+
+               if (s->role.level != PT_PAGE_TABLE_LEVEL)
+                       return 1;
+
+               if (!need_unsync && !s->unsync) {
+                       if (!oos_shadow)
+                               return 1;
+                       need_unsync = true;
+               }
        }
+       if (need_unsync)
+               kvm_unsync_pages(vcpu, gfn);
        return 0;
 }
 
@@ -1804,13 +1946,14 @@ static int set_spte(struct kvm_vcpu *vcpu, u64 *sptep,
        spte |= (u64)pfn << PAGE_SHIFT;
 
        if ((pte_access & ACC_WRITE_MASK)
-           || (write_fault && !is_write_protection(vcpu) && !user_fault)) {
+           || (!tdp_enabled && write_fault && !is_write_protection(vcpu)
+               && !user_fault)) {
 
                if (level > PT_PAGE_TABLE_LEVEL &&
                    has_wrprotected_page(vcpu->kvm, gfn, level)) {
                        ret = 1;
-                       spte = shadow_trap_nonpresent_pte;
-                       goto set_pte;
+                       drop_spte(vcpu->kvm, sptep, shadow_trap_nonpresent_pte);
+                       goto done;
                }
 
                spte |= PT_WRITABLE_MASK;
@@ -1841,7 +1984,8 @@ static int set_spte(struct kvm_vcpu *vcpu, u64 *sptep,
                mark_page_dirty(vcpu->kvm, gfn);
 
 set_pte:
-       __set_spte(sptep, spte);
+       update_spte(sptep, spte);
+done:
        return ret;
 }
 
@@ -1878,7 +2022,8 @@ static void mmu_set_spte(struct kvm_vcpu *vcpu, u64 *sptep,
                } else if (pfn != spte_to_pfn(*sptep)) {
                        pgprintk("hfn old %lx new %lx\n",
                                 spte_to_pfn(*sptep), pfn);
-                       rmap_remove(vcpu->kvm, sptep);
+                       drop_spte(vcpu->kvm, sptep, shadow_trap_nonpresent_pte);
+                       kvm_flush_remote_tlbs(vcpu->kvm);
                } else
                        was_rmapped = 1;
        }
@@ -1888,7 +2033,7 @@ static void mmu_set_spte(struct kvm_vcpu *vcpu, u64 *sptep,
                      reset_host_protection)) {
                if (write_fault)
                        *ptwrite = 1;
-               kvm_x86_ops->tlb_flush(vcpu);
+               kvm_mmu_flush_tlb(vcpu);
        }
 
        pgprintk("%s: setting spte %llx\n", __func__, *sptep);
@@ -1939,7 +2084,10 @@ static int __direct_map(struct kvm_vcpu *vcpu, gpa_t v, int write,
                }
 
                if (*iterator.sptep == shadow_trap_nonpresent_pte) {
-                       pseudo_gfn = (iterator.addr & PT64_DIR_BASE_ADDR_MASK) >> PAGE_SHIFT;
+                       u64 base_addr = iterator.addr;
+
+                       base_addr &= PT64_LVL_ADDR_MASK(iterator.level);
+                       pseudo_gfn = base_addr >> PAGE_SHIFT;
                        sp = kvm_mmu_get_page(vcpu, pseudo_gfn, iterator.addr,
                                              iterator.level - 1,
                                              1, ACC_ALL, iterator.sptep);
@@ -1958,6 +2106,29 @@ static int __direct_map(struct kvm_vcpu *vcpu, gpa_t v, int write,
        return pt_write;
 }
 
+static void kvm_send_hwpoison_signal(struct kvm *kvm, gfn_t gfn)
+{
+       char buf[1];
+       void __user *hva;
+       int r;
+
+       /* Touch the page, so send SIGBUS */
+       hva = (void __user *)gfn_to_hva(kvm, gfn);
+       r = copy_from_user(buf, hva, 1);
+}
+
+static int kvm_handle_bad_page(struct kvm *kvm, gfn_t gfn, pfn_t pfn)
+{
+       kvm_release_pfn_clean(pfn);
+       if (is_hwpoison_pfn(pfn)) {
+               kvm_send_hwpoison_signal(kvm, gfn);
+               return 0;
+       } else if (is_fault_pfn(pfn))
+               return -EFAULT;
+
+       return 1;
+}
+
 static int nonpaging_map(struct kvm_vcpu *vcpu, gva_t v, int write, gfn_t gfn)
 {
        int r;
@@ -1981,10 +2152,8 @@ static int nonpaging_map(struct kvm_vcpu *vcpu, gva_t v, int write, gfn_t gfn)
        pfn = gfn_to_pfn(vcpu->kvm, gfn);
 
        /* mmio */
-       if (is_error_pfn(pfn)) {
-               kvm_release_pfn_clean(pfn);
-               return 1;
-       }
+       if (is_error_pfn(pfn))
+               return kvm_handle_bad_page(vcpu->kvm, gfn, pfn);
 
        spin_lock(&vcpu->kvm->mmu_lock);
        if (mmu_notifier_retry(vcpu, mmu_seq))
@@ -2007,6 +2176,7 @@ static void mmu_free_roots(struct kvm_vcpu *vcpu)
 {
        int i;
        struct kvm_mmu_page *sp;
+       LIST_HEAD(invalid_list);
 
        if (!VALID_PAGE(vcpu->arch.mmu.root_hpa))
                return;
@@ -2016,8 +2186,10 @@ static void mmu_free_roots(struct kvm_vcpu *vcpu)
 
                sp = page_header(root);
                --sp->root_count;
-               if (!sp->root_count && sp->role.invalid)
-                       kvm_mmu_zap_page(vcpu->kvm, sp);
+               if (!sp->root_count && sp->role.invalid) {
+                       kvm_mmu_prepare_zap_page(vcpu->kvm, sp, &invalid_list);
+                       kvm_mmu_commit_zap_page(vcpu->kvm, &invalid_list);
+               }
                vcpu->arch.mmu.root_hpa = INVALID_PAGE;
                spin_unlock(&vcpu->kvm->mmu_lock);
                return;
@@ -2030,10 +2202,12 @@ static void mmu_free_roots(struct kvm_vcpu *vcpu)
                        sp = page_header(root);
                        --sp->root_count;
                        if (!sp->root_count && sp->role.invalid)
-                               kvm_mmu_zap_page(vcpu->kvm, sp);
+                               kvm_mmu_prepare_zap_page(vcpu->kvm, sp,
+                                                        &invalid_list);
                }
                vcpu->arch.mmu.pae_root[i] = INVALID_PAGE;
        }
+       kvm_mmu_commit_zap_page(vcpu->kvm, &invalid_list);
        spin_unlock(&vcpu->kvm->mmu_lock);
        vcpu->arch.mmu.root_hpa = INVALID_PAGE;
 }
@@ -2043,7 +2217,7 @@ static int mmu_check_root(struct kvm_vcpu *vcpu, gfn_t root_gfn)
        int ret = 0;
 
        if (!kvm_is_visible_gfn(vcpu->kvm, root_gfn)) {
-               set_bit(KVM_REQ_TRIPLE_FAULT, &vcpu->requests);
+               kvm_make_request(KVM_REQ_TRIPLE_FAULT, vcpu);
                ret = 1;
        }
 
@@ -2071,6 +2245,7 @@ static int mmu_alloc_roots(struct kvm_vcpu *vcpu)
                        root_gfn = 0;
                }
                spin_lock(&vcpu->kvm->mmu_lock);
+               kvm_mmu_free_some_pages(vcpu);
                sp = kvm_mmu_get_page(vcpu, root_gfn, 0,
                                      PT64_ROOT_LEVEL, direct,
                                      ACC_ALL, NULL);
@@ -2101,6 +2276,7 @@ static int mmu_alloc_roots(struct kvm_vcpu *vcpu)
                        root_gfn = i << 30;
                }
                spin_lock(&vcpu->kvm->mmu_lock);
+               kvm_mmu_free_some_pages(vcpu);
                sp = kvm_mmu_get_page(vcpu, root_gfn, i << 30,
                                      PT32_ROOT_LEVEL, direct,
                                      ACC_ALL, NULL);
@@ -2196,10 +2372,8 @@ static int tdp_page_fault(struct kvm_vcpu *vcpu, gva_t gpa,
        mmu_seq = vcpu->kvm->mmu_notifier_seq;
        smp_rmb();
        pfn = gfn_to_pfn(vcpu->kvm, gfn);
-       if (is_error_pfn(pfn)) {
-               kvm_release_pfn_clean(pfn);
-               return 1;
-       }
+       if (is_error_pfn(pfn))
+               return kvm_handle_bad_page(vcpu->kvm, gfn, pfn);
        spin_lock(&vcpu->kvm->mmu_lock);
        if (mmu_notifier_retry(vcpu, mmu_seq))
                goto out_unlock;
@@ -2241,7 +2415,7 @@ static int nonpaging_init_context(struct kvm_vcpu *vcpu)
 void kvm_mmu_flush_tlb(struct kvm_vcpu *vcpu)
 {
        ++vcpu->stat.tlb_flush;
-       kvm_x86_ops->tlb_flush(vcpu);
+       kvm_make_request(KVM_REQ_TLB_FLUSH, vcpu);
 }
 
 static void paging_new_cr3(struct kvm_vcpu *vcpu)
@@ -2455,10 +2629,9 @@ static int init_kvm_mmu(struct kvm_vcpu *vcpu)
 static void destroy_kvm_mmu(struct kvm_vcpu *vcpu)
 {
        ASSERT(vcpu);
-       if (VALID_PAGE(vcpu->arch.mmu.root_hpa)) {
+       if (VALID_PAGE(vcpu->arch.mmu.root_hpa))
+               /* mmu.free() should set root_hpa = INVALID_PAGE */
                vcpu->arch.mmu.free(vcpu);
-               vcpu->arch.mmu.root_hpa = INVALID_PAGE;
-       }
 }
 
 int kvm_mmu_reset_context(struct kvm_vcpu *vcpu)
@@ -2475,9 +2648,6 @@ int kvm_mmu_load(struct kvm_vcpu *vcpu)
        r = mmu_topup_memory_caches(vcpu);
        if (r)
                goto out;
-       spin_lock(&vcpu->kvm->mmu_lock);
-       kvm_mmu_free_some_pages(vcpu);
-       spin_unlock(&vcpu->kvm->mmu_lock);
        r = mmu_alloc_roots(vcpu);
        spin_lock(&vcpu->kvm->mmu_lock);
        mmu_sync_roots(vcpu);
@@ -2506,7 +2676,7 @@ static void mmu_pte_write_zap_pte(struct kvm_vcpu *vcpu,
        pte = *spte;
        if (is_shadow_present_pte(pte)) {
                if (is_last_spte(pte, sp->role.level))
-                       rmap_remove(vcpu->kvm, spte);
+                       drop_spte(vcpu->kvm, spte, shadow_trap_nonpresent_pte);
                else {
                        child = page_header(pte & PT64_BASE_ADDR_MASK);
                        mmu_page_remove_parent_pte(child, spte);
@@ -2547,11 +2717,15 @@ static bool need_remote_flush(u64 old, u64 new)
        return (old & ~new & PT64_PERM_MASK) != 0;
 }
 
-static void mmu_pte_write_flush_tlb(struct kvm_vcpu *vcpu, u64 old, u64 new)
+static void mmu_pte_write_flush_tlb(struct kvm_vcpu *vcpu, bool zap_page,
+                                   bool remote_flush, bool local_flush)
 {
-       if (need_remote_flush(old, new))
+       if (zap_page)
+               return;
+
+       if (remote_flush)
                kvm_flush_remote_tlbs(vcpu->kvm);
-       else
+       else if (local_flush)
                kvm_mmu_flush_tlb(vcpu);
 }
 
@@ -2602,9 +2776,8 @@ void kvm_mmu_pte_write(struct kvm_vcpu *vcpu, gpa_t gpa,
 {
        gfn_t gfn = gpa >> PAGE_SHIFT;
        struct kvm_mmu_page *sp;
-       struct hlist_node *node, *n;
-       struct hlist_head *bucket;
-       unsigned index;
+       struct hlist_node *node;
+       LIST_HEAD(invalid_list);
        u64 entry, gentry;
        u64 *spte;
        unsigned offset = offset_in_page(gpa);
@@ -2617,6 +2790,9 @@ void kvm_mmu_pte_write(struct kvm_vcpu *vcpu, gpa_t gpa,
        int npte;
        int r;
        int invlpg_counter;
+       bool remote_flush, local_flush, zap_page;
+
+       zap_page = remote_flush = local_flush = false;
 
        pgprintk("%s: gpa %llx bytes %d\n", __func__, gpa, bytes);
 
@@ -2672,13 +2848,8 @@ void kvm_mmu_pte_write(struct kvm_vcpu *vcpu, gpa_t gpa,
                        vcpu->arch.last_pte_updated = NULL;
                }
        }
-       index = kvm_page_table_hashfn(gfn);
-       bucket = &vcpu->kvm->arch.mmu_page_hash[index];
 
-restart:
-       hlist_for_each_entry_safe(sp, node, n, bucket, hash_link) {
-               if (sp->gfn != gfn || sp->role.direct || sp->role.invalid)
-                       continue;
+       for_each_gfn_indirect_valid_sp(vcpu->kvm, sp, gfn, node) {
                pte_size = sp->role.cr4_pae ? 8 : 4;
                misaligned = (offset ^ (offset + bytes - 1)) & ~(pte_size - 1);
                misaligned |= bytes < 4;
@@ -2695,8 +2866,8 @@ restart:
                         */
                        pgprintk("misaligned: gpa %llx bytes %d role %x\n",
                                 gpa, bytes, sp->role.word);
-                       if (kvm_mmu_zap_page(vcpu->kvm, sp))
-                               goto restart;
+                       zap_page |= !!kvm_mmu_prepare_zap_page(vcpu->kvm, sp,
+                                                    &invalid_list);
                        ++vcpu->kvm->stat.mmu_flooded;
                        continue;
                }
@@ -2720,16 +2891,20 @@ restart:
                        if (quadrant != sp->role.quadrant)
                                continue;
                }
+               local_flush = true;
                spte = &sp->spt[page_offset / sizeof(*spte)];
                while (npte--) {
                        entry = *spte;
                        mmu_pte_write_zap_pte(vcpu, sp, spte);
                        if (gentry)
                                mmu_pte_write_new_pte(vcpu, sp, spte, &gentry);
-                       mmu_pte_write_flush_tlb(vcpu, entry, *spte);
+                       if (!remote_flush && need_remote_flush(entry, *spte))
+                               remote_flush = true;
                        ++spte;
                }
        }
+       mmu_pte_write_flush_tlb(vcpu, zap_page, remote_flush, local_flush);
+       kvm_mmu_commit_zap_page(vcpu->kvm, &invalid_list);
        kvm_mmu_audit(vcpu, "post pte write");
        spin_unlock(&vcpu->kvm->mmu_lock);
        if (!is_error_pfn(vcpu->arch.update_pte.pfn)) {
@@ -2757,15 +2932,21 @@ EXPORT_SYMBOL_GPL(kvm_mmu_unprotect_page_virt);
 
 void __kvm_mmu_free_some_pages(struct kvm_vcpu *vcpu)
 {
-       while (vcpu->kvm->arch.n_free_mmu_pages < KVM_REFILL_PAGES &&
+       int free_pages;
+       LIST_HEAD(invalid_list);
+
+       free_pages = vcpu->kvm->arch.n_free_mmu_pages;
+       while (free_pages < KVM_REFILL_PAGES &&
               !list_empty(&vcpu->kvm->arch.active_mmu_pages)) {
                struct kvm_mmu_page *sp;
 
                sp = container_of(vcpu->kvm->arch.active_mmu_pages.prev,
                                  struct kvm_mmu_page, link);
-               kvm_mmu_zap_page(vcpu->kvm, sp);
+               free_pages += kvm_mmu_prepare_zap_page(vcpu->kvm, sp,
+                                                      &invalid_list);
                ++vcpu->kvm->stat.mmu_recycled;
        }
+       kvm_mmu_commit_zap_page(vcpu->kvm, &invalid_list);
 }
 
 int kvm_mmu_page_fault(struct kvm_vcpu *vcpu, gva_t cr2, u32 error_code)
@@ -2793,11 +2974,8 @@ int kvm_mmu_page_fault(struct kvm_vcpu *vcpu, gva_t cr2, u32 error_code)
                return 1;
        case EMULATE_DO_MMIO:
                ++vcpu->stat.mmio_exits;
-               return 0;
+               /* fall through */
        case EMULATE_FAIL:
-               vcpu->run->exit_reason = KVM_EXIT_INTERNAL_ERROR;
-               vcpu->run->internal.suberror = KVM_INTERNAL_ERROR_EMULATION;
-               vcpu->run->internal.ndata = 0;
                return 0;
        default:
                BUG();
@@ -2894,7 +3072,7 @@ void kvm_mmu_slot_remove_write_access(struct kvm *kvm, int slot)
                pt = sp->spt;
                for (i = 0; i < PT64_ENT_PER_PAGE; ++i)
                        /* avoid RMW */
-                       if (pt[i] & PT_WRITABLE_MASK)
+                       if (is_writable_pte(pt[i]))
                                pt[i] &= ~PT_WRITABLE_MASK;
        }
        kvm_flush_remote_tlbs(kvm);
@@ -2903,28 +3081,29 @@ void kvm_mmu_slot_remove_write_access(struct kvm *kvm, int slot)
 void kvm_mmu_zap_all(struct kvm *kvm)
 {
        struct kvm_mmu_page *sp, *node;
+       LIST_HEAD(invalid_list);
 
        spin_lock(&kvm->mmu_lock);
 restart:
        list_for_each_entry_safe(sp, node, &kvm->arch.active_mmu_pages, link)
-               if (kvm_mmu_zap_page(kvm, sp))
+               if (kvm_mmu_prepare_zap_page(kvm, sp, &invalid_list))
                        goto restart;
 
+       kvm_mmu_commit_zap_page(kvm, &invalid_list);
        spin_unlock(&kvm->mmu_lock);
-
-       kvm_flush_remote_tlbs(kvm);
 }
 
-static int kvm_mmu_remove_some_alloc_mmu_pages(struct kvm *kvm)
+static int kvm_mmu_remove_some_alloc_mmu_pages(struct kvm *kvm,
+                                              struct list_head *invalid_list)
 {
        struct kvm_mmu_page *page;
 
        page = container_of(kvm->arch.active_mmu_pages.prev,
                            struct kvm_mmu_page, link);
-       return kvm_mmu_zap_page(kvm, page) + 1;
+       return kvm_mmu_prepare_zap_page(kvm, page, invalid_list);
 }
 
-static int mmu_shrink(int nr_to_scan, gfp_t gfp_mask)
+static int mmu_shrink(struct shrinker *shrink, int nr_to_scan, gfp_t gfp_mask)
 {
        struct kvm *kvm;
        struct kvm *kvm_freed = NULL;
@@ -2934,6 +3113,7 @@ static int mmu_shrink(int nr_to_scan, gfp_t gfp_mask)
 
        list_for_each_entry(kvm, &vm_list, vm_list) {
                int npages, idx, freed_pages;
+               LIST_HEAD(invalid_list);
 
                idx = srcu_read_lock(&kvm->srcu);
                spin_lock(&kvm->mmu_lock);
@@ -2941,12 +3121,14 @@ static int mmu_shrink(int nr_to_scan, gfp_t gfp_mask)
                         kvm->arch.n_free_mmu_pages;
                cache_count += npages;
                if (!kvm_freed && nr_to_scan > 0 && npages > 0) {
-                       freed_pages = kvm_mmu_remove_some_alloc_mmu_pages(kvm);
+                       freed_pages = kvm_mmu_remove_some_alloc_mmu_pages(kvm,
+                                                         &invalid_list);
                        cache_count -= freed_pages;
                        kvm_freed = kvm;
                }
                nr_to_scan--;
 
+               kvm_mmu_commit_zap_page(kvm, &invalid_list);
                spin_unlock(&kvm->mmu_lock);
                srcu_read_unlock(&kvm->srcu, idx);
        }
@@ -3072,7 +3254,7 @@ static int kvm_pv_mmu_write(struct kvm_vcpu *vcpu,
 
 static int kvm_pv_mmu_flush_tlb(struct kvm_vcpu *vcpu)
 {
-       kvm_set_cr3(vcpu, vcpu->arch.cr3);
+       (void)kvm_set_cr3(vcpu, vcpu->arch.cr3);
        return 1;
 }
 
@@ -3329,9 +3511,9 @@ void inspect_spte_has_rmap(struct kvm *kvm, u64 *sptep)
        struct kvm_mmu_page *rev_sp;
        gfn_t gfn;
 
-       if (*sptep & PT_WRITABLE_MASK) {
+       if (is_writable_pte(*sptep)) {
                rev_sp = page_header(__pa(sptep));
-               gfn = rev_sp->gfns[sptep - rev_sp->spt];
+               gfn = kvm_mmu_page_get_gfn(rev_sp, sptep - rev_sp->spt);
 
                if (!gfn_to_memslot(kvm, gfn)) {
                        if (!printk_ratelimit())
@@ -3345,8 +3527,7 @@ void inspect_spte_has_rmap(struct kvm *kvm, u64 *sptep)
                        return;
                }
 
-               rmapp = gfn_to_rmap(kvm, rev_sp->gfns[sptep - rev_sp->spt],
-                                   rev_sp->role.level);
+               rmapp = gfn_to_rmap(kvm, gfn, rev_sp->role.level);
                if (!*rmapp) {
                        if (!printk_ratelimit())
                                return;
@@ -3379,7 +3560,7 @@ static void check_writable_mappings_rmap(struct kvm_vcpu *vcpu)
 
                        if (!(ent & PT_PRESENT_MASK))
                                continue;
-                       if (!(ent & PT_WRITABLE_MASK))
+                       if (!is_writable_pte(ent))
                                continue;
                        inspect_spte_has_rmap(vcpu->kvm, &pt[i]);
                }
@@ -3407,13 +3588,12 @@ static void audit_write_protection(struct kvm_vcpu *vcpu)
                if (sp->unsync)
                        continue;
 
-               gfn = unalias_gfn(vcpu->kvm, sp->gfn);
-               slot = gfn_to_memslot_unaliased(vcpu->kvm, sp->gfn);
+               slot = gfn_to_memslot(vcpu->kvm, sp->gfn);
                rmapp = &slot->rmap[gfn - slot->base_gfn];
 
                spte = rmap_next(vcpu->kvm, rmapp, NULL);
                while (spte) {
-                       if (*spte & PT_WRITABLE_MASK)
+                       if (is_writable_pte(*spte))
                                printk(KERN_ERR "%s: (%s) shadow page has "
                                "writable mappings: gfn %lx role %x\n",
                               __func__, audit_msg, sp->gfn,