Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
161 changes: 155 additions & 6 deletions crates/openshell-vm/scripts/openshell-vm-exec-agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,15 @@
# SPDX-License-Identifier: Apache-2.0

import base64
import fcntl
import json
import os
import pty
import socket
import struct
import subprocess
import sys
import termios
import threading


Expand Down Expand Up @@ -42,6 +46,11 @@ def validate_env(env_items):
return env


def set_winsize(fd, cols, rows):
winsize = struct.pack("HHHH", rows, cols, 0, 0)
fcntl.ioctl(fd, termios.TIOCSWINSZ, winsize)


def stream_reader(pipe, frame_type, sock_file, lock):
try:
while True:
Expand Down Expand Up @@ -79,6 +88,8 @@ def stdin_writer(proc, sock_file, sock, lock):
proc.stdin.flush()
elif kind == "stdin_close":
break
elif kind == "resize":
pass
else:
send_frame(
sock_file,
Expand All @@ -96,14 +107,10 @@ def stdin_writer(proc, sock_file, sock, lock):
pass


def handle_client(conn):
sock_file = conn.makefile("rwb", buffering=0)
def handle_client_pipe(conn, request, sock_file):
"""Handle a client connection using pipes (non-TTY mode)."""
lock = threading.Lock()
try:
request = recv_line(sock_file)
if request is None:
return

argv = request.get("argv") or ["sh"]
cwd = request.get("cwd")
env = os.environ.copy()
Expand Down Expand Up @@ -153,6 +160,148 @@ def handle_client(conn):
conn.close()


def handle_client_tty(conn, request, sock_file):
"""Handle a client connection with PTY allocation."""
lock = threading.Lock()
master_fd = -1
try:
argv = request.get("argv") or ["sh"]
cwd = request.get("cwd")
env = os.environ.copy()
env.update(validate_env(request.get("env") or []))
env.setdefault("TERM", "xterm-256color")

master_fd, slave_fd = pty.openpty()

# Consume any resize frame sent right after the ExecRequest.
# The host sends it before starting the stdin pump, so it
# should arrive quickly. Use a short socket timeout.
conn.settimeout(0.5)
try:
pending = sock_file.readline()
if pending:
frame = json.loads(pending.decode("utf-8"))
if frame.get("type") == "resize":
set_winsize(
slave_fd,
frame.get("cols", 80),
frame.get("rows", 24),
)
except (socket.timeout, ValueError, OSError):
pass
finally:
conn.settimeout(None)

proc = subprocess.Popen(
argv,
cwd=cwd or "/",
env=env,
stdin=slave_fd,
stdout=slave_fd,
stderr=slave_fd,
preexec_fn=os.setsid,
)
os.close(slave_fd)

def pty_reader():
try:
while True:
try:
chunk = os.read(master_fd, 8192)
except OSError:
break
if not chunk:
break
send_frame(
sock_file,
lock,
{
"type": "stdout",
"data": base64.b64encode(chunk).decode("ascii"),
},
)
except Exception:
pass

def pty_stdin_writer():
try:
while True:
frame = recv_line(sock_file)
if frame is None:
break
kind = frame.get("type")
if kind == "stdin":
payload = base64.b64decode(frame.get("data", ""))
try:
os.write(master_fd, payload)
except OSError:
break
elif kind == "resize":
try:
set_winsize(
master_fd,
frame.get("cols", 80),
frame.get("rows", 24),
)
except OSError:
pass
elif kind == "stdin_close":
break
else:
send_frame(
sock_file,
lock,
{"type": "error", "message": f"unknown frame type: {kind}"},
)
break
except (BrokenPipeError, OSError):
pass

reader_thread = threading.Thread(target=pty_reader, daemon=True)
stdin_thread = threading.Thread(target=pty_stdin_writer, daemon=True)
reader_thread.start()
stdin_thread.start()

code = proc.wait()
reader_thread.join(timeout=2)
send_frame(sock_file, lock, {"type": "exit", "code": code})
except Exception as exc:
try:
send_frame(sock_file, lock, {"type": "error", "message": str(exc)})
except Exception:
pass
finally:
if master_fd >= 0:
try:
os.close(master_fd)
except OSError:
pass
try:
sock_file.close()
except Exception:
pass
conn.close()


def handle_client(conn):
sock_file = conn.makefile("rwb", buffering=0)
try:
request = recv_line(sock_file)
if request is None:
sock_file.close()
conn.close()
return
except Exception:
sock_file.close()
conn.close()
return

if request.get("tty"):
handle_client_tty(conn, request, sock_file)
else:
handle_client_pipe(conn, request, sock_file)


def main():
if not hasattr(socket, "AF_VSOCK"):
print("AF_VSOCK is not available", file=sys.stderr)
Expand Down
Loading
Loading