diff --git a/arch/x86/include/asm/kvm_host.h b/arch/x86/include/asm/kvm_host.h
index 992bc1058170..51ce635494b0 100644
--- a/arch/x86/include/asm/kvm_host.h
+++ b/arch/x86/include/asm/kvm_host.h
@@ -256,7 +256,7 @@ struct kvm_mmu_memory_cache {
  * @nxe, @cr0_wp, @smep_andnot_wp and @smap_andnot_wp.
  */
 union kvm_mmu_page_role {
-	unsigned word;
+	u32 word;
 	struct {
 		unsigned level:4;
 		unsigned cr4_pae:1;
@@ -282,6 +282,18 @@ union kvm_mmu_page_role {
 	};
 };
 
+union kvm_mmu_extended_role {
+	u32 word;
+};
+
+union kvm_mmu_role {
+	u64 as_u64;
+	struct {
+		union kvm_mmu_page_role base;
+		union kvm_mmu_extended_role ext;
+	};
+};
+
 struct kvm_rmap_head {
 	unsigned long val;
 };
@@ -369,7 +381,7 @@ struct kvm_mmu {
 	void (*update_pte)(struct kvm_vcpu *vcpu, struct kvm_mmu_page *sp,
 			   u64 *spte, const void *pte);
 	hpa_t root_hpa;
-	union kvm_mmu_page_role base_role;
+	union kvm_mmu_role mmu_role;
 	u8 root_level;
 	u8 shadow_root_level;
 	u8 ept_ad;
diff --git a/arch/x86/kvm/mmu.c b/arch/x86/kvm/mmu.c
index dd147fa5b3f4..3301973527aa 100644
--- a/arch/x86/kvm/mmu.c
+++ b/arch/x86/kvm/mmu.c
@@ -2375,7 +2375,7 @@ static struct kvm_mmu_page *kvm_mmu_get_page(struct kvm_vcpu *vcpu,
 	int collisions = 0;
 	LIST_HEAD(invalid_list);
 
-	role = vcpu->arch.mmu->base_role;
+	role = vcpu->arch.mmu->mmu_role.base;
 	role.level = level;
 	role.direct = direct;
 	if (role.direct)
@@ -4423,7 +4423,8 @@ static void reset_rsvds_bits_mask_ept(struct kvm_vcpu *vcpu,
 void
 reset_shadow_zero_bits_mask(struct kvm_vcpu *vcpu, struct kvm_mmu *context)
 {
-	bool uses_nx = context->nx || context->base_role.smep_andnot_wp;
+	bool uses_nx = context->nx ||
+		context->mmu_role.base.smep_andnot_wp;
 	struct rsvd_bits_validate *shadow_zero_check;
 	int i;
 
@@ -4742,7 +4743,7 @@ static void init_kvm_tdp_mmu(struct kvm_vcpu *vcpu)
 {
 	struct kvm_mmu *context = vcpu->arch.mmu;
 
-	context->base_role.word = mmu_base_role_mask.word &
+	context->mmu_role.base.word = mmu_base_role_mask.word &
 				  kvm_calc_tdp_mmu_root_page_role(vcpu).word;
 	context->page_fault = tdp_page_fault;
 	context->sync_page = nonpaging_sync_page;
@@ -4823,7 +4824,7 @@ void kvm_init_shadow_mmu(struct kvm_vcpu *vcpu)
 	else
 		paging32_init_context(vcpu, context);
 
-	context->base_role.word = mmu_base_role_mask.word &
+	context->mmu_role.base.word = mmu_base_role_mask.word &
 				  kvm_calc_shadow_mmu_root_page_role(vcpu).word;
 	reset_shadow_zero_bits_mask(vcpu, context);
 }
@@ -4865,7 +4866,8 @@ void kvm_init_shadow_ept_mmu(struct kvm_vcpu *vcpu, bool execonly,
 	context->update_pte = ept_update_pte;
 	context->root_level = PT64_ROOT_4LEVEL;
 	context->direct_map = false;
-	context->base_role.word = root_page_role.word & mmu_base_role_mask.word;
+	context->mmu_role.base.word =
+		root_page_role.word & mmu_base_role_mask.word;
 
 	update_permission_bitmask(vcpu, context, true);
 	update_pkru_bitmask(vcpu, context, true);
@@ -5179,10 +5181,12 @@ static void kvm_mmu_pte_write(struct kvm_vcpu *vcpu, gpa_t gpa,
 
 		local_flush = true;
 		while (npte--) {
+			u32 base_role = vcpu->arch.mmu->mmu_role.base.word;
+
 			entry = *spte;
 			mmu_page_zap_pte(vcpu->kvm, sp, spte);
 			if (gentry &&
-			      !((sp->role.word ^ vcpu->arch.mmu->base_role.word)
+			      !((sp->role.word ^ base_role)
 			      & mmu_base_role_mask.word) && rmap_can_add(vcpu))
 				mmu_pte_write_new_pte(vcpu, sp, spte, &gentry);
 			if (need_remote_flush(entry, *spte))
@@ -5879,6 +5883,16 @@ int kvm_mmu_module_init(void)
 {
 	int ret = -ENOMEM;
 
+	/*
+	 * MMU roles use union aliasing which is, generally speaking, an
+	 * undefined behavior. However, we supposedly know how compilers behave
+	 * and the current status quo is unlikely to change. Guardians below are
+	 * supposed to let us know if the assumption becomes false.
+	 */
+	BUILD_BUG_ON(sizeof(union kvm_mmu_page_role) != sizeof(u32));
+	BUILD_BUG_ON(sizeof(union kvm_mmu_extended_role) != sizeof(u32));
+	BUILD_BUG_ON(sizeof(union kvm_mmu_role) != sizeof(u64));
+
 	kvm_mmu_reset_all_pte_masks();
 
 	pte_list_desc_cache = kmem_cache_create("pte_list_desc",
diff --git a/arch/x86/kvm/vmx.c b/arch/x86/kvm/vmx.c
index 02888031d038..6f44d3a63434 100644
--- a/arch/x86/kvm/vmx.c
+++ b/arch/x86/kvm/vmx.c
@@ -9263,7 +9263,7 @@ static int nested_vmx_eptp_switching(struct kvm_vcpu *vcpu,
 
 		kvm_mmu_unload(vcpu);
 		mmu->ept_ad = accessed_dirty;
-		mmu->base_role.ad_disabled = !accessed_dirty;
+		mmu->mmu_role.base.ad_disabled = !accessed_dirty;
 		vmcs12->ept_pointer = address;
 		/*
 		 * TODO: Check what's the correct approach in case