Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 36 additions & 8 deletions python/rcs/_core/sim.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ import rcs._core.common

__all__: list[str] = [
"CameraType",
"DynamicJointSchema",
"DynamicJointState",
"FrameSet",
"GuiClient",
"Sim",
Expand All @@ -32,6 +34,7 @@ __all__: list[str] = [
"tracking",
]
M = typing.TypeVar("M", bound=int)
N = typing.TypeVar("N", bound=int)

class CameraType:
"""
Expand Down Expand Up @@ -68,6 +71,18 @@ class CameraType:
@property
def value(self) -> int: ...

class DynamicJointSchema:
joint_names: list[str]
joint_types: list[int]
qpos_sizes: list[int]
qvel_sizes: list[int]
def __init__(self) -> None: ...

class DynamicJointState:
qpos: numpy.ndarray[tuple[M], numpy.dtype[numpy.float64]]
qvel: numpy.ndarray[tuple[M], numpy.dtype[numpy.float64]]
def __init__(self) -> None: ...

class FrameSet:
def __init__(
self,
Expand All @@ -93,20 +108,28 @@ class Sim:
def _start_gui_server(self, id: str) -> None: ...
def _stop_gui_server(self) -> None: ...
def get_config(self) -> SimConfig: ...
def get_dynamic_joint_schema(self) -> DynamicJointSchema: ...
def get_dynamic_joint_state(self) -> DynamicJointState: ...
def is_converged(self) -> bool: ...
def reset(self) -> None: ...
def set_config(self, cfg: SimConfig) -> bool: ...
def set_dynamic_joint_state(self, schema: DynamicJointSchema, state: DynamicJointState) -> None: ...
def step(self, k: int) -> None: ...
def step_until_convergence(self) -> None: ...
def sync_gui(self) -> None: ...

class SimCameraConfig(rcs._core.common.BaseCameraConfig):
type: CameraType
def __copy__(self) -> SimCameraConfig: ...
def __deepcopy__(self, arg0: dict) -> SimCameraConfig: ...
def __init__(
self, identifier: str, frame_rate: int, resolution_width: int, resolution_height: int, type: CameraType = ...
) -> None: ...

class SimCameraSet:
def __init__(self, sim: Sim, cameras: dict[str, SimCameraConfig], render_on_demand: bool = True) -> None: ...
def __init__(
self, sim: Sim, cameras: dict[str, SimCameraConfig], render_on_demand: bool = True, max_buffer_frames: int = 100
) -> None: ...
def buffer_size(self) -> int: ...
def clear_buffer(self) -> None: ...
def get_latest_frameset(self) -> FrameSet | None: ...
Expand All @@ -116,7 +139,7 @@ class SimCameraSet:

class SimConfig:
async_control: bool
frequency: int
frequency: float
max_convergence_steps: int
realtime: bool
def __copy__(self) -> SimConfig: ...
Expand All @@ -125,7 +148,7 @@ class SimConfig:
self,
async_control: bool = False,
realtime: bool = False,
frequency: float = 30,
frequency: float = 30.0,
max_convergence_steps: int = 500,
) -> None: ...

Expand All @@ -142,6 +165,7 @@ class SimGripperConfig(rcs._core.common.GripperConfig):
collision_geoms_fingers: list[str]
epsilon_inner: float
epsilon_outer: float
gripper_type: rcs._core.common.GripperType
ignored_collision_geoms: list[str]
joints: list[str]
max_actuator_width: float
Expand All @@ -165,8 +189,9 @@ class SimGripperConfig(rcs._core.common.GripperConfig):
actuator: str = "actuator8",
max_actuator_width: float = 255.0,
min_actuator_width: float = 0.0,
gripper_type: rcs._core.common.GripperType = ...,
) -> None: ...
def add_postfix(self, id: str) -> None: ...
def add_prefix(self, id: str) -> None: ...

class SimGripperState(rcs._core.common.GripperState):
def __init__(self) -> None: ...
Expand All @@ -193,9 +218,10 @@ class SimRobotConfig(rcs._core.common.RobotConfig):
actuators: list[str]
arm_collision_geoms: list[str]
base: str
dof: int
joint_limits: numpy.ndarray[tuple[typing.Literal[2], N], numpy.dtype[numpy.float64]]
joint_rotational_tolerance: float
joints: list[str]
mjcf_scene_path: str
seconds_between_callbacks: float
trajectory_trace: bool
def __copy__(self) -> SimRobotConfig: ...
Expand All @@ -208,7 +234,6 @@ class SimRobotConfig(rcs._core.common.RobotConfig):
kinematic_model_path: str = "assets/scenes/fr3_empty_world/robot.xml",
joint_rotational_tolerance: float = 0.0008726646259971648,
seconds_between_callbacks: float = 0.1,
mjcf_scene_path: str = "assets/scenes/fr3_empty_world/scene.xml",
trajectory_trace: bool = False,
arm_collision_geoms: list[str] = [
"fr3_link0_collision",
Expand All @@ -229,6 +254,7 @@ class SimRobotConfig(rcs._core.common.RobotConfig):
"fr3_joint6",
"fr3_joint7",
],
q_home: numpy.ndarray[tuple[M], numpy.dtype[numpy.float64]] | None = None,
actuators: list[str] = [
"fr3_joint1",
"fr3_joint2",
Expand All @@ -239,8 +265,10 @@ class SimRobotConfig(rcs._core.common.RobotConfig):
"fr3_joint7",
],
base: str = "base",
dof: int = 7,
joint_limits: numpy.ndarray[tuple[typing.Literal[2], N], numpy.dtype[numpy.float64]] = ...,
) -> None: ...
def add_postfix(self, id: str) -> None: ...
def add_prefix(self, id: str) -> None: ...

class SimRobotState(rcs._core.common.RobotState):
def __init__(self) -> None: ...
Expand Down Expand Up @@ -323,7 +351,7 @@ class SimTilburgHandConfig(rcs._core.common.HandConfig):
max_joint_position: numpy.ndarray[tuple[typing.Literal[16]], numpy.dtype[numpy.float64]] = ...,
min_joint_position: numpy.ndarray[tuple[typing.Literal[16]], numpy.dtype[numpy.float64]] = ...,
) -> None: ...
def add_postfix(self, id: str) -> None: ...
def add_prefix(self, id: str) -> None: ...

class SimTilburgHandState(rcs._core.common.HandState):
def __init__(self) -> None: ...
Expand Down
13 changes: 12 additions & 1 deletion python/rcs/envs/sim.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,15 +52,26 @@ def __init__(self, env):
super().__init__(env)
assert self.env.get_wrapper_attr("PLATFORM") == RobotPlatform.SIMULATION, "Base environment must be simulation."
self.sim = cast(sim.Sim, self.get_wrapper_attr("sim"))
self._state_spec = self.sim.get_state_spec()
self._include_state_spec_in_next_step = True

def observation(self, observation: dict[str, Any], info: dict[str, Any]) -> tuple[dict[str, Any], dict[str, Any]]:
observation = dict(observation)
sim_state = self.sim.get_state()
observation[self.STATE_KEY] = sim_state
observation[self.STATE_SPEC_KEY] = self.sim.get_state_spec()
observation[self.STATE_SIZE_KEY] = sim_state.shape[0]
if self._include_state_spec_in_next_step:
observation[self.STATE_SPEC_KEY] = self._state_spec
self._include_state_spec_in_next_step = False
return observation, info

def reset(
self, *, seed: int | None = None, options: dict[str, Any] | None = None
) -> tuple[dict[str, Any], dict[str, Any]]:
obs, info = super().reset(seed=seed, options=options)
self._include_state_spec_in_next_step = True
return obs, info


class GripperWrapperSim(ActObsInfoWrapper):
def __init__(self, env):
Expand Down
81 changes: 60 additions & 21 deletions python/rcs/sim/sim.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
import mujoco as mj
import mujoco.viewer
import numpy as np
from rcs._core.sim import DynamicJointSchema as _DynamicJointSchema
from rcs._core.sim import DynamicJointState as _DynamicJointState
from rcs._core.sim import GuiClient as _GuiClient
from rcs._core.sim import Sim as _Sim
from rcs.sim import SimConfig, egl_bootstrap
Expand Down Expand Up @@ -43,8 +45,6 @@ def gui_loop(gui_uuid: str, close_event):


class Sim(_Sim):
STATE_SPEC = mj.mjtState.mjSTATE_INTEGRATION

def __init__(self, mjmdl: str | PathLike, cfg: SimConfig | None = None):
mjmdl = Path(mjmdl)
if mjmdl.suffix == ".xml":
Expand All @@ -64,31 +64,70 @@ def __init__(self, mjmdl: str | PathLike, cfg: SimConfig | None = None):
if cfg is not None:
self.set_config(cfg)

def get_state_spec(self) -> int:
return int(self.STATE_SPEC)
def get_state_spec(self) -> dict[str, list[str] | list[int]]:
return self.get_dynamic_joint_schema()

def get_state_size(self, spec: int | None = None) -> int:
state_spec = self.STATE_SPEC if spec is None else mj.mjtState(spec)
return mj.mj_stateSize(self.model, state_spec)
def get_state_size(self, spec: dict[str, list[str] | list[int]] | None = None) -> int:
state_spec = self.get_state_spec() if spec is None else spec
qpos_size = sum(int(value) for value in state_spec["qpos_sizes"])
qvel_size = sum(int(value) for value in state_spec["qvel_sizes"])
return qpos_size + qvel_size

def get_state(self, spec: int | None = None) -> np.ndarray:
state_spec = self.STATE_SPEC if spec is None else mj.mjtState(spec)
state = np.empty(self.get_state_size(int(state_spec)), dtype=np.float64)
mj.mj_getState(self.model, self.data, state, state_spec)
return state
def get_state(self, spec: dict[str, list[str] | list[int]] | None = None) -> np.ndarray:
del spec
dynamic_joint_state = self.get_dynamic_joint_state()
return np.concatenate((dynamic_joint_state["qpos"], dynamic_joint_state["qvel"]))

def set_state(self, state: np.ndarray, spec: int | None = None):
state_spec = self.STATE_SPEC if spec is None else mj.mjtState(spec)
def set_state(
self,
state: np.ndarray,
spec: dict[str, list[str] | list[int]] | None = None,
):
state_spec = self.get_state_spec() if spec is None else spec
state_array = np.asarray(state, dtype=np.float64)
expected_size = self.get_state_size(int(state_spec))
expected_size = self.get_state_size(state_spec)
if state_array.shape != (expected_size,):
msg = (
f"Expected MuJoCo state with shape ({expected_size},), "
f"got {state_array.shape} for spec {int(state_spec)}."
)
msg = f"Expected state with shape ({expected_size},), got {state_array.shape}."
raise ValueError(msg)
mj.mj_setState(self.model, self.data, state_array, state_spec)
mj.mj_forward(self.model, self.data)

qpos_size = sum(int(value) for value in state_spec["qpos_sizes"])
dynamic_joint_state = {
"qpos": state_array[:qpos_size],
"qvel": state_array[qpos_size:],
}
self.set_dynamic_joint_state(state_spec, dynamic_joint_state)

def get_dynamic_joint_schema(self) -> dict[str, list[str] | list[int]]:
schema = super().get_dynamic_joint_schema()
return {
"joint_names": list(schema.joint_names),
"joint_types": list(schema.joint_types),
"qpos_sizes": list(schema.qpos_sizes),
"qvel_sizes": list(schema.qvel_sizes),
}

def get_dynamic_joint_state(self) -> dict[str, np.ndarray]:
state = super().get_dynamic_joint_state()
return {
"qpos": np.asarray(state.qpos, dtype=np.float64),
"qvel": np.asarray(state.qvel, dtype=np.float64),
}

def set_dynamic_joint_state(
self,
schema: dict[str, list[str] | list[int]],
state: dict[str, np.ndarray],
):
dynamic_joint_schema = _DynamicJointSchema()
dynamic_joint_schema.joint_names = list(schema["joint_names"])
dynamic_joint_schema.joint_types = [int(value) for value in schema["joint_types"]]
dynamic_joint_schema.qpos_sizes = [int(value) for value in schema["qpos_sizes"]]
dynamic_joint_schema.qvel_sizes = [int(value) for value in schema["qvel_sizes"]]

dynamic_joint_state = _DynamicJointState()
dynamic_joint_state.qpos = np.asarray(state["qpos"], dtype=np.float64)
dynamic_joint_state.qvel = np.asarray(state["qvel"], dtype=np.float64)
super().set_dynamic_joint_state(dynamic_joint_schema, dynamic_joint_state)

def close_gui(self):
if self._stop_event is not None:
Expand Down
24 changes: 19 additions & 5 deletions python/rcs/sim_state_replay.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,9 @@ def sim_state(self) -> np.ndarray:
return np.asarray(self.observation[SimStateObservationWrapper.STATE_KEY], dtype=np.float64)

@property
def sim_state_spec(self) -> int:
return int(self.observation.get(SimStateObservationWrapper.STATE_SPEC_KEY, 0))
def sim_state_spec(self) -> dict[str, Any] | None:
schema = self.observation.get(SimStateObservationWrapper.STATE_SPEC_KEY)
return dict(schema) if schema is not None else None


class DuckDBUnavailableError(RuntimeError):
Expand Down Expand Up @@ -149,9 +150,18 @@ def resolve_trajectory_uuid(dataset_path: Path, trajectory_uuid: str | None, pre
raise ValueError(msg)


def restore_sim_step(env: gym.Env, recorded_step: RecordedSimStep):
def restore_sim_step(
env: gym.Env,
recorded_step: RecordedSimStep,
sim_state_spec: dict[str, Any] | None = None,
):
resolved_spec = sim_state_spec or recorded_step.sim_state_spec
if resolved_spec is None:
msg = "Recorded sim state is missing its schema."
raise ValueError(msg)

sim = env.get_wrapper_attr("sim")
sim.set_state(recorded_step.sim_state, spec=recorded_step.sim_state_spec)
sim.set_state(recorded_step.sim_state, spec=resolved_spec)


def collect_rgb_frames(env: gym.Env) -> dict[str, np.ndarray]:
Expand Down Expand Up @@ -190,9 +200,13 @@ def replay_trajectory(
msg = "No recorded sim states found in the requested trajectory."
raise ValueError(msg)

sim_state_spec = next(
(recorded_step.sim_state_spec for recorded_step in recorded_steps if recorded_step.sim_state_spec), None
)

env.reset()
for recorded_step in recorded_steps:
restore_sim_step(env, recorded_step)
restore_sim_step(env, recorded_step, sim_state_spec=sim_state_spec)
if output_dir is not None:
save_rgb_frames(output_dir, recorded_step, collect_rgb_frames(env))
if sleep_s > 0:
Expand Down
Loading
Loading