From 52a1df93c3fe831d4b78e085df0127f9bdc0a6b7 Mon Sep 17 00:00:00 2001 From: Florent Benoit Date: Thu, 23 Apr 2026 17:05:19 +0200 Subject: [PATCH] feat(openshell-vm): allow to have tty with exec add support on both side (rust and python) fixes https://github.com/NVIDIA/OpenShell/issues/936 Signed-off-by: Florent Benoit --- .../scripts/openshell-vm-exec-agent.py | 161 ++++++- crates/openshell-vm/src/exec.rs | 415 +++++++++++++++++- 2 files changed, 568 insertions(+), 8 deletions(-) diff --git a/crates/openshell-vm/scripts/openshell-vm-exec-agent.py b/crates/openshell-vm/scripts/openshell-vm-exec-agent.py index d7ffd81df..f2b384cf9 100644 --- a/crates/openshell-vm/scripts/openshell-vm-exec-agent.py +++ b/crates/openshell-vm/scripts/openshell-vm-exec-agent.py @@ -3,11 +3,15 @@ # SPDX-License-Identifier: Apache-2.0 import base64 +import fcntl import json import os +import pty import socket +import struct import subprocess import sys +import termios import threading @@ -42,6 +46,11 @@ def validate_env(env_items): return env +def set_winsize(fd, cols, rows): + winsize = struct.pack("HHHH", rows, cols, 0, 0) + fcntl.ioctl(fd, termios.TIOCSWINSZ, winsize) + + def stream_reader(pipe, frame_type, sock_file, lock): try: while True: @@ -79,6 +88,8 @@ def stdin_writer(proc, sock_file, sock, lock): proc.stdin.flush() elif kind == "stdin_close": break + elif kind == "resize": + pass else: send_frame( sock_file, @@ -96,14 +107,10 @@ def stdin_writer(proc, sock_file, sock, lock): pass -def handle_client(conn): - sock_file = conn.makefile("rwb", buffering=0) +def handle_client_pipe(conn, request, sock_file): + """Handle a client connection using pipes (non-TTY mode).""" lock = threading.Lock() try: - request = recv_line(sock_file) - if request is None: - return - argv = request.get("argv") or ["sh"] cwd = request.get("cwd") env = os.environ.copy() @@ -153,6 +160,148 @@ def handle_client(conn): conn.close() +def handle_client_tty(conn, request, sock_file): + """Handle a client connection with PTY allocation.""" + lock = threading.Lock() + master_fd = -1 + try: + argv = request.get("argv") or ["sh"] + cwd = request.get("cwd") + env = os.environ.copy() + env.update(validate_env(request.get("env") or [])) + env.setdefault("TERM", "xterm-256color") + + master_fd, slave_fd = pty.openpty() + + # Consume any resize frame sent right after the ExecRequest. + # The host sends it before starting the stdin pump, so it + # should arrive quickly. Use a short socket timeout. + conn.settimeout(0.5) + try: + pending = sock_file.readline() + if pending: + frame = json.loads(pending.decode("utf-8")) + if frame.get("type") == "resize": + set_winsize( + slave_fd, + frame.get("cols", 80), + frame.get("rows", 24), + ) + except (socket.timeout, ValueError, OSError): + pass + finally: + conn.settimeout(None) + + proc = subprocess.Popen( + argv, + cwd=cwd or "/", + env=env, + stdin=slave_fd, + stdout=slave_fd, + stderr=slave_fd, + preexec_fn=os.setsid, + ) + os.close(slave_fd) + + def pty_reader(): + try: + while True: + try: + chunk = os.read(master_fd, 8192) + except OSError: + break + if not chunk: + break + send_frame( + sock_file, + lock, + { + "type": "stdout", + "data": base64.b64encode(chunk).decode("ascii"), + }, + ) + except Exception: + pass + + def pty_stdin_writer(): + try: + while True: + frame = recv_line(sock_file) + if frame is None: + break + kind = frame.get("type") + if kind == "stdin": + payload = base64.b64decode(frame.get("data", "")) + try: + os.write(master_fd, payload) + except OSError: + break + elif kind == "resize": + try: + set_winsize( + master_fd, + frame.get("cols", 80), + frame.get("rows", 24), + ) + except OSError: + pass + elif kind == "stdin_close": + break + else: + send_frame( + sock_file, + lock, + {"type": "error", "message": f"unknown frame type: {kind}"}, + ) + break + except (BrokenPipeError, OSError): + pass + + reader_thread = threading.Thread(target=pty_reader, daemon=True) + stdin_thread = threading.Thread(target=pty_stdin_writer, daemon=True) + reader_thread.start() + stdin_thread.start() + + code = proc.wait() + reader_thread.join(timeout=2) + send_frame(sock_file, lock, {"type": "exit", "code": code}) + except Exception as exc: + try: + send_frame(sock_file, lock, {"type": "error", "message": str(exc)}) + except Exception: + pass + finally: + if master_fd >= 0: + try: + os.close(master_fd) + except OSError: + pass + try: + sock_file.close() + except Exception: + pass + conn.close() + + +def handle_client(conn): + sock_file = conn.makefile("rwb", buffering=0) + try: + request = recv_line(sock_file) + if request is None: + sock_file.close() + conn.close() + return + except Exception: + sock_file.close() + conn.close() + return + + if request.get("tty"): + handle_client_tty(conn, request, sock_file) + else: + handle_client_pipe(conn, request, sock_file) + + def main(): if not hasattr(socket, "AF_VSOCK"): print("AF_VSOCK is not available", file=sys.stderr) diff --git a/crates/openshell-vm/src/exec.rs b/crates/openshell-vm/src/exec.rs index 6195556e1..f3198a6be 100644 --- a/crates/openshell-vm/src/exec.rs +++ b/crates/openshell-vm/src/exec.rs @@ -3,12 +3,14 @@ use std::fs::{self, File}; use std::io::{BufRead, BufReader, Read, Write}; +use std::os::fd::{AsFd, BorrowedFd}; use std::os::unix::net::UnixStream; use std::path::{Path, PathBuf}; use std::thread; use std::time::{SystemTime, UNIX_EPOCH}; use base64::Engine as _; +use nix::sys::termios::{self, SetArg, Termios}; use serde::{Deserialize, Serialize}; use crate::VmError; @@ -87,6 +89,7 @@ struct ExecRequest { enum ClientFrame { Stdin { data: String }, StdinClose, + Resize { cols: u16, rows: u16 }, } #[derive(Debug, Deserialize)] @@ -98,6 +101,46 @@ enum ServerFrame { Error { message: String }, } +struct RawModeGuard { + raw_fd: i32, + original: Termios, +} + +impl RawModeGuard { + fn enter() -> Result { + let stdin = std::io::stdin(); + let fd = stdin.as_fd(); + let original = + termios::tcgetattr(&fd).map_err(|e| VmError::Exec(format!("tcgetattr: {e}")))?; + let mut raw = original.clone(); + termios::cfmakeraw(&mut raw); + termios::tcsetattr(&fd, SetArg::TCSANOW, &raw) + .map_err(|e| VmError::Exec(format!("tcsetattr: {e}")))?; + Ok(Self { + raw_fd: std::os::unix::io::AsRawFd::as_raw_fd(&stdin), + original, + }) + } +} + +impl Drop for RawModeGuard { + fn drop(&mut self) { + let fd = unsafe { BorrowedFd::borrow_raw(self.raw_fd) }; + let _ = termios::tcsetattr(&fd, SetArg::TCSANOW, &self.original); + } +} + +fn get_terminal_size() -> Option<(u16, u16)> { + let fd = std::os::unix::io::AsRawFd::as_raw_fd(&std::io::stdout()); + let mut ws: libc::winsize = unsafe { std::mem::zeroed() }; + let rc = unsafe { libc::ioctl(fd, libc::TIOCGWINSZ, &mut ws) }; + if rc == 0 && ws.ws_col > 0 && ws.ws_row > 0 { + Some((ws.ws_col, ws.ws_row)) + } else { + None + } +} + pub fn vm_exec_socket_path(rootfs: &Path) -> PathBuf { // Prefer XDG_RUNTIME_DIR (per-user, restricted permissions on Linux), // fall back to /tmp. Ownership/symlink validation happens in @@ -495,9 +538,19 @@ pub fn exec_running_vm(options: VmExecOptions) -> Result { }; send_json_line(&mut writer, &request)?; + let tty = options.tty; + let _raw_guard = if tty { + if let Some((cols, rows)) = get_terminal_size() { + send_json_line(&mut writer, &ClientFrame::Resize { cols, rows })?; + } + Some(RawModeGuard::enter()?) + } else { + None + }; + let stdin_writer = writer; thread::spawn(move || { - let _ = pump_stdin(stdin_writer); + let _ = pump_stdin(stdin_writer, tty); }); let mut reader = BufReader::new(&mut stream); @@ -724,10 +777,11 @@ fn send_json_line(writer: &mut UnixStream, value: &T) -> Result<() .map_err(|e| VmError::Exec(format!("write VM exec request: {e}"))) } -fn pump_stdin(mut writer: UnixStream) -> Result<(), VmError> { +fn pump_stdin(mut writer: UnixStream, tty: bool) -> Result<(), VmError> { let stdin = std::io::stdin(); let mut stdin = stdin.lock(); let mut buf = [0u8; 8192]; + let mut last_size: Option<(u16, u16)> = None; loop { let read = stdin @@ -736,6 +790,22 @@ fn pump_stdin(mut writer: UnixStream) -> Result<(), VmError> { if read == 0 { break; } + + if tty { + if let Some(size) = get_terminal_size() { + if last_size != Some(size) { + last_size = Some(size); + let _ = send_json_line( + &mut writer, + &ClientFrame::Resize { + cols: size.0, + rows: size.1, + }, + ); + } + } + } + let frame = ClientFrame::Stdin { data: base64::engine::general_purpose::STANDARD.encode(&buf[..read]), }; @@ -765,3 +835,344 @@ fn now_ms() -> Result { .map_err(|e| VmError::RuntimeState(format!("read system clock: {e}")))?; Ok(duration.as_millis()) } + +#[cfg(test)] +mod tests { + use super::*; + + // ── ExecRequest serialization ──────────────────────────────────── + + #[test] + fn exec_request_serializes_with_tty() { + let req = ExecRequest { + argv: vec!["sh".into()], + env: vec!["TERM=xterm".into()], + cwd: None, + tty: true, + }; + let json: serde_json::Value = serde_json::to_value(&req).unwrap(); + assert_eq!(json["argv"], serde_json::json!(["sh"])); + assert_eq!(json["tty"], true); + assert_eq!(json["cwd"], serde_json::Value::Null); + } + + #[test] + fn exec_request_serializes_without_tty() { + let req = ExecRequest { + argv: vec!["echo".into(), "hello".into()], + env: vec![], + cwd: Some("/tmp".into()), + tty: false, + }; + let json: serde_json::Value = serde_json::to_value(&req).unwrap(); + assert_eq!(json["tty"], false); + assert_eq!(json["cwd"], "/tmp"); + } + + // ── ClientFrame serialization ──────────────────────────────────── + + #[test] + fn client_frame_stdin_serializes() { + let frame = ClientFrame::Stdin { + data: "aGVsbG8=".into(), + }; + let json: serde_json::Value = serde_json::to_value(&frame).unwrap(); + assert_eq!(json["type"], "stdin"); + assert_eq!(json["data"], "aGVsbG8="); + } + + #[test] + fn client_frame_stdin_close_serializes() { + let frame = ClientFrame::StdinClose; + let json: serde_json::Value = serde_json::to_value(&frame).unwrap(); + assert_eq!(json["type"], "stdin_close"); + } + + #[test] + fn client_frame_resize_serializes() { + let frame = ClientFrame::Resize { + cols: 120, + rows: 40, + }; + let json: serde_json::Value = serde_json::to_value(&frame).unwrap(); + assert_eq!(json["type"], "resize"); + assert_eq!(json["cols"], 120); + assert_eq!(json["rows"], 40); + } + + // ── ServerFrame deserialization ─────────────────────────────────── + + #[test] + fn server_frame_stdout_deserializes() { + let json = r#"{"type":"stdout","data":"aGVsbG8="}"#; + let frame: ServerFrame = serde_json::from_str(json).unwrap(); + assert!(matches!(frame, ServerFrame::Stdout { data } if data == "aGVsbG8=")); + } + + #[test] + fn server_frame_stderr_deserializes() { + let json = r#"{"type":"stderr","data":"ZXJy"}"#; + let frame: ServerFrame = serde_json::from_str(json).unwrap(); + assert!(matches!(frame, ServerFrame::Stderr { data } if data == "ZXJy")); + } + + #[test] + fn server_frame_exit_deserializes() { + let json = r#"{"type":"exit","code":42}"#; + let frame: ServerFrame = serde_json::from_str(json).unwrap(); + assert!(matches!(frame, ServerFrame::Exit { code: 42 })); + } + + #[test] + fn server_frame_error_deserializes() { + let json = r#"{"type":"error","message":"boom"}"#; + let frame: ServerFrame = serde_json::from_str(json).unwrap(); + assert!(matches!(frame, ServerFrame::Error { message } if message == "boom")); + } + + #[test] + fn server_frame_unknown_type_fails() { + let json = r#"{"type":"unknown","data":"x"}"#; + assert!(serde_json::from_str::(json).is_err()); + } + + // ── ClientFrame ↔ ServerFrame round-trip compatibility ─────────── + // Verify that what the Rust host serializes can be parsed by the + // Python agent (same JSON shape), and vice versa. + + #[test] + fn resize_frame_has_expected_json_shape() { + let frame = ClientFrame::Resize { cols: 80, rows: 24 }; + let s = serde_json::to_string(&frame).unwrap(); + let v: serde_json::Value = serde_json::from_str(&s).unwrap(); + assert_eq!(v["type"].as_str().unwrap(), "resize"); + assert!(v["cols"].is_u64()); + assert!(v["rows"].is_u64()); + } + + // ── validate_env_vars ──────────────────────────────────────────── + + #[test] + fn validate_env_vars_accepts_valid() { + let items = vec![ + "HOME=/root".to_string(), + "PATH=/usr/bin".to_string(), + "_UNDERSCORE=1".to_string(), + "A1B2=val".to_string(), + ]; + assert!(validate_env_vars(&items).is_ok()); + } + + #[test] + fn validate_env_vars_rejects_missing_equals() { + let items = vec!["NOEQUALS".to_string()]; + assert!(validate_env_vars(&items).is_err()); + } + + #[test] + fn validate_env_vars_rejects_empty_key() { + let items = vec!["=value".to_string()]; + assert!(validate_env_vars(&items).is_err()); + } + + #[test] + fn validate_env_vars_rejects_leading_digit() { + let items = vec!["1BAD=val".to_string()]; + assert!(validate_env_vars(&items).is_err()); + } + + #[test] + fn validate_env_vars_rejects_special_chars() { + let items = vec!["BAD-KEY=val".to_string()]; + assert!(validate_env_vars(&items).is_err()); + } + + // ── decode_payload ─────────────────────────────────────────────── + + #[test] + fn decode_payload_valid_base64() { + let decoded = decode_payload("aGVsbG8=").unwrap(); + assert_eq!(decoded, b"hello"); + } + + #[test] + fn decode_payload_empty() { + let decoded = decode_payload("").unwrap(); + assert!(decoded.is_empty()); + } + + #[test] + fn decode_payload_invalid_base64() { + assert!(decode_payload("!!!not-base64!!!").is_err()); + } + + // ── Resize frame edge cases ────────────────────────────────────── + + #[test] + fn resize_frame_max_dimensions() { + let frame = ClientFrame::Resize { + cols: u16::MAX, + rows: u16::MAX, + }; + let json: serde_json::Value = serde_json::to_value(&frame).unwrap(); + assert_eq!(json["cols"], u16::MAX as u64); + assert_eq!(json["rows"], u16::MAX as u64); + } + + #[test] + fn resize_frame_minimum_dimensions() { + let frame = ClientFrame::Resize { cols: 1, rows: 1 }; + let json: serde_json::Value = serde_json::to_value(&frame).unwrap(); + assert_eq!(json["cols"], 1); + assert_eq!(json["rows"], 1); + } + + // ── Wire format: newline-delimited JSON ────────────────────────── + // The protocol sends one JSON object per line. Verify that + // serialized frames produce valid single-line JSON that the + // Python agent can split on '\n' and json.loads(). + + #[test] + fn client_frames_serialize_to_single_line_json() { + let frames: Vec = vec![ + ClientFrame::Stdin { + data: "dGVzdA==".into(), + }, + ClientFrame::StdinClose, + ClientFrame::Resize { cols: 80, rows: 24 }, + ]; + for frame in &frames { + let s = serde_json::to_string(frame).unwrap(); + assert!(!s.contains('\n'), "frame should be single-line: {s}"); + let _: serde_json::Value = serde_json::from_str(&s).unwrap(); + } + } + + #[test] + fn exec_request_serializes_to_single_line_json() { + let req = ExecRequest { + argv: vec!["bash".into(), "-c".into(), "echo 'hello world'".into()], + env: vec!["HOME=/root".into(), "TERM=xterm-256color".into()], + cwd: Some("/home/user".into()), + tty: true, + }; + let s = serde_json::to_string(&req).unwrap(); + assert!(!s.contains('\n')); + let _: serde_json::Value = serde_json::from_str(&s).unwrap(); + } + + // ── Stdin data encode → decode round-trip ──────────────────────── + // Mirrors the flow: host encodes payload as base64 in a Stdin + // frame, guest decodes with decode_payload(). + + #[test] + fn stdin_payload_round_trip() { + let original = b"echo hello\n"; + let encoded = base64::engine::general_purpose::STANDARD.encode(original); + let frame = ClientFrame::Stdin { + data: encoded.clone(), + }; + let json = serde_json::to_string(&frame).unwrap(); + let parsed: serde_json::Value = serde_json::from_str(&json).unwrap(); + let decoded = decode_payload(parsed["data"].as_str().unwrap()).unwrap(); + assert_eq!(decoded, original); + } + + #[test] + fn stdin_payload_round_trip_binary() { + let original: Vec = (0..=255).collect(); + let encoded = base64::engine::general_purpose::STANDARD.encode(&original); + let decoded = decode_payload(&encoded).unwrap(); + assert_eq!(decoded, original); + } + + // ── Python agent compatibility ─────────────────────────────────── + // The Python agent parses frames with json.loads() and dispatches + // on frame["type"]. These tests verify the exact field names and + // values match what the Python code expects. + + #[test] + fn exec_request_tty_field_matches_python_dispatch() { + // Python: request.get("tty") — must be a JSON boolean + let req = ExecRequest { + argv: vec!["sh".into()], + env: vec![], + cwd: None, + tty: true, + }; + let v: serde_json::Value = serde_json::to_value(&req).unwrap(); + assert!(v["tty"].is_boolean()); + assert_eq!(v["tty"].as_bool().unwrap(), true); + + let req_no_tty = ExecRequest { + argv: vec!["echo".into()], + env: vec![], + cwd: None, + tty: false, + }; + let v: serde_json::Value = serde_json::to_value(&req_no_tty).unwrap(); + assert_eq!(v["tty"].as_bool().unwrap(), false); + } + + #[test] + fn resize_type_tag_is_snake_case() { + // Python: kind == "resize" — must be lowercase snake_case + let frame = ClientFrame::Resize { cols: 80, rows: 24 }; + let v: serde_json::Value = serde_json::to_value(&frame).unwrap(); + assert_eq!(v["type"].as_str().unwrap(), "resize"); + } + + #[test] + fn stdin_close_type_tag_is_snake_case() { + // Python: kind == "stdin_close" + let frame = ClientFrame::StdinClose; + let v: serde_json::Value = serde_json::to_value(&frame).unwrap(); + assert_eq!(v["type"].as_str().unwrap(), "stdin_close"); + } + + #[test] + fn resize_fields_are_integers_not_strings() { + // Python: frame.get("cols", 80) — expects int, not string + let frame = ClientFrame::Resize { + cols: 200, + rows: 50, + }; + let v: serde_json::Value = serde_json::to_value(&frame).unwrap(); + assert!(v["cols"].is_u64()); + assert!(v["rows"].is_u64()); + } + + // ── ServerFrame: Python agent output ───────────────────────────── + // These mirror the exact JSON the Python agent produces with + // json.dumps(frame, separators=(",", ":")) + + #[test] + fn server_frame_parses_compact_json() { + // Python uses separators=(",", ":") — no spaces + let compact = r#"{"type":"stdout","data":"aGk="}"#; + let frame: ServerFrame = serde_json::from_str(compact).unwrap(); + assert!(matches!(frame, ServerFrame::Stdout { data } if data == "aGk=")); + } + + #[test] + fn server_frame_exit_code_zero() { + let json = r#"{"type":"exit","code":0}"#; + let frame: ServerFrame = serde_json::from_str(json).unwrap(); + assert!(matches!(frame, ServerFrame::Exit { code: 0 })); + } + + #[test] + fn server_frame_exit_code_negative() { + let json = r#"{"type":"exit","code":-1}"#; + let frame: ServerFrame = serde_json::from_str(json).unwrap(); + assert!(matches!(frame, ServerFrame::Exit { code: -1 })); + } + + #[test] + fn server_frame_tolerates_extra_fields() { + // Future-proofing: agent may add fields we don't know about + let json = r#"{"type":"exit","code":0,"extra":"ignored"}"#; + let frame: ServerFrame = serde_json::from_str(json).unwrap(); + assert!(matches!(frame, ServerFrame::Exit { code: 0 })); + } +}