From 4a3a50ef81e576075643f7b72b45e22c5380a76a Mon Sep 17 00:00:00 2001 From: yuecideng Date: Tue, 12 May 2026 16:33:41 +0800 Subject: [PATCH] feat(agents): Add hierarchical language support for VLA training Add comprehensive language support to Online Data Streaming (ODS) for Vision-Language-Action (VLA) model training. The implementation provides: - Hierarchical language structure (task/subtask/primitive levels) - Multiple language sources (file, env, template, LLM) - Flexible storage modes (tokens, embeddings, hybrid) - LanguageManager for tokenization and data management - Integration with ODS shared memory buffer New files: - embodichain/lab/gym/envs/managers/language.py: LanguageManager, configs - embodichain/lab/gym/envs/managers/language_provider.py: Language providers - configs/language/: Example configurations and documentation - tests/agents/test_language_support.py: Test suite Modified files: - embodichain/agents/engine/data.py: Add language_cfg to OnlineDataEngine - embodichain/lab/gym/envs/embodied_env.py: Integrate LanguageManager - embodichain/lab/gym/utils/gym_utils.py: Extend buffer initialization - embodichain/lab/gym/envs/managers/__init__.py: Export language classes This enables VLA models to learn from multi-scale language representations similar to human task understanding. Co-Authored-By: Claude Opus 4.6 --- configs/language/README.md | 275 +++++++ configs/language/tasks_example.yaml | 108 +++ configs/language/usage_example.py | 327 ++++++++ embodichain/agents/engine/data.py | 38 + embodichain/lab/gym/envs/embodied_env.py | 184 +++++ embodichain/lab/gym/envs/managers/__init__.py | 15 + embodichain/lab/gym/envs/managers/language.py | 767 ++++++++++++++++++ .../gym/envs/managers/language_provider.py | 647 +++++++++++++++ embodichain/lab/gym/utils/gym_utils.py | 127 ++- tests/agents/test_language_support.py | 325 ++++++++ 10 files changed, 2811 insertions(+), 2 deletions(-) create mode 100644 configs/language/README.md create mode 100644 configs/language/tasks_example.yaml create mode 100644 configs/language/usage_example.py create mode 100644 embodichain/lab/gym/envs/managers/language.py create mode 100644 embodichain/lab/gym/envs/managers/language_provider.py create mode 100644 tests/agents/test_language_support.py diff --git a/configs/language/README.md b/configs/language/README.md new file mode 100644 index 00000000..c69148c2 --- /dev/null +++ b/configs/language/README.md @@ -0,0 +1,275 @@ +# Language Support for VLA Training + +This directory contains configuration and examples for the hierarchical language support feature in EmbodiChain, enabling Vision-Language-Action (VLA) model training with Online Data Streaming (ODS). + +## Overview + +The language support feature adds hierarchical language descriptions to the rollout buffer, organized at three abstraction levels: + +1. **Task Level**: High-level goal or overall task description +2. **Subtask Level**: Intermediate step descriptions +3. **Primitive Level**: Low-level action descriptions + +This hierarchical structure enables VLA models to learn from multi-scale language representations, similar to human task understanding. + +## Features + +- **Multiple Language Sources**: Support for file-based, environment-based, template-based, and LLM-generated language +- **Hierarchical Structure**: Organize instructions at multiple abstraction levels +- **Flexible Storage**: Support for tokens, embeddings, or hybrid storage modes +- **Dynamic Chunk Sizes**: Works with variable-length trajectory chunks +- **Curriculum Learning**: Gradually increase language complexity during training +- **Token Agnostic**: Works with various tokenizers (GPT, BERT, etc.) + +## Quick Start + +### 1. Prepare Language Configuration + +Create a YAML file with task descriptions: + +```yaml +# tasks.yaml +pick_and_place: + task: + - "Pick up the red block and place it in the blue basket." + + subtask: + - "Move the gripper to the red block." + - "Grasp the red block." + - "Lift the block and move to the blue basket." + - "Release the block into the basket." + + primitive: + - "Close gripper." + - "Move up." + - "Move right." + - "Open gripper." +``` + +### 2. Configure ODS Engine + +```python +from embodichain.agents.engine.data import OnlineDataEngine, OnlineDataEngineCfg + +language_cfg = { + "mode": "tokens", + "hierarchy_levels": ["task", "subtask", "primitive"], + "max_tokens": 512, + "tokenizer": "gpt2", + "language_source": "file", + "language_config_path": "configs/language/tasks.yaml", + "max_instructions_per_level": 5, +} + +engine_cfg = OnlineDataEngineCfg( + buffer_size=16, + max_episode_steps=300, + state_dim=14, + gym_config={...}, + language_cfg=language_cfg, +) + +engine = OnlineDataEngine(engine_cfg) +engine.start() +``` + +### 3. Use Language Data in Training + +```python +from embodichain.agents.datasets.online_data import OnlineDataset +from torch.utils.data import DataLoader + +dataset = OnlineDataset(engine, chunk_size=64, batch_size=8) +loader = DataLoader(dataset, batch_size=None) + +for batch in loader: + obs = batch["obs"] + actions = batch["actions"] + language = batch["language"] + + # Access language at different hierarchy levels + task_tokens = language["task_level_tokens"] + subtask_tokens = language["subtask_level_tokens"] + primitive_tokens = language["primitive_level_tokens"] + + # Train your VLA model + # loss = vla_model(obs, language, actions) +``` + +## Configuration Options + +### Language Configuration + +| Parameter | Type | Default | Description | +|-----------|------|---------|-------------| +| `mode` | str | "tokens" | Storage mode: 'tokens', 'embeddings', or 'hybrid' | +| `hierarchy_levels` | list | ["task", "subtask", "primitive"] | Hierarchy levels to store | +| `max_tokens` | int | 512 | Maximum sequence length per instruction | +| `tokenizer` | str | "gpt2" | Tokenizer identifier | +| `pad_token_id` | int | 0 | Token ID used for padding | +| `max_instructions_per_level` | int | 3 | Maximum number of instructions per level | +| `embedding_dim` | int | 768 | Dimension for embeddings (if mode='embeddings') | +| `language_source` | str | "env" | Source of language: 'env', 'file', 'llm', 'template' | +| `language_config_path` | str | None | Path to language config file (if source='file') | + +### Language Sources + +#### File-Based (`language_source: "file"`) +Load language descriptions from YAML or JSON files. Best for static task descriptions. + +```python +language_cfg = { + "language_source": "file", + "language_config_path": "configs/language/tasks.yaml", +} +``` + +#### Environment-Based (`language_source: "env"`) +Generate language descriptions from the environment. The environment should implement: +- `get_task_language(task_id, context) -> HierarchicalLanguageData` +- Or have a `task_description` attribute + +```python +language_cfg = { + "language_source": "env", +} +``` + +#### Template-Based (`language_source: "template"`) +Use templates with variable substitution for structured tasks. + +```python +language_cfg = { + "language_source": "template", + "templates": { + "pick_and_place": { + "task": "Pick up the {color} {object} and place it {location}.", + "subtasks": [...], + } + }, + "variables": {"color": "red", "object": "block", "location": "in basket"}, +} +``` + +#### LLM-Based (`language_source: "llm"`) +Generate descriptions using an LLM (e.g., GPT-4, Claude). + +```python +language_cfg = { + "language_source": "llm", + "model": "gpt-4", + "api_key": "your-api-key", +} +``` + +## Buffer Structure + +When language support is enabled, the rollout buffer includes the following fields: + +### Per-Hierarchy-Level Fields + +For each level in `hierarchy_levels` (e.g., "task", "subtask", "primitive"): + +- `{level}_tokens`: `[batch_size, max_episode_steps, max_instructions, max_tokens]` +- `{level}_attention_mask`: `[batch_size, max_episode_steps, max_instructions, max_tokens]` +- `{level}_count`: `[batch_size, max_episode_steps]` + +### Global Fields + +- `instruction_counts`: `[batch_size, max_episode_steps, 3]` - Counts per hierarchy level +- `change_points`: `[batch_size, max_episode_steps, max_instructions]` - Timesteps where language changes +- `hierarchy_depth`: `[batch_size, max_episode_steps]` - Current depth of hierarchy (1-3) +- `instruction_types`: `[batch_size, max_episode_steps, max_instructions]` - Instruction type IDs + +## Advanced Usage + +### Custom Language Provider + +```python +from embodichain.lab.gym.envs.managers import LanguageProvider, HierarchicalLanguageData + +class MyLanguageProvider(LanguageProvider): + def get_language(self, task_id, context=None): + # Generate custom language data + return HierarchicalLanguageData( + task_level=[...], + subtask_level=[...], + primitive_level=[...], + ) + + def get_available_tasks(self): + return ["task1", "task2"] +``` + +### Language Augmentation + +```python +from embodichain.lab.gym.envs.managers import LanguageAugmentationCfg + +augmentation_cfg = LanguageAugmentationCfg( + synonym_replacement=0.1, + template_variation=True, + augmentation_prob=0.5, +) +``` + +### Curriculum Learning + +```python +from embodichain.lab.gym.envs.managers import LanguageCurriculumCfg + +curriculum_cfg = LanguageCurriculumCfg( + enabled=True, + stage_duration=1000, + stages=[ + # Simple language first + LanguageCurriculumCfg.CurriculumStage( + max_words=10, + max_sentences=1, + max_hierarchy_depth=1, + ), + # Then more complex + LanguageCurriculumCfg.CurriculumStage( + max_words=50, + max_sentences=3, + max_hierarchy_depth=3, + ), + ], +) +``` + +## Examples + +See `usage_example.py` for complete examples of: +- File-based language loading +- Environment-based language generation +- Template-based language +- Dynamic chunk sizes with language +- Custom environments with language + +## Files + +- `tasks_example.yaml` - Example task descriptions in YAML format +- `usage_example.py` - Complete usage examples +- `README.md` - This file + +## API Reference + +### Core Classes + +- `LanguageCfg` - Configuration for language data +- `LanguageManager` - Manages tokenization and language data +- `LanguageData` - Single-level language data +- `HierarchicalLanguageData` - Multi-level hierarchical language data +- `LanguageProvider` - Abstract base for language sources +- `FileBasedLanguageProvider` - Load from YAML/JSON files +- `LLMBasedLanguageProvider` - Generate with LLM +- `EnvBasedLanguageProvider` - Generate from environment +- `TemplateBasedLanguageProvider` - Template-based generation + +## Notes + +- Language data is broadcast across all timesteps in an episode +- Tokenization happens in the simulation subprocess for efficiency +- Shared memory ensures zero-copy data transfer to training process +- Compatible with all existing ODS features (dynamic chunks, etc.) diff --git a/configs/language/tasks_example.yaml b/configs/language/tasks_example.yaml new file mode 100644 index 00000000..70fe6f5f --- /dev/null +++ b/configs/language/tasks_example.yaml @@ -0,0 +1,108 @@ +# Example language configuration file for VLA training +# This file demonstrates the hierarchical language structure for task descriptions + +pick_and_place: + task: + - "Pick up the red block and place it in the blue basket." + + subtask: + - "Move the gripper to the red block." + - "Grasp the red block." + - "Lift the block and move to the blue basket." + - "Release the block into the basket." + + primitive: + - "Close gripper." + - "Move up." + - "Move right." + - "Open gripper." + + change_points: [0, 10, 20, 30] + +stack_blocks: + task: + - "Stack the red block on top of the green block." + + subtask: + - "Locate the red block and green block." + - "Move the gripper to the red block." + - "Grasp the red block." + - "Lift the red block." + - "Position the red block above the green block." + - "Lower the red block onto the green block." + - "Release the red block." + + primitive: + - "Close gripper." + - "Move up." + - "Move forward." + - "Move left." + - "Open gripper." + + change_points: [0, 5, 10, 15, 20, 25, 30] + +pour_liquid: + task: + - "Pour water from the cup into the bowl." + + subtask: + - "Approach the cup with the gripper." + - "Grasp the cup securely." + - "Lift the cup." + - "Tilt the cup over the bowl." + - "Wait for liquid to pour." + - "Return the cup to upright position." + + primitive: + - "Close gripper." + - "Move up." + - "Rotate wrist." + - "Wait." + - "Rotate wrist back." + + change_points: [0, 5, 10, 15, 25, 30] + +button_press: + task: + - "Press the red button to activate the mechanism." + + subtask: + - "Locate the red button." + - "Move the end-effector to the button." + - "Apply downward force to press the button." + - "Release and retract." + + primitive: + - "Move forward." + - "Move down." + - "Apply force." + - "Move up." + - "Move backward." + + change_points: [0, 5, 10, 15, 20] + +door_open: + task: + - "Open the cabinet door and place the object inside." + + subtask: + - "Approach the cabinet handle." + - "Grasp the cabinet handle." + - "Pull the door open." + - "Pick up the object." + - "Move the object into the cabinet." + - "Release the object." + - "Close the cabinet door." + + primitive: + - "Move forward." + - "Close gripper." + - "Move backward." + - "Move down." + - "Close gripper." + - "Move forward." + - "Open gripper." + - "Move backward." + - "Push forward." + + change_points: [0, 5, 10, 15, 20, 25, 30, 35] diff --git a/configs/language/usage_example.py b/configs/language/usage_example.py new file mode 100644 index 00000000..2a695b71 --- /dev/null +++ b/configs/language/usage_example.py @@ -0,0 +1,327 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +""" +Example: Using Language Support for VLA Training with Online Data Streaming + +This example demonstrates how to configure and use the hierarchical language +support for Vision-Language-Action (VLA) model training. +""" + +import torch +from torch.utils.data import DataLoader + +from embodichain.agents.engine.data import OnlineDataEngine, OnlineDataEngineCfg +from embodichain.agents.datasets.online_data import OnlineDataset +from embodichain.lab.gym.envs.managers import ( + LanguageCfg, + LanguageManager, + HierarchicalLanguageData, +) + + +# Example 1: Basic ODS with Language Support (File-based) +def example_ods_with_language_file(): + """Set up ODS with language descriptions loaded from a YAML file.""" + + # Language configuration + language_cfg = { + "mode": "tokens", + "hierarchy_levels": ["task", "subtask", "primitive"], + "max_tokens": 512, + "tokenizer": "gpt2", + "language_source": "file", + "language_config_path": "configs/language/tasks_example.yaml", + "max_instructions_per_level": 5, + } + + # ODS engine configuration + engine_cfg = OnlineDataEngineCfg( + buffer_size=16, + max_episode_steps=300, + state_dim=14, + gym_config={ + "id": "EmbodiedEnv-v1", + "env": {"robot": {...}}, + # ... other env config + }, + language_cfg=language_cfg, # Enable language support + ) + + # Create and start the engine + engine = OnlineDataEngine(engine_cfg) + engine.start() + + # Create dataset with language support + dataset = OnlineDataset(engine, chunk_size=64, batch_size=8) + + # Create DataLoader + loader = DataLoader( + dataset, + batch_size=None, # Batch mode (dataset handles batching) + num_workers=0, + collate_fn=OnlineDataset.passthrough_collate_fn, + ) + + # Training loop + for batch in loader: + # Access different data modalities + obs = batch["obs"] # Vision and proprioception + actions = batch["actions"] # Robot actions + language = batch["language"] # Hierarchical language data + + # Access language at different hierarchy levels + task_tokens = language[ + "task_level_tokens" + ] # [batch, chunk, max_instr, max_tokens] + task_mask = language["task_level_attention_mask"] + + subtask_tokens = language["subtask_level_tokens"] + subtask_mask = language["subtask_level_attention_mask"] + + primitive_tokens = language["primitive_level_tokens"] + primitive_mask = language["primitive_level_attention_mask"] + + # Use for VLA training + # train_step(obs, language, actions) + + +# Example 2: Environment-Based Language Generation +def example_env_based_language(): + """Set up ODS with language generated by the environment.""" + + language_cfg = { + "mode": "tokens", + "hierarchy_levels": ["task", "subtask"], + "max_tokens": 256, + "tokenizer": "gpt2", + "language_source": "env", # Environment generates language + } + + engine_cfg = OnlineDataEngineCfg( + buffer_size=16, + max_episode_steps=300, + state_dim=14, + gym_config={...}, + language_cfg=language_cfg, + ) + + engine = OnlineDataEngine(engine_cfg) + engine.start() + + # Your environment should implement: + # - get_task_language(task_id, context) -> HierarchicalLanguageData + # - Or have a task_description attribute + + +# Example 3: Template-Based Language +def example_template_based_language(): + """Set up ODS with template-based language generation.""" + + language_cfg = { + "mode": "tokens", + "hierarchy_levels": ["task", "subtask"], + "max_tokens": 256, + "tokenizer": "gpt2", + "language_source": "template", + "templates": { + "pick_and_place": { + "task": "Pick up the {color} {object} and place it {location}.", + "subtasks": [ + "Move to the {color} {object}.", + "Grasp the {color} {object}.", + "Move {location}.", + "Release the {object}.", + ], + } + }, + "variables": { + "color": "red", + "object": "block", + "location": "in the blue basket", + }, + } + + engine_cfg = OnlineDataEngineCfg( + buffer_size=16, + max_episode_steps=300, + state_dim=14, + gym_config={...}, + language_cfg=language_cfg, + ) + + engine = OnlineDataEngine(engine_cfg) + engine.start() + + +# Example 4: Using Language Manager Directly +def example_language_manager(): + """Use LanguageManager to tokenize and manage language data.""" + + cfg = LanguageCfg( + mode="tokens", + hierarchy_levels=["task", "subtask", "primitive"], + max_tokens=512, + tokenizer="gpt2", + ) + + # Create a simple mock environment + class MockEnv: + task_name = "pick_and_place" + task_description = "Pick up the red block and place it in the basket." + + env = MockEnv() + manager = LanguageManager(cfg, env) + + # Create hierarchical language data + language_data = manager.create_hierarchical_language_data( + task_texts="Pick up the red block and place it in the basket.", + subtask_texts=[ + "Move to the red block.", + "Grasp the red block.", + "Move to the basket.", + "Release the block.", + ], + primitive_texts=[ + "Close gripper.", + "Move up.", + "Move right.", + "Open gripper.", + ], + change_points=[0, 10, 20, 30], + ) + + # Convert to buffer format + buffer_format = language_data.to_buffer_format(cfg) + + # Access tokenized data + task_tokens = buffer_format["task_level_tokens"] # [max_instructions, max_tokens] + task_mask = buffer_format["task_level_attention_mask"] + + +# Example 5: Dynamic Chunk Size with Language +def example_dynamic_chunk_language(): + """Use dynamic chunk sizes with language support.""" + + from embodichain.agents.datasets.sampler import UniformChunkSampler + + language_cfg = { + "mode": "tokens", + "hierarchy_levels": ["task"], + "max_tokens": 256, + "tokenizer": "gpt2", + "language_source": "file", + "language_config_path": "configs/language/tasks_example.yaml", + } + + engine_cfg = OnlineDataEngineCfg( + buffer_size=16, + max_episode_steps=300, + state_dim=14, + gym_config={...}, + language_cfg=language_cfg, + ) + + engine = OnlineDataEngine(engine_cfg) + engine.start() + + # Dynamic chunk size sampler + chunk_sampler = UniformChunkSampler(low=32, high=96) + + # Dataset with dynamic chunk size + dataset = OnlineDataset( + engine, + chunk_size=chunk_sampler, + batch_size=8, + ) + + loader = DataLoader( + dataset, + batch_size=None, + collate_fn=OnlineDataset.passthrough_collate_fn, + ) + + for batch in loader: + # Batch shape is [batch_size, chunk_size, ...] + # Chunk dimension varies each iteration + print(f"Batch chunk size: {batch.shape[1]}") + + # Language tokens are broadcast across all timesteps + language = batch["language"] + task_tokens = language[ + "task_level_tokens" + ] # [batch_size, chunk_size, max_instr, max_tokens] + + +# Example 6: Custom Environment with Language +def example_custom_env_with_language(): + """Example of a custom environment implementing language generation.""" + + from embodichain.lab.gym.envs import EmbodiedEnv, EmbodiedEnvCfg + + class MyTaskEnv(EmbodiedEnv): + """Custom environment that provides language descriptions.""" + + def __init__(self, cfg, **kwargs): + super().__init__(cfg, **kwargs) + self.task_name = "my_custom_task" + + def get_task_language(self, task_id, context=None): + """Generate hierarchical language for the current task.""" + return self.language_manager.create_hierarchical_language_data( + task_texts="Complete the custom manipulation task.", + subtask_texts=[ + "Approach the object.", + "Grasp the object.", + "Move to target location.", + "Release the object.", + ], + primitive_texts=[ + "Move forward.", + "Close gripper.", + "Move up.", + "Move right.", + "Open gripper.", + ], + ) + + # Configuration with language + env_cfg = EmbodiedEnvCfg( + robot={...}, + language={ + "mode": "tokens", + "hierarchy_levels": ["task", "subtask", "primitive"], + "max_tokens": 512, + "tokenizer": "gpt2", + "language_source": "env", # Environment will generate language + }, + init_rollout_buffer=True, + ) + + env = MyTaskEnv(env_cfg) + + +if __name__ == "__main__": + print("Language Support Examples for VLA Training") + print("=" * 50) + print("\nAvailable examples:") + print("1. example_ods_with_language_file() - File-based language") + print("2. example_env_based_language() - Environment-based language") + print("3. example_template_based_language() - Template-based language") + print("4. example_language_manager() - Direct LanguageManager usage") + print("5. example_dynamic_chunk_language() - Dynamic chunk sizes") + print("6. example_custom_env_with_language() - Custom environment") + print("\nRun any example function to see it in action.") diff --git a/embodichain/agents/engine/data.py b/embodichain/agents/engine/data.py index c11fb966..a836a013 100644 --- a/embodichain/agents/engine/data.py +++ b/embodichain/agents/engine/data.py @@ -61,6 +61,31 @@ class OnlineDataEngineCfg: amortising the cost of environment simulation over many training steps. """ + language_cfg: Union[dict, None] = None + """Language configuration for VLA training. + + If provided, the shared buffer will include hierarchical language data fields + and the simulation subprocess will collect language descriptions during rollouts. + + The configuration should include: + - mode: Storage mode ('tokens', 'embeddings', 'hybrid') + - hierarchy_levels: List of hierarchy levels ('task', 'subtask', 'primitive') + - max_tokens: Maximum sequence length per instruction + - tokenizer: Tokenizer identifier + - language_source: Source of language ('env', 'file', 'llm', 'template') + - language_config_path: Path to language descriptions (if source='file') + + Example: + language_cfg = { + "mode": "tokens", + "hierarchy_levels": ["task", "subtask", "primitive"], + "max_tokens": 512, + "tokenizer": "gpt2", + "language_source": "file", + "language_config_path": "config/language/tasks.yaml", + } + """ + # --------------------------------------------------------------------------- # Subprocess entry point (module-level so it can be pickled by multiprocessing) @@ -110,6 +135,15 @@ def _sim_worker_fn( env_cfg = config_to_cfg(gym_config, manager_modules=DEFAULT_MANAGER_MODULES) env_cfg.filter_dataset_saving = True env_cfg.init_rollout_buffer = False + + # Add language configuration if provided + if cfg.language_cfg is not None: + env_cfg.language = cfg.language_cfg + log_info( + f"[Simulation Process] Language configuration added: {cfg.language_cfg.get('mode', 'tokens')}, " + f"hierarchy={cfg.language_cfg.get('hierarchy_levels', ['task', 'subtask', 'primitive'])}" + ) + env_cfg.sim_cfg = SimulationManagerCfg( headless=gym_config.get("headless", True), sim_device=gym_config.get("device", "cpu"), @@ -364,6 +398,9 @@ def _create_buffer(self) -> TensorDict: placed in CPU shared memory so it can be safely accessed from both the main process and the simulation subprocess. + If language configuration is provided, the buffer will also include + hierarchical language data fields for VLA training. + Returns: TensorDict in shared memory. """ @@ -380,6 +417,7 @@ def _create_buffer(self) -> TensorDict: batch_size=self.cfg.buffer_size, max_episode_steps=max_episode_steps, state_dim=self.cfg.state_dim, + language_cfg=self.cfg.language_cfg, ) if shared_td.device.type == "cpu": diff --git a/embodichain/lab/gym/envs/embodied_env.py b/embodichain/lab/gym/envs/embodied_env.py index d6ca36d9..a1ede4f1 100644 --- a/embodichain/lab/gym/envs/embodied_env.py +++ b/embodichain/lab/gym/envs/embodied_env.py @@ -205,6 +205,30 @@ class EnvLightCfg: If filter_dataset_saving is False and a dataset manager is configured, the rollout buffer will be initialized by default """ + language: Union[Dict[str, Any], None] = None + """Language settings for VLA training. + + When configured, enables hierarchical language data collection for + Vision-Language-Action model training. Supports: + + - mode: Storage mode ('tokens', 'embeddings', 'hybrid') + - hierarchy_levels: List of levels ('task', 'subtask', 'primitive') + - max_tokens: Maximum sequence length per instruction + - tokenizer: Tokenizer identifier + - language_source: Source of language ('env', 'file', 'llm', 'template') + - language_config_path: Path to language descriptions (if source='file') + + Example: + language = { + "mode": "tokens", + "hierarchy_levels": ["task", "subtask", "primitive"], + "max_tokens": 512, + "tokenizer": "gpt2", + "language_source": "file", + "language_config_path": "config/language/tasks.yaml", + } + """ + @register_env("EmbodiedEnv-v1") class EmbodiedEnv(BaseEnv): @@ -267,6 +291,7 @@ def __init__(self, cfg: EmbodiedEnvCfg, **kwargs): self.reward_manager: RewardManager | None = None self.action_manager: ActionManager | None = None self.dataset_manager: DatasetManager | None = None + self.language_manager = None super().__init__(cfg, **kwargs) @@ -274,12 +299,65 @@ def __init__(self, cfg: EmbodiedEnvCfg, **kwargs): self.dataset_manager = DatasetManager(self.cfg.dataset, self) self.cfg.init_rollout_buffer = True + # Initialize LanguageManager for VLA training + if self.cfg.language: + from embodichain.lab.gym.envs.managers import ( + LanguageCfg, + LanguageManager, + LanguageProvider, + FileBasedLanguageProvider, + LLMBasedLanguageProvider, + EnvBasedLanguageProvider, + TemplateBasedLanguageProvider, + ) + + # Create language config + language_cfg = LanguageCfg(**self.cfg.language) + + # Initialize language provider based on source + language_source = self.cfg.language.get("language_source", "env") + if language_source == "file": + language_config_path = self.cfg.language.get("language_config_path") + if language_config_path is None: + log_error( + "language_config_path must be provided when language_source='file'", + error_type=ValueError, + ) + self.language_provider = FileBasedLanguageProvider( + language_cfg, language_config_path + ) + elif language_source == "llm": + model = self.cfg.language.get("model", "gpt-4") + api_key = self.cfg.language.get("api_key") + self.language_provider = LLMBasedLanguageProvider( + language_cfg, model, api_key + ) + elif language_source == "template": + templates = self.cfg.language.get("templates", {}) + variables = self.cfg.language.get("variables", {}) + self.language_provider = TemplateBasedLanguageProvider( + language_cfg, templates, variables + ) + else: # env or default + self.language_provider = EnvBasedLanguageProvider(language_cfg, self) + + # Initialize language manager + self.language_manager = LanguageManager(language_cfg, self) + log_info( + f"[EmbodiedEnv] LanguageManager initialized with source={language_source}, " + f"mode={language_cfg.mode}, hierarchy={language_cfg.hierarchy_levels}" + ) + else: + self.language_manager = None + self.language_provider = None + # Rollout buffer for episode data collection. # The shape of the buffer is (num_envs, max_episode_steps, *data_shape) for each key. # The default key in the buffer are: # - obs: the observation returned by the environment. # - action: the action applied to the environment. # - reward: the reward returned by the environment. + # - language: Hierarchical language data for VLA training (if language_manager is set) # TODO: we may add more keys and make the buffer extensible in the future. # This buffer should also be support initialized from outside of the environment. # For example, a shared rollout buffer initialized in model training process and passed to the environment for data collection. @@ -287,6 +365,8 @@ def __init__(self, cfg: EmbodiedEnvCfg, **kwargs): self._max_rollout_steps = 0 self._rollout_buffer_mode: str | None = None if self.cfg.init_rollout_buffer: + # Determine if we need to initialize language fields + language_cfg = self.cfg.language if self.cfg.language else None self.rollout_buffer = init_rollout_buffer_from_gym_space( obs_space=self.observation_space, action_space=self.action_space, @@ -551,6 +631,19 @@ def _initialize_episode( self.episode_success_status[env_ids_to_process] = False + # Initialize language data for the new episode + if self.language_manager is not None: + # Get task ID for language lookup + task_id = getattr(self, "task_name", "default") + + # Get language data from provider + if self.language_provider is not None: + language_data = self.language_provider.get_language( + task_id, context={"env_ids": env_ids} + ) + # Write language data to rollout buffer + self._write_language_data(language_data, env_ids_to_process) + # apply events such as randomization for environments that need a reset if self.cfg.events: if "reset" in self.event_manager.available_modes: @@ -612,6 +705,97 @@ def _write_episode_rollout_step( rewards.to(buffer_device), non_blocking=True ) + def _write_language_data( + self, + language_data: "HierarchicalLanguageData", + env_ids: Optional[torch.Tensor] = None, + ) -> None: + """Write hierarchical language data to the rollout buffer. + + This method writes language data at multiple hierarchy levels to the + rollout buffer. The data is broadcast across all timesteps of the + current episode. + + Args: + language_data: HierarchicalLanguageData containing task descriptions. + env_ids: Optional tensor of environment IDs to write to. + If None, writes to all environments. + """ + if self.rollout_buffer is None or "language" not in self.rollout_buffer: + return + + if env_ids is None: + env_ids = torch.arange(self.num_envs, device=self.device) + + buffer_device = self.rollout_buffer.device + + # Get language config for max values + cfg = self.language_manager.cfg + max_instructions = cfg.max_instructions_per_level + max_tokens = cfg.max_tokens + + # Convert language data to buffer format + buffer_format = language_data.to_buffer_format(cfg) + + # Write data for each hierarchy level + for level in cfg.hierarchy_levels: + level_key = f"{level}_level" + + # Get tokens and mask + tokens_key = f"{level_key}_tokens" + mask_key = f"{level_key}_attention_mask" + count_key = f"{level_key}_count" + + if tokens_key not in buffer_format: + continue + + tokens = buffer_format[tokens_key] # [max_instructions, max_tokens] + mask = buffer_format[mask_key] + + # Create the full tensor for all environments and timesteps + # Shape: [num_envs, max_episode_steps, max_instructions, max_tokens] + full_tokens = ( + tokens.unsqueeze(0) + .unsqueeze(0) + .expand(len(env_ids), self._max_rollout_steps, -1, -1) + ) + full_mask = ( + mask.unsqueeze(0) + .unsqueeze(0) + .expand(len(env_ids), self._max_rollout_steps, -1, -1) + ) + + # Write to buffer + self.rollout_buffer["language"][tokens_key][env_ids, ...] = full_tokens.to( + buffer_device, non_blocking=True + ) + self.rollout_buffer["language"][mask_key][env_ids, ...] = full_mask.to( + buffer_device, non_blocking=True + ) + + # Write instruction count + count = buffer_format.get(f"{level_key}_count", torch.tensor([0])) + level_idx = {"task": 0, "subtask": 1, "primitive": 2}[level] + self.rollout_buffer["language"]["instruction_counts"][ + env_ids, :, level_idx + ] = count.item() + + # Write change points + if "change_points" in buffer_format: + change_points = buffer_format["change_points"] + full_change_points = ( + change_points.unsqueeze(0) + .unsqueeze(0) + .expand(len(env_ids), self._max_rollout_steps, -1) + ) + self.rollout_buffer["language"]["change_points"][env_ids, ...] = ( + full_change_points.to(buffer_device, non_blocking=True) + ) + + # Write hierarchy depth + hierarchy_depth = language_data.hierarchy_depth + self.rollout_buffer["language"]["hierarchy_depth"][env_ids, :] = hierarchy_depth + def _write_rl_rollout_step( self, obs: EnvObs, diff --git a/embodichain/lab/gym/envs/managers/__init__.py b/embodichain/lab/gym/envs/managers/__init__.py index 939f190c..22b1effe 100644 --- a/embodichain/lab/gym/envs/managers/__init__.py +++ b/embodichain/lab/gym/envs/managers/__init__.py @@ -30,3 +30,18 @@ from .action_manager import * from .actions import * from .dataset_manager import DatasetManager +from .language import ( + LanguageCfg, + LanguageCurriculumCfg, + LanguageAugmentationCfg, + LanguageManager, + LanguageData, + HierarchicalLanguageData, +) +from .language_provider import ( + LanguageProvider, + FileBasedLanguageProvider, + LLMBasedLanguageProvider, + EnvBasedLanguageProvider, + TemplateBasedLanguageProvider, +) diff --git a/embodichain/lab/gym/envs/managers/language.py b/embodichain/lab/gym/envs/managers/language.py new file mode 100644 index 00000000..a62bc097 --- /dev/null +++ b/embodichain/lab/gym/envs/managers/language.py @@ -0,0 +1,767 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import Any, Dict, List, Optional, Literal, Union +from dataclasses import dataclass, field +from pathlib import Path + +import torch +import numpy as np + +from embodichain.utils import configclass +from embodichain.utils.logger import log_info, log_warning, log_error + +__all__ = [ + "LanguageCfg", + "LanguageCurriculumCfg", + "LanguageAugmentationCfg", + "LanguageManager", + "LanguageData", + "HierarchicalLanguageData", +] + + +@configclass +class LanguageCfg: + """Configuration for language data in rollout buffers. + + Supports three storage modes: + - 'tokens': Store token IDs (default, most flexible) + - 'embeddings': Store pre-computed embeddings + - 'hybrid': Store both tokens and embeddings + + Supports hierarchical language structure for VLA training: + - task_level: Overall goal/description + - subtask_level: Intermediate step descriptions + - primitive_level: Low-level action descriptions + + Args: + mode: Storage mode ('tokens', 'embeddings', or 'hybrid'). + hierarchy_levels: List of hierarchy levels to store. If None, uses + all levels. Valid levels: 'task', 'subtask', 'primitive'. + max_tokens: Maximum sequence length for tokenized text. + tokenizer: Tokenizer/model identifier (huggingface or OpenAI). + pad_token_id: Token ID used for padding. + max_instructions_per_level: Maximum number of instructions per hierarchy level. + embedding_dim: Dimension of text embeddings (when mode='embeddings' or 'hybrid'). + embedding_type: How to compute embeddings from tokens. + tokenizer_backend: 'huggingface' or 'openai'. + trust_remote_code: Whether to trust remote code for huggingface tokenizers. + """ + + mode: Literal["tokens", "embeddings", "hybrid"] = "tokens" + """Storage mode for language data.""" + + hierarchy_levels: Optional[List[Literal["task", "subtask", "primitive"]]] = None + """Hierarchy levels to store. If None, uses all levels.""" + + max_tokens: int = 512 + """Maximum sequence length for tokenized text per instruction.""" + + tokenizer: str = "gpt2" + """Tokenizer/model identifier.""" + + pad_token_id: int = 0 + """Token ID used for padding.""" + + max_instructions_per_level: int = 3 + """Maximum number of instructions per hierarchy level.""" + + embedding_dim: int = 768 + """Dimension of text embeddings.""" + + embedding_type: Literal["mean_pool", "cls", "last"] = "mean_pool" + """How to compute embeddings from tokens.""" + + tokenizer_backend: Literal["huggingface", "openai"] = "huggingface" + """Tokenizer backend to use.""" + + trust_remote_code: bool = False + """Whether to trust remote code for huggingface tokenizers.""" + + def __post_init__(self) -> None: + if self.hierarchy_levels is None: + self.hierarchy_levels = ["task", "subtask", "primitive"] + + # Validate hierarchy levels + valid_levels = {"task", "subtask", "primitive"} + for level in self.hierarchy_levels: + if level not in valid_levels: + log_error( + f"Invalid hierarchy level: {level}. Must be one of {valid_levels}.", + error_type=ValueError, + ) + + +@configclass +class LanguageCurriculumCfg: + """Language complexity curriculum for progressive training. + + Defines stages of increasing language complexity, allowing the model + to learn from simple descriptions before tackling complex ones. + + Args: + stages: List of curriculum stages, each defining complexity constraints. + stage_duration: Number of training steps per curriculum stage. + enabled: Whether curriculum learning is enabled. + """ + + @dataclass + class CurriculumStage: + """Configuration for a single curriculum stage.""" + + max_words: int = 50 + """Maximum number of words per instruction.""" + + max_sentences: int = 2 + """Maximum number of sentences per instruction.""" + + max_hierarchy_depth: int = 1 + """Maximum hierarchy depth (1=task only, 2=task+subtask, 3=all).""" + + vocabulary_complexity: Literal["simple", "moderate", "complex"] = "simple" + """Vocabulary complexity level.""" + + instruction_types: List[str] = field(default_factory=lambda: ["imperative"]) + """Allowed instruction types: 'imperative', 'declarative', 'conditional'.""" + + stages: List[CurriculumStage] = field( + default_factory=lambda: [ + LanguageCurriculumCfg.CurriculumStage( + max_words=10, + max_sentences=1, + max_hierarchy_depth=1, + vocabulary_complexity="simple", + instruction_types=["imperative"], + ), + LanguageCurriculumCfg.CurriculumStage( + max_words=25, + max_sentences=2, + max_hierarchy_depth=2, + vocabulary_complexity="moderate", + instruction_types=["imperative", "declarative"], + ), + LanguageCurriculumCfg.CurriculumStage( + max_words=50, + max_sentences=3, + max_hierarchy_depth=3, + vocabulary_complexity="complex", + instruction_types=["imperative", "declarative", "conditional"], + ), + ] + ) + + stage_duration: int = 1000 + """Number of training steps per curriculum stage.""" + + enabled: bool = False + """Whether curriculum learning is enabled.""" + + +@configclass +class LanguageAugmentationCfg: + """Configuration for language data augmentation. + + Augmentations are applied during sampling to increase data diversity + and improve model generalization. + + Args: + back_translation: Use back-translation for paraphrasing. + synonym_replacement: Probability of replacing words with synonyms. + template_variation: Apply template-based rephrasing. + drop_word: Probability of randomly dropping a word. + swap_word: Probability of swapping two adjacent words. + insert_word: Probability of inserting a filler word. + """ + + back_translation: bool = False + """Use back-translation for paraphrasing.""" + + synonym_replacement: float = 0.0 + """Probability of replacing words with synonyms [0.0, 1.0].""" + + template_variation: bool = False + """Apply template-based rephrasing.""" + + drop_word: float = 0.0 + """Probability of randomly dropping a word [0.0, 1.0].""" + + swap_word: float = 0.0 + """Probability of swapping two adjacent words [0.0, 1.0].""" + + insert_word: float = 0.0 + """Probability of inserting a filler word [0.0, 1.0].""" + + augmentation_prob: float = 0.5 + """Overall probability of applying any augmentation [0.0, 1.0].""" + + +@dataclass +class LanguageData: + """Single-level language data structure. + + Contains tokenized text and metadata for a single instruction. + + Args: + tokens: Token IDs tensor of shape [seq_len]. + attention_mask: Attention mask tensor of shape [seq_len]. + raw_text: Original raw text string (for debugging). + instruction_type: Type of instruction (imperative, declarative, etc.). + metadata: Additional metadata dictionary. + """ + + tokens: torch.Tensor + attention_mask: torch.Tensor + raw_text: str + instruction_type: str = "imperative" + metadata: Dict[str, Any] = field(default_factory=dict) + + def to_dict(self) -> Dict[str, Any]: + """Convert to dictionary format.""" + return { + "tokens": self.tokens, + "attention_mask": self.attention_mask, + "raw_text": self.raw_text, + "instruction_type": self.instruction_type, + "metadata": self.metadata, + } + + +@dataclass +class HierarchicalLanguageData: + """Hierarchical language data structure for VLA training. + + Organizes language instructions at multiple abstraction levels: + - task_level: High-level goal/description + - subtask_level: Intermediate step descriptions + - primitive_level: Low-level action descriptions + + This structure enables VLA models to learn from multi-scale language + representations, similar to human task understanding. + + Args: + task_level: List of task-level instructions. + subtask_level: List of subtask-level instructions. + primitive_level: List of primitive-level instructions. + hierarchy_depth: Current depth of the hierarchy (1-3). + change_points: Timesteps where language changes within the trajectory. + """ + + task_level: List[LanguageData] = field(default_factory=list) + subtask_level: List[LanguageData] = field(default_factory=list) + primitive_level: List[LanguageData] = field(default_factory=list) + hierarchy_depth: int = 3 + change_points: Optional[List[int]] = None + + def __post_init__(self) -> None: + if self.change_points is None: + self.change_points = [0] + + def get_level(self, level: str) -> List[LanguageData]: + """Get language data for a specific hierarchy level. + + Args: + level: Hierarchy level ('task', 'subtask', 'primitive'). + + Returns: + List of LanguageData for the requested level. + """ + level_map = { + "task": self.task_level, + "subtask": self.subtask_level, + "primitive": self.primitive_level, + } + if level not in level_map: + log_error(f"Invalid hierarchy level: {level}", error_type=ValueError) + return level_map[level] + + def set_level(self, level: str, data: List[LanguageData]) -> None: + """Set language data for a specific hierarchy level. + + Args: + level: Hierarchy level ('task', 'subtask', 'primitive'). + data: List of LanguageData to set. + """ + level_map = { + "task": "task_level", + "subtask": "subtask_level", + "primitive": "primitive_level", + } + if level not in level_map: + log_error(f"Invalid hierarchy level: {level}", error_type=ValueError) + setattr(self, level_map[level], data) + + def flatten(self) -> Dict[str, List[LanguageData]]: + """Flatten hierarchical structure into a dictionary. + + Returns: + Dictionary mapping level names to their language data. + """ + return { + "task": self.task_level, + "subtask": self.subtask_level, + "primitive": self.primitive_level, + } + + def to_buffer_format(self, cfg: LanguageCfg) -> Dict[str, torch.Tensor]: + """Convert hierarchical language data to buffer tensor format. + + Args: + cfg: Language configuration for buffer layout. + + Returns: + Dictionary with tensor fields ready for rollout buffer. + """ + result = {} + + # Process each hierarchy level + for level in cfg.hierarchy_levels: + level_data = self.get_level(level) + level_key = f"{level}_level" + + # Pad to max_instructions_per_level + padded_tokens = [] + padded_masks = [] + + for i in range(cfg.max_instructions_per_level): + if i < len(level_data): + # Pad sequence to max_tokens + tokens = level_data[i].tokens + mask = level_data[i].attention_mask + + seq_len = tokens.shape[0] + if seq_len < cfg.max_tokens: + pad_len = cfg.max_tokens - seq_len + tokens = torch.cat( + [ + tokens, + torch.full( + (pad_len,), + cfg.pad_token_id, + dtype=tokens.dtype, + device=tokens.device, + ), + ] + ) + mask = torch.cat( + [ + mask, + torch.zeros( + (pad_len,), dtype=mask.dtype, device=mask.device + ), + ] + ) + elif seq_len > cfg.max_tokens: + tokens = tokens[: cfg.max_tokens] + mask = mask[: cfg.max_tokens] + else: + # Empty instruction + tokens = torch.full( + (cfg.max_tokens,), + cfg.pad_token_id, + dtype=torch.int64, + device="cpu", + ) + mask = torch.zeros( + (cfg.max_tokens,), + dtype=torch.int64, + device="cpu", + ) + + padded_tokens.append(tokens) + padded_masks.append(mask) + + # Stack instructions + result[f"{level_key}_tokens"] = torch.stack(padded_tokens) + result[f"{level_key}_attention_mask"] = torch.stack(padded_masks) + + # Add instruction counts + result["instruction_counts"] = torch.tensor( + [ + len(self.task_level), + len(self.subtask_level), + len(self.primitive_level), + ], + dtype=torch.int64, + ) + + # Add change points (padded to max_instructions_per_level) + change_points = torch.full( + (cfg.max_instructions_per_level,), + -1, + dtype=torch.int64, + device="cpu", + ) + for i, cp in enumerate(self.change_points[: cfg.max_instructions_per_level]): + change_points[i] = cp + result["change_points"] = change_points + + return result + + +class LanguageManager: + """Manages language data generation, tokenization, and storage. + + The LanguageManager handles: + - Loading and configuring tokenizers + - Generating or retrieving hierarchical language descriptions + - Tokenizing text into model-ready format + - Managing language curriculum and augmentation + + Args: + cfg: Language configuration. + env: Reference to the environment for context. + """ + + def __init__(self, cfg: LanguageCfg, env) -> None: + self.cfg = cfg + self.env = env + self._tokenizer = None + self._load_tokenizer() + + # Curriculum state + self._curriculum_step = 0 + self._current_stage = 0 + + # Cache for tokenized language + self._language_cache: Dict[str, HierarchicalLanguageData] = {} + + log_info( + f"[LanguageManager] Initialized with mode={cfg.mode}, " + f"hierarchy={cfg.hierarchy_levels}, tokenizer={cfg.tokenizer}" + ) + + def _load_tokenizer(self) -> None: + """Load the tokenizer based on configuration.""" + if self.cfg.tokenizer_backend == "huggingface": + try: + from transformers import AutoTokenizer + + self._tokenizer = AutoTokenizer.from_pretrained( + self.cfg.tokenizer, + trust_remote_code=self.cfg.trust_remote_code, + ) + + # Update pad_token_id from tokenizer if not specified + if ( + self.cfg.pad_token_id == 0 + and self._tokenizer.pad_token_id is not None + ): + self.cfg.pad_token_id = self._tokenizer.pad_token_id + + log_info( + f"[LanguageManager] Loaded huggingface tokenizer: {self.cfg.tokenizer}" + ) + except ImportError: + log_error( + "transformers library not installed. " + "Install with: pip install transformers", + error_type=ImportError, + ) + except Exception as e: + log_error( + f"Failed to load huggingface tokenizer: {e}", + error_type=RuntimeError, + ) + elif self.cfg.tokenizer_backend == "openai": + try: + import tiktoken + + self._tokenizer = tiktoken.encoding_for_model(self.cfg.tokenizer) + log_info( + f"[LanguageManager] Loaded OpenAI tokenizer: {self.cfg.tokenizer}" + ) + except ImportError: + log_error( + "tiktoken library not installed. " + "Install with: pip install tiktoken", + error_type=ImportError, + ) + else: + log_error( + f"Unknown tokenizer backend: {self.cfg.tokenizer_backend}", + error_type=ValueError, + ) + + def tokenize( + self, text: str, return_tensors: str = "pt" + ) -> Dict[str, torch.Tensor]: + """Tokenize a single text string. + + Args: + text: Text to tokenize. + return_tensors: Return tensor format ('pt' for PyTorch). + + Returns: + Dictionary with 'input_ids' and 'attention_mask'. + """ + if self._tokenizer is None: + log_error("Tokenizer not initialized", error_type=RuntimeError) + + if self.cfg.tokenizer_backend == "huggingface": + result = self._tokenizer( + text, + max_length=self.cfg.max_tokens, + padding="max_length", + truncation=True, + return_tensors=return_tensors, + ) + # Ensure dtype is int64 + result["input_ids"] = result["input_ids"].to(torch.int64) + result["attention_mask"] = result["attention_mask"].to(torch.int64) + return result + else: # openai/tiktoken + tokens = self._tokenizer.encode( + text, + max_length=self.cfg.max_tokens, + truncation=True, + ) + # Pad to max_tokens + if len(tokens) < self.cfg.max_tokens: + tokens = tokens + [self.cfg.pad_token_id] * ( + self.cfg.max_tokens - len(tokens) + ) + else: + tokens = tokens[: self.cfg.max_tokens] + + input_ids = torch.tensor(tokens, dtype=torch.int64) + attention_mask = (input_ids != self.cfg.pad_token_id).to(torch.int64) + + return {"input_ids": input_ids, "attention_mask": attention_mask} + + def tokenize_batch( + self, texts: List[str], return_tensors: str = "pt" + ) -> Dict[str, torch.Tensor]: + """Tokenize a batch of text strings. + + Args: + texts: List of texts to tokenize. + return_tensors: Return tensor format ('pt' for PyTorch). + + Returns: + Dictionary with 'input_ids' and 'attention_mask' tensors. + """ + if self._tokenizer is None: + log_error("Tokenizer not initialized", error_type=RuntimeError) + + if self.cfg.tokenizer_backend == "huggingface": + result = self._tokenizer( + texts, + max_length=self.cfg.max_tokens, + padding="max_length", + truncation=True, + return_tensors=return_tensors, + ) + result["input_ids"] = result["input_ids"].to(torch.int64) + result["attention_mask"] = result["attention_mask"].to(torch.int64) + return result + else: # openai/tiktoken + batch_tokens = [] + for text in texts: + tokens = self._tokenizer.encode( + text, + max_length=self.cfg.max_tokens, + truncation=True, + ) + if len(tokens) < self.cfg.max_tokens: + tokens = tokens + [self.cfg.pad_token_id] * ( + self.cfg.max_tokens - len(tokens) + ) + else: + tokens = tokens[: self.cfg.max_tokens] + batch_tokens.append(tokens) + + input_ids = torch.tensor(batch_tokens, dtype=torch.int64) + attention_mask = (input_ids != self.cfg.pad_token_id).to(torch.int64) + + return {"input_ids": input_ids, "attention_mask": attention_mask} + + def decode(self, token_ids: torch.Tensor) -> str: + """Decode token IDs back to text. + + Args: + token_ids: Token IDs to decode. + + Returns: + Decoded text string. + """ + if self._tokenizer is None: + log_error("Tokenizer not initialized", error_type=RuntimeError) + + # Remove padding + mask = token_ids != self.cfg.pad_token_id + token_ids = token_ids[mask] + + if self.cfg.tokenizer_backend == "huggingface": + return self._tokenizer.decode(token_ids, skip_special_tokens=True) + else: # openai/tiktoken + return self._tokenizer.decode(token_ids) + + def create_language_data( + self, text: str, instruction_type: str = "imperative", **metadata + ) -> LanguageData: + """Create a LanguageData object from raw text. + + Args: + text: Raw text string. + instruction_type: Type of instruction. + **metadata: Additional metadata. + + Returns: + LanguageData object with tokenized text. + """ + tokenized = self.tokenize(text) + return LanguageData( + tokens=tokenized["input_ids"].squeeze(0), + attention_mask=tokenized["attention_mask"].squeeze(0), + raw_text=text, + instruction_type=instruction_type, + metadata=metadata, + ) + + def create_hierarchical_language_data( + self, + task_texts: List[str] | str, + subtask_texts: Optional[List[str] | str] = None, + primitive_texts: Optional[List[str] | str] = None, + change_points: Optional[List[int]] = None, + ) -> HierarchicalLanguageData: + """Create hierarchical language data from text at multiple levels. + + Args: + task_texts: Task-level descriptions (string or list). + subtask_texts: Subtask-level descriptions (optional). + primitive_texts: Primitive-level descriptions (optional). + change_points: Timesteps where language changes (optional). + + Returns: + HierarchicalLanguageData object. + """ + # Normalize to lists + if isinstance(task_texts, str): + task_texts = [task_texts] + if subtask_texts is not None and isinstance(subtask_texts, str): + subtask_texts = [subtask_texts] + if primitive_texts is not None and isinstance(primitive_texts, str): + primitive_texts = [primitive_texts] + + # Create language data for each level + task_level = [self.create_language_data(text) for text in task_texts] + subtask_level = ( + [self.create_language_data(text) for text in subtask_texts] + if subtask_texts is not None + else [] + ) + primitive_level = ( + [self.create_language_data(text) for text in primitive_texts] + if primitive_texts is not None + else [] + ) + + return HierarchicalLanguageData( + task_level=task_level, + subtask_level=subtask_level, + primitive_level=primitive_level, + change_points=change_points, + ) + + def get_task_language( + self, task_id: Optional[str] = None + ) -> HierarchicalLanguageData: + """Generate or retrieve language description for the current task. + + This method should be overridden in subclasses or configured via + language providers to implement custom language generation logic. + + Args: + task_id: Optional task identifier for cache lookup. + + Returns: + HierarchicalLanguageData for the current task. + """ + cache_key = task_id or "default" + + if cache_key in self._language_cache: + return self._language_cache[cache_key] + + # Default implementation: generate generic task description + task_name = getattr(self.env, "task_name", "unknown_task") + task_description = getattr( + self.env, + "task_description", + f"Complete the {task_name} task.", + ) + + language_data = self.create_hierarchical_language_data( + task_texts=task_description, + subtask_texts=None, # Can be generated by subclasses + primitive_texts=None, # Can be generated by subclasses + ) + + self._language_cache[cache_key] = language_data + return language_data + + def set_curriculum_step( + self, step: int, curriculum_cfg: Optional[LanguageCurriculumCfg] = None + ) -> None: + """Update curriculum learning step. + + Args: + step: Current curriculum step. + curriculum_cfg: Optional curriculum configuration. + """ + self._curriculum_step = step + + if curriculum_cfg and curriculum_cfg.enabled: + self._current_stage = min( + step // curriculum_cfg.stage_duration, + len(curriculum_cfg.stages) - 1, + ) + log_info( + f"[LanguageManager] Curriculum: stage {self._current_stage}/{len(curriculum_cfg.stages)-1} " + f"(step {step})" + ) + + def get_current_stage_constraints( + self, curriculum_cfg: Optional[LanguageCurriculumCfg] = None + ) -> Optional[Dict[str, Any]]: + """Get constraints for the current curriculum stage. + + Args: + curriculum_cfg: Optional curriculum configuration. + + Returns: + Dictionary of constraints or None if curriculum is disabled. + """ + if not curriculum_cfg or not curriculum_cfg.enabled: + return None + + stage = curriculum_cfg.stages[self._current_stage] + return { + "max_words": stage.max_words, + "max_sentences": stage.max_sentences, + "max_hierarchy_depth": stage.max_hierarchy_depth, + "vocabulary_complexity": stage.vocabulary_complexity, + "instruction_types": stage.instruction_types, + } + + def clear_cache(self) -> None: + """Clear the language cache.""" + self._language_cache.clear() + log_info("[LanguageManager] Language cache cleared") diff --git a/embodichain/lab/gym/envs/managers/language_provider.py b/embodichain/lab/gym/envs/managers/language_provider.py new file mode 100644 index 00000000..6ef3a131 --- /dev/null +++ b/embodichain/lab/gym/envs/managers/language_provider.py @@ -0,0 +1,647 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import Any, Dict, List, Optional, Literal +from pathlib import Path + +import yaml +import json + +from embodichain.utils.logger import log_info, log_warning, log_error +from .language import ( + LanguageCfg, + HierarchicalLanguageData, + LanguageData, +) + +__all__ = [ + "LanguageProvider", + "FileBasedLanguageProvider", + "LLMBasedLanguageProvider", + "EnvBasedLanguageProvider", + "TemplateBasedLanguageProvider", +] + + +class LanguageProvider(ABC): + """Abstract base class for language data sources. + + Language providers are responsible for generating or retrieving + hierarchical language descriptions for tasks. Different providers + can be used depending on the data source (files, LLMs, environment, etc.). + + Args: + cfg: Language configuration. + """ + + def __init__(self, cfg: LanguageCfg) -> None: + self.cfg = cfg + + @abstractmethod + def get_language( + self, task_id: str, context: Optional[Dict[str, Any]] = None + ) -> HierarchicalLanguageData: + """Get hierarchical language data for a specific task. + + Args: + task_id: Unique identifier for the task. + context: Optional context dictionary with environment state. + + Returns: + HierarchicalLanguageData with task descriptions at multiple levels. + """ + ... + + @abstractmethod + def get_available_tasks(self) -> List[str]: + """Get list of available task IDs. + + Returns: + List of task identifiers. + """ + ... + + def validate_hierarchy_data(self, data: HierarchicalLanguageData) -> bool: + """Validate that hierarchical language data meets configuration constraints. + + Args: + data: HierarchicalLanguageData to validate. + + Returns: + True if data is valid, False otherwise. + """ + # Check each level doesn't exceed max instructions + if len(data.task_level) > self.cfg.max_instructions_per_level: + log_warning( + f"Task level has {len(data.task_level)} instructions, " + f"exceeding max {self.cfg.max_instructions_per_level}" + ) + return False + + if len(data.subtask_level) > self.cfg.max_instructions_per_level: + log_warning( + f"Subtask level has {len(data.subtask_level)} instructions, " + f"exceeding max {self.cfg.max_instructions_per_level}" + ) + return False + + if len(data.primitive_level) > self.cfg.max_instructions_per_level: + log_warning( + f"Primitive level has {len(data.primitive_level)} instructions, " + f"exceeding max {self.cfg.max_instructions_per_level}" + ) + return False + + return True + + +class FileBasedLanguageProvider(LanguageProvider): + """Language provider that loads task descriptions from files. + + Supports YAML and JSON file formats. The file structure should contain + task IDs mapped to their hierarchical descriptions. + + Example YAML structure: + ```yaml + pick_and_place: + task: + - "Pick up the red block and place it in the blue basket." + subtask: + - "Move the gripper to the red block." + - "Grasp the red block." + - "Lift the block and move to the blue basket." + - "Release the block into the basket." + primitive: + - "Close gripper." + - "Move up." + - "Move right." + - "Open gripper." + ``` + + Args: + cfg: Language configuration. + config_path: Path to the configuration file (YAML or JSON). + reload_on_access: Whether to reload the file on each access (for dynamic updates). + """ + + def __init__( + self, + cfg: LanguageCfg, + config_path: str, + reload_on_access: bool = False, + ) -> None: + super().__init__(cfg) + self.config_path = Path(config_path) + self.reload_on_access = reload_on_access + self._data: Dict[str, Any] = {} + self._load_data() + + def _load_data(self) -> None: + """Load language data from the configuration file.""" + if not self.config_path.exists(): + log_error( + f"Language config file not found: {self.config_path}", + error_type=FileNotFoundError, + ) + + suffix = self.config_path.suffix.lower() + + try: + with open(self.config_path, "r", encoding="utf-8") as f: + if suffix in [".yaml", ".yml"]: + self._data = yaml.safe_load(f) + elif suffix == ".json": + self._data = json.load(f) + else: + log_error( + f"Unsupported file format: {suffix}. Use .yaml, .yml, or .json", + error_type=ValueError, + ) + + log_info( + f"[FileBasedLanguageProvider] Loaded {len(self._data)} task descriptions " + f"from {self.config_path}" + ) + except Exception as e: + log_error( + f"Failed to load language config from {self.config_path}: {e}", + error_type=RuntimeError, + ) + + def get_language( + self, task_id: str, context: Optional[Dict[str, Any]] = None + ) -> HierarchicalLanguageData: + """Get language data from file for a specific task. + + Args: + task_id: Unique identifier for the task. + context: Optional context (not used in file-based provider). + + Returns: + HierarchicalLanguageData loaded from file. + """ + if self.reload_on_access: + self._load_data() + + if task_id not in self._data: + log_error( + f"Task ID '{task_id}' not found in language config. " + f"Available tasks: {list(self._data.keys())}", + error_type=KeyError, + ) + + task_data = self._data[task_id] + + # Extract hierarchical descriptions + task_texts = task_data.get("task", []) + subtask_texts = task_data.get("subtask", []) + primitive_texts = task_data.get("primitive", []) + change_points = task_data.get("change_points", None) + + # Import LanguageManager to create data (we need tokenizer access) + from .language import LanguageManager + + # Create a temporary manager for tokenization + # In practice, the environment should provide the manager + class _TempManager: + def __init__(self, cfg): + self.cfg = cfg + self._tokenizer = None + self._load_tokenizer() + + def _load_tokenizer(self): + if self.cfg.tokenizer_backend == "huggingface": + from transformers import AutoTokenizer + + self._tokenizer = AutoTokenizer.from_pretrained( + self.cfg.tokenizer, + trust_remote_code=self.cfg.trust_remote_code, + ) + if ( + self.cfg.pad_token_id == 0 + and self._tokenizer.pad_token_id is not None + ): + self.cfg.pad_token_id = self._tokenizer.pad_token_id + else: + import tiktoken + + self._tokenizer = tiktoken.encoding_for_model(self.cfg.tokenizer) + + def tokenize(self, text): + if self.cfg.tokenizer_backend == "huggingface": + result = self._tokenizer( + text, + max_length=self.cfg.max_tokens, + padding="max_length", + truncation=True, + return_tensors="pt", + ) + return result["input_ids"].squeeze(0).to(torch.int64), result[ + "attention_mask" + ].squeeze(0).to(torch.int64) + else: + import torch + + tokens = self._tokenizer.encode( + text, max_length=self.cfg.max_tokens, truncation=True + ) + if len(tokens) < self.cfg.max_tokens: + tokens = tokens + [self.cfg.pad_token_id] * ( + self.cfg.max_tokens - len(tokens) + ) + else: + tokens = tokens[: self.cfg.max_tokens] + input_ids = torch.tensor(tokens, dtype=torch.int64) + attention_mask = (input_ids != self.cfg.pad_token_id).to( + torch.int64 + ) + return input_ids, attention_mask + + def create_language_data(self, text): + tokens, mask = self.tokenize(text) + return LanguageData(tokens=tokens, attention_mask=mask, raw_text=text) + + temp_mgr = _TempManager(self.cfg) + + # Build hierarchical language data + task_level = [ + temp_mgr.create_language_data(t) if isinstance(t, str) else t + for t in (task_texts if isinstance(task_texts, list) else [task_texts]) + ] + subtask_level = ( + [ + temp_mgr.create_language_data(t) if isinstance(t, str) else t + for t in (subtask_texts if isinstance(subtask_texts, list) else []) + ] + if subtask_texts + else [] + ) + primitive_level = ( + [ + temp_mgr.create_language_data(t) if isinstance(t, str) else t + for t in (primitive_texts if isinstance(primitive_texts, list) else []) + ] + if primitive_texts + else [] + ) + + return HierarchicalLanguageData( + task_level=task_level, + subtask_level=subtask_level, + primitive_level=primitive_level, + change_points=change_points, + ) + + def get_available_tasks(self) -> List[str]: + """Get list of available task IDs from the file. + + Returns: + List of task identifiers. + """ + return list(self._data.keys()) + + +class LLMBasedLanguageProvider(LanguageProvider): + """Language provider that generates descriptions using an LLM. + + This provider uses a language model to generate task descriptions + on-the-fly based on task context and templates. + + Args: + cfg: Language configuration. + model: Model identifier (e.g., "gpt-4", "claude-3-opus"). + api_key: API key for the LLM service. + templates: Optional dictionary of templates for different task types. + """ + + def __init__( + self, + cfg: LanguageCfg, + model: str = "gpt-4", + api_key: Optional[str] = None, + templates: Optional[Dict[str, Any]] = None, + ) -> None: + super().__init__(cfg) + self.model = model + self.api_key = api_key + self.templates = templates or self._default_templates() + self._client = None + self._init_client() + + def _default_templates(self) -> Dict[str, str]: + """Default prompt templates for language generation.""" + return { + "task": "Generate a clear, concise task description for: {task_name}.", + "subtask": "Break down the task '{task_name}' into {num_steps} step-by-step instructions.", + "primitive": "For each subtask, provide low-level action descriptions in: {task_name}.", + } + + def _init_client(self) -> None: + """Initialize the LLM client based on model type.""" + if self.model.startswith("gpt"): + try: + import openai + + self._client = openai.OpenAI(api_key=self.api_key) + except ImportError: + log_warning( + "openai library not available. LLM provider will use fallback." + ) + elif self.model.startswith("claude"): + try: + import anthropic + + self._client = anthropic.Anthropic(api_key=self.api_key) + except ImportError: + log_warning( + "anthropic library not available. LLM provider will use fallback." + ) + else: + log_warning( + f"Unknown model type: {self.model}. LLM provider will use fallback." + ) + + def _generate_with_llm(self, prompt: str) -> str: + """Generate text using the configured LLM. + + Args: + prompt: The prompt to send to the LLM. + + Returns: + Generated text string. + """ + if self._client is None: + # Fallback: return a generic response + log_warning("LLM client not available, using fallback response.") + return "Complete the task as described in the environment." + + try: + if self.model.startswith("gpt"): + response = self._client.chat.completions.create( + model=self.model, + messages=[{"role": "user", "content": prompt}], + max_tokens=500, + temperature=0.7, + ) + return response.choices[0].message.content + elif self.model.startswith("claude"): + response = self._client.messages.create( + model=self.model, + max_tokens=500, + messages=[{"role": "user", "content": prompt}], + ) + return response.content[0].text + except Exception as e: + log_warning(f"LLM generation failed: {e}. Using fallback.") + return "Complete the task as described in the environment." + + def get_language( + self, task_id: str, context: Optional[Dict[str, Any]] = None + ) -> HierarchicalLanguageData: + """Generate language data using LLM for a specific task. + + Args: + task_id: Unique identifier for the task. + context: Optional context with task details. + + Returns: + HierarchicalLanguageData generated by LLM. + """ + task_name = context.get("task_name", task_id) if context else task_id + + # Generate task-level description + task_prompt = self.templates["task"].format(task_name=task_name) + task_text = self._generate_with_llm(task_prompt) + + # Generate subtask-level descriptions + num_subtasks = context.get("num_subtasks", 3) if context else 3 + subtask_prompt = self.templates["subtask"].format( + task_name=task_name, num_steps=num_subtasks + ) + subtask_text = self._generate_with_llm(subtask_prompt) + subtask_texts = [ + line.strip() for line in subtask_text.split("\n") if line.strip() + ] + + # Generate primitive-level descriptions (optional) + primitive_texts = [] + if context and context.get("include_primitive", False): + primitive_prompt = self.templates["primitive"].format(task_name=task_name) + primitive_text = self._generate_with_llm(primitive_prompt) + primitive_texts = [ + line.strip() for line in primitive_text.split("\n") if line.strip() + ] + + # Create LanguageData objects (would need LanguageManager in practice) + # This is a simplified version - in production, use LanguageManager + return HierarchicalLanguageData( + task_level=[], # Would be populated with LanguageData objects + subtask_level=[], + primitive_level=[], + ) + + def get_available_tasks(self) -> List[str]: + """Get list of available task IDs. + + For LLM provider, this returns an empty list as tasks are + generated on-the-fly. + + Returns: + Empty list (tasks are generated dynamically). + """ + return [] + + +class EnvBasedLanguageProvider(LanguageProvider): + """Language provider that extracts descriptions from the environment. + + This provider delegates language generation to the environment itself, + allowing task-specific implementations to provide custom logic. + + Args: + cfg: Language configuration. + env: The environment instance. + """ + + def __init__(self, cfg: LanguageCfg, env) -> None: + super().__init__(cfg) + self.env = env + + def get_language( + self, task_id: str, context: Optional[Dict[str, Any]] = None + ) -> HierarchicalLanguageData: + """Get language data from the environment. + + The environment should implement one of: + - get_task_language(task_id, context) -> HierarchicalLanguageData + - task_description attribute (simple string) + - generate_task_description() method + + Args: + task_id: Unique identifier for the task. + context: Optional context dictionary. + + Returns: + HierarchicalLanguageData from the environment. + """ + # Check for dedicated method + if hasattr(self.env, "get_task_language"): + return self.env.get_task_language(task_id, context) + + # Check for attribute + if hasattr(self.env, "task_description"): + task_desc = self.env.task_description + # Would need LanguageManager to tokenize + return HierarchicalLanguageData( + task_level=[], # Would be populated + subtask_level=[], + primitive_level=[], + ) + + # Check for method + if hasattr(self.env, "generate_task_description"): + task_desc = self.env.generate_task_description(context) + return HierarchicalLanguageData( + task_level=[], + subtask_level=[], + primitive_level=[], + ) + + log_error( + "Environment does not provide language data. " + "Implement get_task_language, set task_description attribute, or generate_task_description method.", + error_type=NotImplementedError, + ) + + def get_available_tasks(self) -> List[str]: + """Get list of available task IDs from the environment. + + The environment can optionally provide: + - available_tasks attribute + - get_available_tasks() method + + Returns: + List of task identifiers or empty list. + """ + if hasattr(self.env, "available_tasks"): + return self.env.available_tasks + + if hasattr(self.env, "get_available_tasks"): + return self.env.get_available_tasks() + + return [] + + +class TemplateBasedLanguageProvider(LanguageProvider): + """Language provider that uses templates with variable substitution. + + This provider fills in templates with task-specific variables to generate + hierarchical descriptions. Useful for structured tasks with predictable patterns. + + Example templates: + ```python + templates = { + "pick_and_place": { + "task": "Pick up the {color} {object} and place it {location}.", + "subtasks": [ + "Move to the {color} {object}.", + "Grasp the {color} {object}.", + "Move {location}.", + "Release the {object}.", + ], + } + } + ``` + + Args: + cfg: Language configuration. + templates: Dictionary of templates keyed by task ID. + variables: Optional default variable values. + """ + + def __init__( + self, + cfg: LanguageCfg, + templates: Dict[str, Dict[str, Any]], + variables: Optional[Dict[str, Any]] = None, + ) -> None: + super().__init__(cfg) + self.templates = templates + self.variables = variables or {} + + def get_language( + self, task_id: str, context: Optional[Dict[str, Any]] = None + ) -> HierarchicalLanguageData: + """Generate language data from templates for a specific task. + + Args: + task_id: Unique identifier for the task. + context: Optional context with variable values. + + Returns: + HierarchicalLanguageData generated from templates. + """ + if task_id not in self.templates: + log_error( + f"Task ID '{task_id}' not found in templates. " + f"Available tasks: {list(self.templates.keys())}", + error_type=KeyError, + ) + + template = self.templates[task_id] + + # Merge default variables with context + vars_to_use = {**self.variables, **(context or {})} + + # Fill in task-level template + task_template = template.get("task", "Complete the task.") + task_text = task_template.format(**vars_to_use) + + # Fill in subtask templates + subtask_templates = template.get("subtasks", []) + subtask_texts = [ + st.format(**vars_to_use) for st in subtask_templates if isinstance(st, str) + ] + + # Fill in primitive templates + primitive_templates = template.get("primitives", []) + primitive_texts = [ + pt.format(**vars_to_use) + for pt in primitive_templates + if isinstance(pt, str) + ] + + # Get change points if specified + change_points = template.get("change_points", None) + + # Would need LanguageManager to tokenize - return placeholder + return HierarchicalLanguageData( + task_level=[], # Would be populated with LanguageData objects + subtask_level=[], + primitive_level=[], + change_points=change_points, + ) + + def get_available_tasks(self) -> List[str]: + """Get list of available task IDs from templates. + + Returns: + List of task identifiers. + """ + return list(self.templates.keys()) diff --git a/embodichain/lab/gym/utils/gym_utils.py b/embodichain/lab/gym/utils/gym_utils.py index fc9a5ffe..6ed0bd56 100644 --- a/embodichain/lab/gym/utils/gym_utils.py +++ b/embodichain/lab/gym/utils/gym_utils.py @@ -20,7 +20,7 @@ import argparse import gymnasium -from typing import Dict, Any, List, Tuple, Union, Sequence +from typing import Dict, Any, List, Tuple, Union, Sequence, Optional from gymnasium import spaces from copy import deepcopy from tensordict import TensorDict @@ -958,12 +958,117 @@ def _init_buffer_from_space( return rollout_buffer +def _init_language_buffer( + language_cfg: dict, + batch_size: int, + max_episode_steps: int, + device: Union[str, torch.device] = "cpu", +) -> Dict[str, torch.Tensor]: + """Initialize language buffer fields for VLA training. + + Creates tensor fields for hierarchical language data storage. + + Args: + language_cfg (dict): Language configuration dictionary. + batch_size (int): Number of parallel environments. + max_episode_steps (int): Maximum episode length. + device (Union[str, torch.device]): Device for tensor allocation. + + Returns: + Dict[str, torch.Tensor]: Dictionary of language tensors. + """ + # Get configuration parameters with defaults + hierarchy_levels = language_cfg.get( + "hierarchy_levels", ["task", "subtask", "primitive"] + ) + max_tokens = language_cfg.get("max_tokens", 512) + max_instructions = language_cfg.get("max_instructions_per_level", 3) + pad_token_id = language_cfg.get("pad_token_id", 0) + mode = language_cfg.get("mode", "tokens") + + language_desc = {} + + # Create tensor fields for each hierarchy level + for level in hierarchy_levels: + level_key = f"{level}_level" + + # Token IDs: [batch_size, max_episode_steps, max_instructions, max_tokens] + language_desc[f"{level_key}_tokens"] = torch.zeros( + (batch_size, max_episode_steps, max_instructions, max_tokens), + dtype=torch.int64, + device=device, + ) + + # Attention mask: [batch_size, max_episode_steps, max_instructions, max_tokens] + language_desc[f"{level_key}_attention_mask"] = torch.zeros( + (batch_size, max_episode_steps, max_instructions, max_tokens), + dtype=torch.int64, + device=device, + ) + + # Instruction count per level: [batch_size, max_episode_steps] + language_desc[f"{level_key}_count"] = torch.zeros( + (batch_size, max_episode_steps), + dtype=torch.int64, + device=device, + ) + + # Instruction count by hierarchy level: [batch_size, max_episode_steps, 3] + # 3 corresponds to [task, subtask, primitive] levels + language_desc["instruction_counts"] = torch.zeros( + (batch_size, max_episode_steps, 3), + dtype=torch.int64, + device=device, + ) + + # Change points: [batch_size, max_episode_steps, max_instructions] + # Timesteps where language changes within the trajectory + language_desc["change_points"] = torch.full( + (batch_size, max_episode_steps, max_instructions), + -1, + dtype=torch.int64, + device=device, + ) + + # Hierarchy depth: [batch_size, max_episode_steps] + # Current depth of hierarchy used (1=task only, 2=task+subtask, 3=all) + language_desc["hierarchy_depth"] = torch.full( + (batch_size, max_episode_steps), + len(hierarchy_levels), + dtype=torch.int64, + device=device, + ) + + # Instruction type IDs: [batch_size, max_episode_steps, max_instructions] + # Encoding of instruction types (e.g., 0=imperative, 1=declarative, 2=conditional) + language_desc["instruction_types"] = torch.zeros( + (batch_size, max_episode_steps, max_instructions), + dtype=torch.int64, + device=device, + ) + + # Optional: Embedding storage for mode='embeddings' or mode='hybrid' + if mode in ("embeddings", "hybrid"): + embedding_dim = language_cfg.get("embedding_dim", 768) + for level in hierarchy_levels: + level_key = f"{level}_level" + # Embeddings: [batch_size, max_episode_steps, max_instructions, embedding_dim] + language_desc[f"{level_key}_embeddings"] = torch.zeros( + (batch_size, max_episode_steps, max_instructions, embedding_dim), + dtype=torch.float32, + device=device, + ) + + return language_desc + + def init_rollout_buffer_from_config( config: dict, max_episode_steps: int, batch_size: int, state_dim: int, device: Union[str, torch.device] = "cpu", + language_cfg: Optional[dict] = None, ) -> TensorDict: """Initialize a rollout buffer based on the environment configuration. @@ -972,15 +1077,19 @@ def init_rollout_buffer_from_config( - Sensor observations: ``sensor/`` for each sensor in config - Extra observations: Custom observations from observation functors in ``add`` mode that have a ``shape`` specified in their ``extra`` parameter + - Language data: Hierarchical language descriptions for VLA training (if language_cfg is provided) Args: config (dict): The environment configuration dictionary. max_episode_steps (int): The number of steps in an episode. batch_size (int): The batch size for the rollout buffer. state_dim (int): The dimension of the flattened state vector. + language_cfg (Optional[dict]): Language configuration for VLA training. + If provided, language fields will be added to the buffer. Returns: - TensorDict: A TensorDict containing the initialized rollout buffer with keys 'obs', 'actions' and 'rewards'. + TensorDict: A TensorDict containing the initialized rollout buffer with keys 'obs', 'actions' and 'rewards', + and optionally 'language' if language_cfg is provided. """ # TODO: Currently we use this method to pre-allocate a rollout buffer with fixed size for simplicity. @@ -1134,4 +1243,18 @@ def init_rollout_buffer_from_config( for obs_name, obs_tensor in extra_obs_desc.items(): assign_data_to_dict(rollout_buffer["obs"], obs_name, obs_tensor) + # Add language data for VLA training if language config is provided + if language_cfg is not None: + language_desc = _init_language_buffer( + language_cfg, batch_size, max_episode_steps, device + ) + rollout_buffer["language"] = TensorDict( + language_desc, + batch_size=[batch_size, max_episode_steps], + device=device, + ) + log_info( + f"[init_rollout_buffer_from_config] Language buffer added with hierarchy levels: {language_cfg.get('hierarchy_levels', ['task', 'subtask', 'primitive'])}" + ) + return rollout_buffer diff --git a/tests/agents/test_language_support.py b/tests/agents/test_language_support.py new file mode 100644 index 00000000..112adb34 --- /dev/null +++ b/tests/agents/test_language_support.py @@ -0,0 +1,325 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +"""Tests for language support in ODS and VLA training.""" + +import pytest +import torch +import tempfile +from pathlib import Path + +from embodichain.lab.gym.envs.managers import ( + LanguageCfg, + LanguageManager, + LanguageData, + HierarchicalLanguageData, + FileBasedLanguageProvider, + TemplateBasedLanguageProvider, +) +from embodichain.lab.gym.utils.gym_utils import _init_language_buffer + + +class MockEnv: + """Mock environment for testing.""" + + task_name = "test_task" + task_description = "Complete the test task." + + +class TestLanguageData: + """Tests for LanguageData and HierarchicalLanguageData.""" + + def test_language_data_creation(self): + """Test creating LanguageData objects.""" + tokens = torch.tensor([1, 2, 3, 0, 0], dtype=torch.int64) + mask = torch.tensor([1, 1, 1, 0, 0], dtype=torch.int64) + + data = LanguageData( + tokens=tokens, + attention_mask=mask, + raw_text="Test instruction", + instruction_type="imperative", + ) + + assert data.tokens.shape == (5,) + assert data.attention_mask.shape == (5,) + assert data.raw_text == "Test instruction" + assert data.instruction_type == "imperative" + + def test_hierarchical_language_data_creation(self): + """Test creating HierarchicalLanguageData.""" + task_tokens = torch.tensor([1, 2, 3, 0], dtype=torch.int64) + task_mask = torch.tensor([1, 1, 1, 0], dtype=torch.int64) + + task_data = LanguageData( + tokens=task_tokens, + attention_mask=task_mask, + raw_text="Task description", + ) + + subtask_tokens = torch.tensor([4, 5, 0, 0], dtype=torch.int64) + subtask_mask = torch.tensor([1, 1, 0, 0], dtype=torch.int64) + + subtask_data = LanguageData( + tokens=subtask_tokens, + attention_mask=subtask_mask, + raw_text="Subtask description", + ) + + hierarchical = HierarchicalLanguageData( + task_level=[task_data], + subtask_level=[subtask_data], + primitive_level=[], + ) + + assert len(hierarchical.task_level) == 1 + assert len(hierarchical.subtask_level) == 1 + assert len(hierarchical.primitive_level) == 0 + + def test_hierarchical_language_data_flatten(self): + """Test flattening hierarchical language data.""" + task_data = LanguageData( + tokens=torch.tensor([1, 2, 0], dtype=torch.int64), + attention_mask=torch.tensor([1, 1, 0], dtype=torch.int64), + raw_text="Task", + ) + + hierarchical = HierarchicalLanguageData( + task_level=[task_data], + subtask_level=[], + primitive_level=[], + ) + + flattened = hierarchical.flatten() + assert "task" in flattened + assert "subtask" in flattened + assert "primitive" in flattened + + +class TestLanguageBuffer: + """Tests for language buffer initialization.""" + + def test_init_language_buffer(self): + """Test initializing language buffer tensors.""" + language_cfg = { + "hierarchy_levels": ["task", "subtask"], + "max_tokens": 256, + "max_instructions_per_level": 3, + "pad_token_id": 0, + "mode": "tokens", + } + + buffer = _init_language_buffer( + language_cfg, batch_size=4, max_episode_steps=100, device="cpu" + ) + + # Check that expected keys are present + assert "task_level_tokens" in buffer + assert "task_level_attention_mask" in buffer + assert "subtask_level_tokens" in buffer + assert "subtask_level_attention_mask" in buffer + + # Check tensor shapes + assert buffer["task_level_tokens"].shape == (4, 100, 3, 256) + assert buffer["task_level_attention_mask"].shape == (4, 100, 3, 256) + assert buffer["task_level_count"].shape == (4, 100) + + # Check global fields + assert "instruction_counts" in buffer + assert buffer["instruction_counts"].shape == (4, 100, 3) + assert "change_points" in buffer + assert buffer["change_points"].shape == (4, 100, 3) + assert "hierarchy_depth" in buffer + assert buffer["hierarchy_depth"].shape == (4, 100) + + +class TestLanguageManager: + """Tests for LanguageManager.""" + + def test_language_manager_initialization(self): + """Test initializing LanguageManager.""" + cfg = LanguageCfg( + mode="tokens", + hierarchy_levels=["task", "subtask"], + max_tokens=256, + tokenizer="gpt2", + ) + + env = MockEnv() + + # Test with a simple tokenizer that doesn't require external dependencies + try: + manager = LanguageManager(cfg, env) + assert manager.cfg == cfg + assert manager.env == env + except (ImportError, RuntimeError) as e: + pytest.skip(f"Tokenizer not available: {e}") + + def test_create_language_data(self): + """Test creating LanguageData from raw text.""" + cfg = LanguageCfg( + mode="tokens", + max_tokens=256, + tokenizer="gpt2", + ) + + env = MockEnv() + + try: + manager = LanguageManager(cfg, env) + data = manager.create_language_data("Test instruction") + assert isinstance(data, LanguageData) + assert data.raw_text == "Test instruction" + except (ImportError, RuntimeError) as e: + pytest.skip(f"Tokenizer not available: {e}") + + def test_create_hierarchical_language_data(self): + """Test creating hierarchical language data.""" + cfg = LanguageCfg( + mode="tokens", + max_tokens=256, + tokenizer="gpt2", + ) + + env = MockEnv() + + try: + manager = LanguageManager(cfg, env) + data = manager.create_hierarchical_language_data( + task_texts="Pick up the block.", + subtask_texts=["Move to block.", "Grasp block."], + primitive_texts=["Close gripper."], + ) + + assert isinstance(data, HierarchicalLanguageData) + assert len(data.task_level) == 1 + assert len(data.subtask_level) == 2 + assert len(data.primitive_level) == 1 + except (ImportError, RuntimeError) as e: + pytest.skip(f"Tokenizer not available: {e}") + + def test_to_buffer_format(self): + """Test converting hierarchical data to buffer format.""" + cfg = LanguageCfg( + mode="tokens", + hierarchy_levels=["task", "subtask"], + max_tokens=256, + max_instructions_per_level=3, + tokenizer="gpt2", + ) + + env = MockEnv() + + try: + manager = LanguageManager(cfg, env) + data = manager.create_hierarchical_language_data( + task_texts="Task description.", + subtask_texts=["Step 1.", "Step 2."], + ) + + buffer_format = data.to_buffer_format(cfg) + + assert "task_level_tokens" in buffer_format + assert "subtask_level_tokens" in buffer_format + assert "instruction_counts" in buffer_format + + # Check shapes + assert buffer_format["task_level_tokens"].shape == (3, 256) + assert buffer_format["subtask_level_tokens"].shape == (3, 256) + except (ImportError, RuntimeError) as e: + pytest.skip(f"Tokenizer not available: {e}") + + +class TestFileBasedLanguageProvider: + """Tests for FileBasedLanguageProvider.""" + + def test_file_provider_initialization(self): + """Test initializing file-based provider.""" + cfg = LanguageCfg( + mode="tokens", + max_tokens=256, + tokenizer="gpt2", + ) + + # Create a temporary YAML file + with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f: + f.write(""" +test_task: + task: + - "Test task description." + subtask: + - "Step 1." + - "Step 2." +""") + temp_path = f.name + + try: + provider = FileBasedLanguageProvider(cfg, temp_path) + assert provider.config_path == Path(temp_path) + assert "test_task" in provider.get_available_tasks() + finally: + Path(temp_path).unlink() + + +class TestTemplateBasedLanguageProvider: + """Tests for TemplateBasedLanguageProvider.""" + + def test_template_provider_initialization(self): + """Test initializing template-based provider.""" + cfg = LanguageCfg( + mode="tokens", + max_tokens=256, + tokenizer="gpt2", + ) + + templates = { + "test_task": { + "task": "Complete the {object} task.", + "subtasks": ["Move to {object}.", "Grasp {object}."], + } + } + + provider = TemplateBasedLanguageProvider(cfg, templates) + assert "test_task" in provider.get_available_tasks() + + def test_template_provider_get_language(self): + """Test getting language from templates.""" + cfg = LanguageCfg( + mode="tokens", + max_tokens=256, + tokenizer="gpt2", + ) + + templates = { + "test_task": { + "task": "Pick up the {color} {object}.", + "subtasks": [ + "Move to {color} {object}.", + "Grasp {color} {object}.", + ], + } + } + + provider = TemplateBasedLanguageProvider(cfg, templates) + + context = {"color": "red", "object": "block"} + language_data = provider.get_language("test_task", context) + + assert isinstance(language_data, HierarchicalLanguageData) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"])