From f1335d2e401bd05a907d157d379fac9985e14017 Mon Sep 17 00:00:00 2001 From: Evan Lezar Date: Mon, 8 Jun 2026 10:32:43 +0200 Subject: [PATCH 1/4] fix(cli): let gpu-device request GPUs Signed-off-by: Evan Lezar --- crates/openshell-cli/src/main.rs | 31 ++++++++++++- crates/openshell-cli/src/run.rs | 4 +- .../sandbox_create_lifecycle_integration.rs | 45 +++++++++++++++++++ 3 files changed, 77 insertions(+), 3 deletions(-) diff --git a/crates/openshell-cli/src/main.rs b/crates/openshell-cli/src/main.rs index d4a19c4bf..c45627518 100644 --- a/crates/openshell-cli/src/main.rs +++ b/crates/openshell-cli/src/main.rs @@ -1215,8 +1215,9 @@ enum SandboxCommands { /// Target a driver-specific GPU device. Docker and Podman use CDI device IDs /// (for example "nvidia.com/gpu=0"); VM uses a PCI BDF or index. - /// Only valid with --gpu. When omitted with --gpu, the driver uses its default GPU selection. - #[arg(long, requires = "gpu")] + /// Specifying --gpu-device also requests GPU resources. + /// When omitted with --gpu, the driver uses its default GPU selection. + #[arg(long)] gpu_device: Option, /// CPU limit for the sandbox (for example: 500m, 1, 2.5). @@ -4371,6 +4372,32 @@ mod tests { } } + #[test] + fn sandbox_create_gpu_device_parses_without_gpu_flag() { + let cli = Cli::try_parse_from([ + "openshell", + "sandbox", + "create", + "--gpu-device", + "nvidia.com/gpu=0", + ]) + .expect("sandbox create --gpu-device should parse without --gpu"); + + match cli.command { + Some(Commands::Sandbox { + command: + Some(SandboxCommands::Create { + gpu, gpu_device, .. + }), + .. + }) => { + assert!(!gpu); + assert_eq!(gpu_device.as_deref(), Some("nvidia.com/gpu=0")); + } + other => panic!("expected SandboxCommands::Create, got: {other:?}"), + } + } + #[test] fn service_expose_accepts_positional_target_port_and_service() { let cli = Cli::try_parse_from([ diff --git a/crates/openshell-cli/src/run.rs b/crates/openshell-cli/src/run.rs index b29e2f633..e4ef4df04 100644 --- a/crates/openshell-cli/src/run.rs +++ b/crates/openshell-cli/src/run.rs @@ -1799,7 +1799,9 @@ pub async fn sandbox_create( } None => None, }; - let requested_gpu = gpu || image.as_deref().is_some_and(image_requests_gpu); + let requested_gpu = gpu + || gpu_device.is_some_and(|device_id| !device_id.is_empty()) + || image.as_deref().is_some_and(image_requests_gpu); let providers_v2_enabled = gateway_providers_v2_enabled(&mut client).await?; let inferred_types: Vec = if providers_v2_enabled { diff --git a/crates/openshell-cli/tests/sandbox_create_lifecycle_integration.rs b/crates/openshell-cli/tests/sandbox_create_lifecycle_integration.rs index 96821655d..7f83ad5e1 100644 --- a/crates/openshell-cli/tests/sandbox_create_lifecycle_integration.rs +++ b/crates/openshell-cli/tests/sandbox_create_lifecycle_integration.rs @@ -886,6 +886,51 @@ async fn sandbox_create_sends_cpu_and_memory_limits_only() { assert!(!resources.fields.contains_key("requests")); } +#[tokio::test] +async fn sandbox_create_sends_gpu_device_request_without_gpu_flag() { + let server = run_server().await; + let fake_ssh_dir = tempfile::tempdir().unwrap(); + let xdg_dir = tempfile::tempdir().unwrap(); + let _env = test_env(&fake_ssh_dir, &xdg_dir); + let tls = test_tls(&server); + install_fake_ssh(&fake_ssh_dir); + + run::sandbox_create( + &server.endpoint, + Some("gpu-device"), + None, + "openshell", + &[], + true, + false, + Some("nvidia.com/gpu=0"), + None, + None, + None, + None, + &[], + None, + None, + &["echo".to_string(), "OK".to_string()], + Some(false), + Some(false), + &HashMap::new(), + "manual", + &tls, + ) + .await + .expect("sandbox create should succeed"); + + let requests = create_requests(&server).await; + let spec = requests[0] + .spec + .as_ref() + .expect("sandbox spec should be sent"); + + assert!(spec.gpu); + assert_eq!(spec.gpu_device, "nvidia.com/gpu=0"); +} + #[tokio::test] async fn sandbox_create_sends_driver_config_json() { let server = run_server().await; From 055ddf13d0985cc1e989e3672b030d2ef28925f6 Mon Sep 17 00:00:00 2001 From: Evan Lezar Date: Mon, 8 Jun 2026 11:18:41 +0200 Subject: [PATCH 2/4] feat(cli): add gpu count requests Signed-off-by: Evan Lezar --- architecture/compute-runtimes.md | 4 +- crates/openshell-cli/src/main.rs | 54 ++++++++++++- crates/openshell-cli/src/run.rs | 3 + .../sandbox_create_lifecycle_integration.rs | 59 ++++++++++++++ crates/openshell-driver-docker/README.md | 2 +- crates/openshell-driver-docker/src/lib.rs | 14 +++- crates/openshell-driver-docker/src/tests.rs | 19 +++++ crates/openshell-driver-kubernetes/README.md | 10 +-- .../openshell-driver-kubernetes/src/driver.rs | 76 ++++++++++++++++--- crates/openshell-driver-podman/README.md | 2 +- crates/openshell-driver-podman/src/driver.rs | 27 ++++++- crates/openshell-driver-vm/src/driver.rs | 51 +++++++++++++ crates/openshell-server/src/compute/mod.rs | 16 ++++ .../openshell-server/src/grpc/validation.rs | 70 +++++++++++++++++ docs/reference/sandbox-compute-drivers.mdx | 3 +- docs/sandboxes/manage-sandboxes.mdx | 19 +++++ proto/compute_driver.proto | 2 + proto/openshell.proto | 2 + 18 files changed, 407 insertions(+), 26 deletions(-) diff --git a/architecture/compute-runtimes.md b/architecture/compute-runtimes.md index b70a2fccc..efeae6d98 100644 --- a/architecture/compute-runtimes.md +++ b/architecture/compute-runtimes.md @@ -82,7 +82,9 @@ users. Custom sandbox images must include the agent runtime and any system dependencies, but they should not need to include the gateway. GPU-capable images must include the user-space libraries required by the workload. The -runtime still owns GPU device injection. +runtime still owns GPU device injection. GPU requests can include a driver-native +device identifier or a requested count; the gateway validates the request shape +and each runtime enforces the GPU allocation modes it supports. ## Deployment Shape diff --git a/crates/openshell-cli/src/main.rs b/crates/openshell-cli/src/main.rs index c45627518..8065cdfe5 100644 --- a/crates/openshell-cli/src/main.rs +++ b/crates/openshell-cli/src/main.rs @@ -1217,9 +1217,13 @@ enum SandboxCommands { /// (for example "nvidia.com/gpu=0"); VM uses a PCI BDF or index. /// Specifying --gpu-device also requests GPU resources. /// When omitted with --gpu, the driver uses its default GPU selection. - #[arg(long)] + #[arg(long, conflicts_with = "gpu_count")] gpu_device: Option, + /// Request a specific number of GPUs. Mutually exclusive with --gpu-device. + #[arg(long, value_parser = clap::value_parser!(u32).range(1..), conflicts_with = "gpu_device")] + gpu_count: Option, + /// CPU limit for the sandbox (for example: 500m, 1, 2.5). #[arg(long)] cpu: Option, @@ -2548,6 +2552,7 @@ async fn main() -> Result<()> { editor, gpu, gpu_device, + gpu_count, cpu, memory, driver_config_json, @@ -2629,6 +2634,7 @@ async fn main() -> Result<()> { keep, gpu, gpu_device.as_deref(), + gpu_count, cpu.as_deref(), memory.as_deref(), driver_config_json.as_deref(), @@ -4398,6 +4404,52 @@ mod tests { } } + #[test] + fn sandbox_create_gpu_count_parses_without_gpu_flag() { + let cli = Cli::try_parse_from(["openshell", "sandbox", "create", "--gpu-count", "2"]) + .expect("sandbox create --gpu-count should parse"); + + match cli.command { + Some(Commands::Sandbox { + command: Some(SandboxCommands::Create { gpu, gpu_count, .. }), + .. + }) => { + assert!(!gpu); + assert_eq!(gpu_count, Some(2)); + } + other => panic!("expected SandboxCommands::Create, got: {other:?}"), + } + } + + #[test] + fn sandbox_create_gpu_count_rejects_zero() { + let result = Cli::try_parse_from(["openshell", "sandbox", "create", "--gpu-count", "0"]); + + assert!( + result.is_err(), + "sandbox create --gpu-count 0 should be rejected" + ); + } + + #[test] + fn sandbox_create_gpu_count_conflicts_with_gpu_device() { + let result = Cli::try_parse_from([ + "openshell", + "sandbox", + "create", + "--gpu", + "--gpu-device", + "nvidia.com/gpu=0", + "--gpu-count", + "2", + ]); + + assert!( + result.is_err(), + "sandbox create should reject --gpu-count with --gpu-device" + ); + } + #[test] fn service_expose_accepts_positional_target_port_and_service() { let cli = Cli::try_parse_from([ diff --git a/crates/openshell-cli/src/run.rs b/crates/openshell-cli/src/run.rs index e4ef4df04..746714621 100644 --- a/crates/openshell-cli/src/run.rs +++ b/crates/openshell-cli/src/run.rs @@ -1745,6 +1745,7 @@ pub async fn sandbox_create( keep: bool, gpu: bool, gpu_device: Option<&str>, + gpu_count: Option, cpu: Option<&str>, memory: Option<&str>, driver_config_json: Option<&str>, @@ -1801,6 +1802,7 @@ pub async fn sandbox_create( }; let requested_gpu = gpu || gpu_device.is_some_and(|device_id| !device_id.is_empty()) + || gpu_count.is_some() || image.as_deref().is_some_and(image_requests_gpu); let providers_v2_enabled = gateway_providers_v2_enabled(&mut client).await?; @@ -1838,6 +1840,7 @@ pub async fn sandbox_create( spec: Some(SandboxSpec { gpu: requested_gpu, gpu_device: gpu_device.unwrap_or_default().to_string(), + gpu_count, policy, providers: configured_providers, template, diff --git a/crates/openshell-cli/tests/sandbox_create_lifecycle_integration.rs b/crates/openshell-cli/tests/sandbox_create_lifecycle_integration.rs index 7f83ad5e1..d1781baa8 100644 --- a/crates/openshell-cli/tests/sandbox_create_lifecycle_integration.rs +++ b/crates/openshell-cli/tests/sandbox_create_lifecycle_integration.rs @@ -788,6 +788,7 @@ async fn sandbox_create_keeps_command_sessions_by_default() { None, None, None, + None, &[], None, None, @@ -827,6 +828,7 @@ async fn sandbox_create_sends_cpu_and_memory_limits_only() { true, false, None, + None, Some("500m"), Some("2Gi"), None, @@ -908,6 +910,7 @@ async fn sandbox_create_sends_gpu_device_request_without_gpu_flag() { None, None, None, + None, &[], None, None, @@ -931,6 +934,53 @@ async fn sandbox_create_sends_gpu_device_request_without_gpu_flag() { assert_eq!(spec.gpu_device, "nvidia.com/gpu=0"); } +#[tokio::test] +async fn sandbox_create_sends_gpu_count_request_without_gpu_flag() { + let server = run_server().await; + let fake_ssh_dir = tempfile::tempdir().unwrap(); + let xdg_dir = tempfile::tempdir().unwrap(); + let _env = test_env(&fake_ssh_dir, &xdg_dir); + let tls = test_tls(&server); + install_fake_ssh(&fake_ssh_dir); + + run::sandbox_create( + &server.endpoint, + Some("gpu-count"), + None, + "openshell", + &[], + true, + false, + None, + Some(2), + None, + None, + None, + None, + &[], + None, + None, + &["echo".to_string(), "OK".to_string()], + Some(false), + Some(false), + &HashMap::new(), + "manual", + &tls, + ) + .await + .expect("sandbox create should succeed"); + + let requests = create_requests(&server).await; + let spec = requests[0] + .spec + .as_ref() + .expect("sandbox spec should be sent"); + + assert!(spec.gpu); + assert_eq!(spec.gpu_count, Some(2)); + assert!(spec.gpu_device.is_empty()); +} + #[tokio::test] async fn sandbox_create_sends_driver_config_json() { let server = run_server().await; @@ -951,6 +1001,7 @@ async fn sandbox_create_sends_driver_config_json() { None, None, None, + None, Some(r#"{"kubernetes":{"pod":{"priority_class_name":"batch-low"}}}"#), None, &[], @@ -1027,6 +1078,7 @@ async fn sandbox_create_does_not_infer_command_providers_when_v2_enabled() { None, None, None, + None, &[], None, None, @@ -1085,6 +1137,7 @@ async fn sandbox_create_returns_vm_error_without_waiting_for_timeout() { None, None, None, + None, &[], None, None, @@ -1139,6 +1192,7 @@ async fn sandbox_create_keeps_waiting_while_vm_progress_arrives() { None, None, None, + None, &[], None, None, @@ -1185,6 +1239,7 @@ async fn sandbox_create_times_out_when_only_logs_arrive() { None, None, None, + None, &[], None, None, @@ -1227,6 +1282,7 @@ async fn sandbox_create_deletes_command_sessions_with_no_keep() { None, None, None, + None, &[], None, None, @@ -1273,6 +1329,7 @@ async fn sandbox_create_deletes_shell_sessions_with_no_keep() { None, None, None, + None, &[], None, None, @@ -1319,6 +1376,7 @@ async fn sandbox_create_keeps_sandbox_with_hidden_keep_flag() { None, None, None, + None, &[], None, None, @@ -1365,6 +1423,7 @@ async fn sandbox_create_keeps_sandbox_with_forwarding() { None, None, None, + None, &[], None, Some(openshell_core::forward::ForwardSpec::new(forward_port)), diff --git a/crates/openshell-driver-docker/README.md b/crates/openshell-driver-docker/README.md index ea57f44e4..486546f7c 100644 --- a/crates/openshell-driver-docker/README.md +++ b/crates/openshell-driver-docker/README.md @@ -32,7 +32,7 @@ contract: | `apparmor=unconfined` | Avoids Docker's default profile blocking required mount operations. | | `restart_policy = unless-stopped` | Keeps managed sandboxes resumable across daemon or gateway restarts. | | `PidsLimit` | Enforces the sandbox PID budget at the Docker cgroup layer. Set `[openshell.drivers.docker].sandbox_pids_limit = 0` to inherit the Docker/runtime default. | -| CDI GPU request | Uses the sandbox `gpu_device` value when set; otherwise requests all NVIDIA GPUs when the sandbox spec asks for GPU support and daemon CDI support is detected. | +| CDI GPU request | Uses the sandbox `gpu_device` value when set; otherwise requests all NVIDIA GPUs when the sandbox spec asks for GPU support and daemon CDI support is detected. Count-based GPU requests are rejected. | The agent child process does not retain these supervisor privileges. diff --git a/crates/openshell-driver-docker/src/lib.rs b/crates/openshell-driver-docker/src/lib.rs index 864d91f22..7dce66c01 100644 --- a/crates/openshell-driver-docker/src/lib.rs +++ b/crates/openshell-driver-docker/src/lib.rs @@ -375,7 +375,7 @@ impl DockerComputeDriver { "docker sandboxes require a template image", )); } - Self::validate_gpu_request(spec.gpu, config.supports_gpu)?; + Self::validate_gpu_request(spec.gpu, spec.gpu_count, config.supports_gpu)?; if !template.agent_socket_path.trim().is_empty() { return Err(Status::failed_precondition( "docker compute driver does not support template.agent_socket_path", @@ -409,7 +409,17 @@ impl DockerComputeDriver { )) } - fn validate_gpu_request(gpu: bool, supports_gpu: bool) -> Result<(), Status> { + fn validate_gpu_request( + gpu: bool, + gpu_count: Option, + supports_gpu: bool, + ) -> Result<(), Status> { + if gpu_count.is_some() { + return Err(Status::invalid_argument( + "docker GPU count requests are not supported; use --gpu or --gpu-device", + )); + } + if gpu && !supports_gpu { return Err(Status::failed_precondition( "docker GPU sandboxes require Docker CDI support. Enable CDI on the Docker daemon, then restart the OpenShell gateway/server so GPU capability is detected.", diff --git a/crates/openshell-driver-docker/src/tests.rs b/crates/openshell-driver-docker/src/tests.rs index 4a902a48b..9343cdfb8 100644 --- a/crates/openshell-driver-docker/src/tests.rs +++ b/crates/openshell-driver-docker/src/tests.rs @@ -43,6 +43,7 @@ fn test_sandbox() -> DriverSandbox { }), gpu: false, gpu_device: String::new(), + gpu_count: None, sandbox_token: String::new(), }), status: None, @@ -612,6 +613,24 @@ fn validate_sandbox_rejects_gpu_when_cdi_unavailable() { assert!(err.message().contains("Docker CDI")); } +#[test] +fn validate_sandbox_rejects_gpu_count_request() { + let mut config = runtime_config(); + config.supports_gpu = true; + let mut sandbox = test_sandbox(); + let spec = sandbox.spec.as_mut().unwrap(); + spec.gpu = true; + spec.gpu_count = Some(2); + + let err = DockerComputeDriver::validate_sandbox(&sandbox, &config).unwrap_err(); + + assert_eq!(err.code(), tonic::Code::InvalidArgument); + assert!( + err.message() + .contains("GPU count requests are not supported") + ); +} + #[test] fn validate_sandbox_auth_requires_gateway_token() { let mut sandbox = test_sandbox(); diff --git a/crates/openshell-driver-kubernetes/README.md b/crates/openshell-driver-kubernetes/README.md index 0bdcf3748..2b34b48a3 100644 --- a/crates/openshell-driver-kubernetes/README.md +++ b/crates/openshell-driver-kubernetes/README.md @@ -62,9 +62,9 @@ the supervisor's network namespace mount setup on AppArmor-enabled nodes. ## GPU Support When a sandbox requests GPU support, the driver checks node allocatable capacity -for `nvidia.com/gpu` and requests one GPU resource in the workload spec. The -sandbox image must provide the user-space libraries needed by the agent -workload. +for `nvidia.com/gpu` and requests the configured GPU count in the workload spec. +When no count is set, the driver requests one GPU resource. The sandbox image +must provide the user-space libraries needed by the agent workload. ## Driver Config POC @@ -97,5 +97,5 @@ POC parser renders the keys listed above and ignores unknown fields. `pod.runtime_class_name` maps to PodSpec `runtimeClassName` and overrides the driver's configured `default_runtime_class_name`; the typed public `SandboxTemplate.runtime_class_name` still takes precedence when set. Use the -public `gpu` flag for the default GPU request and `driver_config` only for -additional driver-owned resource details. +public `gpu` flag for the default GPU request, `gpu_count` for counted GPU +requests, and `driver_config` only for additional driver-owned resource details. diff --git a/crates/openshell-driver-kubernetes/src/driver.rs b/crates/openshell-driver-kubernetes/src/driver.rs index 449cee58d..475bae5a7 100644 --- a/crates/openshell-driver-kubernetes/src/driver.rs +++ b/crates/openshell-driver-kubernetes/src/driver.rs @@ -1165,7 +1165,14 @@ fn sandbox_to_k8s_spec( if let Some(template) = spec.template.as_ref() { root.insert( "podTemplate".to_string(), - sandbox_template_to_k8s(template, spec.gpu, &pod_env, inject_workspace, params), + sandbox_template_to_k8s_with_gpu_count( + template, + spec.gpu, + spec.gpu_count, + &pod_env, + inject_workspace, + params, + ), ); if !template.agent_socket_path.is_empty() { root.insert( @@ -1195,9 +1202,10 @@ fn sandbox_to_k8s_spec( let pod_env = spec_pod_env(spec); root.insert( "podTemplate".to_string(), - sandbox_template_to_k8s( + sandbox_template_to_k8s_with_gpu_count( &SandboxTemplate::default(), spec.is_some_and(|s| s.gpu), + spec.and_then(|s| s.gpu_count), &pod_env, inject_workspace, params, @@ -1210,12 +1218,31 @@ fn sandbox_to_k8s_spec( ) } +#[cfg(test)] fn sandbox_template_to_k8s( template: &SandboxTemplate, gpu: bool, spec_environment: &std::collections::HashMap, inject_workspace: bool, params: &SandboxPodParams<'_>, +) -> serde_json::Value { + sandbox_template_to_k8s_with_gpu_count( + template, + gpu, + None, + spec_environment, + inject_workspace, + params, + ) +} + +fn sandbox_template_to_k8s_with_gpu_count( + template: &SandboxTemplate, + gpu: bool, + gpu_count: Option, + spec_environment: &std::collections::HashMap, + inject_workspace: bool, + params: &SandboxPodParams<'_>, ) -> serde_json::Value { let driver_config = kubernetes_driver_config(template); @@ -1379,7 +1406,7 @@ fn sandbox_template_to_k8s( serde_json::Value::Array(volume_mounts), ); - if let Some(resources) = container_resources(template, gpu) { + if let Some(resources) = container_resources(template, gpu, gpu_count) { container.insert("resources".to_string(), resources); } apply_agent_driver_resources(&mut container, &driver_config.containers.agent.resources); @@ -1548,7 +1575,11 @@ fn app_armor_profile_to_k8s(profile: &AppArmorProfile) -> serde_json::Value { value } -fn container_resources(template: &SandboxTemplate, gpu: bool) -> Option { +fn container_resources( + template: &SandboxTemplate, + gpu: bool, + gpu_count: Option, +) -> Option { // Start from the raw resources passthrough in platform_config (preserves // custom resource types like GPU limits that users set via the public API // Struct), then overlay the typed DriverResourceRequirements on top. @@ -1582,7 +1613,11 @@ fn container_resources(template: &SandboxTemplate, gpu: bool) -> Option Option Result<(), ComputeDriverError> { let gpu_requested = sandbox.spec.as_ref().is_some_and(|s| s.gpu); - Self::validate_gpu_request(gpu_requested) + let gpu_count = sandbox.spec.as_ref().and_then(|s| s.gpu_count); + Self::validate_gpu_request(gpu_requested, gpu_count) } - fn validate_gpu_request(gpu_requested: bool) -> Result<(), ComputeDriverError> { + fn validate_gpu_request( + gpu_requested: bool, + gpu_count: Option, + ) -> Result<(), ComputeDriverError> { + if gpu_count.is_some() { + return Err(ComputeDriverError::Precondition( + "podman GPU count requests are not supported; use --gpu or --gpu-device" + .to_string(), + )); + } + if gpu_requested && !Self::has_gpu_capacity() { return Err(ComputeDriverError::Precondition( "GPU sandbox requested, but no NVIDIA GPU devices are available.".to_string(), @@ -651,6 +662,18 @@ mod tests { assert!(matches!(err, ComputeDriverError::Message(_))); } + #[test] + fn validate_gpu_request_rejects_gpu_count() { + let err = PodmanComputeDriver::validate_gpu_request(true, Some(2)) + .expect_err("gpu count should be rejected"); + + assert!(matches!(err, ComputeDriverError::Precondition(_))); + assert!( + err.to_string() + .contains("GPU count requests are not supported") + ); + } + // ── grpc_endpoint auto-detection ─────────────────────────────────── // // PodmanComputeDriver::new() fills grpc_endpoint when it is empty. diff --git a/crates/openshell-driver-vm/src/driver.rs b/crates/openshell-driver-vm/src/driver.rs index 445905a1e..bddf27bc8 100644 --- a/crates/openshell-driver-vm/src/driver.rs +++ b/crates/openshell-driver-vm/src/driver.rs @@ -2577,6 +2577,26 @@ fn validate_vm_sandbox(sandbox: &Sandbox, gpu_enabled: bool) -> Result<(), Statu .as_ref() .ok_or_else(|| Status::invalid_argument("sandbox spec is required"))?; + if spec.gpu_count == Some(0) { + return Err(Status::invalid_argument("gpu_count must be greater than 0")); + } + + if spec.gpu_count.is_some() && !spec.gpu { + return Err(Status::invalid_argument("gpu_count requires gpu=true")); + } + + if spec.gpu_count.is_some() && !spec.gpu_device.is_empty() { + return Err(Status::invalid_argument( + "gpu_count is mutually exclusive with gpu_device", + )); + } + + if spec.gpu_count.is_some_and(|count| count > 1) { + return Err(Status::invalid_argument( + "VM GPU sandboxes support only one GPU", + )); + } + if spec.gpu && !gpu_enabled { return Err(Status::failed_precondition( "GPU support is not enabled on this driver; start with --gpu", @@ -4515,6 +4535,37 @@ mod tests { validate_vm_sandbox(&sandbox, true).expect("gpu should be accepted when enabled"); } + #[test] + fn validate_vm_sandbox_accepts_gpu_count_one() { + let sandbox = Sandbox { + id: "sandbox-123".to_string(), + spec: Some(SandboxSpec { + gpu: true, + gpu_count: Some(1), + ..Default::default() + }), + ..Default::default() + }; + validate_vm_sandbox(&sandbox, true).expect("one GPU should be accepted when enabled"); + } + + #[test] + fn validate_vm_sandbox_rejects_gpu_count_above_one() { + let sandbox = Sandbox { + id: "sandbox-123".to_string(), + spec: Some(SandboxSpec { + gpu: true, + gpu_count: Some(2), + ..Default::default() + }), + ..Default::default() + }; + let err = validate_vm_sandbox(&sandbox, true) + .expect_err("multiple GPU VM request should be rejected"); + assert_eq!(err.code(), Code::InvalidArgument); + assert!(err.message().contains("support only one GPU")); + } + #[test] fn validate_vm_sandbox_rejects_gpu_device_without_gpu() { let sandbox = Sandbox { diff --git a/crates/openshell-server/src/compute/mod.rs b/crates/openshell-server/src/compute/mod.rs index 064eb3857..bc5adbd26 100644 --- a/crates/openshell-server/src/compute/mod.rs +++ b/crates/openshell-server/src/compute/mod.rs @@ -1281,6 +1281,7 @@ fn driver_sandbox_spec_from_public( .transpose()?, gpu: spec.gpu, gpu_device: spec.gpu_device.clone(), + gpu_count: spec.gpu_count, sandbox_token: String::new(), }) } @@ -1857,6 +1858,21 @@ mod tests { } } + #[test] + fn driver_sandbox_spec_from_public_preserves_gpu_count() { + let public = SandboxSpec { + gpu: true, + gpu_count: Some(2), + ..Default::default() + }; + + let driver = + driver_sandbox_spec_from_public(&public, None).expect("driver spec should map"); + + assert!(driver.gpu); + assert_eq!(driver.gpu_count, Some(2)); + } + #[test] fn select_driver_config_forwards_only_matching_driver_block() { let config = prost_types::Struct { diff --git a/crates/openshell-server/src/grpc/validation.rs b/crates/openshell-server/src/grpc/validation.rs index 268c143d2..9ddb6627e 100644 --- a/crates/openshell-server/src/grpc/validation.rs +++ b/crates/openshell-server/src/grpc/validation.rs @@ -131,6 +131,9 @@ pub(super) fn validate_sandbox_spec( validate_sandbox_template(tmpl)?; } + // --- spec.gpu* --- + validate_gpu_request_fields(spec)?; + // --- spec.policy serialized size --- if let Some(ref policy) = spec.policy { let size = policy.encoded_len(); @@ -144,6 +147,24 @@ pub(super) fn validate_sandbox_spec( Ok(()) } +fn validate_gpu_request_fields(spec: &openshell_core::proto::SandboxSpec) -> Result<(), Status> { + if spec.gpu_count == Some(0) { + return Err(Status::invalid_argument("gpu_count must be greater than 0")); + } + + if spec.gpu_count.is_some() && !spec.gpu { + return Err(Status::invalid_argument("gpu_count requires gpu=true")); + } + + if spec.gpu_count.is_some() && !spec.gpu_device.is_empty() { + return Err(Status::invalid_argument( + "gpu_count is mutually exclusive with gpu_device", + )); + } + + Ok(()) +} + /// Validate template-level field sizes. fn validate_sandbox_template(tmpl: &SandboxTemplate) -> Result<(), Status> { // String fields. @@ -712,6 +733,55 @@ mod tests { assert!(validate_sandbox_spec("gpu-sandbox", &spec).is_ok()); } + #[test] + fn validate_sandbox_spec_accepts_gpu_count() { + let spec = SandboxSpec { + gpu: true, + gpu_count: Some(2), + ..Default::default() + }; + assert!(validate_sandbox_spec("gpu-sandbox", &spec).is_ok()); + } + + #[test] + fn validate_sandbox_spec_rejects_zero_gpu_count() { + let spec = SandboxSpec { + gpu: true, + gpu_count: Some(0), + ..Default::default() + }; + let err = validate_sandbox_spec("gpu-sandbox", &spec).unwrap_err(); + assert_eq!(err.code(), Code::InvalidArgument); + assert!(err.message().contains("gpu_count must be greater than 0")); + } + + #[test] + fn validate_sandbox_spec_rejects_gpu_count_without_gpu() { + let spec = SandboxSpec { + gpu_count: Some(1), + ..Default::default() + }; + let err = validate_sandbox_spec("gpu-sandbox", &spec).unwrap_err(); + assert_eq!(err.code(), Code::InvalidArgument); + assert!(err.message().contains("gpu_count requires gpu=true")); + } + + #[test] + fn validate_sandbox_spec_rejects_gpu_count_with_gpu_device() { + let spec = SandboxSpec { + gpu: true, + gpu_count: Some(1), + gpu_device: "nvidia.com/gpu=0".to_string(), + ..Default::default() + }; + let err = validate_sandbox_spec("gpu-sandbox", &spec).unwrap_err(); + assert_eq!(err.code(), Code::InvalidArgument); + assert!( + err.message() + .contains("gpu_count is mutually exclusive with gpu_device") + ); + } + #[test] fn validate_sandbox_spec_accepts_empty_defaults() { assert!(validate_sandbox_spec("", &default_spec()).is_ok()); diff --git a/docs/reference/sandbox-compute-drivers.mdx b/docs/reference/sandbox-compute-drivers.mdx index 229bb1bdb..60c062120 100644 --- a/docs/reference/sandbox-compute-drivers.mdx +++ b/docs/reference/sandbox-compute-drivers.mdx @@ -53,7 +53,8 @@ openshell sandbox create \ ``` Driver config is for fields without a stable public flag. Prefer `--cpu`, -`--memory`, and `--gpu` for portable resource intent. +`--memory`, `--gpu`, `--gpu-count`, and `--gpu-device` for supported resource +intent. For Kubernetes, `pod.runtime_class_name` maps to PodSpec `runtimeClassName`. It overrides the gateway's configured default runtime class for that sandbox, diff --git a/docs/sandboxes/manage-sandboxes.mdx b/docs/sandboxes/manage-sandboxes.mdx index 0db6d7678..6065a2913 100644 --- a/docs/sandboxes/manage-sandboxes.mdx +++ b/docs/sandboxes/manage-sandboxes.mdx @@ -70,6 +70,25 @@ To request GPU resources, add `--gpu`: openshell sandbox create --gpu -- claude ``` +Request a specific number of GPUs with `--gpu-count`: + +```shell +openshell sandbox create --gpu-count 2 -- claude +``` + +Request a driver-specific GPU device with `--gpu-device`; this also requests +GPU resources: + +```shell +openshell sandbox create --gpu-device nvidia.com/gpu=0 -- claude +``` + +Support for count and device selection is driver-dependent. Kubernetes honors +`--gpu-count` by setting the `nvidia.com/gpu` limit. Docker and Podman support +explicit CDI device IDs through `--gpu-device` but reject count-based selection. +VM gateways accept only one GPU, either through `--gpu`, `--gpu-count 1`, or +`--gpu-device`. + For Docker-backed sandboxes, GPU injection uses Docker CDI. If you enable Docker CDI after the gateway starts, restart the gateway so OpenShell can detect the updated Docker daemon capability. diff --git a/proto/compute_driver.proto b/proto/compute_driver.proto index 190a04e87..cffdc36bc 100644 --- a/proto/compute_driver.proto +++ b/proto/compute_driver.proto @@ -89,6 +89,8 @@ message DriverSandboxSpec { // (e.g. "0", "1"). When empty with gpu=true, the driver assigns the // first available GPU. string gpu_device = 10; + // Optional number of GPUs requested. Mutually exclusive with gpu_device. + optional uint32 gpu_count = 12; // Gateway-minted JWT identifying this sandbox to the gateway. Set by // the gateway on create; the driver materialises it via its native // secret mechanism (Docker/Podman/VM bind-mount a per-sandbox file; diff --git a/proto/openshell.proto b/proto/openshell.proto index c2755aaf7..3a9625f35 100644 --- a/proto/openshell.proto +++ b/proto/openshell.proto @@ -323,6 +323,8 @@ message SandboxSpec { // (e.g. "0", "1"). When empty with gpu=true, the driver assigns the // first available GPU. string gpu_device = 10; + // Optional number of GPUs requested. Mutually exclusive with gpu_device. + optional uint32 gpu_count = 12; // Field 11 was `proposal_approval_mode`. The approval mode is now a // runtime setting (gateway or sandbox scope) read via UpdateConfig / // GetSandboxConfig, so it can be flipped on a running sandbox and From a7760c76eecf8d7069f6a8ba0d7aad005f2c15fe Mon Sep 17 00:00:00 2001 From: Evan Lezar Date: Mon, 8 Jun 2026 14:59:38 +0200 Subject: [PATCH 3/4] fix(gateway): normalize gpu request intent Signed-off-by: Evan Lezar --- crates/openshell-server/src/grpc/sandbox.rs | 84 +++++++++++++++++-- .../openshell-server/src/grpc/validation.rs | 10 +-- 2 files changed, 80 insertions(+), 14 deletions(-) diff --git a/crates/openshell-server/src/grpc/sandbox.rs b/crates/openshell-server/src/grpc/sandbox.rs index e60ce3995..8463ed587 100644 --- a/crates/openshell-server/src/grpc/sandbox.rs +++ b/crates/openshell-server/src/grpc/sandbox.rs @@ -19,8 +19,8 @@ use openshell_core::proto::{ ExecSandboxInput, ExecSandboxRequest, ExecSandboxStderr, ExecSandboxStdout, GetSandboxRequest, ListSandboxProvidersRequest, ListSandboxProvidersResponse, ListSandboxesRequest, ListSandboxesResponse, Provider, RevokeSshSessionRequest, RevokeSshSessionResponse, - SandboxResponse, SandboxStreamEvent, SshRelayTarget, TcpForwardFrame, TcpForwardInit, - TcpRelayTarget, WatchSandboxRequest, relay_open, tcp_forward_init, + SandboxResponse, SandboxSpec, SandboxStreamEvent, SshRelayTarget, TcpForwardFrame, + TcpForwardInit, TcpRelayTarget, WatchSandboxRequest, relay_open, tcp_forward_init, }; use openshell_core::proto::{Sandbox, SandboxPhase, SandboxTemplate, SshSession}; use openshell_core::telemetry::{ @@ -100,7 +100,7 @@ fn emit_sandbox_create_telemetry( }; openshell_core::telemetry::emit_sandbox_create( outcome, - spec.gpu, + effective_gpu_request(spec), spec.providers.len() as u64, spec.policy.is_some(), template_source, @@ -114,6 +114,10 @@ fn telemetry_compute_driver( TelemetryComputeDriver::from_driver_kind(driver_kind) } +fn effective_gpu_request(spec: &SandboxSpec) -> bool { + spec.gpu || spec.gpu_count.is_some() || !spec.gpu_device.is_empty() +} + async fn handle_create_sandbox_inner( state: &Arc, request: Request, @@ -143,8 +147,10 @@ async fn handle_create_sandbox_inner( } validate_provider_environment_keys_unique(state.store.as_ref(), &spec.providers).await?; - // Ensure the template always carries the resolved image. let mut spec = spec; + spec.gpu = effective_gpu_request(&spec); + + // Ensure the template always carries the resolved image. let template = spec.template.get_or_insert_with(SandboxTemplate::default); if template.image.is_empty() { template.image = state.compute.default_image().to_string(); @@ -2208,7 +2214,7 @@ mod tests { labels: std::iter::once(("team".to_string(), "agents".to_string())).collect(), resource_version: 0, }), - spec: Some(openshell_core::proto::SandboxSpec { + spec: Some(SandboxSpec { log_level: "debug".to_string(), policy: Some(openshell_core::proto::SandboxPolicy::default()), providers, @@ -2572,7 +2578,7 @@ mod tests { &state, Request::new(CreateSandboxRequest { name: "collision".to_string(), - spec: Some(openshell_core::proto::SandboxSpec { + spec: Some(SandboxSpec { providers: vec!["provider-a".to_string(), "provider-b".to_string()], ..Default::default() }), @@ -2588,6 +2594,72 @@ mod tests { assert!(err.message().contains("provider-b")); } + #[tokio::test] + async fn create_sandbox_gpu_count_implies_gpu() { + let state = test_server_state().await; + + let response = handle_create_sandbox( + &state, + Request::new(CreateSandboxRequest { + name: "gpu-count".to_string(), + spec: Some(SandboxSpec { + gpu_count: Some(2), + ..Default::default() + }), + labels: HashMap::new(), + }), + ) + .await + .unwrap() + .into_inner(); + + let sandbox = response.sandbox.unwrap(); + let spec = sandbox.spec.as_ref().unwrap(); + assert!(spec.gpu); + assert_eq!(spec.gpu_count, Some(2)); + + let stored = state + .store + .get_message_by_name::("gpu-count") + .await + .unwrap() + .unwrap(); + assert!(stored.spec.as_ref().unwrap().gpu); + } + + #[tokio::test] + async fn create_sandbox_gpu_device_implies_gpu() { + let state = test_server_state().await; + + let response = handle_create_sandbox( + &state, + Request::new(CreateSandboxRequest { + name: "gpu-device".to_string(), + spec: Some(SandboxSpec { + gpu_device: "nvidia.com/gpu=0".to_string(), + ..Default::default() + }), + labels: HashMap::new(), + }), + ) + .await + .unwrap() + .into_inner(); + + let sandbox = response.sandbox.unwrap(); + let spec = sandbox.spec.as_ref().unwrap(); + assert!(spec.gpu); + assert_eq!(spec.gpu_device, "nvidia.com/gpu=0"); + + let stored = state + .store + .get_message_by_name::("gpu-device") + .await + .unwrap() + .unwrap(); + assert!(stored.spec.as_ref().unwrap().gpu); + } + #[tokio::test] async fn attach_sandbox_provider_rejects_credential_key_collisions() { let state = test_server_state().await; diff --git a/crates/openshell-server/src/grpc/validation.rs b/crates/openshell-server/src/grpc/validation.rs index 9ddb6627e..da43d339d 100644 --- a/crates/openshell-server/src/grpc/validation.rs +++ b/crates/openshell-server/src/grpc/validation.rs @@ -152,10 +152,6 @@ fn validate_gpu_request_fields(spec: &openshell_core::proto::SandboxSpec) -> Res return Err(Status::invalid_argument("gpu_count must be greater than 0")); } - if spec.gpu_count.is_some() && !spec.gpu { - return Err(Status::invalid_argument("gpu_count requires gpu=true")); - } - if spec.gpu_count.is_some() && !spec.gpu_device.is_empty() { return Err(Status::invalid_argument( "gpu_count is mutually exclusive with gpu_device", @@ -756,14 +752,12 @@ mod tests { } #[test] - fn validate_sandbox_spec_rejects_gpu_count_without_gpu() { + fn validate_sandbox_spec_accepts_gpu_count_without_gpu() { let spec = SandboxSpec { gpu_count: Some(1), ..Default::default() }; - let err = validate_sandbox_spec("gpu-sandbox", &spec).unwrap_err(); - assert_eq!(err.code(), Code::InvalidArgument); - assert!(err.message().contains("gpu_count requires gpu=true")); + assert!(validate_sandbox_spec("gpu-sandbox", &spec).is_ok()); } #[test] From abe5b792bc9c717b721630bff2ade38bc8882df3 Mon Sep 17 00:00:00 2001 From: Evan Lezar Date: Mon, 8 Jun 2026 15:02:55 +0200 Subject: [PATCH 4/4] fix(kubernetes): reject explicit gpu devices Signed-off-by: Evan Lezar --- architecture/compute-runtimes.md | 3 +- crates/openshell-driver-kubernetes/README.md | 4 +- .../openshell-driver-kubernetes/src/driver.rs | 51 +++++++++++++++++++ docs/sandboxes/manage-sandboxes.mdx | 8 +-- 4 files changed, 60 insertions(+), 6 deletions(-) diff --git a/architecture/compute-runtimes.md b/architecture/compute-runtimes.md index efeae6d98..dcf408f1c 100644 --- a/architecture/compute-runtimes.md +++ b/architecture/compute-runtimes.md @@ -84,7 +84,8 @@ dependencies, but they should not need to include the gateway. GPU-capable images must include the user-space libraries required by the workload. The runtime still owns GPU device injection. GPU requests can include a driver-native device identifier or a requested count; the gateway validates the request shape -and each runtime enforces the GPU allocation modes it supports. +and each runtime enforces the GPU allocation modes it supports. Kubernetes uses +counted `nvidia.com/gpu` resources and rejects driver-native device identifiers. ## Deployment Shape diff --git a/crates/openshell-driver-kubernetes/README.md b/crates/openshell-driver-kubernetes/README.md index 2b34b48a3..940d0db93 100644 --- a/crates/openshell-driver-kubernetes/README.md +++ b/crates/openshell-driver-kubernetes/README.md @@ -64,7 +64,9 @@ the supervisor's network namespace mount setup on AppArmor-enabled nodes. When a sandbox requests GPU support, the driver checks node allocatable capacity for `nvidia.com/gpu` and requests the configured GPU count in the workload spec. When no count is set, the driver requests one GPU resource. The sandbox image -must provide the user-space libraries needed by the agent workload. +must provide the user-space libraries needed by the agent workload. The driver +does not support explicit GPU device identifiers; use the public `gpu` flag or +`gpu_count`. ## Driver Config POC diff --git a/crates/openshell-driver-kubernetes/src/driver.rs b/crates/openshell-driver-kubernetes/src/driver.rs index 475bae5a7..4b78c0f8e 100644 --- a/crates/openshell-driver-kubernetes/src/driver.rs +++ b/crates/openshell-driver-kubernetes/src/driver.rs @@ -173,6 +173,20 @@ impl std::fmt::Debug for KubernetesComputeDriver { } } +fn gpu_request_validation_error(sandbox: &Sandbox) -> Option<&'static str> { + if sandbox + .spec + .as_ref() + .is_some_and(|spec| !spec.gpu_device.is_empty()) + { + return Some( + "gpu_device is not supported by the kubernetes compute driver; use gpu or gpu_count", + ); + } + + None +} + impl KubernetesComputeDriver { pub async fn new(config: KubernetesComputeConfig) -> Result { let base_config = match kube::Config::incluster() { @@ -245,6 +259,10 @@ impl KubernetesComputeDriver { } pub async fn validate_sandbox_create(&self, sandbox: &Sandbox) -> Result<(), tonic::Status> { + if let Some(message) = gpu_request_validation_error(sandbox) { + return Err(tonic::Status::invalid_argument(message)); + } + let gpu_requested = sandbox.spec.as_ref().is_some_and(|spec| spec.gpu); if gpu_requested && !self.has_gpu_capacity().await.map_err(|err| { @@ -2288,6 +2306,39 @@ mod tests { ); } + #[test] + fn gpu_request_validation_error_rejects_gpu_device() { + let sandbox = Sandbox { + spec: Some(SandboxSpec { + gpu: true, + gpu_device: "nvidia.com/gpu=0".to_string(), + ..Default::default() + }), + ..Default::default() + }; + + assert_eq!( + gpu_request_validation_error(&sandbox), + Some( + "gpu_device is not supported by the kubernetes compute driver; use gpu or gpu_count" + ) + ); + } + + #[test] + fn gpu_request_validation_error_accepts_gpu_count() { + let sandbox = Sandbox { + spec: Some(SandboxSpec { + gpu: true, + gpu_count: Some(2), + ..Default::default() + }), + ..Default::default() + }; + + assert_eq!(gpu_request_validation_error(&sandbox), None); + } + #[test] fn gpu_sandbox_uses_template_runtime_class_name_when_set() { let template = SandboxTemplate { diff --git a/docs/sandboxes/manage-sandboxes.mdx b/docs/sandboxes/manage-sandboxes.mdx index 6065a2913..26aac8fad 100644 --- a/docs/sandboxes/manage-sandboxes.mdx +++ b/docs/sandboxes/manage-sandboxes.mdx @@ -84,10 +84,10 @@ openshell sandbox create --gpu-device nvidia.com/gpu=0 -- claude ``` Support for count and device selection is driver-dependent. Kubernetes honors -`--gpu-count` by setting the `nvidia.com/gpu` limit. Docker and Podman support -explicit CDI device IDs through `--gpu-device` but reject count-based selection. -VM gateways accept only one GPU, either through `--gpu`, `--gpu-count 1`, or -`--gpu-device`. +`--gpu-count` by setting the `nvidia.com/gpu` limit and rejects +`--gpu-device`. Docker and Podman support explicit CDI device IDs through +`--gpu-device` but reject count-based selection. VM gateways accept only one +GPU, either through `--gpu`, `--gpu-count 1`, or `--gpu-device`. For Docker-backed sandboxes, GPU injection uses Docker CDI. If you enable Docker CDI after the gateway starts, restart the gateway so OpenShell can detect the