KVM: MMU: Remove user access when allowing kernel access to gpte.w=0 page
[safe/jmp/linux-2.6] / arch / x86 / kvm / mmu.c
index b93ad2c..a6f695d 100644 (file)
@@ -18,6 +18,7 @@
  */
 
 #include "mmu.h"
+#include "x86.h"
 #include "kvm_cache_regs.h"
 
 #include <linux/kvm_host.h>
@@ -29,6 +30,8 @@
 #include <linux/swap.h>
 #include <linux/hugetlb.h>
 #include <linux/compiler.h>
+#include <linux/srcu.h>
+#include <linux/slab.h>
 
 #include <asm/page.h>
 #include <asm/cmpxchg.h>
@@ -108,6 +111,9 @@ module_param(oos_shadow, bool, 0644);
 
 #define PT32_LEVEL_MASK(level) \
                (((1ULL << PT32_LEVEL_BITS) - 1) << PT32_LEVEL_SHIFT(level))
+#define PT32_LVL_OFFSET_MASK(level) \
+       (PT32_BASE_ADDR_MASK & ((1ULL << (PAGE_SHIFT + (((level) - 1) \
+                                               * PT32_LEVEL_BITS))) - 1))
 
 #define PT32_INDEX(address, level)\
        (((address) >> PT32_LEVEL_SHIFT(level)) & ((1 << PT32_LEVEL_BITS) - 1))
@@ -116,23 +122,23 @@ module_param(oos_shadow, bool, 0644);
 #define PT64_BASE_ADDR_MASK (((1ULL << 52) - 1) & ~(u64)(PAGE_SIZE-1))
 #define PT64_DIR_BASE_ADDR_MASK \
        (PT64_BASE_ADDR_MASK & ~((1ULL << (PAGE_SHIFT + PT64_LEVEL_BITS)) - 1))
+#define PT64_LVL_ADDR_MASK(level) \
+       (PT64_BASE_ADDR_MASK & ~((1ULL << (PAGE_SHIFT + (((level) - 1) \
+                                               * PT64_LEVEL_BITS))) - 1))
+#define PT64_LVL_OFFSET_MASK(level) \
+       (PT64_BASE_ADDR_MASK & ((1ULL << (PAGE_SHIFT + (((level) - 1) \
+                                               * PT64_LEVEL_BITS))) - 1))
 
 #define PT32_BASE_ADDR_MASK PAGE_MASK
 #define PT32_DIR_BASE_ADDR_MASK \
        (PAGE_MASK & ~((1ULL << (PAGE_SHIFT + PT32_LEVEL_BITS)) - 1))
+#define PT32_LVL_ADDR_MASK(level) \
+       (PAGE_MASK & ~((1ULL << (PAGE_SHIFT + (((level) - 1) \
+                                           * PT32_LEVEL_BITS))) - 1))
 
 #define PT64_PERM_MASK (PT_PRESENT_MASK | PT_WRITABLE_MASK | PT_USER_MASK \
                        | PT64_NX_MASK)
 
-#define PFERR_PRESENT_MASK (1U << 0)
-#define PFERR_WRITE_MASK (1U << 1)
-#define PFERR_USER_MASK (1U << 2)
-#define PFERR_RSVD_MASK (1U << 3)
-#define PFERR_FETCH_MASK (1U << 4)
-
-#define PT_DIRECTORY_LEVEL 2
-#define PT_PAGE_TABLE_LEVEL 1
-
 #define RMAP_EXT 4
 
 #define ACC_EXEC_MASK    1
@@ -140,9 +146,13 @@ module_param(oos_shadow, bool, 0644);
 #define ACC_USER_MASK    PT_USER_MASK
 #define ACC_ALL          (ACC_EXEC_MASK | ACC_WRITE_MASK | ACC_USER_MASK)
 
+#include <trace/events/kvm.h>
+
 #define CREATE_TRACE_POINTS
 #include "mmutrace.h"
 
+#define SPTE_HOST_WRITEABLE (1ULL << PT_FIRST_AVAIL_BITS_SHIFT)
+
 #define SHADOW_PT_INDEX(addr, level) PT64_INDEX(addr, level)
 
 struct kvm_rmap_desc {
@@ -163,12 +173,7 @@ struct kvm_shadow_walk_iterator {
             shadow_walk_okay(&(_walker));                      \
             shadow_walk_next(&(_walker)))
 
-
-struct kvm_unsync_walk {
-       int (*entry) (struct kvm_mmu_page *sp, struct kvm_unsync_walk *walk);
-};
-
-typedef int (*mmu_parent_walk_fn) (struct kvm_vcpu *vcpu, struct kvm_mmu_page *sp);
+typedef int (*mmu_parent_walk_fn) (struct kvm_mmu_page *sp);
 
 static struct kmem_cache *pte_chain_cache;
 static struct kmem_cache *rmap_desc_cache;
@@ -212,9 +217,9 @@ void kvm_mmu_set_mask_ptes(u64 user_mask, u64 accessed_mask,
 }
 EXPORT_SYMBOL_GPL(kvm_mmu_set_mask_ptes);
 
-static int is_write_protection(struct kvm_vcpu *vcpu)
+static bool is_write_protection(struct kvm_vcpu *vcpu)
 {
-       return vcpu->arch.cr0 & X86_CR0_WP;
+       return kvm_read_cr0_bits(vcpu, X86_CR0_WP);
 }
 
 static int is_cpuid_PSE36(void)
@@ -224,7 +229,7 @@ static int is_cpuid_PSE36(void)
 
 static int is_nx(struct kvm_vcpu *vcpu)
 {
-       return vcpu->arch.shadow_efer & EFER_NX;
+       return vcpu->arch.efer & EFER_NX;
 }
 
 static int is_shadow_present_pte(u64 pte)
@@ -238,7 +243,7 @@ static int is_large_pte(u64 pte)
        return pte & PT_PAGE_SIZE_MASK;
 }
 
-static int is_writeble_pte(unsigned long pte)
+static int is_writable_pte(unsigned long pte)
 {
        return pte & PT_WRITABLE_MASK;
 }
@@ -257,7 +262,7 @@ static int is_last_spte(u64 pte, int level)
 {
        if (level == PT_PAGE_TABLE_LEVEL)
                return 1;
-       if (level == PT_DIRECTORY_LEVEL && is_large_pte(pte))
+       if (is_large_pte(pte))
                return 1;
        return 0;
 }
@@ -316,7 +321,6 @@ static int mmu_topup_memory_cache_page(struct kvm_mmu_memory_cache *cache,
                page = alloc_page(GFP_KERNEL);
                if (!page)
                        return -ENOMEM;
-               set_page_private(page, 0);
                cache->objects[cache->nobjs++] = page_address(page);
        }
        return 0;
@@ -393,37 +397,52 @@ static void mmu_free_rmap_desc(struct kvm_rmap_desc *rd)
  * Return the pointer to the largepage write count for a given
  * gfn, handling slots that are not large page aligned.
  */
-static int *slot_largepage_idx(gfn_t gfn, struct kvm_memory_slot *slot)
+static int *slot_largepage_idx(gfn_t gfn,
+                              struct kvm_memory_slot *slot,
+                              int level)
 {
        unsigned long idx;
 
-       idx = (gfn / KVM_PAGES_PER_HPAGE(PT_DIRECTORY_LEVEL)) -
-             (slot->base_gfn / KVM_PAGES_PER_HPAGE(PT_DIRECTORY_LEVEL));
-       return &slot->lpage_info[0][idx].write_count;
+       idx = (gfn / KVM_PAGES_PER_HPAGE(level)) -
+             (slot->base_gfn / KVM_PAGES_PER_HPAGE(level));
+       return &slot->lpage_info[level - 2][idx].write_count;
 }
 
 static void account_shadowed(struct kvm *kvm, gfn_t gfn)
 {
+       struct kvm_memory_slot *slot;
        int *write_count;
+       int i;
 
        gfn = unalias_gfn(kvm, gfn);
-       write_count = slot_largepage_idx(gfn,
-                                        gfn_to_memslot_unaliased(kvm, gfn));
-       *write_count += 1;
+
+       slot = gfn_to_memslot_unaliased(kvm, gfn);
+       for (i = PT_DIRECTORY_LEVEL;
+            i < PT_PAGE_TABLE_LEVEL + KVM_NR_PAGE_SIZES; ++i) {
+               write_count   = slot_largepage_idx(gfn, slot, i);
+               *write_count += 1;
+       }
 }
 
 static void unaccount_shadowed(struct kvm *kvm, gfn_t gfn)
 {
+       struct kvm_memory_slot *slot;
        int *write_count;
+       int i;
 
        gfn = unalias_gfn(kvm, gfn);
-       write_count = slot_largepage_idx(gfn,
-                                        gfn_to_memslot_unaliased(kvm, gfn));
-       *write_count -= 1;
-       WARN_ON(*write_count < 0);
+       slot = gfn_to_memslot_unaliased(kvm, gfn);
+       for (i = PT_DIRECTORY_LEVEL;
+            i < PT_PAGE_TABLE_LEVEL + KVM_NR_PAGE_SIZES; ++i) {
+               write_count   = slot_largepage_idx(gfn, slot, i);
+               *write_count -= 1;
+               WARN_ON(*write_count < 0);
+       }
 }
 
-static int has_wrprotected_page(struct kvm *kvm, gfn_t gfn)
+static int has_wrprotected_page(struct kvm *kvm,
+                               gfn_t gfn,
+                               int level)
 {
        struct kvm_memory_slot *slot;
        int *largepage_idx;
@@ -431,47 +450,53 @@ static int has_wrprotected_page(struct kvm *kvm, gfn_t gfn)
        gfn = unalias_gfn(kvm, gfn);
        slot = gfn_to_memslot_unaliased(kvm, gfn);
        if (slot) {
-               largepage_idx = slot_largepage_idx(gfn, slot);
+               largepage_idx = slot_largepage_idx(gfn, slot, level);
                return *largepage_idx;
        }
 
        return 1;
 }
 
-static int host_largepage_backed(struct kvm *kvm, gfn_t gfn)
+static int host_mapping_level(struct kvm *kvm, gfn_t gfn)
 {
-       struct vm_area_struct *vma;
-       unsigned long addr;
-       int ret = 0;
+       unsigned long page_size;
+       int i, ret = 0;
 
-       addr = gfn_to_hva(kvm, gfn);
-       if (kvm_is_error_hva(addr))
-               return ret;
+       page_size = kvm_host_page_size(kvm, gfn);
 
-       down_read(&current->mm->mmap_sem);
-       vma = find_vma(current->mm, addr);
-       if (vma && is_vm_hugetlb_page(vma))
-               ret = 1;
-       up_read(&current->mm->mmap_sem);
+       for (i = PT_PAGE_TABLE_LEVEL;
+            i < (PT_PAGE_TABLE_LEVEL + KVM_NR_PAGE_SIZES); ++i) {
+               if (page_size >= KVM_HPAGE_SIZE(i))
+                       ret = i;
+               else
+                       break;
+       }
 
        return ret;
 }
 
-static int is_largepage_backed(struct kvm_vcpu *vcpu, gfn_t large_gfn)
+static int mapping_level(struct kvm_vcpu *vcpu, gfn_t large_gfn)
 {
        struct kvm_memory_slot *slot;
-
-       if (has_wrprotected_page(vcpu->kvm, large_gfn))
-               return 0;
-
-       if (!host_largepage_backed(vcpu->kvm, large_gfn))
-               return 0;
+       int host_level, level, max_level;
 
        slot = gfn_to_memslot(vcpu->kvm, large_gfn);
        if (slot && slot->dirty_bitmap)
-               return 0;
+               return PT_PAGE_TABLE_LEVEL;
 
-       return 1;
+       host_level = host_mapping_level(vcpu->kvm, large_gfn);
+
+       if (host_level == PT_PAGE_TABLE_LEVEL)
+               return host_level;
+
+       max_level = kvm_x86_ops->get_lpage_level() < host_level ?
+               kvm_x86_ops->get_lpage_level() : host_level;
+
+       for (level = PT_DIRECTORY_LEVEL; level <= max_level; ++level)
+               if (has_wrprotected_page(vcpu->kvm, large_gfn, level))
+                       break;
+
+       return level - 1;
 }
 
 /*
@@ -585,10 +610,8 @@ static void rmap_remove(struct kvm *kvm, u64 *spte)
        pfn = spte_to_pfn(*spte);
        if (*spte & shadow_accessed_mask)
                kvm_set_pfn_accessed(pfn);
-       if (is_writeble_pte(*spte))
-               kvm_release_pfn_dirty(pfn);
-       else
-               kvm_release_pfn_clean(pfn);
+       if (is_writable_pte(*spte))
+               kvm_set_pfn_dirty(pfn);
        rmapp = gfn_to_rmap(kvm, sp->gfns[spte - sp->spt], sp->role.level);
        if (!*rmapp) {
                printk(KERN_ERR "rmap_remove: %p %llx 0->BUG\n", spte, *spte);
@@ -616,6 +639,7 @@ static void rmap_remove(struct kvm *kvm, u64 *spte)
                        prev_desc = desc;
                        desc = desc->more;
                }
+               pr_err("rmap_remove: %p %llx many->many\n", spte, *spte);
                BUG();
        }
 }
@@ -623,7 +647,6 @@ static void rmap_remove(struct kvm *kvm, u64 *spte)
 static u64 *rmap_next(struct kvm *kvm, unsigned long *rmapp, u64 *spte)
 {
        struct kvm_rmap_desc *desc;
-       struct kvm_rmap_desc *prev_desc;
        u64 *prev_spte;
        int i;
 
@@ -635,7 +658,6 @@ static u64 *rmap_next(struct kvm *kvm, unsigned long *rmapp, u64 *spte)
                return NULL;
        }
        desc = (struct kvm_rmap_desc *)(*rmapp & ~1ul);
-       prev_desc = NULL;
        prev_spte = NULL;
        while (desc) {
                for (i = 0; i < RMAP_EXT && desc->sptes[i]; ++i) {
@@ -662,7 +684,7 @@ static int rmap_write_protect(struct kvm *kvm, u64 gfn)
                BUG_ON(!spte);
                BUG_ON(!(*spte & PT_PRESENT_MASK));
                rmap_printk("rmap_write_protect: spte %p %llx\n", spte, *spte);
-               if (is_writeble_pte(*spte)) {
+               if (is_writable_pte(*spte)) {
                        __set_spte(spte, *spte & ~PT_WRITABLE_MASK);
                        write_protected = 1;
                }
@@ -686,7 +708,7 @@ static int rmap_write_protect(struct kvm *kvm, u64 gfn)
                        BUG_ON(!(*spte & PT_PRESENT_MASK));
                        BUG_ON((*spte & (PT_PAGE_SIZE_MASK|PT_PRESENT_MASK)) != (PT_PAGE_SIZE_MASK|PT_PRESENT_MASK));
                        pgprintk("rmap_write_protect(large): spte %p %llx %lld\n", spte, *spte, gfn);
-                       if (is_writeble_pte(*spte)) {
+                       if (is_writable_pte(*spte)) {
                                rmap_remove(kvm, spte);
                                --kvm->stat.lpages;
                                __set_spte(spte, shadow_trap_nonpresent_pte);
@@ -700,7 +722,8 @@ static int rmap_write_protect(struct kvm *kvm, u64 gfn)
        return write_protected;
 }
 
-static int kvm_unmap_rmapp(struct kvm *kvm, unsigned long *rmapp)
+static int kvm_unmap_rmapp(struct kvm *kvm, unsigned long *rmapp,
+                          unsigned long data)
 {
        u64 *spte;
        int need_tlb_flush = 0;
@@ -715,33 +738,75 @@ static int kvm_unmap_rmapp(struct kvm *kvm, unsigned long *rmapp)
        return need_tlb_flush;
 }
 
+static int kvm_set_pte_rmapp(struct kvm *kvm, unsigned long *rmapp,
+                            unsigned long data)
+{
+       int need_flush = 0;
+       u64 *spte, new_spte;
+       pte_t *ptep = (pte_t *)data;
+       pfn_t new_pfn;
+
+       WARN_ON(pte_huge(*ptep));
+       new_pfn = pte_pfn(*ptep);
+       spte = rmap_next(kvm, rmapp, NULL);
+       while (spte) {
+               BUG_ON(!is_shadow_present_pte(*spte));
+               rmap_printk("kvm_set_pte_rmapp: spte %p %llx\n", spte, *spte);
+               need_flush = 1;
+               if (pte_write(*ptep)) {
+                       rmap_remove(kvm, spte);
+                       __set_spte(spte, shadow_trap_nonpresent_pte);
+                       spte = rmap_next(kvm, rmapp, NULL);
+               } else {
+                       new_spte = *spte &~ (PT64_BASE_ADDR_MASK);
+                       new_spte |= (u64)new_pfn << PAGE_SHIFT;
+
+                       new_spte &= ~PT_WRITABLE_MASK;
+                       new_spte &= ~SPTE_HOST_WRITEABLE;
+                       if (is_writable_pte(*spte))
+                               kvm_set_pfn_dirty(spte_to_pfn(*spte));
+                       __set_spte(spte, new_spte);
+                       spte = rmap_next(kvm, rmapp, spte);
+               }
+       }
+       if (need_flush)
+               kvm_flush_remote_tlbs(kvm);
+
+       return 0;
+}
+
 static int kvm_handle_hva(struct kvm *kvm, unsigned long hva,
-                         int (*handler)(struct kvm *kvm, unsigned long *rmapp))
+                         unsigned long data,
+                         int (*handler)(struct kvm *kvm, unsigned long *rmapp,
+                                        unsigned long data))
 {
-       int i;
+       int i, j;
+       int ret;
        int retval = 0;
+       struct kvm_memslots *slots;
 
-       /*
-        * If mmap_sem isn't taken, we can look the memslots with only
-        * the mmu_lock by skipping over the slots with userspace_addr == 0.
-        */
-       for (i = 0; i < kvm->nmemslots; i++) {
-               struct kvm_memory_slot *memslot = &kvm->memslots[i];
+       slots = kvm_memslots(kvm);
+
+       for (i = 0; i < slots->nmemslots; i++) {
+               struct kvm_memory_slot *memslot = &slots->memslots[i];
                unsigned long start = memslot->userspace_addr;
                unsigned long end;
 
-               /* mmu_lock protects userspace_addr */
-               if (!start)
-                       continue;
-
                end = start + (memslot->npages << PAGE_SHIFT);
                if (hva >= start && hva < end) {
                        gfn_t gfn_offset = (hva - start) >> PAGE_SHIFT;
-                       int idx = gfn_offset /
-                                 KVM_PAGES_PER_HPAGE(PT_DIRECTORY_LEVEL);
-                       retval |= handler(kvm, &memslot->rmap[gfn_offset]);
-                       retval |= handler(kvm,
-                                       &memslot->lpage_info[0][idx].rmap_pde);
+
+                       ret = handler(kvm, &memslot->rmap[gfn_offset], data);
+
+                       for (j = 0; j < KVM_NR_PAGE_SIZES - 1; ++j) {
+                               int idx = gfn_offset;
+                               idx /= KVM_PAGES_PER_HPAGE(PT_DIRECTORY_LEVEL + j);
+                               ret |= handler(kvm,
+                                       &memslot->lpage_info[j][idx].rmap_pde,
+                                       data);
+                       }
+                       trace_kvm_age_page(hva, memslot, ret);
+                       retval |= ret;
                }
        }
 
@@ -750,17 +815,29 @@ static int kvm_handle_hva(struct kvm *kvm, unsigned long hva,
 
 int kvm_unmap_hva(struct kvm *kvm, unsigned long hva)
 {
-       return kvm_handle_hva(kvm, hva, kvm_unmap_rmapp);
+       return kvm_handle_hva(kvm, hva, 0, kvm_unmap_rmapp);
+}
+
+void kvm_set_spte_hva(struct kvm *kvm, unsigned long hva, pte_t pte)
+{
+       kvm_handle_hva(kvm, hva, (unsigned long)&pte, kvm_set_pte_rmapp);
 }
 
-static int kvm_age_rmapp(struct kvm *kvm, unsigned long *rmapp)
+static int kvm_age_rmapp(struct kvm *kvm, unsigned long *rmapp,
+                        unsigned long data)
 {
        u64 *spte;
        int young = 0;
 
-       /* always return old for EPT */
+       /*
+        * Emulate the accessed bit for EPT, by checking if this page has
+        * an EPT mapping, and clearing it if it does. On the next access,
+        * a new EPT mapping will be established.
+        * This has some overhead, but not as much as the cost of swapping
+        * out actively used pages or breaking up actively used hugepages.
+        */
        if (!shadow_accessed_mask)
-               return 0;
+               return kvm_unmap_rmapp(kvm, rmapp, data);
 
        spte = rmap_next(kvm, rmapp, NULL);
        while (spte) {
@@ -779,20 +856,23 @@ static int kvm_age_rmapp(struct kvm *kvm, unsigned long *rmapp)
 
 #define RMAP_RECYCLE_THRESHOLD 1000
 
-static void rmap_recycle(struct kvm_vcpu *vcpu, gfn_t gfn, int lpage)
+static void rmap_recycle(struct kvm_vcpu *vcpu, u64 *spte, gfn_t gfn)
 {
        unsigned long *rmapp;
+       struct kvm_mmu_page *sp;
+
+       sp = page_header(__pa(spte));
 
        gfn = unalias_gfn(vcpu->kvm, gfn);
-       rmapp = gfn_to_rmap(vcpu->kvm, gfn, lpage);
+       rmapp = gfn_to_rmap(vcpu->kvm, gfn, sp->role.level);
 
-       kvm_unmap_rmapp(vcpu->kvm, rmapp);
+       kvm_unmap_rmapp(vcpu->kvm, rmapp, 0);
        kvm_flush_remote_tlbs(vcpu->kvm);
 }
 
 int kvm_age_hva(struct kvm *kvm, unsigned long hva)
 {
-       return kvm_handle_hva(kvm, hva, kvm_age_rmapp);
+       return kvm_handle_hva(kvm, hva, 0, kvm_age_rmapp);
 }
 
 #ifdef MMU_DEBUG
@@ -836,7 +916,6 @@ static struct kvm_mmu_page *kvm_mmu_alloc_page(struct kvm_vcpu *vcpu,
        sp->gfns = mmu_memory_cache_alloc(&vcpu->arch.mmu_page_cache, PAGE_SIZE);
        set_page_private(virt_to_page(sp->spt), (unsigned long)sp);
        list_add(&sp->link, &vcpu->kvm->arch.active_mmu_pages);
-       INIT_LIST_HEAD(&sp->oos_link);
        bitmap_zero(sp->slot_bitmap, KVM_MEMORY_SLOTS + KVM_PRIVATE_MEM_SLOTS);
        sp->multimapped = 0;
        sp->parent_pte = parent_pte;
@@ -920,8 +999,7 @@ static void mmu_page_remove_parent_pte(struct kvm_mmu_page *sp,
 }
 
 
-static void mmu_parent_walk(struct kvm_vcpu *vcpu, struct kvm_mmu_page *sp,
-                           mmu_parent_walk_fn fn)
+static void mmu_parent_walk(struct kvm_mmu_page *sp, mmu_parent_walk_fn fn)
 {
        struct kvm_pte_chain *pte_chain;
        struct hlist_node *node;
@@ -930,8 +1008,8 @@ static void mmu_parent_walk(struct kvm_vcpu *vcpu, struct kvm_mmu_page *sp,
 
        if (!sp->multimapped && sp->parent_pte) {
                parent_sp = page_header(__pa(sp->parent_pte));
-               fn(vcpu, parent_sp);
-               mmu_parent_walk(vcpu, parent_sp, fn);
+               fn(parent_sp);
+               mmu_parent_walk(parent_sp, fn);
                return;
        }
        hlist_for_each_entry(pte_chain, node, &sp->parent_ptes, link)
@@ -939,8 +1017,8 @@ static void mmu_parent_walk(struct kvm_vcpu *vcpu, struct kvm_mmu_page *sp,
                        if (!pte_chain->parent_ptes[i])
                                break;
                        parent_sp = page_header(__pa(pte_chain->parent_ptes[i]));
-                       fn(vcpu, parent_sp);
-                       mmu_parent_walk(vcpu, parent_sp, fn);
+                       fn(parent_sp);
+                       mmu_parent_walk(parent_sp, fn);
                }
 }
 
@@ -977,16 +1055,15 @@ static void kvm_mmu_update_parents_unsync(struct kvm_mmu_page *sp)
                }
 }
 
-static int unsync_walk_fn(struct kvm_vcpu *vcpu, struct kvm_mmu_page *sp)
+static int unsync_walk_fn(struct kvm_mmu_page *sp)
 {
        kvm_mmu_update_parents_unsync(sp);
        return 1;
 }
 
-static void kvm_mmu_mark_parents_unsync(struct kvm_vcpu *vcpu,
-                                       struct kvm_mmu_page *sp)
+static void kvm_mmu_mark_parents_unsync(struct kvm_mmu_page *sp)
 {
-       mmu_parent_walk(vcpu, sp, unsync_walk_fn);
+       mmu_parent_walk(sp, unsync_walk_fn);
        kvm_mmu_update_parents_unsync(sp);
 }
 
@@ -1112,6 +1189,7 @@ static struct kvm_mmu_page *kvm_mmu_lookup_page(struct kvm *kvm, gfn_t gfn)
 static void kvm_unlink_unsync_page(struct kvm *kvm, struct kvm_mmu_page *sp)
 {
        WARN_ON(!sp->unsync);
+       trace_kvm_mmu_sync_page(sp);
        sp->unsync = 0;
        --kvm->stat.mmu_unsync;
 }
@@ -1120,12 +1198,11 @@ static int kvm_mmu_zap_page(struct kvm *kvm, struct kvm_mmu_page *sp);
 
 static int kvm_sync_page(struct kvm_vcpu *vcpu, struct kvm_mmu_page *sp)
 {
-       if (sp->role.glevels != vcpu->arch.mmu.root_level) {
+       if (sp->role.cr4_pae != !!is_pae(vcpu)) {
                kvm_mmu_zap_page(vcpu->kvm, sp);
                return 1;
        }
 
-       trace_kvm_mmu_sync_page(sp);
        if (rmap_write_protect(vcpu->kvm, sp->gfn))
                kvm_flush_remote_tlbs(vcpu->kvm);
        kvm_unlink_unsync_page(vcpu->kvm, sp);
@@ -1242,6 +1319,8 @@ static struct kvm_mmu_page *kvm_mmu_get_page(struct kvm_vcpu *vcpu,
        role = vcpu->arch.mmu.base_role;
        role.level = level;
        role.direct = direct;
+       if (role.direct)
+               role.cr4_pae = 0;
        role.access = access;
        if (vcpu->arch.mmu.root_level <= PT32_ROOT_LEVEL) {
                quadrant = gaddr >> (PAGE_SHIFT + (PT64_PT_BITS * level));
@@ -1262,7 +1341,7 @@ static struct kvm_mmu_page *kvm_mmu_get_page(struct kvm_vcpu *vcpu,
                        mmu_page_add_parent_pte(vcpu, sp, parent_pte);
                        if (sp->unsync_children) {
                                set_bit(KVM_REQ_MMU_SYNC, &vcpu->requests);
-                               kvm_mmu_mark_parents_unsync(vcpu, sp);
+                               kvm_mmu_mark_parents_unsync(sp);
                        }
                        trace_kvm_mmu_get_page(sp, false);
                        return sp;
@@ -1401,8 +1480,8 @@ static int mmu_zap_unsync_children(struct kvm *kvm,
                for_each_sp(pages, sp, parents, i) {
                        kvm_mmu_zap_page(kvm, sp);
                        mmu_pages_clear_parents(&parents);
+                       zapped++;
                }
-               zapped += pages.nr;
                kvm_mmu_pages_init(parent, &parents, &pages);
        }
 
@@ -1453,14 +1532,16 @@ void kvm_mmu_change_mmu_pages(struct kvm *kvm, unsigned int kvm_nr_mmu_pages)
         */
 
        if (used_pages > kvm_nr_mmu_pages) {
-               while (used_pages > kvm_nr_mmu_pages) {
+               while (used_pages > kvm_nr_mmu_pages &&
+                       !list_empty(&kvm->arch.active_mmu_pages)) {
                        struct kvm_mmu_page *page;
 
                        page = container_of(kvm->arch.active_mmu_pages.prev,
                                            struct kvm_mmu_page, link);
-                       kvm_mmu_zap_page(kvm, page);
+                       used_pages -= kvm_mmu_zap_page(kvm, page);
                        used_pages--;
                }
+               kvm_nr_mmu_pages = used_pages;
                kvm->arch.n_free_mmu_pages = 0;
        }
        else
@@ -1482,13 +1563,14 @@ static int kvm_mmu_unprotect_page(struct kvm *kvm, gfn_t gfn)
        r = 0;
        index = kvm_page_table_hashfn(gfn);
        bucket = &kvm->arch.mmu_page_hash[index];
+restart:
        hlist_for_each_entry_safe(sp, node, n, bucket, hash_link)
                if (sp->gfn == gfn && !sp->role.direct) {
                        pgprintk("%s: gfn %lx role %x\n", __func__, gfn,
                                 sp->role.word);
                        r = 1;
                        if (kvm_mmu_zap_page(kvm, sp))
-                               n = bucket->first;
+                               goto restart;
                }
        return r;
 }
@@ -1502,19 +1584,21 @@ static void mmu_unshadow(struct kvm *kvm, gfn_t gfn)
 
        index = kvm_page_table_hashfn(gfn);
        bucket = &kvm->arch.mmu_page_hash[index];
+restart:
        hlist_for_each_entry_safe(sp, node, nn, bucket, hash_link) {
                if (sp->gfn == gfn && !sp->role.direct
                    && !sp->role.invalid) {
                        pgprintk("%s: zap %lx %x\n",
                                 __func__, gfn, sp->role.word);
-                       kvm_mmu_zap_page(kvm, sp);
+                       if (kvm_mmu_zap_page(kvm, sp))
+                               goto restart;
                }
        }
 }
 
 static void page_header_update_slot(struct kvm *kvm, void *pte, gfn_t gfn)
 {
-       int slot = memslot_id(kvm, gfn_to_memslot(kvm, gfn));
+       int slot = memslot_id(kvm, gfn);
        struct kvm_mmu_page *sp = page_header(__pa(pte));
 
        __set_bit(slot, sp->slot_bitmap);
@@ -1534,20 +1618,6 @@ static void mmu_convert_notrap(struct kvm_mmu_page *sp)
        }
 }
 
-struct page *gva_to_page(struct kvm_vcpu *vcpu, gva_t gva)
-{
-       struct page *page;
-
-       gpa_t gpa = vcpu->arch.mmu.gva_to_gpa(vcpu, gva);
-
-       if (gpa == UNMAPPED_GVA)
-               return NULL;
-
-       page = gfn_to_page(vcpu->kvm, gpa >> PAGE_SHIFT);
-
-       return page;
-}
-
 /*
  * The function is based on mtrr_type_lookup() in
  * arch/x86/kernel/cpu/mtrr/generic.c
@@ -1660,7 +1730,6 @@ static int kvm_unsync_page(struct kvm_vcpu *vcpu, struct kvm_mmu_page *sp)
        struct kvm_mmu_page *s;
        struct hlist_node *node, *n;
 
-       trace_kvm_mmu_unsync_page(sp);
        index = kvm_page_table_hashfn(sp->gfn);
        bucket = &vcpu->kvm->arch.mmu_page_hash[index];
        /* don't unsync if pagetable is shadowed with multiple roles */
@@ -1670,10 +1739,11 @@ static int kvm_unsync_page(struct kvm_vcpu *vcpu, struct kvm_mmu_page *sp)
                if (s->role.word != sp->role.word)
                        return 1;
        }
+       trace_kvm_mmu_unsync_page(sp);
        ++vcpu->kvm->stat.mmu_unsync;
        sp->unsync = 1;
 
-       kvm_mmu_mark_parents_unsync(vcpu, sp);
+       kvm_mmu_mark_parents_unsync(sp);
 
        mmu_convert_notrap(sp);
        return 0;
@@ -1699,9 +1769,9 @@ static int mmu_need_write_protect(struct kvm_vcpu *vcpu, gfn_t gfn,
 
 static int set_spte(struct kvm_vcpu *vcpu, u64 *sptep,
                    unsigned pte_access, int user_fault,
-                   int write_fault, int dirty, int largepage,
+                   int write_fault, int dirty, int level,
                    gfn_t gfn, pfn_t pfn, bool speculative,
-                   bool can_unsync)
+                   bool can_unsync, bool reset_host_protection)
 {
        u64 spte;
        int ret = 0;
@@ -1722,18 +1792,22 @@ static int set_spte(struct kvm_vcpu *vcpu, u64 *sptep,
                spte |= shadow_nx_mask;
        if (pte_access & ACC_USER_MASK)
                spte |= shadow_user_mask;
-       if (largepage)
+       if (level > PT_PAGE_TABLE_LEVEL)
                spte |= PT_PAGE_SIZE_MASK;
        if (tdp_enabled)
                spte |= kvm_x86_ops->get_mt_mask(vcpu, gfn,
                        kvm_is_mmio_pfn(pfn));
 
+       if (reset_host_protection)
+               spte |= SPTE_HOST_WRITEABLE;
+
        spte |= (u64)pfn << PAGE_SHIFT;
 
        if ((pte_access & ACC_WRITE_MASK)
            || (write_fault && !is_write_protection(vcpu) && !user_fault)) {
 
-               if (largepage && has_wrprotected_page(vcpu->kvm, gfn)) {
+               if (level > PT_PAGE_TABLE_LEVEL &&
+                   has_wrprotected_page(vcpu->kvm, gfn, level)) {
                        ret = 1;
                        spte = shadow_trap_nonpresent_pte;
                        goto set_pte;
@@ -1741,13 +1815,16 @@ static int set_spte(struct kvm_vcpu *vcpu, u64 *sptep,
 
                spte |= PT_WRITABLE_MASK;
 
+               if (!tdp_enabled && !(pte_access & ACC_WRITE_MASK))
+                       spte &= ~PT_USER_MASK;
+
                /*
                 * Optimization: for pte sync, if spte was writable the hash
                 * lookup is unnecessary (and expensive). Write protection
                 * is responsibility of mmu_get_page / kvm_sync_page.
                 * Same reasoning can be applied to dirty page accounting.
                 */
-               if (!can_unsync && is_writeble_pte(*sptep))
+               if (!can_unsync && is_writable_pte(*sptep))
                        goto set_pte;
 
                if (mmu_need_write_protect(vcpu, gfn, can_unsync)) {
@@ -1755,7 +1832,7 @@ static int set_spte(struct kvm_vcpu *vcpu, u64 *sptep,
                                 __func__, gfn);
                        ret = 1;
                        pte_access &= ~ACC_WRITE_MASK;
-                       if (is_writeble_pte(spte))
+                       if (is_writable_pte(spte))
                                spte &= ~PT_WRITABLE_MASK;
                }
        }
@@ -1771,11 +1848,12 @@ set_pte:
 static void mmu_set_spte(struct kvm_vcpu *vcpu, u64 *sptep,
                         unsigned pt_access, unsigned pte_access,
                         int user_fault, int write_fault, int dirty,
-                        int *ptwrite, int largepage, gfn_t gfn,
-                        pfn_t pfn, bool speculative)
+                        int *ptwrite, int level, gfn_t gfn,
+                        pfn_t pfn, bool speculative,
+                        bool reset_host_protection)
 {
        int was_rmapped = 0;
-       int was_writeble = is_writeble_pte(*sptep);
+       int was_writable = is_writable_pte(*sptep);
        int rmap_count;
 
        pgprintk("%s: spte %llx access %x write_fault %d"
@@ -1788,12 +1866,15 @@ static void mmu_set_spte(struct kvm_vcpu *vcpu, u64 *sptep,
                 * If we overwrite a PTE page pointer with a 2MB PMD, unlink
                 * the parent of the now unreachable PTE.
                 */
-               if (largepage && !is_large_pte(*sptep)) {
+               if (level > PT_PAGE_TABLE_LEVEL &&
+                   !is_large_pte(*sptep)) {
                        struct kvm_mmu_page *child;
                        u64 pte = *sptep;
 
                        child = page_header(pte & PT64_BASE_ADDR_MASK);
                        mmu_page_remove_parent_pte(child, sptep);
+                       __set_spte(sptep, shadow_trap_nonpresent_pte);
+                       kvm_flush_remote_tlbs(vcpu->kvm);
                } else if (pfn != spte_to_pfn(*sptep)) {
                        pgprintk("hfn old %lx new %lx\n",
                                 spte_to_pfn(*sptep), pfn);
@@ -1801,8 +1882,10 @@ static void mmu_set_spte(struct kvm_vcpu *vcpu, u64 *sptep,
                } else
                        was_rmapped = 1;
        }
+
        if (set_spte(vcpu, sptep, pte_access, user_fault, write_fault,
-                     dirty, largepage, gfn, pfn, speculative, true)) {
+                     dirty, level, gfn, pfn, speculative, true,
+                     reset_host_protection)) {
                if (write_fault)
                        *ptwrite = 1;
                kvm_x86_ops->tlb_flush(vcpu);
@@ -1819,12 +1902,11 @@ static void mmu_set_spte(struct kvm_vcpu *vcpu, u64 *sptep,
        page_header_update_slot(vcpu->kvm, sptep, gfn);
        if (!was_rmapped) {
                rmap_count = rmap_add(vcpu, sptep, gfn);
-               if (!is_rmap_spte(*sptep))
-                       kvm_release_pfn_clean(pfn);
+               kvm_release_pfn_clean(pfn);
                if (rmap_count > RMAP_RECYCLE_THRESHOLD)
-                       rmap_recycle(vcpu, gfn, largepage);
+                       rmap_recycle(vcpu, sptep, gfn);
        } else {
-               if (was_writeble)
+               if (was_writable)
                        kvm_release_pfn_dirty(pfn);
                else
                        kvm_release_pfn_clean(pfn);
@@ -1840,7 +1922,7 @@ static void nonpaging_new_cr3(struct kvm_vcpu *vcpu)
 }
 
 static int __direct_map(struct kvm_vcpu *vcpu, gpa_t v, int write,
-                       int largepage, gfn_t gfn, pfn_t pfn)
+                       int level, gfn_t gfn, pfn_t pfn)
 {
        struct kvm_shadow_walk_iterator iterator;
        struct kvm_mmu_page *sp;
@@ -1848,11 +1930,10 @@ static int __direct_map(struct kvm_vcpu *vcpu, gpa_t v, int write,
        gfn_t pseudo_gfn;
 
        for_each_shadow_entry(vcpu, (u64)gfn << PAGE_SHIFT, iterator) {
-               if (iterator.level == PT_PAGE_TABLE_LEVEL
-                   || (largepage && iterator.level == PT_DIRECTORY_LEVEL)) {
+               if (iterator.level == level) {
                        mmu_set_spte(vcpu, iterator.sptep, ACC_ALL, ACC_ALL,
                                     0, write, 1, &pt_write,
-                                    largepage, gfn, pfn, false);
+                                    level, gfn, pfn, false, true);
                        ++vcpu->stat.pf_fixed;
                        break;
                }
@@ -1880,15 +1961,20 @@ static int __direct_map(struct kvm_vcpu *vcpu, gpa_t v, int write,
 static int nonpaging_map(struct kvm_vcpu *vcpu, gva_t v, int write, gfn_t gfn)
 {
        int r;
-       int largepage = 0;
+       int level;
        pfn_t pfn;
        unsigned long mmu_seq;
 
-       if (is_largepage_backed(vcpu, gfn &
-                       ~(KVM_PAGES_PER_HPAGE(PT_DIRECTORY_LEVEL) - 1))) {
-               gfn &= ~(KVM_PAGES_PER_HPAGE(PT_DIRECTORY_LEVEL) - 1);
-               largepage = 1;
-       }
+       level = mapping_level(vcpu, gfn);
+
+       /*
+        * This path builds a PAE pagetable - so we can map 2mb pages at
+        * maximum. Therefore check if the level is larger than that.
+        */
+       if (level > PT_DIRECTORY_LEVEL)
+               level = PT_DIRECTORY_LEVEL;
+
+       gfn &= ~(KVM_PAGES_PER_HPAGE(level) - 1);
 
        mmu_seq = vcpu->kvm->mmu_notifier_seq;
        smp_rmb();
@@ -1904,7 +1990,7 @@ static int nonpaging_map(struct kvm_vcpu *vcpu, gva_t v, int write, gfn_t gfn)
        if (mmu_notifier_retry(vcpu, mmu_seq))
                goto out_unlock;
        kvm_mmu_free_some_pages(vcpu);
-       r = __direct_map(vcpu, v, write, largepage, gfn, pfn);
+       r = __direct_map(vcpu, v, write, level, gfn, pfn);
        spin_unlock(&vcpu->kvm->mmu_lock);
 
 
@@ -1978,21 +2064,23 @@ static int mmu_alloc_roots(struct kvm_vcpu *vcpu)
                hpa_t root = vcpu->arch.mmu.root_hpa;
 
                ASSERT(!VALID_PAGE(root));
-               if (tdp_enabled)
-                       direct = 1;
                if (mmu_check_root(vcpu, root_gfn))
                        return 1;
+               if (tdp_enabled) {
+                       direct = 1;
+                       root_gfn = 0;
+               }
+               spin_lock(&vcpu->kvm->mmu_lock);
                sp = kvm_mmu_get_page(vcpu, root_gfn, 0,
                                      PT64_ROOT_LEVEL, direct,
                                      ACC_ALL, NULL);
                root = __pa(sp->spt);
                ++sp->root_count;
+               spin_unlock(&vcpu->kvm->mmu_lock);
                vcpu->arch.mmu.root_hpa = root;
                return 0;
        }
        direct = !is_paging(vcpu);
-       if (tdp_enabled)
-               direct = 1;
        for (i = 0; i < 4; ++i) {
                hpa_t root = vcpu->arch.mmu.pae_root[i];
 
@@ -2008,11 +2096,18 @@ static int mmu_alloc_roots(struct kvm_vcpu *vcpu)
                        root_gfn = 0;
                if (mmu_check_root(vcpu, root_gfn))
                        return 1;
+               if (tdp_enabled) {
+                       direct = 1;
+                       root_gfn = i << 30;
+               }
+               spin_lock(&vcpu->kvm->mmu_lock);
                sp = kvm_mmu_get_page(vcpu, root_gfn, i << 30,
                                      PT32_ROOT_LEVEL, direct,
                                      ACC_ALL, NULL);
                root = __pa(sp->spt);
                ++sp->root_count;
+               spin_unlock(&vcpu->kvm->mmu_lock);
+
                vcpu->arch.mmu.pae_root[i] = root | PT_PRESENT_MASK;
        }
        vcpu->arch.mmu.root_hpa = __pa(vcpu->arch.mmu.pae_root);
@@ -2050,8 +2145,11 @@ void kvm_mmu_sync_roots(struct kvm_vcpu *vcpu)
        spin_unlock(&vcpu->kvm->mmu_lock);
 }
 
-static gpa_t nonpaging_gva_to_gpa(struct kvm_vcpu *vcpu, gva_t vaddr)
+static gpa_t nonpaging_gva_to_gpa(struct kvm_vcpu *vcpu, gva_t vaddr,
+                                 u32 access, u32 *error)
 {
+       if (error)
+               *error = 0;
        return vaddr;
 }
 
@@ -2080,7 +2178,7 @@ static int tdp_page_fault(struct kvm_vcpu *vcpu, gva_t gpa,
 {
        pfn_t pfn;
        int r;
-       int largepage = 0;
+       int level;
        gfn_t gfn = gpa >> PAGE_SHIFT;
        unsigned long mmu_seq;
 
@@ -2091,11 +2189,10 @@ static int tdp_page_fault(struct kvm_vcpu *vcpu, gva_t gpa,
        if (r)
                return r;
 
-       if (is_largepage_backed(vcpu, gfn &
-                       ~(KVM_PAGES_PER_HPAGE(PT_DIRECTORY_LEVEL) - 1))) {
-               gfn &= ~(KVM_PAGES_PER_HPAGE(PT_DIRECTORY_LEVEL) - 1);
-               largepage = 1;
-       }
+       level = mapping_level(vcpu, gfn);
+
+       gfn &= ~(KVM_PAGES_PER_HPAGE(level) - 1);
+
        mmu_seq = vcpu->kvm->mmu_notifier_seq;
        smp_rmb();
        pfn = gfn_to_pfn(vcpu->kvm, gfn);
@@ -2108,7 +2205,7 @@ static int tdp_page_fault(struct kvm_vcpu *vcpu, gva_t gpa,
                goto out_unlock;
        kvm_mmu_free_some_pages(vcpu);
        r = __direct_map(vcpu, gpa, error_code & PFERR_WRITE_MASK,
-                        largepage, gfn, pfn);
+                        level, gfn, pfn);
        spin_unlock(&vcpu->kvm->mmu_lock);
 
        return r;
@@ -2194,13 +2291,19 @@ static void reset_rsvds_bits_mask(struct kvm_vcpu *vcpu, int level)
                /* no rsvd bits for 2 level 4K page table entries */
                context->rsvd_bits_mask[0][1] = 0;
                context->rsvd_bits_mask[0][0] = 0;
+               context->rsvd_bits_mask[1][0] = context->rsvd_bits_mask[0][0];
+
+               if (!is_pse(vcpu)) {
+                       context->rsvd_bits_mask[1][1] = 0;
+                       break;
+               }
+
                if (is_cpuid_PSE36())
                        /* 36bits PSE 4MB page */
                        context->rsvd_bits_mask[1][1] = rsvd_bits(17, 21);
                else
                        /* 32 bits PSE 4MB page */
                        context->rsvd_bits_mask[1][1] = rsvd_bits(13, 21);
-               context->rsvd_bits_mask[1][0] = context->rsvd_bits_mask[1][0];
                break;
        case PT32E_ROOT_LEVEL:
                context->rsvd_bits_mask[0][2] =
@@ -2213,7 +2316,7 @@ static void reset_rsvds_bits_mask(struct kvm_vcpu *vcpu, int level)
                context->rsvd_bits_mask[1][1] = exb_bit_rsvd |
                        rsvd_bits(maxphyaddr, 62) |
                        rsvd_bits(13, 20);              /* large page */
-               context->rsvd_bits_mask[1][0] = context->rsvd_bits_mask[1][0];
+               context->rsvd_bits_mask[1][0] = context->rsvd_bits_mask[0][0];
                break;
        case PT64_ROOT_LEVEL:
                context->rsvd_bits_mask[0][3] = exb_bit_rsvd |
@@ -2225,11 +2328,13 @@ static void reset_rsvds_bits_mask(struct kvm_vcpu *vcpu, int level)
                context->rsvd_bits_mask[0][0] = exb_bit_rsvd |
                        rsvd_bits(maxphyaddr, 51);
                context->rsvd_bits_mask[1][3] = context->rsvd_bits_mask[0][3];
-               context->rsvd_bits_mask[1][2] = context->rsvd_bits_mask[0][2];
+               context->rsvd_bits_mask[1][2] = exb_bit_rsvd |
+                       rsvd_bits(maxphyaddr, 51) |
+                       rsvd_bits(13, 29);
                context->rsvd_bits_mask[1][1] = exb_bit_rsvd |
                        rsvd_bits(maxphyaddr, 51) |
                        rsvd_bits(13, 20);              /* large page */
-               context->rsvd_bits_mask[1][0] = context->rsvd_bits_mask[1][0];
+               context->rsvd_bits_mask[1][0] = context->rsvd_bits_mask[0][0];
                break;
        }
 }
@@ -2331,7 +2436,8 @@ static int init_kvm_softmmu(struct kvm_vcpu *vcpu)
        else
                r = paging32_init_context(vcpu);
 
-       vcpu->arch.mmu.base_role.glevels = vcpu->arch.mmu.root_level;
+       vcpu->arch.mmu.base_role.cr4_pae = !!is_pae(vcpu);
+       vcpu->arch.mmu.base_role.cr0_wp = is_write_protection(vcpu);
 
        return r;
 }
@@ -2371,7 +2477,9 @@ int kvm_mmu_load(struct kvm_vcpu *vcpu)
                goto out;
        spin_lock(&vcpu->kvm->mmu_lock);
        kvm_mmu_free_some_pages(vcpu);
+       spin_unlock(&vcpu->kvm->mmu_lock);
        r = mmu_alloc_roots(vcpu);
+       spin_lock(&vcpu->kvm->mmu_lock);
        mmu_sync_roots(vcpu);
        spin_unlock(&vcpu->kvm->mmu_lock);
        if (r)
@@ -2415,15 +2523,12 @@ static void mmu_pte_write_new_pte(struct kvm_vcpu *vcpu,
                                  const void *new)
 {
        if (sp->role.level != PT_PAGE_TABLE_LEVEL) {
-               if (!vcpu->arch.update_pte.largepage ||
-                   sp->role.glevels == PT32_ROOT_LEVEL) {
-                       ++vcpu->kvm->stat.mmu_pde_zapped;
-                       return;
-               }
+               ++vcpu->kvm->stat.mmu_pde_zapped;
+               return;
         }
 
        ++vcpu->kvm->stat.mmu_pte_updated;
-       if (sp->role.glevels == PT32_ROOT_LEVEL)
+       if (!sp->role.cr4_pae)
                paging32_update_pte(vcpu, sp, spte, new);
        else
                paging64_update_pte(vcpu, sp, spte, new);
@@ -2458,46 +2563,15 @@ static bool last_updated_pte_accessed(struct kvm_vcpu *vcpu)
 }
 
 static void mmu_guess_page_from_pte_write(struct kvm_vcpu *vcpu, gpa_t gpa,
-                                         const u8 *new, int bytes)
+                                         u64 gpte)
 {
        gfn_t gfn;
-       int r;
-       u64 gpte = 0;
        pfn_t pfn;
 
-       vcpu->arch.update_pte.largepage = 0;
-
-       if (bytes != 4 && bytes != 8)
-               return;
-
-       /*
-        * Assume that the pte write on a page table of the same type
-        * as the current vcpu paging mode.  This is nearly always true
-        * (might be false while changing modes).  Note it is verified later
-        * by update_pte().
-        */
-       if (is_pae(vcpu)) {
-               /* Handle a 32-bit guest writing two halves of a 64-bit gpte */
-               if ((bytes == 4) && (gpa % 4 == 0)) {
-                       r = kvm_read_guest(vcpu->kvm, gpa & ~(u64)7, &gpte, 8);
-                       if (r)
-                               return;
-                       memcpy((void *)&gpte + (gpa % 8), new, 4);
-               } else if ((bytes == 8) && (gpa % 8 == 0)) {
-                       memcpy((void *)&gpte, new, 8);
-               }
-       } else {
-               if ((bytes == 4) && (gpa % 4 == 0))
-                       memcpy((void *)&gpte, new, 4);
-       }
        if (!is_present_gpte(gpte))
                return;
        gfn = (gpte & PT64_BASE_ADDR_MASK) >> PAGE_SHIFT;
 
-       if (is_large_pte(gpte) && is_largepage_backed(vcpu, gfn)) {
-               gfn &= ~(KVM_PAGES_PER_HPAGE(PT_DIRECTORY_LEVEL) - 1);
-               vcpu->arch.update_pte.largepage = 1;
-       }
        vcpu->arch.update_pte.mmu_seq = vcpu->kvm->mmu_notifier_seq;
        smp_rmb();
        pfn = gfn_to_pfn(vcpu->kvm, gfn);
@@ -2542,10 +2616,46 @@ void kvm_mmu_pte_write(struct kvm_vcpu *vcpu, gpa_t gpa,
        int flooded = 0;
        int npte;
        int r;
+       int invlpg_counter;
 
        pgprintk("%s: gpa %llx bytes %d\n", __func__, gpa, bytes);
-       mmu_guess_page_from_pte_write(vcpu, gpa, new, bytes);
+
+       invlpg_counter = atomic_read(&vcpu->kvm->arch.invlpg_counter);
+
+       /*
+        * Assume that the pte write on a page table of the same type
+        * as the current vcpu paging mode.  This is nearly always true
+        * (might be false while changing modes).  Note it is verified later
+        * by update_pte().
+        */
+       if ((is_pae(vcpu) && bytes == 4) || !new) {
+               /* Handle a 32-bit guest writing two halves of a 64-bit gpte */
+               if (is_pae(vcpu)) {
+                       gpa &= ~(gpa_t)7;
+                       bytes = 8;
+               }
+               r = kvm_read_guest(vcpu->kvm, gpa, &gentry, min(bytes, 8));
+               if (r)
+                       gentry = 0;
+               new = (const u8 *)&gentry;
+       }
+
+       switch (bytes) {
+       case 4:
+               gentry = *(const u32 *)new;
+               break;
+       case 8:
+               gentry = *(const u64 *)new;
+               break;
+       default:
+               gentry = 0;
+               break;
+       }
+
+       mmu_guess_page_from_pte_write(vcpu, gpa, gentry);
        spin_lock(&vcpu->kvm->mmu_lock);
+       if (atomic_read(&vcpu->kvm->arch.invlpg_counter) != invlpg_counter)
+               gentry = 0;
        kvm_mmu_access_page(vcpu, gfn);
        kvm_mmu_free_some_pages(vcpu);
        ++vcpu->kvm->stat.mmu_pte_write;
@@ -2564,10 +2674,12 @@ void kvm_mmu_pte_write(struct kvm_vcpu *vcpu, gpa_t gpa,
        }
        index = kvm_page_table_hashfn(gfn);
        bucket = &vcpu->kvm->arch.mmu_page_hash[index];
+
+restart:
        hlist_for_each_entry_safe(sp, node, n, bucket, hash_link) {
                if (sp->gfn != gfn || sp->role.direct || sp->role.invalid)
                        continue;
-               pte_size = sp->role.glevels == PT32_ROOT_LEVEL ? 4 : 8;
+               pte_size = sp->role.cr4_pae ? 8 : 4;
                misaligned = (offset ^ (offset + bytes - 1)) & ~(pte_size - 1);
                misaligned |= bytes < 4;
                if (misaligned || flooded) {
@@ -2584,14 +2696,14 @@ void kvm_mmu_pte_write(struct kvm_vcpu *vcpu, gpa_t gpa,
                        pgprintk("misaligned: gpa %llx bytes %d role %x\n",
                                 gpa, bytes, sp->role.word);
                        if (kvm_mmu_zap_page(vcpu->kvm, sp))
-                               n = bucket->first;
+                               goto restart;
                        ++vcpu->kvm->stat.mmu_flooded;
                        continue;
                }
                page_offset = offset;
                level = sp->role.level;
                npte = 1;
-               if (sp->role.glevels == PT32_ROOT_LEVEL) {
+               if (!sp->role.cr4_pae) {
                        page_offset <<= 1;      /* 32->64 */
                        /*
                         * A 32-bit pde maps 4MB while the shadow pdes map
@@ -2609,20 +2721,11 @@ void kvm_mmu_pte_write(struct kvm_vcpu *vcpu, gpa_t gpa,
                                continue;
                }
                spte = &sp->spt[page_offset / sizeof(*spte)];
-               if ((gpa & (pte_size - 1)) || (bytes < pte_size)) {
-                       gentry = 0;
-                       r = kvm_read_guest_atomic(vcpu->kvm,
-                                                 gpa & ~(u64)(pte_size - 1),
-                                                 &gentry, pte_size);
-                       new = (const void *)&gentry;
-                       if (r < 0)
-                               new = NULL;
-               }
                while (npte--) {
                        entry = *spte;
                        mmu_pte_write_zap_pte(vcpu, sp, spte);
-                       if (new)
-                               mmu_pte_write_new_pte(vcpu, sp, spte, new);
+                       if (gentry)
+                               mmu_pte_write_new_pte(vcpu, sp, spte, &gentry);
                        mmu_pte_write_flush_tlb(vcpu, entry, *spte);
                        ++spte;
                }
@@ -2640,7 +2743,10 @@ int kvm_mmu_unprotect_page_virt(struct kvm_vcpu *vcpu, gva_t gva)
        gpa_t gpa;
        int r;
 
-       gpa = vcpu->arch.mmu.gva_to_gpa(vcpu, gva);
+       if (tdp_enabled)
+               return 0;
+
+       gpa = kvm_mmu_gva_to_gpa_read(vcpu, gva, NULL);
 
        spin_lock(&vcpu->kvm->mmu_lock);
        r = kvm_mmu_unprotect_page(vcpu->kvm, gpa >> PAGE_SHIFT);
@@ -2651,7 +2757,8 @@ EXPORT_SYMBOL_GPL(kvm_mmu_unprotect_page_virt);
 
 void __kvm_mmu_free_some_pages(struct kvm_vcpu *vcpu)
 {
-       while (vcpu->kvm->arch.n_free_mmu_pages < KVM_REFILL_PAGES) {
+       while (vcpu->kvm->arch.n_free_mmu_pages < KVM_REFILL_PAGES &&
+              !list_empty(&vcpu->kvm->arch.active_mmu_pages)) {
                struct kvm_mmu_page *sp;
 
                sp = container_of(vcpu->kvm->arch.active_mmu_pages.prev,
@@ -2679,7 +2786,7 @@ int kvm_mmu_page_fault(struct kvm_vcpu *vcpu, gva_t cr2, u32 error_code)
        if (r)
                goto out;
 
-       er = emulate_instruction(vcpu, vcpu->run, cr2, error_code, 0);
+       er = emulate_instruction(vcpu, cr2, error_code, 0);
 
        switch (er) {
        case EMULATE_DONE:
@@ -2690,6 +2797,7 @@ int kvm_mmu_page_fault(struct kvm_vcpu *vcpu, gva_t cr2, u32 error_code)
        case EMULATE_FAIL:
                vcpu->run->exit_reason = KVM_EXIT_INTERNAL_ERROR;
                vcpu->run->internal.suberror = KVM_INTERNAL_ERROR_EMULATION;
+               vcpu->run->internal.ndata = 0;
                return 0;
        default:
                BUG();
@@ -2731,14 +2839,6 @@ static int alloc_mmu_pages(struct kvm_vcpu *vcpu)
 
        ASSERT(vcpu);
 
-       spin_lock(&vcpu->kvm->mmu_lock);
-       if (vcpu->kvm->arch.n_requested_mmu_pages)
-               vcpu->kvm->arch.n_free_mmu_pages =
-                                       vcpu->kvm->arch.n_requested_mmu_pages;
-       else
-               vcpu->kvm->arch.n_free_mmu_pages =
-                                       vcpu->kvm->arch.n_alloc_mmu_pages;
-       spin_unlock(&vcpu->kvm->mmu_lock);
        /*
         * When emulating 32-bit mode, cr3 is only 32 bits even on x86_64.
         * Therefore we need to allocate shadow page tables in the first
@@ -2746,16 +2846,13 @@ static int alloc_mmu_pages(struct kvm_vcpu *vcpu)
         */
        page = alloc_page(GFP_KERNEL | __GFP_DMA32);
        if (!page)
-               goto error_1;
+               return -ENOMEM;
+
        vcpu->arch.mmu.pae_root = page_address(page);
        for (i = 0; i < 4; ++i)
                vcpu->arch.mmu.pae_root[i] = INVALID_PAGE;
 
        return 0;
-
-error_1:
-       free_mmu_pages(vcpu);
-       return -ENOMEM;
 }
 
 int kvm_mmu_create(struct kvm_vcpu *vcpu)
@@ -2808,22 +2905,23 @@ void kvm_mmu_zap_all(struct kvm *kvm)
        struct kvm_mmu_page *sp, *node;
 
        spin_lock(&kvm->mmu_lock);
+restart:
        list_for_each_entry_safe(sp, node, &kvm->arch.active_mmu_pages, link)
                if (kvm_mmu_zap_page(kvm, sp))
-                       node = container_of(kvm->arch.active_mmu_pages.next,
-                                           struct kvm_mmu_page, link);
+                       goto restart;
+
        spin_unlock(&kvm->mmu_lock);
 
        kvm_flush_remote_tlbs(kvm);
 }
 
-static void kvm_mmu_remove_one_alloc_mmu_page(struct kvm *kvm)
+static int kvm_mmu_remove_some_alloc_mmu_pages(struct kvm *kvm)
 {
        struct kvm_mmu_page *page;
 
        page = container_of(kvm->arch.active_mmu_pages.prev,
                            struct kvm_mmu_page, link);
-       kvm_mmu_zap_page(kvm, page);
+       return kvm_mmu_zap_page(kvm, page) + 1;
 }
 
 static int mmu_shrink(int nr_to_scan, gfp_t gfp_mask)
@@ -2835,23 +2933,22 @@ static int mmu_shrink(int nr_to_scan, gfp_t gfp_mask)
        spin_lock(&kvm_lock);
 
        list_for_each_entry(kvm, &vm_list, vm_list) {
-               int npages;
+               int npages, idx, freed_pages;
 
-               if (!down_read_trylock(&kvm->slots_lock))
-                       continue;
+               idx = srcu_read_lock(&kvm->srcu);
                spin_lock(&kvm->mmu_lock);
                npages = kvm->arch.n_alloc_mmu_pages -
                         kvm->arch.n_free_mmu_pages;
                cache_count += npages;
                if (!kvm_freed && nr_to_scan > 0 && npages > 0) {
-                       kvm_mmu_remove_one_alloc_mmu_page(kvm);
-                       cache_count--;
+                       freed_pages = kvm_mmu_remove_some_alloc_mmu_pages(kvm);
+                       cache_count -= freed_pages;
                        kvm_freed = kvm;
                }
                nr_to_scan--;
 
                spin_unlock(&kvm->mmu_lock);
-               up_read(&kvm->slots_lock);
+               srcu_read_unlock(&kvm->srcu, idx);
        }
        if (kvm_freed)
                list_move_tail(&kvm_freed->vm_list, &vm_list);
@@ -2918,9 +3015,12 @@ unsigned int kvm_mmu_calculate_mmu_pages(struct kvm *kvm)
        int i;
        unsigned int nr_mmu_pages;
        unsigned int  nr_pages = 0;
+       struct kvm_memslots *slots;
+
+       slots = kvm_memslots(kvm);
 
-       for (i = 0; i < kvm->nmemslots; i++)
-               nr_pages += kvm->memslots[i].npages;
+       for (i = 0; i < slots->nmemslots; i++)
+               nr_pages += slots->memslots[i].npages;
 
        nr_mmu_pages = nr_pages * KVM_PERMILLE_MMU_PAGES / 1000;
        nr_mmu_pages = max(nr_mmu_pages,
@@ -3081,8 +3181,7 @@ static gva_t canonicalize(gva_t gva)
 }
 
 
-typedef void (*inspect_spte_fn) (struct kvm *kvm, struct kvm_mmu_page *sp,
-                                u64 *sptep);
+typedef void (*inspect_spte_fn) (struct kvm *kvm, u64 *sptep);
 
 static void __mmu_spte_walk(struct kvm *kvm, struct kvm_mmu_page *sp,
                            inspect_spte_fn fn)
@@ -3098,7 +3197,7 @@ static void __mmu_spte_walk(struct kvm *kvm, struct kvm_mmu_page *sp,
                                child = page_header(ent & PT64_BASE_ADDR_MASK);
                                __mmu_spte_walk(kvm, child, fn);
                        } else
-                               fn(kvm, sp, &sp->spt[i]);
+                               fn(kvm, &sp->spt[i]);
                }
        }
 }
@@ -3145,7 +3244,7 @@ static void audit_mappings_page(struct kvm_vcpu *vcpu, u64 page_pte,
                if (is_shadow_present_pte(ent) && !is_last_spte(ent, level))
                        audit_mappings_page(vcpu, ent, va, level - 1);
                else {
-                       gpa_t gpa = vcpu->arch.mmu.gva_to_gpa(vcpu, va);
+                       gpa_t gpa = kvm_mmu_gva_to_gpa_read(vcpu, va, NULL);
                        gfn_t gfn = gpa >> PAGE_SHIFT;
                        pfn_t pfn = gfn_to_pfn(vcpu->kvm, gfn);
                        hpa_t hpa = (hpa_t)pfn << PAGE_SHIFT;
@@ -3189,11 +3288,15 @@ static void audit_mappings(struct kvm_vcpu *vcpu)
 
 static int count_rmaps(struct kvm_vcpu *vcpu)
 {
+       struct kvm *kvm = vcpu->kvm;
+       struct kvm_memslots *slots;
        int nmaps = 0;
-       int i, j, k;
+       int i, j, k, idx;
 
+       idx = srcu_read_lock(&kvm->srcu);
+       slots = kvm_memslots(kvm);
        for (i = 0; i < KVM_MEMORY_SLOTS; ++i) {
-               struct kvm_memory_slot *m = &vcpu->kvm->memslots[i];
+               struct kvm_memory_slot *m = &slots->memslots[i];
                struct kvm_rmap_desc *d;
 
                for (j = 0; j < m->npages; ++j) {
@@ -3216,10 +3319,11 @@ static int count_rmaps(struct kvm_vcpu *vcpu)
                        }
                }
        }
+       srcu_read_unlock(&kvm->srcu, idx);
        return nmaps;
 }
 
-void inspect_spte_has_rmap(struct kvm *kvm, struct kvm_mmu_page *sp, u64 *sptep)
+void inspect_spte_has_rmap(struct kvm *kvm, u64 *sptep)
 {
        unsigned long *rmapp;
        struct kvm_mmu_page *rev_sp;
@@ -3235,14 +3339,14 @@ void inspect_spte_has_rmap(struct kvm *kvm, struct kvm_mmu_page *sp, u64 *sptep)
                        printk(KERN_ERR "%s: no memslot for gfn %ld\n",
                                         audit_msg, gfn);
                        printk(KERN_ERR "%s: index %ld of sp (gfn=%lx)\n",
-                                       audit_msg, sptep - rev_sp->spt,
+                              audit_msg, (long int)(sptep - rev_sp->spt),
                                        rev_sp->gfn);
                        dump_stack();
                        return;
                }
 
                rmapp = gfn_to_rmap(kvm, rev_sp->gfns[sptep - rev_sp->spt],
-                                   is_large_pte(*sptep));
+                                   rev_sp->role.level);
                if (!*rmapp) {
                        if (!printk_ratelimit())
                                return;
@@ -3277,7 +3381,7 @@ static void check_writable_mappings_rmap(struct kvm_vcpu *vcpu)
                                continue;
                        if (!(ent & PT_WRITABLE_MASK))
                                continue;
-                       inspect_spte_has_rmap(vcpu->kvm, sp, &pt[i]);
+                       inspect_spte_has_rmap(vcpu->kvm, &pt[i]);
                }
        }
        return;