diff --git a/src/hyperlight_host/src/hypervisor/virtual_machine/mshv/x86_64.rs b/src/hyperlight_host/src/hypervisor/virtual_machine/mshv/x86_64.rs index 27f024ca6..badd4f90d 100644 --- a/src/hyperlight_host/src/hypervisor/virtual_machine/mshv/x86_64.rs +++ b/src/hyperlight_host/src/hypervisor/virtual_machine/mshv/x86_64.rs @@ -26,8 +26,8 @@ use mshv_bindings::LapicState; #[cfg(gdb)] use mshv_bindings::{DebugRegisters, hv_message_type_HVMSG_X64_EXCEPTION_INTERCEPT}; use mshv_bindings::{ - FloatingPointUnit, SpecialRegisters, StandardRegisters, XSave, hv_message_type, - hv_message_type_HVMSG_GPA_INTERCEPT, hv_message_type_HVMSG_UNMAPPED_GPA, + FloatingPointUnit, HV_X64_REGISTER_CLASS_IP, SpecialRegisters, StandardRegisters, XSave, + hv_message_type, hv_message_type_HVMSG_GPA_INTERCEPT, hv_message_type_HVMSG_UNMAPPED_GPA, hv_message_type_HVMSG_X64_HALT, hv_message_type_HVMSG_X64_IO_PORT_INTERCEPT, hv_partition_property_code_HV_PARTITION_PROPERTY_SYNTHETIC_PROC_FEATURES, hv_partition_synthetic_processor_features, hv_register_assoc, @@ -36,7 +36,8 @@ use mshv_bindings::{ }; #[cfg(feature = "hw-interrupts")] use mshv_bindings::{ - hv_interrupt_type_HV_X64_INTERRUPT_TYPE_FIXED, hv_register_name_HV_X64_REGISTER_RAX, + HV_X64_REGISTER_CLASS_GENERAL, hv_interrupt_type_HV_X64_INTERRUPT_TYPE_FIXED, + hv_register_name_HV_X64_REGISTER_RAX, set_gp_regs_field_ptr, }; #[cfg(feature = "hw-interrupts")] use mshv_ioctls::InterruptRequest; @@ -219,16 +220,30 @@ impl VirtualMachine for MshvVm { let instruction_length = io_message.header.instruction_length() as u64; let is_write = io_message.header.intercept_access_type != 0; - // mshv, unlike kvm, does not automatically increment RIP - self.vcpu_fd - .set_reg(&[hv_register_assoc { - name: hv_register_name_HV_X64_REGISTER_RIP, - value: hv_register_value { - reg64: rip + instruction_length, - }, - ..Default::default() - }]) - .map_err(|e| RunVcpuError::IncrementRip(e.into()))?; + // mshv, unlike kvm, does not automatically increment RIP. + if let Some(page) = self + .vcpu_fd + .get_vp_reg_page() + .filter(|p| unsafe { (*p.0).isvalid != 0 }) + { + // SAFETY: The register page is valid (isvalid checked + // above) and populated after a vcpu run intercept. + unsafe { + (*page.0).__bindgen_anon_1.__bindgen_anon_1.rip = + rip + instruction_length; + (*page.0).dirty |= 1 << HV_X64_REGISTER_CLASS_IP; + } + } else { + self.vcpu_fd + .set_reg(&[hv_register_assoc { + name: hv_register_name_HV_X64_REGISTER_RIP, + value: hv_register_value { + reg64: rip + instruction_length, + }, + ..Default::default() + }]) + .map_err(|e| RunVcpuError::IncrementRip(e.into()))?; + } // VmAction::Halt always means "I'm done", regardless // of whether a timer is active. @@ -253,13 +268,27 @@ impl VirtualMachine for MshvVm { } else if let Some(val) = super::super::x86_64::hw_interrupts::handle_io_in(port_number) { - self.vcpu_fd - .set_reg(&[hv_register_assoc { - name: hv_register_name_HV_X64_REGISTER_RAX, - value: hv_register_value { reg64: val }, - ..Default::default() - }]) - .map_err(|e| RunVcpuError::Unknown(e.into()))?; + if let Some(page) = self + .vcpu_fd + .get_vp_reg_page() + .filter(|p| unsafe { (*p.0).isvalid != 0 }) + { + let vp_reg_page = page.0; + set_gp_regs_field_ptr!(vp_reg_page, rax, val); + // SAFETY: page is valid (isvalid checked above). + unsafe { + (*vp_reg_page).dirty |= + 1 << HV_X64_REGISTER_CLASS_GENERAL; + } + } else { + self.vcpu_fd + .set_reg(&[hv_register_assoc { + name: hv_register_name_HV_X64_REGISTER_RAX, + value: hv_register_value { reg64: val }, + ..Default::default() + }]) + .map_err(|e| RunVcpuError::Unknown(e.into()))?; + } continue; } }