diff --git a/src/diffusers/pipelines/cosmos/pipeline_cosmos3_omni.py b/src/diffusers/pipelines/cosmos/pipeline_cosmos3_omni.py index 5425b7b575eb..c9ccced66540 100644 --- a/src/diffusers/pipelines/cosmos/pipeline_cosmos3_omni.py +++ b/src/diffusers/pipelines/cosmos/pipeline_cosmos3_omni.py @@ -15,6 +15,7 @@ import copy import json import math +from collections.abc import Iterable from dataclasses import dataclass from typing import Any, Callable, Literal @@ -704,6 +705,9 @@ def _remove_action_video_padding_from_latent( def prepare_latents( self, image: torch.Tensor | None = None, + video: list[Image.Image] | torch.Tensor | np.ndarray | None = None, + condition_frame_indexes_vision: Iterable[int] = (0, 1), + condition_video_keep: Literal["first", "last"] = "first", num_frames: int | None = None, height: int | None = None, width: int | None = None, @@ -737,6 +741,8 @@ def prepare_latents( action_mode = action.mode if action is not None else None is_image = num_frames == 1 has_image_condition = (image is not None and not is_image) or action_mode is not None + # Video-to-video conditioning: a top-level `video` without an action run. + has_video_condition = video is not None and action is None # video_processor.preprocess handles PIL/np/tensor → [1, 3, H, W] in [-1, 1], resized to (height, width). conditioning_frame_2d: torch.Tensor | None = None @@ -745,6 +751,19 @@ def prepare_latents( device=device, dtype=dtype ) + conditioning_frames_3d: torch.Tensor | None = None + condition_indexes_vision: tuple[int, ...] = tuple(condition_frame_indexes_vision) + if has_video_condition: + conditioning_frames_3d = self.video_processor.preprocess_video(video, height=height, width=width).to( + device=device, dtype=dtype + ) + temporal_compression = int(self.vae.config.scale_factor_temporal) + max_cond_frames = max(condition_indexes_vision) * temporal_compression + 1 + if condition_video_keep == "first": + conditioning_frames_3d = conditioning_frames_3d[:, :, :max_cond_frames] + else: + conditioning_frames_3d = conditioning_frames_3d[:, :, -max_cond_frames:] + action_domain_id: torch.Tensor | None = None action_condition_mask: torch.Tensor | None = None raw_action_dim_resolved: int | None = ( @@ -789,7 +808,17 @@ def prepare_latents( ) else: vision_tensor = torch.zeros(1, 3, num_frames, height, width, dtype=dtype, device=device) - if conditioning_frame_2d is not None: + if conditioning_frames_3d is not None: + # Video-to-video: place the leading conditioning frames at the start, repeat-pad the tail with the + # last conditioning frame, then mark the conditioned latent indexes clean (encoded as a whole below). + t_fill = min(conditioning_frames_3d.shape[2], num_frames) + vision_tensor[:, :, :t_fill] = conditioning_frames_3d[:, :, :t_fill] + if t_fill < num_frames: + vision_tensor[:, :, t_fill:] = vision_tensor[:, :, t_fill - 1 : t_fill].expand( + -1, -1, num_frames - t_fill, -1, -1 + ) + vision_condition_frames = list(condition_indexes_vision) + elif conditioning_frame_2d is not None: # Single conditioning frame at t=0, repeat-pad the rest with that same frame. vision_tensor[:, :, 0] = conditioning_frame_2d if num_frames > 1: @@ -928,6 +957,8 @@ def check_inputs( enable_sound: bool, callback_on_step_end_tensor_inputs: list[str], action: "CosmosActionCondition | None" = None, + video: list[Image.Image] | torch.Tensor | np.ndarray | None = None, + condition_frame_indexes_vision: Iterable[int] = (0, 1), ) -> None: if not isinstance(prompt, (str, list)) or ( isinstance(prompt, list) and not all(isinstance(p, str) for p in prompt) @@ -958,6 +989,8 @@ def check_inputs( raise ValueError( "Pass action conditioning via `action.image` / `action.video`, not the top-level `image` argument." ) + if video is not None: + raise ValueError("Pass action conditioning via `action.video`, not the top-level `video` argument.") if not getattr(self.transformer.config, "action_gen", False): raise ValueError("`action` requires a transformer trained with action_gen=True.") if action.mode == "forward_dynamics" and action.raw_actions is not None: @@ -976,6 +1009,27 @@ def check_inputs( sf = int(self.vae.config.scale_factor_spatial) if height % sf != 0 or width % sf != 0: raise ValueError(f"`height` and `width` must be multiples of {sf}, got ({height}, {width}).") + if image is not None and video is not None: + raise ValueError("Pass either `image` (image-to-video) or `video` (video-to-video), not both.") + if video is not None: + if num_frames == 1: + raise ValueError("`video` conditioning requires `num_frames` > 1.") + if isinstance(condition_frame_indexes_vision, (str, bytes)) or not all( + isinstance(index, int) and index >= 0 for index in condition_frame_indexes_vision + ): + raise ValueError( + f"`condition_frame_indexes_vision` must be a list of non-negative ints, e.g. [0, 1]; got " + f"{condition_frame_indexes_vision!r}." + ) + indexes = tuple(condition_frame_indexes_vision) + if not indexes: + raise ValueError("`condition_frame_indexes_vision` must contain at least one index.") + latent_t = (num_frames - 1) // int(self.vae.config.scale_factor_temporal) + 1 + if max(indexes) >= latent_t: + raise ValueError( + f"`condition_frame_indexes_vision` {indexes} contains an index outside the latent timeline " + f"(latent_frames={latent_t} for num_frames={num_frames})." + ) @staticmethod def _build_action_json_prompt( @@ -1198,6 +1252,9 @@ def __call__( prompt: str | list[str], negative_prompt: str | list[str] | None = None, image: torch.Tensor | None = None, + video: list[Image.Image] | torch.Tensor | np.ndarray | None = None, + condition_frame_indexes_vision: Iterable[int] = (0, 1), + condition_video_keep: Literal["first", "last"] = "first", num_frames: int | None = None, height: int | None = None, width: int | None = None, @@ -1223,9 +1280,13 @@ def __call__( enable_safety_check: bool = True, ) -> Cosmos3OmniPipelineOutput: r""" - Run the Cosmos 3 omni pipeline end-to-end: encode the (optional) conditioning image, denoise vision and + Run the Cosmos 3 omni pipeline end-to-end: encode the (optional) conditioning image/video, denoise vision and (optional) sound latents jointly, and decode them back into a video and audio waveform. + The generation mode is selected from the inputs: text-to-image when `num_frames == 1`, image-to-video when + `image` is supplied, video-to-video (generation) when `video` is supplied (without `action`), action-conditioned generation + when `action` is supplied, and text-to-video otherwise. + Args: prompt (`str` or `List[str]`): The prompt to guide generation. Lists are collapsed to the first entry — the pipeline runs one sample @@ -1235,6 +1296,20 @@ def __call__( image (`torch.Tensor` or `PIL.Image.Image`, *optional*): Optional conditioning frame for image-to-video. The pipeline anchors frame 0 to this image and denoises the remaining frames. Ignored when `num_frames == 1`. Not used for action runs (pass `action` instead). + Mutually exclusive with `video`. + video (`List[PIL.Image.Image]`, `torch.Tensor`, or `np.ndarray`, *optional*): + Optional conditioning clip for video-to-video. The leading frames are kept clean at the latent indexes + given by `condition_frame_indexes_vision` and the remaining frames are denoised. Each frame is + preprocessed (resized to `height`/`width`) like the `image` input. The canonical input is a list of PIL + frames, e.g. from `diffusers.utils.load_video`. Mutually exclusive with `image`; not used for action + runs (pass `action.video` instead). + condition_frame_indexes_vision (`List[int]`, *optional*): + Latent frame indexes to keep clean when `video` conditioning is supplied, e.g. `[0, 1]` (the default), + i.e. the first two latent frames (a 5 pixel-frame clip under 4x temporal compression). Only consulted + for video-to-video. + condition_video_keep (`str`, *optional*, defaults to `"first"`): + Which end of a longer source `video` to take the conditioning frames from: `"first"` or `"last"`. Only + consulted for video-to-video. num_frames (`int`, *optional*, defaults to `None`): Number of frames to generate. Use `1` for text-to-image. Defaults to `189` (≈ 7.9 s at 24 FPS) for non-action modes when omitted (`None`). Must be `None` for action runs, where frame count is derived @@ -1327,6 +1402,8 @@ def __call__( enable_sound, callback_on_step_end_tensor_inputs, action, + video=video, + condition_frame_indexes_vision=condition_frame_indexes_vision, ) # `action_mode` is the only action field consumed directly in __call__ (prompt template + output slicing); @@ -1405,6 +1482,9 @@ def __call__( action_condition_frame_indexes, ) = self.prepare_latents( image=image, + video=video, + condition_frame_indexes_vision=condition_frame_indexes_vision, + condition_video_keep=condition_video_keep, num_frames=num_frames, height=height, width=width,