KVM: MMU: invalidate and flush on spte small->large page size change
[safe/jmp/linux-2.6] / arch / x86 / kvm / mmu.c
index dc4d954..6fbcb48 100644 (file)
@@ -31,6 +31,7 @@
 #include <linux/hugetlb.h>
 #include <linux/compiler.h>
 #include <linux/srcu.h>
+#include <linux/slab.h>
 
 #include <asm/page.h>
 #include <asm/cmpxchg.h>
@@ -138,12 +139,6 @@ module_param(oos_shadow, bool, 0644);
 #define PT64_PERM_MASK (PT_PRESENT_MASK | PT_WRITABLE_MASK | PT_USER_MASK \
                        | PT64_NX_MASK)
 
-#define PFERR_PRESENT_MASK (1U << 0)
-#define PFERR_WRITE_MASK (1U << 1)
-#define PFERR_USER_MASK (1U << 2)
-#define PFERR_RSVD_MASK (1U << 3)
-#define PFERR_FETCH_MASK (1U << 4)
-
 #define RMAP_EXT 4
 
 #define ACC_EXEC_MASK    1
@@ -151,6 +146,8 @@ module_param(oos_shadow, bool, 0644);
 #define ACC_USER_MASK    PT_USER_MASK
 #define ACC_ALL          (ACC_EXEC_MASK | ACC_WRITE_MASK | ACC_USER_MASK)
 
+#include <trace/events/kvm.h>
+
 #define CREATE_TRACE_POINTS
 #include "mmutrace.h"
 
@@ -176,12 +173,7 @@ struct kvm_shadow_walk_iterator {
             shadow_walk_okay(&(_walker));                      \
             shadow_walk_next(&(_walker)))
 
-
-struct kvm_unsync_walk {
-       int (*entry) (struct kvm_mmu_page *sp, struct kvm_unsync_walk *walk);
-};
-
-typedef int (*mmu_parent_walk_fn) (struct kvm_vcpu *vcpu, struct kvm_mmu_page *sp);
+typedef int (*mmu_parent_walk_fn) (struct kvm_mmu_page *sp);
 
 static struct kmem_cache *pte_chain_cache;
 static struct kmem_cache *rmap_desc_cache;
@@ -225,7 +217,7 @@ void kvm_mmu_set_mask_ptes(u64 user_mask, u64 accessed_mask,
 }
 EXPORT_SYMBOL_GPL(kvm_mmu_set_mask_ptes);
 
-static int is_write_protection(struct kvm_vcpu *vcpu)
+static bool is_write_protection(struct kvm_vcpu *vcpu)
 {
        return kvm_read_cr0_bits(vcpu, X86_CR0_WP);
 }
@@ -329,7 +321,6 @@ static int mmu_topup_memory_cache_page(struct kvm_mmu_memory_cache *cache,
                page = alloc_page(GFP_KERNEL);
                if (!page)
                        return -ENOMEM;
-               set_page_private(page, 0);
                cache->objects[cache->nobjs++] = page_address(page);
        }
        return 0;
@@ -440,9 +431,9 @@ static void unaccount_shadowed(struct kvm *kvm, gfn_t gfn)
        int i;
 
        gfn = unalias_gfn(kvm, gfn);
+       slot = gfn_to_memslot_unaliased(kvm, gfn);
        for (i = PT_DIRECTORY_LEVEL;
             i < PT_PAGE_TABLE_LEVEL + KVM_NR_PAGE_SIZES; ++i) {
-               slot          = gfn_to_memslot_unaliased(kvm, gfn);
                write_count   = slot_largepage_idx(gfn, slot, i);
                *write_count -= 1;
                WARN_ON(*write_count < 0);
@@ -468,24 +459,10 @@ static int has_wrprotected_page(struct kvm *kvm,
 
 static int host_mapping_level(struct kvm *kvm, gfn_t gfn)
 {
-       unsigned long page_size = PAGE_SIZE;
-       struct vm_area_struct *vma;
-       unsigned long addr;
+       unsigned long page_size;
        int i, ret = 0;
 
-       addr = gfn_to_hva(kvm, gfn);
-       if (kvm_is_error_hva(addr))
-               return PT_PAGE_TABLE_LEVEL;
-
-       down_read(&current->mm->mmap_sem);
-       vma = find_vma(current->mm, addr);
-       if (!vma)
-               goto out;
-
-       page_size = vma_kernel_pagesize(vma);
-
-out:
-       up_read(&current->mm->mmap_sem);
+       page_size = kvm_host_page_size(kvm, gfn);
 
        for (i = PT_PAGE_TABLE_LEVEL;
             i < (PT_PAGE_TABLE_LEVEL + KVM_NR_PAGE_SIZES); ++i) {
@@ -670,7 +647,6 @@ static void rmap_remove(struct kvm *kvm, u64 *spte)
 static u64 *rmap_next(struct kvm *kvm, unsigned long *rmapp, u64 *spte)
 {
        struct kvm_rmap_desc *desc;
-       struct kvm_rmap_desc *prev_desc;
        u64 *prev_spte;
        int i;
 
@@ -682,7 +658,6 @@ static u64 *rmap_next(struct kvm *kvm, unsigned long *rmapp, u64 *spte)
                return NULL;
        }
        desc = (struct kvm_rmap_desc *)(*rmapp & ~1ul);
-       prev_desc = NULL;
        prev_spte = NULL;
        while (desc) {
                for (i = 0; i < RMAP_EXT && desc->sptes[i]; ++i) {
@@ -806,10 +781,11 @@ static int kvm_handle_hva(struct kvm *kvm, unsigned long hva,
                                         unsigned long data))
 {
        int i, j;
+       int ret;
        int retval = 0;
        struct kvm_memslots *slots;
 
-       slots = rcu_dereference(kvm->memslots);
+       slots = kvm_memslots(kvm);
 
        for (i = 0; i < slots->nmemslots; i++) {
                struct kvm_memory_slot *memslot = &slots->memslots[i];
@@ -820,16 +796,17 @@ static int kvm_handle_hva(struct kvm *kvm, unsigned long hva,
                if (hva >= start && hva < end) {
                        gfn_t gfn_offset = (hva - start) >> PAGE_SHIFT;
 
-                       retval |= handler(kvm, &memslot->rmap[gfn_offset],
-                                         data);
+                       ret = handler(kvm, &memslot->rmap[gfn_offset], data);
 
                        for (j = 0; j < KVM_NR_PAGE_SIZES - 1; ++j) {
                                int idx = gfn_offset;
                                idx /= KVM_PAGES_PER_HPAGE(PT_DIRECTORY_LEVEL + j);
-                               retval |= handler(kvm,
+                               ret |= handler(kvm,
                                        &memslot->lpage_info[j][idx].rmap_pde,
                                        data);
                        }
+                       trace_kvm_age_page(hva, memslot, ret);
+                       retval |= ret;
                }
        }
 
@@ -852,9 +829,15 @@ static int kvm_age_rmapp(struct kvm *kvm, unsigned long *rmapp,
        u64 *spte;
        int young = 0;
 
-       /* always return old for EPT */
+       /*
+        * Emulate the accessed bit for EPT, by checking if this page has
+        * an EPT mapping, and clearing it if it does. On the next access,
+        * a new EPT mapping will be established.
+        * This has some overhead, but not as much as the cost of swapping
+        * out actively used pages or breaking up actively used hugepages.
+        */
        if (!shadow_accessed_mask)
-               return 0;
+               return kvm_unmap_rmapp(kvm, rmapp, data);
 
        spte = rmap_next(kvm, rmapp, NULL);
        while (spte) {
@@ -933,7 +916,6 @@ static struct kvm_mmu_page *kvm_mmu_alloc_page(struct kvm_vcpu *vcpu,
        sp->gfns = mmu_memory_cache_alloc(&vcpu->arch.mmu_page_cache, PAGE_SIZE);
        set_page_private(virt_to_page(sp->spt), (unsigned long)sp);
        list_add(&sp->link, &vcpu->kvm->arch.active_mmu_pages);
-       INIT_LIST_HEAD(&sp->oos_link);
        bitmap_zero(sp->slot_bitmap, KVM_MEMORY_SLOTS + KVM_PRIVATE_MEM_SLOTS);
        sp->multimapped = 0;
        sp->parent_pte = parent_pte;
@@ -1017,8 +999,7 @@ static void mmu_page_remove_parent_pte(struct kvm_mmu_page *sp,
 }
 
 
-static void mmu_parent_walk(struct kvm_vcpu *vcpu, struct kvm_mmu_page *sp,
-                           mmu_parent_walk_fn fn)
+static void mmu_parent_walk(struct kvm_mmu_page *sp, mmu_parent_walk_fn fn)
 {
        struct kvm_pte_chain *pte_chain;
        struct hlist_node *node;
@@ -1027,8 +1008,8 @@ static void mmu_parent_walk(struct kvm_vcpu *vcpu, struct kvm_mmu_page *sp,
 
        if (!sp->multimapped && sp->parent_pte) {
                parent_sp = page_header(__pa(sp->parent_pte));
-               fn(vcpu, parent_sp);
-               mmu_parent_walk(vcpu, parent_sp, fn);
+               fn(parent_sp);
+               mmu_parent_walk(parent_sp, fn);
                return;
        }
        hlist_for_each_entry(pte_chain, node, &sp->parent_ptes, link)
@@ -1036,8 +1017,8 @@ static void mmu_parent_walk(struct kvm_vcpu *vcpu, struct kvm_mmu_page *sp,
                        if (!pte_chain->parent_ptes[i])
                                break;
                        parent_sp = page_header(__pa(pte_chain->parent_ptes[i]));
-                       fn(vcpu, parent_sp);
-                       mmu_parent_walk(vcpu, parent_sp, fn);
+                       fn(parent_sp);
+                       mmu_parent_walk(parent_sp, fn);
                }
 }
 
@@ -1074,16 +1055,15 @@ static void kvm_mmu_update_parents_unsync(struct kvm_mmu_page *sp)
                }
 }
 
-static int unsync_walk_fn(struct kvm_vcpu *vcpu, struct kvm_mmu_page *sp)
+static int unsync_walk_fn(struct kvm_mmu_page *sp)
 {
        kvm_mmu_update_parents_unsync(sp);
        return 1;
 }
 
-static void kvm_mmu_mark_parents_unsync(struct kvm_vcpu *vcpu,
-                                       struct kvm_mmu_page *sp)
+static void kvm_mmu_mark_parents_unsync(struct kvm_mmu_page *sp)
 {
-       mmu_parent_walk(vcpu, sp, unsync_walk_fn);
+       mmu_parent_walk(sp, unsync_walk_fn);
        kvm_mmu_update_parents_unsync(sp);
 }
 
@@ -1209,6 +1189,7 @@ static struct kvm_mmu_page *kvm_mmu_lookup_page(struct kvm *kvm, gfn_t gfn)
 static void kvm_unlink_unsync_page(struct kvm *kvm, struct kvm_mmu_page *sp)
 {
        WARN_ON(!sp->unsync);
+       trace_kvm_mmu_sync_page(sp);
        sp->unsync = 0;
        --kvm->stat.mmu_unsync;
 }
@@ -1217,12 +1198,11 @@ static int kvm_mmu_zap_page(struct kvm *kvm, struct kvm_mmu_page *sp);
 
 static int kvm_sync_page(struct kvm_vcpu *vcpu, struct kvm_mmu_page *sp)
 {
-       if (sp->role.glevels != vcpu->arch.mmu.root_level) {
+       if (sp->role.cr4_pae != !!is_pae(vcpu)) {
                kvm_mmu_zap_page(vcpu->kvm, sp);
                return 1;
        }
 
-       trace_kvm_mmu_sync_page(sp);
        if (rmap_write_protect(vcpu->kvm, sp->gfn))
                kvm_flush_remote_tlbs(vcpu->kvm);
        kvm_unlink_unsync_page(vcpu->kvm, sp);
@@ -1339,6 +1319,8 @@ static struct kvm_mmu_page *kvm_mmu_get_page(struct kvm_vcpu *vcpu,
        role = vcpu->arch.mmu.base_role;
        role.level = level;
        role.direct = direct;
+       if (role.direct)
+               role.cr4_pae = 0;
        role.access = access;
        if (vcpu->arch.mmu.root_level <= PT32_ROOT_LEVEL) {
                quadrant = gaddr >> (PAGE_SHIFT + (PT64_PT_BITS * level));
@@ -1359,7 +1341,7 @@ static struct kvm_mmu_page *kvm_mmu_get_page(struct kvm_vcpu *vcpu,
                        mmu_page_add_parent_pte(vcpu, sp, parent_pte);
                        if (sp->unsync_children) {
                                set_bit(KVM_REQ_MMU_SYNC, &vcpu->requests);
-                               kvm_mmu_mark_parents_unsync(vcpu, sp);
+                               kvm_mmu_mark_parents_unsync(sp);
                        }
                        trace_kvm_mmu_get_page(sp, false);
                        return sp;
@@ -1498,8 +1480,8 @@ static int mmu_zap_unsync_children(struct kvm *kvm,
                for_each_sp(pages, sp, parents, i) {
                        kvm_mmu_zap_page(kvm, sp);
                        mmu_pages_clear_parents(&parents);
+                       zapped++;
                }
-               zapped += pages.nr;
                kvm_mmu_pages_init(parent, &parents, &pages);
        }
 
@@ -1550,14 +1532,16 @@ void kvm_mmu_change_mmu_pages(struct kvm *kvm, unsigned int kvm_nr_mmu_pages)
         */
 
        if (used_pages > kvm_nr_mmu_pages) {
-               while (used_pages > kvm_nr_mmu_pages) {
+               while (used_pages > kvm_nr_mmu_pages &&
+                       !list_empty(&kvm->arch.active_mmu_pages)) {
                        struct kvm_mmu_page *page;
 
                        page = container_of(kvm->arch.active_mmu_pages.prev,
                                            struct kvm_mmu_page, link);
-                       kvm_mmu_zap_page(kvm, page);
+                       used_pages -= kvm_mmu_zap_page(kvm, page);
                        used_pages--;
                }
+               kvm_nr_mmu_pages = used_pages;
                kvm->arch.n_free_mmu_pages = 0;
        }
        else
@@ -1579,13 +1563,14 @@ static int kvm_mmu_unprotect_page(struct kvm *kvm, gfn_t gfn)
        r = 0;
        index = kvm_page_table_hashfn(gfn);
        bucket = &kvm->arch.mmu_page_hash[index];
+restart:
        hlist_for_each_entry_safe(sp, node, n, bucket, hash_link)
                if (sp->gfn == gfn && !sp->role.direct) {
                        pgprintk("%s: gfn %lx role %x\n", __func__, gfn,
                                 sp->role.word);
                        r = 1;
                        if (kvm_mmu_zap_page(kvm, sp))
-                               n = bucket->first;
+                               goto restart;
                }
        return r;
 }
@@ -1599,12 +1584,14 @@ static void mmu_unshadow(struct kvm *kvm, gfn_t gfn)
 
        index = kvm_page_table_hashfn(gfn);
        bucket = &kvm->arch.mmu_page_hash[index];
+restart:
        hlist_for_each_entry_safe(sp, node, nn, bucket, hash_link) {
                if (sp->gfn == gfn && !sp->role.direct
                    && !sp->role.invalid) {
                        pgprintk("%s: zap %lx %x\n",
                                 __func__, gfn, sp->role.word);
-                       kvm_mmu_zap_page(kvm, sp);
+                       if (kvm_mmu_zap_page(kvm, sp))
+                               goto restart;
                }
        }
 }
@@ -1631,20 +1618,6 @@ static void mmu_convert_notrap(struct kvm_mmu_page *sp)
        }
 }
 
-struct page *gva_to_page(struct kvm_vcpu *vcpu, gva_t gva)
-{
-       struct page *page;
-
-       gpa_t gpa = vcpu->arch.mmu.gva_to_gpa(vcpu, gva);
-
-       if (gpa == UNMAPPED_GVA)
-               return NULL;
-
-       page = gfn_to_page(vcpu->kvm, gpa >> PAGE_SHIFT);
-
-       return page;
-}
-
 /*
  * The function is based on mtrr_type_lookup() in
  * arch/x86/kernel/cpu/mtrr/generic.c
@@ -1757,7 +1730,6 @@ static int kvm_unsync_page(struct kvm_vcpu *vcpu, struct kvm_mmu_page *sp)
        struct kvm_mmu_page *s;
        struct hlist_node *node, *n;
 
-       trace_kvm_mmu_unsync_page(sp);
        index = kvm_page_table_hashfn(sp->gfn);
        bucket = &vcpu->kvm->arch.mmu_page_hash[index];
        /* don't unsync if pagetable is shadowed with multiple roles */
@@ -1767,10 +1739,11 @@ static int kvm_unsync_page(struct kvm_vcpu *vcpu, struct kvm_mmu_page *sp)
                if (s->role.word != sp->role.word)
                        return 1;
        }
+       trace_kvm_mmu_unsync_page(sp);
        ++vcpu->kvm->stat.mmu_unsync;
        sp->unsync = 1;
 
-       kvm_mmu_mark_parents_unsync(vcpu, sp);
+       kvm_mmu_mark_parents_unsync(sp);
 
        mmu_convert_notrap(sp);
        return 0;
@@ -1897,6 +1870,8 @@ static void mmu_set_spte(struct kvm_vcpu *vcpu, u64 *sptep,
 
                        child = page_header(pte & PT64_BASE_ADDR_MASK);
                        mmu_page_remove_parent_pte(child, sptep);
+                       __set_spte(sptep, shadow_trap_nonpresent_pte);
+                       kvm_flush_remote_tlbs(vcpu->kvm);
                } else if (pfn != spte_to_pfn(*sptep)) {
                        pgprintk("hfn old %lx new %lx\n",
                                 spte_to_pfn(*sptep), pfn);
@@ -2086,21 +2061,23 @@ static int mmu_alloc_roots(struct kvm_vcpu *vcpu)
                hpa_t root = vcpu->arch.mmu.root_hpa;
 
                ASSERT(!VALID_PAGE(root));
-               if (tdp_enabled)
-                       direct = 1;
                if (mmu_check_root(vcpu, root_gfn))
                        return 1;
+               if (tdp_enabled) {
+                       direct = 1;
+                       root_gfn = 0;
+               }
+               spin_lock(&vcpu->kvm->mmu_lock);
                sp = kvm_mmu_get_page(vcpu, root_gfn, 0,
                                      PT64_ROOT_LEVEL, direct,
                                      ACC_ALL, NULL);
                root = __pa(sp->spt);
                ++sp->root_count;
+               spin_unlock(&vcpu->kvm->mmu_lock);
                vcpu->arch.mmu.root_hpa = root;
                return 0;
        }
        direct = !is_paging(vcpu);
-       if (tdp_enabled)
-               direct = 1;
        for (i = 0; i < 4; ++i) {
                hpa_t root = vcpu->arch.mmu.pae_root[i];
 
@@ -2116,11 +2093,18 @@ static int mmu_alloc_roots(struct kvm_vcpu *vcpu)
                        root_gfn = 0;
                if (mmu_check_root(vcpu, root_gfn))
                        return 1;
+               if (tdp_enabled) {
+                       direct = 1;
+                       root_gfn = i << 30;
+               }
+               spin_lock(&vcpu->kvm->mmu_lock);
                sp = kvm_mmu_get_page(vcpu, root_gfn, i << 30,
                                      PT32_ROOT_LEVEL, direct,
                                      ACC_ALL, NULL);
                root = __pa(sp->spt);
                ++sp->root_count;
+               spin_unlock(&vcpu->kvm->mmu_lock);
+
                vcpu->arch.mmu.pae_root[i] = root | PT_PRESENT_MASK;
        }
        vcpu->arch.mmu.root_hpa = __pa(vcpu->arch.mmu.pae_root);
@@ -2158,8 +2142,11 @@ void kvm_mmu_sync_roots(struct kvm_vcpu *vcpu)
        spin_unlock(&vcpu->kvm->mmu_lock);
 }
 
-static gpa_t nonpaging_gva_to_gpa(struct kvm_vcpu *vcpu, gva_t vaddr)
+static gpa_t nonpaging_gva_to_gpa(struct kvm_vcpu *vcpu, gva_t vaddr,
+                                 u32 access, u32 *error)
 {
+       if (error)
+               *error = 0;
        return vaddr;
 }
 
@@ -2301,13 +2288,19 @@ static void reset_rsvds_bits_mask(struct kvm_vcpu *vcpu, int level)
                /* no rsvd bits for 2 level 4K page table entries */
                context->rsvd_bits_mask[0][1] = 0;
                context->rsvd_bits_mask[0][0] = 0;
+               context->rsvd_bits_mask[1][0] = context->rsvd_bits_mask[0][0];
+
+               if (!is_pse(vcpu)) {
+                       context->rsvd_bits_mask[1][1] = 0;
+                       break;
+               }
+
                if (is_cpuid_PSE36())
                        /* 36bits PSE 4MB page */
                        context->rsvd_bits_mask[1][1] = rsvd_bits(17, 21);
                else
                        /* 32 bits PSE 4MB page */
                        context->rsvd_bits_mask[1][1] = rsvd_bits(13, 21);
-               context->rsvd_bits_mask[1][0] = context->rsvd_bits_mask[1][0];
                break;
        case PT32E_ROOT_LEVEL:
                context->rsvd_bits_mask[0][2] =
@@ -2320,7 +2313,7 @@ static void reset_rsvds_bits_mask(struct kvm_vcpu *vcpu, int level)
                context->rsvd_bits_mask[1][1] = exb_bit_rsvd |
                        rsvd_bits(maxphyaddr, 62) |
                        rsvd_bits(13, 20);              /* large page */
-               context->rsvd_bits_mask[1][0] = context->rsvd_bits_mask[1][0];
+               context->rsvd_bits_mask[1][0] = context->rsvd_bits_mask[0][0];
                break;
        case PT64_ROOT_LEVEL:
                context->rsvd_bits_mask[0][3] = exb_bit_rsvd |
@@ -2338,7 +2331,7 @@ static void reset_rsvds_bits_mask(struct kvm_vcpu *vcpu, int level)
                context->rsvd_bits_mask[1][1] = exb_bit_rsvd |
                        rsvd_bits(maxphyaddr, 51) |
                        rsvd_bits(13, 20);              /* large page */
-               context->rsvd_bits_mask[1][0] = context->rsvd_bits_mask[1][0];
+               context->rsvd_bits_mask[1][0] = context->rsvd_bits_mask[0][0];
                break;
        }
 }
@@ -2440,7 +2433,8 @@ static int init_kvm_softmmu(struct kvm_vcpu *vcpu)
        else
                r = paging32_init_context(vcpu);
 
-       vcpu->arch.mmu.base_role.glevels = vcpu->arch.mmu.root_level;
+       vcpu->arch.mmu.base_role.cr4_pae = !!is_pae(vcpu);
+       vcpu->arch.mmu.base_role.cr0_wp = is_write_protection(vcpu);
 
        return r;
 }
@@ -2480,7 +2474,9 @@ int kvm_mmu_load(struct kvm_vcpu *vcpu)
                goto out;
        spin_lock(&vcpu->kvm->mmu_lock);
        kvm_mmu_free_some_pages(vcpu);
+       spin_unlock(&vcpu->kvm->mmu_lock);
        r = mmu_alloc_roots(vcpu);
+       spin_lock(&vcpu->kvm->mmu_lock);
        mmu_sync_roots(vcpu);
        spin_unlock(&vcpu->kvm->mmu_lock);
        if (r)
@@ -2529,7 +2525,7 @@ static void mmu_pte_write_new_pte(struct kvm_vcpu *vcpu,
         }
 
        ++vcpu->kvm->stat.mmu_pte_updated;
-       if (sp->role.glevels == PT32_ROOT_LEVEL)
+       if (!sp->role.cr4_pae)
                paging32_update_pte(vcpu, sp, spte, new);
        else
                paging64_update_pte(vcpu, sp, spte, new);
@@ -2564,36 +2560,11 @@ static bool last_updated_pte_accessed(struct kvm_vcpu *vcpu)
 }
 
 static void mmu_guess_page_from_pte_write(struct kvm_vcpu *vcpu, gpa_t gpa,
-                                         const u8 *new, int bytes)
+                                         u64 gpte)
 {
        gfn_t gfn;
-       int r;
-       u64 gpte = 0;
        pfn_t pfn;
 
-       if (bytes != 4 && bytes != 8)
-               return;
-
-       /*
-        * Assume that the pte write on a page table of the same type
-        * as the current vcpu paging mode.  This is nearly always true
-        * (might be false while changing modes).  Note it is verified later
-        * by update_pte().
-        */
-       if (is_pae(vcpu)) {
-               /* Handle a 32-bit guest writing two halves of a 64-bit gpte */
-               if ((bytes == 4) && (gpa % 4 == 0)) {
-                       r = kvm_read_guest(vcpu->kvm, gpa & ~(u64)7, &gpte, 8);
-                       if (r)
-                               return;
-                       memcpy((void *)&gpte + (gpa % 8), new, 4);
-               } else if ((bytes == 8) && (gpa % 8 == 0)) {
-                       memcpy((void *)&gpte, new, 8);
-               }
-       } else {
-               if ((bytes == 4) && (gpa % 4 == 0))
-                       memcpy((void *)&gpte, new, 4);
-       }
        if (!is_present_gpte(gpte))
                return;
        gfn = (gpte & PT64_BASE_ADDR_MASK) >> PAGE_SHIFT;
@@ -2642,10 +2613,46 @@ void kvm_mmu_pte_write(struct kvm_vcpu *vcpu, gpa_t gpa,
        int flooded = 0;
        int npte;
        int r;
+       int invlpg_counter;
 
        pgprintk("%s: gpa %llx bytes %d\n", __func__, gpa, bytes);
-       mmu_guess_page_from_pte_write(vcpu, gpa, new, bytes);
+
+       invlpg_counter = atomic_read(&vcpu->kvm->arch.invlpg_counter);
+
+       /*
+        * Assume that the pte write on a page table of the same type
+        * as the current vcpu paging mode.  This is nearly always true
+        * (might be false while changing modes).  Note it is verified later
+        * by update_pte().
+        */
+       if ((is_pae(vcpu) && bytes == 4) || !new) {
+               /* Handle a 32-bit guest writing two halves of a 64-bit gpte */
+               if (is_pae(vcpu)) {
+                       gpa &= ~(gpa_t)7;
+                       bytes = 8;
+               }
+               r = kvm_read_guest(vcpu->kvm, gpa, &gentry, min(bytes, 8));
+               if (r)
+                       gentry = 0;
+               new = (const u8 *)&gentry;
+       }
+
+       switch (bytes) {
+       case 4:
+               gentry = *(const u32 *)new;
+               break;
+       case 8:
+               gentry = *(const u64 *)new;
+               break;
+       default:
+               gentry = 0;
+               break;
+       }
+
+       mmu_guess_page_from_pte_write(vcpu, gpa, gentry);
        spin_lock(&vcpu->kvm->mmu_lock);
+       if (atomic_read(&vcpu->kvm->arch.invlpg_counter) != invlpg_counter)
+               gentry = 0;
        kvm_mmu_access_page(vcpu, gfn);
        kvm_mmu_free_some_pages(vcpu);
        ++vcpu->kvm->stat.mmu_pte_write;
@@ -2664,10 +2671,12 @@ void kvm_mmu_pte_write(struct kvm_vcpu *vcpu, gpa_t gpa,
        }
        index = kvm_page_table_hashfn(gfn);
        bucket = &vcpu->kvm->arch.mmu_page_hash[index];
+
+restart:
        hlist_for_each_entry_safe(sp, node, n, bucket, hash_link) {
                if (sp->gfn != gfn || sp->role.direct || sp->role.invalid)
                        continue;
-               pte_size = sp->role.glevels == PT32_ROOT_LEVEL ? 4 : 8;
+               pte_size = sp->role.cr4_pae ? 8 : 4;
                misaligned = (offset ^ (offset + bytes - 1)) & ~(pte_size - 1);
                misaligned |= bytes < 4;
                if (misaligned || flooded) {
@@ -2684,14 +2693,14 @@ void kvm_mmu_pte_write(struct kvm_vcpu *vcpu, gpa_t gpa,
                        pgprintk("misaligned: gpa %llx bytes %d role %x\n",
                                 gpa, bytes, sp->role.word);
                        if (kvm_mmu_zap_page(vcpu->kvm, sp))
-                               n = bucket->first;
+                               goto restart;
                        ++vcpu->kvm->stat.mmu_flooded;
                        continue;
                }
                page_offset = offset;
                level = sp->role.level;
                npte = 1;
-               if (sp->role.glevels == PT32_ROOT_LEVEL) {
+               if (!sp->role.cr4_pae) {
                        page_offset <<= 1;      /* 32->64 */
                        /*
                         * A 32-bit pde maps 4MB while the shadow pdes map
@@ -2709,20 +2718,11 @@ void kvm_mmu_pte_write(struct kvm_vcpu *vcpu, gpa_t gpa,
                                continue;
                }
                spte = &sp->spt[page_offset / sizeof(*spte)];
-               if ((gpa & (pte_size - 1)) || (bytes < pte_size)) {
-                       gentry = 0;
-                       r = kvm_read_guest_atomic(vcpu->kvm,
-                                                 gpa & ~(u64)(pte_size - 1),
-                                                 &gentry, pte_size);
-                       new = (const void *)&gentry;
-                       if (r < 0)
-                               new = NULL;
-               }
                while (npte--) {
                        entry = *spte;
                        mmu_pte_write_zap_pte(vcpu, sp, spte);
-                       if (new)
-                               mmu_pte_write_new_pte(vcpu, sp, spte, new);
+                       if (gentry)
+                               mmu_pte_write_new_pte(vcpu, sp, spte, &gentry);
                        mmu_pte_write_flush_tlb(vcpu, entry, *spte);
                        ++spte;
                }
@@ -2743,7 +2743,7 @@ int kvm_mmu_unprotect_page_virt(struct kvm_vcpu *vcpu, gva_t gva)
        if (tdp_enabled)
                return 0;
 
-       gpa = vcpu->arch.mmu.gva_to_gpa(vcpu, gva);
+       gpa = kvm_mmu_gva_to_gpa_read(vcpu, gva, NULL);
 
        spin_lock(&vcpu->kvm->mmu_lock);
        r = kvm_mmu_unprotect_page(vcpu->kvm, gpa >> PAGE_SHIFT);
@@ -2902,22 +2902,23 @@ void kvm_mmu_zap_all(struct kvm *kvm)
        struct kvm_mmu_page *sp, *node;
 
        spin_lock(&kvm->mmu_lock);
+restart:
        list_for_each_entry_safe(sp, node, &kvm->arch.active_mmu_pages, link)
                if (kvm_mmu_zap_page(kvm, sp))
-                       node = container_of(kvm->arch.active_mmu_pages.next,
-                                           struct kvm_mmu_page, link);
+                       goto restart;
+
        spin_unlock(&kvm->mmu_lock);
 
        kvm_flush_remote_tlbs(kvm);
 }
 
-static void kvm_mmu_remove_one_alloc_mmu_page(struct kvm *kvm)
+static int kvm_mmu_remove_some_alloc_mmu_pages(struct kvm *kvm)
 {
        struct kvm_mmu_page *page;
 
        page = container_of(kvm->arch.active_mmu_pages.prev,
                            struct kvm_mmu_page, link);
-       kvm_mmu_zap_page(kvm, page);
+       return kvm_mmu_zap_page(kvm, page) + 1;
 }
 
 static int mmu_shrink(int nr_to_scan, gfp_t gfp_mask)
@@ -2929,7 +2930,7 @@ static int mmu_shrink(int nr_to_scan, gfp_t gfp_mask)
        spin_lock(&kvm_lock);
 
        list_for_each_entry(kvm, &vm_list, vm_list) {
-               int npages, idx;
+               int npages, idx, freed_pages;
 
                idx = srcu_read_lock(&kvm->srcu);
                spin_lock(&kvm->mmu_lock);
@@ -2937,8 +2938,8 @@ static int mmu_shrink(int nr_to_scan, gfp_t gfp_mask)
                         kvm->arch.n_free_mmu_pages;
                cache_count += npages;
                if (!kvm_freed && nr_to_scan > 0 && npages > 0) {
-                       kvm_mmu_remove_one_alloc_mmu_page(kvm);
-                       cache_count--;
+                       freed_pages = kvm_mmu_remove_some_alloc_mmu_pages(kvm);
+                       cache_count -= freed_pages;
                        kvm_freed = kvm;
                }
                nr_to_scan--;
@@ -3013,7 +3014,8 @@ unsigned int kvm_mmu_calculate_mmu_pages(struct kvm *kvm)
        unsigned int  nr_pages = 0;
        struct kvm_memslots *slots;
 
-       slots = rcu_dereference(kvm->memslots);
+       slots = kvm_memslots(kvm);
+
        for (i = 0; i < slots->nmemslots; i++)
                nr_pages += slots->memslots[i].npages;
 
@@ -3176,8 +3178,7 @@ static gva_t canonicalize(gva_t gva)
 }
 
 
-typedef void (*inspect_spte_fn) (struct kvm *kvm, struct kvm_mmu_page *sp,
-                                u64 *sptep);
+typedef void (*inspect_spte_fn) (struct kvm *kvm, u64 *sptep);
 
 static void __mmu_spte_walk(struct kvm *kvm, struct kvm_mmu_page *sp,
                            inspect_spte_fn fn)
@@ -3193,7 +3194,7 @@ static void __mmu_spte_walk(struct kvm *kvm, struct kvm_mmu_page *sp,
                                child = page_header(ent & PT64_BASE_ADDR_MASK);
                                __mmu_spte_walk(kvm, child, fn);
                        } else
-                               fn(kvm, sp, &sp->spt[i]);
+                               fn(kvm, &sp->spt[i]);
                }
        }
 }
@@ -3240,7 +3241,7 @@ static void audit_mappings_page(struct kvm_vcpu *vcpu, u64 page_pte,
                if (is_shadow_present_pte(ent) && !is_last_spte(ent, level))
                        audit_mappings_page(vcpu, ent, va, level - 1);
                else {
-                       gpa_t gpa = vcpu->arch.mmu.gva_to_gpa(vcpu, va);
+                       gpa_t gpa = kvm_mmu_gva_to_gpa_read(vcpu, va, NULL);
                        gfn_t gfn = gpa >> PAGE_SHIFT;
                        pfn_t pfn = gfn_to_pfn(vcpu->kvm, gfn);
                        hpa_t hpa = (hpa_t)pfn << PAGE_SHIFT;
@@ -3284,11 +3285,13 @@ static void audit_mappings(struct kvm_vcpu *vcpu)
 
 static int count_rmaps(struct kvm_vcpu *vcpu)
 {
+       struct kvm *kvm = vcpu->kvm;
+       struct kvm_memslots *slots;
        int nmaps = 0;
        int i, j, k, idx;
 
        idx = srcu_read_lock(&kvm->srcu);
-       slots = rcu_dereference(kvm->memslots);
+       slots = kvm_memslots(kvm);
        for (i = 0; i < KVM_MEMORY_SLOTS; ++i) {
                struct kvm_memory_slot *m = &slots->memslots[i];
                struct kvm_rmap_desc *d;
@@ -3317,7 +3320,7 @@ static int count_rmaps(struct kvm_vcpu *vcpu)
        return nmaps;
 }
 
-void inspect_spte_has_rmap(struct kvm *kvm, struct kvm_mmu_page *sp, u64 *sptep)
+void inspect_spte_has_rmap(struct kvm *kvm, u64 *sptep)
 {
        unsigned long *rmapp;
        struct kvm_mmu_page *rev_sp;
@@ -3333,14 +3336,14 @@ void inspect_spte_has_rmap(struct kvm *kvm, struct kvm_mmu_page *sp, u64 *sptep)
                        printk(KERN_ERR "%s: no memslot for gfn %ld\n",
                                         audit_msg, gfn);
                        printk(KERN_ERR "%s: index %ld of sp (gfn=%lx)\n",
-                                       audit_msg, sptep - rev_sp->spt,
+                              audit_msg, (long int)(sptep - rev_sp->spt),
                                        rev_sp->gfn);
                        dump_stack();
                        return;
                }
 
                rmapp = gfn_to_rmap(kvm, rev_sp->gfns[sptep - rev_sp->spt],
-                                   is_large_pte(*sptep));
+                                   rev_sp->role.level);
                if (!*rmapp) {
                        if (!printk_ratelimit())
                                return;
@@ -3375,7 +3378,7 @@ static void check_writable_mappings_rmap(struct kvm_vcpu *vcpu)
                                continue;
                        if (!(ent & PT_WRITABLE_MASK))
                                continue;
-                       inspect_spte_has_rmap(vcpu->kvm, sp, &pt[i]);
+                       inspect_spte_has_rmap(vcpu->kvm, &pt[i]);
                }
        }
        return;