diff --git a/python/rcs/__main__.py b/python/rcs/__main__.py index 344060d4..95d30a6d 100644 --- a/python/rcs/__main__.py +++ b/python/rcs/__main__.py @@ -4,9 +4,11 @@ import typer from rcs.envs.storage_wrapper import StorageWrapper from rcs.lerobot_joint_converter import ( + DEFAULT_BINARIZE_GRIPPER, DEFAULT_CAMERAS, DEFAULT_DATASET_PATHS, DEFAULT_FPS, + DEFAULT_GRIPPER_BINARIZE_THRESHOLD, DEFAULT_GRIPPER_TYPE, DEFAULT_HF_DATA_DIR, DEFAULT_IMAGE_BATCH_SIZE, @@ -144,6 +146,15 @@ def lerobot_convert( per_robot_arm_dim: Annotated[ int, typer.Option(help="Per-robot arm joint/action dimension without gripper. Example: --per-robot-arm-dim 7") ] = DEFAULT_PER_ROBOT_ARM_DIM, + binarize_gripper: Annotated[ + bool, typer.Option(help="Binarize gripper values before export. Example: --binarize-gripper") + ] = DEFAULT_BINARIZE_GRIPPER, + gripper_binarize_threshold: Annotated[ + float, + typer.Option( + help="Threshold used when binarizing gripper values; values above this become 1.0. Example: --gripper-binarize-threshold 0.2" + ), + ] = DEFAULT_GRIPPER_BINARIZE_THRESHOLD, success: Annotated[bool, typer.Option(help="Only include successful episodes. Example: --success")] = True, n: Annotated[int, typer.Option(help="Maximum number of episodes to convert. -1 means all. Example: --n 50")] = -1, video_encoding: Annotated[bool, typer.Option(help="Should the image data be video encoded")] = False, @@ -164,6 +175,8 @@ def lerobot_convert( cameras=cameras, image_batch_size=image_batch_size, per_robot_arm_dim=per_robot_arm_dim, + binarize_gripper=binarize_gripper, + gripper_binarize_threshold=gripper_binarize_threshold, success=success, n=n, video_encoding=video_encoding, diff --git a/python/rcs/lerobot_joint_converter.py b/python/rcs/lerobot_joint_converter.py index 60b800c3..613ef5c9 100644 --- a/python/rcs/lerobot_joint_converter.py +++ b/python/rcs/lerobot_joint_converter.py @@ -27,6 +27,8 @@ DEFAULT_ROBOT_KEYS = ["left", "right"] DEFAULT_JOINTS = False DEFAULT_GRIPPER_TYPE = "Robotiq2F85" +DEFAULT_BINARIZE_GRIPPER = False +DEFAULT_GRIPPER_BINARIZE_THRESHOLD = 0.2 @dataclass(frozen=True) @@ -98,6 +100,8 @@ def __init__( cameras: list[CamConversionConfig] | None = None, image_batch_size: int = DEFAULT_IMAGE_BATCH_SIZE, per_robot_arm_dim: int = DEFAULT_PER_ROBOT_ARM_DIM, + binarize_gripper: bool = DEFAULT_BINARIZE_GRIPPER, + gripper_binarize_threshold: float = DEFAULT_GRIPPER_BINARIZE_THRESHOLD, video_encoding: bool = False, video_backend: str | None = None, ): @@ -115,6 +119,8 @@ def __init__( self.per_robot_arm_dim = per_robot_arm_dim self.per_robot_state_dim = self.per_robot_arm_dim + 1 self.state_dim = len(self.robot_keys) * self.per_robot_state_dim + self.binarize_gripper = binarize_gripper + self.gripper_binarize_threshold = gripper_binarize_threshold self.source_sql = self._build_source_sql(self.dataset_paths) self.video_encoding = video_encoding @@ -137,6 +143,11 @@ def __init__( video_backend=video_backend, ) + def _maybe_binarize_gripper(self, gripper: np.ndarray) -> np.ndarray: + if not self.binarize_gripper: + return gripper.astype(np.float32) + return (gripper > self.gripper_binarize_threshold).astype(np.float32) + def _build_features(self) -> dict[str, dict[str, Any]]: state_names = [] for robot_key in self.robot_keys: @@ -268,6 +279,7 @@ def _build_observation_state(self, row: pd.Series) -> np.ndarray: f"joints={joints_vec.shape}, gripper={gripper_vec.shape}" ) raise ValueError(msg) + gripper_vec = self._maybe_binarize_gripper(gripper_vec) vectors.append(np.concatenate([joints_vec, gripper_vec]).astype(np.float32)) return np.concatenate(vectors).astype(np.float32) @@ -300,6 +312,7 @@ def _convert_action_to_joint_space(self, row: pd.Series) -> np.ndarray | None: f"absolute_action={absolute_action_vec.shape}, action_gripper={action_gripper_vec.shape}" ) raise ValueError(msg) + action_gripper_vec = self._maybe_binarize_gripper(action_gripper_vec) if self.joints: arm_action_vec = absolute_action_vec.astype(np.float32) @@ -415,6 +428,8 @@ def run_conversion( cameras: list[CamConversionConfig] | None = None, image_batch_size: int = DEFAULT_IMAGE_BATCH_SIZE, per_robot_arm_dim: int = DEFAULT_PER_ROBOT_ARM_DIM, + binarize_gripper: bool = DEFAULT_BINARIZE_GRIPPER, + gripper_binarize_threshold: float = DEFAULT_GRIPPER_BINARIZE_THRESHOLD, success: bool = True, n: int = -1, video_encoding: bool = False, @@ -434,6 +449,8 @@ def run_conversion( cameras=cameras, image_batch_size=image_batch_size, per_robot_arm_dim=per_robot_arm_dim, + binarize_gripper=binarize_gripper, + gripper_binarize_threshold=gripper_binarize_threshold, video_encoding=video_encoding, video_backend=video_backend, )