diff --git a/python/rcs/_core/sim.pyi b/python/rcs/_core/sim.pyi
index a75d598a..5c1e5722 100644
--- a/python/rcs/_core/sim.pyi
+++ b/python/rcs/_core/sim.pyi
@@ -11,6 +11,8 @@ import rcs._core.common
__all__: list[str] = [
"CameraType",
+ "DynamicJointSchema",
+ "DynamicJointState",
"FrameSet",
"GuiClient",
"Sim",
@@ -32,6 +34,7 @@ __all__: list[str] = [
"tracking",
]
M = typing.TypeVar("M", bound=int)
+N = typing.TypeVar("N", bound=int)
class CameraType:
"""
@@ -68,6 +71,18 @@ class CameraType:
@property
def value(self) -> int: ...
+class DynamicJointSchema:
+ joint_names: list[str]
+ joint_types: list[int]
+ qpos_sizes: list[int]
+ qvel_sizes: list[int]
+ def __init__(self) -> None: ...
+
+class DynamicJointState:
+ qpos: numpy.ndarray[tuple[M], numpy.dtype[numpy.float64]]
+ qvel: numpy.ndarray[tuple[M], numpy.dtype[numpy.float64]]
+ def __init__(self) -> None: ...
+
class FrameSet:
def __init__(
self,
@@ -93,20 +108,28 @@ class Sim:
def _start_gui_server(self, id: str) -> None: ...
def _stop_gui_server(self) -> None: ...
def get_config(self) -> SimConfig: ...
+ def get_dynamic_joint_schema(self) -> DynamicJointSchema: ...
+ def get_dynamic_joint_state(self) -> DynamicJointState: ...
def is_converged(self) -> bool: ...
def reset(self) -> None: ...
def set_config(self, cfg: SimConfig) -> bool: ...
+ def set_dynamic_joint_state(self, schema: DynamicJointSchema, state: DynamicJointState) -> None: ...
def step(self, k: int) -> None: ...
def step_until_convergence(self) -> None: ...
+ def sync_gui(self) -> None: ...
class SimCameraConfig(rcs._core.common.BaseCameraConfig):
type: CameraType
+ def __copy__(self) -> SimCameraConfig: ...
+ def __deepcopy__(self, arg0: dict) -> SimCameraConfig: ...
def __init__(
self, identifier: str, frame_rate: int, resolution_width: int, resolution_height: int, type: CameraType = ...
) -> None: ...
class SimCameraSet:
- def __init__(self, sim: Sim, cameras: dict[str, SimCameraConfig], render_on_demand: bool = True) -> None: ...
+ def __init__(
+ self, sim: Sim, cameras: dict[str, SimCameraConfig], render_on_demand: bool = True, max_buffer_frames: int = 100
+ ) -> None: ...
def buffer_size(self) -> int: ...
def clear_buffer(self) -> None: ...
def get_latest_frameset(self) -> FrameSet | None: ...
@@ -116,7 +139,7 @@ class SimCameraSet:
class SimConfig:
async_control: bool
- frequency: int
+ frequency: float
max_convergence_steps: int
realtime: bool
def __copy__(self) -> SimConfig: ...
@@ -125,7 +148,7 @@ class SimConfig:
self,
async_control: bool = False,
realtime: bool = False,
- frequency: float = 30,
+ frequency: float = 30.0,
max_convergence_steps: int = 500,
) -> None: ...
@@ -142,6 +165,7 @@ class SimGripperConfig(rcs._core.common.GripperConfig):
collision_geoms_fingers: list[str]
epsilon_inner: float
epsilon_outer: float
+ gripper_type: rcs._core.common.GripperType
ignored_collision_geoms: list[str]
joints: list[str]
max_actuator_width: float
@@ -165,8 +189,9 @@ class SimGripperConfig(rcs._core.common.GripperConfig):
actuator: str = "actuator8",
max_actuator_width: float = 255.0,
min_actuator_width: float = 0.0,
+ gripper_type: rcs._core.common.GripperType = ...,
) -> None: ...
- def add_postfix(self, id: str) -> None: ...
+ def add_prefix(self, id: str) -> None: ...
class SimGripperState(rcs._core.common.GripperState):
def __init__(self) -> None: ...
@@ -193,9 +218,10 @@ class SimRobotConfig(rcs._core.common.RobotConfig):
actuators: list[str]
arm_collision_geoms: list[str]
base: str
+ dof: int
+ joint_limits: numpy.ndarray[tuple[typing.Literal[2], N], numpy.dtype[numpy.float64]]
joint_rotational_tolerance: float
joints: list[str]
- mjcf_scene_path: str
seconds_between_callbacks: float
trajectory_trace: bool
def __copy__(self) -> SimRobotConfig: ...
@@ -208,7 +234,6 @@ class SimRobotConfig(rcs._core.common.RobotConfig):
kinematic_model_path: str = "assets/scenes/fr3_empty_world/robot.xml",
joint_rotational_tolerance: float = 0.0008726646259971648,
seconds_between_callbacks: float = 0.1,
- mjcf_scene_path: str = "assets/scenes/fr3_empty_world/scene.xml",
trajectory_trace: bool = False,
arm_collision_geoms: list[str] = [
"fr3_link0_collision",
@@ -229,6 +254,7 @@ class SimRobotConfig(rcs._core.common.RobotConfig):
"fr3_joint6",
"fr3_joint7",
],
+ q_home: numpy.ndarray[tuple[M], numpy.dtype[numpy.float64]] | None = None,
actuators: list[str] = [
"fr3_joint1",
"fr3_joint2",
@@ -239,8 +265,10 @@ class SimRobotConfig(rcs._core.common.RobotConfig):
"fr3_joint7",
],
base: str = "base",
+ dof: int = 7,
+ joint_limits: numpy.ndarray[tuple[typing.Literal[2], N], numpy.dtype[numpy.float64]] = ...,
) -> None: ...
- def add_postfix(self, id: str) -> None: ...
+ def add_prefix(self, id: str) -> None: ...
class SimRobotState(rcs._core.common.RobotState):
def __init__(self) -> None: ...
@@ -323,7 +351,7 @@ class SimTilburgHandConfig(rcs._core.common.HandConfig):
max_joint_position: numpy.ndarray[tuple[typing.Literal[16]], numpy.dtype[numpy.float64]] = ...,
min_joint_position: numpy.ndarray[tuple[typing.Literal[16]], numpy.dtype[numpy.float64]] = ...,
) -> None: ...
- def add_postfix(self, id: str) -> None: ...
+ def add_prefix(self, id: str) -> None: ...
class SimTilburgHandState(rcs._core.common.HandState):
def __init__(self) -> None: ...
diff --git a/python/rcs/envs/sim.py b/python/rcs/envs/sim.py
index 53fb60c3..287e5b90 100644
--- a/python/rcs/envs/sim.py
+++ b/python/rcs/envs/sim.py
@@ -52,15 +52,26 @@ def __init__(self, env):
super().__init__(env)
assert self.env.get_wrapper_attr("PLATFORM") == RobotPlatform.SIMULATION, "Base environment must be simulation."
self.sim = cast(sim.Sim, self.get_wrapper_attr("sim"))
+ self._state_spec = self.sim.get_state_spec()
+ self._include_state_spec_in_next_step = True
def observation(self, observation: dict[str, Any], info: dict[str, Any]) -> tuple[dict[str, Any], dict[str, Any]]:
observation = dict(observation)
sim_state = self.sim.get_state()
observation[self.STATE_KEY] = sim_state
- observation[self.STATE_SPEC_KEY] = self.sim.get_state_spec()
observation[self.STATE_SIZE_KEY] = sim_state.shape[0]
+ if self._include_state_spec_in_next_step:
+ observation[self.STATE_SPEC_KEY] = self._state_spec
+ self._include_state_spec_in_next_step = False
return observation, info
+ def reset(
+ self, *, seed: int | None = None, options: dict[str, Any] | None = None
+ ) -> tuple[dict[str, Any], dict[str, Any]]:
+ obs, info = super().reset(seed=seed, options=options)
+ self._include_state_spec_in_next_step = True
+ return obs, info
+
class GripperWrapperSim(ActObsInfoWrapper):
def __init__(self, env):
diff --git a/python/rcs/sim/sim.py b/python/rcs/sim/sim.py
index 0d46773c..081fa716 100644
--- a/python/rcs/sim/sim.py
+++ b/python/rcs/sim/sim.py
@@ -11,6 +11,8 @@
import mujoco as mj
import mujoco.viewer
import numpy as np
+from rcs._core.sim import DynamicJointSchema as _DynamicJointSchema
+from rcs._core.sim import DynamicJointState as _DynamicJointState
from rcs._core.sim import GuiClient as _GuiClient
from rcs._core.sim import Sim as _Sim
from rcs.sim import SimConfig, egl_bootstrap
@@ -43,8 +45,6 @@ def gui_loop(gui_uuid: str, close_event):
class Sim(_Sim):
- STATE_SPEC = mj.mjtState.mjSTATE_INTEGRATION
-
def __init__(self, mjmdl: str | PathLike, cfg: SimConfig | None = None):
mjmdl = Path(mjmdl)
if mjmdl.suffix == ".xml":
@@ -64,31 +64,70 @@ def __init__(self, mjmdl: str | PathLike, cfg: SimConfig | None = None):
if cfg is not None:
self.set_config(cfg)
- def get_state_spec(self) -> int:
- return int(self.STATE_SPEC)
+ def get_state_spec(self) -> dict[str, list[str] | list[int]]:
+ return self.get_dynamic_joint_schema()
- def get_state_size(self, spec: int | None = None) -> int:
- state_spec = self.STATE_SPEC if spec is None else mj.mjtState(spec)
- return mj.mj_stateSize(self.model, state_spec)
+ def get_state_size(self, spec: dict[str, list[str] | list[int]] | None = None) -> int:
+ state_spec = self.get_state_spec() if spec is None else spec
+ qpos_size = sum(int(value) for value in state_spec["qpos_sizes"])
+ qvel_size = sum(int(value) for value in state_spec["qvel_sizes"])
+ return qpos_size + qvel_size
- def get_state(self, spec: int | None = None) -> np.ndarray:
- state_spec = self.STATE_SPEC if spec is None else mj.mjtState(spec)
- state = np.empty(self.get_state_size(int(state_spec)), dtype=np.float64)
- mj.mj_getState(self.model, self.data, state, state_spec)
- return state
+ def get_state(self, spec: dict[str, list[str] | list[int]] | None = None) -> np.ndarray:
+ del spec
+ dynamic_joint_state = self.get_dynamic_joint_state()
+ return np.concatenate((dynamic_joint_state["qpos"], dynamic_joint_state["qvel"]))
- def set_state(self, state: np.ndarray, spec: int | None = None):
- state_spec = self.STATE_SPEC if spec is None else mj.mjtState(spec)
+ def set_state(
+ self,
+ state: np.ndarray,
+ spec: dict[str, list[str] | list[int]] | None = None,
+ ):
+ state_spec = self.get_state_spec() if spec is None else spec
state_array = np.asarray(state, dtype=np.float64)
- expected_size = self.get_state_size(int(state_spec))
+ expected_size = self.get_state_size(state_spec)
if state_array.shape != (expected_size,):
- msg = (
- f"Expected MuJoCo state with shape ({expected_size},), "
- f"got {state_array.shape} for spec {int(state_spec)}."
- )
+ msg = f"Expected state with shape ({expected_size},), got {state_array.shape}."
raise ValueError(msg)
- mj.mj_setState(self.model, self.data, state_array, state_spec)
- mj.mj_forward(self.model, self.data)
+
+ qpos_size = sum(int(value) for value in state_spec["qpos_sizes"])
+ dynamic_joint_state = {
+ "qpos": state_array[:qpos_size],
+ "qvel": state_array[qpos_size:],
+ }
+ self.set_dynamic_joint_state(state_spec, dynamic_joint_state)
+
+ def get_dynamic_joint_schema(self) -> dict[str, list[str] | list[int]]:
+ schema = super().get_dynamic_joint_schema()
+ return {
+ "joint_names": list(schema.joint_names),
+ "joint_types": list(schema.joint_types),
+ "qpos_sizes": list(schema.qpos_sizes),
+ "qvel_sizes": list(schema.qvel_sizes),
+ }
+
+ def get_dynamic_joint_state(self) -> dict[str, np.ndarray]:
+ state = super().get_dynamic_joint_state()
+ return {
+ "qpos": np.asarray(state.qpos, dtype=np.float64),
+ "qvel": np.asarray(state.qvel, dtype=np.float64),
+ }
+
+ def set_dynamic_joint_state(
+ self,
+ schema: dict[str, list[str] | list[int]],
+ state: dict[str, np.ndarray],
+ ):
+ dynamic_joint_schema = _DynamicJointSchema()
+ dynamic_joint_schema.joint_names = list(schema["joint_names"])
+ dynamic_joint_schema.joint_types = [int(value) for value in schema["joint_types"]]
+ dynamic_joint_schema.qpos_sizes = [int(value) for value in schema["qpos_sizes"]]
+ dynamic_joint_schema.qvel_sizes = [int(value) for value in schema["qvel_sizes"]]
+
+ dynamic_joint_state = _DynamicJointState()
+ dynamic_joint_state.qpos = np.asarray(state["qpos"], dtype=np.float64)
+ dynamic_joint_state.qvel = np.asarray(state["qvel"], dtype=np.float64)
+ super().set_dynamic_joint_state(dynamic_joint_schema, dynamic_joint_state)
def close_gui(self):
if self._stop_event is not None:
diff --git a/python/rcs/sim_state_replay.py b/python/rcs/sim_state_replay.py
index f07eef91..0f2a29b5 100644
--- a/python/rcs/sim_state_replay.py
+++ b/python/rcs/sim_state_replay.py
@@ -43,8 +43,9 @@ def sim_state(self) -> np.ndarray:
return np.asarray(self.observation[SimStateObservationWrapper.STATE_KEY], dtype=np.float64)
@property
- def sim_state_spec(self) -> int:
- return int(self.observation.get(SimStateObservationWrapper.STATE_SPEC_KEY, 0))
+ def sim_state_spec(self) -> dict[str, Any] | None:
+ schema = self.observation.get(SimStateObservationWrapper.STATE_SPEC_KEY)
+ return dict(schema) if schema is not None else None
class DuckDBUnavailableError(RuntimeError):
@@ -149,9 +150,18 @@ def resolve_trajectory_uuid(dataset_path: Path, trajectory_uuid: str | None, pre
raise ValueError(msg)
-def restore_sim_step(env: gym.Env, recorded_step: RecordedSimStep):
+def restore_sim_step(
+ env: gym.Env,
+ recorded_step: RecordedSimStep,
+ sim_state_spec: dict[str, Any] | None = None,
+):
+ resolved_spec = sim_state_spec or recorded_step.sim_state_spec
+ if resolved_spec is None:
+ msg = "Recorded sim state is missing its schema."
+ raise ValueError(msg)
+
sim = env.get_wrapper_attr("sim")
- sim.set_state(recorded_step.sim_state, spec=recorded_step.sim_state_spec)
+ sim.set_state(recorded_step.sim_state, spec=resolved_spec)
def collect_rgb_frames(env: gym.Env) -> dict[str, np.ndarray]:
@@ -190,9 +200,13 @@ def replay_trajectory(
msg = "No recorded sim states found in the requested trajectory."
raise ValueError(msg)
+ sim_state_spec = next(
+ (recorded_step.sim_state_spec for recorded_step in recorded_steps if recorded_step.sim_state_spec), None
+ )
+
env.reset()
for recorded_step in recorded_steps:
- restore_sim_step(env, recorded_step)
+ restore_sim_step(env, recorded_step, sim_state_spec=sim_state_spec)
if output_dir is not None:
save_rgb_frames(output_dir, recorded_step, collect_rgb_frames(env))
if sleep_s > 0:
diff --git a/python/tests/test_sim_state_record_replay.py b/python/tests/test_sim_state_record_replay.py
index 1e974812..3cf29db6 100644
--- a/python/tests/test_sim_state_record_replay.py
+++ b/python/tests/test_sim_state_record_replay.py
@@ -2,6 +2,7 @@
import importlib.util
import sys
+import xml.etree.ElementTree as ET
from dataclasses import dataclass
from pathlib import Path
@@ -10,8 +11,12 @@
import numpy as np
import pyarrow.dataset as ds
from rcs._core.common import RobotPlatform
+from rcs._core.sim import SimConfig
from rcs.camera.interface import CameraFrame, DataFrame, Frame, FrameSet
+from rcs.envs.base import ControlMode, JointsDictType
+from rcs.envs.creators import SimMultiEnvCreator
from rcs.envs.storage_wrapper import StorageWrapper
+from rcs.envs.utils import default_sim_gripper_cfg, default_sim_robot_cfg
import rcs
@@ -52,7 +57,7 @@ def _load_local_module(module_name: str, relative_path: str):
-
+
@@ -154,6 +159,7 @@ def test_record_and_replay_sim_state(tmp_path: Path):
recorded_steps = load_trajectory(dataset_path, rows[0]["uuid"], prefer_duckdb=True)
assert len(recorded_steps) == 1
+ assert recorded_steps[0].sim_state_spec is not None
assert np.allclose(recorded_steps[0].sim_state, np.asarray(recorded_obs[SimStateObservationWrapper.STATE_KEY]))
replay_sim = Sim(model_path)
@@ -162,7 +168,7 @@ def test_record_and_replay_sim_state(tmp_path: Path):
render_dir = tmp_path / "rendered"
replay_env.reset()
- restore_sim_step(replay_env, recorded_steps[0])
+ restore_sim_step(replay_env, recorded_steps[0], sim_state_spec=recorded_steps[0].sim_state_spec)
assert np.allclose(
replay_env.get_wrapper_attr("sim").data.qpos, np.asarray(recorded_obs["qpos"]), atol=1e-9, rtol=0
)
@@ -174,3 +180,129 @@ def test_record_and_replay_sim_state(tmp_path: Path):
rendered_files = sorted(path.name for path in render_dir.glob("*.png"))
assert rendered_files == ["step-000000-main.png"]
+
+
+def _write_scene_with_extra_fixed_body_and_camera(src: Path, dst: Path):
+ tree = ET.parse(src)
+ root = tree.getroot()
+ for include in root.findall("include"):
+ include_file = include.get("file")
+ if include_file is not None and not Path(include_file).is_absolute():
+ include.set("file", str((src.parent / include_file).resolve()))
+
+ worldbody = root.find("worldbody")
+ assert worldbody is not None
+
+ worldbody.append(
+ ET.Element(
+ "camera",
+ {
+ "name": "replay_extra_cam",
+ "pos": "1.4 0.0 0.9",
+ "xyaxes": "0 1 0 -0.3 0 1",
+ },
+ )
+ )
+ body = ET.SubElement(worldbody, "body", {"name": "replay_extra_bg", "pos": "3 3 3"})
+ ET.SubElement(body, "geom", {"name": "replay_extra_bg_geom", "type": "box", "size": "0.1 0.1 0.1"})
+ tree.write(dst)
+
+
+def _record_dummy_trajectory(dataset_path: Path, model_path: Path) -> tuple[list, dict[str, object]]:
+ record_env: gym.Env = DummySimEnv(Sim(model_path))
+ record_env = SimStateObservationWrapper(record_env)
+ record_env = StorageWrapper(record_env, str(dataset_path), "test sim replay", batch_size=1, always_record=True)
+ record_env.reset()
+ record_env.step({"delta": np.array([0.125], dtype=np.float64)})
+ record_env.close()
+
+ table = ds.dataset(str(dataset_path), format="parquet").to_table().sort_by([("step", "ascending")])
+ rows = table.to_pylist()
+ recorded_steps = load_trajectory(dataset_path, rows[0]["uuid"], prefer_duckdb=True)
+ return recorded_steps, rows[0]["obs"]
+
+
+def test_sim_state_replay_tolerates_added_and_removed_fixed_scene_elements(tmp_path: Path):
+ base_model_path = tmp_path / "base.xml"
+ base_model_path.write_text(XML)
+ modified_model_path = tmp_path / "modified.xml"
+ _write_scene_with_extra_fixed_body_and_camera(base_model_path, modified_model_path)
+
+ for record_model_path, replay_model_path in (
+ (base_model_path, modified_model_path),
+ (modified_model_path, base_model_path),
+ ):
+ dataset_path = tmp_path / f"dataset-{record_model_path.stem}-to-{replay_model_path.stem}"
+ recorded_steps, recorded_obs = _record_dummy_trajectory(dataset_path, record_model_path)
+
+ replay_sim = Sim(replay_model_path)
+ replay_env: gym.Env = DummySimEnv(replay_sim)
+ replay_env = SimStateObservationWrapper(replay_env)
+ replay_env.reset()
+ sim_state_spec = next(step.sim_state_spec for step in recorded_steps if step.sim_state_spec is not None)
+ restore_sim_step(replay_env, recorded_steps[0], sim_state_spec=sim_state_spec)
+
+ assert np.allclose(
+ replay_env.get_wrapper_attr("sim").data.qpos, np.asarray(recorded_obs["qpos"]), atol=1e-9, rtol=0
+ )
+ assert np.allclose(
+ replay_env.get_wrapper_attr("sim").data.qvel, np.asarray(recorded_obs["qvel"]), atol=1e-9, rtol=0
+ )
+
+
+DUAL_ARM_ROBOT2ID = {"left": "0", "right": "1"}
+
+
+def _create_dual_arm_env(scene_name: str):
+ robot_cfg = default_sim_robot_cfg(scene_name, idx="")
+ sim_cfg = SimConfig()
+ sim_cfg.async_control = False
+ return SimMultiEnvCreator()(
+ name2id=DUAL_ARM_ROBOT2ID,
+ robot_cfg=robot_cfg,
+ control_mode=ControlMode.JOINTS,
+ gripper_cfg=default_sim_gripper_cfg(idx=""),
+ sim_cfg=sim_cfg,
+ max_relative_movement=None,
+ )
+
+
+def test_sim_state_roundtrip_on_fr3_dual_arm_scene(tmp_path: Path):
+ source_scene_path = REPO_ROOT / "assets/scenes/fr3_dual_arm/scene.xml"
+ source_robot_path = REPO_ROOT / "assets/scenes/fr3_empty_world/robot.xml"
+ source_urdf_path = REPO_ROOT / "assets/scenes/fr3_empty_world/robot.urdf"
+ modified_scene_path = source_scene_path.parent / "scene_sim_state_test.xml"
+ _write_scene_with_extra_fixed_body_and_camera(source_scene_path, modified_scene_path)
+
+ base_scene_name = "fr3_dual_arm_sim_state_base_test"
+ test_scene_name = "fr3_dual_arm_sim_state_test"
+ scene_kwargs = {
+ "mjcf_robot": str(source_robot_path),
+ "urdf": str(source_urdf_path),
+ "robot_type": rcs.scenes["fr3_dual_arm"].robot_type,
+ "mjb": None,
+ }
+ rcs.scenes[base_scene_name] = rcs.Scene(mjcf_scene=str(source_scene_path), **scene_kwargs)
+ rcs.scenes[test_scene_name] = rcs.Scene(mjcf_scene=str(modified_scene_path), **scene_kwargs)
+
+ base_env = _create_dual_arm_env(base_scene_name)
+ modified_env = _create_dual_arm_env(test_scene_name)
+ try:
+ base_env.reset()
+ base_sim = base_env.get_wrapper_attr("sim")
+ sim_state_spec = base_sim.get_state_spec()
+ sim_state = base_sim.get_state()
+
+ modified_env.reset()
+ modified_sim = modified_env.get_wrapper_attr("sim")
+ modified_sim.set_state(sim_state, sim_state_spec)
+ restored_sim_state = modified_sim.get_state()
+
+ assert sim_state_spec == modified_sim.get_state_spec()
+ assert np.allclose(restored_sim_state, sim_state, atol=1e-9, rtol=0)
+ finally:
+ base_env.close()
+ modified_env.close()
+ del rcs.scenes[test_scene_name]
+ del rcs.scenes[base_scene_name]
+ modified_scene_path.unlink(missing_ok=True)
diff --git a/src/pybind/rcs.cpp b/src/pybind/rcs.cpp
index 007ee475..54222b32 100644
--- a/src/pybind/rcs.cpp
+++ b/src/pybind/rcs.cpp
@@ -703,6 +703,18 @@ PYBIND11_MODULE(_core, m) {
return rcs::sim::SimConfig(self);
});
+ py::class_(sim, "DynamicJointSchema")
+ .def(py::init<>())
+ .def_readwrite("joint_names", &rcs::sim::DynamicJointSchema::joint_names)
+ .def_readwrite("joint_types", &rcs::sim::DynamicJointSchema::joint_types)
+ .def_readwrite("qpos_sizes", &rcs::sim::DynamicJointSchema::qpos_sizes)
+ .def_readwrite("qvel_sizes", &rcs::sim::DynamicJointSchema::qvel_sizes);
+
+ py::class_(sim, "DynamicJointState")
+ .def(py::init<>())
+ .def_readwrite("qpos", &rcs::sim::DynamicJointState::qpos)
+ .def_readwrite("qvel", &rcs::sim::DynamicJointState::qvel);
+
py::class_>(sim, "Sim")
.def(py::init([](long m, long d) {
return std::make_shared((mjModel*)m, (mjData*)d);
@@ -715,6 +727,10 @@ PYBIND11_MODULE(_core, m) {
.def("get_config", &rcs::sim::Sim::get_config)
.def("step", &rcs::sim::Sim::step, py::arg("k"))
.def("reset", &rcs::sim::Sim::reset)
+ .def("get_dynamic_joint_schema", &rcs::sim::Sim::get_dynamic_joint_schema)
+ .def("get_dynamic_joint_state", &rcs::sim::Sim::get_dynamic_joint_state)
+ .def("set_dynamic_joint_state", &rcs::sim::Sim::set_dynamic_joint_state,
+ py::arg("schema"), py::arg("state"))
.def("_start_gui_server", &rcs::sim::Sim::start_gui_server, py::arg("id"))
.def("_stop_gui_server", &rcs::sim::Sim::stop_gui_server);
diff --git a/src/sim/sim.cpp b/src/sim/sim.cpp
index 164ad2c8..1ce1f175 100644
--- a/src/sim/sim.cpp
+++ b/src/sim/sim.cpp
@@ -5,6 +5,9 @@
#include
#include
#include
+#include
+#include
+#include
#include
namespace rcs {
@@ -25,7 +28,65 @@ bool get_last_return_value(ConditionCallback cb) {
return cb.last_return_value;
}
-Sim::Sim(mjModel* m, mjData* d) : m(m), d(d), renderer(m) {};
+int Sim::get_joint_qpos_size(int joint_type) {
+ switch (joint_type) {
+ case mjJNT_FREE:
+ return 7;
+ case mjJNT_BALL:
+ return 4;
+ case mjJNT_SLIDE:
+ case mjJNT_HINGE:
+ return 1;
+ default:
+ throw std::runtime_error("Unsupported MuJoCo joint type for qpos size.");
+ }
+}
+
+int Sim::get_joint_qvel_size(int joint_type) {
+ switch (joint_type) {
+ case mjJNT_FREE:
+ return 6;
+ case mjJNT_BALL:
+ return 3;
+ case mjJNT_SLIDE:
+ case mjJNT_HINGE:
+ return 1;
+ default:
+ throw std::runtime_error("Unsupported MuJoCo joint type for qvel size.");
+ }
+}
+
+void Sim::init_dynamic_joint_specs() {
+ this->dynamic_joint_specs.clear();
+ this->dynamic_joint_name_to_index.clear();
+
+ for (int joint_id = 0; joint_id < this->m->njnt; ++joint_id) {
+ const char* joint_name = mj_id2name(this->m, mjOBJ_JOINT, joint_id);
+ if (joint_name == nullptr || joint_name[0] == '\0') {
+ std::ostringstream msg;
+ msg << "Dynamic joint state requires all joints to be named. Joint id "
+ << joint_id << " is unnamed.";
+ throw std::runtime_error(msg.str());
+ }
+
+ DynamicJointSpec spec{
+ .name = joint_name,
+ .type = this->m->jnt_type[joint_id],
+ .qpos_adr = this->m->jnt_qposadr[joint_id],
+ .qvel_adr = this->m->jnt_dofadr[joint_id],
+ .qpos_size = get_joint_qpos_size(this->m->jnt_type[joint_id]),
+ .qvel_size = get_joint_qvel_size(this->m->jnt_type[joint_id]),
+ };
+
+ this->dynamic_joint_name_to_index[spec.name] =
+ this->dynamic_joint_specs.size();
+ this->dynamic_joint_specs.push_back(spec);
+ }
+}
+
+Sim::Sim(mjModel* m, mjData* d) : m(m), d(d), renderer(m) {
+ this->init_dynamic_joint_specs();
+};
bool Sim::set_config(const SimConfig& cfg) {
this->cfg = cfg;
@@ -118,6 +179,134 @@ void Sim::reset() {
this->reset_callbacks();
}
+DynamicJointSchema Sim::get_dynamic_joint_schema() const {
+ DynamicJointSchema schema;
+ schema.joint_names.reserve(this->dynamic_joint_specs.size());
+ schema.joint_types.reserve(this->dynamic_joint_specs.size());
+ schema.qpos_sizes.reserve(this->dynamic_joint_specs.size());
+ schema.qvel_sizes.reserve(this->dynamic_joint_specs.size());
+
+ for (const DynamicJointSpec& spec : this->dynamic_joint_specs) {
+ schema.joint_names.push_back(spec.name);
+ schema.joint_types.push_back(spec.type);
+ schema.qpos_sizes.push_back(spec.qpos_size);
+ schema.qvel_sizes.push_back(spec.qvel_size);
+ }
+ return schema;
+}
+
+DynamicJointState Sim::get_dynamic_joint_state() const {
+ DynamicJointState state;
+ int total_qpos = 0;
+ int total_qvel = 0;
+ for (const DynamicJointSpec& spec : this->dynamic_joint_specs) {
+ total_qpos += spec.qpos_size;
+ total_qvel += spec.qvel_size;
+ }
+
+ state.qpos = rcs::common::VectorXd(total_qpos);
+ state.qvel = rcs::common::VectorXd(total_qvel);
+
+ int qpos_offset = 0;
+ int qvel_offset = 0;
+ for (const DynamicJointSpec& spec : this->dynamic_joint_specs) {
+ for (int i = 0; i < spec.qpos_size; ++i) {
+ state.qpos[qpos_offset + i] = this->d->qpos[spec.qpos_adr + i];
+ }
+ for (int i = 0; i < spec.qvel_size; ++i) {
+ state.qvel[qvel_offset + i] = this->d->qvel[spec.qvel_adr + i];
+ }
+ qpos_offset += spec.qpos_size;
+ qvel_offset += spec.qvel_size;
+ }
+
+ return state;
+}
+
+void Sim::set_dynamic_joint_state(const DynamicJointSchema& schema,
+ const DynamicJointState& state) {
+ size_t joint_count = schema.joint_names.size();
+ if (schema.joint_types.size() != joint_count ||
+ schema.qpos_sizes.size() != joint_count ||
+ schema.qvel_sizes.size() != joint_count) {
+ throw std::invalid_argument(
+ "Dynamic joint schema fields must all have the same length.");
+ }
+
+ int expected_qpos_size =
+ std::accumulate(schema.qpos_sizes.begin(), schema.qpos_sizes.end(), 0);
+ int expected_qvel_size =
+ std::accumulate(schema.qvel_sizes.begin(), schema.qvel_sizes.end(), 0);
+ if (state.qpos.size() != expected_qpos_size) {
+ std::ostringstream msg;
+ msg << "Dynamic joint qpos size mismatch. Expected " << expected_qpos_size
+ << ", got " << state.qpos.size() << ".";
+ throw std::invalid_argument(msg.str());
+ }
+ if (state.qvel.size() != expected_qvel_size) {
+ std::ostringstream msg;
+ msg << "Dynamic joint qvel size mismatch. Expected " << expected_qvel_size
+ << ", got " << state.qvel.size() << ".";
+ throw std::invalid_argument(msg.str());
+ }
+
+ std::vector matched_target_joints(this->dynamic_joint_specs.size(),
+ false);
+ int qpos_offset = 0;
+ int qvel_offset = 0;
+ for (size_t i = 0; i < joint_count; ++i) {
+ auto spec_iter =
+ this->dynamic_joint_name_to_index.find(schema.joint_names[i]);
+ if (spec_iter == this->dynamic_joint_name_to_index.end()) {
+ std::cerr << "WARNING: Recorded dynamic joint '" << schema.joint_names[i]
+ << "' is missing in the replay model. Skipping it."
+ << std::endl;
+ qpos_offset += schema.qpos_sizes[i];
+ qvel_offset += schema.qvel_sizes[i];
+ continue;
+ }
+
+ const DynamicJointSpec& target_spec =
+ this->dynamic_joint_specs[spec_iter->second];
+ matched_target_joints[spec_iter->second] = true;
+ if (target_spec.type != schema.joint_types[i] ||
+ target_spec.qpos_size != schema.qpos_sizes[i] ||
+ target_spec.qvel_size != schema.qvel_sizes[i]) {
+ std::ostringstream msg;
+ msg << "Dynamic joint schema mismatch for joint '"
+ << schema.joint_names[i] << "': expected type=" << target_spec.type
+ << ", qpos_size=" << target_spec.qpos_size
+ << ", qvel_size=" << target_spec.qvel_size
+ << " but got type=" << schema.joint_types[i]
+ << ", qpos_size=" << schema.qpos_sizes[i]
+ << ", qvel_size=" << schema.qvel_sizes[i] << ".";
+ throw std::invalid_argument(msg.str());
+ }
+
+ for (int j = 0; j < target_spec.qpos_size; ++j) {
+ this->d->qpos[target_spec.qpos_adr + j] = state.qpos[qpos_offset + j];
+ }
+ for (int j = 0; j < target_spec.qvel_size; ++j) {
+ this->d->qvel[target_spec.qvel_adr + j] = state.qvel[qvel_offset + j];
+ }
+
+ qpos_offset += schema.qpos_sizes[i];
+ qvel_offset += schema.qvel_sizes[i];
+ }
+
+ for (size_t i = 0; i < this->dynamic_joint_specs.size(); ++i) {
+ if (!matched_target_joints[i]) {
+ std::cerr << "WARNING: Replay model dynamic joint '"
+ << this->dynamic_joint_specs[i].name
+ << "' is missing in the recorded schema. Leaving it at its "
+ "current value."
+ << std::endl;
+ }
+ }
+
+ mj_forward(this->m, this->d);
+}
+
void Sim::reset_callbacks() {
for (size_t i = 0; i < std::size(this->callbacks); ++i) {
this->callbacks[i].last_call_timestamp = 0;
diff --git a/src/sim/sim.h b/src/sim/sim.h
index 62a2fbbf..b96a9aa7 100644
--- a/src/sim/sim.h
+++ b/src/sim/sim.h
@@ -5,10 +5,13 @@
#include
#include
#include
+#include
+#include
#include "boost/interprocess/managed_shared_memory.hpp"
#include "gui.h"
#include "mujoco/mujoco.h"
+#include "rcs/utils.h"
namespace rcs {
namespace sim {
@@ -55,16 +58,42 @@ struct RenderingCallback {
mjtNum last_call_timestamp; // in seconds
};
+struct DynamicJointSchema {
+ std::vector joint_names;
+ std::vector joint_types;
+ std::vector qpos_sizes;
+ std::vector qvel_sizes;
+};
+
+struct DynamicJointState {
+ rcs::common::VectorXd qpos;
+ rcs::common::VectorXd qvel;
+};
+
class Sim {
private:
+ struct DynamicJointSpec {
+ std::string name;
+ int type;
+ int qpos_adr;
+ int qvel_adr;
+ int qpos_size;
+ int qvel_size;
+ };
+
SimConfig cfg;
std::vector callbacks;
std::vector any_callbacks;
std::vector all_callbacks;
std::vector rendering_callbacks;
+ std::vector dynamic_joint_specs;
+ std::unordered_map dynamic_joint_name_to_index;
void invoke_callbacks();
bool invoke_condition_callbacks();
void invoke_rendering_callbacks();
+ void init_dynamic_joint_specs();
+ static int get_joint_qpos_size(int joint_type);
+ static int get_joint_qvel_size(int joint_type);
size_t convergence_steps = 0;
bool converged = true;
std::optional gui;
@@ -82,6 +111,10 @@ class Sim {
void step(size_t k);
void reset_callbacks();
void reset();
+ DynamicJointSchema get_dynamic_joint_schema() const;
+ DynamicJointState get_dynamic_joint_state() const;
+ void set_dynamic_joint_state(const DynamicJointSchema& schema,
+ const DynamicJointState& state);
/* NOTE: IMPORTANT, the callback is not necessarily called at exactly the
* the requested interval. We invoke a callback if the elapsed simulation time
* since the last call of the callback is greater than the requested time.