diff --git a/Cargo.lock b/Cargo.lock index d0cd77f85..04b561899 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3481,6 +3481,7 @@ dependencies = [ "bytes", "futures", "openshell-core", + "prost-types", "serde", "serde_json", "tar", @@ -3529,6 +3530,7 @@ dependencies = [ "miette", "nix", "openshell-core", + "prost-types", "rustix 1.1.4", "serde", "serde_json", diff --git a/TESTING.md b/TESTING.md index 7bcf2d203..d32dc385c 100644 --- a/TESTING.md +++ b/TESTING.md @@ -157,7 +157,9 @@ defaults to the image used by the `gateway` stage in `deploy/docker/Dockerfile.images`; set `OPENSHELL_E2E_GPU_PROBE_IMAGE` to override it. Per-device checks run only for NVIDIA CDI device IDs reported by the runtime's discovered devices list, so WSL2 hosts that expose only -`nvidia.com/gpu=all` skip the index-based cases. +`nvidia.com/gpu=all` skip the index-based cases. Exact CDI device selection is +passed through `--driver-config-json` with the active Docker or Podman driver +key. Run the Docker-backed Rust CLI e2e suite: diff --git a/crates/openshell-cli/src/main.rs b/crates/openshell-cli/src/main.rs index 19fe2df83..2084082b5 100644 --- a/crates/openshell-cli/src/main.rs +++ b/crates/openshell-cli/src/main.rs @@ -1211,12 +1211,6 @@ enum SandboxCommands { #[arg(long)] gpu: bool, - /// Target a driver-specific GPU device. Docker and Podman use CDI device IDs - /// (for example "nvidia.com/gpu=0"); VM uses a PCI BDF or index. - /// Only valid with --gpu. When omitted with --gpu, the driver uses its default GPU selection. - #[arg(long, requires = "gpu")] - gpu_device: Option, - /// CPU limit for the sandbox (for example: 500m, 1, 2.5). #[arg(long)] cpu: Option, @@ -2552,7 +2546,6 @@ async fn main() -> Result<()> { no_keep, editor, gpu, - gpu_device, cpu, memory, driver_config_json, @@ -2637,7 +2630,6 @@ async fn main() -> Result<()> { &upload_specs, keep, gpu, - gpu_device.as_deref(), cpu.as_deref(), memory.as_deref(), driver_config_json.as_deref(), diff --git a/crates/openshell-cli/src/run.rs b/crates/openshell-cli/src/run.rs index 7290ac05e..3739ee00a 100644 --- a/crates/openshell-cli/src/run.rs +++ b/crates/openshell-cli/src/run.rs @@ -1725,7 +1725,6 @@ pub async fn sandbox_create( uploads: &[(String, Option, bool)], keep: bool, gpu: bool, - gpu_device: Option<&str>, cpu: Option<&str>, memory: Option<&str>, driver_config_json: Option<&str>, @@ -1817,7 +1816,6 @@ pub async fn sandbox_create( let request = CreateSandboxRequest { spec: Some(SandboxSpec { gpu: requested_gpu, - gpu_device: gpu_device.unwrap_or_default().to_string(), environment: environment.clone(), policy, providers: configured_providers, diff --git a/crates/openshell-cli/tests/sandbox_create_lifecycle_integration.rs b/crates/openshell-cli/tests/sandbox_create_lifecycle_integration.rs index be71c0a36..7061614cb 100644 --- a/crates/openshell-cli/tests/sandbox_create_lifecycle_integration.rs +++ b/crates/openshell-cli/tests/sandbox_create_lifecycle_integration.rs @@ -787,7 +787,6 @@ async fn sandbox_create_keeps_command_sessions_by_default() { None, None, None, - None, &[], None, None, @@ -827,7 +826,6 @@ async fn sandbox_create_sends_cpu_and_memory_limits_only() { &[], true, false, - None, Some("500m"), Some("2Gi"), None, @@ -907,7 +905,6 @@ async fn sandbox_create_sends_driver_config_json() { false, None, None, - None, Some(r#"{"kubernetes":{"pod":{"priority_class_name":"batch-low"}}}"#), None, &[], @@ -984,7 +981,6 @@ async fn sandbox_create_does_not_infer_command_providers_when_v2_enabled() { None, None, None, - None, &[], None, None, @@ -1043,7 +1039,6 @@ async fn sandbox_create_returns_vm_error_without_waiting_for_timeout() { None, None, None, - None, &[], None, None, @@ -1098,7 +1093,6 @@ async fn sandbox_create_keeps_waiting_while_vm_progress_arrives() { None, None, None, - None, &[], None, None, @@ -1145,7 +1139,6 @@ async fn sandbox_create_times_out_when_only_logs_arrive() { None, None, None, - None, &[], None, None, @@ -1188,7 +1181,6 @@ async fn sandbox_create_deletes_command_sessions_with_no_keep() { None, None, None, - None, &[], None, None, @@ -1235,7 +1227,6 @@ async fn sandbox_create_deletes_shell_sessions_with_no_keep() { None, None, None, - None, &[], None, None, @@ -1282,7 +1273,6 @@ async fn sandbox_create_keeps_sandbox_with_hidden_keep_flag() { None, None, None, - None, &[], None, None, @@ -1329,7 +1319,6 @@ async fn sandbox_create_keeps_sandbox_with_forwarding() { None, None, None, - None, &[], None, Some(openshell_core::forward::ForwardSpec::new(forward_port)), @@ -1372,7 +1361,6 @@ async fn sandbox_create_sends_environment_variables() { None, None, None, - None, &[], None, None, diff --git a/crates/openshell-core/src/error.rs b/crates/openshell-core/src/error.rs index 6f04ebece..a149cf006 100644 --- a/crates/openshell-core/src/error.rs +++ b/crates/openshell-core/src/error.rs @@ -113,6 +113,9 @@ pub enum ComputeDriverError { /// The requested sandbox already exists. #[error("sandbox already exists")] AlreadyExists, + /// The request contains an invalid argument. + #[error("{0}")] + InvalidArgument(String), /// A precondition for the operation was not met. #[error("{0}")] Precondition(String), @@ -125,6 +128,7 @@ impl From for tonic::Status { fn from(err: ComputeDriverError) -> Self { match err { ComputeDriverError::AlreadyExists => Self::already_exists("sandbox already exists"), + ComputeDriverError::InvalidArgument(m) => Self::invalid_argument(m), ComputeDriverError::Precondition(m) => Self::failed_precondition(m), ComputeDriverError::Message(m) => Self::internal(m), } diff --git a/crates/openshell-core/src/gpu.rs b/crates/openshell-core/src/gpu.rs index 5df8702ed..9718b50f2 100644 --- a/crates/openshell-core/src/gpu.rs +++ b/crates/openshell-core/src/gpu.rs @@ -5,18 +5,18 @@ use crate::config::CDI_GPU_DEVICE_ALL; -/// Resolve the existing GPU request fields into CDI device identifiers. +/// Resolve a GPU request into CDI device identifiers. /// -/// `None` means no GPU was requested. A GPU request with no explicit device -/// ID uses the CDI all-GPU request; otherwise the driver-native ID passes -/// through unchanged. +/// `None` means no GPU was requested. A GPU request with no explicit CDI +/// devices uses the CDI all-GPU request; otherwise the driver-configured CDI +/// devices pass through unchanged. #[must_use] -pub fn cdi_gpu_device_ids(gpu: bool, gpu_device: &str) -> Option> { +pub fn cdi_gpu_device_ids(gpu: bool, cdi_devices: &[String]) -> Option> { gpu.then(|| { - if gpu_device.is_empty() { + if cdi_devices.is_empty() { vec![CDI_GPU_DEVICE_ALL.to_string()] } else { - vec![gpu_device.to_string()] + cdi_devices.to_vec() } }) } @@ -27,22 +27,31 @@ mod tests { #[test] fn cdi_gpu_device_ids_returns_none_when_absent() { - assert_eq!(cdi_gpu_device_ids(false, ""), None); + assert_eq!(cdi_gpu_device_ids(false, &[]), None); } #[test] fn cdi_gpu_device_ids_defaults_empty_request_to_all_gpus() { assert_eq!( - cdi_gpu_device_ids(true, ""), + cdi_gpu_device_ids(true, &[]), Some(vec![CDI_GPU_DEVICE_ALL.to_string()]) ); } #[test] - fn cdi_gpu_device_ids_passes_explicit_device_id_through() { + fn cdi_gpu_device_ids_passes_explicit_device_ids_through() { assert_eq!( - cdi_gpu_device_ids(true, "nvidia.com/gpu=0"), - Some(vec!["nvidia.com/gpu=0".to_string()]) + cdi_gpu_device_ids( + true, + &[ + "nvidia.com/gpu=0".to_string(), + "nvidia.com/gpu=1".to_string() + ] + ), + Some(vec![ + "nvidia.com/gpu=0".to_string(), + "nvidia.com/gpu=1".to_string() + ]) ); } } diff --git a/crates/openshell-core/src/lib.rs b/crates/openshell-core/src/lib.rs index c3241cdd8..c975cfd18 100644 --- a/crates/openshell-core/src/lib.rs +++ b/crates/openshell-core/src/lib.rs @@ -22,6 +22,7 @@ pub mod net; pub mod paths; pub mod progress; pub mod proto; +pub mod proto_struct; pub mod sandbox_env; pub mod settings; pub mod telemetry; diff --git a/crates/openshell-core/src/proto_struct.rs b/crates/openshell-core/src/proto_struct.rs new file mode 100644 index 000000000..874c7c6fd --- /dev/null +++ b/crates/openshell-core/src/proto_struct.rs @@ -0,0 +1,120 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! Helpers for decoding `google.protobuf.Struct` values. + +use serde::{Deserialize, Deserializer, de::Error as _}; + +/// Convert a protobuf Struct into a JSON object for typed serde decoding. +#[must_use] +pub fn struct_to_json_object( + config: &prost_types::Struct, +) -> serde_json::Map { + config + .fields + .iter() + .map(|(key, value)| (key.clone(), value_to_json(value))) + .collect() +} + +/// Convert a protobuf Struct into a JSON value for typed serde decoding. +#[must_use] +pub fn struct_to_json_value(config: &prost_types::Struct) -> serde_json::Value { + serde_json::Value::Object(struct_to_json_object(config)) +} + +/// Convert a protobuf Value into a JSON value for typed serde decoding. +#[must_use] +pub fn value_to_json(value: &prost_types::Value) -> serde_json::Value { + match value.kind.as_ref() { + Some(prost_types::value::Kind::NumberValue(num)) => serde_json::Number::from_f64(*num) + .map_or(serde_json::Value::Null, serde_json::Value::Number), + Some(prost_types::value::Kind::StringValue(val)) => serde_json::Value::String(val.clone()), + Some(prost_types::value::Kind::BoolValue(val)) => serde_json::Value::Bool(*val), + Some(prost_types::value::Kind::StructValue(val)) => { + let mut map = serde_json::Map::new(); + for (key, value) in &val.fields { + map.insert(key.clone(), value_to_json(value)); + } + serde_json::Value::Object(map) + } + Some(prost_types::value::Kind::ListValue(list)) => { + let values = list.values.iter().map(value_to_json).collect(); + serde_json::Value::Array(values) + } + Some(prost_types::value::Kind::NullValue(_)) | None => serde_json::Value::Null, + } +} + +/// Deserialize a present field as a non-empty list of non-empty strings. +/// +/// Use with `#[serde(default, deserialize_with = "...")]` on +/// `Option>` fields. Missing fields use the option default; present +/// fields must be arrays and cannot be empty. +pub fn deserialize_optional_non_empty_string_list<'de, D>( + deserializer: D, +) -> Result>, D::Error> +where + D: Deserializer<'de>, +{ + let values = Vec::::deserialize(deserializer)?; + if values.is_empty() { + return Err(D::Error::custom("must be a non-empty list of strings")); + } + + for (idx, value) in values.iter().enumerate() { + if value.trim().is_empty() { + return Err(D::Error::custom(format!( + "[{idx}] must be a non-empty string" + ))); + } + } + + Ok(Some(values)) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[derive(Debug, Default, Deserialize)] + #[serde(default)] + struct TestConfig { + #[serde( + default, + deserialize_with = "deserialize_optional_non_empty_string_list" + )] + devices: Option>, + } + + #[test] + fn optional_non_empty_string_list_defaults_when_absent() { + let config: TestConfig = serde_json::from_value(serde_json::json!({})).unwrap(); + + assert_eq!(config.devices, None); + } + + #[test] + fn optional_non_empty_string_list_parses_present_list() { + let config: TestConfig = + serde_json::from_value(serde_json::json!({"devices": ["nvidia.com/gpu=0"]})).unwrap(); + + assert_eq!(config.devices, Some(vec!["nvidia.com/gpu=0".to_string()])); + } + + #[test] + fn optional_non_empty_string_list_rejects_empty_list() { + let err = + serde_json::from_value::(serde_json::json!({"devices": []})).unwrap_err(); + + assert!(err.to_string().contains("non-empty list")); + } + + #[test] + fn optional_non_empty_string_list_rejects_empty_string() { + let err = + serde_json::from_value::(serde_json::json!({"devices": [""]})).unwrap_err(); + + assert!(err.to_string().contains("non-empty string")); + } +} diff --git a/crates/openshell-driver-docker/Cargo.toml b/crates/openshell-driver-docker/Cargo.toml index 4ddb1a913..df2fa25f0 100644 --- a/crates/openshell-driver-docker/Cargo.toml +++ b/crates/openshell-driver-docker/Cargo.toml @@ -27,6 +27,7 @@ tempfile = "3" url = { workspace = true } [dev-dependencies] +prost-types = { workspace = true } temp-env = "0.3" [lints] diff --git a/crates/openshell-driver-docker/README.md b/crates/openshell-driver-docker/README.md index ea57f44e4..4c6cfa8a7 100644 --- a/crates/openshell-driver-docker/README.md +++ b/crates/openshell-driver-docker/README.md @@ -32,7 +32,7 @@ contract: | `apparmor=unconfined` | Avoids Docker's default profile blocking required mount operations. | | `restart_policy = unless-stopped` | Keeps managed sandboxes resumable across daemon or gateway restarts. | | `PidsLimit` | Enforces the sandbox PID budget at the Docker cgroup layer. Set `[openshell.drivers.docker].sandbox_pids_limit = 0` to inherit the Docker/runtime default. | -| CDI GPU request | Uses the sandbox `gpu_device` value when set; otherwise requests all NVIDIA GPUs when the sandbox spec asks for GPU support and daemon CDI support is detected. | +| CDI GPU request | Uses `driver_config.cdi_devices` when set; otherwise requests all NVIDIA GPUs when the sandbox spec asks for GPU support and daemon CDI support is detected. | The agent child process does not retain these supervisor privileges. diff --git a/crates/openshell-driver-docker/src/lib.rs b/crates/openshell-driver-docker/src/lib.rs index 49bd3d3f6..56dab7e4d 100644 --- a/crates/openshell-driver-docker/src/lib.rs +++ b/crates/openshell-driver-docker/src/lib.rs @@ -40,6 +40,9 @@ use openshell_core::proto::compute::v1::{ WatchSandboxesRequest, WatchSandboxesSandboxEvent, compute_driver_server::ComputeDriver, watch_sandboxes_event, }; +use openshell_core::proto_struct::{ + deserialize_optional_non_empty_string_list, struct_to_json_value, +}; use openshell_core::{Config, Error, Result as CoreResult}; use std::collections::HashMap; use std::io::Read; @@ -70,6 +73,39 @@ const HOST_OPENSHELL_INTERNAL: &str = "host.openshell.internal"; const HOST_DOCKER_INTERNAL: &str = "host.docker.internal"; const DOCKER_NETWORK_DRIVER: &str = "bridge"; +#[derive(Debug, Clone, Default, serde::Deserialize)] +#[serde(default, deny_unknown_fields)] +struct DockerSandboxDriverConfig { + #[serde( + default, + deserialize_with = "deserialize_optional_non_empty_string_list" + )] + cdi_devices: Option>, +} + +impl DockerSandboxDriverConfig { + fn from_sandbox(sandbox: &DriverSandbox) -> Result { + let Some(template) = sandbox + .spec + .as_ref() + .and_then(|spec| spec.template.as_ref()) + else { + return Ok(Self::default()); + }; + + Self::from_template(template) + } + + fn from_template(template: &DriverSandboxTemplate) -> Result { + let Some(config) = template.driver_config.as_ref() else { + return Ok(Self::default()); + }; + + serde_json::from_value(struct_to_json_value(config)) + .map_err(|err| format!("invalid docker driver_config: {err}")) + } +} + /// Default image holding the Linux `openshell-sandbox` binary. The gateway /// pulls this image and extracts the binary to a host-side cache when no /// explicit `supervisor_bin` override or local build is available. @@ -370,12 +406,20 @@ impl DockerComputeDriver { .as_ref() .ok_or_else(|| Status::invalid_argument("sandbox.spec.template is required"))?; + Self::validate_sandbox_template(template)?; + + let driver_config = + DockerSandboxDriverConfig::from_template(template).map_err(Status::invalid_argument)?; + Self::validate_gpu_request(spec.gpu, config.supports_gpu, &driver_config)?; + Ok(()) + } + + fn validate_sandbox_template(template: &DriverSandboxTemplate) -> Result<(), Status> { if template.image.trim().is_empty() { return Err(Status::failed_precondition( "docker sandboxes require a template image", )); } - Self::validate_gpu_request(spec.gpu, config.supports_gpu)?; if !template.agent_socket_path.trim().is_empty() { return Err(Status::failed_precondition( "docker compute driver does not support template.agent_socket_path", @@ -409,7 +453,17 @@ impl DockerComputeDriver { )) } - fn validate_gpu_request(gpu: bool, supports_gpu: bool) -> Result<(), Status> { + fn validate_gpu_request( + gpu: bool, + supports_gpu: bool, + driver_config: &DockerSandboxDriverConfig, + ) -> Result<(), Status> { + if !gpu && driver_config.cdi_devices.is_some() { + return Err(Status::invalid_argument( + "driver_config.cdi_devices requires gpu=true", + )); + } + if gpu && !supports_gpu { return Err(Status::failed_precondition( "docker GPU sandboxes require Docker CDI support. Enable CDI on the Docker daemon, then restart the OpenShell gateway/server so GPU capability is detected.", @@ -1723,14 +1777,29 @@ fn build_environment(sandbox: &DriverSandbox, config: &DockerDriverRuntimeConfig .collect() } -fn docker_gpu_device_requests(gpu: bool, gpu_device: &str) -> Option> { - cdi_gpu_device_ids(gpu, gpu_device).map(|device_ids| { - vec![DeviceRequest { - driver: Some("cdi".to_string()), - device_ids: Some(device_ids), - ..Default::default() - }] - }) +fn build_device_requests(sandbox: &DriverSandbox) -> Result>, Status> { + let Some(spec) = sandbox.spec.as_ref() else { + return Ok(None); + }; + let cdi_devices = DockerSandboxDriverConfig::from_sandbox(sandbox) + .map_err(Status::invalid_argument)? + .cdi_devices + .unwrap_or_default(); + if !spec.gpu && !cdi_devices.is_empty() { + return Err(Status::invalid_argument( + "driver_config.cdi_devices requires gpu=true", + )); + } + + Ok( + cdi_gpu_device_ids(spec.gpu, &cdi_devices).map(|device_ids| { + vec![DeviceRequest { + driver: Some("cdi".to_string()), + device_ids: Some(device_ids), + ..Default::default() + }] + }), + ) } fn build_container_create_body( @@ -1746,6 +1815,7 @@ fn build_container_create_body( .as_ref() .ok_or_else(|| Status::invalid_argument("sandbox.spec.template is required"))?; let resource_limits = docker_resource_limits(template)?; + let device_requests = build_device_requests(sandbox)?; let mut labels = template.labels.clone(); labels.insert( LABEL_MANAGED_BY.to_string(), @@ -1775,7 +1845,7 @@ fn build_container_create_body( nano_cpus: resource_limits.nano_cpus, memory: resource_limits.memory_bytes, pids_limit: docker_pids_limit(config.sandbox_pids_limit)?, - device_requests: docker_gpu_device_requests(spec.gpu, &spec.gpu_device), + device_requests, binds: Some(build_binds(sandbox, config)?), restart_policy: Some(RestartPolicy { name: Some(RestartPolicyNameEnum::UNLESS_STOPPED), diff --git a/crates/openshell-driver-docker/src/tests.rs b/crates/openshell-driver-docker/src/tests.rs index 4a902a48b..b924c6ccd 100644 --- a/crates/openshell-driver-docker/src/tests.rs +++ b/crates/openshell-driver-docker/src/tests.rs @@ -42,13 +42,43 @@ fn test_sandbox() -> DriverSandbox { ..Default::default() }), gpu: false, - gpu_device: String::new(), sandbox_token: String::new(), }), status: None, } } +fn cdi_devices_config(device_ids: &[&str]) -> prost_types::Struct { + list_string_driver_config("cdi_devices", device_ids) +} + +fn cdi_device_typo_config(device_ids: &[&str]) -> prost_types::Struct { + list_string_driver_config("cdi_device", device_ids) +} + +fn list_string_driver_config(field: &str, values: &[&str]) -> prost_types::Struct { + prost_types::Struct { + fields: std::iter::once(( + field.to_string(), + prost_types::Value { + kind: Some(prost_types::value::Kind::ListValue( + prost_types::ListValue { + values: values + .iter() + .map(|device_id| prost_types::Value { + kind: Some(prost_types::value::Kind::StringValue( + (*device_id).to_string(), + )), + }) + .collect(), + }, + )), + }, + )) + .collect(), + } +} + fn runtime_config() -> DockerDriverRuntimeConfig { DockerDriverRuntimeConfig { default_image: "image:latest".to_string(), @@ -612,6 +642,52 @@ fn validate_sandbox_rejects_gpu_when_cdi_unavailable() { assert!(err.message().contains("Docker CDI")); } +#[test] +fn validate_sandbox_rejects_invalid_cdi_devices_before_gpu_capability() { + let config = runtime_config(); + let mut sandbox = test_sandbox(); + let spec = sandbox.spec.as_mut().unwrap(); + spec.gpu = true; + spec.template.as_mut().unwrap().driver_config = Some(cdi_devices_config(&[])); + + let err = DockerComputeDriver::validate_sandbox(&sandbox, &config).unwrap_err(); + + assert_eq!(err.code(), tonic::Code::InvalidArgument); + assert!(err.message().contains("invalid docker driver_config")); + assert!(err.message().contains("non-empty list")); +} + +#[test] +fn validate_sandbox_rejects_unknown_driver_config_fields() { + let config = runtime_config(); + let mut sandbox = test_sandbox(); + let spec = sandbox.spec.as_mut().unwrap(); + spec.gpu = true; + spec.template.as_mut().unwrap().driver_config = + Some(cdi_device_typo_config(&["nvidia.com/gpu=0"])); + + let err = DockerComputeDriver::validate_sandbox(&sandbox, &config).unwrap_err(); + + assert_eq!(err.code(), tonic::Code::InvalidArgument); + assert!(err.message().contains("unknown field")); +} + +#[test] +fn validate_sandbox_rejects_template_errors_before_device_config() { + let config = runtime_config(); + let mut sandbox = test_sandbox(); + let spec = sandbox.spec.as_mut().unwrap(); + spec.gpu = true; + let template = spec.template.as_mut().unwrap(); + template.agent_socket_path = "/tmp/agent.sock".to_string(); + template.driver_config = Some(cdi_devices_config(&[])); + + let err = DockerComputeDriver::validate_sandbox(&sandbox, &config).unwrap_err(); + + assert_eq!(err.code(), tonic::Code::FailedPrecondition); + assert!(err.message().contains("agent_socket_path")); +} + #[test] fn validate_sandbox_auth_requires_gateway_token() { let mut sandbox = test_sandbox(); @@ -663,7 +739,7 @@ fn build_container_create_body_passes_explicit_cdi_device_id_through() { let mut sandbox = test_sandbox(); let spec = sandbox.spec.as_mut().unwrap(); spec.gpu = true; - spec.gpu_device = "nvidia.com/gpu=0".to_string(); + spec.template.as_mut().unwrap().driver_config = Some(cdi_devices_config(&["nvidia.com/gpu=0"])); let create_body = build_container_create_body(&sandbox, &config).unwrap(); let request = create_body @@ -680,6 +756,35 @@ fn build_container_create_body_passes_explicit_cdi_device_id_through() { ); } +#[test] +fn build_container_create_body_rejects_cdi_devices_without_gpu() { + let mut sandbox = test_sandbox(); + sandbox + .spec + .as_mut() + .unwrap() + .template + .as_mut() + .unwrap() + .driver_config = Some(cdi_devices_config(&["nvidia.com/gpu=0"])); + + let err = build_container_create_body(&sandbox, &runtime_config()).unwrap_err(); + assert_eq!(err.code(), tonic::Code::InvalidArgument); + assert!(err.message().contains("requires gpu=true")); +} + +#[test] +fn build_container_create_body_rejects_empty_cdi_devices() { + let mut sandbox = test_sandbox(); + let spec = sandbox.spec.as_mut().unwrap(); + spec.gpu = true; + spec.template.as_mut().unwrap().driver_config = Some(cdi_devices_config(&[])); + + let err = build_container_create_body(&sandbox, &runtime_config()).unwrap_err(); + assert_eq!(err.code(), tonic::Code::InvalidArgument); + assert!(err.message().contains("non-empty list")); +} + #[test] fn require_sandbox_identifier_rejects_when_id_and_name_are_empty() { // Regression test: `delete_sandbox` (and the other identifier-keyed diff --git a/crates/openshell-driver-kubernetes/README.md b/crates/openshell-driver-kubernetes/README.md index 0bdcf3748..6ad0b27c8 100644 --- a/crates/openshell-driver-kubernetes/README.md +++ b/crates/openshell-driver-kubernetes/README.md @@ -93,7 +93,7 @@ openshell sandbox create \ ``` Resource keys use native Kubernetes resource names and quantity strings. The -POC parser renders the keys listed above and ignores unknown fields. +POC parser renders the keys listed above and rejects unknown fields. `pod.runtime_class_name` maps to PodSpec `runtimeClassName` and overrides the driver's configured `default_runtime_class_name`; the typed public `SandboxTemplate.runtime_class_name` still takes precedence when set. Use the diff --git a/crates/openshell-driver-kubernetes/src/driver.rs b/crates/openshell-driver-kubernetes/src/driver.rs index ba20b0725..f984fce09 100644 --- a/crates/openshell-driver-kubernetes/src/driver.rs +++ b/crates/openshell-driver-kubernetes/src/driver.rs @@ -28,6 +28,7 @@ use openshell_core::proto::compute::v1::{ GetCapabilitiesResponse, WatchSandboxesDeletedEvent, WatchSandboxesEvent, WatchSandboxesPlatformEvent, WatchSandboxesSandboxEvent, watch_sandboxes_event, }; +use openshell_core::proto_struct::{struct_to_json_object, value_to_json}; use serde::Deserialize; use std::collections::BTreeMap; use std::pin::Pin; @@ -44,6 +45,8 @@ pub enum KubernetesDriverError { #[error("sandbox already exists")] AlreadyExists, #[error("{0}")] + InvalidArgument(String), + #[error("{0}")] Precondition(String), #[error("{0}")] Message(String), @@ -62,6 +65,7 @@ impl From for openshell_core::ComputeDriverError { fn from(err: KubernetesDriverError) -> Self { match err { KubernetesDriverError::AlreadyExists => Self::AlreadyExists, + KubernetesDriverError::InvalidArgument(m) => Self::InvalidArgument(m), KubernetesDriverError::Precondition(m) => Self::Precondition(m), KubernetesDriverError::Message(m) => Self::Message(m), } @@ -87,14 +91,38 @@ const GPU_RESOURCE_QUANTITY: &str = "1"; // translation layer; the RFC boundary is Struct at the gateway, typed config in // the selected driver. #[derive(Debug, Clone, Default, Deserialize)] -#[serde(default)] +#[serde(default, deny_unknown_fields)] struct KubernetesSandboxDriverConfig { pod: KubernetesPodDriverConfig, containers: KubernetesDriverContainersConfig, } +impl KubernetesSandboxDriverConfig { + fn from_sandbox(sandbox: &Sandbox) -> Result { + let Some(template) = sandbox + .spec + .as_ref() + .and_then(|spec| spec.template.as_ref()) + else { + return Ok(Self::default()); + }; + + Self::from_template(template) + } + + fn from_template(template: &SandboxTemplate) -> Result { + let Some(config) = template.driver_config.as_ref() else { + return Ok(Self::default()); + }; + + let json = serde_json::Value::Object(struct_to_json_object(config)); + serde_json::from_value(json) + .map_err(|err| format!("invalid kubernetes driver_config: {err}")) + } +} + #[derive(Debug, Clone, Default, Deserialize)] -#[serde(default)] +#[serde(default, deny_unknown_fields)] struct KubernetesPodDriverConfig { node_selector: BTreeMap, runtime_class_name: String, @@ -103,19 +131,19 @@ struct KubernetesPodDriverConfig { } #[derive(Debug, Clone, Default, Deserialize)] -#[serde(default)] +#[serde(default, deny_unknown_fields)] struct KubernetesDriverContainersConfig { agent: KubernetesContainerDriverConfig, } #[derive(Debug, Clone, Default, Deserialize)] -#[serde(default)] +#[serde(default, deny_unknown_fields)] struct KubernetesContainerDriverConfig { resources: KubernetesContainerResourceConfig, } #[derive(Debug, Clone, Default, Deserialize)] -#[serde(default)] +#[serde(default, deny_unknown_fields)] struct KubernetesContainerResourceConfig { requests: BTreeMap, limits: BTreeMap, @@ -245,6 +273,8 @@ impl KubernetesComputeDriver { } pub async fn validate_sandbox_create(&self, sandbox: &Sandbox) -> Result<(), tonic::Status> { + let _ = KubernetesSandboxDriverConfig::from_sandbox(sandbox) + .map_err(tonic::Status::invalid_argument)?; let gpu_requested = sandbox.spec.as_ref().is_some_and(|spec| spec.gpu); if gpu_requested && !self.has_gpu_capacity().await.map_err(|err| { @@ -338,6 +368,8 @@ impl KubernetesComputeDriver { } pub async fn create_sandbox(&self, sandbox: &Sandbox) -> Result<(), KubernetesDriverError> { + let _ = KubernetesSandboxDriverConfig::from_sandbox(sandbox) + .map_err(KubernetesDriverError::InvalidArgument)?; let name = sandbox.name.as_str(); info!( sandbox_id = %sandbox.id, @@ -1129,18 +1161,8 @@ fn spec_pod_env(spec: Option<&SandboxSpec>) -> std::collections::HashMap KubernetesSandboxDriverConfig { - let Some(config) = template.driver_config.as_ref() else { - return KubernetesSandboxDriverConfig::default(); - }; - - let json = serde_json::Value::Object(proto_struct_to_json_object(config)); - match serde_json::from_value(json) { - Ok(config) => config, - Err(err) => { - warn!(error = %err, "Ignoring invalid Kubernetes driver_config"); - KubernetesSandboxDriverConfig::default() - } - } + KubernetesSandboxDriverConfig::from_template(template) + .expect("validated Kubernetes driver_config") } fn sandbox_to_k8s_spec( @@ -1751,7 +1773,7 @@ fn platform_config_bool(template: &SandboxTemplate, key: &str) -> Option { fn platform_config_struct(template: &SandboxTemplate, key: &str) -> Option { let config = template.platform_config.as_ref()?; let value = config.fields.get(key)?; - let json = proto_value_to_json(value); + let json = value_to_json(value); // Return None for null/empty objects so callers can distinguish // "field absent" from "field present but empty". match &json { @@ -1761,37 +1783,6 @@ fn platform_config_struct(template: &SandboxTemplate, key: &str) -> Option serde_json::Map { - config - .fields - .iter() - .map(|(key, value)| (key.clone(), proto_value_to_json(value))) - .collect() -} - -fn proto_value_to_json(value: &prost_types::Value) -> serde_json::Value { - match value.kind.as_ref() { - Some(prost_types::value::Kind::NumberValue(num)) => serde_json::Number::from_f64(*num) - .map_or(serde_json::Value::Null, serde_json::Value::Number), - Some(prost_types::value::Kind::StringValue(val)) => serde_json::Value::String(val.clone()), - Some(prost_types::value::Kind::BoolValue(val)) => serde_json::Value::Bool(*val), - Some(prost_types::value::Kind::StructValue(val)) => { - let mut map = serde_json::Map::new(); - for (key, value) in &val.fields { - map.insert(key.clone(), proto_value_to_json(value)); - } - serde_json::Value::Object(map) - } - Some(prost_types::value::Kind::ListValue(list)) => { - let values = list.values.iter().map(proto_value_to_json).collect(); - serde_json::Value::Array(values) - } - Some(prost_types::value::Kind::NullValue(_)) | None => serde_json::Value::Null, - } -} - fn status_from_object(obj: &DynamicObject) -> Option { let status = obj.data.get("status")?; let status_obj = status.as_object()?; @@ -1904,7 +1895,7 @@ mod tests { } #[test] - fn driver_config_ignores_invalid_shape() { + fn driver_config_rejects_invalid_shape() { let template = SandboxTemplate { driver_config: Some(json_struct(serde_json::json!({ "pod": "not-an-object" @@ -1912,11 +1903,44 @@ mod tests { ..SandboxTemplate::default() }; - let config = kubernetes_driver_config(&template); + let err = KubernetesSandboxDriverConfig::from_template(&template).unwrap_err(); + + assert!(err.contains("invalid kubernetes driver_config")); + } + + #[test] + fn driver_config_rejects_unknown_fields() { + let template = SandboxTemplate { + driver_config: Some(json_struct(serde_json::json!({ + "cdi_devices": ["nvidia.com/gpu=0"] + }))), + ..SandboxTemplate::default() + }; + + let err = KubernetesSandboxDriverConfig::from_template(&template).unwrap_err(); + + assert!(err.contains("unknown field")); + } + + #[test] + fn driver_config_from_sandbox_rejects_unknown_fields() { + let sandbox = Sandbox { + id: "sandbox-123".to_string(), + spec: Some(SandboxSpec { + gpu: true, + template: Some(SandboxTemplate { + driver_config: Some(json_struct(serde_json::json!({ + "cdi_devices": ["nvidia.com/gpu=0"] + }))), + ..Default::default() + }), + ..Default::default() + }), + ..Default::default() + }; - assert!(config.pod.node_selector.is_empty()); - assert!(config.containers.agent.resources.requests.is_empty()); - assert!(config.containers.agent.resources.limits.is_empty()); + let err = KubernetesSandboxDriverConfig::from_sandbox(&sandbox).unwrap_err(); + assert!(err.contains("unknown field")); } #[test] diff --git a/crates/openshell-driver-podman/Cargo.toml b/crates/openshell-driver-podman/Cargo.toml index 4a1c8de83..04ba326ca 100644 --- a/crates/openshell-driver-podman/Cargo.toml +++ b/crates/openshell-driver-podman/Cargo.toml @@ -35,6 +35,7 @@ thiserror = { workspace = true } miette = { workspace = true } [dev-dependencies] +prost-types = { workspace = true } temp-env = "0.3" [lints] diff --git a/crates/openshell-driver-podman/README.md b/crates/openshell-driver-podman/README.md index 77b42ba37..e1117909a 100644 --- a/crates/openshell-driver-podman/README.md +++ b/crates/openshell-driver-podman/README.md @@ -46,7 +46,7 @@ The container spec in `container.rs` sets these security-critical fields: | `no_new_privileges` | `true` | Prevents privilege escalation after exec. | | `seccomp_profile_path` | `unconfined` | The supervisor installs its own policy-aware BPF filter. A container-level profile can block Landlock/seccomp syscalls during setup. | | `mounts` | Private tmpfs at `/run/netns` | Lets the supervisor create named network namespaces in rootless Podman. | -| CDI GPU devices | Sandbox `gpu_device` value when set, otherwise all NVIDIA GPUs | Exposes requested GPUs to GPU-enabled sandbox containers. | +| CDI GPU devices | `driver_config.cdi_devices` when set, otherwise all NVIDIA GPUs | Exposes requested GPUs to GPU-enabled sandbox containers. | The restricted agent child does not retain these supervisor privileges. diff --git a/crates/openshell-driver-podman/src/container.rs b/crates/openshell-driver-podman/src/container.rs index f3aceb9bf..9170c4c63 100644 --- a/crates/openshell-driver-podman/src/container.rs +++ b/crates/openshell-driver-podman/src/container.rs @@ -4,9 +4,13 @@ //! Container spec construction for the Podman driver. use crate::config::PodmanComputeConfig; +use openshell_core::ComputeDriverError; use openshell_core::gpu::cdi_gpu_device_ids; -use openshell_core::proto::compute::v1::DriverSandbox; -use serde::Serialize; +use openshell_core::proto::compute::v1::{DriverSandbox, DriverSandboxTemplate}; +use openshell_core::proto_struct::{ + deserialize_optional_non_empty_string_list, struct_to_json_value, +}; +use serde::{Deserialize, Serialize}; use serde_json::Value; use std::collections::BTreeMap; @@ -57,6 +61,40 @@ const SUPERVISOR_MOUNT_DIR: &str = openshell_core::driver_utils::SUPERVISOR_CONT /// Full path to the supervisor binary inside sandbox containers. const SUPERVISOR_BINARY_PATH: &str = openshell_core::driver_utils::SUPERVISOR_CONTAINER_BINARY; +#[derive(Debug, Clone, Default, Deserialize)] +#[serde(default, deny_unknown_fields)] +pub struct PodmanSandboxDriverConfig { + #[serde( + default, + deserialize_with = "deserialize_optional_non_empty_string_list" + )] + pub cdi_devices: Option>, +} + +impl PodmanSandboxDriverConfig { + pub fn from_sandbox(sandbox: &DriverSandbox) -> Result { + let Some(template) = sandbox + .spec + .as_ref() + .and_then(|spec| spec.template.as_ref()) + else { + return Ok(Self::default()); + }; + + Self::from_template(template) + } + + pub fn from_template(template: &DriverSandboxTemplate) -> Result { + let Some(config) = template.driver_config.as_ref() else { + return Ok(Self::default()); + }; + + serde_json::from_value(struct_to_json_value(config)).map_err(|err| { + ComputeDriverError::InvalidArgument(format!("invalid podman driver_config: {err}")) + }) + } +} + /// Build a Podman container name from the sandbox name. #[must_use] pub fn container_name(sandbox_name: &str) -> String { @@ -391,29 +429,53 @@ fn podman_pids_limit(value: i64) -> Option { } /// Build CDI GPU device list if GPU is requested. -fn build_devices(sandbox: &DriverSandbox) -> Option> { - let spec = sandbox.spec.as_ref()?; - cdi_gpu_device_ids(spec.gpu, &spec.gpu_device).map(|device_ids| { - device_ids - .into_iter() - .map(|path| LinuxDevice { path }) - .collect() - }) +fn build_devices(sandbox: &DriverSandbox) -> Result>, ComputeDriverError> { + let Some(spec) = sandbox.spec.as_ref() else { + return Ok(None); + }; + let cdi_devices = PodmanSandboxDriverConfig::from_sandbox(sandbox)? + .cdi_devices + .unwrap_or_default(); + if !spec.gpu && !cdi_devices.is_empty() { + return Err(ComputeDriverError::InvalidArgument( + "driver_config.cdi_devices requires gpu=true".to_string(), + )); + } + + Ok( + cdi_gpu_device_ids(spec.gpu, &cdi_devices).map(|device_ids| { + device_ids + .into_iter() + .map(|path| LinuxDevice { path }) + .collect() + }), + ) } /// Build the Podman container creation JSON spec. #[cfg(test)] #[must_use] pub fn build_container_spec(sandbox: &DriverSandbox, config: &PodmanComputeConfig) -> Value { - build_container_spec_with_token(sandbox, config, None) + try_build_container_spec_with_token(sandbox, config, None) + .expect("container spec should be valid") } +#[cfg(test)] #[must_use] pub fn build_container_spec_with_token( sandbox: &DriverSandbox, config: &PodmanComputeConfig, token_host_path: Option<&std::path::Path>, ) -> Value { + try_build_container_spec_with_token(sandbox, config, token_host_path) + .expect("container spec should be valid") +} + +pub fn try_build_container_spec_with_token( + sandbox: &DriverSandbox, + config: &PodmanComputeConfig, + token_host_path: Option<&std::path::Path>, +) -> Result { let image = resolve_image(sandbox, config); let name = container_name(&sandbox.name); let vol = volume_name(&sandbox.id); @@ -421,7 +483,7 @@ pub fn build_container_spec_with_token( let env = build_env(sandbox, config, image); let labels = build_labels(sandbox); let resource_limits = build_resource_limits(sandbox, config); - let devices = build_devices(sandbox); + let devices = build_devices(sandbox)?; // Network configuration -- always bridge mode. // Matches libpod's network spec format `{name: {opts}}`; the unit-struct @@ -633,7 +695,7 @@ pub fn build_container_spec_with_token( }], }; - serde_json::to_value(container_spec).expect("ContainerSpec serialization cannot fail") + Ok(serde_json::to_value(container_spec).expect("ContainerSpec serialization cannot fail")) } fn hostadd_entries(config: &PodmanComputeConfig) -> Vec { @@ -839,12 +901,15 @@ mod tests { #[test] fn container_spec_passes_explicit_cdi_device_id_through() { - use openshell_core::proto::compute::v1::DriverSandboxSpec; + use openshell_core::proto::compute::v1::{DriverSandboxSpec, DriverSandboxTemplate}; let mut sandbox = test_sandbox("test-id", "test-name"); sandbox.spec = Some(DriverSandboxSpec { gpu: true, - gpu_device: "nvidia.com/gpu=0".to_string(), + template: Some(DriverSandboxTemplate { + driver_config: Some(cdi_devices_config(&["nvidia.com/gpu=0"])), + ..Default::default() + }), ..Default::default() }); let config = test_config(); @@ -856,6 +921,65 @@ mod tests { ); } + #[test] + fn container_spec_rejects_cdi_devices_without_gpu() { + use openshell_core::proto::compute::v1::{DriverSandboxSpec, DriverSandboxTemplate}; + + let mut sandbox = test_sandbox("test-id", "test-name"); + sandbox.spec = Some(DriverSandboxSpec { + template: Some(DriverSandboxTemplate { + driver_config: Some(cdi_devices_config(&["nvidia.com/gpu=0"])), + ..Default::default() + }), + ..Default::default() + }); + let config = test_config(); + + let err = try_build_container_spec_with_token(&sandbox, &config, None).unwrap_err(); + assert!(matches!(err, ComputeDriverError::InvalidArgument(_))); + assert!(err.to_string().contains("requires gpu=true")); + } + + #[test] + fn container_spec_rejects_empty_cdi_devices() { + use openshell_core::proto::compute::v1::{DriverSandboxSpec, DriverSandboxTemplate}; + + let mut sandbox = test_sandbox("test-id", "test-name"); + sandbox.spec = Some(DriverSandboxSpec { + gpu: true, + template: Some(DriverSandboxTemplate { + driver_config: Some(cdi_devices_config(&[])), + ..Default::default() + }), + ..Default::default() + }); + let config = test_config(); + + let err = try_build_container_spec_with_token(&sandbox, &config, None).unwrap_err(); + assert!(matches!(err, ComputeDriverError::InvalidArgument(_))); + assert!(err.to_string().contains("non-empty list")); + } + + #[test] + fn container_spec_rejects_unknown_driver_config_fields() { + use openshell_core::proto::compute::v1::{DriverSandboxSpec, DriverSandboxTemplate}; + + let mut sandbox = test_sandbox("test-id", "test-name"); + sandbox.spec = Some(DriverSandboxSpec { + gpu: true, + template: Some(DriverSandboxTemplate { + driver_config: Some(cdi_device_typo_config(&["nvidia.com/gpu=0"])), + ..Default::default() + }), + ..Default::default() + }); + let config = test_config(); + + let err = try_build_container_spec_with_token(&sandbox, &config, None).unwrap_err(); + assert!(matches!(err, ComputeDriverError::InvalidArgument(_))); + assert!(err.to_string().contains("unknown field")); + } + #[test] fn container_spec_includes_required_capabilities() { let sandbox = test_sandbox("test-id", "test-name"); @@ -1133,6 +1257,37 @@ mod tests { } } + fn cdi_devices_config(device_ids: &[&str]) -> prost_types::Struct { + list_string_driver_config("cdi_devices", device_ids) + } + + fn cdi_device_typo_config(device_ids: &[&str]) -> prost_types::Struct { + list_string_driver_config("cdi_device", device_ids) + } + + fn list_string_driver_config(field: &str, values: &[&str]) -> prost_types::Struct { + prost_types::Struct { + fields: std::iter::once(( + field.to_string(), + prost_types::Value { + kind: Some(prost_types::value::Kind::ListValue( + prost_types::ListValue { + values: values + .iter() + .map(|device_id| prost_types::Value { + kind: Some(prost_types::value::Kind::StringValue( + (*device_id).to_string(), + )), + }) + .collect(), + }, + )), + }, + )) + .collect(), + } + } + fn test_config() -> PodmanComputeConfig { PodmanComputeConfig { socket_path: std::path::PathBuf::from("/tmp/test.sock"), diff --git a/crates/openshell-driver-podman/src/driver.rs b/crates/openshell-driver-podman/src/driver.rs index 1358d8945..f11f38630 100644 --- a/crates/openshell-driver-podman/src/driver.rs +++ b/crates/openshell-driver-podman/src/driver.rs @@ -5,7 +5,7 @@ use crate::client::{PodmanApiError, PodmanClient}; use crate::config::PodmanComputeConfig; -use crate::container::{self, LABEL_MANAGED_FILTER, LABEL_SANDBOX_ID}; +use crate::container::{self, LABEL_MANAGED_FILTER, LABEL_SANDBOX_ID, PodmanSandboxDriverConfig}; use crate::watcher::{ self, WatchStream, driver_sandbox_from_inspect, driver_sandbox_from_list_entry, }; @@ -282,6 +282,12 @@ impl PodmanComputeDriver { sandbox: &DriverSandbox, ) -> Result<(), ComputeDriverError> { let gpu_requested = sandbox.spec.as_ref().is_some_and(|s| s.gpu); + let driver_config = PodmanSandboxDriverConfig::from_sandbox(sandbox)?; + if !gpu_requested && driver_config.cdi_devices.is_some() { + return Err(ComputeDriverError::InvalidArgument( + "driver_config.cdi_devices requires gpu=true".to_string(), + )); + } Self::validate_gpu_request(gpu_requested) } @@ -365,11 +371,11 @@ impl PodmanComputeDriver { }; // 3. Create container. - let spec = container::build_container_spec_with_token( + let spec = container::try_build_container_spec_with_token( sandbox, &self.config, token_host_path.as_deref(), - ); + )?; match self.client.create_container(&spec).await { Ok(_) => {} Err(PodmanApiError::Conflict(_)) => { diff --git a/crates/openshell-driver-vm/src/driver.rs b/crates/openshell-driver-vm/src/driver.rs index 2e0c5a603..30fecd8be 100644 --- a/crates/openshell-driver-vm/src/driver.rs +++ b/crates/openshell-driver-vm/src/driver.rs @@ -36,13 +36,17 @@ use openshell_core::progress::{ use openshell_core::proto::compute::v1::{ CreateSandboxRequest, CreateSandboxResponse, DeleteSandboxRequest, DeleteSandboxResponse, DriverCondition as SandboxCondition, DriverPlatformEvent as PlatformEvent, - DriverSandbox as Sandbox, DriverSandboxStatus as SandboxStatus, GetCapabilitiesRequest, - GetCapabilitiesResponse, GetSandboxRequest, GetSandboxResponse, ListSandboxesRequest, - ListSandboxesResponse, StopSandboxRequest, StopSandboxResponse, ValidateSandboxCreateRequest, + DriverSandbox as Sandbox, DriverSandboxStatus as SandboxStatus, + DriverSandboxTemplate as SandboxTemplate, GetCapabilitiesRequest, GetCapabilitiesResponse, + GetSandboxRequest, GetSandboxResponse, ListSandboxesRequest, ListSandboxesResponse, + StopSandboxRequest, StopSandboxResponse, ValidateSandboxCreateRequest, ValidateSandboxCreateResponse, WatchSandboxesDeletedEvent, WatchSandboxesEvent, WatchSandboxesPlatformEvent, WatchSandboxesRequest, WatchSandboxesSandboxEvent, compute_driver_server::ComputeDriver, watch_sandboxes_event, }; +use openshell_core::proto_struct::{ + deserialize_optional_non_empty_string_list, struct_to_json_value, +}; use openshell_vfio::SysfsRoot; use prost::Message; use sha2::{Digest, Sha256}; @@ -74,6 +78,40 @@ const DEFAULT_MEM_MIB: u32 = 2048; const DEFAULT_OVERLAY_DISK_MIB: u64 = 4096; const DEFAULT_REGISTRY_LAYER_DOWNLOAD_CONCURRENCY: usize = 4; const MAX_REGISTRY_LAYER_DOWNLOAD_CONCURRENCY: usize = 16; + +#[derive(Debug, Clone, Default, serde::Deserialize)] +#[serde(default, deny_unknown_fields)] +struct VmSandboxDriverConfig { + #[serde( + default, + deserialize_with = "deserialize_optional_non_empty_string_list" + )] + gpu_device_ids: Option>, +} + +impl VmSandboxDriverConfig { + fn from_sandbox(sandbox: &Sandbox) -> Result { + let Some(template) = sandbox + .spec + .as_ref() + .and_then(|spec| spec.template.as_ref()) + else { + return Ok(Self::default()); + }; + + Self::from_template(template) + } + + fn from_template(template: &SandboxTemplate) -> Result { + let Some(config) = template.driver_config.as_ref() else { + return Ok(Self::default()); + }; + + serde_json::from_value(struct_to_json_value(config)) + .map_err(|err| format!("invalid vm driver_config: {err}")) + } +} + /// gvproxy host-loopback IP — gvproxy's TCP/UDP/ICMP forwarder NAT-rewrites /// this destination to the host's `127.0.0.1` and dials out from the host /// process. This is the only address that transparently reaches host-bound @@ -651,9 +689,12 @@ impl VmDriver { ))); } - let gpu_device = sandbox.spec.as_ref().map_or("", |s| s.gpu_device.as_str()); - let gpu_bdf = if is_gpu { - Some(self.assign_gpu_to_record(&sandbox.id, gpu_device).await?) + let gpu_device_id = vm_gpu_device_id(&sandbox)?; + let gpu_bdf = if let Some(gpu_device_id) = gpu_device_id.as_deref() { + Some( + self.assign_gpu_to_record(&sandbox.id, gpu_device_id) + .await?, + ) } else { None }; @@ -3035,29 +3076,68 @@ fn validate_vm_sandbox(sandbox: &Sandbox, gpu_enabled: bool) -> Result<(), Statu .as_ref() .ok_or_else(|| Status::invalid_argument("sandbox spec is required"))?; + if let Some(template) = spec.template.as_ref() { + validate_vm_sandbox_template(template)?; + } + validate_vm_gpu_request(sandbox, gpu_enabled)?; + + Ok(()) +} + +#[allow(clippy::result_large_err)] +fn validate_vm_sandbox_template(template: &SandboxTemplate) -> Result<(), Status> { + if !template.agent_socket_path.is_empty() { + return Err(Status::failed_precondition( + "vm sandboxes do not support template.agent_socket_path", + )); + } + if template.platform_config.is_some() { + return Err(Status::failed_precondition( + "vm sandboxes do not support template.platform_config", + )); + } + Ok(()) +} + +#[allow(clippy::result_large_err)] +fn validate_vm_gpu_request(sandbox: &Sandbox, gpu_enabled: bool) -> Result<(), Status> { + let spec = sandbox + .spec + .as_ref() + .ok_or_else(|| Status::invalid_argument("sandbox spec is required"))?; + + let _ = vm_gpu_device_id(sandbox)?; if spec.gpu && !gpu_enabled { return Err(Status::failed_precondition( "GPU support is not enabled on this driver; start with --gpu", )); } + Ok(()) +} - if !spec.gpu && !spec.gpu_device.is_empty() { - return Err(Status::invalid_argument("gpu_device requires gpu=true")); +#[allow(clippy::result_large_err)] +fn vm_gpu_device_id(sandbox: &Sandbox) -> Result, Status> { + let Some(spec) = sandbox.spec.as_ref() else { + return Ok(None); + }; + let gpu_device_ids = VmSandboxDriverConfig::from_sandbox(sandbox) + .map_err(Status::invalid_argument)? + .gpu_device_ids + .unwrap_or_default(); + if !spec.gpu && !gpu_device_ids.is_empty() { + return Err(Status::invalid_argument( + "driver_config.gpu_device_ids requires gpu=true", + )); } - - if let Some(template) = spec.template.as_ref() { - if !template.agent_socket_path.is_empty() { - return Err(Status::failed_precondition( - "vm sandboxes do not support template.agent_socket_path", - )); - } - if template.platform_config.is_some() { - return Err(Status::failed_precondition( - "vm sandboxes do not support template.platform_config", - )); - } + if gpu_device_ids.len() > 1 { + return Err(Status::invalid_argument( + "vm driver currently supports at most one gpu_device_ids entry", + )); } - Ok(()) + + Ok(spec + .gpu + .then(|| gpu_device_ids.into_iter().next().unwrap_or_default())) } #[allow(clippy::result_large_err)] @@ -4995,6 +5075,33 @@ mod tests { static ENV_LOCK: std::sync::LazyLock> = std::sync::LazyLock::new(|| std::sync::Mutex::new(())); + fn gpu_device_ids_config(device_ids: &[&str]) -> Struct { + list_string_driver_config("gpu_device_ids", device_ids) + } + + fn gpu_device_id_typo_config(device_ids: &[&str]) -> Struct { + list_string_driver_config("gpu_device_id", device_ids) + } + + fn list_string_driver_config(field: &str, values: &[&str]) -> Struct { + Struct { + fields: std::iter::once(( + field.to_string(), + Value { + kind: Some(Kind::ListValue(prost_types::ListValue { + values: values + .iter() + .map(|device_id| Value { + kind: Some(Kind::StringValue((*device_id).to_string())), + }) + .collect(), + })), + }, + )) + .collect(), + } + } + #[test] fn vm_pulling_layer_event_adds_progress_detail_metadata() { let mut event = platform_event( @@ -5092,15 +5199,99 @@ mod tests { id: "sandbox-123".to_string(), spec: Some(SandboxSpec { gpu: false, - gpu_device: "0000:2d:00.0".to_string(), + template: Some(SandboxTemplate { + driver_config: Some(gpu_device_ids_config(&["0000:2d:00.0"])), + ..Default::default() + }), ..Default::default() }), ..Default::default() }; let err = validate_vm_sandbox(&sandbox, true) - .expect_err("gpu_device without gpu should be rejected"); + .expect_err("gpu_device_ids without gpu should be rejected"); + assert_eq!(err.code(), Code::InvalidArgument); + assert!(err.message().contains("gpu_device_ids requires gpu=true")); + } + + #[test] + fn validate_vm_sandbox_rejects_multiple_gpu_device_ids() { + let sandbox = Sandbox { + id: "sandbox-123".to_string(), + spec: Some(SandboxSpec { + gpu: true, + template: Some(SandboxTemplate { + driver_config: Some(gpu_device_ids_config(&["0000:2d:00.0", "0000:31:00.0"])), + ..Default::default() + }), + ..Default::default() + }), + ..Default::default() + }; + let err = + validate_vm_sandbox(&sandbox, true).expect_err("multiple GPUs should be rejected"); + assert_eq!(err.code(), Code::InvalidArgument); + assert!(err.message().contains("at most one gpu_device_ids")); + } + + #[test] + fn validate_vm_sandbox_rejects_empty_gpu_device_ids() { + let sandbox = Sandbox { + id: "sandbox-123".to_string(), + spec: Some(SandboxSpec { + gpu: true, + template: Some(SandboxTemplate { + driver_config: Some(gpu_device_ids_config(&[])), + ..Default::default() + }), + ..Default::default() + }), + ..Default::default() + }; + let err = + validate_vm_sandbox(&sandbox, true).expect_err("empty GPU IDs should be rejected"); assert_eq!(err.code(), Code::InvalidArgument); - assert!(err.message().contains("gpu_device requires gpu=true")); + assert!(err.message().contains("non-empty list")); + } + + #[test] + fn validate_vm_sandbox_rejects_unknown_driver_config_fields() { + let sandbox = Sandbox { + id: "sandbox-123".to_string(), + spec: Some(SandboxSpec { + gpu: true, + template: Some(SandboxTemplate { + driver_config: Some(gpu_device_id_typo_config(&["0000:2d:00.0"])), + ..Default::default() + }), + ..Default::default() + }), + ..Default::default() + }; + let err = + validate_vm_sandbox(&sandbox, true).expect_err("unknown field should be rejected"); + assert_eq!(err.code(), Code::InvalidArgument); + assert!(err.message().contains("unknown field")); + } + + #[test] + fn validate_vm_sandbox_rejects_template_errors_before_device_config() { + let sandbox = Sandbox { + id: "sandbox-123".to_string(), + spec: Some(SandboxSpec { + gpu: true, + template: Some(SandboxTemplate { + agent_socket_path: "/tmp/agent.sock".to_string(), + driver_config: Some(gpu_device_ids_config(&[])), + ..Default::default() + }), + ..Default::default() + }), + ..Default::default() + }; + let err = + validate_vm_sandbox(&sandbox, true).expect_err("template error should be rejected"); + assert_eq!(err.code(), Code::FailedPrecondition); + assert!(err.message().contains("agent_socket_path")); } #[test] diff --git a/crates/openshell-server/src/compute/mod.rs b/crates/openshell-server/src/compute/mod.rs index 064eb3857..812b9c59a 100644 --- a/crates/openshell-server/src/compute/mod.rs +++ b/crates/openshell-server/src/compute/mod.rs @@ -1280,7 +1280,6 @@ fn driver_sandbox_spec_from_public( .map(|template| driver_sandbox_template_from_public(template, driver_kind)) .transpose()?, gpu: spec.gpu, - gpu_device: spec.gpu_device.clone(), sandbox_token: String::new(), }) } diff --git a/docs/reference/sandbox-compute-drivers.mdx b/docs/reference/sandbox-compute-drivers.mdx index 229bb1bdb..be63476fb 100644 --- a/docs/reference/sandbox-compute-drivers.mdx +++ b/docs/reference/sandbox-compute-drivers.mdx @@ -55,6 +55,13 @@ openshell sandbox create \ Driver config is for fields without a stable public flag. Prefer `--cpu`, `--memory`, and `--gpu` for portable resource intent. +Exact GPU device selection remains driver-owned and requires `--gpu`. Docker +and Podman accept `cdi_devices`; replace the top-level `docker` key with +`podman` when using the Podman driver, for example +`{"docker":{"cdi_devices":["nvidia.com/gpu=0"]}}`. The VM driver accepts +`gpu_device_ids`, for example `{"vm":{"gpu_device_ids":["0000:2d:00.0"]}}`; +the current VM implementation accepts at most one entry. + For Kubernetes, `pod.runtime_class_name` maps to PodSpec `runtimeClassName`. It overrides the gateway's configured default runtime class for that sandbox, while a typed `SandboxTemplate.runtime_class_name` value from the API still diff --git a/docs/sandboxes/manage-sandboxes.mdx b/docs/sandboxes/manage-sandboxes.mdx index 6ca24cccd..1a54d0a06 100644 --- a/docs/sandboxes/manage-sandboxes.mdx +++ b/docs/sandboxes/manage-sandboxes.mdx @@ -74,6 +74,17 @@ For Docker-backed sandboxes, GPU injection uses Docker CDI. If you enable Docker CDI after the gateway starts, restart the gateway so OpenShell can detect the updated Docker daemon capability. +Exact GPU device selection is driver-specific and requires `--gpu`. For Docker +or Podman, pass CDI IDs through `cdi_devices`. The top-level key must match the +active driver; replace `docker` with `podman` when using Podman: + +```shell +openshell sandbox create \ + --gpu \ + --driver-config-json '{"docker":{"cdi_devices":["nvidia.com/gpu=0"]}}' \ + -- claude +``` + ### Custom Containers Use `--from` to create a sandbox from the base image, another pre-built sandbox name, a local directory, or a container image: diff --git a/e2e/rust/tests/gpu_device_selection.rs b/e2e/rust/tests/gpu_device_selection.rs index 5f5314b9c..77605fe1a 100644 --- a/e2e/rust/tests/gpu_device_selection.rs +++ b/e2e/rust/tests/gpu_device_selection.rs @@ -11,7 +11,7 @@ use std::process::Stdio; use std::time::Duration; use openshell_e2e::harness::binary::openshell_cmd; -use openshell_e2e::harness::container::ContainerEngine; +use openshell_e2e::harness::container::{ContainerEngine, e2e_driver}; use openshell_e2e::harness::output::strip_ansi; use openshell_e2e::harness::sandbox::SandboxGuard; use serde_json::{Map, Value}; @@ -130,6 +130,22 @@ fn has_cdi_gpu_device(device_id: &str) -> bool { .any(|discovered| discovered == device_id) } +fn e2e_driver_config_key() -> &'static str { + match e2e_driver().as_deref() { + Some("podman") => "podman", + _ => "docker", + } +} + +fn cdi_devices_driver_config_json(device_ids: &[&str]) -> String { + serde_json::json!({ + e2e_driver_config_key(): { + "cdi_devices": device_ids + } + }) + .to_string() +} + fn runtime_gpu_lines(gpu_device: &str) -> Vec { let engine = ContainerEngine::from_env(); let image = gpu_probe_image(); @@ -174,9 +190,11 @@ fn runtime_gpu_lines(gpu_device: &str) -> Vec { async fn sandbox_gpu_lines(gpu_device: Option<&str>) -> Vec { let mut args = vec!["--gpu"]; + let driver_config_json; if let Some(gpu_device) = gpu_device { - args.push("--gpu-device"); - args.push(gpu_device); + driver_config_json = cdi_devices_driver_config_json(&[gpu_device]); + args.push("--driver-config-json"); + args.push(driver_config_json.as_str()); } args.extend(["--", "sh", "-lc", "nvidia-smi -L"]); @@ -271,16 +289,17 @@ async fn gpu_all_device_request_matches_plain_all_gpu_container() { #[tokio::test] async fn gpu_invalid_device_request_fails() { - let output = sandbox_create_output(&[ + let driver_config_json = cdi_devices_driver_config_json(&["nvidia.com/gpu=invalid"]); + let args = vec![ "--gpu", - "--gpu-device", - "nvidia.com/gpu=invalid", + "--driver-config-json", + driver_config_json.as_str(), "--", "sh", "-lc", "nvidia-smi -L", - ]) - .await; + ]; + let output = sandbox_create_output(&args).await; let output_lower = output.to_ascii_lowercase(); assert!( diff --git a/proto/compute_driver.proto b/proto/compute_driver.proto index 190a04e87..dbcb9e818 100644 --- a/proto/compute_driver.proto +++ b/proto/compute_driver.proto @@ -85,10 +85,8 @@ message DriverSandboxSpec { DriverSandboxTemplate template = 6; // Request NVIDIA GPU resources for this sandbox. bool gpu = 9; - // Optional PCI BDF address (e.g. "0000:2d:00.0") or device index - // (e.g. "0", "1"). When empty with gpu=true, the driver assigns the - // first available GPU. - string gpu_device = 10; + reserved 10; + reserved "gpu_device"; // Gateway-minted JWT identifying this sandbox to the gateway. Set by // the gateway on create; the driver materialises it via its native // secret mechanism (Docker/Podman/VM bind-mount a per-sandbox file; diff --git a/proto/openshell.proto b/proto/openshell.proto index c2755aaf7..3b5d81fcd 100644 --- a/proto/openshell.proto +++ b/proto/openshell.proto @@ -319,10 +319,8 @@ message SandboxSpec { repeated string providers = 8; // Request NVIDIA GPU resources for this sandbox. bool gpu = 9; - // Optional PCI BDF address (e.g. "0000:2d:00.0") or device index - // (e.g. "0", "1"). When empty with gpu=true, the driver assigns the - // first available GPU. - string gpu_device = 10; + reserved 10; + reserved "gpu_device"; // Field 11 was `proposal_approval_mode`. The approval mode is now a // runtime setting (gateway or sandbox scope) read via UpdateConfig / // GetSandboxConfig, so it can be flipped on a running sandbox and