diff --git a/Makefile b/Makefile index e3830ac8..9d5908b6 100644 --- a/Makefile +++ b/Makefile @@ -32,9 +32,11 @@ stubgen: find ./python -not -path "./python/rcs/_core/*" -name '*.pyi' -delete find ./python/rcs/_core -name '*.pyi' -print | xargs sed -i 's/tuple\[typing\.Literal\[\([0-9]\+\)\], typing\.Literal\[1\]\]/tuple\[typing\.Literal[\1]\]/g' find ./python/rcs/_core -name '*.pyi' -print | xargs sed -i 's/tuple\[\([M|N]\), typing\.Literal\[1\]\]/tuple\[\1\]/g' - sed -i 's/ q_home: numpy\.ndarray\[tuple\[M\], numpy\.dtype\[numpy\.float64\]\] | None/ q_home: numpy.ndarray | None/' python/rcs/_core/common.pyi - python -c "from pathlib import Path; p=Path('python/rcs/_core/common.pyi'); t=p.read_text(); t=t.replace('numpy.ndarray[tuple[typing.Literal[2], N], numpy.dtype[numpy.float64]]', 'numpy.ndarray[tuple[typing.Literal[2], typing.Any], numpy.dtype[numpy.float64]]'); p.write_text(t)" - python -c "from pathlib import Path; p=Path('python/rcs/_core/sim.pyi'); t=p.read_text(); t=t.replace('numpy.ndarray[tuple[typing.Literal[2], N], numpy.dtype[numpy.float64]]', 'numpy.ndarray[tuple[typing.Literal[2], typing.Any], numpy.dtype[numpy.float64]]'); t=t.replace(', max_buffer_frames: int = 100', ''); p.write_text(t)" + find ./python/rcs/_core -name '*.pyi' -print | xargs sed -i 's/class RobotConfig/class RobotConfig(typing.Generic[M])/g' + find ./python/rcs/_core -name '*.pyi' -print | xargs sed -i 's/class SimRobotConfig(rcs._core.common.RobotConfig)/class SimRobotConfig(rcs._core.common.RobotConfig[M])/g' + find ./python/rcs/_core -name '*.pyi' -print | xargs sed -i 's/class DynamicJointState/class DynamicJointState(typing.Generic[M])/g' + find ./python/rcs/_core -name '*.pyi' -print | xargs sed -i 's/N = typing.TypeVar("N", bound=int)//g' + find ./python/rcs/_core -name '*.pyi' -print | xargs sed -i 's/, N/, M/g' python ci_scripts/generate_common_typing.py ruff check --fix python/rcs/_core python/rcs/common_typing.py isort python/rcs/_core python/rcs/common_typing.py diff --git a/extensions/rcs_fr3/pyproject.toml b/extensions/rcs_fr3/pyproject.toml index 249ca89a..1127acf9 100644 --- a/extensions/rcs_fr3/pyproject.toml +++ b/extensions/rcs_fr3/pyproject.toml @@ -19,7 +19,7 @@ dependencies = ["rcs>=0.6.3", "frankik"] readme = "README.md" maintainers = [{ name = "Tobias Jülg", email = "tobias.juelg@utn.de" }] authors = [{ name = "Tobias Jülg", email = "tobias.juelg@utn.de" }] -requires-python = ">=3.10" +requires-python = ">=3.11" [tool.scikit-build] diff --git a/extensions/rcs_panda/pyproject.toml b/extensions/rcs_panda/pyproject.toml index 99319869..ed122a0c 100644 --- a/extensions/rcs_panda/pyproject.toml +++ b/extensions/rcs_panda/pyproject.toml @@ -18,7 +18,7 @@ dependencies = ["rcs>=0.6.3"] readme = "README.md" maintainers = [{ name = "Tobias Jülg", email = "tobias.juelg@utn.de" }] authors = [{ name = "Tobias Jülg", email = "tobias.juelg@utn.de" }] -requires-python = ">=3.10" +requires-python = ">=3.11" [tool.scikit-build] diff --git a/extensions/rcs_robotics_library/pyproject.toml b/extensions/rcs_robotics_library/pyproject.toml index d2fe0da7..17a2c09b 100644 --- a/extensions/rcs_robotics_library/pyproject.toml +++ b/extensions/rcs_robotics_library/pyproject.toml @@ -21,7 +21,7 @@ authors = [ { name = "Tobias Jülg", email = "tobias.juelg@utn.de" }, { name = "Pierre Krack", email = "pierre.krack@utn.de" }, ] -requires-python = ">=3.10" +requires-python = ">=3.11" [tool.scikit-build] diff --git a/extensions/rcs_robotiq2f85/pyproject.toml b/extensions/rcs_robotiq2f85/pyproject.toml index 835dfe02..ddb09018 100644 --- a/extensions/rcs_robotiq2f85/pyproject.toml +++ b/extensions/rcs_robotiq2f85/pyproject.toml @@ -17,5 +17,5 @@ maintainers = [ authors = [ { name = "Tobias Jülg", email = "tobias.juelg@utn.de" }, ] -requires-python = ">=3.10" +requires-python = ">=3.11" license = { text = "AGPL-3.0-or-later" } diff --git a/extensions/rcs_so101/pyproject.toml b/extensions/rcs_so101/pyproject.toml index 06d22ba3..0dbf0130 100644 --- a/extensions/rcs_so101/pyproject.toml +++ b/extensions/rcs_so101/pyproject.toml @@ -18,7 +18,7 @@ dependencies = ["rcs>=0.6.3", "lerobot==0.3.3"] readme = "README.md" maintainers = [{ name = "Tobias Jülg", email = "tobias.juelg@utn.de" }] authors = [{ name = "Tobias Jülg", email = "tobias.juelg@utn.de" }] -requires-python = ">=3.10" +requires-python = ">=3.11" [tool.scikit-build] build.verbose = true diff --git a/extensions/rcs_tacto/pyproject.toml b/extensions/rcs_tacto/pyproject.toml index f69dd2cb..7c94e292 100644 --- a/extensions/rcs_tacto/pyproject.toml +++ b/extensions/rcs_tacto/pyproject.toml @@ -17,7 +17,7 @@ maintainers = [ { name = "Seongjin Bien", email = "seongjin.bien@utn.de" }, ] authors = [{ name = "Seongjin Bien", email = "seongjin.bien@utn.de" }] -requires-python = ">=3.10" +requires-python = ">=3.11" [tool.black] line-length = 120 diff --git a/extensions/rcs_ur5e/pyproject.toml b/extensions/rcs_ur5e/pyproject.toml index 1cb50fde..84ddac07 100644 --- a/extensions/rcs_ur5e/pyproject.toml +++ b/extensions/rcs_ur5e/pyproject.toml @@ -15,7 +15,7 @@ authors = [ { name = "Tobias Jülg", email = "tobias.juelg@utn.de" }, { name = "Johannes Hechtl", email = "johannes.hechtl@siemens.com" }, ] -requires-python = ">=3.10" +requires-python = ">=3.11" [tool.black] line-length = 120 diff --git a/extensions/rcs_usb_cam/pyproject.toml b/extensions/rcs_usb_cam/pyproject.toml index 8d42c626..4a39396b 100644 --- a/extensions/rcs_usb_cam/pyproject.toml +++ b/extensions/rcs_usb_cam/pyproject.toml @@ -12,7 +12,7 @@ maintainers = [ { name = "Seongjin Bien", email = "seongjin.bien@utn.de" }, ] authors = [{ name = "Seongjin Bien", email = "seongjin.bien@utn.de" }] -requires-python = ">=3.10" +requires-python = ">=3.11" [tool.black] line-length = 120 diff --git a/extensions/rcs_xarm7/pyproject.toml b/extensions/rcs_xarm7/pyproject.toml index a024385b..3432f3e3 100644 --- a/extensions/rcs_xarm7/pyproject.toml +++ b/extensions/rcs_xarm7/pyproject.toml @@ -16,7 +16,7 @@ authors = [ { name = "Tobias Jülg", email = "tobias.juelg@utn.de" }, { name = "Ken Nakahara", email = "knakahara@lasr.org" }, ] -requires-python = ">=3.10" +requires-python = ">=3.11" [tool.black] line-length = 120 diff --git a/pyproject.toml b/pyproject.toml index 6a9213d3..ed210c40 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,7 +50,7 @@ authors = [ { name = "Pierre Krack", email = "pierre.krack@utn.de" }, { name = "Seongjin Bien", email = "seongjin.bien@utn.de" }, ] -requires-python = ">=3.10" +requires-python = ">=3.11" license = { file = "LICENSE" } [dependency-groups] diff --git a/python/rcs/_core/common.pyi b/python/rcs/_core/common.pyi index cf9d1d5d..b2749517 100644 --- a/python/rcs/_core/common.pyi +++ b/python/rcs/_core/common.pyi @@ -41,7 +41,6 @@ __all__: list[str] = [ "TRIPOD_GRASP", ] M = typing.TypeVar("M", bound=int) -N = typing.TypeVar("N", bound=int) class BaseCameraConfig: frame_rate: int @@ -239,12 +238,12 @@ class Robot: def to_pose_in_robot_coordinates(self, pose_in_world_coordinates: Pose) -> Pose: ... def to_pose_in_world_coordinates(self, pose_in_robot_coordinates: Pose) -> Pose: ... -class RobotConfig: +class RobotConfig(typing.Generic[M]): attachment_site: str dof: int - joint_limits: numpy.ndarray[tuple[typing.Literal[2], typing.Any], numpy.dtype[numpy.float64]] + joint_limits: numpy.ndarray[tuple[typing.Literal[2], M], numpy.dtype[numpy.float64]] kinematic_model_path: str - q_home: numpy.ndarray | None + q_home: numpy.ndarray[tuple[M], numpy.dtype[numpy.float64]] | None robot_platform: RobotPlatform robot_type: RobotType tcp_offset: Pose @@ -252,7 +251,7 @@ class RobotConfig: self, robot_type: RobotType = ..., dof: int = 7, - joint_limits: numpy.ndarray[tuple[typing.Literal[2], typing.Any], numpy.dtype[numpy.float64]] = ..., + joint_limits: numpy.ndarray[tuple[typing.Literal[2], M], numpy.dtype[numpy.float64]] = ..., robot_platform: RobotPlatform = ..., tcp_offset: Pose = ..., attachment_site: str = "attachment_site", diff --git a/python/rcs/_core/sim.pyi b/python/rcs/_core/sim.pyi index fa5a8899..8ea8ae40 100644 --- a/python/rcs/_core/sim.pyi +++ b/python/rcs/_core/sim.pyi @@ -11,6 +11,8 @@ import rcs._core.common __all__: list[str] = [ "CameraType", + "DynamicJointSchema", + "DynamicJointState", "FrameSet", "GuiClient", "Sim", @@ -32,7 +34,6 @@ __all__: list[str] = [ "tracking", ] M = typing.TypeVar("M", bound=int) -N = typing.TypeVar("N", bound=int) class CameraType: """ @@ -69,6 +70,18 @@ class CameraType: @property def value(self) -> int: ... +class DynamicJointSchema: + joint_names: list[str] + joint_types: list[int] + qpos_sizes: list[int] + qvel_sizes: list[int] + def __init__(self) -> None: ... + +class DynamicJointState(typing.Generic[M]): + qpos: numpy.ndarray[tuple[M], numpy.dtype[numpy.float64]] + qvel: numpy.ndarray[tuple[M], numpy.dtype[numpy.float64]] + def __init__(self) -> None: ... + class FrameSet: def __init__( self, @@ -94,9 +107,12 @@ class Sim: def _start_gui_server(self, id: str) -> None: ... def _stop_gui_server(self) -> None: ... def get_config(self) -> SimConfig: ... + def get_dynamic_joint_schema(self) -> DynamicJointSchema: ... + def get_dynamic_joint_state(self) -> DynamicJointState: ... def is_converged(self) -> bool: ... def reset(self) -> None: ... def set_config(self, cfg: SimConfig) -> bool: ... + def set_dynamic_joint_state(self, schema: DynamicJointSchema, state: DynamicJointState) -> None: ... def step(self, k: int) -> None: ... def step_until_convergence(self) -> None: ... def sync_gui(self) -> None: ... @@ -110,7 +126,9 @@ class SimCameraConfig(rcs._core.common.BaseCameraConfig): ) -> None: ... class SimCameraSet: - def __init__(self, sim: Sim, cameras: dict[str, SimCameraConfig], render_on_demand: bool = True) -> None: ... + def __init__( + self, sim: Sim, cameras: dict[str, SimCameraConfig], render_on_demand: bool = True, max_buffer_frames: int = 100 + ) -> None: ... def buffer_size(self) -> int: ... def clear_buffer(self) -> None: ... def get_latest_frameset(self) -> FrameSet | None: ... @@ -195,12 +213,12 @@ class SimRobot(rcs._core.common.Robot): def set_config(self, cfg: SimRobotConfig) -> bool: ... def set_joints_hard(self, q: numpy.ndarray[tuple[M], numpy.dtype[numpy.float64]]) -> None: ... -class SimRobotConfig(rcs._core.common.RobotConfig): +class SimRobotConfig(rcs._core.common.RobotConfig[M]): actuators: list[str] arm_collision_geoms: list[str] base: str dof: int - joint_limits: numpy.ndarray[tuple[typing.Literal[2], typing.Any], numpy.dtype[numpy.float64]] + joint_limits: numpy.ndarray[tuple[typing.Literal[2], M], numpy.dtype[numpy.float64]] joint_rotational_tolerance: float joints: list[str] seconds_between_callbacks: float @@ -247,7 +265,7 @@ class SimRobotConfig(rcs._core.common.RobotConfig): ], base: str = "base", dof: int = 7, - joint_limits: numpy.ndarray[tuple[typing.Literal[2], typing.Any], numpy.dtype[numpy.float64]] = ..., + joint_limits: numpy.ndarray[tuple[typing.Literal[2], M], numpy.dtype[numpy.float64]] = ..., ) -> None: ... def add_prefix(self, id: str) -> None: ... diff --git a/python/rcs/envs/base.py b/python/rcs/envs/base.py index 8dc97d54..f339b23b 100644 --- a/python/rcs/envs/base.py +++ b/python/rcs/envs/base.py @@ -170,6 +170,7 @@ class ArmObsType(TQuatDictType, JointsDictType, TRPYDictType): ... CartOrJointContType: TypeAlias = TQuatDictType | JointsDictType | TRPYDictType LimitedCartOrJointContType: TypeAlias = LimitedTQuatRelDictType | LimitedJointsRelDictType | LimitedTRPYRelDictType +SimStateSchema: TypeAlias = dict[str, list[str] | list[int]] class ArmWithGripper(TQuatDictType, GripperDictType): ... @@ -204,7 +205,7 @@ class HardwareEnv(BaseEnv): class SimEnv(BaseEnv): PLATFORM = RobotPlatform.SIMULATION STATE_KEY = "sim_state" - STATE_SPEC_KEY = "sim_state_spec" + STATE_SCHEMA_KEY = "sim_state_schema" def __init__(self, sim: simulation.Sim, return_state=True) -> None: self.sim = sim @@ -212,10 +213,10 @@ def __init__(self, sim: simulation.Sim, return_state=True) -> None: self.frame_rate = SimpleFrameRate(cfg.frequency, "MoJoCo Simulation Loop") self.main_greenlet: greenlet | None = None self.return_state = return_state - self._replay_state: tuple[np.ndarray, int | None] | None = None + self._replay_state: tuple[np.ndarray, SimStateSchema | None] | None = None - def set_replay_state(self, state: np.ndarray, spec: int | None = None): - self._replay_state = (state, spec) + def set_replay_state(self, state: np.ndarray, schema: SimStateSchema | None = None): + self._replay_state = (state, schema) def step(self, action: dict[str, Any]) -> tuple[dict[str, Any], float, bool, bool, dict]: if self.main_greenlet is not None: @@ -255,7 +256,7 @@ def reset( def observation(self, observation: dict[str, Any], info: dict[str, Any]) -> tuple[dict[str, Any], dict[str, Any]]: sim_state = self.sim.get_state() info[self.STATE_KEY] = sim_state - info[self.STATE_SPEC_KEY] = self.sim.get_state_spec() + info[self.STATE_SCHEMA_KEY] = self.sim.get_state_schema() return observation, info diff --git a/python/rcs/envs/configs.py b/python/rcs/envs/configs.py index afbd9f95..f29165b8 100644 --- a/python/rcs/envs/configs.py +++ b/python/rcs/envs/configs.py @@ -1,6 +1,6 @@ import copy import time -from typing import ClassVar +from typing import ClassVar, Literal import numpy as np from rcs._core.common import FrankaHandTCPOffset, GripperType, RobotType @@ -37,7 +37,7 @@ class EmptyWorldFR3(SimEnvCreator): def config(self) -> SimEnvCreatorConfig: q_home = rcs.ROBOTS[RobotType.FR3].q_home q_home[-1] = np.pi / 4 - robot_cfg = SimRobotConfig( + robot_cfg: SimRobotConfig[Literal[7]] = SimRobotConfig( robot_type=RobotType.FR3, tcp_offset=GRIPPER_OFFSETS[rcs.common.GripperType.FrankaHand], attachment_site=rcs.ROBOTS[RobotType.FR3].attachment_site, @@ -183,7 +183,7 @@ class EmptyWorldFR3Duo(SimEnvCreator): gripper_mesh_quaternion_offset: ClassVar[list[float]] = [0, 0, 0.7071068, 0.7071068] def config(self) -> SimEnvCreatorConfig: - robot_cfg = SimRobotConfig( + robot_cfg: SimRobotConfig[Literal[7]] = SimRobotConfig( tcp_offset=GRIPPER_OFFSETS[rcs.common.GripperType("Robotiq2F85")], robot_type=RobotType.FR3, attachment_site=rcs.ROBOTS[RobotType.FR3].attachment_site, @@ -224,7 +224,7 @@ def config(self) -> SimEnvCreatorConfig: joint_limits=rcs.ROBOTS[RobotType.FR3].joint_limits, q_home=rcs.HOME_POSITIONS["FR3_DUO_LEFT"], ) - robot_cfg_right = copy.deepcopy(robot_cfg) + robot_cfg_right: SimRobotConfig[Literal[7]] = copy.deepcopy(robot_cfg) robot_cfg_right.q_home = rcs.HOME_POSITIONS["FR3_DUO_RIGHT"] robot_cfgs: dict[str, SimRobotConfig] = {"left": robot_cfg, "right": robot_cfg_right} diff --git a/python/rcs/envs/scenes.py b/python/rcs/envs/scenes.py index b661204d..05e04060 100644 --- a/python/rcs/envs/scenes.py +++ b/python/rcs/envs/scenes.py @@ -49,6 +49,7 @@ def __call__(self, **kwargs) -> gym.Env: class WrapperConfig: binary_gripper: bool = True home_on_reset: bool = True + include_depth: bool = False #### SIM SPECIFIC #### @@ -239,7 +240,13 @@ def create_model(self, cfg: SimEnvCreatorConfig) -> MjModel: if cfg.root_frame_objects is not None: for object_id, (object_xml, object2root_frame) in cfg.root_frame_objects.items(): object2world = cfg.root_frame_to_world * object2root_frame - self.add_object_mujoco(composer, object_id, object_xml, object2world) + self.add_object_mujoco( + composer, + object_id, + object_xml, + object2world, + register_root_relative_replay_free_joints=True, + ) # add external objects if cfg.world_frame_objects is not None: for object_id, (object_xml, object2world) in cfg.world_frame_objects.items(): @@ -324,6 +331,11 @@ def create_env_from_model(self, cfg: SimEnvCreatorConfig, mjmodel: MjModel) -> g prefixed_cfg = self.prefixed_cfg(cfg) simulation = Sim(mjmodel, prefixed_cfg.sim_cfg) + if isinstance(mjmodel, ModelComposer): + simulation.configure_state_encodings( + root_frame_to_world=cfg.root_frame_to_world, + root_relative_free_joints=mjmodel.root_relative_replay_free_joints, + ) envs: dict[str, gym.Env] = {} env: gym.Env @@ -353,7 +365,7 @@ def create_env_from_model(self, cfg: SimEnvCreatorConfig, mjmodel: MjModel) -> g BaseCameraSet, SimCameraSet(simulation, prefixed_cfg.camera_cfgs, physical_units=True, render_on_demand=True), ) - env = CameraSetWrapper(env, camera_set, include_depth=True) + env = CameraSetWrapper(env, camera_set, include_depth=cfg.wrapper_cfg.include_depth) env = self.add_task_env(prefixed_cfg.task_cfg, env, simulation, cfg) if not prefixed_cfg.headless: env.get_wrapper_attr("sim").open_gui() @@ -373,13 +385,20 @@ def add_task_env( return env def add_object_mujoco( - self, composer: ModelComposer, object_id: str, object_xml: str, object2world: rcs.common.Pose + self, + composer: ModelComposer, + object_id: str, + object_xml: str, + object2world: rcs.common.Pose, + *, + register_root_relative_replay_free_joints: bool = False, ): """Add an object to the Mujoco scene.""" composer.add_object_world_frame( object_xml, object_prefix=object_id + "_", pose=object2world, + register_root_relative_replay_free_joints=register_root_relative_replay_free_joints, ) def add_object_robot_frame_mujoco( diff --git a/python/rcs/envs/tasks.py b/python/rcs/envs/tasks.py index fde0a7d7..5444d461 100644 --- a/python/rcs/envs/tasks.py +++ b/python/rcs/envs/tasks.py @@ -167,6 +167,7 @@ def add_task_mujoco(cfg: PickTaskConfig, composer: ModelComposer, env_cfg: SimEn cfg.object_xml, object_prefix=cfg.prefix, pose=object2world, + register_root_relative_replay_free_joints=True, ) @staticmethod diff --git a/python/rcs/sim/composer.py b/python/rcs/sim/composer.py index 27f73d48..2a6ebe6e 100644 --- a/python/rcs/sim/composer.py +++ b/python/rcs/sim/composer.py @@ -20,6 +20,7 @@ def __init__( self.spec.compiler.autolimits = True self.add_gravcomp = add_gravcomp self._gravcomp_prefixes: set[str] = set() + self._root_relative_replay_free_joints: set[str] = set() def _resolve_asset_paths(self, spec: mujoco.MjSpec, xml_path: str): """Resolves relative paths to absolute ones.""" @@ -67,6 +68,17 @@ def _apply_pose(self, body: mujoco._specs.MjsBody, pose: Pose): body.pos = list(pose.translation()) body.quat = list(pose.rotation_q_wxyz()) + def _prefixed_free_joint_names(self, spec: mujoco.MjSpec, prefix: str) -> list[str]: + free_joint_type = int(mujoco.mjtJoint.mjJNT_FREE) + return [f"{prefix}{joint.name}" for joint in spec.joints if joint.name and int(joint.type) == free_joint_type] + + def register_root_relative_replay_free_joints(self, joint_names: list[str]): + self._root_relative_replay_free_joints.update(joint_names) + + @property + def root_relative_replay_free_joints(self) -> set[str]: + return set(self._root_relative_replay_free_joints) + def add_camera( self, resolution: tuple[int, int], @@ -230,6 +242,8 @@ def add_object_robot_frame( object_prefix: str, attachment_site_name: str, pose: Pose | None = None, + *, + register_root_relative_replay_free_joints: bool = False, ) -> mujoco._specs.MjsBody: """Attaches an object to a robot attachment site with an optional local pose offset.""" if pose is None: @@ -243,6 +257,8 @@ def add_object_robot_frame( object_spec = mujoco.MjSpec.from_file(xml_path) self._resolve_asset_paths(object_spec, xml_path) + if register_root_relative_replay_free_joints: + self.register_root_relative_replay_free_joints(self._prefixed_free_joint_names(object_spec, object_prefix)) object_root = object_spec.worldbody.first_body() object_root = attachment_site.attach(object_root, object_prefix, "") @@ -250,7 +266,12 @@ def add_object_robot_frame( return object_root def add_object_world_frame( - self, xml_path: str, object_prefix: str, pose: Pose | None = None + self, + xml_path: str, + object_prefix: str, + pose: Pose | None = None, + *, + register_root_relative_replay_free_joints: bool = False, ) -> mujoco._specs.MjsBody: """ Attaches a single object MJCF at a specific pose. @@ -265,6 +286,8 @@ def add_object_world_frame( # Load the child spec child_spec = mujoco.MjSpec.from_file(xml_path) self._resolve_asset_paths(child_spec, xml_path) + if register_root_relative_replay_free_joints: + self.register_root_relative_replay_free_joints(self._prefixed_free_joint_names(child_spec, object_prefix)) # Attach using a frame frame = self.spec.worldbody.add_frame() diff --git a/python/rcs/sim/replayer.py b/python/rcs/sim/replayer.py index f8ec6832..05a091a1 100644 --- a/python/rcs/sim/replayer.py +++ b/python/rcs/sim/replayer.py @@ -9,9 +9,21 @@ import rcs.envs.configs as env_configs import rcs.envs.tasks as env_tasks from rcs._core.sim import SimConfig -from rcs.envs.base import RelativeTo, SimEnv +from rcs.envs.base import RelativeTo, SimEnv, SimStateSchema from rcs.envs.scenes import SimEnvCreator from rcs.envs.storage_wrapper import StorageWrapper +from rcs.sim.sim import RAW_STATE_ENCODING + + +def _normalize_sim_state_schema(value: Any) -> SimStateSchema: + joint_names = [str(item) for item in value["joint_names"]] + return { + "joint_names": joint_names, + "joint_types": [int(item) for item in value["joint_types"]], + "qpos_sizes": [int(item) for item in value["qpos_sizes"]], + "qvel_sizes": [int(item) for item in value["qvel_sizes"]], + "encodings": [str(item) for item in value.get("encodings", [RAW_STATE_ENCODING] * len(joint_names))], + } @dataclass(frozen=True) @@ -38,13 +50,13 @@ def sim_state(self) -> np.ndarray: raise KeyError(msg) @property - def sim_state_spec(self) -> int | None: - if SimEnv.STATE_SPEC_KEY in self.info: - return int(self.info[SimEnv.STATE_SPEC_KEY]) + def sim_state_schema(self) -> SimStateSchema | None: + if SimEnv.STATE_SCHEMA_KEY in self.info: + return _normalize_sim_state_schema(self.info[SimEnv.STATE_SCHEMA_KEY]) for value in self.info.values(): - if isinstance(value, dict) and SimEnv.STATE_SPEC_KEY in value: - return int(value[SimEnv.STATE_SPEC_KEY]) + if isinstance(value, dict) and SimEnv.STATE_SCHEMA_KEY in value: + return _normalize_sim_state_schema(value[SimEnv.STATE_SCHEMA_KEY]) return None @@ -94,9 +106,9 @@ def restore_sim_step(env: gym.Env, recorded_step: RecordedSimStep): lead_env = None if lead_env is not None: - lead_env.set_replay_state(recorded_step.sim_state, spec=recorded_step.sim_state_spec) + lead_env.set_replay_state(recorded_step.sim_state, schema=recorded_step.sim_state_schema) else: - env.get_wrapper_attr("set_replay_state")(recorded_step.sim_state, spec=recorded_step.sim_state_spec) + env.get_wrapper_attr("set_replay_state")(recorded_step.sim_state, schema=recorded_step.sim_state_schema) def replay_trajectory(env: gym.Env, recorded_steps: list[RecordedSimStep], headless: bool): diff --git a/python/rcs/sim/sim.py b/python/rcs/sim/sim.py index 07963e4d..43a49a20 100644 --- a/python/rcs/sim/sim.py +++ b/python/rcs/sim/sim.py @@ -1,6 +1,7 @@ import atexit import contextlib import multiprocessing as mp +import typing import uuid from logging import getLogger from multiprocessing.synchronize import Event as EventClass @@ -12,6 +13,8 @@ import mujoco as mj import mujoco.viewer import numpy as np +from rcs._core import common +from rcs._core.sim import DynamicJointSchema, DynamicJointState from rcs._core.sim import GuiClient as _GuiClient from rcs._core.sim import Sim as _Sim from rcs.sim import SimConfig, egl_bootstrap @@ -24,6 +27,8 @@ # Target frames per second FPS = 60 +RAW_STATE_ENCODING = "raw" +ROOT_RELATIVE_FREE_STATE_ENCODING = "root_relative_free" def gui_loop(gui_uuid: str, close_event): @@ -45,8 +50,6 @@ def gui_loop(gui_uuid: str, close_event): class Sim(_Sim): - STATE_SPEC = mj.mjtState.mjSTATE_INTEGRATION - def __init__(self, mjmdl: str | PathLike | ModelComposer, cfg: SimConfig | None = None): if isinstance(mjmdl, ModelComposer): self.model = mjmdl.get_model() @@ -68,34 +71,154 @@ def __init__(self, mjmdl: str | PathLike | ModelComposer, cfg: SimConfig | None self._gui_process: Optional[mp.context.SpawnProcess] = None self._stop_event: Optional[EventClass] = None self._gui_atexit_registered = False + self._root_frame_to_world = common.Pose() + self._root_relative_replay_free_joints: set[str] = set() if cfg is not None: self.set_config(cfg) - - def get_state_spec(self) -> int: - return int(self.STATE_SPEC) - - def get_state_size(self, spec: int | None = None) -> int: - state_spec = self.STATE_SPEC if spec is None else mj.mjtState(spec) - return mj.mj_stateSize(self.model, state_spec) - - def get_state(self, spec: int | None = None) -> np.ndarray: - state_spec = self.STATE_SPEC if spec is None else mj.mjtState(spec) - state = np.empty(self.get_state_size(int(state_spec)), dtype=np.float64) - mj.mj_getState(self.model, self.data, state, state_spec) - return state - - def set_state(self, state: np.ndarray, spec: int | None = None): - state_spec = self.STATE_SPEC if spec is None else mj.mjtState(spec) + self._state_schema = self._compute_state_schema() + + def configure_state_encodings( + self, + *, + root_frame_to_world: common.Pose, + root_relative_free_joints: typing.Iterable[str] = (), + ): + self._root_frame_to_world = common.Pose(root_frame_to_world) + self._root_relative_replay_free_joints = set(root_relative_free_joints) + self._state_schema = self._compute_state_schema() + + def get_state_schema(self) -> dict[str, list[str] | list[int]]: + return self._state_schema + + def _compute_state_schema(self) -> dict[str, list[str] | list[int]]: + schema = super().get_dynamic_joint_schema() + return { + "joint_names": list(schema.joint_names), + "joint_types": list(schema.joint_types), + "qpos_sizes": list(schema.qpos_sizes), + "qvel_sizes": list(schema.qvel_sizes), + "encodings": [ + ( + ROOT_RELATIVE_FREE_STATE_ENCODING + if joint_name in self._root_relative_replay_free_joints + else RAW_STATE_ENCODING + ) + for joint_name in schema.joint_names + ], + } + + def get_state_size(self, schema: dict[str, list[str] | list[int]] | None = None) -> int: + state_schema = self.get_state_schema() if schema is None else schema + qpos_size = sum(int(value) for value in state_schema["qpos_sizes"]) + qvel_size = sum(int(value) for value in state_schema["qvel_sizes"]) + return qpos_size + qvel_size + + def get_state(self) -> np.ndarray: + state = super().get_dynamic_joint_state() + qpos, qvel = self._transform_state( + np.array(state.qpos, copy=True), + np.array(state.qvel, copy=True), + self.get_state_schema(), + encode=True, + ) + return np.concatenate((qpos, qvel)) + + def set_state( + self, + state: np.ndarray, + schema: dict[str, list[str] | list[int]] | None = None, + ): + state_schema = self.get_state_schema() if schema is None else schema state_array = np.asarray(state, dtype=np.float64) - expected_size = self.get_state_size(int(state_spec)) + expected_size = self.get_state_size(state_schema) if state_array.shape != (expected_size,): - msg = ( - f"Expected MuJoCo state with shape ({expected_size},), " - f"got {state_array.shape} for spec {int(state_spec)}." - ) + msg = f"Expected state with shape ({expected_size},), got {state_array.shape}." raise ValueError(msg) - mj.mj_setState(self.model, self.data, state_array, state_spec) - mj.mj_forward(self.model, self.data) + + qpos_size = sum(int(value) for value in state_schema["qpos_sizes"]) + qpos, qvel = self._transform_state( + np.array(state_array[:qpos_size], copy=True), + np.array(state_array[qpos_size:], copy=True), + state_schema, + encode=False, + ) + + dynamic_joint_schema = DynamicJointSchema() + dynamic_joint_schema.joint_names = typing.cast(list[str], list(state_schema["joint_names"])) + dynamic_joint_schema.joint_types = [int(value) for value in state_schema["joint_types"]] + dynamic_joint_schema.qpos_sizes = [int(value) for value in state_schema["qpos_sizes"]] + dynamic_joint_schema.qvel_sizes = [int(value) for value in state_schema["qvel_sizes"]] + + dynamic_joint_state = DynamicJointState() # type: ignore + dynamic_joint_state.qpos = qpos + dynamic_joint_state.qvel = qvel + super().set_dynamic_joint_state(dynamic_joint_schema, dynamic_joint_state) + + def _transform_state( + self, + qpos: np.ndarray, + qvel: np.ndarray, + schema: dict[str, list[str] | list[int]], + *, + encode: bool, + ) -> tuple[np.ndarray, np.ndarray]: + joint_names = typing.cast(list[str], list(schema["joint_names"])) + joint_types = [int(value) for value in schema["joint_types"]] + qpos_sizes = [int(value) for value in schema["qpos_sizes"]] + qvel_sizes = [int(value) for value in schema["qvel_sizes"]] + encodings = typing.cast(list[str], list(schema.get("encodings", [RAW_STATE_ENCODING] * len(joint_names)))) + root_transform = self._root_frame_to_world.inverse() if encode else self._root_frame_to_world + root_rotation = root_transform.rotation_m() + free_joint_type = int(mj.mjtJoint.mjJNT_FREE) + + qpos_offset = 0 + qvel_offset = 0 + for joint_name, joint_type, joint_qpos_size, joint_qvel_size, encoding in zip( + joint_names, joint_types, qpos_sizes, qvel_sizes, encodings, strict=True + ): + if encoding == RAW_STATE_ENCODING: + pass + elif encoding == ROOT_RELATIVE_FREE_STATE_ENCODING: + if joint_type != free_joint_type or joint_qpos_size != 7 or joint_qvel_size != 6: + msg = ( + f"Joint '{joint_name}' uses encoding '{ROOT_RELATIVE_FREE_STATE_ENCODING}' " + "but is not a free joint." + ) + raise ValueError(msg) + qpos[qpos_offset : qpos_offset + joint_qpos_size], qvel[qvel_offset : qvel_offset + joint_qvel_size] = ( + self._transform_root_relative_free_joint( + qpos[qpos_offset : qpos_offset + joint_qpos_size], + qvel[qvel_offset : qvel_offset + joint_qvel_size], + root_transform=root_transform, + root_rotation=root_rotation, + ) + ) + else: + msg = f"Unsupported sim state encoding '{encoding}' for joint '{joint_name}'." + raise ValueError(msg) + + qpos_offset += joint_qpos_size + qvel_offset += joint_qvel_size + + return qpos, qvel + + def _transform_root_relative_free_joint( + self, + joint_qpos: np.ndarray, + joint_qvel: np.ndarray, + *, + root_transform: common.Pose, + root_rotation: np.ndarray, + ) -> tuple[np.ndarray, np.ndarray]: + joint_pose = common.Pose( + translation=np.asarray(joint_qpos[:3], dtype=np.float64), + quaternion=np.asarray([joint_qpos[4], joint_qpos[5], joint_qpos[6], joint_qpos[3]], dtype=np.float64), + ) + transformed_pose = root_transform * joint_pose + return ( + np.concatenate((transformed_pose.translation(), transformed_pose.rotation_q_wxyz())), + np.concatenate((root_rotation @ joint_qvel[:3], root_rotation @ joint_qvel[3:6])), + ) def close_gui(self): if self._stop_event is not None: diff --git a/python/tests/test_replayer.py b/python/tests/test_replayer.py index c819c6b2..1c299f25 100644 --- a/python/tests/test_replayer.py +++ b/python/tests/test_replayer.py @@ -1,5 +1,6 @@ +import xml.etree.ElementTree as ET from pathlib import Path -from typing import Any +from typing import Any, cast import duckdb import numpy as np @@ -8,10 +9,25 @@ from rcs.envs.configs import EmptyWorldFR3Duo from rcs.envs.storage_wrapper import StorageWrapper from rcs.envs.tasks import PickTaskConfig -from rcs.sim.replayer import load_distinct_uuids, load_trajectory, replay_trajectory - - -def _build_env(output_dir: Path, *, with_cameras: bool, instruction: str = "") -> StorageWrapper: +from rcs.sim.replayer import ( + load_distinct_uuids, + load_trajectory, + replay_trajectory, + restore_sim_step, +) +from rcs.sim.sim import ROOT_RELATIVE_FREE_STATE_ENCODING + +import rcs + + +def _build_env( + output_dir: Path, + *, + with_cameras: bool, + instruction: str = "", + scene_path: Path | None = None, + root_frame_to_world: rcs.common.Pose | None = None, +) -> StorageWrapper: scene = EmptyWorldFR3Duo() cfg = scene.config() cfg.sim_cfg = SimConfig(async_control=True, realtime=False, frequency=30, max_convergence_steps=500) @@ -20,6 +36,10 @@ def _build_env(output_dir: Path, *, with_cameras: bool, instruction: str = "") - if cfg.root_frame_objects is None: cfg.root_frame_objects = {} cfg.task_cfg = PickTaskConfig(robot_name="right") + if scene_path is not None: + cfg.scene = str(scene_path) + if root_frame_to_world is not None: + cfg.root_frame_to_world = root_frame_to_world if not with_cameras: cfg.camera_cfgs = {} else: @@ -41,8 +61,21 @@ def _build_env(output_dir: Path, *, with_cameras: bool, instruction: str = "") - ) -def _record_source_dataset(dataset_dir: Path, *, limit: int, instruction: str) -> None: - env = _build_env(dataset_dir, with_cameras=False, instruction=instruction) +def _record_source_dataset( + dataset_dir: Path, + *, + limit: int, + instruction: str, + scene_path: Path | None = None, + root_frame_to_world: rcs.common.Pose | None = None, +) -> None: + env = _build_env( + dataset_dir, + with_cameras=False, + instruction=instruction, + scene_path=scene_path, + root_frame_to_world=root_frame_to_world, + ) try: env.reset() action = { @@ -94,9 +127,21 @@ def _replay_rows(dataset_dir: Path): connection.close() -def _replay_prefix(output_dir: Path, *, with_cameras: bool, limit: int) -> None: +def _replay_prefix( + output_dir: Path, + *, + with_cameras: bool, + limit: int, + scene_path: Path | None = None, + root_frame_to_world: rcs.common.Pose | None = None, +) -> None: source_dir = output_dir.parent / "source" - env = _build_env(output_dir, with_cameras=with_cameras) + env = _build_env( + output_dir, + with_cameras=with_cameras, + scene_path=scene_path, + root_frame_to_world=root_frame_to_world, + ) try: uuid = load_distinct_uuids(source_dir)[0] recorded_steps = load_trajectory(source_dir, uuid)[:limit] @@ -106,6 +151,32 @@ def _replay_prefix(output_dir: Path, *, with_cameras: bool, limit: int) -> None: env.close() +def _write_scene_with_extra_fixed_body_and_camera(src: Path, dst: Path): + tree = ET.parse(src) + root = tree.getroot() + for include in root.findall("include"): + include_file = include.get("file") + if include_file is not None and not Path(include_file).is_absolute(): + include.set("file", str((src.parent / include_file).resolve())) + + worldbody = root.find("worldbody") + assert worldbody is not None + + worldbody.append( + ET.Element( + "camera", + { + "name": "replay_extra_cam", + "pos": "1.4 0.0 0.9", + "xyaxes": "0 1 0 -0.3 0 1", + }, + ) + ) + body = ET.SubElement(worldbody, "body", {"name": "replay_extra_bg", "pos": "3 3 3"}) + ET.SubElement(body, "geom", {"name": "replay_extra_bg_geom", "type": "box", "size": "0.1 0.1 0.1"}) + tree.write(dst) + + def _assert_nested_close(actual: Any, expected: Any, *, atol: float = 1e-6): if isinstance(expected, dict): assert isinstance(actual, dict) @@ -151,6 +222,31 @@ def _strip_frames(obs: dict[str, Any]) -> dict[str, Any]: return {key: value for key, value in obs.items() if key != "frames"} +def _tilted_root_frame_to_world() -> rcs.common.Pose: + return rcs.common.Pose( + translation=np.array([0.35, -0.2, 0.15]), + quaternion=np.array([0.0, 0.0, 0.38268343, 0.92387953]), + ) + + +def _joint_qpos_from_state(state: np.ndarray, schema: dict[str, list[str] | list[int]], joint_name: str) -> np.ndarray: + joint_names = cast(list[str], schema["joint_names"]) + joint_index = joint_names.index(joint_name) + qpos_offset = sum(int(size) for size in schema["qpos_sizes"][:joint_index]) + qpos_size = int(schema["qpos_sizes"][joint_index]) + return np.asarray(state[qpos_offset : qpos_offset + qpos_size], dtype=np.float64) + + +def _joint_qpos_in_root_frame(env: StorageWrapper, joint_name: str, root_frame_to_world: rcs.common.Pose) -> np.ndarray: + joint_qpos_world = np.asarray(env.get_wrapper_attr("sim").data.joint(joint_name).qpos, dtype=np.float64) + joint_pose_world = rcs.common.Pose( + translation=joint_qpos_world[:3], + quaternion=np.array([joint_qpos_world[4], joint_qpos_world[5], joint_qpos_world[6], joint_qpos_world[3]]), + ) + joint_pose_root = root_frame_to_world.inverse() * joint_pose_world + return np.concatenate((joint_pose_root.translation(), joint_pose_root.rotation_q_wxyz())) + + def test_replayer_reproduces_existing_parquet_prefix_without_cameras(tmp_path: Path): source_dir = tmp_path / "source" replay_dir = tmp_path / "replayed" @@ -198,6 +294,33 @@ def test_replayer_reproduces_existing_parquet_prefix_without_cameras(tmp_path: P _assert_nested_close(replay_instruction, source_instruction) +def test_replayer_restores_sim_state_across_fixed_scene_changes(tmp_path: Path): + source_scene_path = Path(EmptyWorldFR3Duo().config().scene) + modified_scene_path = tmp_path / "modified_scene.xml" + _write_scene_with_extra_fixed_body_and_camera(source_scene_path, modified_scene_path) + + for record_scene_path, replay_scene_path in ( + (source_scene_path, modified_scene_path), + (modified_scene_path, source_scene_path), + ): + case_dir = tmp_path / f"{record_scene_path.stem}-to-{replay_scene_path.stem}" + source_dir = case_dir / "source" + replay_dir = case_dir / "replayed" + + _record_source_dataset(source_dir, limit=3, instruction="pick up cube", scene_path=record_scene_path) + _replay_prefix(replay_dir, with_cameras=False, limit=3, scene_path=replay_scene_path) + + source_uuid = load_distinct_uuids(source_dir)[0] + replay_uuid = load_distinct_uuids(replay_dir)[0] + source_steps = load_trajectory(source_dir, source_uuid) + replay_steps = load_trajectory(replay_dir, replay_uuid) + + assert len(source_steps) == len(replay_steps) == 3 + for replay_step, source_step in zip(replay_steps, source_steps, strict=True): + assert replay_step.sim_state_schema == source_step.sim_state_schema + assert np.allclose(replay_step.sim_state, source_step.sim_state, atol=1e-5, rtol=0) + + def test_replayer_adds_cameras_to_existing_episode_without_cameras(tmp_path: Path): source_dir = tmp_path / "source" replay_dir = tmp_path / "replayed_with_cameras" @@ -246,3 +369,56 @@ def test_replayer_adds_cameras_to_existing_episode_without_cameras(tmp_path: Pat _assert_nested_close(replay_action, source_action, atol=1e-8) _assert_nested_close(replay_env_action, source_env_action, atol=1e-8) _assert_nested_close(replay_instruction, source_instruction) + + +def test_replayer_restores_root_relative_free_joint_state_across_root_frame_changes(tmp_path: Path): + source_dir = tmp_path / "source" + replay_dir = tmp_path / "replayed" + default_root = rcs.common.Pose() + shifted_root = _tilted_root_frame_to_world() + object_joint_name = "PickTask_box_joint" + + _record_source_dataset( + source_dir, + limit=3, + instruction="pick up cube", + root_frame_to_world=default_root, + ) + _replay_prefix( + replay_dir, + with_cameras=False, + limit=3, + root_frame_to_world=shifted_root, + ) + + source_uuid = load_distinct_uuids(source_dir)[0] + replay_uuid = load_distinct_uuids(replay_dir)[0] + source_steps = load_trajectory(source_dir, source_uuid) + replay_steps = load_trajectory(replay_dir, replay_uuid) + + assert len(source_steps) == len(replay_steps) == 3 + for replay_step, source_step in zip(replay_steps, source_steps, strict=True): + assert replay_step.sim_state_schema == source_step.sim_state_schema + assert replay_step.sim_state_schema is not None + schema = replay_step.sim_state_schema + joint_names = cast(list[str], schema["joint_names"]) + encodings = cast(list[str], schema["encodings"]) + object_joint_index = joint_names.index(object_joint_name) + assert encodings[object_joint_index] == ROOT_RELATIVE_FREE_STATE_ENCODING + assert np.allclose(replay_step.sim_state, source_step.sim_state, atol=1e-5, rtol=0) + + replay_env = _build_env(replay_dir / "inspection", with_cameras=False, root_frame_to_world=shifted_root) + try: + replay_env.reset() + lead_env = replay_env.get_wrapper_attr("lead_env") + for source_step in source_steps: + restore_sim_step(replay_env, source_step) + lead_env.step_sim() + assert source_step.sim_state_schema is not None + expected_joint_qpos = _joint_qpos_from_state( + source_step.sim_state, source_step.sim_state_schema, object_joint_name + ) + actual_joint_qpos = _joint_qpos_in_root_frame(replay_env, object_joint_name, shifted_root) + assert np.allclose(actual_joint_qpos, expected_joint_qpos, atol=1e-5, rtol=0) + finally: + replay_env.close() diff --git a/src/pybind/rcs.cpp b/src/pybind/rcs.cpp index ab7772c0..9be15edf 100644 --- a/src/pybind/rcs.cpp +++ b/src/pybind/rcs.cpp @@ -730,6 +730,18 @@ PYBIND11_MODULE(_core, m) { return rcs::sim::SimConfig(self); }); + py::class_(sim, "DynamicJointSchema") + .def(py::init<>()) + .def_readwrite("joint_names", &rcs::sim::DynamicJointSchema::joint_names) + .def_readwrite("joint_types", &rcs::sim::DynamicJointSchema::joint_types) + .def_readwrite("qpos_sizes", &rcs::sim::DynamicJointSchema::qpos_sizes) + .def_readwrite("qvel_sizes", &rcs::sim::DynamicJointSchema::qvel_sizes); + + py::class_(sim, "DynamicJointState") + .def(py::init<>()) + .def_readwrite("qpos", &rcs::sim::DynamicJointState::qpos) + .def_readwrite("qvel", &rcs::sim::DynamicJointState::qvel); + py::class_>(sim, "Sim") .def(py::init([](long m, long d) { return std::make_shared((mjModel*)m, (mjData*)d); @@ -743,6 +755,10 @@ PYBIND11_MODULE(_core, m) { .def("step", &rcs::sim::Sim::step, py::arg("k")) .def("reset", &rcs::sim::Sim::reset) .def("sync_gui", &rcs::sim::Sim::sync_gui) + .def("get_dynamic_joint_schema", &rcs::sim::Sim::get_dynamic_joint_schema) + .def("get_dynamic_joint_state", &rcs::sim::Sim::get_dynamic_joint_state) + .def("set_dynamic_joint_state", &rcs::sim::Sim::set_dynamic_joint_state, + py::arg("schema"), py::arg("state")) .def("_start_gui_server", &rcs::sim::Sim::start_gui_server, py::arg("id")) .def("_stop_gui_server", &rcs::sim::Sim::stop_gui_server); diff --git a/src/sim/sim.cpp b/src/sim/sim.cpp index 8facd93d..2cbafbb3 100644 --- a/src/sim/sim.cpp +++ b/src/sim/sim.cpp @@ -5,6 +5,9 @@ #include #include #include +#include +#include +#include #include namespace rcs { @@ -25,7 +28,65 @@ bool get_last_return_value(ConditionCallback cb) { return cb.last_return_value; } -Sim::Sim(mjModel* m, mjData* d) : m(m), d(d), renderer(m) {}; +int Sim::get_joint_qpos_size(int joint_type) { + switch (joint_type) { + case mjJNT_FREE: + return 7; + case mjJNT_BALL: + return 4; + case mjJNT_SLIDE: + case mjJNT_HINGE: + return 1; + default: + throw std::runtime_error("Unsupported MuJoCo joint type for qpos size."); + } +} + +int Sim::get_joint_qvel_size(int joint_type) { + switch (joint_type) { + case mjJNT_FREE: + return 6; + case mjJNT_BALL: + return 3; + case mjJNT_SLIDE: + case mjJNT_HINGE: + return 1; + default: + throw std::runtime_error("Unsupported MuJoCo joint type for qvel size."); + } +} + +void Sim::init_dynamic_joint_specs() { + this->dynamic_joint_specs.clear(); + this->dynamic_joint_name_to_index.clear(); + + for (int joint_id = 0; joint_id < this->m->njnt; ++joint_id) { + const char* joint_name = mj_id2name(this->m, mjOBJ_JOINT, joint_id); + if (joint_name == nullptr || joint_name[0] == '\0') { + std::ostringstream msg; + msg << "Dynamic joint state requires all joints to be named. Joint id " + << joint_id << " is unnamed."; + throw std::runtime_error(msg.str()); + } + + DynamicJointSpec spec{ + .name = joint_name, + .type = this->m->jnt_type[joint_id], + .qpos_adr = this->m->jnt_qposadr[joint_id], + .qvel_adr = this->m->jnt_dofadr[joint_id], + .qpos_size = get_joint_qpos_size(this->m->jnt_type[joint_id]), + .qvel_size = get_joint_qvel_size(this->m->jnt_type[joint_id]), + }; + + this->dynamic_joint_name_to_index[spec.name] = + this->dynamic_joint_specs.size(); + this->dynamic_joint_specs.push_back(spec); + } +} + +Sim::Sim(mjModel* m, mjData* d) : m(m), d(d), renderer(m) { + this->init_dynamic_joint_specs(); +}; bool Sim::set_config(const SimConfig& cfg) { this->cfg = cfg; @@ -118,6 +179,134 @@ void Sim::reset() { this->reset_callbacks(); } +DynamicJointSchema Sim::get_dynamic_joint_schema() const { + DynamicJointSchema schema; + schema.joint_names.reserve(this->dynamic_joint_specs.size()); + schema.joint_types.reserve(this->dynamic_joint_specs.size()); + schema.qpos_sizes.reserve(this->dynamic_joint_specs.size()); + schema.qvel_sizes.reserve(this->dynamic_joint_specs.size()); + + for (const DynamicJointSpec& spec : this->dynamic_joint_specs) { + schema.joint_names.push_back(spec.name); + schema.joint_types.push_back(spec.type); + schema.qpos_sizes.push_back(spec.qpos_size); + schema.qvel_sizes.push_back(spec.qvel_size); + } + return schema; +} + +DynamicJointState Sim::get_dynamic_joint_state() const { + DynamicJointState state; + int total_qpos = 0; + int total_qvel = 0; + for (const DynamicJointSpec& spec : this->dynamic_joint_specs) { + total_qpos += spec.qpos_size; + total_qvel += spec.qvel_size; + } + + state.qpos = rcs::common::VectorXd(total_qpos); + state.qvel = rcs::common::VectorXd(total_qvel); + + int qpos_offset = 0; + int qvel_offset = 0; + for (const DynamicJointSpec& spec : this->dynamic_joint_specs) { + for (int i = 0; i < spec.qpos_size; ++i) { + state.qpos[qpos_offset + i] = this->d->qpos[spec.qpos_adr + i]; + } + for (int i = 0; i < spec.qvel_size; ++i) { + state.qvel[qvel_offset + i] = this->d->qvel[spec.qvel_adr + i]; + } + qpos_offset += spec.qpos_size; + qvel_offset += spec.qvel_size; + } + + return state; +} + +void Sim::set_dynamic_joint_state(const DynamicJointSchema& schema, + const DynamicJointState& state) { + size_t joint_count = schema.joint_names.size(); + if (schema.joint_types.size() != joint_count || + schema.qpos_sizes.size() != joint_count || + schema.qvel_sizes.size() != joint_count) { + throw std::invalid_argument( + "Dynamic joint schema fields must all have the same length."); + } + + int expected_qpos_size = + std::accumulate(schema.qpos_sizes.begin(), schema.qpos_sizes.end(), 0); + int expected_qvel_size = + std::accumulate(schema.qvel_sizes.begin(), schema.qvel_sizes.end(), 0); + if (state.qpos.size() != expected_qpos_size) { + std::ostringstream msg; + msg << "Dynamic joint qpos size mismatch. Expected " << expected_qpos_size + << ", got " << state.qpos.size() << "."; + throw std::invalid_argument(msg.str()); + } + if (state.qvel.size() != expected_qvel_size) { + std::ostringstream msg; + msg << "Dynamic joint qvel size mismatch. Expected " << expected_qvel_size + << ", got " << state.qvel.size() << "."; + throw std::invalid_argument(msg.str()); + } + + std::vector matched_target_joints(this->dynamic_joint_specs.size(), + false); + int qpos_offset = 0; + int qvel_offset = 0; + for (size_t i = 0; i < joint_count; ++i) { + auto spec_iter = + this->dynamic_joint_name_to_index.find(schema.joint_names[i]); + if (spec_iter == this->dynamic_joint_name_to_index.end()) { + std::cerr << "WARNING: Recorded dynamic joint '" << schema.joint_names[i] + << "' is missing in the replay model. Skipping it." + << std::endl; + qpos_offset += schema.qpos_sizes[i]; + qvel_offset += schema.qvel_sizes[i]; + continue; + } + + const DynamicJointSpec& target_spec = + this->dynamic_joint_specs[spec_iter->second]; + matched_target_joints[spec_iter->second] = true; + if (target_spec.type != schema.joint_types[i] || + target_spec.qpos_size != schema.qpos_sizes[i] || + target_spec.qvel_size != schema.qvel_sizes[i]) { + std::ostringstream msg; + msg << "Dynamic joint schema mismatch for joint '" + << schema.joint_names[i] << "': expected type=" << target_spec.type + << ", qpos_size=" << target_spec.qpos_size + << ", qvel_size=" << target_spec.qvel_size + << " but got type=" << schema.joint_types[i] + << ", qpos_size=" << schema.qpos_sizes[i] + << ", qvel_size=" << schema.qvel_sizes[i] << "."; + throw std::invalid_argument(msg.str()); + } + + for (int j = 0; j < target_spec.qpos_size; ++j) { + this->d->qpos[target_spec.qpos_adr + j] = state.qpos[qpos_offset + j]; + } + for (int j = 0; j < target_spec.qvel_size; ++j) { + this->d->qvel[target_spec.qvel_adr + j] = state.qvel[qvel_offset + j]; + } + + qpos_offset += schema.qpos_sizes[i]; + qvel_offset += schema.qvel_sizes[i]; + } + + for (size_t i = 0; i < this->dynamic_joint_specs.size(); ++i) { + if (!matched_target_joints[i]) { + std::cerr << "WARNING: Replay model dynamic joint '" + << this->dynamic_joint_specs[i].name + << "' is missing in the recorded schema. Leaving it at its " + "current value." + << std::endl; + } + } + + mj_forward(this->m, this->d); +} + void Sim::reset_callbacks() { for (size_t i = 0; i < std::size(this->callbacks); ++i) { this->callbacks[i].last_call_timestamp = 0; diff --git a/src/sim/sim.h b/src/sim/sim.h index 4ed35e60..6cae264c 100644 --- a/src/sim/sim.h +++ b/src/sim/sim.h @@ -5,10 +5,13 @@ #include #include #include +#include +#include #include "boost/interprocess/managed_shared_memory.hpp" #include "gui.h" #include "mujoco/mujoco.h" +#include "rcs/utils.h" namespace rcs { namespace sim { @@ -55,16 +58,42 @@ struct RenderingCallback { mjtNum last_call_timestamp; // in seconds }; +struct DynamicJointSchema { + std::vector joint_names; + std::vector joint_types; + std::vector qpos_sizes; + std::vector qvel_sizes; +}; + +struct DynamicJointState { + rcs::common::VectorXd qpos; + rcs::common::VectorXd qvel; +}; + class Sim { private: + struct DynamicJointSpec { + std::string name; + int type; + int qpos_adr; + int qvel_adr; + int qpos_size; + int qvel_size; + }; + SimConfig cfg; std::vector callbacks; std::vector any_callbacks; std::vector all_callbacks; std::vector rendering_callbacks; + std::vector dynamic_joint_specs; + std::unordered_map dynamic_joint_name_to_index; void invoke_callbacks(); bool invoke_condition_callbacks(); void invoke_rendering_callbacks(); + void init_dynamic_joint_specs(); + static int get_joint_qpos_size(int joint_type); + static int get_joint_qvel_size(int joint_type); size_t convergence_steps = 0; bool converged = true; std::optional gui; @@ -83,6 +112,10 @@ class Sim { void step(size_t k); void reset_callbacks(); void reset(); + DynamicJointSchema get_dynamic_joint_schema() const; + DynamicJointState get_dynamic_joint_state() const; + void set_dynamic_joint_state(const DynamicJointSchema& schema, + const DynamicJointState& state); /* NOTE: IMPORTANT, the callback is not necessarily called at exactly the * the requested interval. We invoke a callback if the elapsed simulation time * since the last call of the callback is greater than the requested time.