From b6cbe231077cada2db2dc2d1f74c2a89e5105bf8 Mon Sep 17 00:00:00 2001 From: Tobias Juelg Date: Sat, 2 May 2026 10:08:38 +0200 Subject: [PATCH 1/3] Add dynamic joint replay state --- python/rcs/_core/sim.pyi | 15 ++ python/rcs/envs/sim.py | 22 ++- python/rcs/sim/sim.py | 34 ++++ python/rcs/sim_state_replay.py | 38 ++++- python/tests/test_sim_state_record_replay.py | 156 +++++++++++++++-- src/pybind/rcs.cpp | 16 ++ src/sim/sim.cpp | 166 ++++++++++++++++++- src/sim/sim.h | 33 ++++ 8 files changed, 462 insertions(+), 18 deletions(-) diff --git a/python/rcs/_core/sim.pyi b/python/rcs/_core/sim.pyi index a75d598a..646a9e35 100644 --- a/python/rcs/_core/sim.pyi +++ b/python/rcs/_core/sim.pyi @@ -88,14 +88,29 @@ class GuiClient: def set_model_and_data(self, arg0: int, arg1: int) -> None: ... def sync(self) -> None: ... +class DynamicJointSchema: + joint_names: list[str] + joint_types: list[int] + qpos_sizes: list[int] + qvel_sizes: list[int] + def __init__(self) -> None: ... + +class DynamicJointState: + qpos: numpy.ndarray[tuple[M], numpy.dtype[numpy.float64]] + qvel: numpy.ndarray[tuple[M], numpy.dtype[numpy.float64]] + def __init__(self) -> None: ... + class Sim: def __init__(self, mjmdl: int, mjdata: int) -> None: ... def _start_gui_server(self, id: str) -> None: ... def _stop_gui_server(self) -> None: ... def get_config(self) -> SimConfig: ... + def get_dynamic_joint_schema(self) -> DynamicJointSchema: ... + def get_dynamic_joint_state(self) -> DynamicJointState: ... def is_converged(self) -> bool: ... def reset(self) -> None: ... def set_config(self, cfg: SimConfig) -> bool: ... + def set_dynamic_joint_state(self, schema: DynamicJointSchema, state: DynamicJointState) -> None: ... def step(self, k: int) -> None: ... def step_until_convergence(self) -> None: ... diff --git a/python/rcs/envs/sim.py b/python/rcs/envs/sim.py index 53fb60c3..c2e5424f 100644 --- a/python/rcs/envs/sim.py +++ b/python/rcs/envs/sim.py @@ -47,20 +47,34 @@ class SimStateObservationWrapper(ActObsInfoWrapper): STATE_KEY = "sim_state" STATE_SPEC_KEY = "sim_state_spec" STATE_SIZE_KEY = "sim_state_size" + DYNAMIC_JOINT_SCHEMA_KEY = "dynamic_joint_schema" + DYNAMIC_JOINT_QPOS_KEY = "dynamic_joint_qpos" + DYNAMIC_JOINT_QVEL_KEY = "dynamic_joint_qvel" def __init__(self, env): super().__init__(env) assert self.env.get_wrapper_attr("PLATFORM") == RobotPlatform.SIMULATION, "Base environment must be simulation." self.sim = cast(sim.Sim, self.get_wrapper_attr("sim")) + self._dynamic_joint_schema = self.sim.get_dynamic_joint_schema() + self._include_schema_in_next_step = True def observation(self, observation: dict[str, Any], info: dict[str, Any]) -> tuple[dict[str, Any], dict[str, Any]]: observation = dict(observation) - sim_state = self.sim.get_state() - observation[self.STATE_KEY] = sim_state - observation[self.STATE_SPEC_KEY] = self.sim.get_state_spec() - observation[self.STATE_SIZE_KEY] = sim_state.shape[0] + dynamic_joint_state = self.sim.get_dynamic_joint_state() + observation[self.DYNAMIC_JOINT_QPOS_KEY] = dynamic_joint_state["qpos"] + observation[self.DYNAMIC_JOINT_QVEL_KEY] = dynamic_joint_state["qvel"] + if self._include_schema_in_next_step: + observation[self.DYNAMIC_JOINT_SCHEMA_KEY] = self._dynamic_joint_schema + self._include_schema_in_next_step = False return observation, info + def reset( + self, *, seed: int | None = None, options: dict[str, Any] | None = None + ) -> tuple[dict[str, Any], dict[str, Any]]: + obs, info = super().reset(seed=seed, options=options) + self._include_schema_in_next_step = True + return obs, info + class GripperWrapperSim(ActObsInfoWrapper): def __init__(self, env): diff --git a/python/rcs/sim/sim.py b/python/rcs/sim/sim.py index 0d46773c..ffb17075 100644 --- a/python/rcs/sim/sim.py +++ b/python/rcs/sim/sim.py @@ -11,6 +11,8 @@ import mujoco as mj import mujoco.viewer import numpy as np +from rcs._core.sim import DynamicJointSchema as _DynamicJointSchema +from rcs._core.sim import DynamicJointState as _DynamicJointState from rcs._core.sim import GuiClient as _GuiClient from rcs._core.sim import Sim as _Sim from rcs.sim import SimConfig, egl_bootstrap @@ -90,6 +92,38 @@ def set_state(self, state: np.ndarray, spec: int | None = None): mj.mj_setState(self.model, self.data, state_array, state_spec) mj.mj_forward(self.model, self.data) + def get_dynamic_joint_schema(self) -> dict[str, list[str] | list[int]]: + schema = super().get_dynamic_joint_schema() + return { + "joint_names": list(schema.joint_names), + "joint_types": list(schema.joint_types), + "qpos_sizes": list(schema.qpos_sizes), + "qvel_sizes": list(schema.qvel_sizes), + } + + def get_dynamic_joint_state(self) -> dict[str, np.ndarray]: + state = super().get_dynamic_joint_state() + return { + "qpos": np.asarray(state.qpos, dtype=np.float64), + "qvel": np.asarray(state.qvel, dtype=np.float64), + } + + def set_dynamic_joint_state( + self, + schema: dict[str, list[str] | list[int]], + state: dict[str, np.ndarray], + ): + dynamic_joint_schema = _DynamicJointSchema() + dynamic_joint_schema.joint_names = list(schema["joint_names"]) + dynamic_joint_schema.joint_types = [int(value) for value in schema["joint_types"]] + dynamic_joint_schema.qpos_sizes = [int(value) for value in schema["qpos_sizes"]] + dynamic_joint_schema.qvel_sizes = [int(value) for value in schema["qvel_sizes"]] + + dynamic_joint_state = _DynamicJointState() + dynamic_joint_state.qpos = np.asarray(state["qpos"], dtype=np.float64) + dynamic_joint_state.qvel = np.asarray(state["qvel"], dtype=np.float64) + super().set_dynamic_joint_state(dynamic_joint_schema, dynamic_joint_state) + def close_gui(self): if self._stop_event is not None: self._stop_event.set() diff --git a/python/rcs/sim_state_replay.py b/python/rcs/sim_state_replay.py index f07eef91..9b98e09f 100644 --- a/python/rcs/sim_state_replay.py +++ b/python/rcs/sim_state_replay.py @@ -38,6 +38,23 @@ class RecordedSimStep: timestamp: float | None observation: dict[str, Any] + @property + def dynamic_joint_schema(self) -> dict[str, Any] | None: + schema = self.observation.get(SimStateObservationWrapper.DYNAMIC_JOINT_SCHEMA_KEY) + return dict(schema) if schema is not None else None + + @property + def dynamic_joint_state(self) -> dict[str, np.ndarray] | None: + if ( + SimStateObservationWrapper.DYNAMIC_JOINT_QPOS_KEY not in self.observation + or SimStateObservationWrapper.DYNAMIC_JOINT_QVEL_KEY not in self.observation + ): + return None + return { + "qpos": np.asarray(self.observation[SimStateObservationWrapper.DYNAMIC_JOINT_QPOS_KEY], dtype=np.float64), + "qvel": np.asarray(self.observation[SimStateObservationWrapper.DYNAMIC_JOINT_QVEL_KEY], dtype=np.float64), + } + @property def sim_state(self) -> np.ndarray: return np.asarray(self.observation[SimStateObservationWrapper.STATE_KEY], dtype=np.float64) @@ -149,8 +166,20 @@ def resolve_trajectory_uuid(dataset_path: Path, trajectory_uuid: str | None, pre raise ValueError(msg) -def restore_sim_step(env: gym.Env, recorded_step: RecordedSimStep): +def restore_sim_step( + env: gym.Env, + recorded_step: RecordedSimStep, + dynamic_joint_schema: dict[str, Any] | None = None, +): sim = env.get_wrapper_attr("sim") + dynamic_joint_state = recorded_step.dynamic_joint_state + if dynamic_joint_state is not None: + resolved_schema = dynamic_joint_schema or recorded_step.dynamic_joint_schema + if resolved_schema is None: + msg = "Recorded dynamic joint state is missing its schema." + raise ValueError(msg) + sim.set_dynamic_joint_state(resolved_schema, dynamic_joint_state) + return sim.set_state(recorded_step.sim_state, spec=recorded_step.sim_state_spec) @@ -190,9 +219,14 @@ def replay_trajectory( msg = "No recorded sim states found in the requested trajectory." raise ValueError(msg) + dynamic_joint_schema = next( + (recorded_step.dynamic_joint_schema for recorded_step in recorded_steps if recorded_step.dynamic_joint_schema), + None, + ) + env.reset() for recorded_step in recorded_steps: - restore_sim_step(env, recorded_step) + restore_sim_step(env, recorded_step, dynamic_joint_schema=dynamic_joint_schema) if output_dir is not None: save_rgb_frames(output_dir, recorded_step, collect_rgb_frames(env)) if sleep_s > 0: diff --git a/python/tests/test_sim_state_record_replay.py b/python/tests/test_sim_state_record_replay.py index 1e974812..2befdc91 100644 --- a/python/tests/test_sim_state_record_replay.py +++ b/python/tests/test_sim_state_record_replay.py @@ -2,6 +2,7 @@ import importlib.util import sys +import xml.etree.ElementTree as ET from dataclasses import dataclass from pathlib import Path @@ -10,8 +11,12 @@ import numpy as np import pyarrow.dataset as ds from rcs._core.common import RobotPlatform +from rcs._core.sim import SimConfig from rcs.camera.interface import CameraFrame, DataFrame, Frame, FrameSet +from rcs.envs.base import ControlMode, JointsDictType +from rcs.envs.creators import SimMultiEnvCreator from rcs.envs.storage_wrapper import StorageWrapper +from rcs.envs.utils import default_sim_gripper_cfg, default_sim_robot_cfg import rcs @@ -52,7 +57,7 @@ def _load_local_module(module_name: str, relative_path: str): - + @@ -134,7 +139,7 @@ def test_record_and_replay_sim_state(tmp_path: Path): record_env = StorageWrapper(record_env, str(dataset_path), "test sim replay", batch_size=1, always_record=True) obs, _ = record_env.reset() - assert SimStateObservationWrapper.STATE_KEY in obs + assert SimStateObservationWrapper.DYNAMIC_JOINT_SCHEMA_KEY in obs record_env.step({"delta": np.array([0.125], dtype=np.float64)}) record_env.close() @@ -144,17 +149,17 @@ def test_record_and_replay_sim_state(tmp_path: Path): assert len(rows) == 1 recorded_obs = rows[0]["obs"] - assert SimStateObservationWrapper.STATE_KEY in recorded_obs - assert SimStateObservationWrapper.STATE_SPEC_KEY in recorded_obs - assert SimStateObservationWrapper.STATE_SIZE_KEY in recorded_obs - assert ( - len(recorded_obs[SimStateObservationWrapper.STATE_KEY]) - == recorded_obs[SimStateObservationWrapper.STATE_SIZE_KEY] - ) + assert SimStateObservationWrapper.DYNAMIC_JOINT_SCHEMA_KEY in recorded_obs + assert SimStateObservationWrapper.DYNAMIC_JOINT_QPOS_KEY in recorded_obs + assert SimStateObservationWrapper.DYNAMIC_JOINT_QVEL_KEY in recorded_obs recorded_steps = load_trajectory(dataset_path, rows[0]["uuid"], prefer_duckdb=True) assert len(recorded_steps) == 1 - assert np.allclose(recorded_steps[0].sim_state, np.asarray(recorded_obs[SimStateObservationWrapper.STATE_KEY])) + assert recorded_steps[0].dynamic_joint_schema is not None + assert np.allclose( + recorded_steps[0].dynamic_joint_state["qpos"], # type: ignore[index] + np.asarray(recorded_obs[SimStateObservationWrapper.DYNAMIC_JOINT_QPOS_KEY]), + ) replay_sim = Sim(model_path) replay_env: gym.Env = DummySimEnv(replay_sim, camera_set=DummyCameraSet(replay_sim)) @@ -162,7 +167,7 @@ def test_record_and_replay_sim_state(tmp_path: Path): render_dir = tmp_path / "rendered" replay_env.reset() - restore_sim_step(replay_env, recorded_steps[0]) + restore_sim_step(replay_env, recorded_steps[0], dynamic_joint_schema=recorded_steps[0].dynamic_joint_schema) assert np.allclose( replay_env.get_wrapper_attr("sim").data.qpos, np.asarray(recorded_obs["qpos"]), atol=1e-9, rtol=0 ) @@ -174,3 +179,132 @@ def test_record_and_replay_sim_state(tmp_path: Path): rendered_files = sorted(path.name for path in render_dir.glob("*.png")) assert rendered_files == ["step-000000-main.png"] + + +def _write_scene_with_extra_fixed_body_and_camera(src: Path, dst: Path): + tree = ET.parse(src) + root = tree.getroot() + for include in root.findall("include"): + include_file = include.get("file") + if include_file is not None and not Path(include_file).is_absolute(): + include.set("file", str((src.parent / include_file).resolve())) + + worldbody = root.find("worldbody") + assert worldbody is not None + + worldbody.append( + ET.Element( + "camera", + { + "name": "replay_extra_cam", + "pos": "1.4 0.0 0.9", + "xyaxes": "0 1 0 -0.3 0 1", + }, + ) + ) + body = ET.SubElement(worldbody, "body", {"name": "replay_extra_bg", "pos": "3 3 3"}) + ET.SubElement(body, "geom", {"name": "replay_extra_bg_geom", "type": "box", "size": "0.1 0.1 0.1"}) + tree.write(dst) + + +def _record_dummy_trajectory(dataset_path: Path, model_path: Path) -> tuple[list, dict[str, object]]: + record_env: gym.Env = DummySimEnv(Sim(model_path)) + record_env = SimStateObservationWrapper(record_env) + record_env = StorageWrapper(record_env, str(dataset_path), "test sim replay", batch_size=1, always_record=True) + record_env.reset() + record_env.step({"delta": np.array([0.125], dtype=np.float64)}) + record_env.close() + + table = ds.dataset(str(dataset_path), format="parquet").to_table().sort_by([("step", "ascending")]) + rows = table.to_pylist() + recorded_steps = load_trajectory(dataset_path, rows[0]["uuid"], prefer_duckdb=True) + return recorded_steps, rows[0]["obs"] + + +def test_dynamic_joint_replay_tolerates_added_and_removed_fixed_scene_elements(tmp_path: Path): + base_model_path = tmp_path / "base.xml" + base_model_path.write_text(XML) + modified_model_path = tmp_path / "modified.xml" + _write_scene_with_extra_fixed_body_and_camera(base_model_path, modified_model_path) + + for record_model_path, replay_model_path in ( + (base_model_path, modified_model_path), + (modified_model_path, base_model_path), + ): + dataset_path = tmp_path / f"dataset-{record_model_path.stem}-to-{replay_model_path.stem}" + recorded_steps, recorded_obs = _record_dummy_trajectory(dataset_path, record_model_path) + + replay_sim = Sim(replay_model_path) + replay_env: gym.Env = DummySimEnv(replay_sim) + replay_env = SimStateObservationWrapper(replay_env) + replay_env.reset() + dynamic_joint_schema = next( + step.dynamic_joint_schema for step in recorded_steps if step.dynamic_joint_schema is not None + ) + restore_sim_step(replay_env, recorded_steps[0], dynamic_joint_schema=dynamic_joint_schema) + + assert np.allclose( + replay_env.get_wrapper_attr("sim").data.qpos, np.asarray(recorded_obs["qpos"]), atol=1e-9, rtol=0 + ) + assert np.allclose( + replay_env.get_wrapper_attr("sim").data.qvel, np.asarray(recorded_obs["qvel"]), atol=1e-9, rtol=0 + ) + + +DUAL_ARM_ROBOT2ID = {"left": "0", "right": "1"} + + +def _create_dual_arm_env(scene_name: str): + robot_cfg = default_sim_robot_cfg(scene_name, idx="") + sim_cfg = SimConfig() + sim_cfg.async_control = False + return SimMultiEnvCreator()( + name2id=DUAL_ARM_ROBOT2ID, + robot_cfg=robot_cfg, + control_mode=ControlMode.JOINTS, + gripper_cfg=default_sim_gripper_cfg(idx=""), + sim_cfg=sim_cfg, + max_relative_movement=None, + ) + + +def test_dynamic_joint_state_roundtrip_on_fr3_dual_arm_scene(tmp_path: Path): + source_scene_path = REPO_ROOT / "assets/scenes/fr3_dual_arm/scene.xml" + source_robot_path = REPO_ROOT / "assets/scenes/fr3_empty_world/robot.xml" + source_urdf_path = REPO_ROOT / "assets/scenes/fr3_empty_world/robot.urdf" + modified_scene_path = source_scene_path.parent / "scene_dynamic_joint_test.xml" + _write_scene_with_extra_fixed_body_and_camera(source_scene_path, modified_scene_path) + + base_scene_name = "fr3_dual_arm_dynamic_joint_base_test" + test_scene_name = "fr3_dual_arm_dynamic_joint_test" + scene_kwargs = { + "mjcf_robot": str(source_robot_path), + "urdf": str(source_urdf_path), + "robot_type": rcs.scenes["fr3_dual_arm"].robot_type, + "mjb": None, + } + rcs.scenes[base_scene_name] = rcs.Scene(mjcf_scene=str(source_scene_path), **scene_kwargs) + rcs.scenes[test_scene_name] = rcs.Scene(mjcf_scene=str(modified_scene_path), **scene_kwargs) + + base_env = _create_dual_arm_env(base_scene_name) + modified_env = _create_dual_arm_env(test_scene_name) + try: + base_env.reset() + base_sim = base_env.get_wrapper_attr("sim") + dynamic_joint_schema = base_sim.get_dynamic_joint_schema() + dynamic_joint_state = base_sim.get_dynamic_joint_state() + + modified_env.reset() + modified_sim = modified_env.get_wrapper_attr("sim") + modified_sim.set_dynamic_joint_state(dynamic_joint_schema, dynamic_joint_state) + restored_dynamic_joint_state = modified_sim.get_dynamic_joint_state() + + assert dynamic_joint_schema == modified_sim.get_dynamic_joint_schema() + assert np.allclose(restored_dynamic_joint_state["qpos"], dynamic_joint_state["qpos"], atol=1e-9, rtol=0) + assert np.allclose(restored_dynamic_joint_state["qvel"], dynamic_joint_state["qvel"], atol=1e-9, rtol=0) + finally: + base_env.close() + modified_env.close() + del rcs.scenes[test_scene_name] + del rcs.scenes[base_scene_name] + modified_scene_path.unlink(missing_ok=True) diff --git a/src/pybind/rcs.cpp b/src/pybind/rcs.cpp index 007ee475..54222b32 100644 --- a/src/pybind/rcs.cpp +++ b/src/pybind/rcs.cpp @@ -703,6 +703,18 @@ PYBIND11_MODULE(_core, m) { return rcs::sim::SimConfig(self); }); + py::class_(sim, "DynamicJointSchema") + .def(py::init<>()) + .def_readwrite("joint_names", &rcs::sim::DynamicJointSchema::joint_names) + .def_readwrite("joint_types", &rcs::sim::DynamicJointSchema::joint_types) + .def_readwrite("qpos_sizes", &rcs::sim::DynamicJointSchema::qpos_sizes) + .def_readwrite("qvel_sizes", &rcs::sim::DynamicJointSchema::qvel_sizes); + + py::class_(sim, "DynamicJointState") + .def(py::init<>()) + .def_readwrite("qpos", &rcs::sim::DynamicJointState::qpos) + .def_readwrite("qvel", &rcs::sim::DynamicJointState::qvel); + py::class_>(sim, "Sim") .def(py::init([](long m, long d) { return std::make_shared((mjModel*)m, (mjData*)d); @@ -715,6 +727,10 @@ PYBIND11_MODULE(_core, m) { .def("get_config", &rcs::sim::Sim::get_config) .def("step", &rcs::sim::Sim::step, py::arg("k")) .def("reset", &rcs::sim::Sim::reset) + .def("get_dynamic_joint_schema", &rcs::sim::Sim::get_dynamic_joint_schema) + .def("get_dynamic_joint_state", &rcs::sim::Sim::get_dynamic_joint_state) + .def("set_dynamic_joint_state", &rcs::sim::Sim::set_dynamic_joint_state, + py::arg("schema"), py::arg("state")) .def("_start_gui_server", &rcs::sim::Sim::start_gui_server, py::arg("id")) .def("_stop_gui_server", &rcs::sim::Sim::stop_gui_server); diff --git a/src/sim/sim.cpp b/src/sim/sim.cpp index 164ad2c8..7f7d1e0a 100644 --- a/src/sim/sim.cpp +++ b/src/sim/sim.cpp @@ -5,6 +5,9 @@ #include #include #include +#include +#include +#include #include namespace rcs { @@ -25,7 +28,65 @@ bool get_last_return_value(ConditionCallback cb) { return cb.last_return_value; } -Sim::Sim(mjModel* m, mjData* d) : m(m), d(d), renderer(m) {}; +int Sim::get_joint_qpos_size(int joint_type) { + switch (joint_type) { + case mjJNT_FREE: + return 7; + case mjJNT_BALL: + return 4; + case mjJNT_SLIDE: + case mjJNT_HINGE: + return 1; + default: + throw std::runtime_error("Unsupported MuJoCo joint type for qpos size."); + } +} + +int Sim::get_joint_qvel_size(int joint_type) { + switch (joint_type) { + case mjJNT_FREE: + return 6; + case mjJNT_BALL: + return 3; + case mjJNT_SLIDE: + case mjJNT_HINGE: + return 1; + default: + throw std::runtime_error("Unsupported MuJoCo joint type for qvel size."); + } +} + +void Sim::init_dynamic_joint_specs() { + this->dynamic_joint_specs.clear(); + this->dynamic_joint_name_to_index.clear(); + + for (int joint_id = 0; joint_id < this->m->njnt; ++joint_id) { + const char* joint_name = mj_id2name(this->m, mjOBJ_JOINT, joint_id); + if (joint_name == nullptr || joint_name[0] == '\0') { + std::ostringstream msg; + msg << "Dynamic joint state requires all joints to be named. Joint id " + << joint_id << " is unnamed."; + throw std::runtime_error(msg.str()); + } + + DynamicJointSpec spec{ + .name = joint_name, + .type = this->m->jnt_type[joint_id], + .qpos_adr = this->m->jnt_qposadr[joint_id], + .qvel_adr = this->m->jnt_dofadr[joint_id], + .qpos_size = get_joint_qpos_size(this->m->jnt_type[joint_id]), + .qvel_size = get_joint_qvel_size(this->m->jnt_type[joint_id]), + }; + + this->dynamic_joint_name_to_index[spec.name] = + this->dynamic_joint_specs.size(); + this->dynamic_joint_specs.push_back(spec); + } +} + +Sim::Sim(mjModel* m, mjData* d) : m(m), d(d), renderer(m) { + this->init_dynamic_joint_specs(); +}; bool Sim::set_config(const SimConfig& cfg) { this->cfg = cfg; @@ -118,6 +179,109 @@ void Sim::reset() { this->reset_callbacks(); } +DynamicJointSchema Sim::get_dynamic_joint_schema() const { + DynamicJointSchema schema; + schema.joint_names.reserve(this->dynamic_joint_specs.size()); + schema.joint_types.reserve(this->dynamic_joint_specs.size()); + schema.qpos_sizes.reserve(this->dynamic_joint_specs.size()); + schema.qvel_sizes.reserve(this->dynamic_joint_specs.size()); + + for (const DynamicJointSpec& spec : this->dynamic_joint_specs) { + schema.joint_names.push_back(spec.name); + schema.joint_types.push_back(spec.type); + schema.qpos_sizes.push_back(spec.qpos_size); + schema.qvel_sizes.push_back(spec.qvel_size); + } + return schema; +} + +DynamicJointState Sim::get_dynamic_joint_state() const { + DynamicJointState state; + int total_qpos = 0; + int total_qvel = 0; + for (const DynamicJointSpec& spec : this->dynamic_joint_specs) { + total_qpos += spec.qpos_size; + total_qvel += spec.qvel_size; + } + + state.qpos = rcs::common::VectorXd(total_qpos); + state.qvel = rcs::common::VectorXd(total_qvel); + + int qpos_offset = 0; + int qvel_offset = 0; + for (const DynamicJointSpec& spec : this->dynamic_joint_specs) { + for (int i = 0; i < spec.qpos_size; ++i) { + state.qpos[qpos_offset + i] = this->d->qpos[spec.qpos_adr + i]; + } + for (int i = 0; i < spec.qvel_size; ++i) { + state.qvel[qvel_offset + i] = this->d->qvel[spec.qvel_adr + i]; + } + qpos_offset += spec.qpos_size; + qvel_offset += spec.qvel_size; + } + + return state; +} + +void Sim::set_dynamic_joint_state(const DynamicJointSchema& schema, + const DynamicJointState& state) { + size_t joint_count = schema.joint_names.size(); + if (schema.joint_types.size() != joint_count || + schema.qpos_sizes.size() != joint_count || + schema.qvel_sizes.size() != joint_count) { + throw std::invalid_argument( + "Dynamic joint schema fields must all have the same length."); + } + + int expected_qpos_size = std::accumulate(schema.qpos_sizes.begin(), + schema.qpos_sizes.end(), 0); + int expected_qvel_size = std::accumulate(schema.qvel_sizes.begin(), + schema.qvel_sizes.end(), 0); + if (state.qpos.size() != expected_qpos_size) { + std::ostringstream msg; + msg << "Dynamic joint qpos size mismatch. Expected " + << expected_qpos_size << ", got " << state.qpos.size() << "."; + throw std::invalid_argument(msg.str()); + } + if (state.qvel.size() != expected_qvel_size) { + std::ostringstream msg; + msg << "Dynamic joint qvel size mismatch. Expected " + << expected_qvel_size << ", got " << state.qvel.size() << "."; + throw std::invalid_argument(msg.str()); + } + + int qpos_offset = 0; + int qvel_offset = 0; + for (size_t i = 0; i < joint_count; ++i) { + auto spec_iter = + this->dynamic_joint_name_to_index.find(schema.joint_names[i]); + if (spec_iter != this->dynamic_joint_name_to_index.end()) { + const DynamicJointSpec& target_spec = + this->dynamic_joint_specs[spec_iter->second]; + if (target_spec.type != schema.joint_types[i] || + target_spec.qpos_size != schema.qpos_sizes[i] || + target_spec.qvel_size != schema.qvel_sizes[i]) { + std::ostringstream msg; + msg << "Dynamic joint schema mismatch for joint '" + << schema.joint_names[i] << "'."; + throw std::invalid_argument(msg.str()); + } + + for (int j = 0; j < target_spec.qpos_size; ++j) { + this->d->qpos[target_spec.qpos_adr + j] = state.qpos[qpos_offset + j]; + } + for (int j = 0; j < target_spec.qvel_size; ++j) { + this->d->qvel[target_spec.qvel_adr + j] = state.qvel[qvel_offset + j]; + } + } + + qpos_offset += schema.qpos_sizes[i]; + qvel_offset += schema.qvel_sizes[i]; + } + + mj_forward(this->m, this->d); +} + void Sim::reset_callbacks() { for (size_t i = 0; i < std::size(this->callbacks); ++i) { this->callbacks[i].last_call_timestamp = 0; diff --git a/src/sim/sim.h b/src/sim/sim.h index 62a2fbbf..b96a9aa7 100644 --- a/src/sim/sim.h +++ b/src/sim/sim.h @@ -5,10 +5,13 @@ #include #include #include +#include +#include #include "boost/interprocess/managed_shared_memory.hpp" #include "gui.h" #include "mujoco/mujoco.h" +#include "rcs/utils.h" namespace rcs { namespace sim { @@ -55,16 +58,42 @@ struct RenderingCallback { mjtNum last_call_timestamp; // in seconds }; +struct DynamicJointSchema { + std::vector joint_names; + std::vector joint_types; + std::vector qpos_sizes; + std::vector qvel_sizes; +}; + +struct DynamicJointState { + rcs::common::VectorXd qpos; + rcs::common::VectorXd qvel; +}; + class Sim { private: + struct DynamicJointSpec { + std::string name; + int type; + int qpos_adr; + int qvel_adr; + int qpos_size; + int qvel_size; + }; + SimConfig cfg; std::vector callbacks; std::vector any_callbacks; std::vector all_callbacks; std::vector rendering_callbacks; + std::vector dynamic_joint_specs; + std::unordered_map dynamic_joint_name_to_index; void invoke_callbacks(); bool invoke_condition_callbacks(); void invoke_rendering_callbacks(); + void init_dynamic_joint_specs(); + static int get_joint_qpos_size(int joint_type); + static int get_joint_qvel_size(int joint_type); size_t convergence_steps = 0; bool converged = true; std::optional gui; @@ -82,6 +111,10 @@ class Sim { void step(size_t k); void reset_callbacks(); void reset(); + DynamicJointSchema get_dynamic_joint_schema() const; + DynamicJointState get_dynamic_joint_state() const; + void set_dynamic_joint_state(const DynamicJointSchema& schema, + const DynamicJointState& state); /* NOTE: IMPORTANT, the callback is not necessarily called at exactly the * the requested interval. We invoke a callback if the elapsed simulation time * since the last call of the callback is greater than the requested time. From 5467832fbf7c523ab80448354ab61752e22bb7c6 Mon Sep 17 00:00:00 2001 From: Tobias Juelg Date: Sat, 2 May 2026 15:39:27 +0200 Subject: [PATCH 2/3] Remove legacy replay fallback --- python/rcs/envs/sim.py | 4 +-- python/rcs/sim_state_replay.py | 27 ++++++--------- src/sim/sim.cpp | 60 ++++++++++++++++++++++++---------- 3 files changed, 53 insertions(+), 38 deletions(-) diff --git a/python/rcs/envs/sim.py b/python/rcs/envs/sim.py index c2e5424f..b5348455 100644 --- a/python/rcs/envs/sim.py +++ b/python/rcs/envs/sim.py @@ -44,9 +44,6 @@ def reset( class SimStateObservationWrapper(ActObsInfoWrapper): - STATE_KEY = "sim_state" - STATE_SPEC_KEY = "sim_state_spec" - STATE_SIZE_KEY = "sim_state_size" DYNAMIC_JOINT_SCHEMA_KEY = "dynamic_joint_schema" DYNAMIC_JOINT_QPOS_KEY = "dynamic_joint_qpos" DYNAMIC_JOINT_QVEL_KEY = "dynamic_joint_qvel" @@ -72,6 +69,7 @@ def reset( self, *, seed: int | None = None, options: dict[str, Any] | None = None ) -> tuple[dict[str, Any], dict[str, Any]]: obs, info = super().reset(seed=seed, options=options) + # Re-emit the schema on the first recorded step after each reset. self._include_schema_in_next_step = True return obs, info diff --git a/python/rcs/sim_state_replay.py b/python/rcs/sim_state_replay.py index 9b98e09f..4a0c7c22 100644 --- a/python/rcs/sim_state_replay.py +++ b/python/rcs/sim_state_replay.py @@ -55,14 +55,6 @@ def dynamic_joint_state(self) -> dict[str, np.ndarray] | None: "qvel": np.asarray(self.observation[SimStateObservationWrapper.DYNAMIC_JOINT_QVEL_KEY], dtype=np.float64), } - @property - def sim_state(self) -> np.ndarray: - return np.asarray(self.observation[SimStateObservationWrapper.STATE_KEY], dtype=np.float64) - - @property - def sim_state_spec(self) -> int: - return int(self.observation.get(SimStateObservationWrapper.STATE_SPEC_KEY, 0)) - class DuckDBUnavailableError(RuntimeError): pass @@ -173,14 +165,15 @@ def restore_sim_step( ): sim = env.get_wrapper_attr("sim") dynamic_joint_state = recorded_step.dynamic_joint_state - if dynamic_joint_state is not None: - resolved_schema = dynamic_joint_schema or recorded_step.dynamic_joint_schema - if resolved_schema is None: - msg = "Recorded dynamic joint state is missing its schema." - raise ValueError(msg) - sim.set_dynamic_joint_state(resolved_schema, dynamic_joint_state) - return - sim.set_state(recorded_step.sim_state, spec=recorded_step.sim_state_spec) + if dynamic_joint_state is None: + msg = "Recorded step is missing dynamic joint state data." + raise ValueError(msg) + + resolved_schema = dynamic_joint_schema or recorded_step.dynamic_joint_schema + if resolved_schema is None: + msg = "Recorded dynamic joint state is missing its schema." + raise ValueError(msg) + sim.set_dynamic_joint_state(resolved_schema, dynamic_joint_state) def collect_rgb_frames(env: gym.Env) -> dict[str, np.ndarray]: @@ -216,7 +209,7 @@ def replay_trajectory( output_dir: Path | None = None, ): if not recorded_steps: - msg = "No recorded sim states found in the requested trajectory." + msg = "No recorded dynamic joint states found in the requested trajectory." raise ValueError(msg) dynamic_joint_schema = next( diff --git a/src/sim/sim.cpp b/src/sim/sim.cpp index 7f7d1e0a..3566ac15 100644 --- a/src/sim/sim.cpp +++ b/src/sim/sim.cpp @@ -250,35 +250,59 @@ void Sim::set_dynamic_joint_state(const DynamicJointSchema& schema, throw std::invalid_argument(msg.str()); } + std::vector matched_target_joints(this->dynamic_joint_specs.size(), + false); int qpos_offset = 0; int qvel_offset = 0; for (size_t i = 0; i < joint_count; ++i) { auto spec_iter = this->dynamic_joint_name_to_index.find(schema.joint_names[i]); - if (spec_iter != this->dynamic_joint_name_to_index.end()) { - const DynamicJointSpec& target_spec = - this->dynamic_joint_specs[spec_iter->second]; - if (target_spec.type != schema.joint_types[i] || - target_spec.qpos_size != schema.qpos_sizes[i] || - target_spec.qvel_size != schema.qvel_sizes[i]) { - std::ostringstream msg; - msg << "Dynamic joint schema mismatch for joint '" - << schema.joint_names[i] << "'."; - throw std::invalid_argument(msg.str()); - } - - for (int j = 0; j < target_spec.qpos_size; ++j) { - this->d->qpos[target_spec.qpos_adr + j] = state.qpos[qpos_offset + j]; - } - for (int j = 0; j < target_spec.qvel_size; ++j) { - this->d->qvel[target_spec.qvel_adr + j] = state.qvel[qvel_offset + j]; - } + if (spec_iter == this->dynamic_joint_name_to_index.end()) { + std::cerr << "WARNING: Recorded dynamic joint '" << schema.joint_names[i] + << "' is missing in the replay model. Skipping it." + << std::endl; + qpos_offset += schema.qpos_sizes[i]; + qvel_offset += schema.qvel_sizes[i]; + continue; + } + + const DynamicJointSpec& target_spec = + this->dynamic_joint_specs[spec_iter->second]; + matched_target_joints[spec_iter->second] = true; + if (target_spec.type != schema.joint_types[i] || + target_spec.qpos_size != schema.qpos_sizes[i] || + target_spec.qvel_size != schema.qvel_sizes[i]) { + std::ostringstream msg; + msg << "Dynamic joint schema mismatch for joint '" + << schema.joint_names[i] << "': expected type=" << target_spec.type + << ", qpos_size=" << target_spec.qpos_size + << ", qvel_size=" << target_spec.qvel_size << " but got type=" + << schema.joint_types[i] << ", qpos_size=" << schema.qpos_sizes[i] + << ", qvel_size=" << schema.qvel_sizes[i] << "."; + throw std::invalid_argument(msg.str()); + } + + for (int j = 0; j < target_spec.qpos_size; ++j) { + this->d->qpos[target_spec.qpos_adr + j] = state.qpos[qpos_offset + j]; + } + for (int j = 0; j < target_spec.qvel_size; ++j) { + this->d->qvel[target_spec.qvel_adr + j] = state.qvel[qvel_offset + j]; } qpos_offset += schema.qpos_sizes[i]; qvel_offset += schema.qvel_sizes[i]; } + for (size_t i = 0; i < this->dynamic_joint_specs.size(); ++i) { + if (!matched_target_joints[i]) { + std::cerr << "WARNING: Replay model dynamic joint '" + << this->dynamic_joint_specs[i].name + << "' is missing in the recorded schema. Leaving it at its " + "current value." + << std::endl; + } + } + mj_forward(this->m, this->d); } From 8c5f4a7a904136ee179cc8e3fc81120a415ee2b5 Mon Sep 17 00:00:00 2001 From: Tobias Juelg Date: Sun, 3 May 2026 02:21:48 +0200 Subject: [PATCH 3/3] Refactor replay onto existing sim_state flow --- python/rcs/_core/sim.pyi | 53 ++++++++++++-------- python/rcs/envs/sim.py | 25 +++++---- python/rcs/sim/sim.py | 47 +++++++++-------- python/rcs/sim_state_replay.py | 43 ++++++---------- python/tests/test_sim_state_record_replay.py | 50 +++++++++--------- src/sim/sim.cpp | 21 ++++---- 6 files changed, 121 insertions(+), 118 deletions(-) diff --git a/python/rcs/_core/sim.pyi b/python/rcs/_core/sim.pyi index 646a9e35..5c1e5722 100644 --- a/python/rcs/_core/sim.pyi +++ b/python/rcs/_core/sim.pyi @@ -11,6 +11,8 @@ import rcs._core.common __all__: list[str] = [ "CameraType", + "DynamicJointSchema", + "DynamicJointState", "FrameSet", "GuiClient", "Sim", @@ -32,6 +34,7 @@ __all__: list[str] = [ "tracking", ] M = typing.TypeVar("M", bound=int) +N = typing.TypeVar("N", bound=int) class CameraType: """ @@ -68,6 +71,18 @@ class CameraType: @property def value(self) -> int: ... +class DynamicJointSchema: + joint_names: list[str] + joint_types: list[int] + qpos_sizes: list[int] + qvel_sizes: list[int] + def __init__(self) -> None: ... + +class DynamicJointState: + qpos: numpy.ndarray[tuple[M], numpy.dtype[numpy.float64]] + qvel: numpy.ndarray[tuple[M], numpy.dtype[numpy.float64]] + def __init__(self) -> None: ... + class FrameSet: def __init__( self, @@ -88,18 +103,6 @@ class GuiClient: def set_model_and_data(self, arg0: int, arg1: int) -> None: ... def sync(self) -> None: ... -class DynamicJointSchema: - joint_names: list[str] - joint_types: list[int] - qpos_sizes: list[int] - qvel_sizes: list[int] - def __init__(self) -> None: ... - -class DynamicJointState: - qpos: numpy.ndarray[tuple[M], numpy.dtype[numpy.float64]] - qvel: numpy.ndarray[tuple[M], numpy.dtype[numpy.float64]] - def __init__(self) -> None: ... - class Sim: def __init__(self, mjmdl: int, mjdata: int) -> None: ... def _start_gui_server(self, id: str) -> None: ... @@ -113,15 +116,20 @@ class Sim: def set_dynamic_joint_state(self, schema: DynamicJointSchema, state: DynamicJointState) -> None: ... def step(self, k: int) -> None: ... def step_until_convergence(self) -> None: ... + def sync_gui(self) -> None: ... class SimCameraConfig(rcs._core.common.BaseCameraConfig): type: CameraType + def __copy__(self) -> SimCameraConfig: ... + def __deepcopy__(self, arg0: dict) -> SimCameraConfig: ... def __init__( self, identifier: str, frame_rate: int, resolution_width: int, resolution_height: int, type: CameraType = ... ) -> None: ... class SimCameraSet: - def __init__(self, sim: Sim, cameras: dict[str, SimCameraConfig], render_on_demand: bool = True) -> None: ... + def __init__( + self, sim: Sim, cameras: dict[str, SimCameraConfig], render_on_demand: bool = True, max_buffer_frames: int = 100 + ) -> None: ... def buffer_size(self) -> int: ... def clear_buffer(self) -> None: ... def get_latest_frameset(self) -> FrameSet | None: ... @@ -131,7 +139,7 @@ class SimCameraSet: class SimConfig: async_control: bool - frequency: int + frequency: float max_convergence_steps: int realtime: bool def __copy__(self) -> SimConfig: ... @@ -140,7 +148,7 @@ class SimConfig: self, async_control: bool = False, realtime: bool = False, - frequency: float = 30, + frequency: float = 30.0, max_convergence_steps: int = 500, ) -> None: ... @@ -157,6 +165,7 @@ class SimGripperConfig(rcs._core.common.GripperConfig): collision_geoms_fingers: list[str] epsilon_inner: float epsilon_outer: float + gripper_type: rcs._core.common.GripperType ignored_collision_geoms: list[str] joints: list[str] max_actuator_width: float @@ -180,8 +189,9 @@ class SimGripperConfig(rcs._core.common.GripperConfig): actuator: str = "actuator8", max_actuator_width: float = 255.0, min_actuator_width: float = 0.0, + gripper_type: rcs._core.common.GripperType = ..., ) -> None: ... - def add_postfix(self, id: str) -> None: ... + def add_prefix(self, id: str) -> None: ... class SimGripperState(rcs._core.common.GripperState): def __init__(self) -> None: ... @@ -208,9 +218,10 @@ class SimRobotConfig(rcs._core.common.RobotConfig): actuators: list[str] arm_collision_geoms: list[str] base: str + dof: int + joint_limits: numpy.ndarray[tuple[typing.Literal[2], N], numpy.dtype[numpy.float64]] joint_rotational_tolerance: float joints: list[str] - mjcf_scene_path: str seconds_between_callbacks: float trajectory_trace: bool def __copy__(self) -> SimRobotConfig: ... @@ -223,7 +234,6 @@ class SimRobotConfig(rcs._core.common.RobotConfig): kinematic_model_path: str = "assets/scenes/fr3_empty_world/robot.xml", joint_rotational_tolerance: float = 0.0008726646259971648, seconds_between_callbacks: float = 0.1, - mjcf_scene_path: str = "assets/scenes/fr3_empty_world/scene.xml", trajectory_trace: bool = False, arm_collision_geoms: list[str] = [ "fr3_link0_collision", @@ -244,6 +254,7 @@ class SimRobotConfig(rcs._core.common.RobotConfig): "fr3_joint6", "fr3_joint7", ], + q_home: numpy.ndarray[tuple[M], numpy.dtype[numpy.float64]] | None = None, actuators: list[str] = [ "fr3_joint1", "fr3_joint2", @@ -254,8 +265,10 @@ class SimRobotConfig(rcs._core.common.RobotConfig): "fr3_joint7", ], base: str = "base", + dof: int = 7, + joint_limits: numpy.ndarray[tuple[typing.Literal[2], N], numpy.dtype[numpy.float64]] = ..., ) -> None: ... - def add_postfix(self, id: str) -> None: ... + def add_prefix(self, id: str) -> None: ... class SimRobotState(rcs._core.common.RobotState): def __init__(self) -> None: ... @@ -338,7 +351,7 @@ class SimTilburgHandConfig(rcs._core.common.HandConfig): max_joint_position: numpy.ndarray[tuple[typing.Literal[16]], numpy.dtype[numpy.float64]] = ..., min_joint_position: numpy.ndarray[tuple[typing.Literal[16]], numpy.dtype[numpy.float64]] = ..., ) -> None: ... - def add_postfix(self, id: str) -> None: ... + def add_prefix(self, id: str) -> None: ... class SimTilburgHandState(rcs._core.common.HandState): def __init__(self) -> None: ... diff --git a/python/rcs/envs/sim.py b/python/rcs/envs/sim.py index b5348455..287e5b90 100644 --- a/python/rcs/envs/sim.py +++ b/python/rcs/envs/sim.py @@ -44,33 +44,32 @@ def reset( class SimStateObservationWrapper(ActObsInfoWrapper): - DYNAMIC_JOINT_SCHEMA_KEY = "dynamic_joint_schema" - DYNAMIC_JOINT_QPOS_KEY = "dynamic_joint_qpos" - DYNAMIC_JOINT_QVEL_KEY = "dynamic_joint_qvel" + STATE_KEY = "sim_state" + STATE_SPEC_KEY = "sim_state_spec" + STATE_SIZE_KEY = "sim_state_size" def __init__(self, env): super().__init__(env) assert self.env.get_wrapper_attr("PLATFORM") == RobotPlatform.SIMULATION, "Base environment must be simulation." self.sim = cast(sim.Sim, self.get_wrapper_attr("sim")) - self._dynamic_joint_schema = self.sim.get_dynamic_joint_schema() - self._include_schema_in_next_step = True + self._state_spec = self.sim.get_state_spec() + self._include_state_spec_in_next_step = True def observation(self, observation: dict[str, Any], info: dict[str, Any]) -> tuple[dict[str, Any], dict[str, Any]]: observation = dict(observation) - dynamic_joint_state = self.sim.get_dynamic_joint_state() - observation[self.DYNAMIC_JOINT_QPOS_KEY] = dynamic_joint_state["qpos"] - observation[self.DYNAMIC_JOINT_QVEL_KEY] = dynamic_joint_state["qvel"] - if self._include_schema_in_next_step: - observation[self.DYNAMIC_JOINT_SCHEMA_KEY] = self._dynamic_joint_schema - self._include_schema_in_next_step = False + sim_state = self.sim.get_state() + observation[self.STATE_KEY] = sim_state + observation[self.STATE_SIZE_KEY] = sim_state.shape[0] + if self._include_state_spec_in_next_step: + observation[self.STATE_SPEC_KEY] = self._state_spec + self._include_state_spec_in_next_step = False return observation, info def reset( self, *, seed: int | None = None, options: dict[str, Any] | None = None ) -> tuple[dict[str, Any], dict[str, Any]]: obs, info = super().reset(seed=seed, options=options) - # Re-emit the schema on the first recorded step after each reset. - self._include_schema_in_next_step = True + self._include_state_spec_in_next_step = True return obs, info diff --git a/python/rcs/sim/sim.py b/python/rcs/sim/sim.py index ffb17075..081fa716 100644 --- a/python/rcs/sim/sim.py +++ b/python/rcs/sim/sim.py @@ -45,8 +45,6 @@ def gui_loop(gui_uuid: str, close_event): class Sim(_Sim): - STATE_SPEC = mj.mjtState.mjSTATE_INTEGRATION - def __init__(self, mjmdl: str | PathLike, cfg: SimConfig | None = None): mjmdl = Path(mjmdl) if mjmdl.suffix == ".xml": @@ -66,31 +64,38 @@ def __init__(self, mjmdl: str | PathLike, cfg: SimConfig | None = None): if cfg is not None: self.set_config(cfg) - def get_state_spec(self) -> int: - return int(self.STATE_SPEC) + def get_state_spec(self) -> dict[str, list[str] | list[int]]: + return self.get_dynamic_joint_schema() - def get_state_size(self, spec: int | None = None) -> int: - state_spec = self.STATE_SPEC if spec is None else mj.mjtState(spec) - return mj.mj_stateSize(self.model, state_spec) + def get_state_size(self, spec: dict[str, list[str] | list[int]] | None = None) -> int: + state_spec = self.get_state_spec() if spec is None else spec + qpos_size = sum(int(value) for value in state_spec["qpos_sizes"]) + qvel_size = sum(int(value) for value in state_spec["qvel_sizes"]) + return qpos_size + qvel_size - def get_state(self, spec: int | None = None) -> np.ndarray: - state_spec = self.STATE_SPEC if spec is None else mj.mjtState(spec) - state = np.empty(self.get_state_size(int(state_spec)), dtype=np.float64) - mj.mj_getState(self.model, self.data, state, state_spec) - return state + def get_state(self, spec: dict[str, list[str] | list[int]] | None = None) -> np.ndarray: + del spec + dynamic_joint_state = self.get_dynamic_joint_state() + return np.concatenate((dynamic_joint_state["qpos"], dynamic_joint_state["qvel"])) - def set_state(self, state: np.ndarray, spec: int | None = None): - state_spec = self.STATE_SPEC if spec is None else mj.mjtState(spec) + def set_state( + self, + state: np.ndarray, + spec: dict[str, list[str] | list[int]] | None = None, + ): + state_spec = self.get_state_spec() if spec is None else spec state_array = np.asarray(state, dtype=np.float64) - expected_size = self.get_state_size(int(state_spec)) + expected_size = self.get_state_size(state_spec) if state_array.shape != (expected_size,): - msg = ( - f"Expected MuJoCo state with shape ({expected_size},), " - f"got {state_array.shape} for spec {int(state_spec)}." - ) + msg = f"Expected state with shape ({expected_size},), got {state_array.shape}." raise ValueError(msg) - mj.mj_setState(self.model, self.data, state_array, state_spec) - mj.mj_forward(self.model, self.data) + + qpos_size = sum(int(value) for value in state_spec["qpos_sizes"]) + dynamic_joint_state = { + "qpos": state_array[:qpos_size], + "qvel": state_array[qpos_size:], + } + self.set_dynamic_joint_state(state_spec, dynamic_joint_state) def get_dynamic_joint_schema(self) -> dict[str, list[str] | list[int]]: schema = super().get_dynamic_joint_schema() diff --git a/python/rcs/sim_state_replay.py b/python/rcs/sim_state_replay.py index 4a0c7c22..0f2a29b5 100644 --- a/python/rcs/sim_state_replay.py +++ b/python/rcs/sim_state_replay.py @@ -39,21 +39,13 @@ class RecordedSimStep: observation: dict[str, Any] @property - def dynamic_joint_schema(self) -> dict[str, Any] | None: - schema = self.observation.get(SimStateObservationWrapper.DYNAMIC_JOINT_SCHEMA_KEY) - return dict(schema) if schema is not None else None + def sim_state(self) -> np.ndarray: + return np.asarray(self.observation[SimStateObservationWrapper.STATE_KEY], dtype=np.float64) @property - def dynamic_joint_state(self) -> dict[str, np.ndarray] | None: - if ( - SimStateObservationWrapper.DYNAMIC_JOINT_QPOS_KEY not in self.observation - or SimStateObservationWrapper.DYNAMIC_JOINT_QVEL_KEY not in self.observation - ): - return None - return { - "qpos": np.asarray(self.observation[SimStateObservationWrapper.DYNAMIC_JOINT_QPOS_KEY], dtype=np.float64), - "qvel": np.asarray(self.observation[SimStateObservationWrapper.DYNAMIC_JOINT_QVEL_KEY], dtype=np.float64), - } + def sim_state_spec(self) -> dict[str, Any] | None: + schema = self.observation.get(SimStateObservationWrapper.STATE_SPEC_KEY) + return dict(schema) if schema is not None else None class DuckDBUnavailableError(RuntimeError): @@ -161,19 +153,15 @@ def resolve_trajectory_uuid(dataset_path: Path, trajectory_uuid: str | None, pre def restore_sim_step( env: gym.Env, recorded_step: RecordedSimStep, - dynamic_joint_schema: dict[str, Any] | None = None, + sim_state_spec: dict[str, Any] | None = None, ): - sim = env.get_wrapper_attr("sim") - dynamic_joint_state = recorded_step.dynamic_joint_state - if dynamic_joint_state is None: - msg = "Recorded step is missing dynamic joint state data." + resolved_spec = sim_state_spec or recorded_step.sim_state_spec + if resolved_spec is None: + msg = "Recorded sim state is missing its schema." raise ValueError(msg) - resolved_schema = dynamic_joint_schema or recorded_step.dynamic_joint_schema - if resolved_schema is None: - msg = "Recorded dynamic joint state is missing its schema." - raise ValueError(msg) - sim.set_dynamic_joint_state(resolved_schema, dynamic_joint_state) + sim = env.get_wrapper_attr("sim") + sim.set_state(recorded_step.sim_state, spec=resolved_spec) def collect_rgb_frames(env: gym.Env) -> dict[str, np.ndarray]: @@ -209,17 +197,16 @@ def replay_trajectory( output_dir: Path | None = None, ): if not recorded_steps: - msg = "No recorded dynamic joint states found in the requested trajectory." + msg = "No recorded sim states found in the requested trajectory." raise ValueError(msg) - dynamic_joint_schema = next( - (recorded_step.dynamic_joint_schema for recorded_step in recorded_steps if recorded_step.dynamic_joint_schema), - None, + sim_state_spec = next( + (recorded_step.sim_state_spec for recorded_step in recorded_steps if recorded_step.sim_state_spec), None ) env.reset() for recorded_step in recorded_steps: - restore_sim_step(env, recorded_step, dynamic_joint_schema=dynamic_joint_schema) + restore_sim_step(env, recorded_step, sim_state_spec=sim_state_spec) if output_dir is not None: save_rgb_frames(output_dir, recorded_step, collect_rgb_frames(env)) if sleep_s > 0: diff --git a/python/tests/test_sim_state_record_replay.py b/python/tests/test_sim_state_record_replay.py index 2befdc91..3cf29db6 100644 --- a/python/tests/test_sim_state_record_replay.py +++ b/python/tests/test_sim_state_record_replay.py @@ -139,7 +139,7 @@ def test_record_and_replay_sim_state(tmp_path: Path): record_env = StorageWrapper(record_env, str(dataset_path), "test sim replay", batch_size=1, always_record=True) obs, _ = record_env.reset() - assert SimStateObservationWrapper.DYNAMIC_JOINT_SCHEMA_KEY in obs + assert SimStateObservationWrapper.STATE_KEY in obs record_env.step({"delta": np.array([0.125], dtype=np.float64)}) record_env.close() @@ -149,17 +149,18 @@ def test_record_and_replay_sim_state(tmp_path: Path): assert len(rows) == 1 recorded_obs = rows[0]["obs"] - assert SimStateObservationWrapper.DYNAMIC_JOINT_SCHEMA_KEY in recorded_obs - assert SimStateObservationWrapper.DYNAMIC_JOINT_QPOS_KEY in recorded_obs - assert SimStateObservationWrapper.DYNAMIC_JOINT_QVEL_KEY in recorded_obs + assert SimStateObservationWrapper.STATE_KEY in recorded_obs + assert SimStateObservationWrapper.STATE_SPEC_KEY in recorded_obs + assert SimStateObservationWrapper.STATE_SIZE_KEY in recorded_obs + assert ( + len(recorded_obs[SimStateObservationWrapper.STATE_KEY]) + == recorded_obs[SimStateObservationWrapper.STATE_SIZE_KEY] + ) recorded_steps = load_trajectory(dataset_path, rows[0]["uuid"], prefer_duckdb=True) assert len(recorded_steps) == 1 - assert recorded_steps[0].dynamic_joint_schema is not None - assert np.allclose( - recorded_steps[0].dynamic_joint_state["qpos"], # type: ignore[index] - np.asarray(recorded_obs[SimStateObservationWrapper.DYNAMIC_JOINT_QPOS_KEY]), - ) + assert recorded_steps[0].sim_state_spec is not None + assert np.allclose(recorded_steps[0].sim_state, np.asarray(recorded_obs[SimStateObservationWrapper.STATE_KEY])) replay_sim = Sim(model_path) replay_env: gym.Env = DummySimEnv(replay_sim, camera_set=DummyCameraSet(replay_sim)) @@ -167,7 +168,7 @@ def test_record_and_replay_sim_state(tmp_path: Path): render_dir = tmp_path / "rendered" replay_env.reset() - restore_sim_step(replay_env, recorded_steps[0], dynamic_joint_schema=recorded_steps[0].dynamic_joint_schema) + restore_sim_step(replay_env, recorded_steps[0], sim_state_spec=recorded_steps[0].sim_state_spec) assert np.allclose( replay_env.get_wrapper_attr("sim").data.qpos, np.asarray(recorded_obs["qpos"]), atol=1e-9, rtol=0 ) @@ -221,7 +222,7 @@ def _record_dummy_trajectory(dataset_path: Path, model_path: Path) -> tuple[list return recorded_steps, rows[0]["obs"] -def test_dynamic_joint_replay_tolerates_added_and_removed_fixed_scene_elements(tmp_path: Path): +def test_sim_state_replay_tolerates_added_and_removed_fixed_scene_elements(tmp_path: Path): base_model_path = tmp_path / "base.xml" base_model_path.write_text(XML) modified_model_path = tmp_path / "modified.xml" @@ -238,10 +239,8 @@ def test_dynamic_joint_replay_tolerates_added_and_removed_fixed_scene_elements(t replay_env: gym.Env = DummySimEnv(replay_sim) replay_env = SimStateObservationWrapper(replay_env) replay_env.reset() - dynamic_joint_schema = next( - step.dynamic_joint_schema for step in recorded_steps if step.dynamic_joint_schema is not None - ) - restore_sim_step(replay_env, recorded_steps[0], dynamic_joint_schema=dynamic_joint_schema) + sim_state_spec = next(step.sim_state_spec for step in recorded_steps if step.sim_state_spec is not None) + restore_sim_step(replay_env, recorded_steps[0], sim_state_spec=sim_state_spec) assert np.allclose( replay_env.get_wrapper_attr("sim").data.qpos, np.asarray(recorded_obs["qpos"]), atol=1e-9, rtol=0 @@ -268,15 +267,15 @@ def _create_dual_arm_env(scene_name: str): ) -def test_dynamic_joint_state_roundtrip_on_fr3_dual_arm_scene(tmp_path: Path): +def test_sim_state_roundtrip_on_fr3_dual_arm_scene(tmp_path: Path): source_scene_path = REPO_ROOT / "assets/scenes/fr3_dual_arm/scene.xml" source_robot_path = REPO_ROOT / "assets/scenes/fr3_empty_world/robot.xml" source_urdf_path = REPO_ROOT / "assets/scenes/fr3_empty_world/robot.urdf" - modified_scene_path = source_scene_path.parent / "scene_dynamic_joint_test.xml" + modified_scene_path = source_scene_path.parent / "scene_sim_state_test.xml" _write_scene_with_extra_fixed_body_and_camera(source_scene_path, modified_scene_path) - base_scene_name = "fr3_dual_arm_dynamic_joint_base_test" - test_scene_name = "fr3_dual_arm_dynamic_joint_test" + base_scene_name = "fr3_dual_arm_sim_state_base_test" + test_scene_name = "fr3_dual_arm_sim_state_test" scene_kwargs = { "mjcf_robot": str(source_robot_path), "urdf": str(source_urdf_path), @@ -291,17 +290,16 @@ def test_dynamic_joint_state_roundtrip_on_fr3_dual_arm_scene(tmp_path: Path): try: base_env.reset() base_sim = base_env.get_wrapper_attr("sim") - dynamic_joint_schema = base_sim.get_dynamic_joint_schema() - dynamic_joint_state = base_sim.get_dynamic_joint_state() + sim_state_spec = base_sim.get_state_spec() + sim_state = base_sim.get_state() modified_env.reset() modified_sim = modified_env.get_wrapper_attr("sim") - modified_sim.set_dynamic_joint_state(dynamic_joint_schema, dynamic_joint_state) - restored_dynamic_joint_state = modified_sim.get_dynamic_joint_state() + modified_sim.set_state(sim_state, sim_state_spec) + restored_sim_state = modified_sim.get_state() - assert dynamic_joint_schema == modified_sim.get_dynamic_joint_schema() - assert np.allclose(restored_dynamic_joint_state["qpos"], dynamic_joint_state["qpos"], atol=1e-9, rtol=0) - assert np.allclose(restored_dynamic_joint_state["qvel"], dynamic_joint_state["qvel"], atol=1e-9, rtol=0) + assert sim_state_spec == modified_sim.get_state_spec() + assert np.allclose(restored_sim_state, sim_state, atol=1e-9, rtol=0) finally: base_env.close() modified_env.close() diff --git a/src/sim/sim.cpp b/src/sim/sim.cpp index 3566ac15..1ce1f175 100644 --- a/src/sim/sim.cpp +++ b/src/sim/sim.cpp @@ -233,20 +233,20 @@ void Sim::set_dynamic_joint_state(const DynamicJointSchema& schema, "Dynamic joint schema fields must all have the same length."); } - int expected_qpos_size = std::accumulate(schema.qpos_sizes.begin(), - schema.qpos_sizes.end(), 0); - int expected_qvel_size = std::accumulate(schema.qvel_sizes.begin(), - schema.qvel_sizes.end(), 0); + int expected_qpos_size = + std::accumulate(schema.qpos_sizes.begin(), schema.qpos_sizes.end(), 0); + int expected_qvel_size = + std::accumulate(schema.qvel_sizes.begin(), schema.qvel_sizes.end(), 0); if (state.qpos.size() != expected_qpos_size) { std::ostringstream msg; - msg << "Dynamic joint qpos size mismatch. Expected " - << expected_qpos_size << ", got " << state.qpos.size() << "."; + msg << "Dynamic joint qpos size mismatch. Expected " << expected_qpos_size + << ", got " << state.qpos.size() << "."; throw std::invalid_argument(msg.str()); } if (state.qvel.size() != expected_qvel_size) { std::ostringstream msg; - msg << "Dynamic joint qvel size mismatch. Expected " - << expected_qvel_size << ", got " << state.qvel.size() << "."; + msg << "Dynamic joint qvel size mismatch. Expected " << expected_qvel_size + << ", got " << state.qvel.size() << "."; throw std::invalid_argument(msg.str()); } @@ -276,8 +276,9 @@ void Sim::set_dynamic_joint_state(const DynamicJointSchema& schema, msg << "Dynamic joint schema mismatch for joint '" << schema.joint_names[i] << "': expected type=" << target_spec.type << ", qpos_size=" << target_spec.qpos_size - << ", qvel_size=" << target_spec.qvel_size << " but got type=" - << schema.joint_types[i] << ", qpos_size=" << schema.qpos_sizes[i] + << ", qvel_size=" << target_spec.qvel_size + << " but got type=" << schema.joint_types[i] + << ", qpos_size=" << schema.qpos_sizes[i] << ", qvel_size=" << schema.qvel_sizes[i] << "."; throw std::invalid_argument(msg.str()); }