diff --git a/.gitignore b/.gitignore index 68905fbf5..c525d8208 100644 --- a/.gitignore +++ b/.gitignore @@ -44,11 +44,17 @@ playground.xcworkspace # # Xcode automatically generates this directory with a hierarchical structure of links to # temporary directories for debugging Swift packages. -.swiftpm/xcode +.swiftpm/ .build/ LocalPackages/*/Package.resolved +# Coverage profile output +*.profraw + +# Claude Code worktrees and per-session state +.claude/worktrees/ + # CocoaPods # # We recommend against adding the Pods directory to your .gitignore. However @@ -144,6 +150,3 @@ Libs/*.a Libs/.downloaded Libs/dylibs/ Libs/ios/ - -# Local refactor scratchpad (per chore: untrack docs/refactor scratchpad) -docs/refactor/ diff --git a/CHANGELOG.md b/CHANGELOG.md index a1e57dcdf..db2c0ccb3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,19 +9,28 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added +- MCP: support for protocol versions `2025-06-18` and `2025-11-25` in addition to `2025-03-26`. Clients on the latest spec no longer downgrade. The server advertises the latest version it supports (`2025-11-25`) and falls back when a client requests an unknown version. +- MCP: structured tool output (`structuredContent`) on every tool. The serialized JSON still appears in `content[].text` for backward compatibility, while 2025-11-25 clients can read the parsed object directly. +- MCP: tool annotations (`title`, `readOnlyHint`, `destructiveHint`, `idempotentHint`, `openWorldHint`) on every tool, plus `serverInfo.title` in `initialize` responses. Read tools advertise `readOnlyHint=true`; `confirm_destructive_operation` advertises `destructiveHint=true`. +- MCP: `completions` capability advertised in `initialize` (the `completion/complete` handler was already wired). +- MCP: streaming progress notifications. Long-running tool calls (e.g. `execute_query`) now emit `notifications/progress` events to clients that pass a `_meta.progressToken` in their request. +- MCP: pairing redirect carries an explicit `error=denied` parameter when the user clicks Deny so extensions can show a clear error instead of hanging. +- MCP: re-pairing the same client name automatically revokes the previous token instead of leaving it active. - Oracle 10G password verifier authentication. Accounts whose `password_versions` includes a 10G hash now connect successfully, matching DBeaver/JDBC/sqlplus behavior. The 10G hash is documented as legacy; rotating to a modern verifier is still recommended (#483) - Oracle Test Connection now opens a focused diagnostic sheet for auth failures with copy-able diagnostic info, suggested actions, and a link to file an issue - Oracle connection negotiation now matches python-oracledb's 23ai compile-capability advertisement, including TTC4 explicit boundary, TTC5 token/pipelining/sessionless flags, OCI3 sync, dequeue selectors, and sparse vector features ### Removed -- Keychain: the legacy-keychain migration (`migrateFromLegacyKeychainIfNeeded`) and the password-sync-state migration (`migratePasswordSyncState`). The first violated Apple's Data Protection keychain contract on sandboxed macOS apps and corrupted user credentials; the second toggled `kSecAttrSynchronizable` at runtime, which Apple does not document as safe. The Sync Passwords settings toggle now applies to new saves only — existing keychain items keep their original sync state, matching Apple's documented behavior. Users with stale items in the legacy keychain can clean them via Keychain Access; the running app no longer touches them. +- Keychain: the legacy-keychain migration (`migrateFromLegacyKeychainIfNeeded`) and the password-sync-state migration (`migratePasswordSyncState`). The first violated Apple's Data Protection keychain contract on sandboxed macOS apps and corrupted user credentials; the second toggled `kSecAttrSynchronizable` at runtime, which Apple does not document as safe. The Sync Passwords settings toggle now applies to new saves only, existing keychain items keep their original sync state, matching Apple's documented behavior. Users with stale items in the legacy keychain can clean them via Keychain Access; the running app no longer touches them. ### Changed +- MCP: idle session timeout raised from 5 to 15 minutes. +- MCP: complete internal rewrite of the server, stdio bridge, and protocol dispatcher for spec compliance. Public API of `MCPServerManager` and the on-disk handshake format are unchanged; clients do not need to re-pair. - Internal: introduce `TabSession` as the foundation type for the editor tab/window subsystem rewrite. Currently a parallel structure mirroring `QueryTab`; subsequent PRs migrate state ownership and lifecycle hooks per `docs/architecture/tab-subsystem-rewrite.md`. No user-visible behavior change in this PR. - Internal: row data and load epoch now live on `TabSession`. `TabSessionRegistry` exposes the row-access methods directly (`tableRows(for:)`, `setTableRows(_:for:)`, `evict(for:)`, etc.); the intermediate `TableRowsStore` facade is gone. All consumers (coordinator, extensions, views, command actions) now read row data from the registry. No user-visible behavior change. -- Internal: hidden-column state moves from the per-window `ColumnVisibilityManager` into each tab's `columnLayout.hiddenColumns`. The shared manager is removed; `MainContentCoordinator` exposes `hideColumn`, `showColumn`, `toggleColumnVisibility`, `showAllColumns`, `hideAllColumns`, and `pruneHiddenColumns` that mutate the active tab directly. Per-table UserDefaults persistence moves into a small `ColumnVisibilityPersistence` service. Tab-switch save/restore swap is gone — each tab is its own source of truth. No user-visible behavior change. +- Internal: hidden-column state moves from the per-window `ColumnVisibilityManager` into each tab's `columnLayout.hiddenColumns`. The shared manager is removed; `MainContentCoordinator` exposes `hideColumn`, `showColumn`, `toggleColumnVisibility`, `showAllColumns`, `hideAllColumns`, and `pruneHiddenColumns` that mutate the active tab directly. Per-table UserDefaults persistence moves into a small `ColumnVisibilityPersistence` service. Tab-switch save/restore swap is gone, each tab is its own source of truth. No user-visible behavior change. - Internal: filter state collapses from three places (the per-window `FilterStateManager`, the `TabFilterState` snapshot on `QueryTab`, and the per-table file-based restore) to a single source: `tab.filterState`. The shared manager is removed; `MainContentCoordinator` now exposes the full filter API (`addFilter`, `applyAllFilters`, `clearFilterState`, `toggleFilterPanel`, `setFKFilter`, `saveLastFilters(for:)`, `restoreLastFilters(for:)`, `saveFilterPreset`, `loadFilterPreset`, `generateFilterPreviewSQL`, etc.) that mutates the active tab. The file-based "restore last filters" persistence in `FilterSettingsStorage` is unchanged. `FilterPanelView`, `MainStatusBarView`, `MainContentCommandActions`, `MainContentView`, and `MainEditorContentView` read filter state directly off the active tab. No user-visible behavior change. - Internal: extract `QueryExecutor` service from `MainContentCoordinator`. Query data fetch, parallel schema fetch, schema parsing, parameter detection, row-cap policy, and DDL detection now live in `TablePro/Core/Services/Query/QueryExecutor.swift`. SQL parsing helpers (`extractTableName`, `stripTrailingOrderBy`, `parseSQLiteCheckConstraintValues`) move into `QuerySqlParser`. Coordinator methods become thin wrappers; behavior unchanged. No user-visible behavior change. - Security: non-syncing keychain items now use `kSecAttrAccessibleAfterFirstUnlockThisDeviceOnly`. This keeps local-only secrets out of unencrypted device backups (the pairing Apple recommends for local secrets). Syncing items still use `kSecAttrAccessibleAfterFirstUnlock` because iCloud Keychain requires it. Existing items keep their accessibility class until you save them again. @@ -30,6 +39,22 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed +- MCP: GET `/mcp` now opens a real SSE notification stream. Previously the GET path was routed through the request dispatcher, which had no handler for it, so the connection was closed immediately and `notifications/progress` events were dropped. +- MCP: concurrent tool calls no longer serialize at the dispatcher loop. Each exchange is dispatched in its own child task while session-state guards still serialize per-session work. +- MCP: server validates the `protocolVersion` requested in `initialize` against a supported set and rejects unknown versions with `-32600 invalid_request` instead of silently echoing back whatever the client sent. +- MCP: server validates `MCP-Protocol-Version` on follow-up requests against the negotiated version on the session. +- MCP: 429 responses now include a real `Retry-After` header derived from the rate-limiter lockout time. The audit log records the same value. +- MCP: token revocation cancels every in-flight request issued by that token and terminates its sessions. +- MCP: CORS reflects the request `Origin` against an allowlist (`localhost`, `127.0.0.1`, `claude.ai`, `app.cursor.com`) instead of unconditionally returning `Access-Control-Allow-Origin: http://localhost`. Requests without an `Origin` header (native clients) get no CORS headers. +- MCP: duplicate `initialize` on the same session now returns `invalid_request` instead of silently overwriting `clientInfo`. +- MCP: `xcodebuild test` no longer leaves an orphan `TablePro.app` running. The app delegate skips its normal startup when launched under XCTest. +- MCP: server start removes a stale handshake file written by a crashed previous PID before writing a fresh one. +- MCP: settings activity log refreshes automatically when new audit entries are written. +- MCP: stale `Mcp-Session-Id` after idle timeout now produces a JSON-RPC `-32001 "Session not found"` envelope with HTTP 404, matching the spec and letting clients re-initialize cleanly. Previously the bridge forwarded a plain `{"error":"Session not found"}` body that Claude Desktop's parser rejected, hanging the request until a 4-minute client-side timeout fired. +- MCP: stdio bridge no longer exits silently when stdin is briefly empty (was reading via `availableData`, which can't tell EOF from "no bytes right now"). Now uses `FileHandle.bytes` AsyncBytes. +- MCP: SSE responses now stream incrementally instead of buffering the entire body before delivering events. +- MCP: localhost auth-DoS surface closed. The rate limiter now keys on `(client_address, principal_fingerprint)` so failed attempts from one bridge can't lock out another. +- MCP: in-app "Setup for Claude Desktop / Cursor" snippets now use the stdio command form pointing at the bundled `tablepro-mcp` binary. The previous `"url"` form was rejected by Claude Desktop entirely. - Saved connection passwords no longer disappear after quitting and relaunching the app. The legacy-keychain migration that ran on every launch was destructive on sandboxed macOS configurations: queries without `kSecUseDataProtectionKeychain` returned items that had been written *with* the flag, and the migration's "delete legacy entry" step then removed the only copy. Removed the legacy keychain migration entirely; `KeychainHelper` now exclusively reads and writes through the Data Protection keychain on every launch. - Tab switching: rapid Cmd+Number presses no longer leave a tail of tab transitions playing after the user releases the keys. The tab-selection setter (`NSWindowTabGroup.selectedWindow`) is now wrapped in `NSAnimationContext.runAnimationGroup` with `duration = 0`, so AppKit applies each switch synchronously without queuing a CAAnimation. Lazy-load also moved out of `windowDidBecomeKey` into `.task(id:)` view-appearance lifecycle per Apple's documentation. Note: extreme Cmd+Number bursts (e.g. holding the key for key-repeat) still incur per-switch AppKit window-focus overhead; this is platform-inherent to native NSWindow tabs and documented in `docs/architecture/tab-subsystem-rewrite.md` D2 - Oracle TIMESTAMP, TIMESTAMP WITH TIME ZONE, TIMESTAMP WITH LOCAL TIME ZONE, INTERVAL DAY TO SECOND, INTERVAL YEAR TO MONTH, DATE, RAW, and BLOB columns now render through typed decoders instead of garbled text. Tables containing INTERVAL YEAR TO MONTH or BFILE columns no longer crash the app on row fetch. Unknown column types display `` instead of crashing (#965) diff --git a/Plugins/TableProPluginKit/PluginDatabaseDriver.swift b/Plugins/TableProPluginKit/PluginDatabaseDriver.swift index d5f56e1d0..4fd8d3a3c 100644 --- a/Plugins/TableProPluginKit/PluginDatabaseDriver.swift +++ b/Plugins/TableProPluginKit/PluginDatabaseDriver.swift @@ -279,6 +279,27 @@ public extension PluginDatabaseDriver { return "\"\(escaped)\"" } + func streamRows(query: String) -> AsyncThrowingStream { + AsyncThrowingStream { continuation in + Task { + do { + let result = try await self.execute(query: query) + let header = PluginStreamHeader( + columns: result.columns, + columnTypeNames: result.columnTypeNames + ) + continuation.yield(.header(header)) + if !result.rows.isEmpty { + continuation.yield(.rows(result.rows)) + } + continuation.finish() + } catch { + continuation.finish(throwing: error) + } + } + } + } + func escapeStringLiteral(_ value: String) -> String { var result = value result = result.replacingOccurrences(of: "'", with: "''") diff --git a/TablePro.xcodeproj/project.pbxproj b/TablePro.xcodeproj/project.pbxproj index af10b330e..bb254edf6 100644 --- a/TablePro.xcodeproj/project.pbxproj +++ b/TablePro.xcodeproj/project.pbxproj @@ -319,8 +319,26 @@ 5A32BC082F9D5FC900BAEB5F /* Exceptions for "TablePro" folder in "mcp-server" target */ = { isa = PBXFileSystemSynchronizedBuildFileExceptionSet; membershipExceptions = ( - CLI/main.swift, - CLI/MCPBridgeProxy.swift, + CLI/BridgeMain.swift, + CLI/BridgeProxy.swift, + CLI/Handshake.swift, + Core/MCP/Transport/MCPBridgeLogger.swift, + Core/MCP/Transport/MCPMessageTransport.swift, + Core/MCP/Transport/MCPProtocolError.swift, + Core/MCP/Transport/MCPStdioMessageTransport.swift, + Core/MCP/Transport/MCPStreamableHttpClientTransport.swift, + Core/MCP/Wire/HttpRequestHead.swift, + Core/MCP/Wire/HttpResponseHead.swift, + Core/MCP/Wire/JsonRpcCodec.swift, + Core/MCP/Wire/JsonRpcError.swift, + Core/MCP/Wire/JsonRpcErrorCode.swift, + Core/MCP/Wire/JsonRpcId.swift, + Core/MCP/Wire/JsonRpcMessage.swift, + Core/MCP/Wire/JsonRpcVersion.swift, + Core/MCP/Wire/JsonValue.swift, + Core/MCP/Wire/SseDecoder.swift, + Core/MCP/Wire/SseEncoder.swift, + Core/MCP/Wire/SseFrame.swift, ); target = 5A32BBFF2F9D5F1300BAEB5F /* mcp-server */; }; @@ -460,8 +478,9 @@ 5AF312BE2F36FF7500E86682 /* Exceptions for "TablePro" folder in "TablePro" target */ = { isa = PBXFileSystemSynchronizedBuildFileExceptionSet; membershipExceptions = ( - CLI/main.swift, - CLI/MCPBridgeProxy.swift, + CLI/BridgeMain.swift, + CLI/BridgeProxy.swift, + CLI/Handshake.swift, Info.plist, ); target = 5A1091C62EF17EDC0055EA7C /* TablePro */; diff --git a/TablePro/AppDelegate.swift b/TablePro/AppDelegate.swift index 370f9240e..5ca87cde4 100644 --- a/TablePro/AppDelegate.swift +++ b/TablePro/AppDelegate.swift @@ -31,6 +31,11 @@ class AppDelegate: NSObject, NSApplicationDelegate { // MARK: - Lifecycle func applicationDidFinishLaunching(_ notification: Notification) { + if ProcessInfo.processInfo.environment["XCTestConfigurationFilePath"] != nil { + Self.logger.info("Running under XCTest, skipping normal app startup") + return + } + let appearanceSettings = AppSettingsManager.shared.appearance ThemeEngine.shared.updateAppearanceAndTheme( mode: appearanceSettings.appearanceMode, diff --git a/TablePro/CLI/BridgeMain.swift b/TablePro/CLI/BridgeMain.swift new file mode 100644 index 000000000..7e727addd --- /dev/null +++ b/TablePro/CLI/BridgeMain.swift @@ -0,0 +1,58 @@ +import Foundation + +@main +struct TableProMcpBridge { + static func main() async { + let logger: any MCPBridgeLogger = MCPCompositeBridgeLogger([ + MCPOSBridgeLogger(category: "MCP.Bridge"), + MCPStderrBridgeLogger() + ]) + + let acquirer = MCPHandshakeAcquirer(logger: logger) + let handshake: MCPBridgeHandshake + do { + handshake = try await acquirer.acquire() + } catch { + logger.log(.error, "Handshake failed: \(error.localizedDescription)") + emitFatalJsonRpcError(message: "TablePro is not running. Launch the app and enable the MCP server.") + exit(1) + } + + guard let endpoint = handshake.endpoint() else { + logger.log(.error, "Handshake produced invalid endpoint") + emitFatalJsonRpcError(message: "Invalid MCP server endpoint") + exit(1) + } + + let upstream = MCPStreamableHttpClientTransport( + configuration: MCPStreamableHttpClientConfiguration( + endpoint: endpoint, + bearerToken: handshake.token, + tlsCertFingerprint: handshake.tlsCertFingerprint, + requestTimeout: .seconds(60), + serverInitiatedStream: false + ), + errorLogger: logger + ) + + let host = MCPStdioMessageTransport(errorLogger: logger) + + let proxy = BridgeProxy(host: host, upstream: upstream, logger: logger) + await proxy.run() + } + + private static func emitFatalJsonRpcError(message: String) { + let envelope = JsonRpcMessage.errorResponse( + JsonRpcErrorResponse( + id: nil, + error: JsonRpcError( + code: JsonRpcErrorCode.serverError, + message: message, + data: nil + ) + ) + ) + guard let data = try? JsonRpcCodec.encodeLine(envelope) else { return } + FileHandle.standardOutput.write(data) + } +} diff --git a/TablePro/CLI/BridgeProxy.swift b/TablePro/CLI/BridgeProxy.swift new file mode 100644 index 000000000..1537aed30 --- /dev/null +++ b/TablePro/CLI/BridgeProxy.swift @@ -0,0 +1,43 @@ +import Foundation + +actor BridgeProxy { + private let host: any MCPMessageTransport + private let upstream: any MCPMessageTransport + private let logger: any MCPBridgeLogger + + init(host: any MCPMessageTransport, upstream: any MCPMessageTransport, logger: any MCPBridgeLogger) { + self.host = host + self.upstream = upstream + self.logger = logger + } + + func run() async { + await withTaskGroup(of: Void.self) { [host, upstream, logger] group in + group.addTask { await Self.forward(from: host, to: upstream, direction: "host→upstream", logger: logger) } + group.addTask { await Self.forward(from: upstream, to: host, direction: "upstream→host", logger: logger) } + await group.waitForAll() + } + } + + private static func forward( + from source: any MCPMessageTransport, + to destination: any MCPMessageTransport, + direction: String, + logger: any MCPBridgeLogger + ) async { + do { + for try await message in source.inbound { + do { + try await destination.send(message) + } catch { + logger.log(.warning, "[\(direction)] send failed: \(error.localizedDescription)") + } + } + logger.log(.info, "[\(direction)] inbound stream closed") + } catch { + logger.log(.error, "[\(direction)] inbound failed: \(error.localizedDescription)") + } + + await destination.close() + } +} diff --git a/TablePro/CLI/Handshake.swift b/TablePro/CLI/Handshake.swift new file mode 100644 index 000000000..acefbb292 --- /dev/null +++ b/TablePro/CLI/Handshake.swift @@ -0,0 +1,105 @@ +import CryptoKit +import Foundation +import Security + +struct MCPBridgeHandshake: Codable, Sendable { + let port: Int + let token: String + let pid: Int32 + let protocolVersion: String + let tls: Bool? + let tlsCertFingerprint: String? +} + +enum MCPHandshakeError: Error, LocalizedError { + case launchFailed(status: Int32) + case timeout + case fileNotFound + + var errorDescription: String? { + switch self { + case .launchFailed(let status): + return "Failed to launch TablePro (open exit \(status))" + case .timeout: + return "Timed out waiting for TablePro MCP server to start" + case .fileNotFound: + return "Handshake file not found" + } + } +} + +struct MCPHandshakeAcquirer: Sendable { + private static let pollInterval: Duration = .milliseconds(200) + private static let pollTimeout: Duration = .seconds(10) + private static let launchUrl = "tablepro://integrations/start-mcp" + + let handshakePath: String + let logger: any MCPBridgeLogger + + init(logger: any MCPBridgeLogger) { + let home = FileManager.default.homeDirectoryForCurrentUser.path + self.handshakePath = "\(home)/Library/Application Support/TablePro/mcp-handshake.json" + self.logger = logger + } + + func acquire() async throws -> MCPBridgeHandshake { + if let existing = try? load(), isProcessRunning(pid: existing.pid) { + return existing + } + + if (try? load()) != nil { + logger.log(.warning, "Stale handshake detected; relaunching TablePro") + removeHandshake() + } + + try launchHostApp() + return try await pollForHandshake() + } + + private func load() throws -> MCPBridgeHandshake { + let url = URL(fileURLWithPath: handshakePath) + guard FileManager.default.fileExists(atPath: handshakePath) else { + throw MCPHandshakeError.fileNotFound + } + let data = try Data(contentsOf: url) + return try JSONDecoder().decode(MCPBridgeHandshake.self, from: data) + } + + private func removeHandshake() { + try? FileManager.default.removeItem(atPath: handshakePath) + } + + private func isProcessRunning(pid: Int32) -> Bool { + kill(pid, 0) == 0 + } + + private func launchHostApp() throws { + logger.log(.info, "TablePro not running; launching via \(Self.launchUrl)") + let process = Process() + process.executableURL = URL(fileURLWithPath: "/usr/bin/open") + process.arguments = ["-g", Self.launchUrl] + try process.run() + process.waitUntilExit() + if process.terminationStatus != 0 { + throw MCPHandshakeError.launchFailed(status: process.terminationStatus) + } + } + + private func pollForHandshake() async throws -> MCPBridgeHandshake { + let deadline = ContinuousClock().now.advanced(by: Self.pollTimeout) + while ContinuousClock().now < deadline { + if let handshake = try? load(), isProcessRunning(pid: handshake.pid) { + return handshake + } + try? await Task.sleep(for: Self.pollInterval) + } + throw MCPHandshakeError.timeout + } +} + +extension MCPBridgeHandshake { + func endpoint() -> URL? { + let scheme = (tls ?? false) ? "https" : "http" + return URL(string: "\(scheme)://127.0.0.1:\(port)/mcp") + } +} diff --git a/TablePro/CLI/MCPBridgeProxy.swift b/TablePro/CLI/MCPBridgeProxy.swift deleted file mode 100644 index f52bae633..000000000 --- a/TablePro/CLI/MCPBridgeProxy.swift +++ /dev/null @@ -1,328 +0,0 @@ -import CryptoKit -import Foundation -import Security - -struct MCPHandshake: Codable { - let port: Int - let token: String - let pid: Int32 - let protocolVersion: String - let tls: Bool? - let tlsCertFingerprint: String? -} - -private final class CertificatePinningDelegate: NSObject, URLSessionDelegate { - private let expectedFingerprint: String - - init(expectedFingerprint: String) { - self.expectedFingerprint = expectedFingerprint - } - - func urlSession( - _ session: URLSession, - didReceive challenge: URLAuthenticationChallenge - ) async -> (URLSession.AuthChallengeDisposition, URLCredential?) { - guard challenge.protectionSpace.authenticationMethod == NSURLAuthenticationMethodServerTrust, - let trust = challenge.protectionSpace.serverTrust else { - return (.performDefaultHandling, nil) - } - - guard let chain = SecTrustCopyCertificateChain(trust) as? [SecCertificate], - let serverCert = chain.first else { - return (.cancelAuthenticationChallenge, nil) - } - - let serverFingerprint = sha256Fingerprint(of: serverCert) - guard serverFingerprint == expectedFingerprint else { - return (.cancelAuthenticationChallenge, nil) - } - - return (.useCredential, URLCredential(trust: trust)) - } - - private func sha256Fingerprint(of certificate: SecCertificate) -> String { - let data = SecCertificateCopyData(certificate) as Data - return SHA256.hash(data: data) - .map { String(format: "%02X", $0) } - .joined(separator: ":") - } -} - -final class MCPBridgeProxy { - private static let pollInterval: TimeInterval = 0.2 - private static let pollTimeout: TimeInterval = 10.0 - private static let launchURL = "tablepro://integrations/start-mcp" - - private let handshakePath: String - private var sessionId: String? - - init() { - let home = FileManager.default.homeDirectoryForCurrentUser.path - self.handshakePath = "\(home)/Library/Application Support/TablePro/mcp-handshake.json" - } - - func run() async { - let handshake: MCPHandshake - do { - handshake = try await acquireHandshake() - } catch { - writeStderr("Error: \(error.localizedDescription)\n") - writeJsonRpcError( - id: .null, - code: -32_000, - message: "TablePro is not running. Launch the app and enable the MCP server." - ) - exit(1) - } - - let urlSession = makeSession(handshake: handshake) - let scheme = (handshake.tls ?? false) ? "https" : "http" - let baseUrl = "\(scheme)://127.0.0.1:\(handshake.port)/mcp" - await readLoop(baseUrl: baseUrl, bearerToken: handshake.token, urlSession: urlSession) - } - - private func acquireHandshake() async throws -> MCPHandshake { - if let handshake = try? loadHandshake(), isProcessRunning(pid: handshake.pid) { - return handshake - } - - if (try? loadHandshake()) != nil { - writeStderr("Stale handshake detected; relaunching TablePro\n") - removeHandshake() - } - - try launchHostApp() - return try await pollForHandshake() - } - - private func loadHandshake() throws -> MCPHandshake { - let data = try Data(contentsOf: URL(fileURLWithPath: handshakePath)) - return try JSONDecoder().decode(MCPHandshake.self, from: data) - } - - private func removeHandshake() { - try? FileManager.default.removeItem(atPath: handshakePath) - } - - private func isProcessRunning(pid: Int32) -> Bool { - kill(pid, 0) == 0 - } - - private func launchHostApp() throws { - writeStderr("TablePro not running; launching via \(Self.launchURL)\n") - let process = Process() - process.executableURL = URL(fileURLWithPath: "/usr/bin/open") - process.arguments = ["-g", Self.launchURL] - try process.run() - process.waitUntilExit() - if process.terminationStatus != 0 { - throw BridgeError.launchFailed(status: process.terminationStatus) - } - } - - private func pollForHandshake() async throws -> MCPHandshake { - let deadline = Date().addingTimeInterval(Self.pollTimeout) - while Date() < deadline { - if let handshake = try? loadHandshake(), isProcessRunning(pid: handshake.pid) { - return handshake - } - try? await Task.sleep(nanoseconds: UInt64(Self.pollInterval * 1_000_000_000)) - } - throw BridgeError.handshakeTimeout - } - - private func makeSession(handshake: MCPHandshake) -> URLSession { - let config = URLSessionConfiguration.ephemeral - config.timeoutIntervalForRequest = 60 - config.timeoutIntervalForResource = 60 - - let delegate: URLSessionDelegate? - if handshake.tls ?? false, let fingerprint = handshake.tlsCertFingerprint { - delegate = CertificatePinningDelegate(expectedFingerprint: fingerprint) - } else { - delegate = nil - } - return URLSession(configuration: config, delegate: delegate, delegateQueue: nil) - } - - private func readLoop(baseUrl: String, bearerToken: String, urlSession: URLSession) async { - let stdin = FileHandle.standardInput - var buffer = Data() - - while true { - let chunk = stdin.availableData - guard !chunk.isEmpty else { - break - } - - buffer.append(chunk) - - while let newlineIndex = buffer.firstIndex(of: 0x0A) { - let lineData = buffer[buffer.startIndex.. String? { - for (rawKey, rawValue) in response.allHeaderFields { - guard let keyString = rawKey as? String, - keyString.lowercased() == key.lowercased(), - let valueString = rawValue as? String else { continue } - return valueString - } - return nil - } - - private func captureSessionId(from response: HTTPURLResponse) { - guard let value = headerValue(response, forKey: "mcp-session-id") else { return } - if sessionId == nil { - writeStderr("Session established: \(value)\n") - } - sessionId = value - } - - private func extractRequestId(from data: Data) -> JsonRpcId { - guard let object = try? JSONSerialization.jsonObject(with: data) as? [String: Any] else { - return .null - } - - guard let id = object["id"] else { - return .null - } - - if let intId = id as? Int { - return .int(intId) - } - if let stringId = id as? String { - return .string(stringId) - } - - return .null - } - - private func writeJsonRpcError(id: JsonRpcId, code: Int, message: String) { - var errorResponse: [String: Any] = [ - "jsonrpc": "2.0", - "error": [ - "code": code, - "message": message - ] as [String: Any] - ] - - switch id { - case .null: - errorResponse["id"] = NSNull() - case .int(let value): - errorResponse["id"] = value - case .string(let value): - errorResponse["id"] = value - } - - guard let data = try? JSONSerialization.data(withJSONObject: errorResponse) else { return } - writeStdout(data) - writeStdout(Data([0x0A])) - } - - private func writeStdout(_ data: Data) { - FileHandle.standardOutput.write(data) - } - - private func writeStderr(_ message: String) { - guard let data = message.data(using: .utf8) else { return } - FileHandle.standardError.write(data) - } -} - -private enum JsonRpcId { - case null - case int(Int) - case string(String) -} - -private enum BridgeError: LocalizedError { - case invalidUrl - case launchFailed(status: Int32) - case handshakeTimeout - - var errorDescription: String? { - switch self { - case .invalidUrl: - "Invalid MCP server URL" - case .launchFailed(let status): - "Failed to launch TablePro (open exit \(status))" - case .handshakeTimeout: - "Timed out waiting for TablePro MCP server to start" - } - } -} diff --git a/TablePro/CLI/main.swift b/TablePro/CLI/main.swift deleted file mode 100644 index cfc9ef571..000000000 --- a/TablePro/CLI/main.swift +++ /dev/null @@ -1,9 +0,0 @@ -import Foundation - -let proxy = MCPBridgeProxy() - -Task { - await proxy.run() -} - -RunLoop.main.run() diff --git a/TablePro/Core/Database/DatabaseDriver.swift b/TablePro/Core/Database/DatabaseDriver.swift index 09880e787..9708f8eee 100644 --- a/TablePro/Core/Database/DatabaseDriver.swift +++ b/TablePro/Core/Database/DatabaseDriver.swift @@ -237,6 +237,16 @@ extension DatabaseDriver { userInfo: [NSLocalizedDescriptionKey: "Drop database is not supported by this driver"]) } + func createDatabaseFormSpec() async throws -> CreateDatabaseFormSpec? { nil } + + func createDatabase(_ request: CreateDatabaseRequest) async throws { + throw NSError( + domain: "DatabaseDriver", + code: -1, + userInfo: [NSLocalizedDescriptionKey: "Create database is not supported by this driver"] + ) + } + /// Default fetchAllDatabaseMetadata: falls back to per-database calls (N+1). /// Drivers should override with a single bulk query where possible. func fetchAllDatabaseMetadata() async throws -> [DatabaseMetadata] { diff --git a/TablePro/Core/MCP/Auth/MCPAuthDecision.swift b/TablePro/Core/MCP/Auth/MCPAuthDecision.swift new file mode 100644 index 000000000..1d4833afd --- /dev/null +++ b/TablePro/Core/MCP/Auth/MCPAuthDecision.swift @@ -0,0 +1,66 @@ +import Foundation + +public enum MCPAuthDecision: Sendable { + case allow(MCPPrincipal) + case deny(MCPAuthDenialReason) +} + +public struct MCPAuthDenialReason: Sendable, Equatable { + public let httpStatus: Int + public let challenge: String? + public let logMessage: String + public let retryAfterSeconds: Int? + + public init( + httpStatus: Int, + challenge: String?, + logMessage: String, + retryAfterSeconds: Int? = nil + ) { + self.httpStatus = httpStatus + self.challenge = challenge + self.logMessage = logMessage + self.retryAfterSeconds = retryAfterSeconds + } + + public static func unauthenticated(reason: String) -> Self { + Self( + httpStatus: 401, + challenge: "Bearer realm=\"TablePro MCP\"", + logMessage: reason + ) + } + + public static func tokenExpired() -> Self { + Self( + httpStatus: 401, + challenge: "Bearer realm=\"TablePro MCP\", error=\"invalid_token\", error_description=\"token_expired\"", + logMessage: "token_expired" + ) + } + + public static func tokenInvalid(reason: String) -> Self { + Self( + httpStatus: 401, + challenge: "Bearer realm=\"TablePro MCP\", error=\"invalid_token\"", + logMessage: reason + ) + } + + public static func forbidden(reason: String) -> Self { + Self( + httpStatus: 403, + challenge: nil, + logMessage: reason + ) + } + + public static func rateLimited(retryAfterSeconds: Int? = nil) -> Self { + Self( + httpStatus: 429, + challenge: nil, + logMessage: "rate_limited", + retryAfterSeconds: retryAfterSeconds + ) + } +} diff --git a/TablePro/Core/MCP/Auth/MCPAuthenticator.swift b/TablePro/Core/MCP/Auth/MCPAuthenticator.swift new file mode 100644 index 000000000..4aa761315 --- /dev/null +++ b/TablePro/Core/MCP/Auth/MCPAuthenticator.swift @@ -0,0 +1,13 @@ +import Foundation + +public enum MCPClientAddress: Sendable, Equatable, Hashable { + case loopback + case remote(String) +} + +public protocol MCPAuthenticator: Sendable { + func authenticate( + authorizationHeader: String?, + clientAddress: MCPClientAddress + ) async -> MCPAuthDecision +} diff --git a/TablePro/Core/MCP/Auth/MCPBearerTokenAuthenticator.swift b/TablePro/Core/MCP/Auth/MCPBearerTokenAuthenticator.swift new file mode 100644 index 000000000..5177af407 --- /dev/null +++ b/TablePro/Core/MCP/Auth/MCPBearerTokenAuthenticator.swift @@ -0,0 +1,213 @@ +import CryptoKit +import Foundation +import os + +public struct MCPValidatedToken: Sendable, Equatable { + public let tokenId: UUID + public let label: String? + public let scopes: Set + public let issuedAt: Date + public let expiresAt: Date? + + public init( + tokenId: UUID, + label: String?, + scopes: Set, + issuedAt: Date, + expiresAt: Date? + ) { + self.tokenId = tokenId + self.label = label + self.scopes = scopes + self.issuedAt = issuedAt + self.expiresAt = expiresAt + } +} + +public enum MCPTokenValidationError: Error, Sendable, Equatable { + case unknownToken + case expired + case revoked +} + +public protocol MCPTokenStoreProtocol: Sendable { + func validateBearerToken(_ token: String) async -> Result +} + +extension MCPTokenStore: MCPTokenStoreProtocol {} + +internal extension MCPTokenStore { + func validateBearerToken(_ bearerToken: String) async -> Result { + guard let authToken = self.validate(bearerToken: bearerToken) else { + return .failure(.unknownToken) + } + if authToken.isExpired { + return .failure(.expired) + } + if !authToken.isActive { + return .failure(.revoked) + } + let validated = MCPValidatedToken( + tokenId: authToken.id, + label: authToken.name, + scopes: Self.mcpScopes(for: authToken.permissions), + issuedAt: authToken.createdAt, + expiresAt: authToken.expiresAt + ) + return .success(validated) + } + + private static func mcpScopes(for permissions: TokenPermissions) -> Set { + switch permissions { + case .readOnly: + return [.toolsRead, .resourcesRead] + case .readWrite: + return [.toolsRead, .toolsWrite, .resourcesRead] + case .fullAccess: + return [.toolsRead, .toolsWrite, .resourcesRead, .admin] + } + } +} + +public actor MCPBearerTokenAuthenticator: MCPAuthenticator { + private static let logger = Logger(subsystem: "com.TablePro", category: "MCP.Auth") + + private let tokenStore: any MCPTokenStoreProtocol + private let rateLimiter: MCPRateLimiter + private let clock: any MCPClock + + public init( + tokenStore: any MCPTokenStoreProtocol, + rateLimiter: MCPRateLimiter, + clock: any MCPClock = MCPSystemClock() + ) { + self.tokenStore = tokenStore + self.rateLimiter = rateLimiter + self.clock = clock + } + + public func authenticate( + authorizationHeader: String?, + clientAddress: MCPClientAddress + ) async -> MCPAuthDecision { + let ipString = Self.ipString(for: clientAddress) + + guard let header = authorizationHeader, !header.isEmpty else { + let key = MCPRateLimitKey(clientAddress: clientAddress, principalFingerprint: nil) + if let retry = await rateLimitedRetryAfter(key: key) { + Self.logger.warning("Auth rejected (rate limited, missing header)") + MCPAuditLogger.logRateLimited(ip: ipString, retryAfterSeconds: retry) + return .deny(.rateLimited(retryAfterSeconds: retry)) + } + Self.logger.info("Auth missing Authorization header") + MCPAuditLogger.logAuthFailure(reason: "missing_authorization_header", ip: ipString) + return .deny(.unauthenticated(reason: "missing_authorization_header")) + } + + guard let token = Self.parseBearerToken(header) else { + let key = MCPRateLimitKey(clientAddress: clientAddress, principalFingerprint: nil) + if let retry = await rateLimitedRetryAfter(key: key) { + MCPAuditLogger.logRateLimited(ip: ipString, retryAfterSeconds: retry) + return .deny(.rateLimited(retryAfterSeconds: retry)) + } + _ = await rateLimiter.recordAttempt(key: key, success: false) + Self.logger.info("Auth invalid Authorization scheme") + MCPAuditLogger.logAuthFailure(reason: "invalid_authorization_scheme", ip: ipString) + return .deny(.unauthenticated(reason: "invalid_authorization_scheme")) + } + + let fingerprint = Self.fingerprint(of: token) + let principalKey = MCPRateLimitKey( + clientAddress: clientAddress, + principalFingerprint: fingerprint + ) + + if let retry = await rateLimitedRetryAfter(key: principalKey) { + Self.logger.warning( + "Auth rate limited fingerprint=\(fingerprint, privacy: .public)" + ) + MCPAuditLogger.logRateLimited(ip: ipString, retryAfterSeconds: retry) + return .deny(.rateLimited(retryAfterSeconds: retry)) + } + + let validation = await tokenStore.validateBearerToken(token) + switch validation { + case .failure(let error): + let verdict = await rateLimiter.recordAttempt(key: principalKey, success: false) + if case .lockedUntil(let unlockDate) = verdict { + let retry = await retryAfter(unlockDate: unlockDate) + MCPAuditLogger.logRateLimited(ip: ipString, retryAfterSeconds: retry) + return .deny(.rateLimited(retryAfterSeconds: retry)) + } + switch error { + case .unknownToken: + Self.logger.info("Auth unknown token fingerprint=\(fingerprint, privacy: .public)") + MCPAuditLogger.logAuthFailure(reason: "unknown_token", ip: ipString) + return .deny(.tokenInvalid(reason: "unknown_token")) + case .expired: + Self.logger.info("Auth expired token fingerprint=\(fingerprint, privacy: .public)") + MCPAuditLogger.logAuthFailure(reason: "expired_token", ip: ipString) + return .deny(.tokenExpired()) + case .revoked: + Self.logger.info("Auth revoked token fingerprint=\(fingerprint, privacy: .public)") + MCPAuditLogger.logAuthFailure(reason: "revoked_token", ip: ipString) + return .deny(.tokenInvalid(reason: "token_revoked")) + } + + case .success(let validated): + _ = await rateLimiter.recordAttempt(key: principalKey, success: true) + let principal = MCPPrincipal( + tokenFingerprint: fingerprint, + tokenId: validated.tokenId, + scopes: validated.scopes, + metadata: MCPPrincipalMetadata( + label: validated.label, + issuedAt: validated.issuedAt, + expiresAt: validated.expiresAt + ) + ) + Self.logger.info("Auth allowed fingerprint=\(fingerprint, privacy: .public)") + MCPAuditLogger.logAuthSuccess(tokenName: validated.label ?? "-", ip: ipString) + return .allow(principal) + } + } + + private func rateLimitedRetryAfter(key: MCPRateLimitKey) async -> Int? { + guard await rateLimiter.isLocked(key: key) else { return nil } + guard let unlockDate = await rateLimiter.lockedUntil(key: key) else { return nil } + return await retryAfter(unlockDate: unlockDate) + } + + private func retryAfter(unlockDate: Date) async -> Int { + let now = await clock.now() + let delta = unlockDate.timeIntervalSince(now) + if delta <= 0 { return 1 } + return max(1, Int(delta.rounded(.up))) + } + + private static func ipString(for address: MCPClientAddress) -> String { + switch address { + case .loopback: + return "127.0.0.1" + case .remote(let host): + return host + } + } + + internal static func parseBearerToken(_ header: String) -> String? { + let trimmed = header.trimmingCharacters(in: .whitespacesAndNewlines) + guard let spaceIndex = trimmed.firstIndex(of: " ") else { return nil } + let scheme = trimmed[trimmed.startIndex.. String { + guard let data = token.data(using: .utf8) else { return "" } + let digest = SHA256.hash(data: data) + let hex = digest.map { String(format: "%02x", $0) }.joined() + return String(hex.prefix(16)) + } +} diff --git a/TablePro/Core/MCP/Auth/MCPPrincipal.swift b/TablePro/Core/MCP/Auth/MCPPrincipal.swift new file mode 100644 index 000000000..d4de96724 --- /dev/null +++ b/TablePro/Core/MCP/Auth/MCPPrincipal.swift @@ -0,0 +1,51 @@ +import Foundation + +public enum MCPScope: String, Sendable, Equatable, Hashable, CaseIterable { + case toolsRead = "tools:read" + case toolsWrite = "tools:write" + case resourcesRead = "resources:read" + case admin +} + +public struct MCPPrincipalMetadata: Sendable, Equatable { + public let label: String? + public let issuedAt: Date + public let expiresAt: Date? + + public init(label: String?, issuedAt: Date, expiresAt: Date?) { + self.label = label + self.issuedAt = issuedAt + self.expiresAt = expiresAt + } +} + +public struct MCPPrincipal: Sendable, Equatable, Hashable { + public let tokenFingerprint: String + public let tokenId: UUID? + public let scopes: Set + public let metadata: MCPPrincipalMetadata + + public init( + tokenFingerprint: String, + tokenId: UUID? = nil, + scopes: Set, + metadata: MCPPrincipalMetadata + ) { + self.tokenFingerprint = tokenFingerprint + self.tokenId = tokenId + self.scopes = scopes + self.metadata = metadata + } + + public static func == (lhs: MCPPrincipal, rhs: MCPPrincipal) -> Bool { + lhs.tokenFingerprint == rhs.tokenFingerprint + && lhs.tokenId == rhs.tokenId + && lhs.scopes == rhs.scopes + && lhs.metadata == rhs.metadata + } + + public func hash(into hasher: inout Hasher) { + hasher.combine(tokenFingerprint) + hasher.combine(tokenId) + } +} diff --git a/TablePro/Core/MCP/MCPAuditLogStorage.swift b/TablePro/Core/MCP/MCPAuditLogStorage.swift index c42e4e7c0..3aaf13e6d 100644 --- a/TablePro/Core/MCP/MCPAuditLogStorage.swift +++ b/TablePro/Core/MCP/MCPAuditLogStorage.swift @@ -1,12 +1,11 @@ -// -// MCPAuditLogStorage.swift -// TablePro -// - import Foundation import os import SQLite3 +extension Notification.Name { + static let mcpAuditLogChanged = Notification.Name("com.TablePro.mcp.auditLogChanged") +} + actor MCPAuditLogStorage { static let shared = MCPAuditLogStorage() private static let logger = Logger(subsystem: "com.TablePro", category: "MCPAuditLogStorage") @@ -158,7 +157,13 @@ actor MCPAuditLogStorage { sqlite3_bind_null(statement, 9) } - return sqlite3_step(statement) == SQLITE_DONE + let inserted = sqlite3_step(statement) == SQLITE_DONE + if inserted { + Task { @MainActor in + NotificationCenter.default.post(name: .mcpAuditLogChanged, object: nil) + } + } + return inserted } func query( diff --git a/TablePro/Core/MCP/MCPAuditLogger.swift b/TablePro/Core/MCP/MCPAuditLogger.swift index f5f53a95f..f92f3e5a2 100644 --- a/TablePro/Core/MCP/MCPAuditLogger.swift +++ b/TablePro/Core/MCP/MCPAuditLogger.swift @@ -44,6 +44,45 @@ enum MCPAuditLogger { ) } + static func logPairingExchange( + outcome: AuditOutcome, + tokenName: String? = nil, + ip: String, + details: String? = nil + ) { + let resolvedDetails = Self.composePairingDetails(ip: ip, extra: details) + switch outcome { + case .success: + serverAuth.info( + "Pairing exchange success: token=\(tokenName ?? "-", privacy: .public) ip=\(ip, privacy: .public)" + ) + case .denied: + serverAuth.warning( + "Pairing exchange denied: ip=\(ip, privacy: .public) details=\(details ?? "-", privacy: .public)" + ) + case .rateLimited: + serverAuth.warning("Pairing exchange rate limited: ip=\(ip, privacy: .public)") + case .error: + serverAuth.error( + "Pairing exchange error: ip=\(ip, privacy: .public) details=\(details ?? "-", privacy: .public)" + ) + } + record( + category: .auth, + tokenName: tokenName, + action: "pairing.exchange", + outcome: outcome, + details: resolvedDetails + ) + } + + private static func composePairingDetails(ip: String, extra: String?) -> String { + guard let extra, !extra.isEmpty else { + return "ip=\(ip)" + } + return "ip=\(ip) \(extra)" + } + static func logTokenCreated(tokenName: String) { serverAdmin.info("Token created: \(tokenName, privacy: .public)") record( diff --git a/TablePro/Core/MCP/MCPAuthPolicy.swift b/TablePro/Core/MCP/MCPAuthPolicy.swift index 2fa2eee60..d1344d9b7 100644 --- a/TablePro/Core/MCP/MCPAuthPolicy.swift +++ b/TablePro/Core/MCP/MCPAuthPolicy.swift @@ -20,9 +20,11 @@ enum AuthDecision: Sendable { case denied(reason: String) } -actor MCPAuthPolicy { +public actor MCPAuthPolicy { private static let logger = Logger(subsystem: "com.TablePro", category: "MCPAuthPolicy") + public init() {} + private var sessionApprovals: [String: Set] = [:] private let approvalDedup = OnceTask() @@ -114,11 +116,11 @@ actor MCPAuthPolicy { return case .denied(let reason): - throw MCPError.forbidden(reason) + throw MCPDataLayerError.forbidden(reason) case .requiresUserApproval(let reason): guard let connectionId else { - throw MCPError.forbidden(reason) + throw MCPDataLayerError.forbidden(reason) } let approved = try await runApprovalDedup( sessionId: sessionId, @@ -128,7 +130,7 @@ actor MCPAuthPolicy { if approved { recordApproval(sessionId: sessionId, connectionId: connectionId) } else { - throw MCPError.forbidden( + throw MCPDataLayerError.forbidden( String(localized: "User denied MCP access to this connection") ) } @@ -171,7 +173,7 @@ actor MCPAuthPolicy { ) if case .blocked(let reason) = permission { - throw MCPError.forbidden(reason) + throw MCPDataLayerError.forbidden(reason) } } @@ -226,12 +228,12 @@ actor MCPAuthPolicy { } group.addTask { try await Task.sleep(for: .seconds(30)) - throw MCPError.timeout( + throw MCPDataLayerError.timeout( String(localized: "User approval timed out after 30 seconds") ) } guard let result = try await group.next() else { - throw MCPError.internalError("No result from approval prompt") + throw MCPDataLayerError.dataSourceError("No result from approval prompt") } return result } diff --git a/TablePro/Core/MCP/MCPConnectionBridge.swift b/TablePro/Core/MCP/MCPConnectionBridge.swift index f60ff6e29..4d7970431 100644 --- a/TablePro/Core/MCP/MCPConnectionBridge.swift +++ b/TablePro/Core/MCP/MCPConnectionBridge.swift @@ -1,17 +1,12 @@ -// -// MCPConnectionBridge.swift -// TablePro -// -// Bridges MCP tool/resource handlers to DatabaseManager and driver APIs. -// - import Foundation import os -actor MCPConnectionBridge { +public actor MCPConnectionBridge { private static let logger = Logger(subsystem: "com.TablePro", category: "MCPConnectionBridge") - func listConnections() async -> JSONValue { + public init() {} + + func listConnections() async -> JsonValue { let (connections, activeSessions) = await MainActor.run { let conns = ConnectionStorage.shared.loadConnections() .filter { $0.externalAccess != .blocked } @@ -19,7 +14,7 @@ actor MCPConnectionBridge { return (conns, sessions) } - let items: [JSONValue] = connections.map { conn in + let items: [JsonValue] = connections.map { conn in let session = activeSessions[conn.id] let isConnected = session?.status.isConnected ?? false let policy = conn.aiPolicy ?? AIConnectionPolicy.askEachTime @@ -41,21 +36,19 @@ actor MCPConnectionBridge { return .object(["connections": .array(items)]) } - func connect(connectionId: UUID) async throws -> JSONValue { + func connect(connectionId: UUID) async throws -> JsonValue { let connection = try await resolveConnection(connectionId) - // Check if session already exists and is connected -- reuse without switching UI let existingSession = await MainActor.run { DatabaseManager.shared.activeSessions[connectionId] } if let existing = existingSession, existing.driver != nil { - // Already connected, return current state without switching the UI's active session let serverVersion = existing.driver?.serverVersion let currentDatabase = existing.activeDatabase let currentSchema = existing.currentSchema - var result: [String: JSONValue] = [ + var result: [String: JsonValue] = [ "status": "connected", "current_database": .string(currentDatabase) ] @@ -79,7 +72,7 @@ actor MCPConnectionBridge { ) } - var result: [String: JSONValue] = [ + var result: [String: JsonValue] = [ "status": "connected", "current_database": .string(currentDatabase ?? "") ] @@ -98,12 +91,12 @@ actor MCPConnectionBridge { DatabaseManager.shared.activeSessions[connectionId] != nil } guard sessionExists else { - throw MCPError.notConnected(connectionId) + throw MCPDataLayerError.notConnected(connectionId) } await DatabaseManager.shared.disconnectSession(connectionId) } - func getConnectionStatus(connectionId: UUID) async throws -> JSONValue { + func getConnectionStatus(connectionId: UUID) async throws -> JsonValue { let core = await MainActor.run { () -> (status: ConnectionStatus, database: String, schema: String?)? in guard let session = DatabaseManager.shared.activeSessions[connectionId] else { @@ -113,7 +106,7 @@ actor MCPConnectionBridge { } guard let core else { - throw MCPError.notConnected(connectionId) + throw MCPDataLayerError.notConnected(connectionId) } let meta = await MainActor.run { @@ -127,7 +120,7 @@ actor MCPConnectionBridge { } let statusString: String - var errorDetail: JSONValue? + var errorDetail: JsonValue? switch core.status { case .connected: statusString = "connected" case .connecting: statusString = "connecting" @@ -139,7 +132,7 @@ actor MCPConnectionBridge { ]) } - var result: [String: JSONValue] = [ + var result: [String: JsonValue] = [ "status": .string(statusString), "current_database": .string(core.database), "connected_at": .string(ISO8601DateFormatter().string(from: meta.connectedAt)), @@ -163,7 +156,7 @@ actor MCPConnectionBridge { query: String, maxRows: Int, timeoutSeconds: Int - ) async throws -> JSONValue { + ) async throws -> JsonValue { let (driver, databaseType) = try await resolveDriver(connectionId) let normalizedQuery = Self.stripTrailingSemicolons(query) let isWrite = QueryClassifier.isWriteQuery(normalizedQuery, databaseType: databaseType) @@ -189,10 +182,10 @@ actor MCPConnectionBridge { group.addTask { try await Task.sleep(for: .seconds(timeoutSeconds)) try? driver.cancelQuery() - throw MCPError.timeout("Query timed out after \(timeoutSeconds) seconds") + throw MCPDataLayerError.timeout("Query timed out after \(timeoutSeconds) seconds") } guard let first = try await group.next() else { - throw MCPError.internalError("No result from query execution") + throw MCPDataLayerError.dataSourceError("No result from query execution") } group.cancelAll() return first @@ -202,8 +195,8 @@ actor MCPConnectionBridge { let executionTimeMs = (CFAbsoluteTimeGetCurrent() - startTime) * 1_000 let isTruncated = result.isTruncated - let jsonColumns: [JSONValue] = result.columns.map { .string($0) } - let jsonRows: [JSONValue] = result.rows.map { row in + let jsonColumns: [JsonValue] = result.columns.map { .string($0) } + let jsonRows: [JsonValue] = result.rows.map { row in .array(row.map { cell in if let value = cell { return .string(value) @@ -212,7 +205,7 @@ actor MCPConnectionBridge { }) } - var response: [String: JSONValue] = [ + var response: [String: JsonValue] = [ "columns": .array(jsonColumns), "rows": .array(jsonRows), "row_count": .int(result.rows.count), @@ -227,7 +220,7 @@ actor MCPConnectionBridge { return .object(response) } - func listTables(connectionId: UUID, includeRowCounts: Bool) async throws -> JSONValue { + func listTables(connectionId: UUID, includeRowCounts: Bool) async throws -> JsonValue { let cachedTables = await MainActor.run { SchemaService.shared.tables(for: connectionId) } @@ -242,8 +235,8 @@ actor MCPConnectionBridge { } } - let jsonTables: [JSONValue] = tables.map { table in - var obj: [String: JSONValue] = [ + let jsonTables: [JsonValue] = tables.map { table in + var obj: [String: JsonValue] = [ "name": .string(table.name), "type": .string(table.type.rawValue) ] @@ -256,11 +249,9 @@ actor MCPConnectionBridge { return .object(["tables": .array(jsonTables)]) } - func describeTable(connectionId: UUID, table: String, schema: String?) async throws -> JSONValue { + func describeTable(connectionId: UUID, table: String, schema: String?) async throws -> JsonValue { let (driver, _) = try await resolveDriver(connectionId) - // Sequential fetches: driver connections are NOT thread-safe, - // so concurrent calls on the same driver would race. return try await DatabaseManager.shared.trackOperation(sessionId: connectionId) { let columns = try await driver.fetchColumns(table: table, schema: schema) let indexes = try await driver.fetchIndexes(table: table) @@ -268,8 +259,8 @@ actor MCPConnectionBridge { let approxRowCount = try await driver.fetchApproximateRowCount(table: table) let ddl = try? await driver.fetchTableDDL(table: table) - let jsonColumns: [JSONValue] = columns.map { col in - var obj: [String: JSONValue] = [ + let jsonColumns: [JsonValue] = columns.map { col in + var obj: [String: JsonValue] = [ "name": .string(col.name), "data_type": .string(col.dataType), "is_nullable": .bool(col.isNullable), @@ -281,7 +272,7 @@ actor MCPConnectionBridge { return .object(obj) } - let jsonIndexes: [JSONValue] = indexes.map { idx in + let jsonIndexes: [JsonValue] = indexes.map { idx in .object([ "name": .string(idx.name), "columns": .array(idx.columns.map { .string($0) }), @@ -291,8 +282,8 @@ actor MCPConnectionBridge { ]) } - let jsonFKs: [JSONValue] = foreignKeys.map { fk in - var obj: [String: JSONValue] = [ + let jsonFKs: [JsonValue] = foreignKeys.map { fk in + var obj: [String: JsonValue] = [ "name": .string(fk.name), "column": .string(fk.column), "referenced_table": .string(fk.referencedTable), @@ -306,7 +297,7 @@ actor MCPConnectionBridge { return .object(obj) } - var result: [String: JSONValue] = [ + var result: [String: JsonValue] = [ "columns": .array(jsonColumns), "indexes": .array(jsonIndexes), "foreign_keys": .array(jsonFKs) @@ -322,7 +313,7 @@ actor MCPConnectionBridge { } } - func listDatabases(connectionId: UUID) async throws -> JSONValue { + func listDatabases(connectionId: UUID) async throws -> JsonValue { let (driver, _) = try await resolveDriver(connectionId) let databases = try await DatabaseManager.shared.trackOperation(sessionId: connectionId) { try await driver.fetchDatabases() @@ -330,7 +321,7 @@ actor MCPConnectionBridge { return .object(["databases": .array(databases.map { .string($0) })]) } - func listSchemas(connectionId: UUID) async throws -> JSONValue { + func listSchemas(connectionId: UUID) async throws -> JsonValue { let (driver, _) = try await resolveDriver(connectionId) let schemas = try await DatabaseManager.shared.trackOperation(sessionId: connectionId) { try await driver.fetchSchemas() @@ -338,7 +329,7 @@ actor MCPConnectionBridge { return .object(["schemas": .array(schemas.map { .string($0) })]) } - func getTableDDL(connectionId: UUID, table: String, schema: String?) async throws -> JSONValue { + func getTableDDL(connectionId: UUID, table: String, schema: String?) async throws -> JsonValue { let (driver, _) = try await resolveDriver(connectionId) let ddl = try await DatabaseManager.shared.trackOperation(sessionId: connectionId) { try await driver.fetchTableDDL(table: table) @@ -346,8 +337,7 @@ actor MCPConnectionBridge { return .object(["ddl": .string(ddl)]) } - func switchDatabase(connectionId: UUID, database: String) async throws -> JSONValue { - // switchDatabase is @MainActor; Swift hops automatically for async calls. + func switchDatabase(connectionId: UUID, database: String) async throws -> JsonValue { try await DatabaseManager.shared.switchDatabase(to: database, for: connectionId) return .object([ "status": "switched", @@ -355,8 +345,7 @@ actor MCPConnectionBridge { ]) } - func switchSchema(connectionId: UUID, schema: String) async throws -> JSONValue { - // switchSchema is @MainActor; Swift hops automatically for async calls. + func switchSchema(connectionId: UUID, schema: String) async throws -> JsonValue { try await DatabaseManager.shared.switchSchema(to: schema, for: connectionId) return .object([ "status": "switched", @@ -364,7 +353,7 @@ actor MCPConnectionBridge { ]) } - func fetchSchemaResource(connectionId: UUID) async throws -> JSONValue { + func fetchSchemaResource(connectionId: UUID) async throws -> JsonValue { let cachedTables = await MainActor.run { SchemaService.shared.tables(for: connectionId) } @@ -382,13 +371,13 @@ actor MCPConnectionBridge { let limitedTables = Array(tables.prefix(100)) - var tableSchemas: [JSONValue] = [] + var tableSchemas: [JsonValue] = [] for table in limitedTables { let columns = try await DatabaseManager.shared.trackOperation(sessionId: connectionId) { try await driver.fetchColumns(table: table.name) } - let jsonCols: [JSONValue] = columns.map { col in + let jsonCols: [JsonValue] = columns.map { col in .object([ "name": .string(col.name), "data_type": .string(col.dataType), @@ -404,7 +393,7 @@ actor MCPConnectionBridge { ])) } - var result: [String: JSONValue] = ["tables": .array(tableSchemas)] + var result: [String: JsonValue] = ["tables": .array(tableSchemas)] if tables.count > 100 { result["truncated"] = .bool(true) result["total_tables"] = .int(tables.count) @@ -418,7 +407,7 @@ actor MCPConnectionBridge { limit: Int, search: String?, dateFilter: String? - ) async throws -> JSONValue { + ) async throws -> JsonValue { let filter: DateFilter switch dateFilter { case "today": filter = .today @@ -434,8 +423,8 @@ actor MCPConnectionBridge { dateFilter: filter ) - let jsonEntries: [JSONValue] = entries.map { entry in - var obj: [String: JSONValue] = [ + let jsonEntries: [JsonValue] = entries.map { entry in + var obj: [String: JsonValue] = [ "id": .string(entry.id.uuidString), "query": .string(entry.query), "database_name": .string(entry.databaseName), @@ -469,7 +458,7 @@ actor MCPConnectionBridge { case .live(let driver, let session): return (driver, session.connection.type) case .stored, .unknown: - throw MCPError.notConnected(connectionId) + throw MCPDataLayerError.notConnected(connectionId) } } } @@ -481,7 +470,7 @@ actor MCPConnectionBridge { private func resolveSession(_ connectionId: UUID) async throws -> ConnectionSession { try await MainActor.run { guard let session = DatabaseManager.shared.activeSessions[connectionId] else { - throw MCPError.notConnected(connectionId) + throw MCPDataLayerError.notConnected(connectionId) } return session } @@ -491,7 +480,7 @@ actor MCPConnectionBridge { try await MainActor.run { let connections = ConnectionStorage.shared.loadConnections() guard let connection = connections.first(where: { $0.id == connectionId }) else { - throw MCPError.invalidParams("Connection not found: \(connectionId)") + throw MCPDataLayerError.invalidArgument("Connection not found: \(connectionId)") } return connection } diff --git a/TablePro/Core/MCP/MCPDataLayerError.swift b/TablePro/Core/MCP/MCPDataLayerError.swift new file mode 100644 index 000000000..dfaf68c90 --- /dev/null +++ b/TablePro/Core/MCP/MCPDataLayerError.swift @@ -0,0 +1,42 @@ +import Foundation + +enum MCPDataLayerError: Error, Sendable { + case notConnected(UUID) + case invalidArgument(String) + case forbidden(String, context: [String: String]? = nil) + case timeout(String, context: [String: String]? = nil) + case notFound(String) + case expired(String) + case userCancelled + case dataSourceError(String) + + var message: String { + switch self { + case .notConnected(let connectionId): + "Not connected: \(connectionId)" + case .invalidArgument(let detail): + "Invalid argument: \(detail)" + case .forbidden(let detail, _): + "Forbidden: \(detail)" + case .timeout(let detail, _): + "Timeout: \(detail)" + case .notFound(let detail): + "Not found: \(detail)" + case .expired(let detail): + "Expired: \(detail)" + case .userCancelled: + "User cancelled" + case .dataSourceError(let detail): + "Data source error: \(detail)" + } + } + + var isUserCancelled: Bool { + if case .userCancelled = self { return true } + return false + } +} + +extension MCPDataLayerError: LocalizedError { + var errorDescription: String? { message } +} diff --git a/TablePro/Core/MCP/MCPHTTPParser.swift b/TablePro/Core/MCP/MCPHTTPParser.swift deleted file mode 100644 index 5662aa73e..000000000 --- a/TablePro/Core/MCP/MCPHTTPParser.swift +++ /dev/null @@ -1,253 +0,0 @@ -import Foundation -import os - -struct HTTPRequest: Sendable { - enum Method: String, Sendable { - case get = "GET" - case post = "POST" - case delete = "DELETE" - case options = "OPTIONS" - } - - let method: Method - let path: String - let headers: [String: String] - let body: Data? - var remoteIP: String? - - init(method: Method, path: String, headers: [String: String], body: Data?, remoteIP: String? = nil) { - self.method = method - self.path = path - self.headers = headers - self.body = body - self.remoteIP = remoteIP - } - - func withRemoteIP(_ remoteIP: String?) -> HTTPRequest { - HTTPRequest(method: method, path: path, headers: headers, body: body, remoteIP: remoteIP) - } -} - -enum HTTPParseError: Error, Sendable { - case incomplete - case malformedRequestLine - case malformedHeaders - case unsupportedMethod(String) - case bodyTooLarge - case malformedChunkedEncoding -} - -enum MCPHTTPParser { - private static let logger = Logger(subsystem: "com.TablePro", category: "MCPHTTPParser") - - static let maxBodySize = 10 * 1_024 * 1_024 - - static func parse(_ data: Data) -> Result { - let crlfcrlf = Data([0x0D, 0x0A, 0x0D, 0x0A]) - let lflf = Data([0x0A, 0x0A]) - - let headerEndRange: Range - if let range = data.range(of: crlfcrlf) { - headerEndRange = range - } else if let range = data.range(of: lflf) { - headerEndRange = range - } else { - return .failure(.incomplete) - } - - let headerData = data[data.startIndex..= 2 else { - return .failure(.malformedRequestLine) - } - - let methodString = String(requestParts[0]) - guard let method = HTTPRequest.Method(rawValue: methodString) else { - return .failure(.unsupportedMethod(methodString)) - } - - let path = String(requestParts[1]) - - var headers: [String: String] = [:] - for i in 1.. maxBodySize { - return .failure(.bodyTooLarge) - } - body = decoded - case .failure(let error): - return .failure(error) - } - } else if let contentLengthStr = headers["content-length"], - let contentLength = Int(contentLengthStr) - { - if contentLength > maxBodySize { - return .failure(.bodyTooLarge) - } - - let availableBytes = data.count - bodyStartIndex - if availableBytes < contentLength { - return .failure(.incomplete) - } - - body = data[bodyStartIndex..<(bodyStartIndex + contentLength)] - } - } - - return .success(HTTPRequest( - method: method, - path: path, - headers: headers, - body: body - )) - } - - private static func decodeChunkedBody(_ data: Data) -> Result { - var result = Data() - var offset = data.startIndex - - while offset < data.endIndex { - guard let lineEnd = findCRLF(in: data, from: offset) else { - return .failure(.incomplete) - } - - let sizeData = data[offset.. maxBodySize { - return .failure(.bodyTooLarge) - } - - result.append(data[chunkDataStart.. Data.Index? { - var i = start - while i < data.endIndex - 1 { - if data[i] == 0x0D, data[i + 1] == 0x0A { - return i - } - i += 1 - } - return nil - } - - static func buildResponse( - status: Int, - statusText: String, - headers: [(String, String)], - body: Data? - ) -> Data { - var response = "HTTP/1.1 \(status) \(statusText)\r\n" - for (key, value) in headers { - response += "\(key): \(value)\r\n" - } - if let body { - response += "Content-Length: \(body.count)\r\n" - } - response += "\r\n" - var data = Data(response.utf8) - if let body { - data.append(body) - } - return data - } - - static func buildSSEHeaders(sessionId: String, corsHeaders: [(String, String)] = []) -> Data { - var response = "HTTP/1.1 200 OK\r\n" - + "Content-Type: text/event-stream\r\n" - + "Cache-Control: no-cache\r\n" - + "Connection: keep-alive\r\n" - + "Mcp-Session-Id: \(sessionId)\r\n" - for (key, value) in corsHeaders { - response += "\(key): \(value)\r\n" - } - response += "\r\n" - return Data(response.utf8) - } - - static func buildSSEEvent(data: Data, id: String? = nil) -> Data { - var event = Data() - if let id { - event.append(Data("id: \(id)\n".utf8)) - } - event.append(Data("data: ".utf8)) - event.append(data) - event.append(Data("\n\n".utf8)) - return event - } - - static func statusText(for code: Int) -> String { - switch code { - case 200: return "OK" - case 202: return "Accepted" - case 204: return "No Content" - case 400: return "Bad Request" - case 401: return "Unauthorized" - case 403: return "Forbidden" - case 404: return "Not Found" - case 405: return "Method Not Allowed" - case 406: return "Not Acceptable" - case 413: return "Content Too Large" - case 429: return "Too Many Requests" - case 500: return "Internal Server Error" - default: return "Unknown" - } - } -} diff --git a/TablePro/Core/MCP/MCPMessageTypes.swift b/TablePro/Core/MCP/MCPMessageTypes.swift deleted file mode 100644 index 01ae77a47..000000000 --- a/TablePro/Core/MCP/MCPMessageTypes.swift +++ /dev/null @@ -1,428 +0,0 @@ -// -// MCPMessageTypes.swift -// TablePro -// - -import Foundation - -enum JSONValue: Codable, Equatable, Sendable { - case null - case bool(Bool) - case int(Int) - case double(Double) - case string(String) - case array([JSONValue]) - case object([String: JSONValue]) - - init(from decoder: Decoder) throws { - let container = try decoder.singleValueContainer() - - if container.decodeNil() { - self = .null - return - } - - if let boolValue = try? container.decode(Bool.self) { - self = .bool(boolValue) - return - } - - if let intValue = try? container.decode(Int.self) { - self = .int(intValue) - return - } - - if let doubleValue = try? container.decode(Double.self) { - self = .double(doubleValue) - return - } - - if let stringValue = try? container.decode(String.self) { - self = .string(stringValue) - return - } - - if let arrayValue = try? container.decode([JSONValue].self) { - self = .array(arrayValue) - return - } - - if let objectValue = try? container.decode([String: JSONValue].self) { - self = .object(objectValue) - return - } - - throw DecodingError.dataCorruptedError(in: container, debugDescription: "Cannot decode JSONValue") - } - - func encode(to encoder: Encoder) throws { - var container = encoder.singleValueContainer() - switch self { - case .null: - try container.encodeNil() - case .bool(let value): - try container.encode(value) - case .int(let value): - try container.encode(value) - case .double(let value): - try container.encode(value) - case .string(let value): - try container.encode(value) - case .array(let value): - try container.encode(value) - case .object(let value): - try container.encode(value) - } - } -} - -extension JSONValue: ExpressibleByStringLiteral { - init(stringLiteral value: String) { - self = .string(value) - } -} - -extension JSONValue: ExpressibleByIntegerLiteral { - init(integerLiteral value: Int) { - self = .int(value) - } -} - -extension JSONValue: ExpressibleByBooleanLiteral { - init(booleanLiteral value: Bool) { - self = .bool(value) - } -} - -extension JSONValue: ExpressibleByNilLiteral { - init(nilLiteral: ()) { - self = .null - } -} - -extension JSONValue: ExpressibleByArrayLiteral { - init(arrayLiteral elements: JSONValue...) { - self = .array(elements) - } -} - -extension JSONValue: ExpressibleByDictionaryLiteral { - init(dictionaryLiteral elements: (String, JSONValue)...) { - self = .object(Dictionary(uniqueKeysWithValues: elements)) - } -} - -extension JSONValue { - subscript(key: String) -> JSONValue? { - guard case .object(let dict) = self else { return nil } - return dict[key] - } - - var stringValue: String? { - guard case .string(let value) = self else { return nil } - return value - } - - var intValue: Int? { - guard case .int(let value) = self else { return nil } - return value - } - - var boolValue: Bool? { - guard case .bool(let value) = self else { return nil } - return value - } - - var doubleValue: Double? { - switch self { - case .double(let value): - return value - case .int(let value): - return Double(value) - default: - return nil - } - } - - var arrayValue: [JSONValue]? { - guard case .array(let value) = self else { return nil } - return value - } - - var objectValue: [String: JSONValue]? { - guard case .object(let value) = self else { return nil } - return value - } -} - -enum JSONRPCId: Codable, Equatable, Hashable, Sendable { - case string(String) - case int(Int) - - init(from decoder: Decoder) throws { - let container = try decoder.singleValueContainer() - - if let intValue = try? container.decode(Int.self) { - self = .int(intValue) - return - } - - if let stringValue = try? container.decode(String.self) { - self = .string(stringValue) - return - } - - throw DecodingError.dataCorruptedError(in: container, debugDescription: "JSONRPCId must be a string or integer") - } - - func encode(to encoder: Encoder) throws { - var container = encoder.singleValueContainer() - switch self { - case .string(let value): - try container.encode(value) - case .int(let value): - try container.encode(value) - } - } -} - -struct JSONRPCRequest: Codable, Sendable { - let jsonrpc: String - let id: JSONRPCId? - let method: String - let params: JSONValue? -} - -struct JSONRPCResponse: Codable, Sendable { - let id: JSONRPCId - let result: JSONValue - - var jsonrpc: String { "2.0" } - - enum CodingKeys: String, CodingKey { - case jsonrpc - case id - case result - } - - init(id: JSONRPCId, result: JSONValue) { - self.id = id - self.result = result - } - - init(from decoder: Decoder) throws { - let container = try decoder.container(keyedBy: CodingKeys.self) - _ = try container.decode(String.self, forKey: .jsonrpc) - id = try container.decode(JSONRPCId.self, forKey: .id) - result = try container.decode(JSONValue.self, forKey: .result) - } - - func encode(to encoder: Encoder) throws { - var container = encoder.container(keyedBy: CodingKeys.self) - try container.encode("2.0", forKey: .jsonrpc) - try container.encode(id, forKey: .id) - try container.encode(result, forKey: .result) - } -} - -struct JSONRPCErrorResponse: Codable, Sendable { - let id: JSONRPCId? - let error: JSONRPCErrorDetail - - var jsonrpc: String { "2.0" } - - enum CodingKeys: String, CodingKey { - case jsonrpc - case id - case error - } - - init(id: JSONRPCId?, error: JSONRPCErrorDetail) { - self.id = id - self.error = error - } - - init(from decoder: Decoder) throws { - let container = try decoder.container(keyedBy: CodingKeys.self) - _ = try container.decode(String.self, forKey: .jsonrpc) - id = try container.decodeIfPresent(JSONRPCId.self, forKey: .id) - error = try container.decode(JSONRPCErrorDetail.self, forKey: .error) - } - - func encode(to encoder: Encoder) throws { - var container = encoder.container(keyedBy: CodingKeys.self) - try container.encode("2.0", forKey: .jsonrpc) - try container.encode(id, forKey: .id) - try container.encode(error, forKey: .error) - } -} - -struct JSONRPCErrorDetail: Codable, Sendable { - let code: Int - let message: String - let data: JSONValue? -} - -enum MCPError: Error, Sendable { - case parseError - case invalidRequest(String) - case methodNotFound(String) - case invalidParams(String) - case internalError(String) - case notConnected(UUID) - case forbidden(String, context: [String: String]? = nil) - case timeout(String, context: [String: String]? = nil) - case resultTooLarge - case serverDisabled - case notFound(String) - case expired(String) - case userCancelled - - var code: Int { - switch self { - case .parseError: -32_700 - case .invalidRequest: -32_600 - case .methodNotFound: -32_601 - case .invalidParams: -32_602 - case .internalError: -32_603 - case .notConnected: -32_000 - case .forbidden: -32_001 - case .timeout: -32_002 - case .resultTooLarge: -32_003 - case .serverDisabled: -32_004 - case .notFound: -32_005 - case .expired: -32_006 - case .userCancelled: -32_007 - } - } - - var message: String { - switch self { - case .parseError: - "Parse error" - case .invalidRequest(let detail): - "Invalid request: \(detail)" - case .methodNotFound(let method): - "Method not found: \(method)" - case .invalidParams(let detail): - "Invalid params: \(detail)" - case .internalError(let detail): - "Internal error: \(detail)" - case .notConnected(let connectionId): - "Not connected: \(connectionId)" - case .forbidden(let detail, _): - "Forbidden: \(detail)" - case .timeout(let detail, _): - "Timeout: \(detail)" - case .resultTooLarge: - "Result too large" - case .serverDisabled: - "MCP server is disabled" - case .notFound(let detail): - "Not found: \(detail)" - case .expired(let detail): - "Expired: \(detail)" - case .userCancelled: - "User cancelled" - } - } - - private var contextData: JSONValue? { - switch self { - case .forbidden(_, let context), .timeout(_, let context): - guard let context, !context.isEmpty else { return nil } - var dict: [String: JSONValue] = [:] - for (key, value) in context { - dict[key] = .string(value) - } - return .object(dict) - case .notConnected(let connectionId): - return .object(["connection_id": .string(connectionId.uuidString)]) - default: - return nil - } - } - - func toJsonRpcError(id: JSONRPCId?) -> JSONRPCErrorResponse { - JSONRPCErrorResponse( - id: id, - error: JSONRPCErrorDetail(code: code, message: message, data: contextData) - ) - } - - var isUserCancelled: Bool { - if case .userCancelled = self { return true } - return false - } -} - -extension MCPError: LocalizedError { - var errorDescription: String? { message } -} - -struct MCPClientInfo: Codable, Sendable { - let name: String - let version: String? -} - -struct MCPInitializeResult: Codable, Sendable { - let protocolVersion: String - let capabilities: MCPServerCapabilities - let serverInfo: MCPServerInfo -} - -struct MCPServerCapabilities: Codable, Sendable { - let tools: ToolCapability? - let resources: ResourceCapability? - - struct ToolCapability: Codable, Sendable { - let listChanged: Bool - } - - struct ResourceCapability: Codable, Sendable { - let subscribe: Bool - let listChanged: Bool - } -} - -struct MCPServerInfo: Codable, Sendable { - let name: String - let version: String -} - -struct MCPToolDefinition: Codable, Sendable { - let name: String - let description: String - let inputSchema: JSONValue -} - -struct MCPToolResult: Codable, Sendable { - let content: [MCPContent] - let isError: Bool? -} - -struct MCPContent: Codable, Sendable { - let type: String - let text: String - - static func text(_ value: String) -> MCPContent { - MCPContent(type: "text", text: value) - } -} - -struct MCPResourceDefinition: Codable, Sendable { - let uri: String - let name: String - let description: String? - let mimeType: String? -} - -struct MCPResourceContent: Codable, Sendable { - let uri: String - let mimeType: String? - let text: String? -} - -struct MCPResourceReadResult: Codable, Sendable { - let contents: [MCPResourceContent] -} diff --git a/TablePro/Core/MCP/MCPPairingService.swift b/TablePro/Core/MCP/MCPPairingService.swift index 3201bf579..1b6729767 100644 --- a/TablePro/Core/MCP/MCPPairingService.swift +++ b/TablePro/Core/MCP/MCPPairingService.swift @@ -1,8 +1,3 @@ -// -// MCPPairingService.swift -// TablePro -// - import AppKit import CryptoKit import Foundation @@ -26,7 +21,7 @@ final class PairingExchangeStore: @unchecked Sendable { defer { lock.unlock() } prune(now: Date.now) guard pending.count < Self.maxPendingCodes else { - throw MCPError.forbidden( + throw MCPDataLayerError.forbidden( String(localized: "Too many pending pairing codes. Try again later.") ) } @@ -39,17 +34,17 @@ final class PairingExchangeStore: @unchecked Sendable { prune(now: now) guard let entry = pending[code] else { - throw MCPError.notFound("pairing code") + throw MCPDataLayerError.notFound("pairing code") } guard entry.expiresAt > now else { pending.removeValue(forKey: code) - throw MCPError.expired("pairing code") + throw MCPDataLayerError.expired("pairing code") } let computed = Self.sha256Base64Url(of: verifier) guard Self.constantTimeEqual(entry.challenge, computed) else { - throw MCPError.forbidden("challenge mismatch") + throw MCPDataLayerError.forbidden("challenge mismatch") } let token = entry.plaintextToken @@ -123,10 +118,25 @@ final class MCPPairingService { guard let tokenStore = MCPServerManager.shared.tokenStore else { Self.logger.error("Token store unavailable after lazyStart") - throw MCPError.internalError("Token store unavailable") + throw MCPDataLayerError.dataSourceError("Token store unavailable") } - let approval = try await AlertHelper.runPairingApproval(request: request) + let approval: PairingApproval + do { + approval = try await AlertHelper.runPairingApproval(request: request) + } catch let error as MCPDataLayerError where error.isUserCancelled { + Self.logger.info("Pairing denied for client '\(request.clientName, privacy: .public)'") + if let redirect = buildErrorRedirect( + base: request.redirectURL, + error: "denied", + description: "user_denied" + ) { + NSWorkspace.shared.open(redirect) + } + throw error + } + + await Self.revokeExistingTokens(named: request.clientName, in: tokenStore) let connectionAccess: ConnectionAccess = approval.allowedConnectionIds.map { .limited($0) } ?? .all let result = await tokenStore.generate( @@ -154,7 +164,7 @@ final class MCPPairingService { guard let redirect = buildRedirectURL(base: request.redirectURL, code: code) else { Self.logger.error("Failed to build pairing redirect URL") await tokenStore.delete(tokenId: result.token.id) - throw MCPError.invalidParams("redirect URL") + throw MCPDataLayerError.invalidArgument("redirect URL") } Self.logger.info("Pairing approved for client '\(request.clientName, privacy: .public)'") @@ -165,6 +175,14 @@ final class MCPPairingService { try store.consume(code: exchange.code, verifier: exchange.verifier) } + private static func revokeExistingTokens(named name: String, in store: MCPTokenStore) async { + let active = await store.activeTokens() + for token in active where token.name == name { + await store.revoke(tokenId: token.id) + Self.logger.info("Revoked previous token '\(name, privacy: .public)' before re-pairing") + } + } + private func startPruneLoop() { pruneTask = Task { [store] in while !Task.isCancelled { @@ -175,6 +193,26 @@ final class MCPPairingService { } } + private func buildErrorRedirect(base: URL, error: String, description: String) -> URL? { + guard var components = URLComponents(url: base, resolvingAgainstBaseURL: false) else { + return nil + } + var items = components.queryItems ?? [] + if base.scheme == "raycast" { + let payload: [String: String] = ["error": error, "error_description": description] + guard let data = try? JSONSerialization.data(withJSONObject: payload, options: [.sortedKeys]), + let json = String(data: data, encoding: .utf8) else { + return nil + } + items.append(URLQueryItem(name: "context", value: json)) + } else { + items.append(URLQueryItem(name: "error", value: error)) + items.append(URLQueryItem(name: "error_description", value: description)) + } + components.queryItems = items + return components.url + } + private func buildRedirectURL(base: URL, code: String) -> URL? { guard var components = URLComponents(url: base, resolvingAgainstBaseURL: false) else { return nil diff --git a/TablePro/Core/MCP/MCPPortAllocator.swift b/TablePro/Core/MCP/MCPPortAllocator.swift index 9a354665a..402349cec 100644 --- a/TablePro/Core/MCP/MCPPortAllocator.swift +++ b/TablePro/Core/MCP/MCPPortAllocator.swift @@ -1,8 +1,3 @@ -// -// MCPPortAllocator.swift -// TablePro -// - import Darwin import Foundation import os diff --git a/TablePro/Core/MCP/MCPRateLimiter.swift b/TablePro/Core/MCP/MCPRateLimiter.swift deleted file mode 100644 index edf5ac0d6..000000000 --- a/TablePro/Core/MCP/MCPRateLimiter.swift +++ /dev/null @@ -1,94 +0,0 @@ -import Foundation -import os - -actor MCPRateLimiter { - enum AuthRateResult: Sendable { - case allowed - case rateLimited(retryAfter: Duration) - } - - private struct FailureRecord { - var consecutiveFailures: Int - var lockedUntil: ContinuousClock.Instant? - var lastUpdated: ContinuousClock.Instant - } - - private static let logger = Logger(subsystem: "com.TablePro", category: "MCPRateLimiter") - - private static let staleEntryThreshold: Duration = .seconds(600) - private static let cleanupInterval: Duration = .seconds(300) - - private var records: [String: FailureRecord] = [:] - private var lastCleanup: ContinuousClock.Instant = .now - - func checkAndRecord(ip: String, success: Bool) -> AuthRateResult { - cleanupStaleEntriesIfNeeded() - - let now = ContinuousClock.now - - if let record = records[ip], let lockedUntil = record.lockedUntil, now < lockedUntil { - let remaining = lockedUntil - now - return .rateLimited(retryAfter: remaining) - } - - guard !success else { - records.removeValue(forKey: ip) - return .allowed - } - - var record = records[ip] ?? FailureRecord(consecutiveFailures: 0, lockedUntil: nil, lastUpdated: now) - record.consecutiveFailures += 1 - record.lastUpdated = now - - let lockoutDuration = lockoutDuration(forFailureCount: record.consecutiveFailures) - if let lockout = lockoutDuration { - record.lockedUntil = now + lockout - records[ip] = record - return .rateLimited(retryAfter: lockout) - } - - record.lockedUntil = nil - records[ip] = record - return .allowed - } - - func isLockedOut(ip: String) -> AuthRateResult { - let now = ContinuousClock.now - guard let record = records[ip], let lockedUntil = record.lockedUntil, now < lockedUntil else { - return .allowed - } - return .rateLimited(retryAfter: lockedUntil - now) - } - - private func lockoutDuration(forFailureCount count: Int) -> Duration? { - switch count { - case 1: - return nil - case 2: - return .seconds(1) - case 3: - return .seconds(5) - case 4: - return .seconds(30) - default: - return .seconds(300) - } - } - - private func cleanupStaleEntriesIfNeeded() { - let now = ContinuousClock.now - guard now - lastCleanup > Self.cleanupInterval else { return } - - lastCleanup = now - let threshold = now - Self.staleEntryThreshold - - let staleKeys = records.filter { $0.value.lastUpdated < threshold }.map(\.key) - for key in staleKeys { - records.removeValue(forKey: key) - } - - if !staleKeys.isEmpty { - Self.logger.info("Cleaned up \(staleKeys.count) stale rate limit entries") - } - } -} diff --git a/TablePro/Core/MCP/MCPResourceHandler.swift b/TablePro/Core/MCP/MCPResourceHandler.swift deleted file mode 100644 index aa4b36226..000000000 --- a/TablePro/Core/MCP/MCPResourceHandler.swift +++ /dev/null @@ -1,137 +0,0 @@ -import Foundation -import os - -final class MCPResourceHandler: Sendable { - private static let logger = Logger(subsystem: "com.TablePro", category: "MCPResourceHandler") - - private let bridge: MCPConnectionBridge - private let authPolicy: MCPAuthPolicy - - init(bridge: MCPConnectionBridge, authPolicy: MCPAuthPolicy) { - self.bridge = bridge - self.authPolicy = authPolicy - } - - func handleResourceRead(uri: String, sessionId: String) async throws -> MCPResourceReadResult { - guard let components = URLComponents(string: uri) else { - throw MCPError.invalidParams("Malformed URI: \(uri)") - } - - guard components.scheme == "tablepro" else { - throw MCPError.invalidParams("Unsupported URI scheme: \(components.scheme ?? "nil")") - } - - let pathSegments = parsePathSegments(from: uri) - - if pathSegments == ["connections"] { - return try await handleConnectionsList(uri: uri) - } - - if pathSegments.count == 3, - pathSegments[0] == "connections", - pathSegments[2] == "schema" - { - guard let connectionId = UUID(uuidString: pathSegments[1]) else { - throw MCPError.invalidParams("Invalid connection UUID in URI") - } - return try await handleSchemaResource(uri: uri, connectionId: connectionId, sessionId: sessionId) - } - - if pathSegments.count == 3, - pathSegments[0] == "connections", - pathSegments[2] == "history" - { - guard let connectionId = UUID(uuidString: pathSegments[1]) else { - throw MCPError.invalidParams("Invalid connection UUID in URI") - } - let queryItems = components.queryItems ?? [] - return try await handleHistoryResource( - uri: uri, - connectionId: connectionId, - queryItems: queryItems, - sessionId: sessionId - ) - } - - throw MCPError.invalidParams("Unknown resource URI: \(uri)") - } - - private func handleConnectionsList(uri: String) async throws -> MCPResourceReadResult { - let result = await bridge.listConnections() - let jsonString = encodeJSON(result) - return MCPResourceReadResult(contents: [ - MCPResourceContent(uri: uri, mimeType: "application/json", text: jsonString) - ]) - } - - private func handleSchemaResource(uri: String, connectionId: UUID, sessionId: String) async throws -> MCPResourceReadResult { - try await authPolicy.resolveAndAuthorize( - token: MCPToolHandler.anonymousFullAccessToken, - tool: "describe_table", - connectionId: connectionId, - sessionId: sessionId - ) - let result = try await bridge.fetchSchemaResource(connectionId: connectionId) - let jsonString = encodeJSON(result) - return MCPResourceReadResult(contents: [ - MCPResourceContent(uri: uri, mimeType: "application/json", text: jsonString) - ]) - } - - private func handleHistoryResource( - uri: String, - connectionId: UUID, - queryItems: [URLQueryItem], - sessionId: String - ) async throws -> MCPResourceReadResult { - try await authPolicy.resolveAndAuthorize( - token: MCPToolHandler.anonymousFullAccessToken, - tool: "search_query_history", - connectionId: connectionId, - sessionId: sessionId - ) - let limit = queryItems.first(where: { $0.name == "limit" }) - .flatMap { $0.value } - .flatMap { Int($0) } - ?? 50 - - let clampedLimit = min(max(limit, 1), 500) - let search = queryItems.first(where: { $0.name == "search" })?.value - let dateFilter = queryItems.first(where: { $0.name == "date_filter" })?.value - - let result = try await bridge.fetchHistoryResource( - connectionId: connectionId, - limit: clampedLimit, - search: search, - dateFilter: dateFilter - ) - let jsonString = encodeJSON(result) - return MCPResourceReadResult(contents: [ - MCPResourceContent(uri: uri, mimeType: "application/json", text: jsonString) - ]) - } - - private func parsePathSegments(from uri: String) -> [String] { - guard let range = uri.range(of: "://") else { return [] } - let afterScheme = String(uri[range.upperBound...]) - let pathOnly: String - if let queryStart = afterScheme.firstIndex(of: "?") { - pathOnly = String(afterScheme[.. String { - let encoder = JSONEncoder() - encoder.outputFormatting = [.sortedKeys] - guard let data = try? encoder.encode(value), - let string = String(data: data, encoding: .utf8) - else { - Self.logger.warning("Failed to encode JSON value") - return "{}" - } - return string - } -} diff --git a/TablePro/Core/MCP/MCPRouteHandler.swift b/TablePro/Core/MCP/MCPRouteHandler.swift deleted file mode 100644 index 3b7a39c12..000000000 --- a/TablePro/Core/MCP/MCPRouteHandler.swift +++ /dev/null @@ -1,7 +0,0 @@ -import Foundation - -protocol MCPRouteHandler: Sendable { - var methods: [HTTPRequest.Method] { get } - var path: String { get } - func handle(_ request: HTTPRequest) async -> MCPRouter.RouteResult -} diff --git a/TablePro/Core/MCP/MCPRouter.swift b/TablePro/Core/MCP/MCPRouter.swift deleted file mode 100644 index 1561e27dd..000000000 --- a/TablePro/Core/MCP/MCPRouter.swift +++ /dev/null @@ -1,485 +0,0 @@ -import Foundation - -final class MCPRouter: Sendable { - enum RouteResult: Sendable { - case json(Data, sessionId: String?) - case sseStream(sessionId: String) - case accepted - case noContent - case httpError(status: Int, message: String) - case httpErrorWithHeaders(status: Int, message: String, extraHeaders: [(String, String)]) - } - - private let routes: [any MCPRouteHandler] - - init(routes: [any MCPRouteHandler]) { - self.routes = routes - } - - func handle(_ request: HTTPRequest) async -> RouteResult { - if request.path.hasPrefix("/.well-known/") { - return .httpError(status: 404, message: "Not found") - } - - if request.method == .options { - return .noContent - } - - guard let route = match(request) else { - return .httpError(status: 404, message: "Not found") - } - - return await route.handle(request) - } - - private func match(_ request: HTTPRequest) -> (any MCPRouteHandler)? { - let normalizedPath = Self.canonicalPath(request.path) - return routes.first { route in - route.path == normalizedPath && route.methods.contains(request.method) - } - } - - private static func canonicalPath(_ path: String) -> String { - if let queryIndex = path.firstIndex(of: "?") { - return String(path[.. [MCPToolDefinition] { - connectionTools() + schemaTools() + queryAndExportTools() + integrationTools() - } - - private static func connectionTools() -> [MCPToolDefinition] { - [ - MCPToolDefinition( - name: "list_connections", - description: "List all saved database connections with their status", - inputSchema: .object([ - "type": "object", - "properties": .object([:]), - "required": .array([]) - ]) - ), - MCPToolDefinition( - name: "connect", - description: "Connect to a saved database", - inputSchema: .object([ - "type": "object", - "properties": .object([ - "connection_id": .object([ - "type": "string", - "description": "UUID of the saved connection" - ]) - ]), - "required": .array([.string("connection_id")]) - ]) - ), - MCPToolDefinition( - name: "disconnect", - description: "Disconnect from a database", - inputSchema: .object([ - "type": "object", - "properties": .object([ - "connection_id": .object([ - "type": "string", - "description": "UUID of the connection to disconnect" - ]) - ]), - "required": .array([.string("connection_id")]) - ]) - ), - MCPToolDefinition( - name: "get_connection_status", - description: "Get detailed status of a database connection", - inputSchema: .object([ - "type": "object", - "properties": .object([ - "connection_id": .object([ - "type": "string", - "description": "UUID of the connection" - ]) - ]), - "required": .array([.string("connection_id")]) - ]) - ), - MCPToolDefinition( - name: "switch_database", - description: "Switch the active database on a connection", - inputSchema: .object([ - "type": "object", - "properties": .object([ - "connection_id": .object([ - "type": "string", - "description": "UUID of the connection" - ]), - "database": .object([ - "type": "string", - "description": "Database name to switch to" - ]) - ]), - "required": .array([.string("connection_id"), .string("database")]) - ]) - ), - MCPToolDefinition( - name: "switch_schema", - description: "Switch the active schema on a connection", - inputSchema: .object([ - "type": "object", - "properties": .object([ - "connection_id": .object([ - "type": "string", - "description": "UUID of the connection" - ]), - "schema": .object([ - "type": "string", - "description": "Schema name to switch to" - ]) - ]), - "required": .array([.string("connection_id"), .string("schema")]) - ]) - ) - ] - } - - private static func schemaTools() -> [MCPToolDefinition] { - [ - MCPToolDefinition( - name: "list_databases", - description: "List all databases on the server", - inputSchema: .object([ - "type": "object", - "properties": .object([ - "connection_id": .object([ - "type": "string", - "description": "UUID of the connection" - ]) - ]), - "required": .array([.string("connection_id")]) - ]) - ), - MCPToolDefinition( - name: "list_schemas", - description: "List schemas in a database", - inputSchema: .object([ - "type": "object", - "properties": .object([ - "connection_id": .object([ - "type": "string", - "description": "UUID of the connection" - ]), - "database": .object([ - "type": "string", - "description": "Database name (uses current if omitted)" - ]) - ]), - "required": .array([.string("connection_id")]) - ]) - ), - MCPToolDefinition( - name: "list_tables", - description: "List tables and views in a database", - inputSchema: .object([ - "type": "object", - "properties": .object([ - "connection_id": .object([ - "type": "string", - "description": "UUID of the connection" - ]), - "database": .object([ - "type": "string", - "description": "Database name (uses current if omitted)" - ]), - "schema": .object([ - "type": "string", - "description": "Schema name (uses current if omitted)" - ]), - "include_row_counts": .object([ - "type": "boolean", - "description": "Include approximate row counts (default false)" - ]) - ]), - "required": .array([.string("connection_id")]) - ]) - ), - MCPToolDefinition( - name: "describe_table", - description: "Get detailed table structure: columns, indexes, foreign keys, and DDL", - inputSchema: .object([ - "type": "object", - "properties": .object([ - "connection_id": .object([ - "type": "string", - "description": "UUID of the connection" - ]), - "table": .object([ - "type": "string", - "description": "Table name" - ]), - "schema": .object([ - "type": "string", - "description": "Schema name (uses current if omitted)" - ]) - ]), - "required": .array([.string("connection_id"), .string("table")]) - ]) - ), - MCPToolDefinition( - name: "get_table_ddl", - description: "Get the CREATE TABLE DDL statement for a table", - inputSchema: .object([ - "type": "object", - "properties": .object([ - "connection_id": .object([ - "type": "string", - "description": "UUID of the connection" - ]), - "table": .object([ - "type": "string", - "description": "Table name" - ]), - "schema": .object([ - "type": "string", - "description": "Schema name (uses current if omitted)" - ]) - ]), - "required": .array([.string("connection_id"), .string("table")]) - ]) - ) - ] - } - - private static func queryAndExportTools() -> [MCPToolDefinition] { - [ - MCPToolDefinition( - name: "execute_query", - description: "Execute a SQL query. All queries are subject to the connection's safe mode policy. " - + "DROP/TRUNCATE/ALTER...DROP must use the confirm_destructive_operation tool.", - inputSchema: .object([ - "type": "object", - "properties": .object([ - "connection_id": .object([ - "type": "string", - "description": "UUID of the connection" - ]), - "query": .object([ - "type": "string", - "description": "SQL or NoSQL query text" - ]), - "max_rows": .object([ - "type": "integer", - "description": "Maximum rows to return (default 500, max 10000)" - ]), - "timeout_seconds": .object([ - "type": "integer", - "description": "Query timeout in seconds (default 30, max 300)" - ]), - "database": .object([ - "type": "string", - "description": "Switch to this database before executing" - ]), - "schema": .object([ - "type": "string", - "description": "Switch to this schema before executing" - ]) - ]), - "required": .array([.string("connection_id"), .string("query")]) - ]) - ), - MCPToolDefinition( - name: "export_data", - description: "Export query results or table data to CSV, JSON, or SQL", - inputSchema: .object([ - "type": "object", - "properties": .object([ - "connection_id": .object([ - "type": "string", - "description": "UUID of the connection" - ]), - "format": .object([ - "type": "string", - "description": "Export format: csv, json, or sql", - "enum": .array([.string("csv"), .string("json"), .string("sql")]) - ]), - "query": .object([ - "type": "string", - "description": "SQL query to export results from" - ]), - "tables": .object([ - "type": "array", - "description": "Table names to export (alternative to query)", - "items": .object(["type": "string"]) - ]), - "output_path": .object([ - "type": "string", - "description": "File path inside the user's Downloads directory (returns inline data if omitted). Paths outside Downloads are rejected." - ]), - "max_rows": .object([ - "type": "integer", - "description": "Maximum rows to export (default 50000)" - ]) - ]), - "required": .array([.string("connection_id"), .string("format")]) - ]) - ), - MCPToolDefinition( - name: "confirm_destructive_operation", - description: "Execute a destructive DDL query (DROP, TRUNCATE, ALTER...DROP) after explicit confirmation.", - inputSchema: .object([ - "type": "object", - "properties": .object([ - "connection_id": .object([ - "type": "string", - "description": "UUID of the active connection" - ]), - "query": .object([ - "type": "string", - "description": "The destructive query to execute" - ]), - "confirmation_phrase": .object([ - "type": "string", - "description": "Must be exactly: I understand this is irreversible" - ]) - ]), - "required": .array([ - .string("connection_id"), - .string("query"), - .string("confirmation_phrase") - ]) - ]) - ) - ] - } - - private static func integrationTools() -> [MCPToolDefinition] { - [ - MCPToolDefinition( - name: "list_recent_tabs", - description: "List currently open tabs across all TablePro windows. " - + "Returns connection, tab type, table name, and titles for each tab.", - inputSchema: .object([ - "type": "object", - "properties": .object([ - "limit": .object([ - "type": "integer", - "description": "Maximum number of tabs to return (default 20, max 500)" - ]) - ]), - "required": .array([]) - ]) - ), - MCPToolDefinition( - name: "search_query_history", - description: "Search saved query history. " - + "Returns matching entries with execution time, row count, and outcome.", - inputSchema: .object([ - "type": "object", - "properties": .object([ - "query": .object([ - "type": "string", - "description": "Search text (full-text matched against the query column)" - ]), - "connection_id": .object([ - "type": "string", - "description": "Restrict to a specific connection (UUID, optional)" - ]), - "limit": .object([ - "type": "integer", - "description": "Maximum number of entries to return (default 50, max 500)" - ]), - "since": .object([ - "type": "number", - "description": "Earliest executed_at to include, Unix epoch seconds (inclusive, optional)" - ]), - "until": .object([ - "type": "number", - "description": "Latest executed_at to include, Unix epoch seconds (inclusive, optional)" - ]) - ]), - "required": .array([.string("query")]) - ]) - ), - MCPToolDefinition( - name: "open_connection_window", - description: "Open a TablePro window for a saved connection (focuses if already open).", - inputSchema: .object([ - "type": "object", - "properties": .object([ - "connection_id": .object([ - "type": "string", - "description": "UUID of the saved connection" - ]) - ]), - "required": .array([.string("connection_id")]) - ]) - ), - MCPToolDefinition( - name: "open_table_tab", - description: "Open a table tab in TablePro for the given connection.", - inputSchema: .object([ - "type": "object", - "properties": .object([ - "connection_id": .object([ - "type": "string", - "description": "UUID of the connection" - ]), - "table_name": .object([ - "type": "string", - "description": "Table name to open" - ]), - "database_name": .object([ - "type": "string", - "description": "Database name (uses connection's current database if omitted)" - ]), - "schema_name": .object([ - "type": "string", - "description": "Schema name (for multi-schema databases)" - ]) - ]), - "required": .array([.string("connection_id"), .string("table_name")]) - ]) - ), - MCPToolDefinition( - name: "focus_query_tab", - description: "Focus an already-open tab by id (returned from list_recent_tabs).", - inputSchema: .object([ - "type": "object", - "properties": .object([ - "tab_id": .object([ - "type": "string", - "description": "UUID of the tab to focus" - ]) - ]), - "required": .array([.string("tab_id")]) - ]) - ) - ] - } -} - -extension MCPRouter { - static func resourceDefinitions() -> [MCPResourceDefinition] { - [ - MCPResourceDefinition( - uri: "tablepro://connections", - name: "Saved Connections", - description: "List of all saved database connections with metadata", - mimeType: "application/json" - ), - MCPResourceDefinition( - uri: "tablepro://connections/{id}/schema", - name: "Database Schema", - description: "Tables, columns, indexes, and foreign keys for a connected database", - mimeType: "application/json" - ), - MCPResourceDefinition( - uri: "tablepro://connections/{id}/history", - name: "Query History", - description: "Recent query history for a connection (supports ?limit=, ?search=, ?date_filter=)", - mimeType: "application/json" - ) - ] - } -} diff --git a/TablePro/Core/MCP/MCPServer.swift b/TablePro/Core/MCP/MCPServer.swift deleted file mode 100644 index a499d6581..000000000 --- a/TablePro/Core/MCP/MCPServer.swift +++ /dev/null @@ -1,481 +0,0 @@ -import Foundation -import Network -import os -import Security - -actor MCPServer { - struct SessionSnapshot: Sendable, Identifiable { - let id: String - let clientName: String - let clientVersion: String? - let connectedSince: Date - let lastActivityAt: Date - let tokenName: String? - let remoteAddress: String? - } - - private static let logger = Logger(subsystem: "com.TablePro", category: "MCPServer") - - private static let maxSessions = 10 - private static let idleTimeout: TimeInterval = 300 - private static let cleanupInterval: TimeInterval = 60 - private static let maxReadSize = 1_048_576 - private static let maxBufferSize = 10 * 1_024 * 1_024 - - private var allowRemoteAccess: Bool = false - private var listener: NWListener? - private var sessions: [String: MCPSession] = [:] - private var cleanupTask: Task? - private let stateCallback: @Sendable (MCPServerState) -> Void - private var router: MCPRouter? - - private(set) var tokenStore: MCPTokenStore? - private(set) var rateLimiter: MCPRateLimiter? - - private(set) var toolCallHandler: (@Sendable (String, JSONValue?, String, MCPAuthToken?) async throws -> MCPToolResult)? - private(set) var resourceReadHandler: (@Sendable (String, String) async throws -> MCPResourceReadResult)? - private(set) var sessionCleanupHandler: (@Sendable (String) async -> Void)? - - init(stateCallback: @escaping @Sendable (MCPServerState) -> Void) { - self.stateCallback = stateCallback - } - - func setRouter(_ router: MCPRouter) { - self.router = router - } - - func setTokenStore(_ store: MCPTokenStore) { - self.tokenStore = store - } - - func setRateLimiter(_ limiter: MCPRateLimiter) { - self.rateLimiter = limiter - } - - func setToolCallHandler(_ handler: @escaping @Sendable (String, JSONValue?, String, MCPAuthToken?) async throws -> MCPToolResult) { - self.toolCallHandler = handler - } - - func setResourceReadHandler(_ handler: @escaping @Sendable (String, String) async throws -> MCPResourceReadResult) { - self.resourceReadHandler = handler - } - - func setSessionCleanupHandler(_ handler: @escaping @Sendable (String) async -> Void) { - self.sessionCleanupHandler = handler - } - - func start(port: UInt16, allowRemoteAccess: Bool = false, tlsIdentity: SecIdentity? = nil) throws { - guard listener == nil else { - Self.logger.warning("Server already running, ignoring start request") - return - } - - stateCallback(.starting) - self.allowRemoteAccess = allowRemoteAccess - - let params: NWParameters - - if allowRemoteAccess, let identity = tlsIdentity { - let tlsOptions = NWProtocolTLS.Options() - guard let secIdentity = sec_identity_create(identity) else { - stateCallback(.failed("Failed to create TLS identity")) - return - } - sec_protocol_options_set_local_identity(tlsOptions.securityProtocolOptions, secIdentity) - sec_protocol_options_set_min_tls_protocol_version(tlsOptions.securityProtocolOptions, .TLSv12) - params = NWParameters(tls: tlsOptions, tcp: NWProtocolTCP.Options()) - params.requiredLocalEndpoint = NWEndpoint.hostPort( - host: .ipv4(.any), - port: NWEndpoint.Port(rawValue: port) ?? 23_508 - ) - params.allowLocalEndpointReuse = true - } else if allowRemoteAccess { - params = NWParameters.tcp - params.requiredLocalEndpoint = NWEndpoint.hostPort( - host: .ipv4(.any), - port: NWEndpoint.Port(rawValue: port) ?? 23_508 - ) - params.allowLocalEndpointReuse = true - } else { - params = NWParameters.tcp - params.requiredLocalEndpoint = NWEndpoint.hostPort( - host: .ipv4(.loopback), - port: NWEndpoint.Port(rawValue: port) ?? 23_508 - ) - params.allowLocalEndpointReuse = true - } - - let newListener = try NWListener(using: params) - self.listener = newListener - - newListener.stateUpdateHandler = { [weak self] state in - guard let self else { return } - Task { - await self.handleListenerState(state, listener: newListener) - } - } - - newListener.newConnectionHandler = { [weak self] connection in - guard let self else { return } - Task { - await self.handleNewConnection(connection) - } - } - - newListener.start(queue: .global(qos: .userInitiated)) - startCleanupTimer() - } - - func stop() async { - Self.logger.info("Stopping MCP server") - - cleanupTask?.cancel() - cleanupTask = nil - - let sessionIds = Array(sessions.keys) - for (_, session) in sessions { - await session.cancelAllTasks() - await session.cancelSSEConnection() - } - - if let cleanupHandler = sessionCleanupHandler { - for id in sessionIds { - await cleanupHandler(id) - } - } - - sessions.removeAll() - - if let currentListener = listener { - listener = nil - await withCheckedContinuation { (continuation: CheckedContinuation) in - currentListener.stateUpdateHandler = { state in - if case .cancelled = state { - continuation.resume() - } - } - currentListener.cancel() - } - } - } - - var sessionCount: Int { - sessions.count - } - - func sessionSnapshots() async -> [SessionSnapshot] { - let now = ContinuousClock.now - var snapshots: [SessionSnapshot] = [] - for (_, session) in sessions { - let info = await session.clientInfo - let created = await session.createdAt - let lastActive = await session.lastActivityAt - let sessionTokenName = await session.tokenName - let sessionRemoteAddress = await session.remoteAddress - let connectedElapsed = now - created - let activeElapsed = now - lastActive - snapshots.append(SessionSnapshot( - id: session.id, - clientName: info?.name ?? String(localized: "Unknown"), - clientVersion: info?.version, - connectedSince: Date.now - TimeInterval(connectedElapsed.components.seconds), - lastActivityAt: Date.now - TimeInterval(activeElapsed.components.seconds), - tokenName: sessionTokenName, - remoteAddress: sessionRemoteAddress - )) - } - return snapshots - } - - private func handleListenerState(_ state: NWListener.State, listener: NWListener) { - switch state { - case .ready: - let port = listener.port?.rawValue ?? 0 - let bindAddress = allowRemoteAccess ? "0.0.0.0" : "127.0.0.1" - Self.logger.info("MCP server listening on \(bindAddress):\(port)") - stateCallback(.running(port: port)) - - case .failed(let error): - Self.logger.error("MCP server listener failed: \(error.localizedDescription)") - stateCallback(.failed(error.localizedDescription)) - self.listener = nil - listener.cancel() - - case .cancelled: - Self.logger.debug("MCP server listener cancelled") - - default: - break - } - } - - private func handleNewConnection(_ connection: NWConnection) { - connection.stateUpdateHandler = { [weak self] state in - guard let self else { return } - switch state { - case .ready: - Task { - await self.readRequest(from: connection, buffer: Data()) - } - case .failed(let error): - Self.logger.debug("Connection failed: \(error.localizedDescription)") - connection.cancel() - default: - break - } - } - connection.start(queue: .global(qos: .userInitiated)) - } - - private func readRequest(from connection: NWConnection, buffer: Data) { - connection.receive(minimumIncompleteLength: 1, maximumLength: Self.maxReadSize) { [weak self] content, _, isComplete, error in - guard let self else { return } - - Task { - await self.processReceivedData( - connection: connection, - existingBuffer: buffer, - content: content, - isComplete: isComplete, - error: error - ) - } - } - } - - private func processReceivedData( - connection: NWConnection, - existingBuffer: Data, - content: Data?, - isComplete: Bool, - error: NWError? - ) { - if let error { - Self.logger.debug("Read error: \(error.localizedDescription)") - connection.cancel() - return - } - - var buffer = existingBuffer - if let content { - buffer.append(content) - } - - if buffer.count > Self.maxBufferSize { - Self.logger.warning("Request buffer exceeds \(Self.maxBufferSize) bytes, rejecting") - sendHTTPError(connection: connection, status: 413, message: "Request entity too large") - return - } - - let parseResult = MCPHTTPParser.parse(buffer) - - switch parseResult { - case .success(let request): - Task { - await self.handleHTTPRequest(request, connection: connection) - } - - case .failure(.incomplete): - if isComplete { - sendHTTPError(connection: connection, status: 400, message: "Incomplete request") - } else { - readRequest(from: connection, buffer: buffer) - } - - case .failure(.bodyTooLarge): - sendHTTPError(connection: connection, status: 400, message: "Request body too large") - - case .failure(let parseError): - Self.logger.warning("Parse error: \(String(describing: parseError))") - sendHTTPError(connection: connection, status: 400, message: "Malformed HTTP request") - } - } - - static let corsHeaders: [(String, String)] = [ - ("Access-Control-Allow-Origin", "http://localhost"), - ("Access-Control-Allow-Methods", "GET, POST, DELETE, OPTIONS"), - ("Access-Control-Allow-Headers", "Content-Type, Mcp-Session-Id, mcp-protocol-version, Authorization"), - ("Access-Control-Expose-Headers", "Mcp-Session-Id"), - ("Access-Control-Max-Age", "86400") - ] - - private func handleHTTPRequest(_ request: HTTPRequest, connection: NWConnection) async { - let remoteIP: String? = { - guard let endpoint = connection.currentPath?.remoteEndpoint, - case .hostPort(let host, _) = endpoint else { return nil } - return "\(host)" - }() - - guard let router else { - sendHTTPError(connection: connection, status: 503, message: "Server not configured") - return - } - - let routedRequest = request.withRemoteIP(remoteIP) - let result = await router.handle(routedRequest) - - switch result { - case .json(let data, let sessionId): - sendJsonResponse(connection: connection, data: data, sessionId: sessionId) - - case .sseStream(let sessionId): - if let session = sessions[sessionId] { - await session.cancelSSEConnection() - await session.setSSEConnection(connection) - } - sendSseHeaders(connection: connection, sessionId: sessionId) - - case .accepted: - sendResponse(connection: connection, status: 202, headers: Self.corsHeaders, body: nil) - - case .noContent: - sendResponse(connection: connection, status: 204, headers: Self.corsHeaders, body: nil) - - case .httpError(let status, let message): - sendHTTPError(connection: connection, status: status, message: message) - - case .httpErrorWithHeaders(let status, let message, let extraHeaders): - sendHTTPErrorWithHeaders(connection: connection, status: status, message: message, extraHeaders: extraHeaders) - } - } - - func createSession() -> MCPSession? { - guard sessions.count < Self.maxSessions else { - Self.logger.warning("Maximum session limit reached (\(Self.maxSessions))") - return nil - } - - let session = MCPSession() - sessions[session.id] = session - Self.logger.info("Created session \(session.id) (total: \(self.sessions.count))") - return session - } - - func session(for sessionId: String) -> MCPSession? { - sessions[sessionId] - } - - func removeSession(_ sessionId: String) async { - guard let session = sessions.removeValue(forKey: sessionId) else { return } - await session.cancelAllTasks() - await session.cancelSSEConnection() - try? await session.transition(to: .terminated(reason: .removed)) - - if let cleanupHandler = sessionCleanupHandler { - await cleanupHandler(sessionId) - } - - Self.logger.info("Removed session \(sessionId) (total: \(self.sessions.count))") - } - - private func startCleanupTimer() { - cleanupTask?.cancel() - cleanupTask = Task { [weak self] in - while !Task.isCancelled { - try? await Task.sleep(for: .seconds(Self.cleanupInterval)) - guard !Task.isCancelled else { break } - await self?.cleanupIdleSessions() - } - } - } - - private func cleanupIdleSessions() async { - let now = ContinuousClock.now - var removed: [String] = [] - - for (id, session) in sessions { - let lastActivity = await session.lastActivityAt - let idle = now - lastActivity - if idle > .seconds(Self.idleTimeout) { - await session.cancelAllTasks() - await session.cancelSSEConnection() - try? await session.transition(to: .terminated(reason: .idleTimeout)) - sessions.removeValue(forKey: id) - - if let cleanupHandler = sessionCleanupHandler { - await cleanupHandler(id) - } - - removed.append(id) - } - } - - if !removed.isEmpty { - Self.logger.info("Cleaned up \(removed.count) idle session(s)") - } - } - - func sendResponse(connection: NWConnection, status: Int, headers: [(String, String)], body: Data?) { - let statusText = MCPHTTPParser.statusText(for: status) - let responseData = MCPHTTPParser.buildResponse( - status: status, - statusText: statusText, - headers: headers, - body: body - ) - connection.send(content: responseData, completion: .contentProcessed { error in - if let error { - Self.logger.debug("Send error: \(error.localizedDescription)") - } - if status != 200 || headers.contains(where: { $0.0.lowercased() == "connection" && $0.1.lowercased() == "close" }) { - connection.cancel() - } - }) - } - - func sendJsonResponse(connection: NWConnection, data: Data, sessionId: String?) { - var headers: [(String, String)] = [ - ("Content-Type", "application/json"), - ("Connection", "close") - ] - headers.append(contentsOf: Self.corsHeaders) - if let sessionId { - headers.append(("Mcp-Session-Id", sessionId)) - } - sendResponse(connection: connection, status: 200, headers: headers, body: data) - } - - func sendSseHeaders(connection: NWConnection, sessionId: String) { - let headerData = MCPHTTPParser.buildSSEHeaders( - sessionId: sessionId, - corsHeaders: Self.corsHeaders - ) - connection.send(content: headerData, completion: .contentProcessed { error in - if let error { - Self.logger.debug("SSE header send error: \(error.localizedDescription)") - } - }) - } - - func sendSseEvent(connection: NWConnection, data: Data, eventId: String? = nil) { - let eventData = MCPHTTPParser.buildSSEEvent(data: data, id: eventId) - connection.send(content: eventData, completion: .contentProcessed { error in - if let error { - Self.logger.debug("SSE event send error: \(error.localizedDescription)") - } - }) - } - - func sendHTTPError(connection: NWConnection, status: Int, message: String) { - let body: [String: String] = ["error": message] - let data = (try? JSONEncoder().encode(body)) ?? Data() - var headers: [(String, String)] = [ - ("Content-Type", "application/json"), - ("Connection", "close") - ] - headers.append(contentsOf: Self.corsHeaders) - sendResponse(connection: connection, status: status, headers: headers, body: data) - } - - func sendHTTPErrorWithHeaders(connection: NWConnection, status: Int, message: String, extraHeaders: [(String, String)]) { - let body: [String: String] = ["error": message] - let data = (try? JSONEncoder().encode(body)) ?? Data() - var headers: [(String, String)] = [ - ("Content-Type", "application/json"), - ("Connection", "close") - ] - headers.append(contentsOf: extraHeaders) - headers.append(contentsOf: Self.corsHeaders) - sendResponse(connection: connection, status: status, headers: headers, body: data) - } -} diff --git a/TablePro/Core/MCP/MCPServerManager.swift b/TablePro/Core/MCP/MCPServerManager.swift index e0616748b..3a6b90516 100644 --- a/TablePro/Core/MCP/MCPServerManager.swift +++ b/TablePro/Core/MCP/MCPServerManager.swift @@ -11,19 +11,37 @@ enum MCPServerState: Sendable, Equatable { @MainActor @Observable final class MCPServerManager { + struct SessionSnapshot: Sendable, Identifiable { + let id: String + let clientName: String + let clientVersion: String? + let connectedSince: Date + let lastActivityAt: Date + let tokenName: String? + let remoteAddress: String? + } + private static let logger = Logger(subsystem: "com.TablePro", category: "MCPServerManager") static let shared = MCPServerManager() private(set) var state: MCPServerState = .stopped - private(set) var connectedClients: [MCPServer.SessionSnapshot] = [] - private var server: MCPServer? - private var clientRefreshTask: Task? - private var serverGeneration: Int = 0 + private(set) var connectedClients: [SessionSnapshot] = [] private(set) var tokenStore: MCPTokenStore? + + private var transport: MCPHttpServerTransport? + private var dispatcher: MCPProtocolDispatcher? + private var sessionStore: MCPSessionStore? + private var rateLimiter: MCPRateLimiter? + private var dispatchTask: Task? + private var stateTask: Task? + private var sessionEventsTask: Task? + private var clientRefreshTask: Task? private var tlsManager: MCPTLSManager? private var bridgeTokenId: UUID? private var internalBridgeToken: String? + private var serverGeneration: Int = 0 + private var revocationObserverId: UUID? var isRunning: Bool { if case .running = state { return true } else { return false } @@ -31,109 +49,115 @@ final class MCPServerManager { var connectedClientCount: Int { get async { - guard let server else { return 0 } - return await server.sessionCount + guard let sessionStore else { return 0 } + return await sessionStore.count() } } private init() {} func start(port: UInt16) async { - if server != nil { + if transport != nil { await stop() } + Self.removeStaleHandshakeFileIfNeeded() + serverGeneration += 1 let generation = serverGeneration - let newServer = MCPServer { [weak self] newState in - Task { @MainActor in - guard let self, self.serverGeneration == generation else { return } - self.state = newState - } - } - - self.server = newServer + state = .starting let newTokenStore = MCPTokenStore() await newTokenStore.loadFromDisk() - self.tokenStore = newTokenStore + tokenStore = newTokenStore - let rateLimiter = MCPRateLimiter() + let bridgeResult = await newTokenStore.generate( + name: MCPTokenStore.stdioBridgeTokenName, + permissions: .fullAccess + ) + bridgeTokenId = bridgeResult.token.id + internalBridgeToken = bridgeResult.plaintext - let bridge = MCPConnectionBridge() - let authPolicy = MCPAuthPolicy() - let toolHandler = MCPToolHandler(bridge: bridge, authPolicy: authPolicy) - let resourceHandler = MCPResourceHandler(bridge: bridge, authPolicy: authPolicy) + let settings = AppSettingsManager.shared.mcp + let configuration: MCPHttpServerConfiguration + do { + configuration = try await makeConfiguration(port: port, settings: settings) + } catch { + Self.logger.error("MCP TLS configuration failed: \(error.localizedDescription, privacy: .public)") + state = .failed("TLS certificate generation failed") + await cleanupBridgeToken() + tokenStore = nil + return + } - await newServer.setTokenStore(newTokenStore) - await newServer.setRateLimiter(rateLimiter) + let newSessionStore = MCPSessionStore(policy: .standard) + await newSessionStore.startCleanup() + sessionStore = newSessionStore - await newServer.setToolCallHandler { name, arguments, sessionId, token in - try await toolHandler.handleToolCall(name: name, arguments: arguments, sessionId: sessionId, token: token) - } - await newServer.setResourceReadHandler { uri, sessionId in - try await resourceHandler.handleResourceRead(uri: uri, sessionId: sessionId) - } - await newServer.setSessionCleanupHandler { sessionId in - await authPolicy.clearSession(sessionId) - } + let newRateLimiter = MCPRateLimiter() + rateLimiter = newRateLimiter - let protocolHandler = MCPProtocolHandler( - server: newServer, + let authenticator = MCPBearerTokenAuthenticator( tokenStore: newTokenStore, - rateLimiter: rateLimiter + rateLimiter: newRateLimiter ) - let exchangeHandler = IntegrationsExchangeHandler.live() - let router = MCPRouter(routes: [protocolHandler, exchangeHandler]) - await newServer.setRouter(router) - let bridgeResult = await newTokenStore.generate( - name: MCPTokenStore.stdioBridgeTokenName, - permissions: .fullAccess + let newTransport = MCPHttpServerTransport( + configuration: configuration, + sessionStore: newSessionStore, + authenticator: authenticator ) - self.bridgeTokenId = bridgeResult.token.id - self.internalBridgeToken = bridgeResult.plaintext + transport = newTransport - do { - let settings = AppSettingsManager.shared.mcp - - var tlsIdentity: SecIdentity? - if settings.allowRemoteConnections { - let manager = MCPTLSManager() - self.tlsManager = manager - do { - tlsIdentity = try await manager.loadOrGenerate() - } catch { - Self.logger.error("Failed to generate TLS certificate: \(error.localizedDescription)") - state = .failed("TLS certificate generation failed") - return - } - } + let progressSink = TransportProgressSink(transport: newTransport) + let services = MCPToolServices( + connectionBridge: MCPConnectionBridge(), + authPolicy: MCPAuthPolicy() + ) - try await newServer.start( - port: port, - allowRemoteAccess: settings.allowRemoteConnections, - tlsIdentity: tlsIdentity - ) - let certFingerprint = await tlsManager?.fingerprint - writeHandshakeFile(port: port, tlsCertFingerprint: certFingerprint) + let handlers: [any MCPMethodHandler] = [ + InitializeHandler(), + PingHandler(), + ToolsListHandler(), + ToolsCallHandler(services: services), + ResourcesListHandler(services: services), + ResourcesReadHandler(services: services), + ResourcesTemplatesListHandler(), + PromptsListHandler(), + PromptsGetHandler(), + LoggingSetLevelHandler(), + CompletionCompleteHandler() + ] + + let newDispatcher = MCPProtocolDispatcher( + handlers: handlers, + sessionStore: newSessionStore, + progressSink: progressSink + ) + dispatcher = newDispatcher + + startDispatchLoop(transport: newTransport, dispatcher: newDispatcher, generation: generation) + startStateLoop(transport: newTransport, generation: generation) + startSessionEventsLoop(sessionStore: newSessionStore, generation: generation) + await registerRevocationObserver( + tokenStore: newTokenStore, + sessionStore: newSessionStore, + dispatcher: newDispatcher, + generation: generation + ) + + do { + try await newTransport.start() startClientRefresh() MCPAuditLogger.logServerStarted( port: port, remoteAccess: settings.allowRemoteConnections, - tlsEnabled: tlsIdentity != nil + tlsEnabled: configuration.tls != nil ) } catch { - Self.logger.error("Failed to start MCP server: \(error.localizedDescription)") + Self.logger.error("Failed to start MCP server: \(error.localizedDescription, privacy: .public)") state = .failed(error.localizedDescription) - if let bridgeId = bridgeTokenId { - await tokenStore?.delete(tokenId: bridgeId) - bridgeTokenId = nil - } - server = nil - self.tokenStore = nil - self.tlsManager = nil - self.internalBridgeToken = nil + await teardown() } } @@ -141,16 +165,7 @@ final class MCPServerManager { stopClientRefresh() deleteHandshakeFile() MCPAuditLogger.logServerStopped() - guard let server else { return } - await server.stop() - if let bridgeId = bridgeTokenId { - await tokenStore?.delete(tokenId: bridgeId) - bridgeTokenId = nil - } - self.server = nil - self.tokenStore = nil - self.tlsManager = nil - self.internalBridgeToken = nil + await teardown() state = .stopped } @@ -173,7 +188,7 @@ final class MCPServerManager { do { chosenPort = try MCPPortAllocator.findFreePort(in: 51_000...52_000) } catch { - Self.logger.error("Lazy start failed to allocate port: \(error.localizedDescription)") + Self.logger.error("Lazy start failed to allocate port: \(error.localizedDescription, privacy: .public)") state = .failed(error.localizedDescription) return } @@ -183,10 +198,166 @@ final class MCPServerManager { } func disconnectClient(_ sessionId: String) async { - await server?.removeSession(sessionId) + guard let sessionStore else { return } + await sessionStore.terminate(id: MCPSessionId(sessionId), reason: .clientRequested) await refreshClients() } + private func makeConfiguration( + port: UInt16, + settings: MCPSettings + ) async throws -> MCPHttpServerConfiguration { + if settings.allowRemoteConnections { + let manager = MCPTLSManager() + tlsManager = manager + let identity = try await manager.loadOrGenerate() + let tls = MCPTLSConfiguration(identity: identity) + return .remote(port: port, tls: tls) + } + return .loopback(port: port) + } + + private func startDispatchLoop( + transport: MCPHttpServerTransport, + dispatcher: MCPProtocolDispatcher, + generation: Int + ) { + dispatchTask?.cancel() + dispatchTask = Task { [weak self] in + for await exchange in transport.exchanges { + guard let self else { return } + guard await self.isCurrentGeneration(generation) else { return } + Task { await dispatcher.dispatch(exchange) } + } + } + } + + private func startStateLoop(transport: MCPHttpServerTransport, generation: Int) { + stateTask?.cancel() + stateTask = Task { [weak self] in + for await transportState in transport.listenerState { + guard let self else { return } + await self.applyTransportState(transportState, generation: generation) + } + } + } + + private func startSessionEventsLoop(sessionStore: MCPSessionStore, generation: Int) { + sessionEventsTask?.cancel() + sessionEventsTask = Task { [weak self] in + let stream = await sessionStore.events + for await event in stream { + guard let self else { return } + guard await self.isCurrentGeneration(generation) else { return } + Self.logger.debug("Session event: \(String(describing: event), privacy: .public)") + await self.refreshClients() + } + } + } + + private func isCurrentGeneration(_ generation: Int) -> Bool { + serverGeneration == generation + } + + private func registerRevocationObserver( + tokenStore: MCPTokenStore, + sessionStore: MCPSessionStore, + dispatcher: MCPProtocolDispatcher, + generation: Int + ) async { + let observerId = await tokenStore.addRevocationObserver { [weak self] tokenIdString in + guard let tokenId = UUID(uuidString: tokenIdString) else { return } + guard let self else { return } + await self.handleTokenRevoked( + tokenId: tokenId, + sessionStore: sessionStore, + dispatcher: dispatcher, + generation: generation + ) + } + revocationObserverId = observerId + } + + private func handleTokenRevoked( + tokenId: UUID, + sessionStore: MCPSessionStore, + dispatcher: MCPProtocolDispatcher, + generation: Int + ) async { + guard isCurrentGeneration(generation) else { return } + let cancelledSessions = await dispatcher.cancelInflight(matchingTokenId: tokenId) + let extraSessions = await sessionStore.sessionIds(forPrincipalTokenId: tokenId) + let toTerminate = Set(cancelledSessions + extraSessions) + for sessionId in toTerminate { + await sessionStore.terminate(id: sessionId, reason: .tokenRevoked) + } + if !toTerminate.isEmpty { + Self.logger.info( + "Token \(tokenId.uuidString, privacy: .public) revoked: cancelled \(toTerminate.count, privacy: .public) session(s)" + ) + } + } + + private func applyTransportState(_ transportState: MCPHttpServerState, generation: Int) { + guard isCurrentGeneration(generation) else { return } + switch transportState { + case .idle: + state = .stopped + case .starting: + state = .starting + case .running(let port): + state = .running(port: port) + Task { [weak self] in + guard let self else { return } + let fingerprint = await self.tlsManager?.fingerprint + self.writeHandshakeFile(port: port, tlsCertFingerprint: fingerprint) + } + case .stopped: + state = .stopped + case .failed(let reason): + state = .failed(reason) + } + } + + private func teardown() async { + dispatchTask?.cancel() + dispatchTask = nil + stateTask?.cancel() + stateTask = nil + sessionEventsTask?.cancel() + sessionEventsTask = nil + + if let transport { + await transport.stop() + } + transport = nil + + if let sessionStore { + await sessionStore.shutdown(reason: .serverShutdown) + } + sessionStore = nil + + dispatcher = nil + rateLimiter = nil + tlsManager = nil + + if let observerId = revocationObserverId, let store = tokenStore { + await store.removeRevocationObserver(observerId) + revocationObserverId = nil + } + await cleanupBridgeToken() + tokenStore = nil + connectedClients = [] + } + + private func cleanupBridgeToken() async { + if let bridgeId = bridgeTokenId { + await tokenStore?.delete(tokenId: bridgeId) + bridgeTokenId = nil + } + internalBridgeToken = nil + } + private func startClientRefresh() { clientRefreshTask = Task { [weak self] in while !Task.isCancelled { @@ -203,11 +374,16 @@ final class MCPServerManager { } private func refreshClients() async { - guard let server else { + guard let sessionStore else { connectedClients = [] return } - connectedClients = await server.sessionSnapshots() + let snapshots = await collectSessionSnapshots(from: sessionStore) + connectedClients = snapshots + } + + private func collectSessionSnapshots(from store: MCPSessionStore) async -> [SessionSnapshot] { + await store.snapshotsForUI() } private static let handshakeDirectoryPath: String = { @@ -219,20 +395,27 @@ final class MCPServerManager { "\(handshakeDirectoryPath)/mcp-handshake.json" }() + private struct HandshakeFilePayload: Codable { + let port: Int + let token: String + let pid: Int32 + let protocolVersion: String + let tls: Bool + let tlsCertFingerprint: String? + } + private func writeHandshakeFile(port: UInt16, tlsCertFingerprint: String? = nil) { guard let bridgeToken = internalBridgeToken else { return } let settings = AppSettingsManager.shared.mcp - var handshake: [String: Any] = [ - "port": Int(port), - "token": bridgeToken, - "pid": ProcessInfo.processInfo.processIdentifier, - "protocolVersion": "2025-03-26", - "tls": settings.allowRemoteConnections - ] - if let tlsCertFingerprint { - handshake["tlsCertFingerprint"] = tlsCertFingerprint - } + let payload = HandshakeFilePayload( + port: Int(port), + token: bridgeToken, + pid: ProcessInfo.processInfo.processIdentifier, + protocolVersion: InitializeHandler.supportedProtocolVersion, + tls: settings.allowRemoteConnections, + tlsCertFingerprint: tlsCertFingerprint + ) let fileManager = FileManager.default let directory = Self.handshakeDirectoryPath @@ -249,7 +432,9 @@ final class MCPServerManager { ) } - let data = try JSONSerialization.data(withJSONObject: handshake, options: [.sortedKeys]) + let encoder = JSONEncoder() + encoder.outputFormatting = [.sortedKeys] + let data = try encoder.encode(payload) let url = URL(fileURLWithPath: Self.handshakeFilePath) try data.write(to: url, options: [.atomic]) try fileManager.setAttributes( @@ -257,12 +442,24 @@ final class MCPServerManager { ofItemAtPath: Self.handshakeFilePath ) - Self.logger.info("Wrote MCP handshake file at \(Self.handshakeFilePath)") + Self.logger.info("Wrote MCP handshake file at \(Self.handshakeFilePath, privacy: .public)") } catch { - Self.logger.error("Failed to write MCP handshake file: \(error.localizedDescription)") + Self.logger.error("Failed to write MCP handshake file: \(error.localizedDescription, privacy: .public)") } } + private static func removeStaleHandshakeFileIfNeeded() { + let path = handshakeFilePath + guard FileManager.default.fileExists(atPath: path) else { return } + guard let data = try? Data(contentsOf: URL(fileURLWithPath: path)) else { return } + guard let payload = try? JSONDecoder().decode(HandshakeFilePayload.self, from: data) else { return } + let currentPid = ProcessInfo.processInfo.processIdentifier + if payload.pid == currentPid { return } + if kill(payload.pid, 0) == 0 { return } + try? FileManager.default.removeItem(atPath: path) + Self.logger.info("Removed stale MCP handshake from PID \(payload.pid, privacy: .public)") + } + private func deleteHandshakeFile() { let fileManager = FileManager.default guard fileManager.fileExists(atPath: Self.handshakeFilePath) else { return } @@ -271,7 +468,35 @@ final class MCPServerManager { try fileManager.removeItem(atPath: Self.handshakeFilePath) Self.logger.info("Deleted MCP handshake file") } catch { - Self.logger.error("Failed to delete MCP handshake file: \(error.localizedDescription)") + Self.logger.error("Failed to delete MCP handshake file: \(error.localizedDescription, privacy: .public)") + } + } +} + +private struct TransportProgressSink: MCPProgressSink { + let transport: MCPHttpServerTransport + + func sendNotification(_ notification: JsonRpcNotification, toSession sessionId: MCPSessionId) async { + await transport.sendNotification(notification, toSession: sessionId) + } +} + +private extension MCPSessionStore { + func snapshotsForUI() async -> [MCPServerManager.SessionSnapshot] { + var result: [MCPServerManager.SessionSnapshot] = [] + for session in await allSessions() { + let snapshot = await session.snapshot() + let info = snapshot.clientInfo + result.append(MCPServerManager.SessionSnapshot( + id: snapshot.id.rawValue, + clientName: info?.name ?? String(localized: "Unknown"), + clientVersion: info?.version, + connectedSince: snapshot.createdAt, + lastActivityAt: snapshot.lastActivityAt, + tokenName: nil, + remoteAddress: nil + )) } + return result } } diff --git a/TablePro/Core/MCP/MCPSession.swift b/TablePro/Core/MCP/MCPSession.swift deleted file mode 100644 index a52295cad..000000000 --- a/TablePro/Core/MCP/MCPSession.swift +++ /dev/null @@ -1,95 +0,0 @@ -import Foundation -import Network - -actor MCPSession { - let id: String - let createdAt: ContinuousClock.Instant - - var lastActivityAt: ContinuousClock.Instant - private(set) var phase: MCPSessionPhase = .created - var clientInfo: MCPClientInfo? - var sseConnection: NWConnection? - var runningTasks: [JSONRPCId: Task] = [:] - private(set) var eventCounter: Int = 0 - private(set) var remoteAddress: String? - - var authenticatedTokenId: UUID? { - if case .active(let tokenId, _) = phase { return tokenId } - return nil - } - - var tokenName: String? { - if case .active(_, let tokenName) = phase { return tokenName } - return nil - } - - init() { - self.id = UUID().uuidString - let now = ContinuousClock.now - self.createdAt = now - self.lastActivityAt = now - } - - func markActive() { - lastActivityAt = .now - } - - func cancelAllTasks() { - for (_, task) in runningTasks { - task.cancel() - } - runningTasks.removeAll() - } - - func transition(to next: MCPSessionPhase) throws { - guard isValidTransition(from: phase, to: next) else { - throw MCPError.invalidRequest( - "Invalid session phase transition from \(phase) to \(next)" - ) - } - phase = next - } - - private func isValidTransition(from current: MCPSessionPhase, to next: MCPSessionPhase) -> Bool { - switch (current, next) { - case (.created, .initializing), - (.created, .active), - (.created, .terminated), - (.initializing, .active), - (.initializing, .terminated), - (.active, .terminated): - return true - default: - return false - } - } - - func setClientInfo(_ info: MCPClientInfo?) { - clientInfo = info - } - - func setRemoteAddress(_ address: String?) { - remoteAddress = address - } - - func setSSEConnection(_ connection: NWConnection?) { - sseConnection = connection - } - - func cancelSSEConnection() { - sseConnection?.cancel() - } - - func addRunningTask(_ id: JSONRPCId, task: Task) { - runningTasks[id] = task - } - - func removeRunningTask(_ id: JSONRPCId) -> Task? { - runningTasks.removeValue(forKey: id) - } - - func nextEventId() -> String { - eventCounter += 1 - return String(eventCounter) - } -} diff --git a/TablePro/Core/MCP/MCPSessionPhase.swift b/TablePro/Core/MCP/MCPSessionPhase.swift deleted file mode 100644 index fa502c9f5..000000000 --- a/TablePro/Core/MCP/MCPSessionPhase.swift +++ /dev/null @@ -1,20 +0,0 @@ -import Foundation - -enum MCPSessionTerminationReason: Sendable, Equatable { - case removed - case idleTimeout - case serverStopped - case clientDisconnected -} - -enum MCPSessionPhase: Sendable, Equatable { - case created - case initializing - case active(tokenId: UUID?, tokenName: String?) - case terminated(reason: MCPSessionTerminationReason) - - var isActive: Bool { - if case .active = self { return true } - return false - } -} diff --git a/TablePro/Core/MCP/MCPTLSManager.swift b/TablePro/Core/MCP/MCPTLSManager.swift index f876fca1a..f0fcf74c6 100644 --- a/TablePro/Core/MCP/MCPTLSManager.swift +++ b/TablePro/Core/MCP/MCPTLSManager.swift @@ -1,8 +1,3 @@ -// -// MCPTLSManager.swift -// TablePro -// - import CryptoKit import Foundation import os diff --git a/TablePro/Core/MCP/MCPTokenStore.swift b/TablePro/Core/MCP/MCPTokenStore.swift index dc71eaab8..f1b848848 100644 --- a/TablePro/Core/MCP/MCPTokenStore.swift +++ b/TablePro/Core/MCP/MCPTokenStore.swift @@ -76,7 +76,6 @@ struct MCPAuthToken: Codable, Identifiable, Sendable { case salt case permissions case connectionAccess - case allowedConnectionIds case createdAt case lastUsedAt case expiresAt @@ -95,14 +94,7 @@ struct MCPAuthToken: Codable, Identifiable, Sendable { self.lastUsedAt = try container.decodeIfPresent(Date.self, forKey: .lastUsedAt) self.expiresAt = try container.decodeIfPresent(Date.self, forKey: .expiresAt) self.isActive = try container.decode(Bool.self, forKey: .isActive) - - if let access = try container.decodeIfPresent(ConnectionAccess.self, forKey: .connectionAccess) { - self.connectionAccess = access - } else if let legacyIds = try container.decodeIfPresent(Set.self, forKey: .allowedConnectionIds) { - self.connectionAccess = .limited(legacyIds) - } else { - self.connectionAccess = .all - } + self.connectionAccess = try container.decodeIfPresent(ConnectionAccess.self, forKey: .connectionAccess) ?? .all } func encode(to encoder: Encoder) throws { @@ -161,6 +153,8 @@ actor MCPTokenStore { private var lastSavedAt: ContinuousClock.Instant = .now private static let saveCooldown: Duration = .seconds(60) + private var revocationObservers: [UUID: @Sendable (String) async -> Void] = [:] + init() { let appSupportUrl = FileManager.default.urls(for: .applicationSupportDirectory, in: .userDomainMask).first ?? URL(fileURLWithPath: NSHomeDirectory()).appendingPathComponent("Library/Application Support") @@ -168,6 +162,17 @@ actor MCPTokenStore { self.storageUrl = directory.appendingPathComponent("mcp-tokens.json") } + @discardableResult + func addRevocationObserver(_ handler: @escaping @Sendable (String) async -> Void) -> UUID { + let id = UUID() + revocationObservers[id] = handler + return id + } + + func removeRevocationObserver(_ id: UUID) { + revocationObservers.removeValue(forKey: id) + } + func generate( name: String, permissions: TokenPermissions, @@ -234,6 +239,7 @@ actor MCPTokenStore { tokens[index].isActive = false save() + notifyRevocationObservers(tokenId: tokenId) let revokedName = tokens[index].name Self.logger.info("Revoked MCP token '\(revokedName, privacy: .public)'") @@ -249,10 +255,19 @@ actor MCPTokenStore { let name = tokens[index].name tokens.remove(at: index) save() + notifyRevocationObservers(tokenId: tokenId) Self.logger.info("Deleted MCP token '\(name, privacy: .public)'") } + private func notifyRevocationObservers(tokenId: UUID) { + let observers = Array(revocationObservers.values) + let key = tokenId.uuidString + for observer in observers { + Task { await observer(key) } + } + } + func list() -> [MCPAuthToken] { tokens } diff --git a/TablePro/Core/MCP/MCPToolHandler+Integrations.swift b/TablePro/Core/MCP/MCPToolHandler+Integrations.swift deleted file mode 100644 index df53c0431..000000000 --- a/TablePro/Core/MCP/MCPToolHandler+Integrations.swift +++ /dev/null @@ -1,334 +0,0 @@ -// -// MCPToolHandler+Integrations.swift -// TablePro -// - -import AppKit -import Foundation - -extension MCPToolHandler { - func handleListRecentTabs(_ args: JSONValue?, sessionId: String, token: MCPAuthToken?) async throws -> MCPToolResult { - let limit = optionalInt(args, key: "limit", default: 20, clamp: 1...500) - - if let token, !token.permissions.satisfies(.readOnly) { - throw MCPError.forbidden( - "Token '\(token.name)' with permission '\(token.permissions.displayName)' cannot access 'list_recent_tabs'" - ) - } - - let snapshots = await MainActor.run { Self.collectTabSnapshots() } - let blockedConnectionIds = await MainActor.run { Self.blockedExternalConnectionIds() } - let access = token?.connectionAccess ?? .all - let filtered = snapshots.filter { snapshot in - guard !blockedConnectionIds.contains(snapshot.connectionId) else { return false } - return access.allows(snapshot.connectionId) - } - - let trimmed = Array(filtered.prefix(limit)) - let payload = trimmed.map { snapshot -> JSONValue in - var dict: [String: JSONValue] = [ - "connection_id": .string(snapshot.connectionId.uuidString), - "connection_name": .string(snapshot.connectionName), - "tab_id": .string(snapshot.tabId.uuidString), - "tab_type": .string(snapshot.tabType), - "display_title": .string(snapshot.displayTitle), - "is_active": .bool(snapshot.isActive) - ] - if let table = snapshot.tableName { - dict["table_name"] = .string(table) - } - if let database = snapshot.databaseName { - dict["database_name"] = .string(database) - } - if let schema = snapshot.schemaName { - dict["schema_name"] = .string(schema) - } - if let windowId = snapshot.windowId { - dict["window_id"] = .string(windowId.uuidString) - } - return .object(dict) - } - - return MCPToolResult(content: [.text(encodeJSON(.object(["tabs": .array(payload)])))], isError: nil) - } - - func handleSearchQueryHistory(_ args: JSONValue?, sessionId: String, token: MCPAuthToken?) async throws -> MCPToolResult { - let query = try requireString(args, key: "query") - let connectionIdString = optionalString(args, key: "connection_id") - let limit = optionalInt(args, key: "limit", default: 50, clamp: 1...500) - let since = args?["since"]?.doubleValue.map { Date(timeIntervalSince1970: $0) } - let until = args?["until"]?.doubleValue.map { Date(timeIntervalSince1970: $0) } - - if let since, let until, since > until { - throw MCPError.invalidParams("'since' must be less than or equal to 'until'") - } - - if let token, !token.permissions.satisfies(.readOnly) { - throw MCPError.forbidden( - "Token '\(token.name)' with permission '\(token.permissions.displayName)' cannot access 'search_query_history'" - ) - } - - let blockedConnectionIds = await MainActor.run { Self.blockedExternalConnectionIds() } - - let connectionId: UUID? - if let connectionIdString { - guard let parsed = UUID(uuidString: connectionIdString) else { - throw MCPError.invalidParams("Invalid UUID for parameter: connection_id") - } - if let token, !token.connectionAccess.allows(parsed) { - throw MCPError.forbidden("Token does not have access to this connection") - } - if blockedConnectionIds.contains(parsed) { - throw MCPError.forbidden( - String(localized: "External access is disabled for this connection") - ) - } - connectionId = parsed - } else { - connectionId = nil - } - - let tokenScopedAllowlist = await resolveHistoryAllowlist( - token: token, - scopedConnectionId: connectionId, - blockedConnectionIds: blockedConnectionIds - ) - - let entries = await QueryHistoryStorage.shared.fetchHistory( - limit: limit, - offset: 0, - connectionId: connectionId, - searchText: query.isEmpty ? nil : query, - dateFilter: .all, - since: since, - until: until, - allowedConnectionIds: tokenScopedAllowlist - ) - - let payload = entries.map { entry -> JSONValue in - var dict: [String: JSONValue] = [ - "id": .string(entry.id.uuidString), - "query": .string(entry.query), - "connection_id": .string(entry.connectionId.uuidString), - "database_name": .string(entry.databaseName), - "executed_at": .double(entry.executedAt.timeIntervalSince1970), - "execution_time_ms": .double(entry.executionTime * 1_000), - "row_count": .int(entry.rowCount), - "was_successful": .bool(entry.wasSuccessful) - ] - if let error = entry.errorMessage { - dict["error_message"] = .string(error) - } - return .object(dict) - } - - return MCPToolResult(content: [.text(encodeJSON(.object(["entries": .array(payload)])))], isError: nil) - } - - func handleOpenConnectionWindow(_ args: JSONValue?, sessionId: String, token: MCPAuthToken?) async throws -> MCPToolResult { - let connectionId = try requireUUID(args, key: "connection_id") - try await ensureConnectionExists(connectionId) - try await authPolicy.resolveAndAuthorize( - token: token ?? Self.anonymousFullAccessToken, - tool: "open_connection_window", - connectionId: connectionId, - sessionId: sessionId - ) - - let windowId = await MainActor.run { () -> UUID in - let payload = EditorTabPayload( - connectionId: connectionId, - tabType: .query, - intent: .restoreOrDefault - ) - WindowManager.shared.openTab(payload: payload) - NSApp.activate(ignoringOtherApps: true) - return payload.id - } - - let result: JSONValue = .object([ - "status": "opened", - "connection_id": .string(connectionId.uuidString), - "window_id": .string(windowId.uuidString) - ]) - return MCPToolResult(content: [.text(encodeJSON(result))], isError: nil) - } - - func handleOpenTableTab(_ args: JSONValue?, sessionId: String, token: MCPAuthToken?) async throws -> MCPToolResult { - let connectionId = try requireUUID(args, key: "connection_id") - let tableName = try requireString(args, key: "table_name") - let databaseName = optionalString(args, key: "database_name") - let schemaName = optionalString(args, key: "schema_name") - - try await ensureConnectionExists(connectionId) - try await authPolicy.resolveAndAuthorize( - token: token ?? Self.anonymousFullAccessToken, - tool: "open_table_tab", - connectionId: connectionId, - sessionId: sessionId - ) - - let windowId = await MainActor.run { () -> UUID in - let payload = EditorTabPayload( - connectionId: connectionId, - tabType: .table, - tableName: tableName, - databaseName: databaseName, - schemaName: schemaName, - intent: .openContent - ) - WindowManager.shared.openTab(payload: payload) - NSApp.activate(ignoringOtherApps: true) - return payload.id - } - - let result: JSONValue = .object([ - "status": "opened", - "connection_id": .string(connectionId.uuidString), - "table_name": .string(tableName), - "window_id": .string(windowId.uuidString) - ]) - return MCPToolResult(content: [.text(encodeJSON(result))], isError: nil) - } - - func handleFocusQueryTab(_ args: JSONValue?, sessionId: String, token: MCPAuthToken?) async throws -> MCPToolResult { - let tabId = try requireUUID(args, key: "tab_id") - - let resolved = await MainActor.run { () -> (hasWindow: Bool, windowId: UUID?, connectionId: UUID?)? in - for snapshot in Self.collectTabSnapshots() where snapshot.tabId == tabId { - return (snapshot.window != nil, snapshot.windowId, snapshot.connectionId) - } - return nil - } - - guard let resolved, resolved.hasWindow else { - throw MCPError.notFound("tab") - } - - guard let connectionId = resolved.connectionId else { - throw MCPError.notFound("connection") - } - try await authPolicy.resolveAndAuthorize( - token: token ?? Self.anonymousFullAccessToken, - tool: "focus_query_tab", - connectionId: connectionId, - sessionId: sessionId - ) - - let raised = await MainActor.run { () -> Bool in - for snapshot in Self.collectTabSnapshots() where snapshot.tabId == tabId { - guard snapshot.connectionId == connectionId else { return false } - guard let window = snapshot.window else { return false } - NSApp.activate(ignoringOtherApps: true) - window.makeKeyAndOrderFront(nil) - return true - } - return false - } - - guard raised else { - throw MCPError.notFound("tab") - } - - var dict: [String: JSONValue] = [ - "status": "focused", - "tab_id": .string(tabId.uuidString), - "connection_id": .string(connectionId.uuidString) - ] - if let windowId = resolved.windowId { - dict["window_id"] = .string(windowId.uuidString) - } - - return MCPToolResult(content: [.text(encodeJSON(.object(dict)))], isError: nil) - } - - private func resolveHistoryAllowlist( - token: MCPAuthToken?, - scopedConnectionId: UUID?, - blockedConnectionIds: Set - ) async -> Set? { - if scopedConnectionId != nil { - return nil - } - if let access = token?.connectionAccess, case .limited(let allowed) = access { - return allowed.subtracting(blockedConnectionIds) - } - guard !blockedConnectionIds.isEmpty else { return nil } - let allConnectionIds = await MainActor.run { - Set(ConnectionStorage.shared.loadConnections().map(\.id)) - } - return allConnectionIds.subtracting(blockedConnectionIds) - } - - private func ensureConnectionExists(_ connectionId: UUID) async throws { - let exists = await MainActor.run { - ConnectionStorage.shared.loadConnections().contains { $0.id == connectionId } - } - guard exists else { - throw MCPError.notFound("connection") - } - } - - @MainActor - static func collectTabSnapshots() -> [TabSnapshot] { - let connections = ConnectionStorage.shared.loadConnections() - let connectionsById = Dictionary(uniqueKeysWithValues: connections.map { ($0.id, $0) }) - - var snapshots: [TabSnapshot] = [] - for coordinator in MainContentCoordinator.allActiveCoordinators() { - let connectionName = connectionsById[coordinator.connectionId]?.name - ?? coordinator.connection.name - let selectedId = coordinator.tabManager.selectedTabId - for tab in coordinator.tabManager.tabs { - snapshots.append(TabSnapshot( - tabId: tab.id, - connectionId: coordinator.connectionId, - connectionName: connectionName, - tabType: tab.tabType.snapshotName, - tableName: tab.tableContext.tableName, - databaseName: tab.tableContext.databaseName, - schemaName: tab.tableContext.schemaName, - displayTitle: tab.title, - windowId: coordinator.windowId, - isActive: tab.id == selectedId, - window: coordinator.contentWindow - )) - } - } - return snapshots - } - - @MainActor - static func blockedExternalConnectionIds() -> Set { - let connections = ConnectionStorage.shared.loadConnections() - return Set(connections.filter { $0.externalAccess == .blocked }.map(\.id)) - } -} - -struct TabSnapshot { - let tabId: UUID - let connectionId: UUID - let connectionName: String - let tabType: String - let tableName: String? - let databaseName: String? - let schemaName: String? - let displayTitle: String - let windowId: UUID? - let isActive: Bool - weak var window: NSWindow? -} - -private extension TabType { - var snapshotName: String { - switch self { - case .query: "query" - case .table: "table" - case .createTable: "createTable" - case .erDiagram: "erDiagram" - case .serverDashboard: "serverDashboard" - case .terminal: "terminal" - } - } -} diff --git a/TablePro/Core/MCP/MCPToolHandler.swift b/TablePro/Core/MCP/MCPToolHandler.swift deleted file mode 100644 index 045c8ac0a..000000000 --- a/TablePro/Core/MCP/MCPToolHandler.swift +++ /dev/null @@ -1,796 +0,0 @@ -import Foundation -import os - -final class MCPToolHandler: Sendable { - private static let logger = Logger(subsystem: "com.TablePro", category: "MCPToolHandler") - - let bridge: MCPConnectionBridge - let authPolicy: MCPAuthPolicy - - init(bridge: MCPConnectionBridge, authPolicy: MCPAuthPolicy) { - self.bridge = bridge - self.authPolicy = authPolicy - } - - func handleToolCall( - name: String, - arguments: JSONValue?, - sessionId: String, - token: MCPAuthToken? = nil - ) async throws -> MCPToolResult { - do { - let result = try await dispatchTool( - name: name, - arguments: arguments, - sessionId: sessionId, - token: token - ) - logToolOutcome(name: name, token: token, arguments: arguments, outcome: .success, error: nil) - return result - } catch let error as MCPError { - let outcome: AuditOutcome - if case .forbidden = error { - outcome = .denied - } else { - outcome = .error - } - logToolOutcome(name: name, token: token, arguments: arguments, outcome: outcome, error: error.message) - throw error - } catch { - logToolOutcome(name: name, token: token, arguments: arguments, outcome: .error, error: error.localizedDescription) - throw error - } - } - - private func dispatchTool( - name: String, - arguments: JSONValue?, - sessionId: String, - token: MCPAuthToken? - ) async throws -> MCPToolResult { - switch name { - case "list_connections": - return try await handleListConnections(token: token) - case "connect": - return try await handleConnect(arguments, sessionId: sessionId, token: token) - case "disconnect": - return try await handleDisconnect(arguments, sessionId: sessionId, token: token) - case "get_connection_status": - return try await handleGetConnectionStatus(arguments, sessionId: sessionId, token: token) - case "execute_query": - return try await handleExecuteQuery(arguments, sessionId: sessionId, token: token) - case "list_tables": - return try await handleListTables(arguments, sessionId: sessionId, token: token) - case "describe_table": - return try await handleDescribeTable(arguments, sessionId: sessionId, token: token) - case "list_databases": - return try await handleListDatabases(arguments, sessionId: sessionId, token: token) - case "list_schemas": - return try await handleListSchemas(arguments, sessionId: sessionId, token: token) - case "get_table_ddl": - return try await handleGetTableDDL(arguments, sessionId: sessionId, token: token) - case "export_data": - return try await handleExportData(arguments, sessionId: sessionId, token: token) - case "confirm_destructive_operation": - return try await handleConfirmDestructiveOperation(arguments, sessionId: sessionId, token: token) - case "switch_database": - return try await handleSwitchDatabase(arguments, sessionId: sessionId, token: token) - case "switch_schema": - return try await handleSwitchSchema(arguments, sessionId: sessionId, token: token) - case "list_recent_tabs": - return try await handleListRecentTabs(arguments, sessionId: sessionId, token: token) - case "search_query_history": - return try await handleSearchQueryHistory(arguments, sessionId: sessionId, token: token) - case "open_connection_window": - return try await handleOpenConnectionWindow(arguments, sessionId: sessionId, token: token) - case "open_table_tab": - return try await handleOpenTableTab(arguments, sessionId: sessionId, token: token) - case "focus_query_tab": - return try await handleFocusQueryTab(arguments, sessionId: sessionId, token: token) - default: - throw MCPError.methodNotFound(name) - } - } - - private func logToolOutcome( - name: String, - token: MCPAuthToken?, - arguments: JSONValue?, - outcome: AuditOutcome, - error: String? - ) { - let connectionId = arguments?["connection_id"]?.stringValue.flatMap(UUID.init(uuidString:)) - MCPAuditLogger.logToolCalled( - tokenId: token?.id, - tokenName: token?.name, - toolName: name, - connectionId: connectionId, - outcome: outcome, - errorMessage: error - ) - } - - private func authorize( - token: MCPAuthToken?, - tool: String, - connectionId: UUID?, - sql: String? = nil, - sessionId: String - ) async throws { - try await authPolicy.resolveAndAuthorize( - token: token ?? Self.anonymousFullAccessToken, - tool: tool, - connectionId: connectionId, - sql: sql, - sessionId: sessionId - ) - } - - static let anonymousFullAccessToken = MCPAuthToken( - id: UUID(), - name: "__anonymous__", - prefix: "tp_anon", - tokenHash: "", - salt: "", - permissions: .fullAccess, - connectionAccess: .all, - createdAt: Date.now, - lastUsedAt: nil, - expiresAt: nil, - isActive: true - ) - - private func handleListConnections(token: MCPAuthToken?) async throws -> MCPToolResult { - let result = await bridge.listConnections() - let filtered = filterConnectionsByToken(result, token: token) - return MCPToolResult(content: [.text(encodeJSON(filtered))], isError: nil) - } - - private func filterConnectionsByToken(_ value: JSONValue, token: MCPAuthToken?) -> JSONValue { - guard let access = token?.connectionAccess, case .limited(let allowed) = access else { - return value - } - guard case .object(var dict) = value, - let entries = dict["connections"]?.arrayValue - else { - return value - } - let filtered = entries.filter { entry in - guard let idString = entry["id"]?.stringValue, - let id = UUID(uuidString: idString) - else { - return false - } - return allowed.contains(id) - } - dict["connections"] = .array(filtered) - return .object(dict) - } - - private func handleConnect(_ args: JSONValue?, sessionId: String, token: MCPAuthToken?) async throws -> MCPToolResult { - let connectionId = try requireUUID(args, key: "connection_id") - try await authorize(token: token, tool: "connect", connectionId: connectionId, sessionId: sessionId) - let result = try await bridge.connect(connectionId: connectionId) - return MCPToolResult(content: [.text(encodeJSON(result))], isError: nil) - } - - private func handleDisconnect(_ args: JSONValue?, sessionId: String, token: MCPAuthToken?) async throws -> MCPToolResult { - let connectionId = try requireUUID(args, key: "connection_id") - try await authorize(token: token, tool: "disconnect", connectionId: connectionId, sessionId: sessionId) - try await bridge.disconnect(connectionId: connectionId) - let result: JSONValue = .object(["status": "disconnected"]) - return MCPToolResult(content: [.text(encodeJSON(result))], isError: nil) - } - - private func handleGetConnectionStatus(_ args: JSONValue?, sessionId: String, token: MCPAuthToken?) async throws -> MCPToolResult { - let connectionId = try requireUUID(args, key: "connection_id") - try await authorize(token: token, tool: "get_connection_status", connectionId: connectionId, sessionId: sessionId) - let result = try await bridge.getConnectionStatus(connectionId: connectionId) - return MCPToolResult(content: [.text(encodeJSON(result))], isError: nil) - } - - private func handleExecuteQuery(_ args: JSONValue?, sessionId: String, token: MCPAuthToken?) async throws -> MCPToolResult { - let connectionId = try requireUUID(args, key: "connection_id") - let query = try requireString(args, key: "query") - let mcpSettings = await MainActor.run { AppSettingsManager.shared.mcp } - let maxRows = optionalInt(args, key: "max_rows", default: mcpSettings.defaultRowLimit, clamp: 1...mcpSettings.maxRowLimit) - let timeoutSeconds = optionalInt(args, key: "timeout_seconds", default: mcpSettings.queryTimeoutSeconds, clamp: 1...300) - let database = optionalString(args, key: "database") - let schema = optionalString(args, key: "schema") - - guard (query as NSString).length <= 102_400 else { - throw MCPError.invalidParams("Query exceeds 100KB limit") - } - - guard !QueryClassifier.isMultiStatement(query) else { - throw MCPError.invalidParams("Multi-statement queries are not supported. Send one statement at a time.") - } - - try await authorize( - token: token, - tool: "execute_query", - connectionId: connectionId, - sql: query, - sessionId: sessionId - ) - - let (databaseType, safeModeLevel, databaseName) = try await resolveConnectionMeta(connectionId) - - if let database { - _ = try await bridge.switchDatabase(connectionId: connectionId, database: database) - } - if let schema { - _ = try await bridge.switchSchema(connectionId: connectionId, schema: schema) - } - - let tier = QueryClassifier.classifyTier(query, databaseType: databaseType) - - switch tier { - case .destructive: - throw MCPError.forbidden( - "Destructive queries (DROP, TRUNCATE, ALTER...DROP) cannot be executed via execute_query. " - + "Use the confirm_destructive_operation tool instead." - ) - - case .write: - if let token, !token.permissions.satisfies(.readWrite) { - throw MCPError.forbidden( - "Token '\(token.name)' with '\(token.permissions.displayName)' permission cannot execute write queries" - ) - } - try await authPolicy.checkSafeModeDialog( - sql: query, - connectionId: connectionId, - databaseType: databaseType, - safeModeLevel: safeModeLevel - ) - - case .safe: - try await authPolicy.checkSafeModeDialog( - sql: query, - connectionId: connectionId, - databaseType: databaseType, - safeModeLevel: safeModeLevel - ) - } - - let result = try await executeAndLog( - query: query, - connectionId: connectionId, - databaseName: databaseName, - maxRows: maxRows, - timeoutSeconds: timeoutSeconds, - token: token - ) - - return MCPToolResult(content: [.text(encodeJSON(result))], isError: nil) - } - - private func handleConfirmDestructiveOperation( - _ args: JSONValue?, - sessionId: String, - token: MCPAuthToken? - ) async throws -> MCPToolResult { - let connectionId = try requireUUID(args, key: "connection_id") - let query = try requireString(args, key: "query") - let confirmationPhrase = try requireString(args, key: "confirmation_phrase") - - guard confirmationPhrase == "I understand this is irreversible" else { - throw MCPError.invalidParams( - "confirmation_phrase must be exactly: I understand this is irreversible" - ) - } - - guard !QueryClassifier.isMultiStatement(query) else { - throw MCPError.invalidParams( - "Multi-statement queries are not supported. Send one statement at a time." - ) - } - - try await authorize( - token: token, - tool: "confirm_destructive_operation", - connectionId: connectionId, - sql: query, - sessionId: sessionId - ) - - let (databaseType, safeModeLevel, databaseName) = try await resolveConnectionMeta(connectionId) - - let tier = QueryClassifier.classifyTier(query, databaseType: databaseType) - guard tier == .destructive else { - throw MCPError.invalidParams( - "This tool only accepts destructive queries (DROP, TRUNCATE, ALTER...DROP). " - + "Use execute_query for other queries." - ) - } - - try await authPolicy.checkSafeModeDialog( - sql: query, - connectionId: connectionId, - databaseType: databaseType, - safeModeLevel: safeModeLevel - ) - - let mcpSettings = await MainActor.run { AppSettingsManager.shared.mcp } - let timeoutSeconds = mcpSettings.queryTimeoutSeconds - - let result = try await executeAndLog( - query: query, - connectionId: connectionId, - databaseName: databaseName, - maxRows: 0, - timeoutSeconds: timeoutSeconds, - token: token - ) - - return MCPToolResult(content: [.text(encodeJSON(result))], isError: nil) - } - - private func handleListTables(_ args: JSONValue?, sessionId: String, token: MCPAuthToken?) async throws -> MCPToolResult { - let connectionId = try requireUUID(args, key: "connection_id") - let includeRowCounts = optionalBool(args, key: "include_row_counts", default: false) - let database = optionalString(args, key: "database") - let schema = optionalString(args, key: "schema") - - try await authorize(token: token, tool: "list_tables", connectionId: connectionId, sessionId: sessionId) - - if let database { - _ = try await bridge.switchDatabase(connectionId: connectionId, database: database) - } - if let schema { - _ = try await bridge.switchSchema(connectionId: connectionId, schema: schema) - } - - let result = try await bridge.listTables(connectionId: connectionId, includeRowCounts: includeRowCounts) - return MCPToolResult(content: [.text(encodeJSON(result))], isError: nil) - } - - private func handleDescribeTable(_ args: JSONValue?, sessionId: String, token: MCPAuthToken?) async throws -> MCPToolResult { - let connectionId = try requireUUID(args, key: "connection_id") - let table = try requireString(args, key: "table") - let schema = optionalString(args, key: "schema") - - try await authorize(token: token, tool: "describe_table", connectionId: connectionId, sessionId: sessionId) - - let result = try await bridge.describeTable(connectionId: connectionId, table: table, schema: schema) - return MCPToolResult(content: [.text(encodeJSON(result))], isError: nil) - } - - private func handleListDatabases(_ args: JSONValue?, sessionId: String, token: MCPAuthToken?) async throws -> MCPToolResult { - let connectionId = try requireUUID(args, key: "connection_id") - try await authorize(token: token, tool: "list_databases", connectionId: connectionId, sessionId: sessionId) - let result = try await bridge.listDatabases(connectionId: connectionId) - return MCPToolResult(content: [.text(encodeJSON(result))], isError: nil) - } - - private func handleListSchemas(_ args: JSONValue?, sessionId: String, token: MCPAuthToken?) async throws -> MCPToolResult { - let connectionId = try requireUUID(args, key: "connection_id") - let database = optionalString(args, key: "database") - - try await authorize(token: token, tool: "list_schemas", connectionId: connectionId, sessionId: sessionId) - - if let database { - _ = try await bridge.switchDatabase(connectionId: connectionId, database: database) - } - - let result = try await bridge.listSchemas(connectionId: connectionId) - return MCPToolResult(content: [.text(encodeJSON(result))], isError: nil) - } - - private func handleGetTableDDL(_ args: JSONValue?, sessionId: String, token: MCPAuthToken?) async throws -> MCPToolResult { - let connectionId = try requireUUID(args, key: "connection_id") - let table = try requireString(args, key: "table") - let schema = optionalString(args, key: "schema") - - try await authorize(token: token, tool: "get_table_ddl", connectionId: connectionId, sessionId: sessionId) - - let result = try await bridge.getTableDDL(connectionId: connectionId, table: table, schema: schema) - return MCPToolResult(content: [.text(encodeJSON(result))], isError: nil) - } - - private func handleExportData(_ args: JSONValue?, sessionId: String, token: MCPAuthToken?) async throws -> MCPToolResult { - let connectionId = try requireUUID(args, key: "connection_id") - let format = try requireString(args, key: "format") - let query = optionalString(args, key: "query") - let tables = optionalStringArray(args, key: "tables") - let outputPath = optionalString(args, key: "output_path") - let maxRows = optionalInt(args, key: "max_rows", default: 50_000, clamp: 1...100_000) - - guard ["csv", "json", "sql"].contains(format) else { - throw MCPError.invalidParams("Unsupported format: \(format). Must be csv, json, or sql") - } - - guard query != nil || tables != nil else { - throw MCPError.invalidParams("Either 'query' or 'tables' must be provided") - } - - if let tables { - for table in tables { - try Self.validateExportTableName(table) - } - } - - if let outputPath { - _ = try Self.sandboxedDownloadsURL(for: outputPath) - } - - try await authorize( - token: token, - tool: "export_data", - connectionId: connectionId, - sql: query, - sessionId: sessionId - ) - - let (databaseType, safeModeLevel, _) = try await resolveConnectionMeta(connectionId) - var queries: [(label: String, sql: String)] = [] - - if let query { - try await authPolicy.checkSafeModeDialog( - sql: query, - connectionId: connectionId, - databaseType: databaseType, - safeModeLevel: safeModeLevel - ) - queries.append((label: "query", sql: query)) - } else if let tables { - let quoteIdentifier = Self.identifierQuoter(for: databaseType) - for table in tables { - let quoted = try Self.quoteQualifiedIdentifier(table, quoter: quoteIdentifier) - let sql = "SELECT * FROM \(quoted) LIMIT \(maxRows)" - try await authPolicy.checkSafeModeDialog( - sql: sql, - connectionId: connectionId, - databaseType: databaseType, - safeModeLevel: safeModeLevel - ) - queries.append((label: table, sql: sql)) - } - } - - var exportResults: [JSONValue] = [] - var totalRowsExported = 0 - - for (label, sql) in queries { - let result = try await bridge.executeQuery( - connectionId: connectionId, - query: sql, - maxRows: maxRows, - timeoutSeconds: 60 - ) - - guard let columns = result["columns"]?.arrayValue, - let rows = result["rows"]?.arrayValue - else { - throw MCPError.internalError("Unexpected query result structure") - } - - let columnNames = columns.compactMap(\.stringValue) - let formatted: String - - switch format { - case "csv": - formatted = formatCSV(columns: columnNames, rows: rows) - case "json": - formatted = formatJSON(columns: columnNames, rows: rows) - case "sql": - formatted = formatSQL(table: label, columns: columnNames, rows: rows) - default: - formatted = formatCSV(columns: columnNames, rows: rows) - } - - totalRowsExported += rows.count - - exportResults.append(.object([ - "label": .string(label), - "format": .string(format), - "row_count": result["row_count"] ?? .int(0), - "data": .string(formatted) - ])) - } - - if let outputPath { - let fileURL = try Self.sandboxedDownloadsURL(for: outputPath) - - let fullContent: String - if exportResults.count == 1, - let data = exportResults.first?["data"]?.stringValue - { - fullContent = data - } else { - fullContent = exportResults.compactMap { $0["data"]?.stringValue }.joined(separator: "\n\n") - } - - try fullContent.write(to: fileURL, atomically: true, encoding: .utf8) - - let response: JSONValue = .object([ - "path": .string(fileURL.path), - "rows_exported": .int(totalRowsExported) - ]) - return MCPToolResult(content: [.text(encodeJSON(response))], isError: nil) - } - - let response: JSONValue - if exportResults.count == 1, let single = exportResults.first { - response = single - } else { - response = .object(["exports": .array(exportResults)]) - } - - return MCPToolResult(content: [.text(encodeJSON(response))], isError: nil) - } - - private func handleSwitchDatabase(_ args: JSONValue?, sessionId: String, token: MCPAuthToken?) async throws -> MCPToolResult { - let connectionId = try requireUUID(args, key: "connection_id") - let database = try requireString(args, key: "database") - - try await authorize(token: token, tool: "switch_database", connectionId: connectionId, sessionId: sessionId) - - let result = try await bridge.switchDatabase(connectionId: connectionId, database: database) - return MCPToolResult(content: [.text(encodeJSON(result))], isError: nil) - } - - private func handleSwitchSchema(_ args: JSONValue?, sessionId: String, token: MCPAuthToken?) async throws -> MCPToolResult { - let connectionId = try requireUUID(args, key: "connection_id") - let schema = try requireString(args, key: "schema") - - try await authorize(token: token, tool: "switch_schema", connectionId: connectionId, sessionId: sessionId) - - let result = try await bridge.switchSchema(connectionId: connectionId, schema: schema) - return MCPToolResult(content: [.text(encodeJSON(result))], isError: nil) - } - - private func executeAndLog( - query: String, - connectionId: UUID, - databaseName: String, - maxRows: Int, - timeoutSeconds: Int, - token: MCPAuthToken? = nil - ) async throws -> JSONValue { - let startTime = Date() - do { - let result = try await bridge.executeQuery( - connectionId: connectionId, - query: query, - maxRows: maxRows, - timeoutSeconds: timeoutSeconds - ) - let elapsed = Date().timeIntervalSince(startTime) - let rowCount = result["row_count"]?.intValue ?? 0 - await authPolicy.logQuery( - sql: query, - connectionId: connectionId, - databaseName: databaseName, - executionTime: elapsed, - rowCount: rowCount, - wasSuccessful: true, - errorMessage: nil - ) - MCPAuditLogger.logQueryExecuted( - tokenId: token?.id, - tokenName: token?.name, - connectionId: connectionId, - sql: query, - durationMs: Int(elapsed * 1_000), - rowCount: rowCount, - outcome: .success - ) - return result - } catch { - let elapsed = Date().timeIntervalSince(startTime) - await authPolicy.logQuery( - sql: query, - connectionId: connectionId, - databaseName: databaseName, - executionTime: elapsed, - rowCount: 0, - wasSuccessful: false, - errorMessage: error.localizedDescription - ) - MCPAuditLogger.logQueryExecuted( - tokenId: token?.id, - tokenName: token?.name, - connectionId: connectionId, - sql: query, - durationMs: Int(elapsed * 1_000), - rowCount: 0, - outcome: .error, - errorMessage: error.localizedDescription - ) - throw error - } - } - - func requireUUID(_ args: JSONValue?, key: String) throws -> UUID { - guard let value = args?[key]?.stringValue else { - throw MCPError.invalidParams("Missing required parameter: \(key)") - } - guard let uuid = UUID(uuidString: value) else { - throw MCPError.invalidParams("Invalid UUID for parameter: \(key)") - } - return uuid - } - - func requireString(_ args: JSONValue?, key: String) throws -> String { - guard let value = args?[key]?.stringValue else { - throw MCPError.invalidParams("Missing required parameter: \(key)") - } - return value - } - - func optionalString(_ args: JSONValue?, key: String) -> String? { - args?[key]?.stringValue - } - - func optionalInt(_ args: JSONValue?, key: String, default defaultValue: Int, clamp range: ClosedRange) -> Int { - guard let value = args?[key]?.intValue else { return defaultValue } - return min(max(value, range.lowerBound), range.upperBound) - } - - private func optionalBool(_ args: JSONValue?, key: String, default defaultValue: Bool) -> Bool { - args?[key]?.boolValue ?? defaultValue - } - - private func optionalStringArray(_ args: JSONValue?, key: String) -> [String]? { - guard let array = args?[key]?.arrayValue else { return nil } - let strings = array.compactMap(\.stringValue) - return strings.isEmpty ? nil : strings - } - - private func resolveConnectionMeta(_ connectionId: UUID) async throws -> (DatabaseType, SafeModeLevel, String) { - try await MainActor.run { - switch DatabaseManager.shared.connectionState(connectionId) { - case .live(_, let session): - return (session.connection.type, session.connection.safeModeLevel, session.activeDatabase) - case .stored(let conn): - return (conn.type, conn.safeModeLevel, conn.database) - case .unknown: - throw MCPError.notConnected(connectionId) - } - } - } - - static func validateExportTableName(_ table: String) throws { - let pattern = "^[A-Za-z0-9_]+(\\.[A-Za-z0-9_]+)*$" - guard table.range(of: pattern, options: .regularExpression) != nil else { - throw MCPError.invalidParams( - "Invalid table name: '\(table)'. Allowed characters: letters, digits, underscore, and '.' for schema-qualified names." - ) - } - } - - static func identifierQuoter(for databaseType: DatabaseType) -> (String) -> String { - if let dialect = try? resolveSQLDialect(for: databaseType) { - return quoteIdentifierFromDialect(dialect) - } - return { "\"\($0.replacingOccurrences(of: "\"", with: "\"\""))\"" } - } - - static func quoteQualifiedIdentifier(_ identifier: String, quoter: (String) -> String) throws -> String { - let segments = identifier.split(separator: ".", omittingEmptySubsequences: true) - guard !segments.isEmpty, segments.count == identifier.split(separator: ".", omittingEmptySubsequences: false).count else { - throw MCPError.invalidParams( - "Invalid qualified identifier: '\(identifier)'. Empty components are not allowed." - ) - } - return segments.map { quoter(String($0)) }.joined(separator: ".") - } - - static func sandboxedDownloadsURL(for path: String) throws -> URL { - guard let downloads = FileManager.default.urls(for: .downloadsDirectory, in: .userDomainMask).first else { - throw MCPError.invalidParams("Downloads directory is not available") - } - let downloadsRoot = downloads.standardizedFileURL.resolvingSymlinksInPath().path - let candidate = path.hasPrefix("/") ? URL(fileURLWithPath: path) : downloads.appendingPathComponent(path) - let resolvedPath = candidate.standardizedFileURL.resolvingSymlinksInPath().path - let prefix = downloadsRoot.hasSuffix("/") ? downloadsRoot : downloadsRoot + "/" - guard resolvedPath == downloadsRoot || resolvedPath.hasPrefix(prefix) else { - throw MCPError.invalidParams( - "output_path must be inside the Downloads directory (\(downloadsRoot))" - ) - } - return URL(fileURLWithPath: resolvedPath) - } - - func encodeJSON(_ value: JSONValue) -> String { - let encoder = JSONEncoder() - encoder.outputFormatting = [.sortedKeys] - guard let data = try? encoder.encode(value), - let string = String(data: data, encoding: .utf8) - else { - Self.logger.warning("Failed to encode JSON value") - return "{}" - } - return string - } - - private func formatCSV(columns: [String], rows: [JSONValue]) -> String { - var lines: [String] = [] - lines.append(columns.map { escapeCSVField($0) }.joined(separator: ",")) - - for row in rows { - guard let cells = row.arrayValue else { continue } - let line = cells.map { cell -> String in - switch cell { - case .string(let value): - return escapeCSVField(value) - case .null: - return "" - case .int(let value): - return String(value) - case .double(let value): - return String(value) - case .bool(let value): - return value ? "true" : "false" - default: - return escapeCSVField(encodeJSON(cell)) - } - } - lines.append(line.joined(separator: ",")) - } - - return lines.joined(separator: "\n") - } - - private func escapeCSVField(_ field: String) -> String { - if field.contains(",") || field.contains("\"") || field.contains("\n") { - return "\"" + field.replacingOccurrences(of: "\"", with: "\"\"") + "\"" - } - return field - } - - private func formatJSON(columns: [String], rows: [JSONValue]) -> String { - var objects: [JSONValue] = [] - - for row in rows { - guard let cells = row.arrayValue else { continue } - var dict: [String: JSONValue] = [:] - for (index, column) in columns.enumerated() where index < cells.count { - dict[column] = cells[index] - } - objects.append(.object(dict)) - } - - return encodeJSON(.array(objects)) - } - - private func formatSQL(table: String, columns: [String], rows: [JSONValue]) -> String { - guard !columns.isEmpty else { return "" } - - var statements: [String] = [] - let escapedTable = "`\(table.replacingOccurrences(of: "`", with: "``"))`" - let escapedColumns = columns.map { "`\($0.replacingOccurrences(of: "`", with: "``"))`" } - let columnList = escapedColumns.joined(separator: ", ") - - for row in rows { - guard let cells = row.arrayValue else { continue } - let values = cells.map { cell -> String in - switch cell { - case .null: - return "NULL" - case .string(let value): - let escaped = value - .replacingOccurrences(of: "\\", with: "\\\\") - .replacingOccurrences(of: "'", with: "\\'") - return "'\(escaped)'" - case .int(let value): - return String(value) - case .double(let value): - return String(value) - case .bool(let value): - return value ? "1" : "0" - default: - let escaped = encodeJSON(cell) - .replacingOccurrences(of: "\\", with: "\\\\") - .replacingOccurrences(of: "'", with: "\\'") - return "'\(escaped)'" - } - } - statements.append("INSERT INTO \(escapedTable) (\(columnList)) VALUES (\(values.joined(separator: ", ")));") - } - - return statements.joined(separator: "\n") - } -} diff --git a/TablePro/Core/MCP/Protocol/Handlers/CompletionCompleteHandler.swift b/TablePro/Core/MCP/Protocol/Handlers/CompletionCompleteHandler.swift new file mode 100644 index 000000000..767259ce8 --- /dev/null +++ b/TablePro/Core/MCP/Protocol/Handlers/CompletionCompleteHandler.swift @@ -0,0 +1,24 @@ +import Foundation +import os + +public struct CompletionCompleteHandler: MCPMethodHandler { + public static let method = "completion/complete" + public static let requiredScopes: Set = [] + public static let allowedSessionStates: Set = [.ready] + + private static let logger = Logger(subsystem: "com.TablePro", category: "MCP.Completion") + + public init() {} + + public func handle(params: JsonValue?, context: MCPRequestContext) async throws -> JsonRpcMessage { + Self.logger.debug("completion/complete returning empty result") + let result: JsonValue = .object([ + "completion": .object([ + "values": .array([]), + "total": .int(0), + "hasMore": .bool(false) + ]) + ]) + return MCPMethodHandlerHelpers.successResponse(id: context.requestId, result: result) + } +} diff --git a/TablePro/Core/MCP/Protocol/Handlers/InitializeHandler.swift b/TablePro/Core/MCP/Protocol/Handlers/InitializeHandler.swift new file mode 100644 index 000000000..add6fc5dc --- /dev/null +++ b/TablePro/Core/MCP/Protocol/Handlers/InitializeHandler.swift @@ -0,0 +1,81 @@ +import Foundation +import os + +public struct InitializeHandler: MCPMethodHandler { + public static let method = "initialize" + public static let requiredScopes: Set = [] + public static let allowedSessionStates: Set = [.uninitialized] + + public static let supportedProtocolVersion = "2025-11-25" + public static let supportedProtocolVersions: Set = [ + "2025-03-26", + "2025-06-18", + "2025-11-25" + ] + + private static let logger = Logger(subsystem: "com.TablePro", category: "MCP.Handler.Initialize") + + public init() {} + + public func handle(params: JsonValue?, context: MCPRequestContext) async throws -> JsonRpcMessage { + let sessionState = await context.session.state + if case .ready = sessionState { + throw MCPProtocolError.invalidRequest(detail: "Session already initialized") + } + if await context.session.clientInfo != nil { + throw MCPProtocolError.invalidRequest(detail: "initialize already received for this session") + } + + let requestedVersion = params?["protocolVersion"]?.stringValue + let protocolVersion = Self.negotiate(requestedVersion: requestedVersion) + + let clientCapabilities = params?["capabilities"] + let clientName = params?["clientInfo"]?["name"]?.stringValue ?? "unknown" + let clientVersion = params?["clientInfo"]?["version"]?.stringValue + + let info = MCPClientInfo(name: clientName, version: clientVersion) + await context.session.recordInitialize( + clientInfo: info, + protocolVersion: protocolVersion, + capabilities: clientCapabilities + ) + + let result: JsonValue = .object([ + "protocolVersion": .string(protocolVersion), + "capabilities": .object([ + "tools": .object(["listChanged": .bool(false)]), + "resources": .object([ + "listChanged": .bool(false), + "subscribe": .bool(false) + ]), + "prompts": .object(["listChanged": .bool(false)]), + "logging": .object([:]), + "completions": .object([:]) + ]), + "serverInfo": .object([ + "name": .string("tablepro"), + "title": .string("TablePro"), + "version": .string(Self.serverVersion) + ]) + ]) + + Self.logger.info( + "Initialize: client=\(clientName, privacy: .public) version=\(clientVersion ?? "-", privacy: .public) protocol=\(protocolVersion, privacy: .public) requested=\(requestedVersion ?? "-", privacy: .public)" + ) + return MCPMethodHandlerHelpers.successResponse(id: context.requestId, result: result) + } + + public static func negotiate(requestedVersion: String?) -> String { + guard let requestedVersion, !requestedVersion.isEmpty else { + return supportedProtocolVersion + } + if supportedProtocolVersions.contains(requestedVersion) { + return requestedVersion + } + return supportedProtocolVersion + } + + private static let serverVersion: String = { + Bundle.main.object(forInfoDictionaryKey: "CFBundleShortVersionString") as? String ?? "0.0.0" + }() +} diff --git a/TablePro/Core/MCP/Protocol/Handlers/LoggingSetLevelHandler.swift b/TablePro/Core/MCP/Protocol/Handlers/LoggingSetLevelHandler.swift new file mode 100644 index 000000000..7526504b0 --- /dev/null +++ b/TablePro/Core/MCP/Protocol/Handlers/LoggingSetLevelHandler.swift @@ -0,0 +1,30 @@ +import Foundation +import os + +public struct LoggingSetLevelHandler: MCPMethodHandler { + public static let method = "logging/setLevel" + public static let requiredScopes: Set = [] + public static let allowedSessionStates: Set = [.ready] + + private static let logger = Logger(subsystem: "com.TablePro", category: "MCP.Logging") + + public static let supportedLevels: Set = [ + "debug", "info", "notice", "warning", "error", "critical", "alert", "emergency" + ] + + public init() {} + + public func handle(params: JsonValue?, context: MCPRequestContext) async throws -> JsonRpcMessage { + guard case .string(let level)? = params?["level"] else { + throw MCPProtocolError.invalidParams(detail: "Missing required parameter: level") + } + + let normalized = level.lowercased() + guard Self.supportedLevels.contains(normalized) else { + throw MCPProtocolError.invalidParams(detail: "Unknown log level: \(level)") + } + + Self.logger.notice("Client requested log level: \(normalized, privacy: .public)") + return MCPMethodHandlerHelpers.successResponse(id: context.requestId, result: .object([:])) + } +} diff --git a/TablePro/Core/MCP/Protocol/Handlers/PingHandler.swift b/TablePro/Core/MCP/Protocol/Handlers/PingHandler.swift new file mode 100644 index 000000000..b95057b33 --- /dev/null +++ b/TablePro/Core/MCP/Protocol/Handlers/PingHandler.swift @@ -0,0 +1,17 @@ +import Foundation + +public struct PingHandler: MCPMethodHandler { + public static let method = "ping" + public static let requiredScopes: Set = [] + public static let allowedSessionStates: Set = [.uninitialized, .ready] + + public init() {} + + public func handle(params: JsonValue?, context: MCPRequestContext) async throws -> JsonRpcMessage { + await context.session.touch(now: await context.clock.now()) + return MCPMethodHandlerHelpers.successResponse( + id: context.requestId, + result: .object([:]) + ) + } +} diff --git a/TablePro/Core/MCP/Protocol/Handlers/PromptsGetHandler.swift b/TablePro/Core/MCP/Protocol/Handlers/PromptsGetHandler.swift new file mode 100644 index 000000000..c3131efd1 --- /dev/null +++ b/TablePro/Core/MCP/Protocol/Handlers/PromptsGetHandler.swift @@ -0,0 +1,17 @@ +import Foundation +import os + +public struct PromptsGetHandler: MCPMethodHandler { + public static let method = "prompts/get" + public static let requiredScopes: Set = [] + public static let allowedSessionStates: Set = [.ready] + + private static let logger = Logger(subsystem: "com.TablePro", category: "MCP.Prompts") + + public init() {} + + public func handle(params: JsonValue?, context: MCPRequestContext) async throws -> JsonRpcMessage { + Self.logger.debug("prompts/get rejected: server has no prompts") + throw MCPProtocolError.methodNotFound(method: "prompts/get") + } +} diff --git a/TablePro/Core/MCP/Protocol/Handlers/PromptsListHandler.swift b/TablePro/Core/MCP/Protocol/Handlers/PromptsListHandler.swift new file mode 100644 index 000000000..e92043b57 --- /dev/null +++ b/TablePro/Core/MCP/Protocol/Handlers/PromptsListHandler.swift @@ -0,0 +1,18 @@ +import Foundation +import os + +public struct PromptsListHandler: MCPMethodHandler { + public static let method = "prompts/list" + public static let requiredScopes: Set = [] + public static let allowedSessionStates: Set = [.ready] + + private static let logger = Logger(subsystem: "com.TablePro", category: "MCP.Prompts") + + public init() {} + + public func handle(params: JsonValue?, context: MCPRequestContext) async throws -> JsonRpcMessage { + Self.logger.debug("prompts/list returning empty list") + let result: JsonValue = .object(["prompts": .array([])]) + return MCPMethodHandlerHelpers.successResponse(id: context.requestId, result: result) + } +} diff --git a/TablePro/Core/MCP/Protocol/Handlers/ResourcesListHandler.swift b/TablePro/Core/MCP/Protocol/Handlers/ResourcesListHandler.swift new file mode 100644 index 000000000..b2b42012f --- /dev/null +++ b/TablePro/Core/MCP/Protocol/Handlers/ResourcesListHandler.swift @@ -0,0 +1,75 @@ +import Foundation +import os + +public struct ResourcesListHandler: MCPMethodHandler { + public static let method = "resources/list" + public static let requiredScopes: Set = [.resourcesRead] + public static let allowedSessionStates: Set = [.ready] + + private static let logger = Logger(subsystem: "com.TablePro", category: "MCP.Resources") + + private let services: MCPToolServices + + public init(services: MCPToolServices) { + self.services = services + } + + public func handle(params: JsonValue?, context: MCPRequestContext) async throws -> JsonRpcMessage { + var resources: [JsonValue] = [] + resources.append(Self.staticConnectionsResource()) + + let connectedItems = await Self.connectedConnectionItems(services: services) + for item in connectedItems { + resources.append(Self.schemaResource(for: item)) + resources.append(Self.historyResource(for: item)) + } + + let result: JsonValue = .object(["resources": .array(resources)]) + Self.logger.debug("resources/list returned \(resources.count, privacy: .public) entries") + return MCPMethodHandlerHelpers.successResponse(id: context.requestId, result: result) + } + + private static func staticConnectionsResource() -> JsonValue { + .object([ + "uri": .string("tablepro://connections"), + "name": .string(String(localized: "Saved Connections")), + "description": .string(String(localized: "List of all saved database connections with metadata")), + "mimeType": .string("application/json") + ]) + } + + private struct ConnectedConnectionItem: Sendable { + let id: String + let name: String + } + + private static func connectedConnectionItems(services: MCPToolServices) async -> [ConnectedConnectionItem] { + let value = await services.connectionBridge.listConnections() + guard let connections = value["connections"]?.arrayValue else { return [] } + + return connections.compactMap { entry -> ConnectedConnectionItem? in + guard let id = entry["id"]?.stringValue else { return nil } + guard entry["is_connected"]?.boolValue == true else { return nil } + let name = entry["name"]?.stringValue ?? id + return ConnectedConnectionItem(id: id, name: name) + } + } + + private static func schemaResource(for item: ConnectedConnectionItem) -> JsonValue { + .object([ + "uri": .string("tablepro://connections/\(item.id)/schema"), + "name": .string(String(format: String(localized: "Schema for %@"), item.name)), + "description": .string(String(localized: "Tables, columns, indexes, and foreign keys for the connected database")), + "mimeType": .string("application/json") + ]) + } + + private static func historyResource(for item: ConnectedConnectionItem) -> JsonValue { + .object([ + "uri": .string("tablepro://connections/\(item.id)/history"), + "name": .string(String(format: String(localized: "Query history for %@"), item.name)), + "description": .string(String(localized: "Recent query history for this connection")), + "mimeType": .string("application/json") + ]) + } +} diff --git a/TablePro/Core/MCP/Protocol/Handlers/ResourcesReadHandler.swift b/TablePro/Core/MCP/Protocol/Handlers/ResourcesReadHandler.swift new file mode 100644 index 000000000..d557476fc --- /dev/null +++ b/TablePro/Core/MCP/Protocol/Handlers/ResourcesReadHandler.swift @@ -0,0 +1,197 @@ +import Foundation +import os + +public struct ResourcesReadHandler: MCPMethodHandler { + public static let method = "resources/read" + public static let requiredScopes: Set = [.resourcesRead] + public static let allowedSessionStates: Set = [.ready] + + private static let logger = Logger(subsystem: "com.TablePro", category: "MCP.Resources") + + private let services: MCPToolServices + + public init(services: MCPToolServices) { + self.services = services + } + + public func handle(params: JsonValue?, context: MCPRequestContext) async throws -> JsonRpcMessage { + guard case .string(let uri)? = params?["uri"] else { + throw MCPProtocolError.invalidParams(detail: "Missing required parameter: uri") + } + + do { + let route = try Self.parseRoute(uri: uri) + let payload = try await Self.fetchPayload(for: route, services: services) + let text = Self.encodeJsonString(payload) + + let result: JsonValue = .object([ + "contents": .array([ + .object([ + "uri": .string(uri), + "mimeType": .string("application/json"), + "text": .string(text) + ]) + ]) + ]) + + Self.logger.debug("resources/read uri=\(uri, privacy: .public)") + MCPAuditLogger.logResourceRead( + tokenId: nil, + tokenName: context.principal.metadata.label, + uri: uri, + outcome: .success + ) + return MCPMethodHandlerHelpers.successResponse(id: context.requestId, result: result) + } catch { + MCPAuditLogger.logResourceRead( + tokenId: nil, + tokenName: context.principal.metadata.label, + uri: uri, + outcome: .error, + errorMessage: (error as? MCPProtocolError)?.message ?? error.localizedDescription + ) + throw error + } + } + + private enum ResourceRoute { + case connectionsList + case connectionSchema(connectionId: UUID) + case connectionHistory(connectionId: UUID, limit: Int, search: String?, dateFilter: String?) + } + + private static func parseRoute(uri: String) throws -> ResourceRoute { + guard let components = URLComponents(string: uri) else { + throw MCPProtocolError.invalidParams(detail: "Malformed URI: \(uri)") + } + guard components.scheme == "tablepro" else { + throw MCPProtocolError.invalidParams(detail: "Unsupported URI scheme: \(components.scheme ?? "nil")") + } + + let segments = pathSegments(from: uri) + + if segments == ["connections"] { + return .connectionsList + } + + guard segments.count == 3, segments[0] == "connections" else { + throw MCPProtocolError( + code: JsonRpcErrorCode.methodNotFound, + message: "Unknown resource URI: \(uri)", + httpStatus: .notFound + ) + } + + guard let connectionId = UUID(uuidString: segments[1]) else { + throw MCPProtocolError.invalidParams(detail: "Invalid connection UUID in URI") + } + + switch segments[2] { + case "schema": + return .connectionSchema(connectionId: connectionId) + case "history": + let queryItems = components.queryItems ?? [] + let rawLimit = queryItems.first(where: { $0.name == "limit" })?.value.flatMap { Int($0) } ?? 50 + let limit = min(max(rawLimit, 1), 500) + let search = queryItems.first(where: { $0.name == "search" })?.value + let dateFilter = queryItems.first(where: { $0.name == "date_filter" })?.value + return .connectionHistory( + connectionId: connectionId, + limit: limit, + search: search, + dateFilter: dateFilter + ) + default: + throw MCPProtocolError( + code: JsonRpcErrorCode.methodNotFound, + message: "Unknown resource URI: \(uri)", + httpStatus: .notFound + ) + } + } + + private static func fetchPayload(for route: ResourceRoute, services: MCPToolServices) async throws -> JsonValue { + switch route { + case .connectionsList: + return await services.connectionBridge.listConnections() + + case .connectionSchema(let connectionId): + do { + return try await services.connectionBridge.fetchSchemaResource(connectionId: connectionId) + } catch let error as MCPDataLayerError { + throw mapDomainError(error) + } + + case .connectionHistory(let connectionId, let limit, let search, let dateFilter): + do { + return try await services.connectionBridge.fetchHistoryResource( + connectionId: connectionId, + limit: limit, + search: search, + dateFilter: dateFilter + ) + } catch let error as MCPDataLayerError { + throw mapDomainError(error) + } + } + } + + private static func mapDomainError(_ error: MCPDataLayerError) -> MCPProtocolError { + switch error { + case .invalidArgument(let detail): + return MCPProtocolError.invalidParams(detail: detail) + case .notConnected(let id): + return MCPProtocolError.invalidParams(detail: "Connection not active: \(id.uuidString)") + case .forbidden(let reason, _): + return MCPProtocolError.forbidden(reason: reason) + case .notFound(let detail): + return MCPProtocolError( + code: JsonRpcErrorCode.resourceNotFound, + message: detail, + httpStatus: .notFound + ) + case .expired(let detail): + return MCPProtocolError( + code: JsonRpcErrorCode.expired, + message: detail, + httpStatus: .ok + ) + case .timeout(let detail, _): + return MCPProtocolError( + code: JsonRpcErrorCode.requestTimeout, + message: "Timeout: \(detail)", + httpStatus: .ok + ) + case .userCancelled: + return MCPProtocolError( + code: JsonRpcErrorCode.requestCancelled, + message: "User cancelled", + httpStatus: .ok + ) + case .dataSourceError(let detail): + return MCPProtocolError.internalError(detail: detail) + } + } + + private static func pathSegments(from uri: String) -> [String] { + guard let range = uri.range(of: "://") else { return [] } + let afterScheme = String(uri[range.upperBound...]) + let pathOnly: String + if let queryStart = afterScheme.firstIndex(of: "?") { + pathOnly = String(afterScheme[.. String { + let encoder = JSONEncoder() + encoder.outputFormatting = [.sortedKeys] + guard let data = try? encoder.encode(value), + let string = String(data: data, encoding: .utf8) else { + return "{}" + } + return string + } +} diff --git a/TablePro/Core/MCP/Protocol/Handlers/ResourcesTemplatesListHandler.swift b/TablePro/Core/MCP/Protocol/Handlers/ResourcesTemplatesListHandler.swift new file mode 100644 index 000000000..9e32509b1 --- /dev/null +++ b/TablePro/Core/MCP/Protocol/Handlers/ResourcesTemplatesListHandler.swift @@ -0,0 +1,33 @@ +import Foundation +import os + +public struct ResourcesTemplatesListHandler: MCPMethodHandler { + public static let method = "resources/templates/list" + public static let requiredScopes: Set = [.resourcesRead] + public static let allowedSessionStates: Set = [.ready] + + private static let logger = Logger(subsystem: "com.TablePro", category: "MCP.Resources") + + public init() {} + + public func handle(params: JsonValue?, context: MCPRequestContext) async throws -> JsonRpcMessage { + let templates: [JsonValue] = [ + .object([ + "uriTemplate": .string("tablepro://connections/{id}/schema"), + "name": .string(String(localized: "Database Schema")), + "description": .string(String(localized: "Tables, columns, indexes, and foreign keys for a connected database")), + "mimeType": .string("application/json") + ]), + .object([ + "uriTemplate": .string("tablepro://connections/{id}/history"), + "name": .string(String(localized: "Query History")), + "description": .string(String(localized: "Recent query history for a connection (supports ?limit=, ?search=, ?date_filter=)")), + "mimeType": .string("application/json") + ]) + ] + + let result: JsonValue = .object(["resourceTemplates": .array(templates)]) + Self.logger.debug("resources/templates/list returned \(templates.count, privacy: .public) templates") + return MCPMethodHandlerHelpers.successResponse(id: context.requestId, result: result) + } +} diff --git a/TablePro/Core/MCP/Protocol/Handlers/ToolsCallHandler.swift b/TablePro/Core/MCP/Protocol/Handlers/ToolsCallHandler.swift new file mode 100644 index 000000000..6674a4714 --- /dev/null +++ b/TablePro/Core/MCP/Protocol/Handlers/ToolsCallHandler.swift @@ -0,0 +1,73 @@ +import Foundation +import os + +public struct ToolsCallHandler: MCPMethodHandler { + public static let method = "tools/call" + public static let requiredScopes: Set = [.toolsRead] + public static let allowedSessionStates: Set = [.ready] + + private static let logger = Logger(subsystem: "com.TablePro", category: "MCP.Tools") + + private let services: MCPToolServices + + public init(services: MCPToolServices) { + self.services = services + } + + public func handle(params: JsonValue?, context: MCPRequestContext) async throws -> JsonRpcMessage { + guard case .object(let object)? = params else { + throw MCPProtocolError.invalidParams(detail: "params must be object") + } + guard case .string(let toolName)? = object["name"] else { + throw MCPProtocolError.invalidParams(detail: "missing tool name") + } + let arguments = object["arguments"] ?? .object([:]) + + guard let tool = MCPToolRegistry.tool(named: toolName) else { + throw MCPProtocolError.methodNotFound(method: "tools/call:\(toolName)") + } + + let toolType = type(of: tool) + if !toolType.requiredScopes.isSubset(of: context.principal.scopes) { + MCPAuditLogger.logToolCalled( + tokenId: nil, + tokenName: context.principal.metadata.label, + toolName: toolName, + connectionId: Self.connectionId(in: arguments), + outcome: .denied, + errorMessage: "missing_scope" + ) + throw MCPProtocolError.forbidden(reason: "Tool '\(toolName)' requires additional scopes") + } + + Self.logger.info("tools/call name=\(toolName, privacy: .public)") + + do { + let result = try await tool.call(arguments: arguments, context: context, services: services) + MCPAuditLogger.logToolCalled( + tokenId: nil, + tokenName: context.principal.metadata.label, + toolName: toolName, + connectionId: Self.connectionId(in: arguments), + outcome: result.isError ? .error : .success + ) + return MCPMethodHandlerHelpers.successResponse(id: context.requestId, result: result.asJsonValue()) + } catch { + MCPAuditLogger.logToolCalled( + tokenId: nil, + tokenName: context.principal.metadata.label, + toolName: toolName, + connectionId: Self.connectionId(in: arguments), + outcome: .error, + errorMessage: (error as? MCPProtocolError)?.message ?? error.localizedDescription + ) + throw error + } + } + + private static func connectionId(in arguments: JsonValue) -> UUID? { + guard case .object(let object) = arguments, + case .string(let value)? = object["connection_id"] else { return nil } + return UUID(uuidString: value) + } +} diff --git a/TablePro/Core/MCP/Protocol/Handlers/ToolsListHandler.swift b/TablePro/Core/MCP/Protocol/Handlers/ToolsListHandler.swift new file mode 100644 index 000000000..6c45d2d8c --- /dev/null +++ b/TablePro/Core/MCP/Protocol/Handlers/ToolsListHandler.swift @@ -0,0 +1,29 @@ +import Foundation + +public struct ToolsListHandler: MCPMethodHandler { + public static let method = "tools/list" + public static let requiredScopes: Set = [.toolsRead] + public static let allowedSessionStates: Set = [.ready] + + public init() {} + + public func handle(params: JsonValue?, context: MCPRequestContext) async throws -> JsonRpcMessage { + let tools: [JsonValue] = MCPToolRegistry.allTools.map { tool in + let toolType = type(of: tool) + var fields: [String: JsonValue] = [ + "name": .string(toolType.name), + "description": .string(toolType.description), + "inputSchema": toolType.inputSchema + ] + if let title = toolType.title { + fields["title"] = .string(title) + } + if let annotationsValue = toolType.annotations.asJsonValue { + fields["annotations"] = annotationsValue + } + return .object(fields) + } + let result: JsonValue = .object(["tools": .array(tools)]) + return MCPMethodHandlerHelpers.successResponse(id: context.requestId, result: result) + } +} diff --git a/TablePro/Core/MCP/Protocol/MCPCancellationToken.swift b/TablePro/Core/MCP/Protocol/MCPCancellationToken.swift new file mode 100644 index 000000000..7fd6431d3 --- /dev/null +++ b/TablePro/Core/MCP/Protocol/MCPCancellationToken.swift @@ -0,0 +1,36 @@ +import Foundation + +public actor MCPCancellationToken { + private var cancelled: Bool = false + private var handlers: [@Sendable () async -> Void] = [] + + public init() {} + + public func cancel() async { + guard !cancelled else { return } + cancelled = true + let toRun = handlers + handlers.removeAll() + for handler in toRun { + await handler() + } + } + + public func isCancelled() async -> Bool { + cancelled + } + + public func onCancel(_ handler: @Sendable @escaping () async -> Void) async { + if cancelled { + await handler() + return + } + handlers.append(handler) + } + + public func throwIfCancelled() async throws { + if cancelled { + throw CancellationError() + } + } +} diff --git a/TablePro/Core/MCP/Protocol/MCPInflightRegistry.swift b/TablePro/Core/MCP/Protocol/MCPInflightRegistry.swift new file mode 100644 index 000000000..5d7bd0426 --- /dev/null +++ b/TablePro/Core/MCP/Protocol/MCPInflightRegistry.swift @@ -0,0 +1,50 @@ +import Foundation + +actor MCPInflightRegistry { + private struct Key: Hashable { + let sessionId: MCPSessionId + let requestId: JsonRpcId + } + + private struct Entry { + let token: MCPCancellationToken + let tokenId: UUID? + } + + private var entries: [Key: Entry] = [:] + + func register( + requestId: JsonRpcId, + sessionId: MCPSessionId, + token: MCPCancellationToken, + tokenId: UUID? = nil + ) { + entries[Key(sessionId: sessionId, requestId: requestId)] = Entry( + token: token, + tokenId: tokenId + ) + } + + func cancel(requestId: JsonRpcId, sessionId: MCPSessionId) async { + let key = Key(sessionId: sessionId, requestId: requestId) + guard let entry = entries.removeValue(forKey: key) else { return } + await entry.token.cancel() + } + + func cancelAll(matchingTokenId tokenId: UUID) async -> [MCPSessionId] { + let matching = entries.filter { $0.value.tokenId == tokenId } + for (key, entry) in matching { + await entry.token.cancel() + entries.removeValue(forKey: key) + } + return Array(Set(matching.map { $0.key.sessionId })) + } + + func remove(requestId: JsonRpcId, sessionId: MCPSessionId) { + entries.removeValue(forKey: Key(sessionId: sessionId, requestId: requestId)) + } + + func count() -> Int { + entries.count + } +} diff --git a/TablePro/Core/MCP/Protocol/MCPMethodHandler.swift b/TablePro/Core/MCP/Protocol/MCPMethodHandler.swift new file mode 100644 index 000000000..fb803e420 --- /dev/null +++ b/TablePro/Core/MCP/Protocol/MCPMethodHandler.swift @@ -0,0 +1,35 @@ +import Foundation + +public enum MCPSessionAllowedState: Sendable, Equatable, Hashable { + case uninitialized + case ready +} + +public protocol MCPMethodHandler: Sendable { + static var method: String { get } + static var requiredScopes: Set { get } + static var allowedSessionStates: Set { get } + func handle(params: JsonValue?, context: MCPRequestContext) async throws -> JsonRpcMessage +} + +public extension MCPMethodHandler { + var method: String { Self.method } + var requiredScopes: Set { Self.requiredScopes } + var allowedSessionStates: Set { Self.allowedSessionStates } +} + +public enum MCPMethodHandlerHelpers { + public static func successResponse(id: JsonRpcId?, result: JsonValue) -> JsonRpcMessage { + guard let id else { + return .errorResponse(JsonRpcErrorResponse( + id: nil, + error: JsonRpcError.invalidRequest(message: "Missing request id") + )) + } + return .successResponse(JsonRpcSuccessResponse(id: id, result: result)) + } + + public static func errorResponse(id: JsonRpcId?, error: MCPProtocolError) -> JsonRpcMessage { + .errorResponse(error.toJsonRpcErrorResponse(id: id)) + } +} diff --git a/TablePro/Core/MCP/Protocol/MCPProgressEmitter.swift b/TablePro/Core/MCP/Protocol/MCPProgressEmitter.swift new file mode 100644 index 000000000..b7c762009 --- /dev/null +++ b/TablePro/Core/MCP/Protocol/MCPProgressEmitter.swift @@ -0,0 +1,52 @@ +import Foundation + +public protocol MCPProgressSink: Sendable { + func sendNotification(_ notification: JsonRpcNotification, toSession sessionId: MCPSessionId) async +} + +public actor MCPProgressEmitter { + private let progressToken: JsonValue? + private let target: any MCPProgressSink + private let sessionId: MCPSessionId + + public init(progressToken: JsonValue?, target: any MCPProgressSink, sessionId: MCPSessionId) { + self.progressToken = progressToken + self.target = target + self.sessionId = sessionId + } + + public func emit(progress: Double, total: Double? = nil, message: String? = nil) async { + guard let progressToken else { return } + + var params: [String: JsonValue] = [ + "progressToken": progressToken, + "progress": .double(progress) + ] + if let total { + params["total"] = .double(total) + } + if let message { + params["message"] = .string(message) + } + + let notification = JsonRpcNotification( + method: "notifications/progress", + params: .object(params) + ) + await target.sendNotification(notification, toSession: sessionId) + } + + public func emitNotification(method: String, params: JsonValue?) async { + let notification = JsonRpcNotification(method: method, params: params) + await target.sendNotification(notification, toSession: sessionId) + } + + public var hasProgressToken: Bool { + progressToken != nil + } + + public static func extractProgressToken(from params: JsonValue?) -> JsonValue? { + guard let meta = params?["_meta"] else { return nil } + return meta["progressToken"] + } +} diff --git a/TablePro/Core/MCP/Protocol/MCPProtocolDispatcher.swift b/TablePro/Core/MCP/Protocol/MCPProtocolDispatcher.swift new file mode 100644 index 000000000..3091f3aae --- /dev/null +++ b/TablePro/Core/MCP/Protocol/MCPProtocolDispatcher.swift @@ -0,0 +1,255 @@ +import Foundation +import os + +public actor MCPProtocolDispatcher { + private static let logger = Logger(subsystem: "com.TablePro", category: "MCP.Dispatcher") + + private let handlers: [String: any MCPMethodHandler] + private let sessionStore: MCPSessionStore + private let progressSink: any MCPProgressSink + private let clock: any MCPClock + private let inflight: MCPInflightRegistry + + public init( + handlers: [any MCPMethodHandler], + sessionStore: MCPSessionStore, + progressSink: any MCPProgressSink, + clock: any MCPClock = MCPSystemClock() + ) { + var map: [String: any MCPMethodHandler] = [:] + for handler in handlers { + map[type(of: handler).method] = handler + } + self.handlers = map + self.sessionStore = sessionStore + self.progressSink = progressSink + self.clock = clock + self.inflight = MCPInflightRegistry() + } + + public func dispatch(_ exchange: MCPInboundExchange) async { + switch exchange.message { + case .request(let request): + await handleRequest(request, exchange: exchange) + case .notification(let notification): + await handleNotification(notification, exchange: exchange) + case .successResponse, .errorResponse: + Self.logger.debug("Ignoring inbound response message") + await exchange.responder.acknowledgeAccepted() + } + } + + public func cancel(requestId: JsonRpcId, sessionId: MCPSessionId) async { + await inflight.cancel(requestId: requestId, sessionId: sessionId) + } + + public func cancelInflight(matchingTokenId tokenId: UUID) async -> [MCPSessionId] { + await inflight.cancelAll(matchingTokenId: tokenId) + } + + private func handleRequest(_ request: JsonRpcRequest, exchange: MCPInboundExchange) async { + guard let handler = handlers[request.method] else { + await respondError( + exchange: exchange, + requestId: request.id, + error: .methodNotFound(method: request.method) + ) + return + } + + let session = await resolveOrCreateSession(method: request.method, exchange: exchange) + guard let session else { + await respondError( + exchange: exchange, + requestId: request.id, + error: .sessionNotFound() + ) + return + } + + let allowed = type(of: handler).allowedSessionStates + let stateCheck = await checkSessionState(session: session, allowed: allowed) + if let stateError = stateCheck { + await respondError(exchange: exchange, requestId: request.id, error: stateError) + return + } + + guard let principal = exchange.context.principal else { + await respondError( + exchange: exchange, + requestId: request.id, + error: .unauthenticated() + ) + return + } + + let required = type(of: handler).requiredScopes + if !required.isEmpty, !required.isSubset(of: principal.scopes) { + await respondError( + exchange: exchange, + requestId: request.id, + error: .forbidden(reason: "missing required scopes") + ) + return + } + + await session.touch(now: await clock.now()) + await session.bindPrincipal(tokenId: principal.tokenId) + + let token = MCPCancellationToken() + await inflight.register( + requestId: request.id, + sessionId: session.id, + token: token, + tokenId: principal.tokenId + ) + + let progressToken = MCPProgressEmitter.extractProgressToken(from: request.params) + let emitter = MCPProgressEmitter( + progressToken: progressToken, + target: progressSink, + sessionId: session.id + ) + + let context = MCPRequestContext( + exchange: exchange, + session: session, + principal: principal, + dispatcher: self, + progress: emitter, + cancellation: token, + clock: clock + ) + + let response = await invokeHandler(handler, params: request.params, context: context, requestId: request.id) + await inflight.remove(requestId: request.id, sessionId: session.id) + await exchange.responder.respond(response, sessionId: session.id) + } + + private func invokeHandler( + _ handler: any MCPMethodHandler, + params: JsonValue?, + context: MCPRequestContext, + requestId: JsonRpcId + ) async -> JsonRpcMessage { + do { + return try await handler.handle(params: params, context: context) + } catch let error as MCPProtocolError { + return MCPMethodHandlerHelpers.errorResponse(id: requestId, error: error) + } catch is CancellationError { + return MCPMethodHandlerHelpers.errorResponse( + id: requestId, + error: MCPProtocolError( + code: JsonRpcErrorCode.requestCancelled, + message: "Request cancelled", + httpStatus: .ok + ) + ) + } catch { + Self.logger.error("Handler threw error: \(error.localizedDescription, privacy: .public)") + return MCPMethodHandlerHelpers.errorResponse( + id: requestId, + error: .internalError(detail: error.localizedDescription) + ) + } + } + + private func handleNotification(_ notification: JsonRpcNotification, exchange: MCPInboundExchange) async { + if notification.method == "notifications/cancelled" { + await handleCancellationNotification(notification, exchange: exchange) + await exchange.responder.acknowledgeAccepted() + return + } + + if notification.method == "notifications/initialized" { + if let sessionId = exchange.context.sessionId, + let session = await sessionStore.session(id: sessionId) { + let state = await session.state + if case .initializing = state { + do { + try await session.transitionToReady() + } catch { + Self.logger.warning( + "Failed to transition session to ready: \(error.localizedDescription, privacy: .public)" + ) + } + } + } + await exchange.responder.acknowledgeAccepted() + return + } + + await exchange.responder.acknowledgeAccepted() + } + + private func handleCancellationNotification( + _ notification: JsonRpcNotification, + exchange: MCPInboundExchange + ) async { + guard let params = notification.params, + let sessionId = exchange.context.sessionId + else { return } + + let requestIdValue = params["requestId"] + let cancelId: JsonRpcId? + switch requestIdValue { + case .string(let value): + cancelId = .string(value) + case .int(let value): + cancelId = .number(Int64(value)) + case .double(let value): + cancelId = .number(Int64(value)) + default: + cancelId = nil + } + + guard let cancelId else { return } + await inflight.cancel(requestId: cancelId, sessionId: sessionId) + } + + private func resolveOrCreateSession(method: String, exchange: MCPInboundExchange) async -> MCPSession? { + if method == "initialize" { + if let sessionId = exchange.context.sessionId, + let existing = await sessionStore.session(id: sessionId) { + return existing + } + do { + return try await sessionStore.create() + } catch { + Self.logger.warning( + "Failed to create session: \(error.localizedDescription, privacy: .public)" + ) + return nil + } + } + + guard let sessionId = exchange.context.sessionId else { return nil } + return await sessionStore.session(id: sessionId) + } + + private func checkSessionState( + session: MCPSession, + allowed: Set + ) async -> MCPProtocolError? { + let state = await session.state + switch state { + case .initializing: + if allowed.contains(.uninitialized) { return nil } + return .invalidRequest(detail: "Session not initialized") + case .ready: + if allowed.contains(.ready) { return nil } + return .invalidRequest(detail: "Session already initialized") + case .terminated: + return .sessionNotFound(message: "Session terminated") + } + } + + private func respondError( + exchange: MCPInboundExchange, + requestId: JsonRpcId, + error: MCPProtocolError + ) async { + let response = MCPMethodHandlerHelpers.errorResponse(id: requestId, error: error) + await exchange.responder.respond(response, sessionId: exchange.context.sessionId) + } +} diff --git a/TablePro/Core/MCP/Protocol/MCPRequestContext.swift b/TablePro/Core/MCP/Protocol/MCPRequestContext.swift new file mode 100644 index 000000000..707418c4a --- /dev/null +++ b/TablePro/Core/MCP/Protocol/MCPRequestContext.swift @@ -0,0 +1,50 @@ +import Foundation + +public struct MCPRequestContext: Sendable { + public let exchange: MCPInboundExchange + public let session: MCPSession + public let principal: MCPPrincipal + public let dispatcher: MCPProtocolDispatcher + public let progress: MCPProgressEmitter + public let cancellation: MCPCancellationToken + public let clock: any MCPClock + + public init( + exchange: MCPInboundExchange, + session: MCPSession, + principal: MCPPrincipal, + dispatcher: MCPProtocolDispatcher, + progress: MCPProgressEmitter, + cancellation: MCPCancellationToken, + clock: any MCPClock + ) { + self.exchange = exchange + self.session = session + self.principal = principal + self.dispatcher = dispatcher + self.progress = progress + self.cancellation = cancellation + self.clock = clock + } + + public var requestId: JsonRpcId? { + if case .request(let request) = exchange.message { + return request.id + } + return nil + } + + public var sessionId: MCPSessionId { + session.id + } + + public var requestParams: JsonValue? { + if case .request(let request) = exchange.message { + return request.params + } + if case .notification(let notification) = exchange.message { + return notification.params + } + return nil + } +} diff --git a/TablePro/Core/MCP/Protocol/Tools/ConfirmDestructiveOperationTool.swift b/TablePro/Core/MCP/Protocol/Tools/ConfirmDestructiveOperationTool.swift new file mode 100644 index 000000000..930826b52 --- /dev/null +++ b/TablePro/Core/MCP/Protocol/Tools/ConfirmDestructiveOperationTool.swift @@ -0,0 +1,99 @@ +import Foundation +import os + +public struct ConfirmDestructiveOperationTool: MCPToolImplementation { + public static let name = "confirm_destructive_operation" + public static let description = String( + localized: "Execute a destructive DDL query (DROP, TRUNCATE, ALTER...DROP) after explicit confirmation." + ) + public static let inputSchema: JsonValue = .object([ + "type": .string("object"), + "properties": .object([ + "connection_id": .object([ + "type": .string("string"), + "description": .string(String(localized: "UUID of the active connection")) + ]), + "query": .object([ + "type": .string("string"), + "description": .string(String(localized: "The destructive query to execute")) + ]), + "confirmation_phrase": .object([ + "type": .string("string"), + "description": .string(String(localized: "Must be exactly: I understand this is irreversible")) + ]) + ]), + "required": .array([ + .string("connection_id"), + .string("query"), + .string("confirmation_phrase") + ]) + ]) + public static let requiredScopes: Set = [.toolsWrite] + public static let annotations = MCPToolAnnotations( + title: String(localized: "Confirm Destructive Operation"), + readOnlyHint: false, + destructiveHint: true, + idempotentHint: false, + openWorldHint: true + ) + + private static let logger = Logger(subsystem: "com.TablePro", category: "MCP.Tools") + private static let requiredPhrase = "I understand this is irreversible" + + public init() {} + + public func call( + arguments: JsonValue, + context: MCPRequestContext, + services: MCPToolServices + ) async throws -> MCPToolCallResult { + let connectionId = try MCPArgumentDecoder.requireUuid(arguments, key: "connection_id") + let query = try MCPArgumentDecoder.requireString(arguments, key: "query") + let confirmationPhrase = try MCPArgumentDecoder.requireString(arguments, key: "confirmation_phrase") + + guard confirmationPhrase == Self.requiredPhrase else { + throw MCPProtocolError.invalidParams( + detail: "confirmation_phrase must be exactly: \(Self.requiredPhrase)" + ) + } + + guard !QueryClassifier.isMultiStatement(query) else { + throw MCPProtocolError.invalidParams( + detail: "Multi-statement queries are not supported. Send one statement at a time." + ) + } + + let meta = try await ToolConnectionMetadata.resolve(connectionId: connectionId) + + let tier = QueryClassifier.classifyTier(query, databaseType: meta.databaseType) + guard tier == .destructive else { + throw MCPProtocolError.invalidParams( + detail: "This tool only accepts destructive queries (DROP, TRUNCATE, ALTER...DROP). Use execute_query for other queries." + ) + } + + try await services.authPolicy.checkSafeModeDialog( + sql: query, + connectionId: connectionId, + databaseType: meta.databaseType, + safeModeLevel: meta.safeModeLevel + ) + + let mcpSettings = await MainActor.run { AppSettingsManager.shared.mcp } + let timeoutSeconds = mcpSettings.queryTimeoutSeconds + + Self.logger.debug("confirm_destructive_operation invoked for connection \(connectionId.uuidString, privacy: .public)") + + let result = try await ToolQueryExecutor.executeAndLog( + services: services, + query: query, + connectionId: connectionId, + databaseName: meta.databaseName, + maxRows: 0, + timeoutSeconds: timeoutSeconds, + principalLabel: context.principal.metadata.label + ) + + return .structured(result) + } +} diff --git a/TablePro/Core/MCP/Protocol/Tools/ConnectTool.swift b/TablePro/Core/MCP/Protocol/Tools/ConnectTool.swift new file mode 100644 index 000000000..5ba8ca898 --- /dev/null +++ b/TablePro/Core/MCP/Protocol/Tools/ConnectTool.swift @@ -0,0 +1,40 @@ +import Foundation +import os + +public struct ConnectTool: MCPToolImplementation { + public static let name = "connect" + public static let description = String(localized: "Connect to a saved database") + public static let inputSchema: JsonValue = .object([ + "type": .string("object"), + "properties": .object([ + "connection_id": .object([ + "type": .string("string"), + "description": .string(String(localized: "UUID of the saved connection")) + ]) + ]), + "required": .array([.string("connection_id")]) + ]) + public static let requiredScopes: Set = [.toolsRead] + public static let annotations = MCPToolAnnotations( + title: String(localized: "Connect"), + readOnlyHint: false, + destructiveHint: false, + idempotentHint: true, + openWorldHint: true + ) + + private static let logger = Logger(subsystem: "com.TablePro", category: "MCP.Tools") + + public init() {} + + public func call( + arguments: JsonValue, + context: MCPRequestContext, + services: MCPToolServices + ) async throws -> MCPToolCallResult { + let connectionId = try MCPArgumentDecoder.requireUuid(arguments, key: "connection_id") + Self.logger.debug("connect tool invoked for connection \(connectionId.uuidString, privacy: .public)") + let payload = try await services.connectionBridge.connect(connectionId: connectionId) + return .structured(payload) + } +} diff --git a/TablePro/Core/MCP/Protocol/Tools/DescribeTableTool.swift b/TablePro/Core/MCP/Protocol/Tools/DescribeTableTool.swift new file mode 100644 index 000000000..c034c6a57 --- /dev/null +++ b/TablePro/Core/MCP/Protocol/Tools/DescribeTableTool.swift @@ -0,0 +1,54 @@ +import Foundation + +public struct DescribeTableTool: MCPToolImplementation { + public static let name = "describe_table" + public static let description = String( + localized: "Get detailed table structure: columns, indexes, foreign keys, and DDL" + ) + public static let requiredScopes: Set = [.toolsRead] + public static let annotations = MCPToolAnnotations( + title: String(localized: "Describe Table"), + readOnlyHint: true, + destructiveHint: false, + idempotentHint: true, + openWorldHint: false + ) + + public static let inputSchema: JsonValue = .object([ + "type": .string("object"), + "properties": .object([ + "connection_id": .object([ + "type": .string("string"), + "description": .string(String(localized: "UUID of the connection")) + ]), + "table": .object([ + "type": .string("string"), + "description": .string(String(localized: "Table name")) + ]), + "schema": .object([ + "type": .string("string"), + "description": .string(String(localized: "Schema name (uses current if omitted)")) + ]) + ]), + "required": .array([.string("connection_id"), .string("table")]) + ]) + + public init() {} + + public func call( + arguments: JsonValue, + context: MCPRequestContext, + services: MCPToolServices + ) async throws -> MCPToolCallResult { + let connectionId = try MCPArgumentDecoder.requireUuid(arguments, key: "connection_id") + let table = try MCPArgumentDecoder.requireString(arguments, key: "table") + let schema = MCPArgumentDecoder.optionalString(arguments, key: "schema") + + let payload = try await services.connectionBridge.describeTable( + connectionId: connectionId, + table: table, + schema: schema + ) + return .structured(payload) + } +} diff --git a/TablePro/Core/MCP/Protocol/Tools/DisconnectTool.swift b/TablePro/Core/MCP/Protocol/Tools/DisconnectTool.swift new file mode 100644 index 000000000..a41f3726d --- /dev/null +++ b/TablePro/Core/MCP/Protocol/Tools/DisconnectTool.swift @@ -0,0 +1,41 @@ +import Foundation +import os + +public struct DisconnectTool: MCPToolImplementation { + public static let name = "disconnect" + public static let description = String(localized: "Disconnect from a database") + public static let inputSchema: JsonValue = .object([ + "type": .string("object"), + "properties": .object([ + "connection_id": .object([ + "type": .string("string"), + "description": .string(String(localized: "UUID of the connection to disconnect")) + ]) + ]), + "required": .array([.string("connection_id")]) + ]) + public static let requiredScopes: Set = [.toolsWrite] + public static let annotations = MCPToolAnnotations( + title: String(localized: "Disconnect"), + readOnlyHint: false, + destructiveHint: false, + idempotentHint: true, + openWorldHint: true + ) + + private static let logger = Logger(subsystem: "com.TablePro", category: "MCP.Tools") + + public init() {} + + public func call( + arguments: JsonValue, + context: MCPRequestContext, + services: MCPToolServices + ) async throws -> MCPToolCallResult { + let connectionId = try MCPArgumentDecoder.requireUuid(arguments, key: "connection_id") + Self.logger.debug("disconnect tool invoked for connection \(connectionId.uuidString, privacy: .public)") + try await services.connectionBridge.disconnect(connectionId: connectionId) + let result: JsonValue = .object(["status": .string("disconnected")]) + return .structured(result) + } +} diff --git a/TablePro/Core/MCP/Protocol/Tools/ExecuteQueryTool.swift b/TablePro/Core/MCP/Protocol/Tools/ExecuteQueryTool.swift new file mode 100644 index 000000000..1b85e8391 --- /dev/null +++ b/TablePro/Core/MCP/Protocol/Tools/ExecuteQueryTool.swift @@ -0,0 +1,175 @@ +import Foundation +import os + +public struct ExecuteQueryTool: MCPToolImplementation { + public static let name = "execute_query" + public static let description = String( + localized: "Execute a SQL query. All queries are subject to the connection's safe mode policy. DROP/TRUNCATE/ALTER...DROP must use the confirm_destructive_operation tool." + ) + public static let inputSchema: JsonValue = .object([ + "type": .string("object"), + "properties": .object([ + "connection_id": .object([ + "type": .string("string"), + "description": .string(String(localized: "UUID of the connection")) + ]), + "query": .object([ + "type": .string("string"), + "description": .string(String(localized: "SQL or NoSQL query text")) + ]), + "max_rows": .object([ + "type": .string("integer"), + "description": .string(String(localized: "Maximum rows to return (default 500, max 10000)")) + ]), + "timeout_seconds": .object([ + "type": .string("integer"), + "description": .string(String(localized: "Query timeout in seconds (default 30, max 300)")) + ]), + "database": .object([ + "type": .string("string"), + "description": .string(String(localized: "Switch to this database before executing")) + ]), + "schema": .object([ + "type": .string("string"), + "description": .string(String(localized: "Switch to this schema before executing")) + ]) + ]), + "required": .array([.string("connection_id"), .string("query")]) + ]) + public static let requiredScopes: Set = [.toolsRead] + public static let annotations = MCPToolAnnotations( + title: String(localized: "Execute Query"), + readOnlyHint: false, + destructiveHint: false, + idempotentHint: false, + openWorldHint: true + ) + + private static let logger = Logger(subsystem: "com.TablePro", category: "MCP.Tools") + + public init() {} + + public func call( + arguments: JsonValue, + context: MCPRequestContext, + services: MCPToolServices + ) async throws -> MCPToolCallResult { + let connectionId = try MCPArgumentDecoder.requireUuid(arguments, key: "connection_id") + let query = try MCPArgumentDecoder.requireString(arguments, key: "query") + + let mcpSettings = await MainActor.run { AppSettingsManager.shared.mcp } + let maxRows = MCPArgumentDecoder.optionalInt( + arguments, + key: "max_rows", + default: mcpSettings.defaultRowLimit, + clamp: 1...mcpSettings.maxRowLimit + ) ?? mcpSettings.defaultRowLimit + let timeoutSeconds = MCPArgumentDecoder.optionalInt( + arguments, + key: "timeout_seconds", + default: mcpSettings.queryTimeoutSeconds, + clamp: 1...300 + ) ?? mcpSettings.queryTimeoutSeconds + let database = MCPArgumentDecoder.optionalString(arguments, key: "database") + let schema = MCPArgumentDecoder.optionalString(arguments, key: "schema") + + guard (query as NSString).length <= 102_400 else { + throw MCPProtocolError.invalidParams(detail: "Query exceeds 100KB limit") + } + + guard !QueryClassifier.isMultiStatement(query) else { + throw MCPProtocolError.invalidParams( + detail: "Multi-statement queries are not supported. Send one statement at a time." + ) + } + + try await throwIfCancelled(context) + await context.progress.emit(progress: 0.0, total: 1.0, message: "Connecting") + + let meta = try await ToolConnectionMetadata.resolve(connectionId: connectionId) + + if let database { + _ = try await services.connectionBridge.switchDatabase( + connectionId: connectionId, + database: database + ) + } + if let schema { + _ = try await services.connectionBridge.switchSchema( + connectionId: connectionId, + schema: schema + ) + } + + try await throwIfCancelled(context) + await context.progress.emit(progress: 0.2, total: 1.0, message: "Executing") + + let tier = QueryClassifier.classifyTier(query, databaseType: meta.databaseType) + try classifyAndAuthorize( + tier: tier, + query: query, + connectionId: connectionId, + meta: meta, + services: services, + context: context + ) + + try await services.authPolicy.checkSafeModeDialog( + sql: query, + connectionId: connectionId, + databaseType: meta.databaseType, + safeModeLevel: meta.safeModeLevel + ) + + Self.logger.debug("execute_query invoked for connection \(connectionId.uuidString, privacy: .public)") + + let result = try await ToolQueryExecutor.executeAndLog( + services: services, + query: query, + connectionId: connectionId, + databaseName: meta.databaseName, + maxRows: maxRows, + timeoutSeconds: timeoutSeconds, + principalLabel: context.principal.metadata.label + ) + + try await throwIfCancelled(context) + await context.progress.emit(progress: 0.8, total: 1.0, message: "Formatting result") + + await context.progress.emit(progress: 1.0, total: 1.0, message: "Done") + return .structured(result) + } + + private func classifyAndAuthorize( + tier: QueryTier, + query: String, + connectionId: UUID, + meta: ToolConnectionMetadata, + services: MCPToolServices, + context: MCPRequestContext + ) throws { + switch tier { + case .destructive: + throw MCPProtocolError.forbidden( + reason: "Destructive queries (DROP, TRUNCATE, ALTER...DROP) cannot be executed via execute_query. Use the confirm_destructive_operation tool instead." + ) + case .write: + guard context.principal.scopes.contains(.toolsWrite) else { + throw MCPProtocolError.forbidden( + reason: "Principal lacks tools:write scope required for write queries" + ) + } + case .safe: + return + } + } + + private func throwIfCancelled(_ context: MCPRequestContext) async throws { + guard await context.cancellation.isCancelled() else { return } + throw MCPProtocolError( + code: JsonRpcErrorCode.requestCancelled, + message: "Cancelled", + httpStatus: .ok + ) + } +} diff --git a/TablePro/Core/MCP/Protocol/Tools/ExportDataTool.swift b/TablePro/Core/MCP/Protocol/Tools/ExportDataTool.swift new file mode 100644 index 000000000..5bcbffea4 --- /dev/null +++ b/TablePro/Core/MCP/Protocol/Tools/ExportDataTool.swift @@ -0,0 +1,324 @@ +import Foundation +import os + +public struct ExportDataTool: MCPToolImplementation { + public static let name = "export_data" + public static let description = String( + localized: "Export query results or table data to CSV, JSON, or SQL" + ) + public static let inputSchema: JsonValue = .object([ + "type": .string("object"), + "properties": .object([ + "connection_id": .object([ + "type": .string("string"), + "description": .string(String(localized: "UUID of the connection")) + ]), + "format": .object([ + "type": .string("string"), + "description": .string(String(localized: "Export format: csv, json, or sql")), + "enum": .array([.string("csv"), .string("json"), .string("sql")]) + ]), + "query": .object([ + "type": .string("string"), + "description": .string(String(localized: "SQL query to export results from")) + ]), + "tables": .object([ + "type": .string("array"), + "description": .string(String(localized: "Table names to export (alternative to query)")), + "items": .object(["type": .string("string")]) + ]), + "output_path": .object([ + "type": .string("string"), + "description": .string(String(localized: "File path inside the user's Downloads directory (returns inline data if omitted). Paths outside Downloads are rejected.")) + ]), + "max_rows": .object([ + "type": .string("integer"), + "description": .string(String(localized: "Maximum rows to export (default 50000)")) + ]) + ]), + "required": .array([.string("connection_id"), .string("format")]) + ]) + public static let requiredScopes: Set = [.toolsRead] + public static let annotations = MCPToolAnnotations( + title: String(localized: "Export Data"), + readOnlyHint: false, + destructiveHint: false, + idempotentHint: false, + openWorldHint: true + ) + + private static let logger = Logger(subsystem: "com.TablePro", category: "MCP.Tools") + private static let allowedFormats: Set = ["csv", "json", "sql"] + private static let exportTableNamePattern = "^[A-Za-z0-9_]+(\\.[A-Za-z0-9_]+)*$" + + public init() {} + + public func call( + arguments: JsonValue, + context: MCPRequestContext, + services: MCPToolServices + ) async throws -> MCPToolCallResult { + let connectionId = try MCPArgumentDecoder.requireUuid(arguments, key: "connection_id") + let format = try MCPArgumentDecoder.requireString(arguments, key: "format") + let query = MCPArgumentDecoder.optionalString(arguments, key: "query") + let tables = MCPArgumentDecoder.optionalStringArray(arguments, key: "tables") + let outputPath = MCPArgumentDecoder.optionalString(arguments, key: "output_path") + let maxRows = MCPArgumentDecoder.optionalInt( + arguments, + key: "max_rows", + default: 50_000, + clamp: 1...100_000 + ) ?? 50_000 + + guard Self.allowedFormats.contains(format) else { + throw MCPProtocolError.invalidParams( + detail: "Unsupported format: \(format). Must be csv, json, or sql" + ) + } + + guard query != nil || tables != nil else { + throw MCPProtocolError.invalidParams(detail: "Either 'query' or 'tables' must be provided") + } + + if let tables { + for table in tables { + try Self.validateExportTableName(table) + } + } + + if let outputPath { + _ = try Self.sandboxedDownloadsURL(for: outputPath) + } + + let meta = try await ToolConnectionMetadata.resolve(connectionId: connectionId) + var queries: [(label: String, sql: String)] = [] + + if let query { + try await services.authPolicy.checkSafeModeDialog( + sql: query, + connectionId: connectionId, + databaseType: meta.databaseType, + safeModeLevel: meta.safeModeLevel + ) + queries.append((label: "query", sql: query)) + } else if let tables { + let quoteIdentifier = Self.identifierQuoter(for: meta.databaseType) + for table in tables { + let quoted = try Self.quoteQualifiedIdentifier(table, quoter: quoteIdentifier) + let sql = "SELECT * FROM \(quoted) LIMIT \(maxRows)" + try await services.authPolicy.checkSafeModeDialog( + sql: sql, + connectionId: connectionId, + databaseType: meta.databaseType, + safeModeLevel: meta.safeModeLevel + ) + queries.append((label: table, sql: sql)) + } + } + + var exportResults: [JsonValue] = [] + var totalRowsExported = 0 + + for (label, sql) in queries { + let result = try await services.connectionBridge.executeQuery( + connectionId: connectionId, + query: sql, + maxRows: maxRows, + timeoutSeconds: 60 + ) + + guard let columns = result["columns"]?.arrayValue, + let rows = result["rows"]?.arrayValue + else { + throw MCPProtocolError.internalError(detail: "Unexpected query result structure") + } + + let columnNames = columns.compactMap(\.stringValue) + let formatted: String + + switch format { + case "csv": + formatted = Self.formatCSV(columns: columnNames, rows: rows) + case "json": + formatted = Self.formatJSON(columns: columnNames, rows: rows) + case "sql": + formatted = Self.formatSQL(table: label, columns: columnNames, rows: rows) + default: + formatted = Self.formatCSV(columns: columnNames, rows: rows) + } + + totalRowsExported += rows.count + + exportResults.append(.object([ + "label": .string(label), + "format": .string(format), + "row_count": result["row_count"] ?? .int(0), + "data": .string(formatted) + ])) + } + + if let outputPath { + let fileURL = try Self.sandboxedDownloadsURL(for: outputPath) + let fullContent: String + if exportResults.count == 1, + let data = exportResults.first?["data"]?.stringValue + { + fullContent = data + } else { + fullContent = exportResults + .compactMap { $0["data"]?.stringValue } + .joined(separator: "\n\n") + } + try fullContent.write(to: fileURL, atomically: true, encoding: .utf8) + + let response: JsonValue = .object([ + "path": .string(fileURL.path), + "rows_exported": .int(totalRowsExported) + ]) + return .structured(response) + } + + let response: JsonValue + if exportResults.count == 1, let single = exportResults.first { + response = single + } else { + response = .object(["exports": .array(exportResults)]) + } + return .structured(response) + } + + static func validateExportTableName(_ table: String) throws { + guard table.range(of: exportTableNamePattern, options: .regularExpression) != nil else { + throw MCPProtocolError.invalidParams( + detail: "Invalid table name: '\(table)'. Allowed characters: letters, digits, underscore, and '.' for schema-qualified names." + ) + } + } + + static func identifierQuoter(for databaseType: DatabaseType) -> (String) -> String { + if let dialect = try? resolveSQLDialect(for: databaseType) { + return quoteIdentifierFromDialect(dialect) + } + return { "\"\($0.replacingOccurrences(of: "\"", with: "\"\""))\"" } + } + + static func quoteQualifiedIdentifier(_ identifier: String, quoter: (String) -> String) throws -> String { + let segments = identifier.split(separator: ".", omittingEmptySubsequences: true) + let segmentsWithEmpty = identifier.split(separator: ".", omittingEmptySubsequences: false) + guard !segments.isEmpty, segments.count == segmentsWithEmpty.count else { + throw MCPProtocolError.invalidParams( + detail: "Invalid qualified identifier: '\(identifier)'. Empty components are not allowed." + ) + } + return segments.map { quoter(String($0)) }.joined(separator: ".") + } + + static func sandboxedDownloadsURL(for path: String) throws -> URL { + guard let downloads = FileManager.default.urls(for: .downloadsDirectory, in: .userDomainMask).first else { + throw MCPProtocolError.invalidParams(detail: "Downloads directory is not available") + } + let downloadsRoot = downloads.standardizedFileURL.resolvingSymlinksInPath().path + let candidate = path.hasPrefix("/") ? URL(fileURLWithPath: path) : downloads.appendingPathComponent(path) + let resolvedPath = candidate.standardizedFileURL.resolvingSymlinksInPath().path + let prefix = downloadsRoot.hasSuffix("/") ? downloadsRoot : downloadsRoot + "/" + guard resolvedPath == downloadsRoot || resolvedPath.hasPrefix(prefix) else { + throw MCPProtocolError.invalidParams( + detail: "output_path must be inside the Downloads directory (\(downloadsRoot))" + ) + } + return URL(fileURLWithPath: resolvedPath) + } + + static func formatCSV(columns: [String], rows: [JsonValue]) -> String { + var lines: [String] = [] + lines.append(columns.map { escapeCSVField($0) }.joined(separator: ",")) + for row in rows { + guard let cells = row.arrayValue else { continue } + let line = cells.map { cell -> String in + switch cell { + case .string(let value): + return escapeCSVField(value) + case .null: + return "" + case .int(let value): + return String(value) + case .double(let value): + return String(value) + case .bool(let value): + return value ? "true" : "false" + default: + return escapeCSVField(encodeJSON(cell)) + } + } + lines.append(line.joined(separator: ",")) + } + return lines.joined(separator: "\n") + } + + static func escapeCSVField(_ field: String) -> String { + if field.contains(",") || field.contains("\"") || field.contains("\n") { + return "\"" + field.replacingOccurrences(of: "\"", with: "\"\"") + "\"" + } + return field + } + + static func formatJSON(columns: [String], rows: [JsonValue]) -> String { + var objects: [JsonValue] = [] + for row in rows { + guard let cells = row.arrayValue else { continue } + var dict: [String: JsonValue] = [:] + for (index, column) in columns.enumerated() where index < cells.count { + dict[column] = cells[index] + } + objects.append(.object(dict)) + } + return encodeJSON(.array(objects)) + } + + static func formatSQL(table: String, columns: [String], rows: [JsonValue]) -> String { + guard !columns.isEmpty else { return "" } + var statements: [String] = [] + let escapedTable = "`\(table.replacingOccurrences(of: "`", with: "``"))`" + let escapedColumns = columns.map { "`\($0.replacingOccurrences(of: "`", with: "``"))`" } + let columnList = escapedColumns.joined(separator: ", ") + + for row in rows { + guard let cells = row.arrayValue else { continue } + let values = cells.map { cell -> String in + switch cell { + case .null: + return "NULL" + case .string(let value): + let escaped = value + .replacingOccurrences(of: "\\", with: "\\\\") + .replacingOccurrences(of: "'", with: "\\'") + return "'\(escaped)'" + case .int(let value): + return String(value) + case .double(let value): + return String(value) + case .bool(let value): + return value ? "1" : "0" + default: + let escaped = encodeJSON(cell) + .replacingOccurrences(of: "\\", with: "\\\\") + .replacingOccurrences(of: "'", with: "\\'") + return "'\(escaped)'" + } + } + statements.append("INSERT INTO \(escapedTable) (\(columnList)) VALUES (\(values.joined(separator: ", ")));") + } + return statements.joined(separator: "\n") + } + + static func encodeJSON(_ value: JsonValue) -> String { + let encoder = JSONEncoder() + encoder.outputFormatting = [.sortedKeys] + guard let data = try? encoder.encode(value), + let string = String(data: data, encoding: .utf8) + else { + return "{}" + } + return string + } +} diff --git a/TablePro/Core/MCP/Protocol/Tools/FocusQueryTabTool.swift b/TablePro/Core/MCP/Protocol/Tools/FocusQueryTabTool.swift new file mode 100644 index 000000000..909891d01 --- /dev/null +++ b/TablePro/Core/MCP/Protocol/Tools/FocusQueryTabTool.swift @@ -0,0 +1,66 @@ +import AppKit +import Foundation + +public struct FocusQueryTabTool: MCPToolImplementation { + public static let name = "focus_query_tab" + public static let description = String(localized: "Focus an already-open tab by id (returned from list_recent_tabs).") + public static let requiredScopes: Set = [.toolsRead] + public static let annotations = MCPToolAnnotations( + title: String(localized: "Focus Query Tab"), + readOnlyHint: false, + destructiveHint: false, + idempotentHint: true, + openWorldHint: false + ) + + public static let inputSchema: JsonValue = .object([ + "type": .string("object"), + "properties": .object([ + "tab_id": .object([ + "type": .string("string"), + "description": .string("UUID of the tab to focus") + ]) + ]), + "required": .array([.string("tab_id")]) + ]) + + public init() {} + + public func call( + arguments: JsonValue, + context: MCPRequestContext, + services: MCPToolServices + ) async throws -> MCPToolCallResult { + let tabId = try MCPArgumentDecoder.requireUuid(arguments, key: "tab_id") + + let resolved: (windowId: UUID?, connectionId: UUID, raised: Bool)? = await MainActor.run { + for snapshot in MCPTabSnapshotProvider.collectTabSnapshots() where snapshot.tabId == tabId { + guard let window = snapshot.window else { + return (windowId: snapshot.windowId, connectionId: snapshot.connectionId, raised: false) + } + NSApp.activate(ignoringOtherApps: true) + window.makeKeyAndOrderFront(nil) + return (windowId: snapshot.windowId, connectionId: snapshot.connectionId, raised: true) + } + return nil + } + + guard let resolved else { + throw MCPProtocolError.invalidParams(detail: "tab not found") + } + guard resolved.raised else { + throw MCPProtocolError.invalidParams(detail: "tab not found") + } + + var dict: [String: JsonValue] = [ + "status": .string("focused"), + "tab_id": .string(tabId.uuidString), + "connection_id": .string(resolved.connectionId.uuidString) + ] + if let windowId = resolved.windowId { + dict["window_id"] = .string(windowId.uuidString) + } + + return .structured(.object(dict)) + } +} diff --git a/TablePro/Core/MCP/Protocol/Tools/GetConnectionStatusTool.swift b/TablePro/Core/MCP/Protocol/Tools/GetConnectionStatusTool.swift new file mode 100644 index 000000000..e09b43ab1 --- /dev/null +++ b/TablePro/Core/MCP/Protocol/Tools/GetConnectionStatusTool.swift @@ -0,0 +1,37 @@ +import Foundation + +public struct GetConnectionStatusTool: MCPToolImplementation { + public static let name = "get_connection_status" + public static let description = String(localized: "Get detailed status of a database connection") + public static let requiredScopes: Set = [.toolsRead] + public static let annotations = MCPToolAnnotations( + title: String(localized: "Get Connection Status"), + readOnlyHint: true, + destructiveHint: false, + idempotentHint: true, + openWorldHint: false + ) + + public static let inputSchema: JsonValue = .object([ + "type": .string("object"), + "properties": .object([ + "connection_id": .object([ + "type": .string("string"), + "description": .string(String(localized: "UUID of the connection")) + ]) + ]), + "required": .array([.string("connection_id")]) + ]) + + public init() {} + + public func call( + arguments: JsonValue, + context: MCPRequestContext, + services: MCPToolServices + ) async throws -> MCPToolCallResult { + let connectionId = try MCPArgumentDecoder.requireUuid(arguments, key: "connection_id") + let payload = try await services.connectionBridge.getConnectionStatus(connectionId: connectionId) + return .structured(payload) + } +} diff --git a/TablePro/Core/MCP/Protocol/Tools/GetTableDdlTool.swift b/TablePro/Core/MCP/Protocol/Tools/GetTableDdlTool.swift new file mode 100644 index 000000000..6f396b7dd --- /dev/null +++ b/TablePro/Core/MCP/Protocol/Tools/GetTableDdlTool.swift @@ -0,0 +1,52 @@ +import Foundation + +public struct GetTableDdlTool: MCPToolImplementation { + public static let name = "get_table_ddl" + public static let description = String(localized: "Get the CREATE TABLE DDL statement for a table") + public static let requiredScopes: Set = [.toolsRead] + public static let annotations = MCPToolAnnotations( + title: String(localized: "Get Table DDL"), + readOnlyHint: true, + destructiveHint: false, + idempotentHint: true, + openWorldHint: false + ) + + public static let inputSchema: JsonValue = .object([ + "type": .string("object"), + "properties": .object([ + "connection_id": .object([ + "type": .string("string"), + "description": .string(String(localized: "UUID of the connection")) + ]), + "table": .object([ + "type": .string("string"), + "description": .string(String(localized: "Table name")) + ]), + "schema": .object([ + "type": .string("string"), + "description": .string(String(localized: "Schema name (uses current if omitted)")) + ]) + ]), + "required": .array([.string("connection_id"), .string("table")]) + ]) + + public init() {} + + public func call( + arguments: JsonValue, + context: MCPRequestContext, + services: MCPToolServices + ) async throws -> MCPToolCallResult { + let connectionId = try MCPArgumentDecoder.requireUuid(arguments, key: "connection_id") + let table = try MCPArgumentDecoder.requireString(arguments, key: "table") + let schema = MCPArgumentDecoder.optionalString(arguments, key: "schema") + + let payload = try await services.connectionBridge.getTableDDL( + connectionId: connectionId, + table: table, + schema: schema + ) + return .structured(payload) + } +} diff --git a/TablePro/Core/MCP/Protocol/Tools/ListConnectionsTool.swift b/TablePro/Core/MCP/Protocol/Tools/ListConnectionsTool.swift new file mode 100644 index 000000000..7a687e5bf --- /dev/null +++ b/TablePro/Core/MCP/Protocol/Tools/ListConnectionsTool.swift @@ -0,0 +1,31 @@ +import Foundation + +public struct ListConnectionsTool: MCPToolImplementation { + public static let name = "list_connections" + public static let description = String(localized: "List all saved database connections with their status") + public static let requiredScopes: Set = [.toolsRead] + public static let annotations = MCPToolAnnotations( + title: String(localized: "List Connections"), + readOnlyHint: true, + destructiveHint: false, + idempotentHint: true, + openWorldHint: false + ) + + public static let inputSchema: JsonValue = .object([ + "type": .string("object"), + "properties": .object([:]), + "required": .array([]) + ]) + + public init() {} + + public func call( + arguments: JsonValue, + context: MCPRequestContext, + services: MCPToolServices + ) async throws -> MCPToolCallResult { + let payload = await services.connectionBridge.listConnections() + return .structured(payload) + } +} diff --git a/TablePro/Core/MCP/Protocol/Tools/ListDatabasesTool.swift b/TablePro/Core/MCP/Protocol/Tools/ListDatabasesTool.swift new file mode 100644 index 000000000..e6ad4a4f3 --- /dev/null +++ b/TablePro/Core/MCP/Protocol/Tools/ListDatabasesTool.swift @@ -0,0 +1,37 @@ +import Foundation + +public struct ListDatabasesTool: MCPToolImplementation { + public static let name = "list_databases" + public static let description = String(localized: "List all databases on the server") + public static let requiredScopes: Set = [.toolsRead] + public static let annotations = MCPToolAnnotations( + title: String(localized: "List Databases"), + readOnlyHint: true, + destructiveHint: false, + idempotentHint: true, + openWorldHint: false + ) + + public static let inputSchema: JsonValue = .object([ + "type": .string("object"), + "properties": .object([ + "connection_id": .object([ + "type": .string("string"), + "description": .string(String(localized: "UUID of the connection")) + ]) + ]), + "required": .array([.string("connection_id")]) + ]) + + public init() {} + + public func call( + arguments: JsonValue, + context: MCPRequestContext, + services: MCPToolServices + ) async throws -> MCPToolCallResult { + let connectionId = try MCPArgumentDecoder.requireUuid(arguments, key: "connection_id") + let payload = try await services.connectionBridge.listDatabases(connectionId: connectionId) + return .structured(payload) + } +} diff --git a/TablePro/Core/MCP/Protocol/Tools/ListRecentTabsTool.swift b/TablePro/Core/MCP/Protocol/Tools/ListRecentTabsTool.swift new file mode 100644 index 000000000..74dc2badc --- /dev/null +++ b/TablePro/Core/MCP/Protocol/Tools/ListRecentTabsTool.swift @@ -0,0 +1,68 @@ +import Foundation + +public struct ListRecentTabsTool: MCPToolImplementation { + public static let name = "list_recent_tabs" + public static let description = String( + localized: "List currently open tabs across all TablePro windows. Returns connection, tab type, table name, and titles for each tab." + ) + public static let requiredScopes: Set = [.toolsRead] + public static let annotations = MCPToolAnnotations( + title: String(localized: "List Recent Tabs"), + readOnlyHint: true, + destructiveHint: false, + idempotentHint: true, + openWorldHint: false + ) + + public static let inputSchema: JsonValue = .object([ + "type": .string("object"), + "properties": .object([ + "limit": .object([ + "type": .string("integer"), + "description": .string("Maximum number of tabs to return (default 20, max 500)") + ]) + ]), + "required": .array([]) + ]) + + public init() {} + + public func call( + arguments: JsonValue, + context: MCPRequestContext, + services: MCPToolServices + ) async throws -> MCPToolCallResult { + let limit = MCPArgumentDecoder.optionalInt(arguments, key: "limit", default: 20, clamp: 1...500) ?? 20 + + let snapshots = await MainActor.run { MCPTabSnapshotProvider.collectTabSnapshots() } + let blocked = await MainActor.run { MCPTabSnapshotProvider.blockedExternalConnectionIds() } + let filtered = snapshots.filter { !blocked.contains($0.connectionId) } + let trimmed = Array(filtered.prefix(limit)) + + let payload: [JsonValue] = trimmed.map { snapshot in + var dict: [String: JsonValue] = [ + "connection_id": .string(snapshot.connectionId.uuidString), + "connection_name": .string(snapshot.connectionName), + "tab_id": .string(snapshot.tabId.uuidString), + "tab_type": .string(snapshot.tabType), + "display_title": .string(snapshot.displayTitle), + "is_active": .bool(snapshot.isActive) + ] + if let table = snapshot.tableName { + dict["table_name"] = .string(table) + } + if let database = snapshot.databaseName { + dict["database_name"] = .string(database) + } + if let schema = snapshot.schemaName { + dict["schema_name"] = .string(schema) + } + if let windowId = snapshot.windowId { + dict["window_id"] = .string(windowId.uuidString) + } + return .object(dict) + } + + return .structured(.object(["tabs": .array(payload)])) + } +} diff --git a/TablePro/Core/MCP/Protocol/Tools/ListSchemasTool.swift b/TablePro/Core/MCP/Protocol/Tools/ListSchemasTool.swift new file mode 100644 index 000000000..bb0e98606 --- /dev/null +++ b/TablePro/Core/MCP/Protocol/Tools/ListSchemasTool.swift @@ -0,0 +1,47 @@ +import Foundation + +public struct ListSchemasTool: MCPToolImplementation { + public static let name = "list_schemas" + public static let description = String(localized: "List schemas in a database") + public static let requiredScopes: Set = [.toolsRead] + public static let annotations = MCPToolAnnotations( + title: String(localized: "List Schemas"), + readOnlyHint: true, + destructiveHint: false, + idempotentHint: true, + openWorldHint: false + ) + + public static let inputSchema: JsonValue = .object([ + "type": .string("object"), + "properties": .object([ + "connection_id": .object([ + "type": .string("string"), + "description": .string(String(localized: "UUID of the connection")) + ]), + "database": .object([ + "type": .string("string"), + "description": .string(String(localized: "Database name (uses current if omitted)")) + ]) + ]), + "required": .array([.string("connection_id")]) + ]) + + public init() {} + + public func call( + arguments: JsonValue, + context: MCPRequestContext, + services: MCPToolServices + ) async throws -> MCPToolCallResult { + let connectionId = try MCPArgumentDecoder.requireUuid(arguments, key: "connection_id") + let database = MCPArgumentDecoder.optionalString(arguments, key: "database") + + if let database { + _ = try await services.connectionBridge.switchDatabase(connectionId: connectionId, database: database) + } + + let payload = try await services.connectionBridge.listSchemas(connectionId: connectionId) + return .structured(payload) + } +} diff --git a/TablePro/Core/MCP/Protocol/Tools/ListTablesTool.swift b/TablePro/Core/MCP/Protocol/Tools/ListTablesTool.swift new file mode 100644 index 000000000..7dd404e34 --- /dev/null +++ b/TablePro/Core/MCP/Protocol/Tools/ListTablesTool.swift @@ -0,0 +1,63 @@ +import Foundation + +public struct ListTablesTool: MCPToolImplementation { + public static let name = "list_tables" + public static let description = String(localized: "List tables and views in a database") + public static let requiredScopes: Set = [.toolsRead] + public static let annotations = MCPToolAnnotations( + title: String(localized: "List Tables"), + readOnlyHint: true, + destructiveHint: false, + idempotentHint: true, + openWorldHint: false + ) + + public static let inputSchema: JsonValue = .object([ + "type": .string("object"), + "properties": .object([ + "connection_id": .object([ + "type": .string("string"), + "description": .string(String(localized: "UUID of the connection")) + ]), + "database": .object([ + "type": .string("string"), + "description": .string(String(localized: "Database name (uses current if omitted)")) + ]), + "schema": .object([ + "type": .string("string"), + "description": .string(String(localized: "Schema name (uses current if omitted)")) + ]), + "include_row_counts": .object([ + "type": .string("boolean"), + "description": .string(String(localized: "Include approximate row counts (default false)")) + ]) + ]), + "required": .array([.string("connection_id")]) + ]) + + public init() {} + + public func call( + arguments: JsonValue, + context: MCPRequestContext, + services: MCPToolServices + ) async throws -> MCPToolCallResult { + let connectionId = try MCPArgumentDecoder.requireUuid(arguments, key: "connection_id") + let database = MCPArgumentDecoder.optionalString(arguments, key: "database") + let schema = MCPArgumentDecoder.optionalString(arguments, key: "schema") + let includeRowCounts = MCPArgumentDecoder.optionalBool(arguments, key: "include_row_counts", default: false) + + if let database { + _ = try await services.connectionBridge.switchDatabase(connectionId: connectionId, database: database) + } + if let schema { + _ = try await services.connectionBridge.switchSchema(connectionId: connectionId, schema: schema) + } + + let payload = try await services.connectionBridge.listTables( + connectionId: connectionId, + includeRowCounts: includeRowCounts + ) + return .structured(payload) + } +} diff --git a/TablePro/Core/MCP/Protocol/Tools/MCPArgumentDecoder.swift b/TablePro/Core/MCP/Protocol/Tools/MCPArgumentDecoder.swift new file mode 100644 index 000000000..c5fa6631c --- /dev/null +++ b/TablePro/Core/MCP/Protocol/Tools/MCPArgumentDecoder.swift @@ -0,0 +1,64 @@ +import Foundation + +enum MCPArgumentDecoder { + static func requireString(_ args: JsonValue, key: String) throws -> String { + guard case .string(let value) = args[key] else { + throw MCPProtocolError.invalidParams(detail: "Missing required parameter: \(key)") + } + return value + } + + static func optionalString(_ args: JsonValue, key: String) -> String? { + guard case .string(let value) = args[key] else { return nil } + return value + } + + static func requireUuid(_ args: JsonValue, key: String) throws -> UUID { + let raw = try requireString(args, key: key) + guard let uuid = UUID(uuidString: raw) else { + throw MCPProtocolError.invalidParams(detail: "Invalid UUID for parameter: \(key)") + } + return uuid + } + + static func optionalUuid(_ args: JsonValue, key: String) throws -> UUID? { + guard let raw = optionalString(args, key: key) else { return nil } + guard let uuid = UUID(uuidString: raw) else { + throw MCPProtocolError.invalidParams(detail: "Invalid UUID for parameter: \(key)") + } + return uuid + } + + static func requireInt(_ args: JsonValue, key: String) throws -> Int { + guard let value = args[key]?.intValue else { + throw MCPProtocolError.invalidParams(detail: "Missing required parameter: \(key)") + } + return value + } + + static func optionalInt( + _ args: JsonValue, + key: String, + default defaultValue: Int? = nil, + clamp: ClosedRange? = nil + ) -> Int? { + let raw = args[key]?.intValue + guard let raw else { return defaultValue } + guard let clamp else { return raw } + return min(max(raw, clamp.lowerBound), clamp.upperBound) + } + + static func optionalBool(_ args: JsonValue, key: String, default defaultValue: Bool = false) -> Bool { + args[key]?.boolValue ?? defaultValue + } + + static func optionalDouble(_ args: JsonValue, key: String) -> Double? { + args[key]?.doubleValue + } + + static func optionalStringArray(_ args: JsonValue, key: String) -> [String]? { + guard let array = args[key]?.arrayValue else { return nil } + let strings = array.compactMap { $0.stringValue } + return strings.isEmpty ? nil : strings + } +} diff --git a/TablePro/Core/MCP/Protocol/Tools/MCPTabSnapshotProvider.swift b/TablePro/Core/MCP/Protocol/Tools/MCPTabSnapshotProvider.swift new file mode 100644 index 000000000..ff88d480c --- /dev/null +++ b/TablePro/Core/MCP/Protocol/Tools/MCPTabSnapshotProvider.swift @@ -0,0 +1,66 @@ +import AppKit +import Foundation + +struct MCPTabSnapshot { + let tabId: UUID + let connectionId: UUID + let connectionName: String + let tabType: String + let tableName: String? + let databaseName: String? + let schemaName: String? + let displayTitle: String + let windowId: UUID? + let isActive: Bool + weak var window: NSWindow? +} + +enum MCPTabSnapshotProvider { + @MainActor + static func collectTabSnapshots() -> [MCPTabSnapshot] { + let connections = ConnectionStorage.shared.loadConnections() + let connectionsById = Dictionary(uniqueKeysWithValues: connections.map { ($0.id, $0) }) + + var snapshots: [MCPTabSnapshot] = [] + for coordinator in MainContentCoordinator.allActiveCoordinators() { + let connectionName = connectionsById[coordinator.connectionId]?.name + ?? coordinator.connection.name + let selectedId = coordinator.tabManager.selectedTabId + for tab in coordinator.tabManager.tabs { + snapshots.append(MCPTabSnapshot( + tabId: tab.id, + connectionId: coordinator.connectionId, + connectionName: connectionName, + tabType: tab.tabType.snapshotName, + tableName: tab.tableContext.tableName, + databaseName: tab.tableContext.databaseName, + schemaName: tab.tableContext.schemaName, + displayTitle: tab.title, + windowId: coordinator.windowId, + isActive: tab.id == selectedId, + window: coordinator.contentWindow + )) + } + } + return snapshots + } + + @MainActor + static func blockedExternalConnectionIds() -> Set { + let connections = ConnectionStorage.shared.loadConnections() + return Set(connections.filter { $0.externalAccess == .blocked }.map(\.id)) + } +} + +private extension TabType { + var snapshotName: String { + switch self { + case .query: "query" + case .table: "table" + case .createTable: "createTable" + case .erDiagram: "erDiagram" + case .serverDashboard: "serverDashboard" + case .terminal: "terminal" + } + } +} diff --git a/TablePro/Core/MCP/Protocol/Tools/MCPToolImplementation.swift b/TablePro/Core/MCP/Protocol/Tools/MCPToolImplementation.swift new file mode 100644 index 000000000..0876b2b1a --- /dev/null +++ b/TablePro/Core/MCP/Protocol/Tools/MCPToolImplementation.swift @@ -0,0 +1,137 @@ +import Foundation + +public protocol MCPToolImplementation: Sendable { + static var name: String { get } + static var title: String? { get } + static var description: String { get } + static var inputSchema: JsonValue { get } + static var annotations: MCPToolAnnotations { get } + static var requiredScopes: Set { get } + func call(arguments: JsonValue, context: MCPRequestContext, services: MCPToolServices) async throws -> MCPToolCallResult +} + +public extension MCPToolImplementation { + static var title: String? { nil } + static var annotations: MCPToolAnnotations { MCPToolAnnotations() } + + var name: String { Self.name } + var description: String { Self.description } + var inputSchema: JsonValue { Self.inputSchema } + var requiredScopes: Set { Self.requiredScopes } +} + +public struct MCPToolAnnotations: Sendable, Equatable { + public let title: String? + public let readOnlyHint: Bool? + public let destructiveHint: Bool? + public let idempotentHint: Bool? + public let openWorldHint: Bool? + + public init( + title: String? = nil, + readOnlyHint: Bool? = nil, + destructiveHint: Bool? = nil, + idempotentHint: Bool? = nil, + openWorldHint: Bool? = nil + ) { + self.title = title + self.readOnlyHint = readOnlyHint + self.destructiveHint = destructiveHint + self.idempotentHint = idempotentHint + self.openWorldHint = openWorldHint + } + + public var asJsonValue: JsonValue? { + var fields: [String: JsonValue] = [:] + if let title { + fields["title"] = .string(title) + } + if let readOnlyHint { + fields["readOnlyHint"] = .bool(readOnlyHint) + } + if let destructiveHint { + fields["destructiveHint"] = .bool(destructiveHint) + } + if let idempotentHint { + fields["idempotentHint"] = .bool(idempotentHint) + } + if let openWorldHint { + fields["openWorldHint"] = .bool(openWorldHint) + } + guard !fields.isEmpty else { return nil } + return .object(fields) + } +} + +public struct MCPToolCallResult: Sendable { + public let content: [MCPToolContentItem] + public let structuredContent: JsonValue? + public let isError: Bool + + public init( + content: [MCPToolContentItem], + structuredContent: JsonValue? = nil, + isError: Bool = false + ) { + self.content = content + self.structuredContent = structuredContent + self.isError = isError + } + + public static func text(_ value: String, isError: Bool = false) -> MCPToolCallResult { + MCPToolCallResult(content: [.text(value)], isError: isError) + } + + public static func json(_ value: JsonValue, isError: Bool = false) -> MCPToolCallResult { + let encoded = encodeJsonString(value) + return MCPToolCallResult(content: [.text(encoded)], isError: isError) + } + + public static func structured(_ value: JsonValue, isError: Bool = false) -> MCPToolCallResult { + let encoded = encodeJsonString(value) + return MCPToolCallResult( + content: [.text(encoded)], + structuredContent: value, + isError: isError + ) + } + + private static func encodeJsonString(_ value: JsonValue) -> String { + let encoder = JSONEncoder() + encoder.outputFormatting = [.sortedKeys] + guard let data = try? encoder.encode(value), + let string = String(data: data, encoding: .utf8) else { + return "{}" + } + return string + } +} + +public enum MCPToolContentItem: Sendable, Equatable { + case text(String) + + var asJsonValue: JsonValue { + switch self { + case .text(let value): + return .object([ + "type": .string("text"), + "text": .string(value) + ]) + } + } +} + +public extension MCPToolCallResult { + func asJsonValue() -> JsonValue { + var fields: [String: JsonValue] = [ + "content": .array(content.map { $0.asJsonValue }) + ] + if let structuredContent { + fields["structuredContent"] = structuredContent + } + if isError { + fields["isError"] = .bool(true) + } + return .object(fields) + } +} diff --git a/TablePro/Core/MCP/Protocol/Tools/MCPToolRegistry.swift b/TablePro/Core/MCP/Protocol/Tools/MCPToolRegistry.swift new file mode 100644 index 000000000..4b2cde4cc --- /dev/null +++ b/TablePro/Core/MCP/Protocol/Tools/MCPToolRegistry.swift @@ -0,0 +1,37 @@ +import Foundation + +public enum MCPToolRegistry { + public static let allTools: [any MCPToolImplementation] = [ + ListConnectionsTool(), + GetConnectionStatusTool(), + ListDatabasesTool(), + ListSchemasTool(), + ListTablesTool(), + DescribeTableTool(), + GetTableDdlTool(), + ListRecentTabsTool(), + SearchQueryHistoryTool(), + FocusQueryTabTool(), + ConnectTool(), + DisconnectTool(), + SwitchDatabaseTool(), + SwitchSchemaTool(), + ExecuteQueryTool(), + ExportDataTool(), + ConfirmDestructiveOperationTool(), + OpenTableTabTool(), + OpenConnectionWindowTool() + ] + + private static let toolsByName: [String: any MCPToolImplementation] = { + var map: [String: any MCPToolImplementation] = [:] + for tool in allTools { + map[type(of: tool).name] = tool + } + return map + }() + + public static func tool(named name: String) -> (any MCPToolImplementation)? { + toolsByName[name] + } +} diff --git a/TablePro/Core/MCP/Protocol/Tools/MCPToolServices.swift b/TablePro/Core/MCP/Protocol/Tools/MCPToolServices.swift new file mode 100644 index 000000000..725f9f440 --- /dev/null +++ b/TablePro/Core/MCP/Protocol/Tools/MCPToolServices.swift @@ -0,0 +1,11 @@ +import Foundation + +public struct MCPToolServices: Sendable { + public let connectionBridge: MCPConnectionBridge + public let authPolicy: MCPAuthPolicy + + public init(connectionBridge: MCPConnectionBridge, authPolicy: MCPAuthPolicy) { + self.connectionBridge = connectionBridge + self.authPolicy = authPolicy + } +} diff --git a/TablePro/Core/MCP/Protocol/Tools/OpenConnectionWindowTool.swift b/TablePro/Core/MCP/Protocol/Tools/OpenConnectionWindowTool.swift new file mode 100644 index 000000000..6d65e7f5a --- /dev/null +++ b/TablePro/Core/MCP/Protocol/Tools/OpenConnectionWindowTool.swift @@ -0,0 +1,70 @@ +import AppKit +import Foundation +import os + +public struct OpenConnectionWindowTool: MCPToolImplementation { + public static let name = "open_connection_window" + public static let description = String( + localized: "Open a TablePro window for a saved connection (focuses if already open)." + ) + public static let inputSchema: JsonValue = .object([ + "type": .string("object"), + "properties": .object([ + "connection_id": .object([ + "type": .string("string"), + "description": .string(String(localized: "UUID of the saved connection")) + ]) + ]), + "required": .array([.string("connection_id")]) + ]) + public static let requiredScopes: Set = [.toolsRead] + public static let annotations = MCPToolAnnotations( + title: String(localized: "Open Connection Window"), + readOnlyHint: false, + destructiveHint: false, + idempotentHint: true, + openWorldHint: false + ) + + private static let logger = Logger(subsystem: "com.TablePro", category: "MCP.Tools") + + public init() {} + + public func call( + arguments: JsonValue, + context: MCPRequestContext, + services: MCPToolServices + ) async throws -> MCPToolCallResult { + let connectionId = try MCPArgumentDecoder.requireUuid(arguments, key: "connection_id") + try await ensureConnectionExists(connectionId) + + Self.logger.debug("open_connection_window invoked for connection \(connectionId.uuidString, privacy: .public)") + + let windowId = await MainActor.run { () -> UUID in + let payload = EditorTabPayload( + connectionId: connectionId, + tabType: .query, + intent: .restoreOrDefault + ) + WindowManager.shared.openTab(payload: payload) + NSApp.activate(ignoringOtherApps: true) + return payload.id + } + + let result: JsonValue = .object([ + "status": .string("opened"), + "connection_id": .string(connectionId.uuidString), + "window_id": .string(windowId.uuidString) + ]) + return .structured(result) + } + + private func ensureConnectionExists(_ connectionId: UUID) async throws { + let exists = await MainActor.run { + ConnectionStorage.shared.loadConnections().contains { $0.id == connectionId } + } + guard exists else { + throw MCPProtocolError.invalidParams(detail: "Connection not found: \(connectionId.uuidString)") + } + } +} diff --git a/TablePro/Core/MCP/Protocol/Tools/OpenTableTabTool.swift b/TablePro/Core/MCP/Protocol/Tools/OpenTableTabTool.swift new file mode 100644 index 000000000..9ab75b6ef --- /dev/null +++ b/TablePro/Core/MCP/Protocol/Tools/OpenTableTabTool.swift @@ -0,0 +1,90 @@ +import AppKit +import Foundation +import os + +public struct OpenTableTabTool: MCPToolImplementation { + public static let name = "open_table_tab" + public static let description = String( + localized: "Open a table tab in TablePro for the given connection." + ) + public static let inputSchema: JsonValue = .object([ + "type": .string("object"), + "properties": .object([ + "connection_id": .object([ + "type": .string("string"), + "description": .string(String(localized: "UUID of the connection")) + ]), + "table_name": .object([ + "type": .string("string"), + "description": .string(String(localized: "Table name to open")) + ]), + "database_name": .object([ + "type": .string("string"), + "description": .string(String(localized: "Database name (uses connection's current database if omitted)")) + ]), + "schema_name": .object([ + "type": .string("string"), + "description": .string(String(localized: "Schema name (for multi-schema databases)")) + ]) + ]), + "required": .array([.string("connection_id"), .string("table_name")]) + ]) + public static let requiredScopes: Set = [.toolsRead] + public static let annotations = MCPToolAnnotations( + title: String(localized: "Open Table Tab"), + readOnlyHint: false, + destructiveHint: false, + idempotentHint: true, + openWorldHint: false + ) + + private static let logger = Logger(subsystem: "com.TablePro", category: "MCP.Tools") + + public init() {} + + public func call( + arguments: JsonValue, + context: MCPRequestContext, + services: MCPToolServices + ) async throws -> MCPToolCallResult { + let connectionId = try MCPArgumentDecoder.requireUuid(arguments, key: "connection_id") + let tableName = try MCPArgumentDecoder.requireString(arguments, key: "table_name") + let databaseName = MCPArgumentDecoder.optionalString(arguments, key: "database_name") + let schemaName = MCPArgumentDecoder.optionalString(arguments, key: "schema_name") + + try await ensureConnectionExists(connectionId) + + Self.logger.debug("open_table_tab invoked for connection \(connectionId.uuidString, privacy: .public)") + + let windowId = await MainActor.run { () -> UUID in + let payload = EditorTabPayload( + connectionId: connectionId, + tabType: .table, + tableName: tableName, + databaseName: databaseName, + schemaName: schemaName, + intent: .openContent + ) + WindowManager.shared.openTab(payload: payload) + NSApp.activate(ignoringOtherApps: true) + return payload.id + } + + let result: JsonValue = .object([ + "status": .string("opened"), + "connection_id": .string(connectionId.uuidString), + "table_name": .string(tableName), + "window_id": .string(windowId.uuidString) + ]) + return .structured(result) + } + + private func ensureConnectionExists(_ connectionId: UUID) async throws { + let exists = await MainActor.run { + ConnectionStorage.shared.loadConnections().contains { $0.id == connectionId } + } + guard exists else { + throw MCPProtocolError.invalidParams(detail: "Connection not found: \(connectionId.uuidString)") + } + } +} diff --git a/TablePro/Core/MCP/Protocol/Tools/SearchQueryHistoryTool.swift b/TablePro/Core/MCP/Protocol/Tools/SearchQueryHistoryTool.swift new file mode 100644 index 000000000..c13dad637 --- /dev/null +++ b/TablePro/Core/MCP/Protocol/Tools/SearchQueryHistoryTool.swift @@ -0,0 +1,109 @@ +import Foundation + +public struct SearchQueryHistoryTool: MCPToolImplementation { + public static let name = "search_query_history" + public static let description = String( + localized: "Search saved query history. Returns matching entries with execution time, row count, and outcome." + ) + public static let requiredScopes: Set = [.toolsRead] + public static let annotations = MCPToolAnnotations( + title: String(localized: "Search Query History"), + readOnlyHint: true, + destructiveHint: false, + idempotentHint: true, + openWorldHint: false + ) + + public static let inputSchema: JsonValue = .object([ + "type": .string("object"), + "properties": .object([ + "query": .object([ + "type": .string("string"), + "description": .string(String(localized: "Search text (full-text matched against the query column)")) + ]), + "connection_id": .object([ + "type": .string("string"), + "description": .string(String(localized: "Restrict to a specific connection (UUID, optional)")) + ]), + "limit": .object([ + "type": .string("integer"), + "description": .string(String(localized: "Maximum number of entries to return (default 50, max 500)")) + ]), + "since": .object([ + "type": .string("number"), + "description": .string(String(localized: "Earliest executed_at to include, Unix epoch seconds (inclusive, optional)")) + ]), + "until": .object([ + "type": .string("number"), + "description": .string(String(localized: "Latest executed_at to include, Unix epoch seconds (inclusive, optional)")) + ]) + ]), + "required": .array([.string("query")]) + ]) + + public init() {} + + public func call( + arguments: JsonValue, + context: MCPRequestContext, + services: MCPToolServices + ) async throws -> MCPToolCallResult { + let query = try MCPArgumentDecoder.requireString(arguments, key: "query") + let connectionId = try MCPArgumentDecoder.optionalUuid(arguments, key: "connection_id") + let limit = MCPArgumentDecoder.optionalInt(arguments, key: "limit", default: 50, clamp: 1...500) ?? 50 + let since = MCPArgumentDecoder.optionalDouble(arguments, key: "since").map { Date(timeIntervalSince1970: $0) } + let until = MCPArgumentDecoder.optionalDouble(arguments, key: "until").map { Date(timeIntervalSince1970: $0) } + + if let since, let until, since > until { + throw MCPProtocolError.invalidParams(detail: "'since' must be less than or equal to 'until'") + } + + let blocked = await MainActor.run { MCPTabSnapshotProvider.blockedExternalConnectionIds() } + + if let connectionId, blocked.contains(connectionId) { + throw MCPProtocolError.forbidden(reason: "External access is disabled for this connection") + } + + let allowlist: Set? + if connectionId != nil { + allowlist = nil + } else if blocked.isEmpty { + allowlist = nil + } else { + let allConnectionIds = await MainActor.run { + Set(ConnectionStorage.shared.loadConnections().map(\.id)) + } + allowlist = allConnectionIds.subtracting(blocked) + } + + let entries = await QueryHistoryStorage.shared.fetchHistory( + limit: limit, + offset: 0, + connectionId: connectionId, + searchText: query.isEmpty ? nil : query, + dateFilter: .all, + since: since, + until: until, + allowedConnectionIds: allowlist + ) + + let payload: [JsonValue] = entries.map { entry in + var dict: [String: JsonValue] = [ + "id": .string(entry.id.uuidString), + "query": .string(entry.query), + "connection_id": .string(entry.connectionId.uuidString), + "database_name": .string(entry.databaseName), + "executed_at": .double(entry.executedAt.timeIntervalSince1970), + "execution_time_ms": .double(entry.executionTime * 1_000), + "row_count": .int(entry.rowCount), + "was_successful": .bool(entry.wasSuccessful) + ] + if let error = entry.errorMessage { + dict["error_message"] = .string(error) + } + return .object(dict) + } + + return .structured(.object(["entries": .array(payload)])) + } +} diff --git a/TablePro/Core/MCP/Protocol/Tools/SwitchDatabaseTool.swift b/TablePro/Core/MCP/Protocol/Tools/SwitchDatabaseTool.swift new file mode 100644 index 000000000..26c8c273b --- /dev/null +++ b/TablePro/Core/MCP/Protocol/Tools/SwitchDatabaseTool.swift @@ -0,0 +1,48 @@ +import Foundation +import os + +public struct SwitchDatabaseTool: MCPToolImplementation { + public static let name = "switch_database" + public static let description = String(localized: "Switch the active database on a connection") + public static let inputSchema: JsonValue = .object([ + "type": .string("object"), + "properties": .object([ + "connection_id": .object([ + "type": .string("string"), + "description": .string(String(localized: "UUID of the connection")) + ]), + "database": .object([ + "type": .string("string"), + "description": .string(String(localized: "Database name to switch to")) + ]) + ]), + "required": .array([.string("connection_id"), .string("database")]) + ]) + public static let requiredScopes: Set = [.toolsWrite] + public static let annotations = MCPToolAnnotations( + title: String(localized: "Switch Database"), + readOnlyHint: false, + destructiveHint: false, + idempotentHint: true, + openWorldHint: false + ) + + private static let logger = Logger(subsystem: "com.TablePro", category: "MCP.Tools") + + public init() {} + + public func call( + arguments: JsonValue, + context: MCPRequestContext, + services: MCPToolServices + ) async throws -> MCPToolCallResult { + let connectionId = try MCPArgumentDecoder.requireUuid(arguments, key: "connection_id") + let database = try MCPArgumentDecoder.requireString(arguments, key: "database") + Self.logger.debug("switch_database tool invoked for connection \(connectionId.uuidString, privacy: .public)") + let payload = try await services.connectionBridge.switchDatabase( + connectionId: connectionId, + database: database + ) + return .structured(payload) + } +} diff --git a/TablePro/Core/MCP/Protocol/Tools/SwitchSchemaTool.swift b/TablePro/Core/MCP/Protocol/Tools/SwitchSchemaTool.swift new file mode 100644 index 000000000..e26bdea75 --- /dev/null +++ b/TablePro/Core/MCP/Protocol/Tools/SwitchSchemaTool.swift @@ -0,0 +1,48 @@ +import Foundation +import os + +public struct SwitchSchemaTool: MCPToolImplementation { + public static let name = "switch_schema" + public static let description = String(localized: "Switch the active schema on a connection") + public static let inputSchema: JsonValue = .object([ + "type": .string("object"), + "properties": .object([ + "connection_id": .object([ + "type": .string("string"), + "description": .string(String(localized: "UUID of the connection")) + ]), + "schema": .object([ + "type": .string("string"), + "description": .string(String(localized: "Schema name to switch to")) + ]) + ]), + "required": .array([.string("connection_id"), .string("schema")]) + ]) + public static let requiredScopes: Set = [.toolsWrite] + public static let annotations = MCPToolAnnotations( + title: String(localized: "Switch Schema"), + readOnlyHint: false, + destructiveHint: false, + idempotentHint: true, + openWorldHint: false + ) + + private static let logger = Logger(subsystem: "com.TablePro", category: "MCP.Tools") + + public init() {} + + public func call( + arguments: JsonValue, + context: MCPRequestContext, + services: MCPToolServices + ) async throws -> MCPToolCallResult { + let connectionId = try MCPArgumentDecoder.requireUuid(arguments, key: "connection_id") + let schema = try MCPArgumentDecoder.requireString(arguments, key: "schema") + Self.logger.debug("switch_schema tool invoked for connection \(connectionId.uuidString, privacy: .public)") + let payload = try await services.connectionBridge.switchSchema( + connectionId: connectionId, + schema: schema + ) + return .structured(payload) + } +} diff --git a/TablePro/Core/MCP/Protocol/Tools/ToolConnectionMetadata.swift b/TablePro/Core/MCP/Protocol/Tools/ToolConnectionMetadata.swift new file mode 100644 index 000000000..e14ad78ca --- /dev/null +++ b/TablePro/Core/MCP/Protocol/Tools/ToolConnectionMetadata.swift @@ -0,0 +1,28 @@ +import Foundation + +struct ToolConnectionMetadata { + let databaseType: DatabaseType + let safeModeLevel: SafeModeLevel + let databaseName: String + + static func resolve(connectionId: UUID) async throws -> ToolConnectionMetadata { + try await MainActor.run { + switch DatabaseManager.shared.connectionState(connectionId) { + case .live(_, let session): + return ToolConnectionMetadata( + databaseType: session.connection.type, + safeModeLevel: session.connection.safeModeLevel, + databaseName: session.activeDatabase + ) + case .stored(let conn): + return ToolConnectionMetadata( + databaseType: conn.type, + safeModeLevel: conn.safeModeLevel, + databaseName: conn.database + ) + case .unknown: + throw MCPProtocolError.invalidParams(detail: "Connection not found: \(connectionId.uuidString)") + } + } + } +} diff --git a/TablePro/Core/MCP/Protocol/Tools/ToolQueryExecutor.swift b/TablePro/Core/MCP/Protocol/Tools/ToolQueryExecutor.swift new file mode 100644 index 000000000..55d399e08 --- /dev/null +++ b/TablePro/Core/MCP/Protocol/Tools/ToolQueryExecutor.swift @@ -0,0 +1,66 @@ +import Foundation + +enum ToolQueryExecutor { + static func executeAndLog( + services: MCPToolServices, + query: String, + connectionId: UUID, + databaseName: String, + maxRows: Int, + timeoutSeconds: Int, + principalLabel: String? + ) async throws -> JsonValue { + let startTime = Date() + do { + let result = try await services.connectionBridge.executeQuery( + connectionId: connectionId, + query: query, + maxRows: maxRows, + timeoutSeconds: timeoutSeconds + ) + let elapsed = Date().timeIntervalSince(startTime) + let rowCount = result["row_count"]?.intValue ?? 0 + await services.authPolicy.logQuery( + sql: query, + connectionId: connectionId, + databaseName: databaseName, + executionTime: elapsed, + rowCount: rowCount, + wasSuccessful: true, + errorMessage: nil + ) + MCPAuditLogger.logQueryExecuted( + tokenId: nil, + tokenName: principalLabel, + connectionId: connectionId, + sql: query, + durationMs: Int(elapsed * 1_000), + rowCount: rowCount, + outcome: .success + ) + return result + } catch { + let elapsed = Date().timeIntervalSince(startTime) + await services.authPolicy.logQuery( + sql: query, + connectionId: connectionId, + databaseName: databaseName, + executionTime: elapsed, + rowCount: 0, + wasSuccessful: false, + errorMessage: error.localizedDescription + ) + MCPAuditLogger.logQueryExecuted( + tokenId: nil, + tokenName: principalLabel, + connectionId: connectionId, + sql: query, + durationMs: Int(elapsed * 1_000), + rowCount: 0, + outcome: .error, + errorMessage: error.localizedDescription + ) + throw error + } + } +} diff --git a/TablePro/Core/MCP/RateLimit/MCPRateLimiter.swift b/TablePro/Core/MCP/RateLimit/MCPRateLimiter.swift new file mode 100644 index 000000000..d3c2d53f3 --- /dev/null +++ b/TablePro/Core/MCP/RateLimit/MCPRateLimiter.swift @@ -0,0 +1,116 @@ +import Foundation +import os + +public struct MCPRateLimitKey: Sendable, Equatable, Hashable { + public let clientAddress: MCPClientAddress + public let principalFingerprint: String? + + public init(clientAddress: MCPClientAddress, principalFingerprint: String?) { + self.clientAddress = clientAddress + self.principalFingerprint = principalFingerprint + } +} + +public struct MCPRateLimitPolicy: Sendable, Equatable { + public let maxFailedAttempts: Int + public let windowDuration: Duration + public let lockoutDuration: Duration + + public init(maxFailedAttempts: Int, windowDuration: Duration, lockoutDuration: Duration) { + self.maxFailedAttempts = maxFailedAttempts + self.windowDuration = windowDuration + self.lockoutDuration = lockoutDuration + } + + public static let standard = MCPRateLimitPolicy( + maxFailedAttempts: 5, + windowDuration: .seconds(60), + lockoutDuration: .seconds(300) + ) +} + +public enum MCPRateLimitVerdict: Sendable, Equatable { + case allowed + case lockedUntil(Date) +} + +public actor MCPRateLimiter { + private static let logger = Logger(subsystem: "com.TablePro", category: "MCP.RateLimit") + + private struct Bucket { + var failureTimestamps: [Date] + var lockedUntil: Date? + } + + private let policy: MCPRateLimitPolicy + private let clock: any MCPClock + private var buckets: [MCPRateLimitKey: Bucket] = [:] + + public init(policy: MCPRateLimitPolicy = .standard, clock: any MCPClock = MCPSystemClock()) { + self.policy = policy + self.clock = clock + } + + public func recordAttempt(key: MCPRateLimitKey, success: Bool) async -> MCPRateLimitVerdict { + let now = await clock.now() + + if let lockedUntil = buckets[key]?.lockedUntil, lockedUntil > now { + return .lockedUntil(lockedUntil) + } + + if success { + buckets.removeValue(forKey: key) + return .allowed + } + + var bucket = buckets[key] ?? Bucket(failureTimestamps: [], lockedUntil: nil) + let windowStart = now.addingTimeInterval(-Self.seconds(of: policy.windowDuration)) + bucket.failureTimestamps.removeAll { $0 < windowStart } + bucket.failureTimestamps.append(now) + + if bucket.failureTimestamps.count >= policy.maxFailedAttempts { + let lockUntil = now.addingTimeInterval(Self.seconds(of: policy.lockoutDuration)) + bucket.lockedUntil = lockUntil + buckets[key] = bucket + Self.logger.warning( + "Rate limit lockout \(Self.describe(key), privacy: .public) until \(lockUntil, privacy: .public)" + ) + return .lockedUntil(lockUntil) + } + + bucket.lockedUntil = nil + buckets[key] = bucket + return .allowed + } + + public func isLocked(key: MCPRateLimitKey) async -> Bool { + guard let lockedUntil = buckets[key]?.lockedUntil else { return false } + return lockedUntil > (await clock.now()) + } + + public func lockedUntil(key: MCPRateLimitKey) async -> Date? { + guard let lockedUntil = buckets[key]?.lockedUntil else { return nil } + guard lockedUntil > (await clock.now()) else { return nil } + return lockedUntil + } + + public func reset(key: MCPRateLimitKey) async { + buckets.removeValue(forKey: key) + } + + private static func describe(_ key: MCPRateLimitKey) -> String { + let address: String + switch key.clientAddress { + case .loopback: + address = "loopback" + case .remote(let value): + address = value + } + return "\(address)/\(key.principalFingerprint ?? "anon")" + } + + private static func seconds(of duration: Duration) -> TimeInterval { + let components = duration.components + return TimeInterval(components.seconds) + TimeInterval(components.attoseconds) / 1.0e18 + } +} diff --git a/TablePro/Core/MCP/Routes/IntegrationsExchangeHandler.swift b/TablePro/Core/MCP/Routes/IntegrationsExchangeHandler.swift deleted file mode 100644 index 794494a02..000000000 --- a/TablePro/Core/MCP/Routes/IntegrationsExchangeHandler.swift +++ /dev/null @@ -1,94 +0,0 @@ -import Foundation -import os - -struct IntegrationsExchangeHandler: MCPRouteHandler { - private static let logger = Logger(subsystem: "com.TablePro", category: "IntegrationsExchangeHandler") - - private let exchange: @Sendable (PairingExchange) async throws -> String - - private let encoder: JSONEncoder - private let decoder: JSONDecoder - - var methods: [HTTPRequest.Method] { [.post] } - var path: String { "/v1/integrations/exchange" } - - init(exchange: @escaping @Sendable (PairingExchange) async throws -> String) { - self.exchange = exchange - let enc = JSONEncoder() - enc.outputFormatting = [.sortedKeys] - self.encoder = enc - self.decoder = JSONDecoder() - } - - static func live() -> IntegrationsExchangeHandler { - IntegrationsExchangeHandler { request in - try await MainActor.run { - try MCPPairingService.shared.exchange(request) - } - } - } - - func handle(_ request: HTTPRequest) async -> MCPRouter.RouteResult { - guard let body = request.body else { - return .httpError(status: 400, message: "Missing request body") - } - - let parsed: ExchangeRequestBody - do { - parsed = try decoder.decode(ExchangeRequestBody.self, from: body) - } catch { - return .httpError(status: 400, message: "Invalid JSON body") - } - - guard !parsed.code.isEmpty, !parsed.codeVerifier.isEmpty else { - return .httpError(status: 400, message: "Missing code or code_verifier") - } - - let token: String - do { - token = try await exchange( - PairingExchange(code: parsed.code, verifier: parsed.codeVerifier) - ) - } catch let mcpError as MCPError { - return Self.mapExchangeError(mcpError) - } catch { - Self.logger.error("Pairing exchange failed: \(error.localizedDescription)") - return .httpError(status: 500, message: "Internal error") - } - - do { - let data = try encoder.encode(ExchangeResponseBody(token: token)) - return .json(data, sessionId: nil) - } catch { - Self.logger.error("Failed to encode exchange response: \(error.localizedDescription)") - return .httpError(status: 500, message: "Internal error") - } - } - - private static func mapExchangeError(_ error: MCPError) -> MCPRouter.RouteResult { - switch error { - case .notFound: - return .httpError(status: 404, message: "Pairing code not found") - case .expired: - return .httpError(status: 410, message: "Pairing code expired") - case .forbidden: - return .httpError(status: 403, message: "Challenge mismatch") - default: - return .httpError(status: 500, message: "Internal error") - } - } - - private struct ExchangeRequestBody: Decodable { - let code: String - let codeVerifier: String - - enum CodingKeys: String, CodingKey { - case code - case codeVerifier = "code_verifier" - } - } - - private struct ExchangeResponseBody: Encodable { - let token: String - } -} diff --git a/TablePro/Core/MCP/Routes/MCPProtocolHandler.swift b/TablePro/Core/MCP/Routes/MCPProtocolHandler.swift deleted file mode 100644 index dac72ede6..000000000 --- a/TablePro/Core/MCP/Routes/MCPProtocolHandler.swift +++ /dev/null @@ -1,533 +0,0 @@ -import Foundation -import os - -final class MCPProtocolHandler: MCPRouteHandler, @unchecked Sendable { - private static let logger = Logger(subsystem: "com.TablePro", category: "MCPProtocolHandler") - - private weak var server: MCPServer? - private let tokenStore: MCPTokenStore? - private let rateLimiter: MCPRateLimiter? - - private let encoder: JSONEncoder - private let decoder: JSONDecoder - - var methods: [HTTPRequest.Method] { [.get, .post, .delete] } - var path: String { "/mcp" } - - init(server: MCPServer, tokenStore: MCPTokenStore?, rateLimiter: MCPRateLimiter?) { - self.server = server - self.tokenStore = tokenStore - self.rateLimiter = rateLimiter - let enc = JSONEncoder() - enc.outputFormatting = [.sortedKeys] - self.encoder = enc - self.decoder = JSONDecoder() - } - - func handle(_ request: HTTPRequest) async -> MCPRouter.RouteResult { - guard let server else { - return .httpError(status: 503, message: "Server unavailable") - } - - if let rateLimiter, let ip = request.remoteIP { - let lockoutCheck = await rateLimiter.isLockedOut(ip: ip) - if case .rateLimited(let retryAfter) = lockoutCheck { - let seconds = Int(retryAfter.components.seconds) - MCPAuditLogger.logRateLimited(ip: ip, retryAfterSeconds: seconds) - return .httpErrorWithHeaders( - status: 429, - message: "Too many failed attempts", - extraHeaders: [("Retry-After", "\(seconds)")] - ) - } - } - - let authResult = await authenticateRequest(request) - - switch authResult { - case .failure(let result): - return result - case .success(let token): - if token == nil { - if let origin = request.headers["origin"], !isAllowedOrigin(origin) { - return .httpError(status: 403, message: "Forbidden origin") - } - } - - switch request.method { - case .post: - return await handlePost(request, server: server, authenticatedToken: token) - case .get: - return await handleGet(request, server: server) - case .delete: - return await handleDelete(request, server: server) - case .options: - return .noContent - } - } - } - - private enum AuthResult { - case success(MCPAuthToken?) - case failure(MCPRouter.RouteResult) - } - - private func authenticateRequest(_ request: HTTPRequest) async -> AuthResult { - let remoteIP = request.remoteIP - let authRequired = await MainActor.run { AppSettingsManager.shared.mcp.requireAuthentication } - - guard let authHeader = request.headers["authorization"] else { - guard !authRequired else { - MCPAuditLogger.logAuthFailure(reason: "Missing authorization header", ip: remoteIP ?? "localhost") - return .failure(.httpErrorWithHeaders( - status: 401, - message: "Authentication required", - extraHeaders: [("WWW-Authenticate", "Bearer realm=\"TablePro MCP\"")] - )) - } - return .success(nil) - } - - guard authHeader.lowercased().hasPrefix("bearer "), let tokenStore else { - let rateLimitResult = await recordAuthFailure(ip: remoteIP) - if case .rateLimited(let retryAfter) = rateLimitResult { - let seconds = Int(retryAfter.components.seconds) - MCPAuditLogger.logRateLimited(ip: remoteIP ?? "localhost", retryAfterSeconds: seconds) - return .failure(.httpErrorWithHeaders( - status: 429, - message: "Too many failed attempts", - extraHeaders: [("Retry-After", "\(seconds)")] - )) - } - MCPAuditLogger.logAuthFailure(reason: "Invalid authorization header format", ip: remoteIP ?? "localhost") - return .failure(.httpErrorWithHeaders( - status: 401, - message: "Invalid authorization header", - extraHeaders: [("WWW-Authenticate", "Bearer realm=\"TablePro MCP\"")] - )) - } - - let bearerToken = String(authHeader.dropFirst(7)) - - guard let token = await tokenStore.validate(bearerToken: bearerToken) else { - let rateLimitResult = await recordAuthFailure(ip: remoteIP) - if case .rateLimited(let retryAfter) = rateLimitResult { - let seconds = Int(retryAfter.components.seconds) - MCPAuditLogger.logRateLimited(ip: remoteIP ?? "localhost", retryAfterSeconds: seconds) - return .failure(.httpErrorWithHeaders( - status: 429, - message: "Too many failed attempts", - extraHeaders: [("Retry-After", "\(seconds)")] - )) - } - MCPAuditLogger.logAuthFailure(reason: "Invalid token", ip: remoteIP ?? "localhost") - return .failure(.httpErrorWithHeaders( - status: 401, - message: "Invalid or expired token", - extraHeaders: [("WWW-Authenticate", "Bearer realm=\"TablePro MCP\"")] - )) - } - - if let rateLimiter, let ip = remoteIP { - _ = await rateLimiter.checkAndRecord(ip: ip, success: true) - } - MCPAuditLogger.logAuthSuccess(tokenName: token.name, ip: remoteIP ?? "localhost") - return .success(token) - } - - @discardableResult - private func recordAuthFailure(ip: String?) async -> MCPRateLimiter.AuthRateResult? { - guard let rateLimiter, let ip else { return nil } - return await rateLimiter.checkAndRecord(ip: ip, success: false) - } - - private func isAllowedOrigin(_ origin: String) -> Bool { - guard let components = URLComponents(string: origin), - let host = components.host - else { - return false - } - let allowedHosts: Set = ["localhost", "127.0.0.1", "::1"] - return allowedHosts.contains(host) - } - - private func handleGet(_ request: HTTPRequest, server: MCPServer) async -> MCPRouter.RouteResult { - guard let sessionId = request.headers["mcp-session-id"] else { - return .httpError(status: 400, message: "Missing Mcp-Session-Id header") - } - - guard let session = await server.session(for: sessionId) else { - return .httpError(status: 404, message: "Session not found") - } - - await session.markActive() - return .sseStream(sessionId: session.id) - } - - private func handleDelete(_ request: HTTPRequest, server: MCPServer) async -> MCPRouter.RouteResult { - guard let sessionId = request.headers["mcp-session-id"] else { - return .httpError(status: 400, message: "Missing Mcp-Session-Id header") - } - - guard await server.session(for: sessionId) != nil else { - return .httpError(status: 404, message: "Session not found") - } - - await server.removeSession(sessionId) - Self.logger.info("Session terminated via DELETE: \(sessionId)") - return .noContent - } - - private func handlePost( - _ request: HTTPRequest, - server: MCPServer, - authenticatedToken: MCPAuthToken? - ) async -> MCPRouter.RouteResult { - if let accept = request.headers["accept"], !accept.contains("application/json") && !accept.contains("*/*") { - return .httpError(status: 406, message: "Accept header must include application/json") - } - - guard let body = request.body else { - return encodeError(MCPError.parseError, id: nil) - } - - let rpcRequest: JSONRPCRequest - do { - rpcRequest = try decoder.decode(JSONRPCRequest.self, from: body) - } catch { - return encodeError(MCPError.parseError, id: nil) - } - - guard rpcRequest.jsonrpc == "2.0" else { - return encodeError(MCPError.invalidRequest("jsonrpc must be \"2.0\""), id: rpcRequest.id) - } - - if let protocolVersion = request.headers["mcp-protocol-version"], - protocolVersion != "2025-03-26" - { - Self.logger.warning("Client mcp-protocol-version mismatch: \(protocolVersion)") - } - - let headerSessionId = request.headers["mcp-session-id"] - return await dispatchMethod( - rpcRequest, - headerSessionId: headerSessionId, - server: server, - authenticatedToken: authenticatedToken - ) - } - - private func dispatchMethod( - _ request: JSONRPCRequest, - headerSessionId: String?, - server: MCPServer, - authenticatedToken: MCPAuthToken? - ) async -> MCPRouter.RouteResult { - if request.method == "initialize" { - return await handleInitialize(request, server: server) - } - - if request.method == "ping" { - return handlePing(request) - } - - guard let sessionId = headerSessionId else { - return .httpError(status: 400, message: "Missing Mcp-Session-Id header") - } - guard let session = await server.session(for: sessionId) else { - return .httpError(status: 404, message: "Session not found") - } - - await session.markActive() - - if request.method == "notifications/initialized" { - do { - try await session.transition(to: .active( - tokenId: authenticatedToken?.id, - tokenName: authenticatedToken?.name - )) - } catch { - return encodeError(MCPError.invalidRequest("Cannot initialize session in current phase"), id: request.id) - } - return .accepted - } - - if request.method == "notifications/cancelled" { - return await handleCancellation(request, session: session) - } - - guard await session.phase.isActive else { - return encodeError( - MCPError.invalidRequest("Session not initialized. Send notifications/initialized first."), - id: request.id - ) - } - - switch request.method { - case "tools/list": - return handleToolsList(request, sessionId: sessionId) - - case "tools/call": - return await handleToolsCall( - request, - sessionId: sessionId, - server: server, - authenticatedToken: authenticatedToken - ) - - case "resources/list": - return handleResourcesList(request, sessionId: sessionId) - - case "resources/read": - return await handleResourcesRead(request, sessionId: sessionId, server: server) - - default: - return encodeError(MCPError.methodNotFound(request.method), id: request.id) - } - } - - private func handleInitialize( - _ request: JSONRPCRequest, - server: MCPServer - ) async -> MCPRouter.RouteResult { - guard let session = await server.createSession() else { - return encodeError(MCPError.internalError("Maximum sessions reached"), id: request.id) - } - - if let params = request.params, - let clientInfo = params["clientInfo"], - let name = clientInfo["name"]?.stringValue - { - let version = clientInfo["version"]?.stringValue - await session.setClientInfo(MCPClientInfo(name: name, version: version)) - } - - do { - try await session.transition(to: .initializing) - } catch { - await server.removeSession(session.id) - return encodeError(MCPError.invalidRequest("Cannot initialize session"), id: request.id) - } - - let result = MCPInitializeResult( - protocolVersion: "2025-03-26", - capabilities: MCPServerCapabilities( - tools: .init(listChanged: false), - resources: .init(subscribe: false, listChanged: false) - ), - serverInfo: MCPServerInfo(name: "tablepro", version: "1.0.0") - ) - - return encodeResult(result, id: request.id, sessionId: session.id) - } - - private func handlePing(_ request: JSONRPCRequest) -> MCPRouter.RouteResult { - guard let id = request.id else { - return .accepted - } - return encodeRawResult(.object([:]), id: id, sessionId: nil) - } - - private func handleCancellation( - _ request: JSONRPCRequest, - session: MCPSession - ) async -> MCPRouter.RouteResult { - guard let params = request.params, - let requestIdValue = params["requestId"] - else { - return .accepted - } - - let cancelId: JSONRPCId? - switch requestIdValue { - case .string(let s): - cancelId = .string(s) - case .int(let i): - cancelId = .int(i) - default: - cancelId = nil - } - - if let cancelId, let task = await session.removeRunningTask(cancelId) { - task.cancel() - Self.logger.info("Cancelled request \(String(describing: cancelId)) in session \(session.id)") - } - - return .accepted - } - - private func handleToolsList(_ request: JSONRPCRequest, sessionId: String) -> MCPRouter.RouteResult { - guard let id = request.id else { - return .accepted - } - - let tools = MCPRouter.toolDefinitions() - let result: JSONValue = .object(["tools": encodeToolDefinitions(tools)]) - return encodeRawResult(result, id: id, sessionId: sessionId) - } - - private func handleToolsCall( - _ request: JSONRPCRequest, - sessionId: String, - server: MCPServer, - authenticatedToken: MCPAuthToken? - ) async -> MCPRouter.RouteResult { - guard let id = request.id else { - return encodeError(MCPError.invalidRequest("tools/call requires an id"), id: nil) - } - - guard let params = request.params, - let name = params["name"]?.stringValue - else { - return encodeError(MCPError.invalidParams("Missing tool name"), id: id) - } - - let arguments = params["arguments"] - - guard let handler = await server.toolCallHandler else { - return encodeError(MCPError.internalError("Server not fully initialized"), id: id) - } - - let session = await server.session(for: sessionId) - let toolTask = Task { - try await handler(name, arguments, sessionId, authenticatedToken) - } - if let session { - let cancelForwardingTask = Task { - await withTaskCancellationHandler { - _ = try? await toolTask.value - } onCancel: { - toolTask.cancel() - } - } - await session.addRunningTask(id, task: cancelForwardingTask) - } - - do { - let toolResult = try await toolTask.value - if let session { _ = await session.removeRunningTask(id) } - let resultData = try encoder.encode(toolResult) - guard let resultValue = try? decoder.decode(JSONValue.self, from: resultData) else { - return encodeError(MCPError.internalError("Failed to encode tool result"), id: id) - } - return encodeRawResult(resultValue, id: id, sessionId: sessionId) - } catch is CancellationError { - if let session { _ = await session.removeRunningTask(id) } - return encodeError(MCPError.timeout("Request was cancelled"), id: id) - } catch let mcpError as MCPError { - if let session { _ = await session.removeRunningTask(id) } - return encodeError(mcpError, id: id) - } catch { - if let session { _ = await session.removeRunningTask(id) } - return encodeError(MCPError.internalError(error.localizedDescription), id: id) - } - } - - private func handleResourcesList(_ request: JSONRPCRequest, sessionId: String) -> MCPRouter.RouteResult { - guard let id = request.id else { - return .accepted - } - - let resources = MCPRouter.resourceDefinitions() - let result: JSONValue = .object(["resources": encodeResourceDefinitions(resources)]) - return encodeRawResult(result, id: id, sessionId: sessionId) - } - - private func handleResourcesRead( - _ request: JSONRPCRequest, - sessionId: String, - server: MCPServer - ) async -> MCPRouter.RouteResult { - guard let id = request.id else { - return encodeError(MCPError.invalidRequest("resources/read requires an id"), id: nil) - } - - guard let params = request.params, - let uri = params["uri"]?.stringValue - else { - return encodeError(MCPError.invalidParams("Missing resource uri"), id: id) - } - - guard let handler = await server.resourceReadHandler else { - return encodeError(MCPError.internalError("Server not fully initialized"), id: id) - } - - do { - let readResult = try await handler(uri, sessionId) - let resultData = try encoder.encode(readResult) - guard let resultValue = try? decoder.decode(JSONValue.self, from: resultData) else { - return encodeError(MCPError.internalError("Failed to encode resource result"), id: id) - } - return encodeRawResult(resultValue, id: id, sessionId: sessionId) - } catch let mcpError as MCPError { - return encodeError(mcpError, id: id) - } catch { - return encodeError(MCPError.internalError(error.localizedDescription), id: id) - } - } - - private func encodeResult(_ result: T, id: JSONRPCId?, sessionId: String?) -> MCPRouter.RouteResult { - guard let id else { - return .accepted - } - - do { - let resultData = try encoder.encode(result) - let resultValue = try decoder.decode(JSONValue.self, from: resultData) - let response = JSONRPCResponse(id: id, result: resultValue) - let data = try encoder.encode(response) - return .json(data, sessionId: sessionId) - } catch { - Self.logger.error("Failed to encode response: \(error.localizedDescription)") - return encodeError(MCPError.internalError("Encoding failed"), id: id) - } - } - - private func encodeRawResult(_ result: JSONValue, id: JSONRPCId, sessionId: String?) -> MCPRouter.RouteResult { - do { - let response = JSONRPCResponse(id: id, result: result) - let data = try encoder.encode(response) - return .json(data, sessionId: sessionId) - } catch { - Self.logger.error("Failed to encode response: \(error.localizedDescription)") - return encodeError(MCPError.internalError("Encoding failed"), id: id) - } - } - - private func encodeError(_ error: MCPError, id: JSONRPCId?) -> MCPRouter.RouteResult { - let errorResponse = error.toJsonRpcError(id: id) - do { - let data = try encoder.encode(errorResponse) - return .json(data, sessionId: nil) - } catch { - Self.logger.error("Failed to encode error response") - return .httpError(status: 500, message: "Internal encoding error") - } - } - - private func encodeToolDefinitions(_ tools: [MCPToolDefinition]) -> JSONValue { - .array(tools.map { tool in - .object([ - "name": .string(tool.name), - "description": .string(tool.description), - "inputSchema": tool.inputSchema - ]) - }) - } - - private func encodeResourceDefinitions(_ resources: [MCPResourceDefinition]) -> JSONValue { - .array(resources.map { resource in - var dict: [String: JSONValue] = [ - "uri": .string(resource.uri), - "name": .string(resource.name) - ] - if let description = resource.description { - dict["description"] = .string(description) - } - if let mimeType = resource.mimeType { - dict["mimeType"] = .string(mimeType) - } - return .object(dict) - }) - } -} diff --git a/TablePro/Core/MCP/Session/MCPClock.swift b/TablePro/Core/MCP/Session/MCPClock.swift new file mode 100644 index 000000000..32b9259dd --- /dev/null +++ b/TablePro/Core/MCP/Session/MCPClock.swift @@ -0,0 +1,18 @@ +import Foundation + +public protocol MCPClock: Sendable { + func now() async -> Date + func sleep(for duration: Duration) async throws +} + +public struct MCPSystemClock: MCPClock { + public init() {} + + public func now() async -> Date { + Date() + } + + public func sleep(for duration: Duration) async throws { + try await Task.sleep(for: duration) + } +} diff --git a/TablePro/Core/MCP/Session/MCPSession.swift b/TablePro/Core/MCP/Session/MCPSession.swift new file mode 100644 index 000000000..dc4c01cdd --- /dev/null +++ b/TablePro/Core/MCP/Session/MCPSession.swift @@ -0,0 +1,106 @@ +import Foundation + +public struct MCPClientInfo: Sendable, Equatable { + public let name: String + public let version: String? + + public init(name: String, version: String? = nil) { + self.name = name + self.version = version + } +} + +public struct MCPSessionSnapshot: Sendable { + public let id: MCPSessionId + public let createdAt: Date + public let lastActivityAt: Date + public let state: MCPSessionState + public let clientInfo: MCPClientInfo? + + public init( + id: MCPSessionId, + createdAt: Date, + lastActivityAt: Date, + state: MCPSessionState, + clientInfo: MCPClientInfo? + ) { + self.id = id + self.createdAt = createdAt + self.lastActivityAt = lastActivityAt + self.state = state + self.clientInfo = clientInfo + } +} + +public enum MCPSessionTransitionError: Error, Sendable, Equatable { + case illegalTransition(from: MCPSessionState, to: MCPSessionState) +} + +public actor MCPSession { + nonisolated public let id: MCPSessionId + nonisolated public let createdAt: Date + public private(set) var lastActivityAt: Date + public private(set) var state: MCPSessionState + public private(set) var clientInfo: MCPClientInfo? + public private(set) var negotiatedProtocolVersion: String? + public private(set) var clientCapabilities: JsonValue? + public private(set) var principalTokenId: UUID? + + public init(id: MCPSessionId = .generate(), now: Date = Date()) { + self.id = id + self.createdAt = now + self.lastActivityAt = now + self.state = .initializing + self.clientInfo = nil + self.negotiatedProtocolVersion = nil + self.clientCapabilities = nil + self.principalTokenId = nil + } + + public func touch(now: Date = Date()) { + guard !isTerminated else { return } + lastActivityAt = now + } + + public func bindPrincipal(tokenId: UUID?) { + guard !isTerminated else { return } + principalTokenId = tokenId + } + + public func recordInitialize( + clientInfo: MCPClientInfo, + protocolVersion: String, + capabilities: JsonValue? + ) { + self.clientInfo = clientInfo + self.negotiatedProtocolVersion = protocolVersion + self.clientCapabilities = capabilities + } + + public func transitionToReady() throws { + guard case .initializing = state else { + throw MCPSessionTransitionError.illegalTransition(from: state, to: .ready) + } + state = .ready + } + + public func terminate(reason: MCPSessionTerminationReason) { + if case .terminated = state { return } + state = .terminated(reason: reason) + } + + public func snapshot() -> MCPSessionSnapshot { + MCPSessionSnapshot( + id: id, + createdAt: createdAt, + lastActivityAt: lastActivityAt, + state: state, + clientInfo: clientInfo + ) + } + + private var isTerminated: Bool { + if case .terminated = state { return true } + return false + } +} diff --git a/TablePro/Core/MCP/Session/MCPSessionEvent.swift b/TablePro/Core/MCP/Session/MCPSessionEvent.swift new file mode 100644 index 000000000..4f219c443 --- /dev/null +++ b/TablePro/Core/MCP/Session/MCPSessionEvent.swift @@ -0,0 +1,6 @@ +import Foundation + +public enum MCPSessionEvent: Sendable { + case created(MCPSessionId) + case terminated(MCPSessionId, reason: MCPSessionTerminationReason) +} diff --git a/TablePro/Core/MCP/Session/MCPSessionId.swift b/TablePro/Core/MCP/Session/MCPSessionId.swift new file mode 100644 index 000000000..4064eb0b0 --- /dev/null +++ b/TablePro/Core/MCP/Session/MCPSessionId.swift @@ -0,0 +1,17 @@ +import Foundation + +public struct MCPSessionId: Sendable, Hashable, Equatable, CustomStringConvertible { + public let rawValue: String + + public init(_ rawValue: String) { + self.rawValue = rawValue + } + + public static func generate() -> MCPSessionId { + MCPSessionId(UUID().uuidString) + } + + public var description: String { + rawValue + } +} diff --git a/TablePro/Core/MCP/Session/MCPSessionPolicy.swift b/TablePro/Core/MCP/Session/MCPSessionPolicy.swift new file mode 100644 index 000000000..f7c04f666 --- /dev/null +++ b/TablePro/Core/MCP/Session/MCPSessionPolicy.swift @@ -0,0 +1,19 @@ +import Foundation + +public struct MCPSessionPolicy: Sendable, Equatable { + public let idleTimeout: Duration + public let maxSessions: Int + public let cleanupInterval: Duration + + public init(idleTimeout: Duration, maxSessions: Int, cleanupInterval: Duration) { + self.idleTimeout = idleTimeout + self.maxSessions = maxSessions + self.cleanupInterval = cleanupInterval + } + + public static let standard = MCPSessionPolicy( + idleTimeout: .seconds(900), + maxSessions: 16, + cleanupInterval: .seconds(60) + ) +} diff --git a/TablePro/Core/MCP/Session/MCPSessionState.swift b/TablePro/Core/MCP/Session/MCPSessionState.swift new file mode 100644 index 000000000..419a5f391 --- /dev/null +++ b/TablePro/Core/MCP/Session/MCPSessionState.swift @@ -0,0 +1,30 @@ +import Foundation + +public enum MCPSessionState: Sendable, Equatable { + case initializing + case ready + case terminated(reason: MCPSessionTerminationReason) +} + +public enum MCPSessionTerminationReason: Sendable, Equatable, CustomStringConvertible { + case clientRequested + case idleTimeout + case capacityEvicted + case serverShutdown + case tokenRevoked + + public var description: String { + switch self { + case .clientRequested: + return "client_requested" + case .idleTimeout: + return "idle_timeout" + case .capacityEvicted: + return "capacity_evicted" + case .serverShutdown: + return "server_shutdown" + case .tokenRevoked: + return "token_revoked" + } + } +} diff --git a/TablePro/Core/MCP/Session/MCPSessionStore.swift b/TablePro/Core/MCP/Session/MCPSessionStore.swift new file mode 100644 index 000000000..7adf2e6b4 --- /dev/null +++ b/TablePro/Core/MCP/Session/MCPSessionStore.swift @@ -0,0 +1,159 @@ +import Foundation +import os + +public enum MCPSessionStoreError: Error, Sendable, Equatable { + case capacityExceeded(limit: Int) + case sessionNotFound(MCPSessionId) +} + +public actor MCPSessionStore { + private static let logger = Logger(subsystem: "com.TablePro", category: "MCP.Session") + + private let policy: MCPSessionPolicy + private let clock: any MCPClock + + private var sessions: [MCPSessionId: MCPSession] = [:] + private var eventSubscribers: [UUID: AsyncStream.Continuation] = [:] + private var cleanupTask: Task? + + public init(policy: MCPSessionPolicy = .standard, clock: any MCPClock = MCPSystemClock()) { + self.policy = policy + self.clock = clock + } + + public func create() async throws -> MCPSession { + guard sessions.count < policy.maxSessions else { + Self.logger.warning("Session capacity exceeded (limit \(self.policy.maxSessions))") + throw MCPSessionStoreError.capacityExceeded(limit: policy.maxSessions) + } + + let now = await clock.now() + let session = MCPSession(now: now) + sessions[session.id] = session + Self.logger.info("Session created: \(session.id.rawValue, privacy: .public)") + broadcast(.created(session.id)) + return session + } + + public func session(id: MCPSessionId) async -> MCPSession? { + sessions[id] + } + + public func touch(id: MCPSessionId) async { + guard let session = sessions[id] else { return } + let now = await clock.now() + await session.touch(now: now) + } + + public func terminate(id: MCPSessionId, reason: MCPSessionTerminationReason) async { + guard let session = sessions.removeValue(forKey: id) else { return } + await session.terminate(reason: reason) + Self.logger.info( + "Session terminated: \(id.rawValue, privacy: .public) reason=\(reason.description, privacy: .public)" + ) + broadcast(.terminated(id, reason: reason)) + } + + public func count() async -> Int { + sessions.count + } + + public func allSessions() async -> [MCPSession] { + Array(sessions.values) + } + + public func sessionIds(forPrincipalTokenId tokenId: UUID) async -> [MCPSessionId] { + var matching: [MCPSessionId] = [] + for (sessionId, session) in sessions { + let bound = await session.principalTokenId + if bound == tokenId { + matching.append(sessionId) + } + } + return matching + } + + public var events: AsyncStream { + let (stream, continuation) = AsyncStream.makeStream( + bufferingPolicy: .bufferingNewest(64) + ) + let subscriberId = UUID() + eventSubscribers[subscriberId] = continuation + continuation.onTermination = { [weak self] _ in + guard let self else { return } + Task { await self.removeSubscriber(subscriberId) } + } + return stream + } + + public func startCleanup() async { + guard cleanupTask == nil else { return } + let interval = policy.cleanupInterval + let clockRef = clock + cleanupTask = Task { [weak self] in + while !Task.isCancelled { + do { + try await clockRef.sleep(for: interval) + } catch { + return + } + guard let self else { return } + await self.runCleanupPass() + } + } + } + + public func stopCleanup() async { + cleanupTask?.cancel() + cleanupTask = nil + } + + public func runCleanupPass() async { + let now = await clock.now() + let idleSeconds = Self.seconds(of: policy.idleTimeout) + let cutoff = now.addingTimeInterval(-idleSeconds) + + var expired: [MCPSessionId] = [] + for (sessionId, session) in sessions { + let lastActivity = await session.lastActivityAt + if lastActivity < cutoff { + expired.append(sessionId) + } + } + + for sessionId in expired { + await terminate(id: sessionId, reason: .idleTimeout) + } + + if !expired.isEmpty { + Self.logger.info("Idle cleanup terminated \(expired.count) session(s)") + } + } + + public func shutdown(reason: MCPSessionTerminationReason = .serverShutdown) async { + await stopCleanup() + let activeIds = Array(sessions.keys) + for sessionId in activeIds { + await terminate(id: sessionId, reason: reason) + } + for (_, continuation) in eventSubscribers { + continuation.finish() + } + eventSubscribers.removeAll() + } + + private func broadcast(_ event: MCPSessionEvent) { + for (_, continuation) in eventSubscribers { + continuation.yield(event) + } + } + + private func removeSubscriber(_ id: UUID) { + eventSubscribers.removeValue(forKey: id) + } + + private static func seconds(of duration: Duration) -> TimeInterval { + let components = duration.components + return TimeInterval(components.seconds) + TimeInterval(components.attoseconds) / 1.0e18 + } +} diff --git a/TablePro/Core/MCP/Transport/MCPBridgeLogger.swift b/TablePro/Core/MCP/Transport/MCPBridgeLogger.swift new file mode 100644 index 000000000..ad9ee8e83 --- /dev/null +++ b/TablePro/Core/MCP/Transport/MCPBridgeLogger.swift @@ -0,0 +1,69 @@ +import Foundation +import os + +public enum MCPBridgeLogLevel: String, Sendable { + case debug + case info + case warning + case error +} + +public protocol MCPBridgeLogger: Sendable { + func log(_ level: MCPBridgeLogLevel, _ message: String) +} + +public struct MCPOSBridgeLogger: MCPBridgeLogger { + private let logger: Logger + + public init(subsystem: String = "com.TablePro", category: String = "MCP.Bridge") { + logger = Logger(subsystem: subsystem, category: category) + } + + public func log(_ level: MCPBridgeLogLevel, _ message: String) { + switch level { + case .debug: + logger.debug("\(message, privacy: .public)") + case .info: + logger.info("\(message, privacy: .public)") + case .warning: + logger.warning("\(message, privacy: .public)") + case .error: + logger.error("\(message, privacy: .public)") + } + } +} + +public struct MCPStderrBridgeLogger: MCPBridgeLogger { + private static let lock = NSLock() + + public init() {} + + public func log(_ level: MCPBridgeLogLevel, _ message: String) { + let prefix: String + switch level { + case .debug: prefix = "[debug] " + case .info: prefix = "[info] " + case .warning: prefix = "[warn] " + case .error: prefix = "[error] " + } + let payload = prefix + message + "\n" + guard let data = payload.data(using: .utf8) else { return } + Self.lock.lock() + defer { Self.lock.unlock() } + FileHandle.standardError.write(data) + } +} + +public struct MCPCompositeBridgeLogger: MCPBridgeLogger { + private let loggers: [any MCPBridgeLogger] + + public init(_ loggers: [any MCPBridgeLogger]) { + self.loggers = loggers + } + + public func log(_ level: MCPBridgeLogLevel, _ message: String) { + for logger in loggers { + logger.log(level, message) + } + } +} diff --git a/TablePro/Core/MCP/Transport/MCPCorsHeaders.swift b/TablePro/Core/MCP/Transport/MCPCorsHeaders.swift new file mode 100644 index 000000000..2ca17cf1d --- /dev/null +++ b/TablePro/Core/MCP/Transport/MCPCorsHeaders.swift @@ -0,0 +1,39 @@ +import Foundation + +public enum MCPCorsHeaders { + private static let allowedHosts: Set = [ + "localhost", + "127.0.0.1", + "claude.ai", + "app.cursor.com" + ] + + private static let baseHeaders: [(String, String)] = [ + ("Access-Control-Allow-Methods", "GET, POST, DELETE, OPTIONS"), + ( + "Access-Control-Allow-Headers", + "Content-Type, Mcp-Session-Id, mcp-protocol-version, Authorization, Last-Event-ID" + ), + ("Access-Control-Expose-Headers", "Mcp-Session-Id"), + ("Access-Control-Max-Age", "86400") + ] + + public static func headers(forOrigin origin: String?) -> [(String, String)] { + guard let origin, !origin.isEmpty else { return [] } + guard isAllowed(origin: origin) else { return [] } + var headers: [(String, String)] = [("Access-Control-Allow-Origin", origin)] + headers.append(("Vary", "Origin")) + headers.append(contentsOf: baseHeaders) + return headers + } + + public static func isAllowed(origin: String) -> Bool { + guard let url = URL(string: origin), + let scheme = url.scheme?.lowercased(), + let host = url.host?.lowercased() else { + return false + } + guard scheme == "http" || scheme == "https" else { return false } + return allowedHosts.contains(host) + } +} diff --git a/TablePro/Core/MCP/Transport/MCPHttpServerConfiguration.swift b/TablePro/Core/MCP/Transport/MCPHttpServerConfiguration.swift new file mode 100644 index 000000000..0247301d0 --- /dev/null +++ b/TablePro/Core/MCP/Transport/MCPHttpServerConfiguration.swift @@ -0,0 +1,95 @@ +import Foundation +@preconcurrency import Security + +public enum MCPBindAddress: Sendable, Equatable { + case loopback + case anyInterface +} + +public enum TLSProtocolVersion: Sendable, Equatable { + case tls12 + case tls13 +} + +public struct MCPTLSConfiguration: Sendable { + public let identity: SecIdentity + public let minimumProtocol: TLSProtocolVersion + + public init(identity: SecIdentity, minimumProtocol: TLSProtocolVersion = .tls12) { + self.identity = identity + self.minimumProtocol = minimumProtocol + } +} + +public struct MCPHttpServerLimits: Sendable, Equatable { + public let maxRequestBodyBytes: Int + public let maxHeaderBytes: Int + public let connectionTimeout: Duration + + public init( + maxRequestBodyBytes: Int, + maxHeaderBytes: Int, + connectionTimeout: Duration + ) { + self.maxRequestBodyBytes = maxRequestBodyBytes + self.maxHeaderBytes = maxHeaderBytes + self.connectionTimeout = connectionTimeout + } + + public static let standard = MCPHttpServerLimits( + maxRequestBodyBytes: 10 * 1_024 * 1_024, + maxHeaderBytes: 16 * 1_024, + connectionTimeout: .seconds(30) + ) +} + +public struct MCPHttpServerConfiguration: Sendable { + public let bindAddress: MCPBindAddress + public let port: UInt16 + public let tls: MCPTLSConfiguration? + public let limits: MCPHttpServerLimits + + private init( + bindAddress: MCPBindAddress, + port: UInt16, + tls: MCPTLSConfiguration?, + limits: MCPHttpServerLimits + ) { + self.bindAddress = bindAddress + self.port = port + self.tls = tls + self.limits = limits + } + + public static func loopback( + port: UInt16, + limits: MCPHttpServerLimits = .standard + ) -> Self { + Self(bindAddress: .loopback, port: port, tls: nil, limits: limits) + } + + public static func loopback( + port: UInt16, + tls: MCPTLSConfiguration, + limits: MCPHttpServerLimits = .standard + ) -> Self { + Self(bindAddress: .loopback, port: port, tls: tls, limits: limits) + } + + public static func remote( + port: UInt16, + tls: MCPTLSConfiguration, + limits: MCPHttpServerLimits = .standard + ) -> Self { + Self(bindAddress: .anyInterface, port: port, tls: tls, limits: limits) + } + + internal static func unsafeMake( + bindAddress: MCPBindAddress, + port: UInt16, + tls: MCPTLSConfiguration?, + limits: MCPHttpServerLimits + ) -> Self { + Self(bindAddress: bindAddress, port: port, tls: tls, limits: limits) + } +} diff --git a/TablePro/Core/MCP/Transport/MCPHttpServerError.swift b/TablePro/Core/MCP/Transport/MCPHttpServerError.swift new file mode 100644 index 000000000..960139db9 --- /dev/null +++ b/TablePro/Core/MCP/Transport/MCPHttpServerError.swift @@ -0,0 +1,24 @@ +import Foundation + +public enum MCPHttpServerError: Error, Sendable, Equatable, LocalizedError { + case tlsRequiredForRemoteAccess + case alreadyStarted + case notStarted + case bindFailed(reason: String) + case acceptCancelled + + public var errorDescription: String? { + switch self { + case .tlsRequiredForRemoteAccess: + return "Remote access requires TLS to be enabled" + case .alreadyStarted: + return "MCP server is already running" + case .notStarted: + return "MCP server is not running" + case .bindFailed(let reason): + return "Failed to bind MCP server: \(reason)" + case .acceptCancelled: + return "MCP server accept loop was cancelled" + } + } +} diff --git a/TablePro/Core/MCP/Transport/MCPHttpServerTransport.swift b/TablePro/Core/MCP/Transport/MCPHttpServerTransport.swift new file mode 100644 index 000000000..f2d5ac764 --- /dev/null +++ b/TablePro/Core/MCP/Transport/MCPHttpServerTransport.swift @@ -0,0 +1,1041 @@ +import Foundation +import Network +import os +import Security + +public enum MCPHttpServerState: Sendable, Equatable { + case idle + case starting + case running(port: UInt16) + case stopped + case failed(reason: String) +} + +public actor MCPHttpServerTransport { + private static let logger = Logger(subsystem: "com.TablePro", category: "MCP.HttpServer") + + private let configuration: MCPHttpServerConfiguration + private let sessionStore: MCPSessionStore + private let authenticator: any MCPAuthenticator + private let clock: any MCPClock + + private var listener: NWListener? + private var connections: [UUID: HttpConnectionContext] = [:] + private var sseConnectionsBySession: [MCPSessionId: UUID] = [:] + private var sessionEventsTask: Task? + + nonisolated public let exchanges: AsyncStream + nonisolated private let exchangesContinuation: AsyncStream.Continuation + + nonisolated public let listenerState: AsyncStream + nonisolated private let stateContinuation: AsyncStream.Continuation + + private var currentState: MCPHttpServerState = .idle + + public init( + configuration: MCPHttpServerConfiguration, + sessionStore: MCPSessionStore, + authenticator: any MCPAuthenticator, + clock: any MCPClock = MCPSystemClock() + ) { + self.configuration = configuration + self.sessionStore = sessionStore + self.authenticator = authenticator + self.clock = clock + + let (exchanges, exchangesContinuation) = AsyncStream.makeStream( + bufferingPolicy: .bufferingOldest(1024) + ) + self.exchanges = exchanges + self.exchangesContinuation = exchangesContinuation + + let (listenerState, stateContinuation) = AsyncStream.makeStream() + self.listenerState = listenerState + self.stateContinuation = stateContinuation + } + + public func start() async throws { + guard listener == nil else { + Self.logger.warning("start() called while listener already exists") + throw MCPHttpServerError.alreadyStarted + } + + Self.logger.info("Starting MCP HTTP server: bind=\(String(describing: self.configuration.bindAddress)) port=\(self.configuration.port) tls=\(self.configuration.tls != nil)") + + if configuration.bindAddress == .anyInterface, configuration.tls == nil { + Self.logger.error("Remote access requested without TLS — refusing to start") + throw MCPHttpServerError.tlsRequiredForRemoteAccess + } + + emitState(.starting) + + let parameters: NWParameters = makeParameters() + + do { + let newListener = try NWListener(using: parameters) + listener = newListener + + newListener.stateUpdateHandler = { [weak self] state in + guard let self else { return } + Task { await self.handleListenerState(state) } + } + + newListener.newConnectionHandler = { [weak self] connection in + guard let self else { return } + Task { await self.handleNewConnection(connection) } + } + + newListener.start(queue: .global(qos: .userInitiated)) + startSessionEventListener() + } catch { + emitState(.failed(reason: error.localizedDescription)) + listener = nil + throw MCPHttpServerError.bindFailed(reason: error.localizedDescription) + } + } + + public func stop() async { + Self.logger.info("Stopping MCP HTTP server") + + sessionEventsTask?.cancel() + sessionEventsTask = nil + + for (_, context) in connections { + await context.cancel() + } + connections.removeAll() + sseConnectionsBySession.removeAll() + + if let listener { + self.listener = nil + await withCheckedContinuation { (continuation: CheckedContinuation) in + listener.stateUpdateHandler = { state in + if case .cancelled = state { + continuation.resume() + } + } + listener.cancel() + } + } + + emitState(.stopped) + exchangesContinuation.finish() + stateContinuation.finish() + } + + public func sendNotification(_ notification: JsonRpcNotification, toSession sessionId: MCPSessionId) async { + guard let connectionId = sseConnectionsBySession[sessionId], + let context = connections[connectionId] else { + return + } + + let message = JsonRpcMessage.notification(notification) + guard let data = try? JsonRpcCodec.encode(message), + let text = String(data: data, encoding: .utf8) else { return } + await context.writeSseFrame(SseFrame(data: text)) + } + + public func broadcastNotification(_ notification: JsonRpcNotification) async { + let sessionIds = Array(sseConnectionsBySession.keys) + for sessionId in sessionIds { + await sendNotification(notification, toSession: sessionId) + } + } + + private func makeParameters() -> NWParameters { + let tcpOptions = NWProtocolTCP.Options() + + let parameters: NWParameters + if let tls = configuration.tls { + let tlsOptions = NWProtocolTLS.Options() + if let secIdentity = sec_identity_create(tls.identity) { + sec_protocol_options_set_local_identity(tlsOptions.securityProtocolOptions, secIdentity) + } + switch tls.minimumProtocol { + case .tls12: + sec_protocol_options_set_min_tls_protocol_version(tlsOptions.securityProtocolOptions, .TLSv12) + case .tls13: + sec_protocol_options_set_min_tls_protocol_version(tlsOptions.securityProtocolOptions, .TLSv13) + } + parameters = NWParameters(tls: tlsOptions, tcp: tcpOptions) + } else { + parameters = NWParameters(tls: nil, tcp: tcpOptions) + } + + let host: NWEndpoint.Host = configuration.bindAddress == .loopback ? .ipv4(.loopback) : .ipv4(.any) + let port = NWEndpoint.Port(rawValue: configuration.port) ?? .any + parameters.requiredLocalEndpoint = NWEndpoint.hostPort(host: host, port: port) + parameters.allowLocalEndpointReuse = true + return parameters + } + + private func handleListenerState(_ state: NWListener.State) { + switch state { + case .ready: + let port = listener?.port?.rawValue ?? configuration.port + Self.logger.info("MCP HTTP server listening on port \(port, privacy: .public)") + emitState(.running(port: port)) + + case .failed(let error): + Self.logger.error("MCP HTTP listener failed: \(error.localizedDescription, privacy: .public)") + emitState(.failed(reason: error.localizedDescription)) + listener?.cancel() + listener = nil + + case .cancelled: + Self.logger.debug("MCP HTTP listener cancelled") + + default: + break + } + } + + private func emitState(_ state: MCPHttpServerState) { + currentState = state + stateContinuation.yield(state) + } + + private func startSessionEventListener() { + sessionEventsTask?.cancel() + let store = sessionStore + sessionEventsTask = Task { [weak self] in + let eventsStream = await store.events + for await event in eventsStream { + guard let self else { return } + if case .terminated(let sessionId, let reason) = event { + await self.handleSessionTerminated(sessionId, reason: reason) + } + } + } + } + + private func handleSessionTerminated(_ sessionId: MCPSessionId, reason: MCPSessionTerminationReason) async { + guard let connectionId = sseConnectionsBySession.removeValue(forKey: sessionId), + let context = connections[connectionId] else { + return + } + + let comment: String + switch reason { + case .idleTimeout: + comment = "idle-timeout" + case .tokenRevoked: + comment = "token-revoked" + case .serverShutdown: + comment = "server-shutdown" + case .clientRequested: + comment = "client-disconnect" + case .capacityEvicted: + comment = "capacity-evicted" + } + await context.writeRaw(Data("\u{003A} \(comment)\n\n".utf8)) + await context.cancel() + connections.removeValue(forKey: connectionId) + } + + private func handleNewConnection(_ connection: NWConnection) async { + let connectionId = UUID() + Self.logger.debug("Accepted connection \(connectionId, privacy: .public)") + let context = HttpConnectionContext(id: connectionId, connection: connection) + connections[connectionId] = context + await context.start { [weak self] data in + guard let self else { return } + await self.handleReceivedData(connectionId: connectionId, data: data) + } onClosed: { [weak self] in + guard let self else { return } + await self.removeConnection(connectionId: connectionId) + } + } + + private func removeConnection(connectionId: UUID) async { + connections.removeValue(forKey: connectionId) + let pairs = sseConnectionsBySession.filter { $0.value == connectionId } + for (sessionId, _) in pairs { + sseConnectionsBySession.removeValue(forKey: sessionId) + } + } + + private func handleReceivedData(connectionId: UUID, data: Data) async { + guard let context = connections[connectionId] else { return } + + let parseResult: HttpRequestParseResult + do { + parseResult = try HttpRequestParser.parse(data) + } catch HttpRequestParseError.bodyTooLarge { + await respondTopLevel(context: context, error: .payloadTooLarge(), requestId: nil) + return + } catch HttpRequestParseError.headerTooLarge { + await respondTopLevel(context: context, error: .payloadTooLarge(), requestId: nil) + return + } catch { + await respondTopLevel( + context: context, + error: .invalidRequest(detail: "Malformed HTTP"), + requestId: nil + ) + return + } + + switch parseResult { + case .incomplete: + return + case .complete(let head, let body, _): + await context.markRequestComplete() + await dispatch(head: head, body: body, context: context) + } + } + + private func dispatch(head: HttpRequestHead, body: Data, context: HttpConnectionContext) async { + let clientAddress: MCPClientAddress = await context.clientAddress() + let now = await clock.now() + + await context.setOrigin(head.headers.value(for: "Origin")) + + if head.method == .post, stripQueryString(head.path) == "/v1/integrations/exchange" { + await handleIntegrationsExchange(body: body, context: context) + return + } + + switch head.method { + case .options: + await context.writeOptions204() + await context.cancel() + return + case .get: + await handleGetMcp(head: head, context: context, clientAddress: clientAddress) + case .post: + await handlePostMcp(head: head, body: body, context: context, clientAddress: clientAddress, now: now) + case .delete: + await handleDeleteMcp(head: head, context: context, clientAddress: clientAddress) + default: + await respondTopLevel( + context: context, + error: MCPProtocolError( + code: JsonRpcErrorCode.methodNotFound, + message: "Method not allowed", + httpStatus: .methodNotAllowed + ), + requestId: nil + ) + } + } + + private func handleIntegrationsExchange(body: Data, context: HttpConnectionContext) async { + struct ExchangeBody: Decodable { + let code: String + let codeVerifier: String + enum CodingKeys: String, CodingKey { + case code + case codeVerifier = "code_verifier" + } + } + struct ExchangeResponse: Encodable { + let token: String + } + + Self.logger.info("Integrations exchange request received (\(body.count, privacy: .public) bytes)") + let ip = Self.ipString(for: await context.clientAddress()) + + let parsed: ExchangeBody + do { + parsed = try JSONDecoder().decode(ExchangeBody.self, from: body) + } catch { + Self.logger.warning("Integrations exchange decode failed: \(error.localizedDescription, privacy: .public)") + MCPAuditLogger.logPairingExchange(outcome: .denied, ip: ip, details: "invalid JSON body") + await context.writePlainJsonError(status: .badRequest, message: "Invalid JSON body") + await context.cancel() + return + } + + guard !parsed.code.isEmpty, !parsed.codeVerifier.isEmpty else { + Self.logger.warning("Integrations exchange missing code or verifier") + MCPAuditLogger.logPairingExchange( + outcome: .denied, + ip: ip, + details: "missing code or code_verifier" + ) + await context.writePlainJsonError(status: .badRequest, message: "Missing code or code_verifier") + await context.cancel() + return + } + + guard parsed.code.utf8.count <= 1024, parsed.codeVerifier.utf8.count <= 1024 else { + Self.logger.warning("Integrations exchange field exceeds size cap") + MCPAuditLogger.logPairingExchange( + outcome: .denied, + ip: ip, + details: "field exceeds 1024 bytes" + ) + await context.writePlainJsonError(status: .badRequest, message: "Field exceeds size limit") + await context.cancel() + return + } + + let exchange = PairingExchange(code: parsed.code, verifier: parsed.codeVerifier) + let outcome: Result = await MainActor.run { + do { + return .success(try MCPPairingService.shared.exchange(exchange)) + } catch { + return .failure(error) + } + } + + switch outcome { + case .success(let token): + Self.logger.info("Integrations exchange succeeded (token len=\(token.count, privacy: .public))") + let label = await Self.resolveTokenLabel(for: token) + MCPAuditLogger.logPairingExchange(outcome: .success, tokenName: label, ip: ip) + let payload = (try? JSONEncoder().encode(ExchangeResponse(token: token))) ?? Data() + await context.writePlainJsonResponse(status: .ok, body: payload) + await context.cancel() + case .failure(let error): + let mapped = Self.mapExchangeError(error) + Self.logger.warning("Integrations exchange failed: status=\(mapped.status.code, privacy: .public) reason=\(mapped.message, privacy: .public)") + MCPAuditLogger.logPairingExchange( + outcome: .denied, + ip: ip, + details: mapped.message + ) + await context.writePlainJsonError(status: mapped.status, message: mapped.message) + await context.cancel() + } + } + + private static func ipString(for address: MCPClientAddress) -> String { + switch address { + case .loopback: + return "127.0.0.1" + case .remote(let host): + return host + } + } + + private static func resolveTokenLabel(for plaintext: String) async -> String? { + let store: MCPTokenStore? = await MainActor.run { MCPServerManager.shared.tokenStore } + guard let store else { return nil } + return await store.validate(bearerToken: plaintext)?.name + } + + private static func mapExchangeError(_ error: Error) -> (status: HttpStatus, message: String) { + guard let domainError = error as? MCPDataLayerError else { + return (.internalServerError, "Internal error") + } + switch domainError { + case .notFound: + return (.notFound, "Pairing code not found") + case .expired: + return (HttpStatus(code: 410, reasonPhrase: "Gone"), "Pairing code expired") + case .forbidden: + return (.forbidden, "Challenge mismatch") + default: + return (.internalServerError, "Internal error") + } + } + + private func handleGetMcp( + head: HttpRequestHead, + context: HttpConnectionContext, + clientAddress: MCPClientAddress + ) async { + guard pathMatchesMcp(head.path) else { + await respondTopLevel( + context: context, + error: MCPProtocolError( + code: JsonRpcErrorCode.methodNotFound, + message: "Method not found", + httpStatus: .notFound + ), + requestId: nil + ) + return + } + + guard let sessionIdRaw = head.headers.value(for: "Mcp-Session-Id") else { + await respondTopLevel(context: context, error: .missingSessionId(), requestId: nil) + return + } + + if head.headers.value(for: "Last-Event-ID") != nil { + await respondTopLevel( + context: context, + error: MCPProtocolError( + code: JsonRpcErrorCode.serverError, + message: "SSE event replay is not supported", + httpStatus: .notImplemented + ), + requestId: nil + ) + return + } + + if let accept = head.headers.value(for: "Accept"), + !accept.lowercased().contains("text/event-stream"), + !accept.contains("*/*") { + await respondTopLevel(context: context, error: .notAcceptable(), requestId: nil) + return + } + + let authResult = await authenticate(headers: head.headers, clientAddress: clientAddress) + guard case .allow = authResult else { + if case .deny(let error) = authResult { + await respondTopLevel(context: context, error: error, requestId: nil) + } + return + } + + let sessionId = MCPSessionId(sessionIdRaw) + guard await sessionStore.session(id: sessionId) != nil else { + await respondTopLevel(context: context, error: .sessionNotFound(), requestId: nil) + return + } + + await sessionStore.touch(id: sessionId) + + registerSseConnection(connectionId: context.id, sessionId: sessionId) + await context.writeSseStreamHeaders(sessionId: sessionId) + Self.logger.info("Registered SSE notification stream for session \(sessionId.rawValue, privacy: .public)") + } + + private func handlePostMcp( + head: HttpRequestHead, + body: Data, + context: HttpConnectionContext, + clientAddress: MCPClientAddress, + now: Date + ) async { + guard pathMatchesMcp(head.path) else { + await respondTopLevel( + context: context, + error: MCPProtocolError( + code: JsonRpcErrorCode.methodNotFound, + message: "Method not found", + httpStatus: .notFound + ), + requestId: nil + ) + return + } + + if body.count > configuration.limits.maxRequestBodyBytes { + await respondTopLevel(context: context, error: .payloadTooLarge(), requestId: nil) + return + } + + let authResult = await authenticate(headers: head.headers, clientAddress: clientAddress) + guard case .allow(let principal) = authResult else { + if case .deny(let error) = authResult { + await respondTopLevel(context: context, error: error, requestId: nil) + } + return + } + + let message: JsonRpcMessage + do { + message = try JsonRpcCodec.decode(body) + } catch { + await respondTopLevel( + context: context, + error: .parseError(detail: String(describing: error)), + requestId: nil + ) + return + } + + let requestId = extractRequestId(from: message) + let methodName = extractMethod(from: message) + let mcpProtocolVersion = head.headers.value(for: "mcp-protocol-version") + + let sessionId: MCPSessionId? + if methodName == "initialize" { + do { + let session = try await sessionStore.create() + sessionId = session.id + } catch { + await respondTopLevel( + context: context, + error: .serviceUnavailable(), + requestId: requestId + ) + return + } + } else { + guard let raw = head.headers.value(for: "Mcp-Session-Id") else { + await respondTopLevel(context: context, error: .missingSessionId(), requestId: requestId) + return + } + let candidate = MCPSessionId(raw) + guard let session = await sessionStore.session(id: candidate) else { + await respondTopLevel(context: context, error: .sessionNotFound(), requestId: requestId) + return + } + if let mismatch = await Self.protocolVersionMismatch( + session: session, + headerValue: mcpProtocolVersion + ) { + await respondTopLevel(context: context, error: mismatch, requestId: requestId) + return + } + sessionId = candidate + await sessionStore.touch(id: candidate) + } + + let sink = TransportResponderSink(transport: self, context: context) + let responder = MCPExchangeResponder(sink: sink, requestId: requestId) + + let exchangeContext = MCPInboundContext( + sessionId: sessionId, + principal: principal, + clientAddress: clientAddress, + receivedAt: now, + mcpProtocolVersion: mcpProtocolVersion + ) + let exchange = MCPInboundExchange( + message: message, + context: exchangeContext, + responder: responder + ) + let yieldResult = exchangesContinuation.yield(exchange) + if case .dropped = yieldResult { + Self.logger.warning("exchanges buffer full, dropped inbound message — dispatcher is falling behind") + } + } + + private func handleDeleteMcp( + head: HttpRequestHead, + context: HttpConnectionContext, + clientAddress: MCPClientAddress + ) async { + guard pathMatchesMcp(head.path) else { + await respondTopLevel( + context: context, + error: MCPProtocolError( + code: JsonRpcErrorCode.methodNotFound, + message: "Method not found", + httpStatus: .notFound + ), + requestId: nil + ) + return + } + + let authResult = await authenticate(headers: head.headers, clientAddress: clientAddress) + guard case .allow = authResult else { + if case .deny(let error) = authResult { + await respondTopLevel(context: context, error: error, requestId: nil) + } + return + } + + guard let raw = head.headers.value(for: "Mcp-Session-Id") else { + await respondTopLevel(context: context, error: .missingSessionId(), requestId: nil) + return + } + + let sessionId = MCPSessionId(raw) + guard await sessionStore.session(id: sessionId) != nil else { + await respondTopLevel(context: context, error: .sessionNotFound(), requestId: nil) + return + } + + await sessionStore.terminate(id: sessionId, reason: .clientRequested) + await context.writeNoContent() + await context.cancel() + } + + private func authenticate( + headers: HttpHeaders, + clientAddress: MCPClientAddress + ) async -> AuthResult { + let authHeader = headers.value(for: "Authorization") + let decision = await authenticator.authenticate( + authorizationHeader: authHeader, + clientAddress: clientAddress + ) + switch decision { + case .allow(let principal): + return .allow(principal) + case .deny(let reason): + let mcpError = mapDenialToProtocolError(reason) + return .deny(mcpError) + } + } + + private func mapDenialToProtocolError(_ reason: MCPAuthDenialReason) -> MCPProtocolError { + switch reason.httpStatus { + case 401: + if let challenge = reason.challenge { + if challenge.contains("invalid_token") { + if challenge.contains("token_expired") || challenge.contains("token expired") { + return .tokenExpired() + } + return .tokenInvalid() + } + return .unauthenticated(challenge: challenge) + } + return .unauthenticated() + case 403: + return .forbidden(reason: reason.logMessage) + case 429: + return .rateLimited(retryAfterSeconds: reason.retryAfterSeconds) + default: + return MCPProtocolError( + code: JsonRpcErrorCode.serverError, + message: reason.logMessage, + httpStatus: HttpStatus(code: reason.httpStatus, reasonPhrase: "Error"), + extraHeaders: reason.challenge.map { [("WWW-Authenticate", $0)] } ?? [] + ) + } + } + + private func respondTopLevel( + context: HttpConnectionContext, + error: MCPProtocolError, + requestId: JsonRpcId? + ) async { + let envelope = error.toJsonRpcErrorResponse(id: requestId) + let data = (try? JSONEncoder().encode(envelope)) ?? Data() + await context.writeJsonResponse( + data: data, + status: error.httpStatus, + sessionId: nil, + extraHeaders: error.extraHeaders + ) + await context.cancel() + } + + private func pathMatchesMcp(_ path: String) -> Bool { + let trimmed = stripQueryString(path) + return trimmed == "/mcp" || trimmed == "/mcp/" + } + + private static func protocolVersionMismatch( + session: MCPSession, + headerValue: String? + ) async -> MCPProtocolError? { + let state = await session.state + guard case .ready = state else { return nil } + guard let negotiated = await session.negotiatedProtocolVersion else { return nil } + guard let headerValue, !headerValue.isEmpty else { return nil } + if headerValue == negotiated { return nil } + return .invalidRequest( + detail: "MCP-Protocol-Version mismatch: client sent \(headerValue), session negotiated \(negotiated)" + ) + } + + private func stripQueryString(_ path: String) -> String { + if let questionIndex = path.firstIndex(of: "?") { + return String(path[path.startIndex.. JsonRpcId? { + switch message { + case .request(let request): + return request.id + case .successResponse(let response): + return response.id + case .errorResponse(let response): + return response.id + case .notification: + return nil + } + } + + private func extractMethod(from message: JsonRpcMessage) -> String? { + switch message { + case .request(let request): + return request.method + case .notification(let notification): + return notification.method + case .successResponse, .errorResponse: + return nil + } + } + + fileprivate func registerSseConnection(connectionId: UUID, sessionId: MCPSessionId) { + if let previous = sseConnectionsBySession[sessionId], previous != connectionId, + let oldContext = connections[previous] { + Task { await oldContext.cancel() } + connections.removeValue(forKey: previous) + } + sseConnectionsBySession[sessionId] = connectionId + } + + private enum AuthResult { + case allow(MCPPrincipal) + case deny(MCPProtocolError) + } +} + +actor HttpConnectionContext { + private static let logger = Logger(subsystem: "com.TablePro", category: "MCP.HttpServer") + + nonisolated let id: UUID + private let connection: NWConnection + private var receiveBuffer = Data() + private var requestComplete = false + private var cancelled = false + private var sseActive = false + private var origin: String? + + init(id: UUID, connection: NWConnection) { + self.id = id + self.connection = connection + } + + func setOrigin(_ value: String?) { + origin = value + } + + private func corsHeaders() -> [(String, String)] { + MCPCorsHeaders.headers(forOrigin: origin) + } + + func start( + onData: @escaping @Sendable (Data) async -> Void, + onClosed: @escaping @Sendable () async -> Void + ) { + let nwConnection = connection + nwConnection.stateUpdateHandler = { [weak self] state in + guard let self else { return } + switch state { + case .ready: + Task { await self.beginReading(onData: onData, onClosed: onClosed) } + case .failed: + Task { await self.handleClosed(onClosed: onClosed) } + case .cancelled: + Task { await self.handleClosed(onClosed: onClosed) } + default: + break + } + } + nwConnection.start(queue: .global(qos: .userInitiated)) + } + + private func beginReading( + onData: @escaping @Sendable (Data) async -> Void, + onClosed: @escaping @Sendable () async -> Void + ) { + scheduleReceive(onData: onData, onClosed: onClosed) + } + + private func scheduleReceive( + onData: @escaping @Sendable (Data) async -> Void, + onClosed: @escaping @Sendable () async -> Void + ) { + if cancelled || requestComplete { return } + connection.receive(minimumIncompleteLength: 1, maximumLength: 65_536) { [weak self] content, _, isComplete, error in + guard let self else { return } + Task { + await self.handleReceive( + content: content, + isComplete: isComplete, + error: error, + onData: onData, + onClosed: onClosed + ) + } + } + } + + private func handleReceive( + content: Data?, + isComplete: Bool, + error: NWError?, + onData: @escaping @Sendable (Data) async -> Void, + onClosed: @escaping @Sendable () async -> Void + ) async { + if let error { + Self.logger.debug("Receive error: \(error.localizedDescription, privacy: .public)") + cancel() + await onClosed() + return + } + + if let content { + receiveBuffer.append(content) + await onData(receiveBuffer) + } + + if isComplete { + cancel() + await onClosed() + return + } + + if !requestComplete, !cancelled { + scheduleReceive(onData: onData, onClosed: onClosed) + } + } + + private func handleClosed(onClosed: @escaping @Sendable () async -> Void) async { + if !cancelled { + cancelled = true + } + await onClosed() + } + + func markRequestComplete() { + requestComplete = true + } + + func clientAddress() -> MCPClientAddress { + guard let endpoint = connection.currentPath?.remoteEndpoint, + case .hostPort(let host, _) = endpoint else { + return .loopback + } + let hostString = "\(host)" + if hostString == "127.0.0.1" || hostString == "::1" || hostString.lowercased() == "localhost" { + return .loopback + } + return .remote(hostString) + } + + func writeJsonResponse( + data: Data, + status: HttpStatus, + sessionId: MCPSessionId?, + extraHeaders: [(String, String)] + ) async { + if cancelled { return } + var headers: [(String, String)] = [ + ("Content-Type", "application/json"), + ("Connection", "close") + ] + if let sessionId { + headers.append(("Mcp-Session-Id", sessionId.rawValue)) + } + headers.append(contentsOf: extraHeaders) + headers.append(contentsOf: self.corsHeaders()) + let head = HttpResponseHead(status: status, headers: HttpHeaders(headers)) + let payload = HttpResponseEncoder.encode(head, body: data) + await send(payload) + } + + func writePlainJsonResponse(status: HttpStatus, body: Data) async { + if cancelled { return } + var headers: [(String, String)] = [ + ("Content-Type", "application/json"), + ("Connection", "close") + ] + headers.append(contentsOf: self.corsHeaders()) + let head = HttpResponseHead(status: status, headers: HttpHeaders(headers)) + let payload = HttpResponseEncoder.encode(head, body: body) + await send(payload) + } + + func writePlainJsonError(status: HttpStatus, message: String) async { + struct ErrorBody: Encodable { let error: String } + let payload = (try? JSONEncoder().encode(ErrorBody(error: message))) ?? Data() + await writePlainJsonResponse(status: status, body: payload) + } + + func writeOptions204() async { + if cancelled { return } + var headers: [(String, String)] = [("Connection", "close")] + headers.append(contentsOf: self.corsHeaders()) + let head = HttpResponseHead(status: .noContent, headers: HttpHeaders(headers)) + let payload = HttpResponseEncoder.encode(head, body: nil) + await send(payload) + } + + func writeNoContent() async { + if cancelled { return } + var headers: [(String, String)] = [("Connection", "close")] + headers.append(contentsOf: self.corsHeaders()) + let head = HttpResponseHead(status: .noContent, headers: HttpHeaders(headers)) + let payload = HttpResponseEncoder.encode(head, body: nil) + await send(payload) + } + + func writeAccepted() async { + if cancelled { return } + var headers: [(String, String)] = [("Connection", "close")] + headers.append(contentsOf: self.corsHeaders()) + let head = HttpResponseHead(status: .accepted, headers: HttpHeaders(headers)) + let payload = HttpResponseEncoder.encode(head, body: nil) + await send(payload) + } + + func writeSseStreamHeaders(sessionId: MCPSessionId) async { + if cancelled { return } + sseActive = true + var headers: [(String, String)] = [ + ("Content-Type", "text/event-stream"), + ("Cache-Control", "no-cache"), + ("Connection", "keep-alive"), + ("Mcp-Session-Id", sessionId.rawValue) + ] + headers.append(contentsOf: self.corsHeaders()) + let head = HttpResponseHead(status: .ok, headers: HttpHeaders(headers)) + let payload = HttpResponseEncoder.encode(head, body: nil) + await send(payload) + } + + func writeSseFrame(_ frame: SseFrame) async { + if cancelled { return } + let data = SseEncoder.encode(frame) + await send(data) + } + + func writeRaw(_ data: Data) async { + if cancelled { return } + await send(data) + } + + func cancel() { + if cancelled { return } + cancelled = true + connection.cancel() + } + + func isSseActive() -> Bool { + sseActive + } + + private func send(_ data: Data) async { + await withCheckedContinuation { (continuation: CheckedContinuation) in + connection.send(content: data, completion: .contentProcessed { error in + if let error { + Self.logger.debug("Send error: \(error.localizedDescription, privacy: .public)") + } + continuation.resume() + }) + } + } +} + +struct TransportResponderSink: MCPResponderSink { + let transport: MCPHttpServerTransport + let context: HttpConnectionContext + + func writeJson(_ data: Data, status: HttpStatus, sessionId: MCPSessionId?, extraHeaders: [(String, String)]) async { + await context.writeJsonResponse( + data: data, + status: status, + sessionId: sessionId, + extraHeaders: extraHeaders + ) + } + + func writeAccepted() async { + await context.writeAccepted() + } + + func writeSseStreamHeaders(sessionId: MCPSessionId) async { + await context.writeSseStreamHeaders(sessionId: sessionId) + } + + func writeSseFrame(_ frame: SseFrame) async { + await context.writeSseFrame(frame) + } + + func closeConnection() async { + await context.cancel() + } + + func registerSseConnection(sessionId: MCPSessionId) async { + await transport.registerSseConnection(connectionId: context.id, sessionId: sessionId) + } +} diff --git a/TablePro/Core/MCP/Transport/MCPInboundExchange.swift b/TablePro/Core/MCP/Transport/MCPInboundExchange.swift new file mode 100644 index 000000000..8cb3a1dd5 --- /dev/null +++ b/TablePro/Core/MCP/Transport/MCPInboundExchange.swift @@ -0,0 +1,145 @@ +import Foundation +import os + +public struct MCPInboundContext: Sendable { + public let sessionId: MCPSessionId? + public let principal: MCPPrincipal? + public let clientAddress: MCPClientAddress + public let receivedAt: Date + public let mcpProtocolVersion: String? + + public init( + sessionId: MCPSessionId?, + principal: MCPPrincipal?, + clientAddress: MCPClientAddress, + receivedAt: Date, + mcpProtocolVersion: String? + ) { + self.sessionId = sessionId + self.principal = principal + self.clientAddress = clientAddress + self.receivedAt = receivedAt + self.mcpProtocolVersion = mcpProtocolVersion + } +} + +public struct MCPInboundExchange: Sendable { + public let message: JsonRpcMessage + public let context: MCPInboundContext + public let responder: MCPExchangeResponder + + public init( + message: JsonRpcMessage, + context: MCPInboundContext, + responder: MCPExchangeResponder + ) { + self.message = message + self.context = context + self.responder = responder + } +} + +public protocol MCPResponderSink: Sendable { + func writeJson(_ data: Data, status: HttpStatus, sessionId: MCPSessionId?, extraHeaders: [(String, String)]) async + func writeAccepted() async + func writeSseStreamHeaders(sessionId: MCPSessionId) async + func writeSseFrame(_ frame: SseFrame) async + func closeConnection() async + func registerSseConnection(sessionId: MCPSessionId) async +} + +public actor MCPExchangeResponder { + private static let logger = Logger(subsystem: "com.TablePro", category: "MCP.HttpServer") + + private let sink: MCPResponderSink + private var completed: Bool = false + private let requestId: JsonRpcId? + + public init(sink: MCPResponderSink, requestId: JsonRpcId?) { + self.sink = sink + self.requestId = requestId + } + + public func respond(_ message: JsonRpcMessage, sessionId: MCPSessionId?) async { + guard !completed else { + Self.logger.warning("Responder.respond called after completion; ignoring") + return + } + completed = true + + let body: Data + do { + body = try JsonRpcCodec.encode(message) + } catch { + let fallback = MCPProtocolError.internalError(detail: "encode failed").toJsonRpcErrorResponse(id: requestId) + body = (try? JSONEncoder().encode(fallback)) ?? Data() + } + + await sink.writeJson(body, status: .ok, sessionId: sessionId, extraHeaders: []) + await sink.closeConnection() + } + + public func respondError(_ error: MCPProtocolError, requestId responseId: JsonRpcId?) async { + guard !completed else { + Self.logger.warning("Responder.respondError called after completion; ignoring") + return + } + completed = true + + let envelope = error.toJsonRpcErrorResponse(id: responseId ?? requestId) + let data = (try? JSONEncoder().encode(envelope)) ?? Data() + await sink.writeJson(data, status: error.httpStatus, sessionId: nil, extraHeaders: error.extraHeaders) + await sink.closeConnection() + } + + public func respondSseStream( + initialMessage: JsonRpcMessage?, + sessionId: MCPSessionId, + additional: AsyncStream + ) async { + guard !completed else { + Self.logger.warning("Responder.respondSseStream called after completion; ignoring") + return + } + completed = true + + await sink.writeSseStreamHeaders(sessionId: sessionId) + await sink.registerSseConnection(sessionId: sessionId) + + if let initialMessage { + if let payload = try? JsonRpcCodec.encode(initialMessage), + let text = String(data: payload, encoding: .utf8) { + await sink.writeSseFrame(SseFrame(data: text)) + } + } + + for await message in additional { + guard let payload = try? JsonRpcCodec.encode(message), + let text = String(data: payload, encoding: .utf8) else { continue } + await sink.writeSseFrame(SseFrame(data: text)) + } + } + + public func acknowledgeAccepted() async { + guard !completed else { + Self.logger.warning("Responder.acknowledgeAccepted called after completion; ignoring") + return + } + completed = true + await sink.writeAccepted() + await sink.closeConnection() + } + + public func reject(_ error: MCPProtocolError) async { + guard !completed else { + Self.logger.warning("Responder.reject called after completion; ignoring") + return + } + completed = true + + let envelope = error.toJsonRpcErrorResponse(id: requestId) + let data = (try? JSONEncoder().encode(envelope)) ?? Data() + await sink.writeJson(data, status: error.httpStatus, sessionId: nil, extraHeaders: error.extraHeaders) + await sink.closeConnection() + } +} diff --git a/TablePro/Core/MCP/Transport/MCPMessageTransport.swift b/TablePro/Core/MCP/Transport/MCPMessageTransport.swift new file mode 100644 index 000000000..64d4fc870 --- /dev/null +++ b/TablePro/Core/MCP/Transport/MCPMessageTransport.swift @@ -0,0 +1,18 @@ +import Foundation + +public protocol MCPMessageTransport: AnyObject, Sendable { + var inbound: AsyncThrowingStream { get } + func send(_ message: JsonRpcMessage) async throws + func close() async +} + +public enum MCPTransportError: Error, Sendable, Equatable { + case closed + case malformedFrame(detail: String) + case writeFailed(detail: String) + case readFailed(detail: String) + case invalidEndpoint + case authentication(httpStatus: Int, message: String) + case sessionExpired + case timeout +} diff --git a/TablePro/Core/MCP/Transport/MCPProtocolError.swift b/TablePro/Core/MCP/Transport/MCPProtocolError.swift new file mode 100644 index 000000000..d28091126 --- /dev/null +++ b/TablePro/Core/MCP/Transport/MCPProtocolError.swift @@ -0,0 +1,163 @@ +import Foundation + +public struct MCPProtocolError: Error, Sendable, Equatable { + public let code: Int + public let message: String + public let httpStatus: HttpStatus + public let extraHeaders: [(String, String)] + public let data: JsonValue? + + public init( + code: Int, + message: String, + httpStatus: HttpStatus, + extraHeaders: [(String, String)] = [], + data: JsonValue? = nil + ) { + self.code = code + self.message = message + self.httpStatus = httpStatus + self.extraHeaders = extraHeaders + self.data = data + } + + public static func == (lhs: Self, rhs: Self) -> Bool { + lhs.code == rhs.code && lhs.message == rhs.message + } +} + +public extension MCPProtocolError { + static func sessionNotFound(message: String = "Session not found") -> Self { + Self(code: JsonRpcErrorCode.sessionNotFound, message: message, httpStatus: .notFound) + } + + static func missingSessionId(message: String = "Missing Mcp-Session-Id header") -> Self { + Self(code: JsonRpcErrorCode.invalidRequest, message: message, httpStatus: .badRequest) + } + + static func parseError(detail: String) -> Self { + Self( + code: JsonRpcErrorCode.parseError, + message: "Parse error: \(detail)", + httpStatus: .badRequest + ) + } + + static func invalidRequest(detail: String) -> Self { + Self( + code: JsonRpcErrorCode.invalidRequest, + message: "Invalid request: \(detail)", + httpStatus: .badRequest + ) + } + + static func methodNotFound(method: String) -> Self { + Self( + code: JsonRpcErrorCode.methodNotFound, + message: "Method not found: \(method)", + httpStatus: .ok + ) + } + + static func invalidParams(detail: String) -> Self { + Self( + code: JsonRpcErrorCode.invalidParams, + message: "Invalid params: \(detail)", + httpStatus: .ok + ) + } + + static func internalError(detail: String) -> Self { + Self( + code: JsonRpcErrorCode.internalError, + message: "Internal error: \(detail)", + httpStatus: .internalServerError + ) + } + + static func unauthenticated(challenge: String = "Bearer realm=\"TablePro\"") -> Self { + Self( + code: JsonRpcErrorCode.unauthenticated, + message: "Unauthenticated", + httpStatus: .unauthorized, + extraHeaders: [("WWW-Authenticate", challenge)] + ) + } + + static func tokenInvalid() -> Self { + Self( + code: JsonRpcErrorCode.forbidden, + message: "Token invalid", + httpStatus: .unauthorized, + extraHeaders: [("WWW-Authenticate", "Bearer error=\"invalid_token\"")] + ) + } + + static func tokenExpired() -> Self { + Self( + code: JsonRpcErrorCode.expired, + message: "Token expired", + httpStatus: .unauthorized, + extraHeaders: [("WWW-Authenticate", "Bearer error=\"invalid_token\", error_description=\"token expired\"")] + ) + } + + static func forbidden(reason: String) -> Self { + Self( + code: JsonRpcErrorCode.forbidden, + message: "Forbidden: \(reason)", + httpStatus: .forbidden + ) + } + + static func rateLimited(retryAfterSeconds: Int? = nil) -> Self { + var headers: [(String, String)] = [] + if let retryAfterSeconds, retryAfterSeconds > 0 { + headers.append(("Retry-After", String(retryAfterSeconds))) + } + return Self( + code: JsonRpcErrorCode.serverError, + message: "Rate limited", + httpStatus: .tooManyRequests, + extraHeaders: headers + ) + } + + static func payloadTooLarge() -> Self { + Self( + code: JsonRpcErrorCode.tooLarge, + message: "Payload too large", + httpStatus: .payloadTooLarge + ) + } + + static func notAcceptable() -> Self { + Self( + code: JsonRpcErrorCode.invalidRequest, + message: "Not acceptable", + httpStatus: .notAcceptable + ) + } + + static func unsupportedMediaType() -> Self { + Self( + code: JsonRpcErrorCode.invalidRequest, + message: "Unsupported media type", + httpStatus: .unsupportedMediaType + ) + } + + static func serviceUnavailable() -> Self { + Self( + code: JsonRpcErrorCode.serverError, + message: "Service unavailable", + httpStatus: .serviceUnavailable + ) + } +} + +public extension MCPProtocolError { + func toJsonRpcErrorResponse(id: JsonRpcId?) -> JsonRpcErrorResponse { + JsonRpcErrorResponse(id: id, error: JsonRpcError(code: code, message: message, data: data)) + } +} diff --git a/TablePro/Core/MCP/Transport/MCPStdioMessageTransport.swift b/TablePro/Core/MCP/Transport/MCPStdioMessageTransport.swift new file mode 100644 index 000000000..5580cd0ab --- /dev/null +++ b/TablePro/Core/MCP/Transport/MCPStdioMessageTransport.swift @@ -0,0 +1,140 @@ +import Foundation + +public actor MCPStdioMessageTransport: MCPMessageTransport { + nonisolated public let inbound: AsyncThrowingStream + nonisolated private let continuation: AsyncThrowingStream.Continuation + + private let writer: StdioWriter + private let errorLogger: (any MCPBridgeLogger)? + private var readerTask: Task? + private var isClosed = false + + public init( + stdin: FileHandle = .standardInput, + stdout: FileHandle = .standardOutput, + errorLogger: (any MCPBridgeLogger)? = nil + ) { + let (stream, continuation) = AsyncThrowingStream.makeStream() + self.inbound = stream + self.continuation = continuation + self.writer = StdioWriter(handle: stdout) + self.errorLogger = errorLogger + + Task { await self.startReader(stdin: stdin) } + } + + public func send(_ message: JsonRpcMessage) async throws { + if isClosed { + throw MCPTransportError.closed + } + + let line: Data + do { + line = try JsonRpcCodec.encodeLine(message) + } catch { + throw MCPTransportError.writeFailed(detail: String(describing: error)) + } + + do { + try await writer.write(line) + } catch { + throw MCPTransportError.writeFailed(detail: String(describing: error)) + } + } + + public func close() async { + if isClosed { + return + } + isClosed = true + let task = readerTask + readerTask = nil + task?.cancel() + continuation.finish() + } + + private func startReader(stdin: FileHandle) { + if isClosed { + return + } + let continuation = self.continuation + let logger = errorLogger + let task = Task.detached(priority: .userInitiated) { [weak self] in + await Self.readLoop(stdin: stdin, continuation: continuation, logger: logger) + await self?.finishStream() + } + readerTask = task + } + + private func finishStream() { + if isClosed { + return + } + isClosed = true + readerTask = nil + continuation.finish() + } + + private static func readLoop( + stdin: FileHandle, + continuation: AsyncThrowingStream.Continuation, + logger: (any MCPBridgeLogger)? + ) async { + var buffer = Data() + do { + for try await byte in stdin.bytes { + if Task.isCancelled { + return + } + if byte == 0x0A { + processLine(buffer, continuation: continuation, logger: logger) + buffer.removeAll(keepingCapacity: true) + continue + } + buffer.append(byte) + } + } catch { + logger?.log(.error, "stdio read failed: \(error)") + continuation.finish(throwing: MCPTransportError.readFailed(detail: String(describing: error))) + return + } + + if !buffer.isEmpty { + processLine(buffer, continuation: continuation, logger: logger) + } + } + + private static func processLine( + _ raw: Data, + continuation: AsyncThrowingStream.Continuation, + logger: (any MCPBridgeLogger)? + ) { + var trimmed = raw + if trimmed.last == 0x0D { + trimmed.removeLast() + } + if trimmed.isEmpty { + return + } + + do { + let message = try JsonRpcCodec.decode(trimmed) + continuation.yield(message) + } catch { + logger?.log(.warning, "stdio: skipping malformed JSON-RPC line: \(error)") + } + } +} + +private actor StdioWriter { + private let handle: FileHandle + + init(handle: FileHandle) { + self.handle = handle + } + + func write(_ data: Data) throws { + try handle.write(contentsOf: data) + try? handle.synchronize() + } +} diff --git a/TablePro/Core/MCP/Transport/MCPStreamableHttpClientTransport.swift b/TablePro/Core/MCP/Transport/MCPStreamableHttpClientTransport.swift new file mode 100644 index 000000000..31ee6b1b3 --- /dev/null +++ b/TablePro/Core/MCP/Transport/MCPStreamableHttpClientTransport.swift @@ -0,0 +1,420 @@ +import CryptoKit +import Foundation +import Security + +public struct MCPStreamableHttpClientConfiguration: Sendable { + public let endpoint: URL + public let bearerToken: String + public let tlsCertFingerprint: String? + public let requestTimeout: Duration + public let serverInitiatedStream: Bool + + public init( + endpoint: URL, + bearerToken: String, + tlsCertFingerprint: String? = nil, + requestTimeout: Duration = .seconds(60), + serverInitiatedStream: Bool = false + ) { + self.endpoint = endpoint + self.bearerToken = bearerToken + self.tlsCertFingerprint = tlsCertFingerprint + self.requestTimeout = requestTimeout + self.serverInitiatedStream = serverInitiatedStream + } +} + +public actor MCPStreamableHttpClientTransport: MCPMessageTransport { + nonisolated public let inbound: AsyncThrowingStream + nonisolated private let continuation: AsyncThrowingStream.Continuation + + private let configuration: MCPStreamableHttpClientConfiguration + private let urlSession: URLSession + private let errorLogger: (any MCPBridgeLogger)? + private var sessionId: String? + private var isClosed = false + private var serverInitiatedStreamOpen = false + private var tasks: [Task] = [] + + public init( + configuration: MCPStreamableHttpClientConfiguration, + urlSession: URLSession? = nil, + errorLogger: (any MCPBridgeLogger)? = nil + ) { + self.configuration = configuration + self.errorLogger = errorLogger + + let (stream, continuation) = AsyncThrowingStream.makeStream() + self.inbound = stream + self.continuation = continuation + + if let urlSession { + self.urlSession = urlSession + } else { + let config = URLSessionConfiguration.ephemeral + config.timeoutIntervalForRequest = TimeInterval(configuration.requestTimeout.components.seconds) + config.timeoutIntervalForResource = TimeInterval(configuration.requestTimeout.components.seconds) + if let fingerprint = configuration.tlsCertFingerprint { + let delegate = CertificatePinningDelegate(expectedFingerprint: fingerprint, errorLogger: errorLogger) + self.urlSession = URLSession(configuration: config, delegate: delegate, delegateQueue: nil) + } else { + self.urlSession = URLSession(configuration: config) + } + } + } + + public func send(_ message: JsonRpcMessage) async throws { + if isClosed { + throw MCPTransportError.closed + } + + let requestId = Self.requestId(of: message) + let body: Data + do { + body = try JsonRpcCodec.encode(message) + } catch { + throw MCPTransportError.writeFailed(detail: String(describing: error)) + } + + let task: Task = Task { [weak self] in + guard let self else { return } + await self.dispatch(body: body, requestId: requestId) + } + trackTask(task) + } + + public func openSseStream() async throws { + if isClosed { + throw MCPTransportError.closed + } + if serverInitiatedStreamOpen { + return + } + serverInitiatedStreamOpen = true + + let task: Task = Task { [weak self] in + guard let self else { return } + await self.runServerInitiatedStream() + } + trackTask(task) + } + + public func close() async { + if isClosed { + return + } + isClosed = true + let pending = tasks + tasks.removeAll() + for task in pending { + task.cancel() + } + urlSession.invalidateAndCancel() + continuation.finish() + } + + private func trackTask(_ task: Task) { + tasks.removeAll { $0.isCancelled } + tasks.append(task) + } + + private func setSessionId(_ value: String) { + sessionId = value + } + + private func currentSessionId() -> String? { + sessionId + } + + private func dispatch(body: Data, requestId: JsonRpcId?) async { + do { + try await performRequest(body: body, requestId: requestId) + } catch { + await handleSendError(error: error, requestId: requestId) + } + } + + private func performRequest(body: Data, requestId: JsonRpcId?) async throws { + var request = URLRequest(url: configuration.endpoint) + request.httpMethod = "POST" + request.httpBody = body + request.setValue("application/json", forHTTPHeaderField: "Content-Type") + request.setValue("application/json, text/event-stream", forHTTPHeaderField: "Accept") + request.setValue("Bearer \(configuration.bearerToken)", forHTTPHeaderField: "Authorization") + if let sessionId = currentSessionId() { + request.setValue(sessionId, forHTTPHeaderField: "Mcp-Session-Id") + } + + let (bytes, response) = try await urlSession.bytes(for: request) + guard let httpResponse = response as? HTTPURLResponse else { + throw MCPTransportError.readFailed(detail: "non-HTTP response") + } + + captureSessionIdIfPresent(from: httpResponse) + + let status = httpResponse.statusCode + let contentType = headerValue(httpResponse, name: "Content-Type")?.lowercased() ?? "" + + if (200..<300).contains(status) { + if contentType.contains("text/event-stream") { + try await consumeSseBytes(bytes) + return + } + if contentType.contains("application/json") { + let data = try await collectBytes(bytes) + if data.isEmpty { + return + } + pushJsonBody(data, fallbackId: requestId) + return + } + let data = try await collectBytes(bytes) + if data.isEmpty { + return + } + pushJsonBody(data, fallbackId: requestId) + return + } + + let data = try await collectBytes(bytes) + handleNonSuccessResponse( + status: status, + headers: httpResponse, + body: data, + requestId: requestId + ) + } + + private func runServerInitiatedStream() async { + do { + var request = URLRequest(url: configuration.endpoint) + request.httpMethod = "GET" + request.setValue("text/event-stream", forHTTPHeaderField: "Accept") + request.setValue("Bearer \(configuration.bearerToken)", forHTTPHeaderField: "Authorization") + if let sessionId = currentSessionId() { + request.setValue(sessionId, forHTTPHeaderField: "Mcp-Session-Id") + } + + let (bytes, response) = try await urlSession.bytes(for: request) + guard let httpResponse = response as? HTTPURLResponse else { + errorLogger?.log(.warning, "server-initiated stream: non-HTTP response") + return + } + captureSessionIdIfPresent(from: httpResponse) + let status = httpResponse.statusCode + guard (200..<300).contains(status) else { + let body = try await collectBytes(bytes) + handleNonSuccessResponse( + status: status, + headers: httpResponse, + body: body, + requestId: nil + ) + return + } + try await consumeSseBytes(bytes) + } catch { + if Task.isCancelled { + return + } + errorLogger?.log(.warning, "server-initiated stream ended: \(error)") + } + } + + private func consumeSseBytes(_ bytes: URLSession.AsyncBytes) async throws { + let decoder = SseDecoder() + var chunk = Data() + for try await byte in bytes { + if Task.isCancelled { + return + } + chunk.append(byte) + if byte == 0x0A { + let frames = await decoder.feed(chunk) + chunk.removeAll(keepingCapacity: true) + for frame in frames { + pushSseFrame(frame) + } + } + } + if !chunk.isEmpty { + let frames = await decoder.feed(chunk) + for frame in frames { + pushSseFrame(frame) + } + } + } + + private func collectBytes(_ bytes: URLSession.AsyncBytes) async throws -> Data { + var data = Data() + for try await byte in bytes { + if Task.isCancelled { + return data + } + data.append(byte) + } + return data + } + + private func pushSseFrame(_ frame: SseFrame) { + guard let payload = frame.data.data(using: .utf8) else { return } + if payload.isEmpty { + return + } + do { + let message = try JsonRpcCodec.decode(payload) + continuation.yield(message) + } catch { + errorLogger?.log(.warning, "SSE: skipping malformed JSON-RPC frame: \(error)") + } + } + + private func pushJsonBody(_ data: Data, fallbackId: JsonRpcId?) { + do { + let message = try JsonRpcCodec.decode(data) + continuation.yield(message) + } catch { + errorLogger?.log(.warning, "HTTP: malformed JSON-RPC body: \(error)") + let synthetic = MCPProtocolError.parseError(detail: String(describing: error)) + .toJsonRpcErrorResponse(id: fallbackId) + continuation.yield(.errorResponse(synthetic)) + } + } + + private func handleNonSuccessResponse( + status: Int, + headers: HTTPURLResponse, + body: Data, + requestId: JsonRpcId? + ) { + if requestId == nil { + errorLogger?.log(.warning, "HTTP \(status) for notification (no response will be emitted)") + return + } + + if !body.isEmpty, let parsed = try? JsonRpcCodec.decode(body) { + if case .errorResponse = parsed { + continuation.yield(parsed) + return + } + if case .successResponse = parsed { + continuation.yield(parsed) + return + } + } + + let challenge = headerValue(headers, name: "WWW-Authenticate") ?? "Bearer realm=\"TablePro\"" + let protocolError = Self.protocolError(forStatus: status, body: body, challenge: challenge) + let response = protocolError.toJsonRpcErrorResponse(id: requestId) + continuation.yield(.errorResponse(response)) + } + + private func handleSendError(error: Error, requestId: JsonRpcId?) async { + if Task.isCancelled { + return + } + errorLogger?.log(.error, "HTTP send failed: \(error)") + guard let requestId else { + return + } + let protocolError = MCPProtocolError.internalError(detail: String(describing: error)) + let response = protocolError.toJsonRpcErrorResponse(id: requestId) + continuation.yield(.errorResponse(response)) + } + + private func captureSessionIdIfPresent(from response: HTTPURLResponse) { + guard let value = headerValue(response, name: "Mcp-Session-Id") else { return } + setSessionId(value) + } + + private func headerValue(_ response: HTTPURLResponse, name: String) -> String? { + let target = name.lowercased() + for (rawKey, rawValue) in response.allHeaderFields { + guard let key = rawKey as? String, + key.lowercased() == target, + let value = rawValue as? String else { continue } + return value + } + return nil + } + + private static func requestId(of message: JsonRpcMessage) -> JsonRpcId? { + switch message { + case .request(let request): + return request.id + case .notification: + return nil + case .successResponse(let response): + return response.id + case .errorResponse(let response): + return response.id + } + } + + private static func protocolError(forStatus status: Int, body: Data, challenge: String) -> MCPProtocolError { + let detail = String(data: body, encoding: .utf8) ?? "HTTP \(status)" + switch status { + case 400: + return .invalidRequest(detail: detail) + case 401: + return .unauthenticated(challenge: challenge) + case 403: + return .forbidden(reason: detail) + case 404: + return .sessionNotFound(message: detail.isEmpty ? "Session not found" : detail) + case 406: + return .notAcceptable() + case 413: + return .payloadTooLarge() + case 415: + return .unsupportedMediaType() + case 429: + return .rateLimited() + case 503: + return .serviceUnavailable() + default: + return .internalError(detail: detail) + } + } +} + +private final class CertificatePinningDelegate: NSObject, URLSessionDelegate { + private let expectedFingerprint: String + private let errorLogger: (any MCPBridgeLogger)? + + init(expectedFingerprint: String, errorLogger: (any MCPBridgeLogger)?) { + self.expectedFingerprint = expectedFingerprint + self.errorLogger = errorLogger + } + + func urlSession( + _ session: URLSession, + didReceive challenge: URLAuthenticationChallenge + ) async -> (URLSession.AuthChallengeDisposition, URLCredential?) { + guard challenge.protectionSpace.authenticationMethod == NSURLAuthenticationMethodServerTrust, + let trust = challenge.protectionSpace.serverTrust else { + return (.performDefaultHandling, nil) + } + + guard let chain = SecTrustCopyCertificateChain(trust) as? [SecCertificate], + let leaf = chain.first else { + errorLogger?.log(.error, "TLS pinning: empty cert chain") + return (.cancelAuthenticationChallenge, nil) + } + + let fingerprint = Self.sha256Fingerprint(of: leaf) + if fingerprint.caseInsensitiveCompare(expectedFingerprint) != .orderedSame { + let prefix = String(fingerprint.prefix(8)) + errorLogger?.log(.error, "TLS pinning: cert mismatch (got \(prefix)...)") + return (.cancelAuthenticationChallenge, nil) + } + return (.useCredential, URLCredential(trust: trust)) + } + + private static func sha256Fingerprint(of certificate: SecCertificate) -> String { + let data = SecCertificateCopyData(certificate) as Data + return SHA256.hash(data: data) + .map { String(format: "%02X", $0) } + .joined(separator: ":") + } +} diff --git a/TablePro/Core/MCP/Wire/HttpRequestHead.swift b/TablePro/Core/MCP/Wire/HttpRequestHead.swift new file mode 100644 index 000000000..cd7949f44 --- /dev/null +++ b/TablePro/Core/MCP/Wire/HttpRequestHead.swift @@ -0,0 +1,96 @@ +import Foundation + +public enum HttpMethod: Sendable, Equatable { + case get + case post + case delete + case options + case put + case patch + case head + case other(String) + + public var rawValue: String { + switch self { + case .get: return "GET" + case .post: return "POST" + case .delete: return "DELETE" + case .options: return "OPTIONS" + case .put: return "PUT" + case .patch: return "PATCH" + case .head: return "HEAD" + case .other(let value): return value + } + } + + public init(rawValue: String) { + switch rawValue { + case "GET": self = .get + case "POST": self = .post + case "DELETE": self = .delete + case "OPTIONS": self = .options + case "PUT": self = .put + case "PATCH": self = .patch + case "HEAD": self = .head + default: self = .other(rawValue) + } + } +} + +public struct HttpHeaders: Sendable, Equatable { + private let storage: [(String, String)] + + public init(_ pairs: [(String, String)] = []) { + storage = pairs + } + + public var all: [(String, String)] { + storage + } + + public func value(for name: String) -> String? { + let lowered = name.lowercased() + for (key, value) in storage where key.lowercased() == lowered { + return value + } + return nil + } + + public func values(for name: String) -> [String] { + let lowered = name.lowercased() + return storage.compactMap { key, value in + key.lowercased() == lowered ? value : nil + } + } + + public func contains(_ name: String) -> Bool { + let lowered = name.lowercased() + return storage.contains { key, _ in key.lowercased() == lowered } + } + + public static func == (lhs: HttpHeaders, rhs: HttpHeaders) -> Bool { + guard lhs.storage.count == rhs.storage.count else { return false } + for index in lhs.storage.indices { + let leftPair = lhs.storage[index] + let rightPair = rhs.storage[index] + if leftPair.0 != rightPair.0 || leftPair.1 != rightPair.1 { + return false + } + } + return true + } +} + +public struct HttpRequestHead: Sendable, Equatable { + public let method: HttpMethod + public let path: String + public let httpVersion: String + public let headers: HttpHeaders + + public init(method: HttpMethod, path: String, httpVersion: String, headers: HttpHeaders) { + self.method = method + self.path = path + self.httpVersion = httpVersion + self.headers = headers + } +} diff --git a/TablePro/Core/MCP/Wire/HttpRequestParser.swift b/TablePro/Core/MCP/Wire/HttpRequestParser.swift new file mode 100644 index 000000000..b9faf1c61 --- /dev/null +++ b/TablePro/Core/MCP/Wire/HttpRequestParser.swift @@ -0,0 +1,196 @@ +import Foundation + +public enum HttpRequestParseResult: Sendable, Equatable { + case incomplete + case complete(HttpRequestHead, body: Data, consumedBytes: Int) +} + +public enum HttpRequestParseError: Error, Equatable, Sendable { + case malformedRequestLine + case malformedHeader + case unsupportedHttpVersion(String) + case missingHostHeader + case bodyTooLarge(limit: Int, actual: Int) + case nonStrictLineEndings + case headerTooLarge +} + +public enum HttpRequestParser { + public static let maxHeaderSize = 16 * 1_024 + public static let maxBodySize = 10 * 1_024 * 1_024 + + private static let crlfcrlf: [UInt8] = [0x0D, 0x0A, 0x0D, 0x0A] + private static let lflf: [UInt8] = [0x0A, 0x0A] + + public static func parse(_ buffer: Data) throws -> HttpRequestParseResult { + let bytes = [UInt8](buffer) + + let crlfTerminator = firstIndex(of: crlfcrlf, in: bytes) + let lflfTerminator = firstIndex(of: lflf, in: bytes) + + if let lflfIndex = lflfTerminator { + if let crlfIndex = crlfTerminator { + if lflfIndex < crlfIndex { + throw HttpRequestParseError.nonStrictLineEndings + } + } else { + if lflfIndex <= maxHeaderSize { + throw HttpRequestParseError.nonStrictLineEndings + } + } + } + + guard let headerEndIndex = crlfTerminator else { + if bytes.count > maxHeaderSize { + throw HttpRequestParseError.headerTooLarge + } + return .incomplete + } + + if headerEndIndex > maxHeaderSize { + throw HttpRequestParseError.headerTooLarge + } + + let headerBytes = Array(bytes[0.. maxBodySize { + throw HttpRequestParseError.bodyTooLarge(limit: maxBodySize, actual: contentLength) + } + + let availableBodyBytes = bytes.count - bodyStartIndex + if availableBodyBytes < contentLength { + return .incomplete + } + + let body = Data(bytes[bodyStartIndex..<(bodyStartIndex + contentLength)]) + let consumed = bodyStartIndex + contentLength + return .complete(head, body: body, consumedBytes: consumed) + } + + return .complete(head, body: Data(), consumedBytes: bodyStartIndex) + } + + private static func splitStrictCrlf(_ bytes: [UInt8]) throws -> [[UInt8]] { + var lines: [[UInt8]] = [] + var current: [UInt8] = [] + var index = 0 + while index < bytes.count { + let byte = bytes[index] + if byte == 0x0D { + let nextIndex = index + 1 + if nextIndex >= bytes.count { + throw HttpRequestParseError.malformedHeader + } + if bytes[nextIndex] != 0x0A { + throw HttpRequestParseError.malformedHeader + } + lines.append(current) + current = [] + index = nextIndex + 1 + continue + } + if byte == 0x0A { + throw HttpRequestParseError.nonStrictLineEndings + } + current.append(byte) + index += 1 + } + lines.append(current) + return lines + } + + private static func parseRequestLine(_ bytes: [UInt8]) throws -> (HttpMethod, String, String) { + guard let line = String(bytes: bytes, encoding: .utf8) else { + throw HttpRequestParseError.malformedRequestLine + } + + let parts = line.split(separator: " ", maxSplits: 2, omittingEmptySubsequences: false) + guard parts.count == 3 else { + throw HttpRequestParseError.malformedRequestLine + } + + let methodString = String(parts[0]) + let path = String(parts[1]) + let version = String(parts[2]) + + guard !methodString.isEmpty, !path.isEmpty, !version.isEmpty else { + throw HttpRequestParseError.malformedRequestLine + } + + guard version.hasPrefix("HTTP/") else { + throw HttpRequestParseError.unsupportedHttpVersion(version) + } + + let method = HttpMethod(rawValue: methodString) + return (method, path, version) + } + + private static func parseHeaderLine(_ bytes: [UInt8]) throws -> (String, String) { + guard let line = String(bytes: bytes, encoding: .utf8) else { + throw HttpRequestParseError.malformedHeader + } + + guard let colonIndex = line.firstIndex(of: ":") else { + throw HttpRequestParseError.malformedHeader + } + + let nameSlice = line[line.startIndex.. Int? { + guard !needle.isEmpty, haystack.count >= needle.count else { return nil } + let lastStart = haystack.count - needle.count + var index = 0 + while index <= lastStart { + var matched = true + for offset in 0.. Data { + var output = "HTTP/1.1 \(head.status.code) \(head.status.reasonPhrase)\r\n" + + let hasContentLength = head.headers.contains("Content-Length") + + for (name, value) in head.headers.all { + output += "\(name): \(value)\r\n" + } + + if let body, !hasContentLength { + output += "Content-Length: \(body.count)\r\n" + } + + output += "\r\n" + + var data = Data(output.utf8) + if let body { + data.append(body) + } + return data + } +} diff --git a/TablePro/Core/MCP/Wire/HttpResponseHead.swift b/TablePro/Core/MCP/Wire/HttpResponseHead.swift new file mode 100644 index 000000000..3c48968ce --- /dev/null +++ b/TablePro/Core/MCP/Wire/HttpResponseHead.swift @@ -0,0 +1,37 @@ +import Foundation + +public struct HttpStatus: Sendable, Equatable { + public let code: Int + public let reasonPhrase: String + + public init(code: Int, reasonPhrase: String) { + self.code = code + self.reasonPhrase = reasonPhrase + } + + public static let ok = HttpStatus(code: 200, reasonPhrase: "OK") + public static let accepted = HttpStatus(code: 202, reasonPhrase: "Accepted") + public static let noContent = HttpStatus(code: 204, reasonPhrase: "No Content") + public static let badRequest = HttpStatus(code: 400, reasonPhrase: "Bad Request") + public static let unauthorized = HttpStatus(code: 401, reasonPhrase: "Unauthorized") + public static let forbidden = HttpStatus(code: 403, reasonPhrase: "Forbidden") + public static let notFound = HttpStatus(code: 404, reasonPhrase: "Not Found") + public static let methodNotAllowed = HttpStatus(code: 405, reasonPhrase: "Method Not Allowed") + public static let notAcceptable = HttpStatus(code: 406, reasonPhrase: "Not Acceptable") + public static let payloadTooLarge = HttpStatus(code: 413, reasonPhrase: "Payload Too Large") + public static let unsupportedMediaType = HttpStatus(code: 415, reasonPhrase: "Unsupported Media Type") + public static let tooManyRequests = HttpStatus(code: 429, reasonPhrase: "Too Many Requests") + public static let internalServerError = HttpStatus(code: 500, reasonPhrase: "Internal Server Error") + public static let notImplemented = HttpStatus(code: 501, reasonPhrase: "Not Implemented") + public static let serviceUnavailable = HttpStatus(code: 503, reasonPhrase: "Service Unavailable") +} + +public struct HttpResponseHead: Sendable, Equatable { + public let status: HttpStatus + public let headers: HttpHeaders + + public init(status: HttpStatus, headers: HttpHeaders) { + self.status = status + self.headers = headers + } +} diff --git a/TablePro/Core/MCP/Wire/JsonRpcCodec.swift b/TablePro/Core/MCP/Wire/JsonRpcCodec.swift new file mode 100644 index 000000000..39f6d09e1 --- /dev/null +++ b/TablePro/Core/MCP/Wire/JsonRpcCodec.swift @@ -0,0 +1,17 @@ +import Foundation + +public enum JsonRpcCodec { + public static func encode(_ message: JsonRpcMessage) throws -> Data { + try message.encode() + } + + public static func decode(_ data: Data) throws -> JsonRpcMessage { + try JsonRpcMessage.decode(from: data) + } + + public static func encodeLine(_ message: JsonRpcMessage) throws -> Data { + var data = try encode(message) + data.append(0x0A) + return data + } +} diff --git a/TablePro/Core/MCP/Wire/JsonRpcError.swift b/TablePro/Core/MCP/Wire/JsonRpcError.swift new file mode 100644 index 000000000..dede20922 --- /dev/null +++ b/TablePro/Core/MCP/Wire/JsonRpcError.swift @@ -0,0 +1,91 @@ +import Foundation + +public struct JsonRpcError: Codable, Equatable, Sendable { + public let code: Int + public let message: String + public let data: JsonValue? + + public init(code: Int, message: String, data: JsonValue? = nil) { + self.code = code + self.message = message + self.data = data + } + + enum CodingKeys: String, CodingKey { + case code + case message + case data + } + + public init(from decoder: Decoder) throws { + let container = try decoder.container(keyedBy: CodingKeys.self) + code = try container.decode(Int.self, forKey: .code) + message = try container.decode(String.self, forKey: .message) + data = try container.decodeIfPresent(JsonValue.self, forKey: .data) + } + + public func encode(to encoder: Encoder) throws { + var container = encoder.container(keyedBy: CodingKeys.self) + try container.encode(code, forKey: .code) + try container.encode(message, forKey: .message) + try container.encodeIfPresent(data, forKey: .data) + } +} + +public extension JsonRpcError { + static func parseError(message: String = "Parse error", data: JsonValue? = nil) -> Self { + Self(code: JsonRpcErrorCode.parseError, message: message, data: data) + } + + static func invalidRequest(message: String = "Invalid request", data: JsonValue? = nil) -> Self { + Self(code: JsonRpcErrorCode.invalidRequest, message: message, data: data) + } + + static func methodNotFound(message: String = "Method not found", data: JsonValue? = nil) -> Self { + Self(code: JsonRpcErrorCode.methodNotFound, message: message, data: data) + } + + static func invalidParams(message: String = "Invalid params", data: JsonValue? = nil) -> Self { + Self(code: JsonRpcErrorCode.invalidParams, message: message, data: data) + } + + static func internalError(message: String = "Internal error", data: JsonValue? = nil) -> Self { + Self(code: JsonRpcErrorCode.internalError, message: message, data: data) + } + + static func serverError(message: String = "Server error", data: JsonValue? = nil) -> Self { + Self(code: JsonRpcErrorCode.serverError, message: message, data: data) + } + + static func sessionNotFound(message: String = "Session not found", data: JsonValue? = nil) -> Self { + Self(code: JsonRpcErrorCode.sessionNotFound, message: message, data: data) + } + + static func requestCancelled(message: String = "Request cancelled", data: JsonValue? = nil) -> Self { + Self(code: JsonRpcErrorCode.requestCancelled, message: message, data: data) + } + + static func requestTimeout(message: String = "Request timeout", data: JsonValue? = nil) -> Self { + Self(code: JsonRpcErrorCode.requestTimeout, message: message, data: data) + } + + static func resourceNotFound(message: String = "Resource not found", data: JsonValue? = nil) -> Self { + Self(code: JsonRpcErrorCode.resourceNotFound, message: message, data: data) + } + + static func tooLarge(message: String = "Payload too large", data: JsonValue? = nil) -> Self { + Self(code: JsonRpcErrorCode.tooLarge, message: message, data: data) + } + + static func serverDisabled(message: String = "Server disabled", data: JsonValue? = nil) -> Self { + Self(code: JsonRpcErrorCode.serverDisabled, message: message, data: data) + } + + static func forbidden(message: String = "Forbidden", data: JsonValue? = nil) -> Self { + Self(code: JsonRpcErrorCode.forbidden, message: message, data: data) + } + + static func expired(message: String = "Expired", data: JsonValue? = nil) -> Self { + Self(code: JsonRpcErrorCode.expired, message: message, data: data) + } +} diff --git a/TablePro/Core/MCP/Wire/JsonRpcErrorCode.swift b/TablePro/Core/MCP/Wire/JsonRpcErrorCode.swift new file mode 100644 index 000000000..bb6059355 --- /dev/null +++ b/TablePro/Core/MCP/Wire/JsonRpcErrorCode.swift @@ -0,0 +1,22 @@ +import Foundation + +public enum JsonRpcErrorCode { + public static let parseError = -32_700 + public static let invalidRequest = -32_600 + public static let methodNotFound = -32_601 + public static let invalidParams = -32_602 + public static let internalError = -32_603 + + public static let serverError = -32_000 + public static let sessionNotFound = -32_001 + public static let requestCancelled = -32_002 + public static let requestTimeout = -32_003 + public static let resourceNotFound = -32_004 + public static let tooLarge = -32_005 + public static let serverDisabled = -32_006 + public static let forbidden = -32_007 + public static let expired = -32_008 + public static let unauthenticated = -32_009 + + public static let serverErrorRange: ClosedRange = -32_099 ... -32_000 +} diff --git a/TablePro/Core/MCP/Wire/JsonRpcId.swift b/TablePro/Core/MCP/Wire/JsonRpcId.swift new file mode 100644 index 000000000..9608ae585 --- /dev/null +++ b/TablePro/Core/MCP/Wire/JsonRpcId.swift @@ -0,0 +1,43 @@ +import Foundation + +public enum JsonRpcId: Codable, Equatable, Hashable, Sendable { + case string(String) + case number(Int64) + case null + + public init(from decoder: Decoder) throws { + let container = try decoder.singleValueContainer() + + if container.decodeNil() { + self = .null + return + } + + if let intValue = try? container.decode(Int64.self) { + self = .number(intValue) + return + } + + if let stringValue = try? container.decode(String.self) { + self = .string(stringValue) + return + } + + throw DecodingError.dataCorruptedError( + in: container, + debugDescription: "JsonRpcId must be a string, integer, or null" + ) + } + + public func encode(to encoder: Encoder) throws { + var container = encoder.singleValueContainer() + switch self { + case .string(let value): + try container.encode(value) + case .number(let value): + try container.encode(value) + case .null: + try container.encodeNil() + } + } +} diff --git a/TablePro/Core/MCP/Wire/JsonRpcMessage.swift b/TablePro/Core/MCP/Wire/JsonRpcMessage.swift new file mode 100644 index 000000000..b22653d23 --- /dev/null +++ b/TablePro/Core/MCP/Wire/JsonRpcMessage.swift @@ -0,0 +1,269 @@ +import Foundation + +public enum JsonRpcDecodingError: Error, Equatable, Sendable { + case missingJsonRpcVersion + case invalidJsonRpcVersion(String) + case ambiguousMessage + case missingMethod + case missingResultOrError + case batchUnsupported +} + +public struct JsonRpcRequest: Codable, Equatable, Sendable { + public let id: JsonRpcId + public let method: String + public let params: JsonValue? + + public init(id: JsonRpcId, method: String, params: JsonValue? = nil) { + self.id = id + self.method = method + self.params = params + } + + enum CodingKeys: String, CodingKey { + case jsonrpc + case id + case method + case params + } + + public init(from decoder: Decoder) throws { + let container = try decoder.container(keyedBy: CodingKeys.self) + guard let version = try container.decodeIfPresent(String.self, forKey: .jsonrpc) else { + throw JsonRpcDecodingError.missingJsonRpcVersion + } + guard version == JsonRpcVersion.current else { + throw JsonRpcDecodingError.invalidJsonRpcVersion(version) + } + id = try container.decode(JsonRpcId.self, forKey: .id) + method = try container.decode(String.self, forKey: .method) + params = try container.decodeIfPresent(JsonValue.self, forKey: .params) + } + + public func encode(to encoder: Encoder) throws { + var container = encoder.container(keyedBy: CodingKeys.self) + try container.encode(JsonRpcVersion.current, forKey: .jsonrpc) + try container.encode(id, forKey: .id) + try container.encode(method, forKey: .method) + try container.encodeIfPresent(params, forKey: .params) + } +} + +public struct JsonRpcNotification: Codable, Equatable, Sendable { + public let method: String + public let params: JsonValue? + + public init(method: String, params: JsonValue? = nil) { + self.method = method + self.params = params + } + + enum CodingKeys: String, CodingKey { + case jsonrpc + case method + case params + } + + public init(from decoder: Decoder) throws { + let container = try decoder.container(keyedBy: CodingKeys.self) + guard let version = try container.decodeIfPresent(String.self, forKey: .jsonrpc) else { + throw JsonRpcDecodingError.missingJsonRpcVersion + } + guard version == JsonRpcVersion.current else { + throw JsonRpcDecodingError.invalidJsonRpcVersion(version) + } + method = try container.decode(String.self, forKey: .method) + params = try container.decodeIfPresent(JsonValue.self, forKey: .params) + } + + public func encode(to encoder: Encoder) throws { + var container = encoder.container(keyedBy: CodingKeys.self) + try container.encode(JsonRpcVersion.current, forKey: .jsonrpc) + try container.encode(method, forKey: .method) + try container.encodeIfPresent(params, forKey: .params) + } +} + +public struct JsonRpcSuccessResponse: Codable, Equatable, Sendable { + public let id: JsonRpcId + public let result: JsonValue + + public init(id: JsonRpcId, result: JsonValue) { + self.id = id + self.result = result + } + + enum CodingKeys: String, CodingKey { + case jsonrpc + case id + case result + } + + public init(from decoder: Decoder) throws { + let container = try decoder.container(keyedBy: CodingKeys.self) + guard let version = try container.decodeIfPresent(String.self, forKey: .jsonrpc) else { + throw JsonRpcDecodingError.missingJsonRpcVersion + } + guard version == JsonRpcVersion.current else { + throw JsonRpcDecodingError.invalidJsonRpcVersion(version) + } + id = try container.decode(JsonRpcId.self, forKey: .id) + result = try container.decode(JsonValue.self, forKey: .result) + } + + public func encode(to encoder: Encoder) throws { + var container = encoder.container(keyedBy: CodingKeys.self) + try container.encode(JsonRpcVersion.current, forKey: .jsonrpc) + try container.encode(id, forKey: .id) + try container.encode(result, forKey: .result) + } +} + +public struct JsonRpcErrorResponse: Codable, Equatable, Sendable { + public let id: JsonRpcId? + public let error: JsonRpcError + + public init(id: JsonRpcId?, error: JsonRpcError) { + self.id = id + self.error = error + } + + enum CodingKeys: String, CodingKey { + case jsonrpc + case id + case error + } + + public init(from decoder: Decoder) throws { + let container = try decoder.container(keyedBy: CodingKeys.self) + guard let version = try container.decodeIfPresent(String.self, forKey: .jsonrpc) else { + throw JsonRpcDecodingError.missingJsonRpcVersion + } + guard version == JsonRpcVersion.current else { + throw JsonRpcDecodingError.invalidJsonRpcVersion(version) + } + if container.contains(.id) { + id = try container.decode(JsonRpcId.self, forKey: .id) + } else { + id = nil + } + error = try container.decode(JsonRpcError.self, forKey: .error) + } + + public func encode(to encoder: Encoder) throws { + var container = encoder.container(keyedBy: CodingKeys.self) + try container.encode(JsonRpcVersion.current, forKey: .jsonrpc) + if let id { + try container.encode(id, forKey: .id) + } else { + try container.encode(JsonRpcId.null, forKey: .id) + } + try container.encode(error, forKey: .error) + } +} + +public enum JsonRpcMessage: Equatable, Sendable { + case request(JsonRpcRequest) + case notification(JsonRpcNotification) + case successResponse(JsonRpcSuccessResponse) + case errorResponse(JsonRpcErrorResponse) +} + +extension JsonRpcMessage: Codable { + enum DiscriminatorKeys: String, CodingKey { + case jsonrpc + case id + case method + case params + case result + case error + } + + public init(from decoder: Decoder) throws { + let container = try decoder.container(keyedBy: DiscriminatorKeys.self) + + guard let version = try container.decodeIfPresent(String.self, forKey: .jsonrpc) else { + throw JsonRpcDecodingError.missingJsonRpcVersion + } + guard version == JsonRpcVersion.current else { + throw JsonRpcDecodingError.invalidJsonRpcVersion(version) + } + + let hasId = container.contains(.id) + let hasMethod = container.contains(.method) + let hasResult = container.contains(.result) + let hasError = container.contains(.error) + + if hasMethod, hasResult || hasError { + throw JsonRpcDecodingError.ambiguousMessage + } + + if hasResult, hasError { + throw JsonRpcDecodingError.ambiguousMessage + } + + if hasMethod { + if hasId { + self = .request(try JsonRpcRequest(from: decoder)) + return + } + self = .notification(try JsonRpcNotification(from: decoder)) + return + } + + if hasResult { + self = .successResponse(try JsonRpcSuccessResponse(from: decoder)) + return + } + + if hasError { + self = .errorResponse(try JsonRpcErrorResponse(from: decoder)) + return + } + + if hasId { + throw JsonRpcDecodingError.missingResultOrError + } + + throw JsonRpcDecodingError.missingMethod + } + + public func encode(to encoder: Encoder) throws { + switch self { + case .request(let request): + try request.encode(to: encoder) + case .notification(let notification): + try notification.encode(to: encoder) + case .successResponse(let response): + try response.encode(to: encoder) + case .errorResponse(let response): + try response.encode(to: encoder) + } + } +} + +public extension JsonRpcMessage { + static func decode(from data: Data) throws -> JsonRpcMessage { + guard let firstNonWhitespace = data.first(where: { !$0.isAsciiWhitespace }) else { + throw JsonRpcDecodingError.missingJsonRpcVersion + } + if firstNonWhitespace == 0x5B { + throw JsonRpcDecodingError.batchUnsupported + } + + let decoder = JSONDecoder() + return try decoder.decode(JsonRpcMessage.self, from: data) + } + + func encode() throws -> Data { + let encoder = JSONEncoder() + encoder.outputFormatting = [] + return try encoder.encode(self) + } +} + +private extension UInt8 { + var isAsciiWhitespace: Bool { + self == 0x20 || self == 0x09 || self == 0x0A || self == 0x0D + } +} diff --git a/TablePro/Core/MCP/Wire/JsonRpcVersion.swift b/TablePro/Core/MCP/Wire/JsonRpcVersion.swift new file mode 100644 index 000000000..ed52ccd3a --- /dev/null +++ b/TablePro/Core/MCP/Wire/JsonRpcVersion.swift @@ -0,0 +1,15 @@ +import Foundation + +public enum JsonRpcVersionError: Error, Equatable, Sendable { + case unsupported(String) +} + +public enum JsonRpcVersion { + public static let current = "2.0" + + public static func validate(_ value: String) throws { + guard value == current else { + throw JsonRpcVersionError.unsupported(value) + } + } +} diff --git a/TablePro/Core/MCP/Wire/JsonValue.swift b/TablePro/Core/MCP/Wire/JsonValue.swift new file mode 100644 index 000000000..9ff2d6502 --- /dev/null +++ b/TablePro/Core/MCP/Wire/JsonValue.swift @@ -0,0 +1,162 @@ +import Foundation + +public enum JsonValue: Codable, Equatable, Sendable { + case null + case bool(Bool) + case int(Int) + case double(Double) + case string(String) + case array([JsonValue]) + case object([String: JsonValue]) + + public init(from decoder: Decoder) throws { + let container = try decoder.singleValueContainer() + + if container.decodeNil() { + self = .null + return + } + + if let boolValue = try? container.decode(Bool.self) { + self = .bool(boolValue) + return + } + + if let intValue = try? container.decode(Int.self) { + self = .int(intValue) + return + } + + if let doubleValue = try? container.decode(Double.self) { + self = .double(doubleValue) + return + } + + if let stringValue = try? container.decode(String.self) { + self = .string(stringValue) + return + } + + if let arrayValue = try? container.decode([JsonValue].self) { + self = .array(arrayValue) + return + } + + if let objectValue = try? container.decode([String: JsonValue].self) { + self = .object(objectValue) + return + } + + throw DecodingError.dataCorruptedError(in: container, debugDescription: "Cannot decode JsonValue") + } + + public func encode(to encoder: Encoder) throws { + var container = encoder.singleValueContainer() + switch self { + case .null: + try container.encodeNil() + case .bool(let value): + try container.encode(value) + case .int(let value): + try container.encode(value) + case .double(let value): + try container.encode(value) + case .string(let value): + try container.encode(value) + case .array(let value): + try container.encode(value) + case .object(let value): + try container.encode(value) + } + } +} + +extension JsonValue: ExpressibleByStringLiteral { + public init(stringLiteral value: String) { + self = .string(value) + } +} + +extension JsonValue: ExpressibleByIntegerLiteral { + public init(integerLiteral value: Int) { + self = .int(value) + } +} + +extension JsonValue: ExpressibleByFloatLiteral { + public init(floatLiteral value: Double) { + self = .double(value) + } +} + +extension JsonValue: ExpressibleByBooleanLiteral { + public init(booleanLiteral value: Bool) { + self = .bool(value) + } +} + +extension JsonValue: ExpressibleByNilLiteral { + public init(nilLiteral: ()) { + self = .null + } +} + +extension JsonValue: ExpressibleByArrayLiteral { + public init(arrayLiteral elements: JsonValue...) { + self = .array(elements) + } +} + +extension JsonValue: ExpressibleByDictionaryLiteral { + public init(dictionaryLiteral elements: (String, JsonValue)...) { + self = .object(Dictionary(uniqueKeysWithValues: elements)) + } +} + +public extension JsonValue { + subscript(key: String) -> JsonValue? { + guard case .object(let dict) = self else { return nil } + return dict[key] + } + + var isNull: Bool { + if case .null = self { return true } + return false + } + + var stringValue: String? { + guard case .string(let value) = self else { return nil } + return value + } + + var intValue: Int? { + guard case .int(let value) = self else { return nil } + return value + } + + var boolValue: Bool? { + guard case .bool(let value) = self else { return nil } + return value + } + + var doubleValue: Double? { + switch self { + case .double(let value): + return value + case .int(let value): + return Double(value) + default: + return nil + } + } + + var arrayValue: [JsonValue]? { + guard case .array(let value) = self else { return nil } + return value + } + + var objectValue: [String: JsonValue]? { + guard case .object(let value) = self else { return nil } + return value + } +} diff --git a/TablePro/Core/MCP/Wire/SseDecoder.swift b/TablePro/Core/MCP/Wire/SseDecoder.swift new file mode 100644 index 000000000..d8a9d61f8 --- /dev/null +++ b/TablePro/Core/MCP/Wire/SseDecoder.swift @@ -0,0 +1,129 @@ +import Foundation + +public actor SseDecoder { + private var buffer: Data + private var pendingEvent: String? + private var pendingId: String? + private var pendingRetry: Int? + private var pendingDataLines: [String] + private var hasPendingFields: Bool + + public init() { + buffer = Data() + pendingEvent = nil + pendingId = nil + pendingRetry = nil + pendingDataLines = [] + hasPendingFields = false + } + + public func feed(_ chunk: Data) -> [SseFrame] { + buffer.append(chunk) + + var frames: [SseFrame] = [] + + while let line = takeLine() { + if line.isEmpty { + if let frame = flushFrame() { + frames.append(frame) + } + continue + } + processLine(line) + } + + return frames + } + + private func takeLine() -> String? { + var index = buffer.startIndex + while index < buffer.endIndex { + let byte = buffer[index] + if byte == 0x0A { + let lineData = buffer[buffer.startIndex.. String { + String(data: data, encoding: .utf8) ?? "" + } + + private func processLine(_ line: String) { + if line.first == ":" { + return + } + + let field: String + let value: String + if let colonIndex = line.firstIndex(of: ":") { + field = String(line[line.startIndex.. SseFrame? { + defer { resetPending() } + + guard hasPendingFields else { return nil } + guard !pendingDataLines.isEmpty else { return nil } + + let data = pendingDataLines.joined(separator: "\n") + return SseFrame( + event: pendingEvent, + id: pendingId, + data: data, + retry: pendingRetry + ) + } + + private func resetPending() { + pendingEvent = nil + pendingId = nil + pendingRetry = nil + pendingDataLines = [] + hasPendingFields = false + } +} diff --git a/TablePro/Core/MCP/Wire/SseEncoder.swift b/TablePro/Core/MCP/Wire/SseEncoder.swift new file mode 100644 index 000000000..aa75adf9f --- /dev/null +++ b/TablePro/Core/MCP/Wire/SseEncoder.swift @@ -0,0 +1,58 @@ +import Foundation + +public enum SseEncoder { + public static func encode(_ frame: SseFrame) -> Data { + var output = "" + + if let event = frame.event { + output += "event: \(event)\n" + } + + if let id = frame.id { + output += "id: \(id)\n" + } + + if let retry = frame.retry { + output += "retry: \(retry)\n" + } + + let dataLines = splitLines(frame.data) + for line in dataLines { + output += "data: \(line)\n" + } + + output += "\n" + return Data(output.utf8) + } + + private static func splitLines(_ value: String) -> [String] { + var lines: [String] = [] + var current = "" + let characters = Array(value) + var index = 0 + while index < characters.count { + let char = characters[index] + if char == "\r" { + lines.append(current) + current = "" + let nextIndex = index + 1 + if nextIndex < characters.count, characters[nextIndex] == "\n" { + index = nextIndex + 1 + continue + } + index += 1 + continue + } + if char == "\n" { + lines.append(current) + current = "" + index += 1 + continue + } + current.append(char) + index += 1 + } + lines.append(current) + return lines + } +} diff --git a/TablePro/Core/MCP/Wire/SseFrame.swift b/TablePro/Core/MCP/Wire/SseFrame.swift new file mode 100644 index 000000000..eea89d726 --- /dev/null +++ b/TablePro/Core/MCP/Wire/SseFrame.swift @@ -0,0 +1,15 @@ +import Foundation + +public struct SseFrame: Sendable, Equatable { + public let event: String? + public let id: String? + public let data: String + public let retry: Int? + + public init(event: String? = nil, id: String? = nil, data: String, retry: Int? = nil) { + self.event = event + self.id = id + self.data = data + self.retry = retry + } +} diff --git a/TablePro/Core/Services/Infrastructure/LaunchIntentRouter.swift b/TablePro/Core/Services/Infrastructure/LaunchIntentRouter.swift index 7a60f1959..cd31312d8 100644 --- a/TablePro/Core/Services/Infrastructure/LaunchIntentRouter.swift +++ b/TablePro/Core/Services/Infrastructure/LaunchIntentRouter.swift @@ -43,7 +43,7 @@ internal final class LaunchIntentRouter { } } catch let error as TabRouterError where error == .userCancelled { Self.logger.info("Intent cancelled by user") - } catch let error as MCPError where error.isUserCancelled { + } catch let error as MCPDataLayerError where error.isUserCancelled { Self.logger.info("Pairing cancelled by user") } catch is CancellationError { Self.logger.info("Intent cancelled") diff --git a/TablePro/Resources/Localizable.xcstrings b/TablePro/Resources/Localizable.xcstrings index d076f61e9..212c2afaf 100644 --- a/TablePro/Resources/Localizable.xcstrings +++ b/TablePro/Resources/Localizable.xcstrings @@ -10565,6 +10565,9 @@ } } } + }, + "Confirm Destructive Operation" : { + }, "Confirm passphrase" : { "localizations" : { @@ -10716,6 +10719,9 @@ } } } + }, + "Connect to a saved database" : { + }, "Connect to popular databases with full feature support" : { "localizations" : { @@ -13835,6 +13841,18 @@ } } } + }, + "Database name (uses connection's current database if omitted)" : { + + }, + "Database name (uses current if omitted)" : { + + }, + "Database name to switch to" : { + + }, + "Database Schema" : { + }, "Database Size" : { "localizations" : { @@ -15309,6 +15327,9 @@ } } } + }, + "Describe Table" : { + }, "Description" : { "localizations" : { @@ -15687,6 +15708,9 @@ } } } + }, + "Disconnect from a database" : { + }, "Disconnected" : { "localizations" : { @@ -16594,6 +16618,9 @@ } } } + }, + "Earliest executed_at to include, Unix epoch seconds (inclusive, optional)" : { + }, "Edit" : { "localizations" : { @@ -17979,6 +18006,9 @@ } } } + }, + "Execute a destructive DDL query (DROP, TRUNCATE, ALTER...DROP) after explicit confirmation." : { + }, "Execute a query to view results as JSON" : { "localizations" : { @@ -18001,6 +18031,9 @@ } } } + }, + "Execute a SQL query. All queries are subject to the connection's safe mode policy. DROP/TRUNCATE/ALTER...DROP must use the confirm_destructive_operation tool." : { + }, "Execute All" : { "localizations" : { @@ -18920,7 +18953,6 @@ } }, "Export Data" : { - "extractionState" : "stale", "localizations" : { "tr" : { "stringUnit" : { @@ -19095,6 +19127,9 @@ } } } + }, + "Export format: csv, json, or sql" : { + }, "Export Formats" : { "localizations" : { @@ -19183,6 +19218,9 @@ } } } + }, + "Export query results or table data to CSV, JSON, or SQL" : { + }, "Export query results to %@" : { "localizations" : { @@ -20865,6 +20903,9 @@ } } } + }, + "File path inside the user's Downloads directory (returns inline data if omitted). Paths outside Downloads are rejected." : { + }, "Filename cannot be '.' or '..' or contain path traversal" : { "localizations" : { @@ -21399,6 +21440,9 @@ } } } + }, + "Focus an already-open tab by id (returned from list_recent_tabs)." : { + }, "Focus Border" : { "localizations" : { @@ -21421,6 +21465,9 @@ } } } + }, + "Focus Query Tab" : { + }, "Focus the query editor to insert" : { "localizations" : { @@ -22061,6 +22108,15 @@ } } } + }, + "Get Connection Status" : { + + }, + "Get detailed status of a database connection" : { + + }, + "Get detailed table structure: columns, indexes, foreign keys, and DDL" : { + }, "Get help writing queries, explaining schemas, or fixing errors." : { "localizations" : { @@ -22127,6 +22183,12 @@ } } } + }, + "Get Table DDL" : { + + }, + "Get the CREATE TABLE DDL statement for a table" : { + }, "GitHub Repository" : { "localizations" : { @@ -23620,6 +23682,9 @@ } } } + }, + "Include approximate row counts (default false)" : { + }, "Include column headers" : { "extractionState" : "stale", @@ -25951,6 +26016,9 @@ } } } + }, + "Latest executed_at to include, Unix epoch seconds (inclusive, optional)" : { + }, "Layout" : { "extractionState" : "stale", @@ -26507,6 +26575,39 @@ } } } + }, + "List all databases on the server" : { + + }, + "List all saved database connections with their status" : { + + }, + "List Connections" : { + + }, + "List currently open tabs across all TablePro windows. Returns connection, tab type, table name, and titles for each tab." : { + + }, + "List Databases" : { + + }, + "List of all saved database connections with metadata" : { + + }, + "List Recent Tabs" : { + + }, + "List Schemas" : { + + }, + "List schemas in a database" : { + + }, + "List Tables" : { + + }, + "List tables and views in a database" : { + }, "Load" : { "extractionState" : "stale", @@ -27512,6 +27613,9 @@ } } } + }, + "Maximum number of entries to return (default 50, max 500)" : { + }, "Maximum row limit" : { "localizations" : { @@ -27534,6 +27638,12 @@ } } } + }, + "Maximum rows to export (default 50000)" : { + + }, + "Maximum rows to return (default 500, max 10000)" : { + }, "Maximum time to wait for a query to complete. Set to 0 for no limit. Applied to new connections." : { "localizations" : { @@ -28387,6 +28497,9 @@ } } } + }, + "Must be exactly: I understand this is irreversible" : { + }, "My Server" : { "localizations" : { @@ -31642,6 +31755,9 @@ } } } + }, + "Only affects new saves. Re-save a password to update its sync." : { + }, "Open" : { "localizations" : { @@ -31708,6 +31824,12 @@ } } } + }, + "Open a table tab in TablePro for the given connection." : { + + }, + "Open a TablePro window for a saved connection (focuses if already open)." : { + }, "Open Claude Desktop, go to Settings > Developer" : { "localizations" : { @@ -31752,6 +31874,9 @@ } } } + }, + "Open Connection Window" : { + }, "Open containing folder" : { "localizations" : { @@ -32178,6 +32303,9 @@ } } } + }, + "Open Table Tab" : { + }, "Open Terminal" : { "localizations" : { @@ -35455,6 +35583,9 @@ } } } + }, + "Query history for %@" : { + }, "Query History:" : { "extractionState" : "stale", @@ -35624,6 +35755,9 @@ } } } + }, + "Query timeout in seconds (default 30, max 300)" : { + }, "Query timeout:" : { "localizations" : { @@ -36359,6 +36493,12 @@ } } } + }, + "Recent query history for a connection (supports ?limit=, ?search=, ?date_filter=)" : { + + }, + "Recent query history for this connection" : { + }, "Reconnect" : { "localizations" : { @@ -37836,6 +37976,9 @@ } } } + }, + "Restrict to a specific connection (UUID, optional)" : { + }, "Results" : { "localizations" : { @@ -39373,6 +39516,18 @@ } } } + }, + "Schema for %@" : { + + }, + "Schema name (for multi-schema databases)" : { + + }, + "Schema name (uses current if omitted)" : { + + }, + "Schema name to switch to" : { + }, "Schema Switch Failed" : { "localizations" : { @@ -39661,6 +39816,12 @@ } } } + }, + "Search Query History" : { + + }, + "Search saved query history. Returns matching entries with execution time, row count, and outcome." : { + }, "Search schemas..." : { "localizations" : { @@ -39727,6 +39888,9 @@ } } } + }, + "Search text (full-text matched against the query column)" : { + }, "Search..." : { "localizations" : { @@ -42558,6 +42722,9 @@ } } } + }, + "SQL or NoSQL query text" : { + }, "SQL Preview" : { "localizations" : { @@ -42602,6 +42769,9 @@ } } } + }, + "SQL query to export results from" : { + }, "SQL Server" : { "extractionState" : "stale", @@ -44268,6 +44438,15 @@ } } } + }, + "Switch Schema" : { + + }, + "Switch the active database on a connection" : { + + }, + "Switch the active schema on a connection" : { + }, "Switch to Inline Configuration" : { "localizations" : { @@ -44296,6 +44475,12 @@ } } } + }, + "Switch to this database before executing" : { + + }, + "Switch to this schema before executing" : { + }, "Sync" : { "extractionState" : "stale", @@ -44925,6 +45110,9 @@ } } } + }, + "Table name" : { + }, "Table Name" : { "extractionState" : "stale", @@ -44948,6 +45136,9 @@ } } } + }, + "Table name to open" : { + }, "Table Name:" : { "localizations" : { @@ -44970,6 +45161,9 @@ } } } + }, + "Table names to export (alternative to query)" : { + }, "Table: %@" : { "localizations" : { @@ -45080,6 +45274,12 @@ } } } + }, + "Tables, columns, indexes, and foreign keys for a connected database" : { + + }, + "Tables, columns, indexes, and foreign keys for the connected database" : { + }, "Tablespace" : { "extractionState" : "stale", @@ -45641,6 +45841,9 @@ } } } + }, + "The destructive query to execute" : { + }, "The encrypted file is corrupt or incomplete" : { "localizations" : { @@ -49602,6 +49805,18 @@ } } } + }, + "UUID of the active connection" : { + + }, + "UUID of the connection" : { + + }, + "UUID of the connection to disconnect" : { + + }, + "UUID of the saved connection" : { + }, "v%@" : { "localizations" : { diff --git a/TablePro/Views/Settings/Sections/MCPAuditLogView.swift b/TablePro/Views/Settings/Sections/MCPAuditLogView.swift index 8ebf5a824..b9425878c 100644 --- a/TablePro/Views/Settings/Sections/MCPAuditLogView.swift +++ b/TablePro/Views/Settings/Sections/MCPAuditLogView.swift @@ -12,6 +12,9 @@ struct MCPAuditLogView: View { @State private var searchText: String = "" @State private var isLoading = false + private let auditChanges = NotificationCenter.default + .publisher(for: .mcpAuditLogChanged) + var body: some View { VStack(alignment: .leading, spacing: 12) { searchBar @@ -36,6 +39,9 @@ struct MCPAuditLogView: View { } .padding() .task { await reload() } + .onReceive(auditChanges) { _ in + Task { await reload() } + } } private var searchBar: some View { diff --git a/TablePro/Views/Settings/Sections/MCPSection.swift b/TablePro/Views/Settings/Sections/MCPSection.swift index 644e59f0d..c462328c1 100644 --- a/TablePro/Views/Settings/Sections/MCPSection.swift +++ b/TablePro/Views/Settings/Sections/MCPSection.swift @@ -10,7 +10,7 @@ struct MCPSection: View { @State private var showRevealSheet = false @State private var revealedToken: MCPAuthToken? @State private var revealedPlaintext: String = "" - @State private var disconnectCandidate: MCPServer.SessionSnapshot? + @State private var disconnectCandidate: MCPServerManager.SessionSnapshot? var body: some View { Section(String(localized: "Integrations")) { @@ -244,7 +244,9 @@ private struct MCPSetupInstructions: View { .font(.callout) } - private var url: String { "http://127.0.0.1:\(port)/mcp" } + private var bridgeBinaryPath: String { + Bundle.main.bundleURL.appendingPathComponent("Contents/MacOS/tablepro-mcp").path + } private var steps: [String] { switch tool { @@ -273,7 +275,7 @@ private struct MCPSetupInstructions: View { { "mcpServers": { "tablepro": { - "url": "\(url)" + "command": "\(bridgeBinaryPath)" } } } @@ -285,7 +287,7 @@ private struct MCPSetupInstructions: View { private var command: String? { switch tool { - case .claudeCode: "claude mcp add tablepro --transport http \(url)" + case .claudeCode: "claude mcp add tablepro -- \(bridgeBinaryPath)" default: nil } } diff --git a/TablePro/Views/Settings/Sections/PairingApprovalSheet.swift b/TablePro/Views/Settings/Sections/PairingApprovalSheet.swift index 521eaf3ea..1335978de 100644 --- a/TablePro/Views/Settings/Sections/PairingApprovalSheet.swift +++ b/TablePro/Views/Settings/Sections/PairingApprovalSheet.swift @@ -207,7 +207,7 @@ struct PairingApprovalSheet: View { private var actionBar: some View { HStack { Button(String(localized: "Deny"), role: .cancel) { - onComplete(.failure(MCPError.userCancelled)) + onComplete(.failure(MCPDataLayerError.userCancelled)) } .keyboardShortcut(.cancelAction) diff --git a/TableProTests/Core/MCP/Auth/MCPBearerTokenAuthenticatorTests.swift b/TableProTests/Core/MCP/Auth/MCPBearerTokenAuthenticatorTests.swift new file mode 100644 index 000000000..07ea832c6 --- /dev/null +++ b/TableProTests/Core/MCP/Auth/MCPBearerTokenAuthenticatorTests.swift @@ -0,0 +1,229 @@ +import Foundation +@testable import TablePro +import Testing + +actor FakeMCPTokenStore: MCPTokenStoreProtocol { + private var tokens: [String: MCPValidatedToken] = [:] + private var expired: Set = [] + private var revoked: Set = [] + + func register(_ plaintext: String, validated: MCPValidatedToken) { + tokens[plaintext] = validated + } + + func markExpired(_ plaintext: String) { + expired.insert(plaintext) + } + + func markRevoked(_ plaintext: String) { + revoked.insert(plaintext) + } + + func validateBearerToken(_ token: String) async -> Result { + if expired.contains(token) { + return .failure(.expired) + } + if revoked.contains(token) { + return .failure(.revoked) + } + if let validated = tokens[token] { + return .success(validated) + } + return .failure(.unknownToken) + } +} + +@Suite("MCP Bearer Token Authenticator") +struct MCPBearerTokenAuthenticatorTests { + private func makePrincipal(label: String = "test", scopes: Set = [.toolsRead]) -> MCPValidatedToken { + MCPValidatedToken( + tokenId: UUID(), + label: label, + scopes: scopes, + issuedAt: Date(timeIntervalSince1970: 1_000_000), + expiresAt: nil + ) + } + + private func makeAuthenticator( + store: FakeMCPTokenStore, + clock: MCPTestClock = MCPTestClock() + ) -> (MCPBearerTokenAuthenticator, MCPRateLimiter) { + let limiter = MCPRateLimiter(clock: clock) + let authenticator = MCPBearerTokenAuthenticator(tokenStore: store, rateLimiter: limiter) + return (authenticator, limiter) + } + + @Test("Missing header returns 401 with bearer challenge") + func missingHeader() async { + let store = FakeMCPTokenStore() + let (authenticator, _) = makeAuthenticator(store: store) + let decision = await authenticator.authenticate( + authorizationHeader: nil, + clientAddress: .loopback + ) + guard case .deny(let reason) = decision else { + Issue.record("Expected deny, got \(decision)") + return + } + #expect(reason.httpStatus == 401) + #expect(reason.challenge?.contains("Bearer") == true) + } + + @Test("Empty header returns 401") + func emptyHeader() async { + let store = FakeMCPTokenStore() + let (authenticator, _) = makeAuthenticator(store: store) + let decision = await authenticator.authenticate( + authorizationHeader: "", + clientAddress: .loopback + ) + guard case .deny(let reason) = decision else { + Issue.record("Expected deny") + return + } + #expect(reason.httpStatus == 401) + } + + @Test("Bad scheme returns 401") + func badScheme() async { + let store = FakeMCPTokenStore() + let (authenticator, _) = makeAuthenticator(store: store) + let decision = await authenticator.authenticate( + authorizationHeader: "Basic abc123", + clientAddress: .loopback + ) + guard case .deny(let reason) = decision else { + Issue.record("Expected deny") + return + } + #expect(reason.httpStatus == 401) + } + + @Test("Valid token returns allow with principal") + func validToken() async { + let store = FakeMCPTokenStore() + let plaintext = "tp_validtoken123" + await store.register(plaintext, validated: makePrincipal()) + let (authenticator, _) = makeAuthenticator(store: store) + let decision = await authenticator.authenticate( + authorizationHeader: "Bearer \(plaintext)", + clientAddress: .loopback + ) + guard case .allow(let principal) = decision else { + Issue.record("Expected allow, got \(decision)") + return + } + #expect(principal.scopes.contains(.toolsRead)) + #expect(principal.tokenFingerprint.count == 8) + #expect(!principal.tokenFingerprint.contains(plaintext)) + } + + @Test("Bearer scheme is case-insensitive") + func bearerCaseInsensitive() async { + let store = FakeMCPTokenStore() + let plaintext = "tp_token" + await store.register(plaintext, validated: makePrincipal()) + let (authenticator, _) = makeAuthenticator(store: store) + let decision = await authenticator.authenticate( + authorizationHeader: "bEaReR \(plaintext)", + clientAddress: .loopback + ) + guard case .allow = decision else { + Issue.record("Expected allow") + return + } + } + + @Test("Expired token returns 401 expired") + func expiredToken() async { + let store = FakeMCPTokenStore() + let plaintext = "tp_expired" + await store.register(plaintext, validated: makePrincipal()) + await store.markExpired(plaintext) + let (authenticator, _) = makeAuthenticator(store: store) + let decision = await authenticator.authenticate( + authorizationHeader: "Bearer \(plaintext)", + clientAddress: .loopback + ) + guard case .deny(let reason) = decision else { + Issue.record("Expected deny") + return + } + #expect(reason.httpStatus == 401) + #expect(reason.logMessage == "token_expired") + } + + @Test("Repeated bad token leads to rate limited 429") + func repeatedBadTokenRateLimited() async { + let store = FakeMCPTokenStore() + let clock = MCPTestClock() + let limiter = MCPRateLimiter(clock: clock) + let authenticator = MCPBearerTokenAuthenticator(tokenStore: store, rateLimiter: limiter) + + let badToken = "tp_unknown" + for _ in 0..<5 { + _ = await authenticator.authenticate( + authorizationHeader: "Bearer \(badToken)", + clientAddress: .loopback + ) + } + let final = await authenticator.authenticate( + authorizationHeader: "Bearer \(badToken)", + clientAddress: .loopback + ) + guard case .deny(let reason) = final else { + Issue.record("Expected deny") + return + } + #expect(reason.httpStatus == 429) + } + + @Test("Successful auth resets rate limit bucket") + func successResetsRateLimit() async { + let store = FakeMCPTokenStore() + let plaintext = "tp_good" + await store.register(plaintext, validated: makePrincipal()) + let clock = MCPTestClock() + let limiter = MCPRateLimiter(clock: clock) + let authenticator = MCPBearerTokenAuthenticator(tokenStore: store, rateLimiter: limiter) + + let goodHeader = "Bearer \(plaintext)" + for _ in 0..<3 { + _ = await authenticator.authenticate( + authorizationHeader: goodHeader, + clientAddress: .loopback + ) + } + let fingerprint = MCPBearerTokenAuthenticator.fingerprint(of: plaintext) + let key = MCPRateLimitKey(clientAddress: .loopback, principalFingerprint: fingerprint) + let locked = await limiter.isLocked(key: key) + #expect(locked == false) + } + + @Test("Different addresses with same token are isolated by rate limiter") + func addressIsolation() async { + let store = FakeMCPTokenStore() + let plaintext = "tp_token" + await store.register(plaintext, validated: makePrincipal()) + let clock = MCPTestClock() + let limiter = MCPRateLimiter(clock: clock) + let authenticator = MCPBearerTokenAuthenticator(tokenStore: store, rateLimiter: limiter) + + for _ in 0..<5 { + _ = await authenticator.authenticate( + authorizationHeader: "Bearer wrong", + clientAddress: .loopback + ) + } + + let decision = await authenticator.authenticate( + authorizationHeader: "Bearer \(plaintext)", + clientAddress: .remote("10.0.0.1") + ) + guard case .allow = decision else { + Issue.record("Expected allow on different address, got \(decision)") + return + } + } +} diff --git a/TableProTests/Core/MCP/Helpers/MCPProtocolHandlerTestSupport.swift b/TableProTests/Core/MCP/Helpers/MCPProtocolHandlerTestSupport.swift new file mode 100644 index 000000000..4bf2afd95 --- /dev/null +++ b/TableProTests/Core/MCP/Helpers/MCPProtocolHandlerTestSupport.swift @@ -0,0 +1,48 @@ +import Foundation +@testable import TablePro + +enum MCPProtocolHandlerTestSupport { + static func makeContext( + method: String, + params: JsonValue? = nil, + principalScopes: Set = [.toolsRead, .toolsWrite], + requestId: JsonRpcId = .number(1) + ) async -> MCPRequestContext { + let sessionStore = MCPSessionStore() + let progressSink = StubProgressSink() + let dispatcher = MCPProtocolDispatcher( + handlers: [], + sessionStore: sessionStore, + progressSink: progressSink, + clock: MCPSystemClock() + ) + + let session = MCPSession() + try? await session.transitionToReady() + + let principal = MCPProtocolTestSupport.makePrincipal(scopes: principalScopes) + let request = JsonRpcRequest(id: requestId, method: method, params: params) + let (exchange, _) = MCPProtocolTestSupport.makeExchange( + message: .request(request), + sessionId: session.id, + principal: principal + ) + + let cancellation = MCPCancellationToken() + let progress = MCPProgressEmitter( + progressToken: nil, + target: progressSink, + sessionId: session.id + ) + + return MCPRequestContext( + exchange: exchange, + session: session, + principal: principal, + dispatcher: dispatcher, + progress: progress, + cancellation: cancellation, + clock: MCPSystemClock() + ) + } +} diff --git a/TableProTests/Core/MCP/Helpers/MCPProtocolTestStubs.swift b/TableProTests/Core/MCP/Helpers/MCPProtocolTestStubs.swift new file mode 100644 index 000000000..17203d727 --- /dev/null +++ b/TableProTests/Core/MCP/Helpers/MCPProtocolTestStubs.swift @@ -0,0 +1,232 @@ +import Foundation +@testable import TablePro + +actor RecordingResponderSink: MCPResponderSink { + struct WriteJsonRecord { + let data: Data + let status: HttpStatus + let sessionId: MCPSessionId? + let extraHeaders: [(String, String)] + } + + private(set) var jsonWrites: [WriteJsonRecord] = [] + private(set) var acceptedCount: Int = 0 + private(set) var sseHeaderCount: Int = 0 + private(set) var sseFrames: [SseFrame] = [] + private(set) var closed: Bool = false + private(set) var sseRegistrations: [MCPSessionId] = [] + + private var continuation: CheckedContinuation? + private var completed: Bool = false + + func writeJson( + _ data: Data, + status: HttpStatus, + sessionId: MCPSessionId?, + extraHeaders: [(String, String)] + ) async { + jsonWrites.append(WriteJsonRecord( + data: data, + status: status, + sessionId: sessionId, + extraHeaders: extraHeaders + )) + } + + func writeAccepted() async { + acceptedCount += 1 + } + + func writeSseStreamHeaders(sessionId: MCPSessionId) async { + sseHeaderCount += 1 + } + + func writeSseFrame(_ frame: SseFrame) async { + sseFrames.append(frame) + } + + func closeConnection() async { + closed = true + if !completed { + completed = true + continuation?.resume() + continuation = nil + } + } + + func registerSseConnection(sessionId: MCPSessionId) async { + sseRegistrations.append(sessionId) + } + + func waitForCompletion() async { + if completed { return } + await withCheckedContinuation { (cont: CheckedContinuation) in + if completed { + cont.resume() + return + } + continuation = cont + } + } + + func firstJsonMessage() throws -> JsonRpcMessage? { + guard let record = jsonWrites.first else { return nil } + return try JsonRpcCodec.decode(record.data) + } +} + +actor StubProgressSink: MCPProgressSink { + private(set) var notifications: [(notification: JsonRpcNotification, sessionId: MCPSessionId)] = [] + + func sendNotification(_ notification: JsonRpcNotification, toSession sessionId: MCPSessionId) async { + notifications.append((notification, sessionId)) + } + + func count() -> Int { + notifications.count + } + + func methods() -> [String] { + notifications.map(\.notification.method) + } +} + +struct StubMethodHandler: MCPMethodHandler { + enum Behavior: Sendable { + case respondImmediately(JsonValue) + case throwProtocolError(MCPProtocolError) + case waitForCancellation + case slowSuccess(milliseconds: UInt64, JsonValue) + } + + static let method = "test/stub" + static let requiredScopes: Set = [] + static let allowedSessionStates: Set = [.uninitialized, .ready] + + let behavior: Behavior + let observedCancel: ObservedFlag + let started: ObservedFlag + + init(behavior: Behavior = .respondImmediately(.object(["ok": .bool(true)]))) { + self.behavior = behavior + self.observedCancel = ObservedFlag() + self.started = ObservedFlag() + } + + func handle(params: JsonValue?, context: MCPRequestContext) async throws -> JsonRpcMessage { + await started.set() + switch behavior { + case .respondImmediately(let result): + return MCPMethodHandlerHelpers.successResponse(id: context.requestId, result: result) + case .throwProtocolError(let error): + throw error + case .waitForCancellation: + while true { + if await context.cancellation.isCancelled() { + await observedCancel.set() + throw CancellationError() + } + try await Task.sleep(nanoseconds: 1_000_000) + } + case .slowSuccess(let ms, let result): + try await Task.sleep(nanoseconds: ms * 1_000_000) + return MCPMethodHandlerHelpers.successResponse(id: context.requestId, result: result) + } + } +} + +actor ObservedFlag { + private var triggered: Bool = false + + func set() { + triggered = true + } + + func value() -> Bool { + triggered + } +} + +struct ConfigurableHandler: MCPMethodHandler { + static var method: String { T.method } + static var requiredScopes: Set { T.requiredScopes } + static var allowedSessionStates: Set { T.allowedSessionStates } + + let inner: T + + func handle(params: JsonValue?, context: MCPRequestContext) async throws -> JsonRpcMessage { + try await inner.handle(params: params, context: context) + } +} + +struct ScopedToolsCallHandler: MCPMethodHandler { + static let method = "tools/call" + static let requiredScopes: Set = [.toolsWrite] + static let allowedSessionStates: Set = [.ready] + + func handle(params: JsonValue?, context: MCPRequestContext) async throws -> JsonRpcMessage { + MCPMethodHandlerHelpers.successResponse(id: context.requestId, result: .object([:])) + } +} + +struct StubToolsListHandler: MCPMethodHandler { + static let method = "tools/list" + static let requiredScopes: Set = [] + static let allowedSessionStates: Set = [.ready] + + func handle(params: JsonValue?, context: MCPRequestContext) async throws -> JsonRpcMessage { + MCPMethodHandlerHelpers.successResponse(id: context.requestId, result: .object(["tools": .array([])])) + } +} + +enum MCPProtocolTestSupport { + static func makePrincipal(scopes: Set = [.toolsRead, .toolsWrite]) -> MCPPrincipal { + MCPPrincipal( + tokenFingerprint: "test-fp", + scopes: scopes, + metadata: MCPPrincipalMetadata( + label: "test", + issuedAt: Date(timeIntervalSince1970: 1_700_000_000), + expiresAt: nil + ) + ) + } + + static func makeExchange( + message: JsonRpcMessage, + sessionId: MCPSessionId? = nil, + principal: MCPPrincipal? = nil, + receivedAt: Date = Date(timeIntervalSince1970: 1_700_000_000) + ) -> (MCPInboundExchange, RecordingResponderSink) { + let sink = RecordingResponderSink() + let requestId: JsonRpcId? + switch message { + case .request(let request): + requestId = request.id + default: + requestId = nil + } + let responder = MCPExchangeResponder(sink: sink, requestId: requestId) + let context = MCPInboundContext( + sessionId: sessionId, + principal: principal ?? makePrincipal(), + clientAddress: .loopback, + receivedAt: receivedAt, + mcpProtocolVersion: "2025-03-26" + ) + let exchange = MCPInboundExchange(message: message, context: context, responder: responder) + return (exchange, sink) + } + + static func makeRequest( + id: JsonRpcId = .number(1), + method: String, + params: JsonValue? = nil + ) -> JsonRpcMessage { + .request(JsonRpcRequest(id: id, method: method, params: params)) + } + + static func makeNotification(method: String, params: JsonValue? = nil) -> JsonRpcMessage { + .notification(JsonRpcNotification(method: method, params: params)) + } +} diff --git a/TableProTests/Core/MCP/Helpers/MCPTestClock.swift b/TableProTests/Core/MCP/Helpers/MCPTestClock.swift new file mode 100644 index 000000000..5fb8342a4 --- /dev/null +++ b/TableProTests/Core/MCP/Helpers/MCPTestClock.swift @@ -0,0 +1,62 @@ +import Foundation +@testable import TablePro + +public actor MCPTestClock: MCPClock { + private var currentDate: Date + private var pendingSleeps: [PendingSleep] = [] + + private struct PendingSleep { + let dueAt: Date + let continuation: CheckedContinuation + } + + public init(start: Date = Date(timeIntervalSince1970: 1_700_000_000)) { + self.currentDate = start + } + + public func now() -> Date { + currentDate + } + + public func sleep(for duration: Duration) async throws { + let dueAt = currentDate.addingTimeInterval(Self.seconds(of: duration)) + try await withCheckedThrowingContinuation { (continuation: CheckedContinuation) in + pendingSleeps.append(PendingSleep(dueAt: dueAt, continuation: continuation)) + } + } + + public func advance(by duration: Duration) async { + let target = currentDate.addingTimeInterval(Self.seconds(of: duration)) + currentDate = target + + let due = pendingSleeps.filter { $0.dueAt <= target } + pendingSleeps.removeAll { $0.dueAt <= target } + for sleep in due { + sleep.continuation.resume() + } + + await Task.yield() + } + + public func setNow(_ date: Date) async { + currentDate = date + let due = pendingSleeps.filter { $0.dueAt <= date } + pendingSleeps.removeAll { $0.dueAt <= date } + for sleep in due { + sleep.continuation.resume() + } + } + + public func cancelAllSleeps() { + let cancelled = pendingSleeps + pendingSleeps.removeAll() + for sleep in cancelled { + sleep.continuation.resume(throwing: CancellationError()) + } + } + + private static func seconds(of duration: Duration) -> TimeInterval { + let components = duration.components + return TimeInterval(components.seconds) + TimeInterval(components.attoseconds) / 1.0e18 + } +} diff --git a/TableProTests/Core/MCP/Helpers/MCPTransportTestStubs.swift b/TableProTests/Core/MCP/Helpers/MCPTransportTestStubs.swift new file mode 100644 index 000000000..043b4804d --- /dev/null +++ b/TableProTests/Core/MCP/Helpers/MCPTransportTestStubs.swift @@ -0,0 +1,101 @@ +import Foundation +@testable import TablePro + +actor StubAlwaysAllowAuthenticator: MCPAuthenticator { + private let principal: MCPPrincipal + + init(scopes: Set = [.toolsRead, .toolsWrite]) { + self.principal = MCPPrincipal( + tokenFingerprint: "stubtoken", + scopes: scopes, + metadata: MCPPrincipalMetadata( + label: "stub", + issuedAt: Date(timeIntervalSince1970: 1_700_000_000), + expiresAt: nil + ) + ) + } + + func authenticate( + authorizationHeader: String?, + clientAddress: MCPClientAddress + ) async -> MCPAuthDecision { + .allow(principal) + } +} + +actor StubBearerAuthenticator: MCPAuthenticator { + private let validToken: String + private let principal: MCPPrincipal + private var attemptsByAddress: [MCPClientAddress: Int] = [:] + private let maxAttempts: Int + + init(validToken: String, maxAttempts: Int = 5) { + self.validToken = validToken + self.maxAttempts = maxAttempts + self.principal = MCPPrincipal( + tokenFingerprint: "fingerprint", + scopes: [.toolsRead, .toolsWrite], + metadata: MCPPrincipalMetadata( + label: "test", + issuedAt: Date(timeIntervalSince1970: 1_700_000_000), + expiresAt: nil + ) + ) + } + + func authenticate( + authorizationHeader: String?, + clientAddress: MCPClientAddress + ) async -> MCPAuthDecision { + let attempts = attemptsByAddress[clientAddress] ?? 0 + if attempts >= maxAttempts { + return .deny(.rateLimited(retryAfterSeconds: 30)) + } + + guard let raw = authorizationHeader, !raw.isEmpty else { + attemptsByAddress[clientAddress] = attempts + 1 + return .deny(.unauthenticated(reason: "missing")) + } + + let lowered = raw.lowercased() + guard lowered.hasPrefix("bearer ") else { + attemptsByAddress[clientAddress] = attempts + 1 + return .deny(.unauthenticated(reason: "bad scheme")) + } + let token = String(raw.dropFirst("bearer ".count)).trimmingCharacters(in: .whitespaces) + + if token == validToken { + attemptsByAddress[clientAddress] = 0 + return .allow(principal) + } + + attemptsByAddress[clientAddress] = attempts + 1 + return .deny(.tokenInvalid(reason: "bad token")) + } +} + +struct NullProgressSink: MCPProgressSink { + func sendNotification(_ notification: JsonRpcNotification, toSession sessionId: MCPSessionId) async {} +} + +actor StubExchangeConsumer { + private var task: Task? + + func start( + transport: MCPHttpServerTransport, + responder: @escaping @Sendable (MCPInboundExchange) async -> Void + ) async { + let stream = transport.exchanges + task = Task { + for await exchange in stream { + await responder(exchange) + } + } + } + + func stop() { + task?.cancel() + task = nil + } +} diff --git a/TableProTests/Core/MCP/Integration/MCPBridgeIntegrationTests.swift b/TableProTests/Core/MCP/Integration/MCPBridgeIntegrationTests.swift new file mode 100644 index 000000000..37bf1edba --- /dev/null +++ b/TableProTests/Core/MCP/Integration/MCPBridgeIntegrationTests.swift @@ -0,0 +1,731 @@ +import Foundation +import Network +@testable import TablePro +import XCTest + +final class MCPBridgeIntegrationTests: XCTestCase { + fileprivate static let mcpVersion = "2024-11-05" + fileprivate static let bearerToken = "integration-token" + + func testHappyPathInitializeAndToolsListFlowsThroughBridge() async throws { + let harness = try await BridgeHarness.start(authenticator: StubAlwaysAllowAuthenticator()) + defer { harness.shutdown() } + + let consumer = StubExchangeConsumer() + await consumer.start(transport: harness.serverTransport) { exchange in + switch exchange.message { + case .request(let request): + let response = JsonRpcMessage.successResponse( + JsonRpcSuccessResponse( + id: request.id, + result: .object(["echo": .string(request.method)]) + ) + ) + await exchange.responder.respond(response, sessionId: exchange.context.sessionId) + default: + await exchange.responder.respondError(.invalidRequest(detail: "unsupported"), requestId: nil) + } + } + defer { Task { await consumer.stop() } } + + let initRequest = JsonRpcMessage.request( + JsonRpcRequest(id: .number(1), method: "initialize", params: nil) + ) + try await harness.writeFromHost(initRequest) + + let firstResponse = try await harness.readNextResponse() + guard case .successResponse(let success) = firstResponse else { + XCTFail("Expected successResponse for initialize, got \(firstResponse)") + return + } + XCTAssertEqual(success.id, .number(1)) + XCTAssertEqual(success.result["echo"]?.stringValue, "initialize") + + let toolsRequest = JsonRpcMessage.request( + JsonRpcRequest(id: .number(2), method: "tools/list", params: nil) + ) + try await harness.writeFromHost(toolsRequest) + + let secondResponse = try await harness.readNextResponse() + guard case .successResponse(let toolsSuccess) = secondResponse else { + XCTFail("Expected successResponse for tools/list, got \(secondResponse)") + return + } + XCTAssertEqual(toolsSuccess.id, .number(2)) + XCTAssertEqual(toolsSuccess.result["echo"]?.stringValue, "tools/list") + } + + func testIdleSessionEvictionReturnsSessionNotFoundError() async throws { + let clock = MCPTestClock(start: Date(timeIntervalSince1970: 1_700_000_000)) + let policy = MCPSessionPolicy( + idleTimeout: .seconds(60), + maxSessions: 16, + cleanupInterval: .seconds(60) + ) + let harness = try await BridgeHarness.start( + authenticator: StubAlwaysAllowAuthenticator(), + clock: clock, + sessionPolicy: policy + ) + defer { harness.shutdown() } + + let consumer = StubExchangeConsumer() + await consumer.start(transport: harness.serverTransport) { exchange in + switch exchange.message { + case .request(let request): + let response = JsonRpcMessage.successResponse( + JsonRpcSuccessResponse(id: request.id, result: .object(["ok": .bool(true)])) + ) + await exchange.responder.respond(response, sessionId: exchange.context.sessionId) + default: + await exchange.responder.respondError(.invalidRequest(detail: "unsupported"), requestId: nil) + } + } + defer { Task { await consumer.stop() } } + + let initRequest = JsonRpcMessage.request( + JsonRpcRequest(id: .number(10), method: "initialize", params: nil) + ) + try await harness.writeFromHost(initRequest) + + let initResponse = try await harness.readNextResponse() + guard case .successResponse = initResponse else { + XCTFail("Expected initialize success, got \(initResponse)") + return + } + let initialSessionCount = await harness.sessionStore.count() + XCTAssertEqual(initialSessionCount, 1) + + await clock.advance(by: .seconds(120)) + await harness.sessionStore.runCleanupPass() + let postCleanupCount = await harness.sessionStore.count() + XCTAssertEqual(postCleanupCount, 0) + + let followUp = JsonRpcMessage.request( + JsonRpcRequest(id: .number(11), method: "tools/call", params: nil) + ) + try await harness.writeFromHost(followUp) + + let response = try await harness.readNextResponse() + guard case .errorResponse(let envelope) = response else { + XCTFail("Expected errorResponse, got \(response)") + return + } + XCTAssertEqual(envelope.id, .number(11)) + XCTAssertEqual(envelope.error.code, JsonRpcErrorCode.sessionNotFound) + } + + func testServerReturning404WithGarbageBodyIsWrappedAsJsonRpcError() async throws { + let badServer = try await BadHttpServer.start { _ in + BadHttpResponse( + status: 404, + headers: [("Content-Type", "application/json")], + body: Data("{\"error\":\"Session not found\"}".utf8) + ) + } + defer { badServer.stop() } + + guard let url = URL(string: "http://127.0.0.1:\(badServer.port)/mcp") else { + XCTFail("Failed to build URL") + return + } + let configuration = MCPStreamableHttpClientConfiguration( + endpoint: url, + bearerToken: Self.bearerToken, + tlsCertFingerprint: nil, + requestTimeout: .seconds(5), + serverInitiatedStream: false + ) + let client = MCPStreamableHttpClientTransport(configuration: configuration, errorLogger: nil) + defer { Task { await client.close() } } + + let request = JsonRpcMessage.request( + JsonRpcRequest(id: .number(42), method: "tools/list", params: nil) + ) + try await client.send(request) + + let received = try await Self.firstInbound(of: client, timeout: 3.0) + guard case .errorResponse(let envelope) = received else { + XCTFail("Expected errorResponse, got \(received)") + return + } + XCTAssertEqual(envelope.id, .number(42)) + XCTAssertEqual(envelope.error.code, JsonRpcErrorCode.sessionNotFound) + + let encoded = try JsonRpcCodec.encode(received) + let roundTripped = try JsonRpcCodec.decode(encoded) + XCTAssertEqual(roundTripped, received) + } + + func testMalformedRequestReturnsValidJsonRpcErrorEnvelope() async throws { + let harness = try await BridgeHarness.start(authenticator: StubAlwaysAllowAuthenticator()) + defer { harness.shutdown() } + + let consumer = StubExchangeConsumer() + await consumer.start(transport: harness.serverTransport) { exchange in + await exchange.responder.respondError(.invalidRequest(detail: "should-not-reach"), requestId: nil) + } + defer { Task { await consumer.stop() } } + + guard let url = URL(string: "http://127.0.0.1:\(harness.serverPort)/mcp") else { + XCTFail("Failed to build URL") + return + } + var request = URLRequest(url: url) + request.httpMethod = "POST" + request.setValue("application/json", forHTTPHeaderField: "Content-Type") + request.setValue(Self.mcpVersion, forHTTPHeaderField: "mcp-protocol-version") + request.setValue("Bearer \(Self.bearerToken)", forHTTPHeaderField: "Authorization") + request.httpBody = Data("{\"not\":\"json-rpc\"}".utf8) + + let (data, response) = try await URLSession.shared.data(for: request) + let httpResponse = try XCTUnwrap(response as? HTTPURLResponse) + + XCTAssertGreaterThanOrEqual(httpResponse.statusCode, 400) + XCTAssertLessThan(httpResponse.statusCode, 500) + XCTAssertFalse(data.isEmpty, "Server must return a body for malformed requests") + + let decoded = try JsonRpcCodec.decode(data) + guard case .errorResponse(let envelope) = decoded else { + XCTFail("Expected JSON-RPC errorResponse envelope, got \(decoded)") + return + } + XCTAssertTrue( + envelope.error.code == JsonRpcErrorCode.invalidRequest + || envelope.error.code == JsonRpcErrorCode.parseError + || envelope.error.code == JsonRpcErrorCode.methodNotFound, + "Unexpected error code \(envelope.error.code)" + ) + + let plainErrorShape = try? JSONSerialization.jsonObject(with: data) as? [String: Any] + if let asObject = plainErrorShape { + XCTAssertNotNil(asObject["jsonrpc"], "Body must include jsonrpc field; got plain dict \(asObject)") + XCTAssertNotNil(asObject["error"], "Body must include error field") + } + } + + private static func firstInbound( + of transport: MCPStreamableHttpClientTransport, + timeout: TimeInterval + ) async throws -> JsonRpcMessage { + try await withThrowingTaskGroup(of: JsonRpcMessage?.self) { group in + group.addTask { + var iterator = transport.inbound.makeAsyncIterator() + return try await iterator.next() + } + group.addTask { + try await Task.sleep(nanoseconds: UInt64(timeout * 1_000_000_000)) + return nil + } + guard let result = try await group.next(), let value = result else { + group.cancelAll() + throw IntegrationTestError.timeout + } + group.cancelAll() + return value + } + } +} + +private enum IntegrationTestError: Error { + case timeout + case serverDidNotStart + case readClosed +} + +private struct PipePair { + let hostInput: FileHandle + let bridgeStdin: FileHandle + let bridgeStdout: FileHandle + let hostOutput: FileHandle + + let stdinPipe: Pipe + let stdoutPipe: Pipe + + static func make() -> PipePair { + let stdinPipe = Pipe() + let stdoutPipe = Pipe() + return PipePair( + hostInput: stdinPipe.fileHandleForWriting, + bridgeStdin: stdinPipe.fileHandleForReading, + bridgeStdout: stdoutPipe.fileHandleForWriting, + hostOutput: stdoutPipe.fileHandleForReading, + stdinPipe: stdinPipe, + stdoutPipe: stdoutPipe + ) + } + + func closeAll() { + try? hostInput.close() + try? bridgeStdin.close() + try? bridgeStdout.close() + try? hostOutput.close() + } +} + +private final class IntegrationBridgeLogger: MCPBridgeLogger, @unchecked Sendable { + func log(_ level: MCPBridgeLogLevel, _ message: String) {} +} + +private actor TestBridgeProxy { + private let host: any MCPMessageTransport + private let upstream: any MCPMessageTransport + private let logger: any MCPBridgeLogger + private var task: Task? + + init(host: any MCPMessageTransport, upstream: any MCPMessageTransport, logger: any MCPBridgeLogger) { + self.host = host + self.upstream = upstream + self.logger = logger + } + + func start() { + task = Task { [host, upstream, logger] in + await withTaskGroup(of: Void.self) { group in + group.addTask { + await Self.forward(from: host, to: upstream, direction: "host→upstream", logger: logger) + } + group.addTask { + await Self.forward(from: upstream, to: host, direction: "upstream→host", logger: logger) + } + await group.waitForAll() + } + } + } + + func stop() { + task?.cancel() + task = nil + } + + private static func forward( + from source: any MCPMessageTransport, + to destination: any MCPMessageTransport, + direction: String, + logger: any MCPBridgeLogger + ) async { + do { + for try await message in source.inbound { + do { + try await destination.send(message) + } catch { + logger.log(.warning, "[\(direction)] send failed: \(error.localizedDescription)") + } + } + logger.log(.info, "[\(direction)] inbound stream closed") + } catch { + logger.log(.error, "[\(direction)] inbound failed: \(error.localizedDescription)") + } + await destination.close() + } +} + +private actor LineQueue { + private var pending: [Data] = [] + private var waiters: [CheckedContinuation] = [] + private var finished = false + + func push(_ line: Data) { + if let waiter = waiters.first { + waiters.removeFirst() + waiter.resume(returning: line) + return + } + pending.append(line) + } + + func finish() { + finished = true + let toResume = waiters + waiters.removeAll() + for waiter in toResume { + waiter.resume(returning: nil) + } + } + + func next() async -> Data? { + if !pending.isEmpty { + return pending.removeFirst() + } + if finished { + return nil + } + return await withCheckedContinuation { continuation in + waiters.append(continuation) + } + } +} + +private final class BridgeHarness: @unchecked Sendable { + let serverTransport: MCPHttpServerTransport + let sessionStore: MCPSessionStore + let serverPort: UInt16 + let clientTransport: MCPStreamableHttpClientTransport + let stdioTransport: MCPStdioMessageTransport + private let proxy: TestBridgeProxy + private let pipes: PipePair + private let lineQueue = LineQueue() + private var readerTask: Task? + private let stateLock = NSLock() + + private init( + serverTransport: MCPHttpServerTransport, + sessionStore: MCPSessionStore, + serverPort: UInt16, + clientTransport: MCPStreamableHttpClientTransport, + stdioTransport: MCPStdioMessageTransport, + proxy: TestBridgeProxy, + pipes: PipePair + ) { + self.serverTransport = serverTransport + self.sessionStore = sessionStore + self.serverPort = serverPort + self.clientTransport = clientTransport + self.stdioTransport = stdioTransport + self.proxy = proxy + self.pipes = pipes + } + + static func start( + authenticator: any MCPAuthenticator, + clock: any MCPClock = MCPSystemClock(), + sessionPolicy: MCPSessionPolicy = MCPSessionPolicy( + idleTimeout: .seconds(900), + maxSessions: 16, + cleanupInterval: .seconds(60) + ) + ) async throws -> BridgeHarness { + let store = MCPSessionStore(policy: sessionPolicy, clock: clock) + let configuration = MCPHttpServerConfiguration.loopback(port: 0) + let serverTransport = MCPHttpServerTransport( + configuration: configuration, + sessionStore: store, + authenticator: authenticator, + clock: clock + ) + + let stateStream = serverTransport.listenerState + let stateTask = Task { + for await state in stateStream { + if case .running(let port) = state { + return port + } + if case .failed = state { + return nil + } + } + return nil + } + + try await serverTransport.start() + guard let port = await stateTask.value, port != 0 else { + await serverTransport.stop() + throw IntegrationTestError.serverDidNotStart + } + + guard let url = URL(string: "http://127.0.0.1:\(port)/mcp") else { + await serverTransport.stop() + throw IntegrationTestError.serverDidNotStart + } + let logger = IntegrationBridgeLogger() + let clientConfig = MCPStreamableHttpClientConfiguration( + endpoint: url, + bearerToken: MCPBridgeIntegrationTests.bearerToken, + tlsCertFingerprint: nil, + requestTimeout: .seconds(5), + serverInitiatedStream: false + ) + let clientTransport = MCPStreamableHttpClientTransport( + configuration: clientConfig, + errorLogger: logger + ) + + let pipes = PipePair.make() + let stdioTransport = MCPStdioMessageTransport( + stdin: pipes.bridgeStdin, + stdout: pipes.bridgeStdout, + errorLogger: logger + ) + + let proxy = TestBridgeProxy(host: stdioTransport, upstream: clientTransport, logger: logger) + await proxy.start() + + let harness = BridgeHarness( + serverTransport: serverTransport, + sessionStore: store, + serverPort: port, + clientTransport: clientTransport, + stdioTransport: stdioTransport, + proxy: proxy, + pipes: pipes + ) + harness.startReader() + return harness + } + + func writeFromHost(_ message: JsonRpcMessage) async throws { + let line = try JsonRpcCodec.encodeLine(message) + try pipes.hostInput.write(contentsOf: line) + } + + func readNextResponse(timeout: TimeInterval = 4.0) async throws -> JsonRpcMessage { + let line = try await readNextLine(timeout: timeout) + return try JsonRpcCodec.decode(line) + } + + private func readNextLine(timeout: TimeInterval) async throws -> Data { + let queue = lineQueue + return try await withThrowingTaskGroup(of: Data?.self) { group in + group.addTask { + await queue.next() + } + group.addTask { + try await Task.sleep(nanoseconds: UInt64(timeout * 1_000_000_000)) + return nil + } + guard let first = try await group.next(), let value = first else { + group.cancelAll() + throw IntegrationTestError.timeout + } + group.cancelAll() + return value + } + } + + fileprivate func startReader() { + stateLock.lock() + if readerTask != nil { + stateLock.unlock() + return + } + let handle = pipes.hostOutput + let queue = lineQueue + readerTask = Task.detached(priority: .userInitiated) { + var buffer = Data() + do { + for try await byte in handle.bytes { + if Task.isCancelled { return } + if byte == 0x0A { + var line = buffer + buffer.removeAll(keepingCapacity: true) + if line.last == 0x0D { + line.removeLast() + } + if !line.isEmpty { + await queue.push(line) + } + } else { + buffer.append(byte) + } + } + } catch { + // pipe closed or read error; finish the queue + } + await queue.finish() + } + stateLock.unlock() + } + + func shutdown() { + stateLock.lock() + readerTask?.cancel() + readerTask = nil + stateLock.unlock() + let queue = lineQueue + Task { await queue.finish() } + Task { await proxy.stop() } + Task { await stdioTransport.close() } + Task { await clientTransport.close() } + Task { await serverTransport.stop() } + pipes.closeAll() + } +} + +private struct BadHttpResponse: Sendable { + let status: Int + let headers: [(String, String)] + let body: Data +} + +private actor BadHttpServerState { + var responder: (@Sendable (Data) -> BadHttpResponse)? + + func setResponder(_ responder: @escaping @Sendable (Data) -> BadHttpResponse) { + self.responder = responder + } + + func respond(_ data: Data) -> BadHttpResponse { + responder?(data) ?? BadHttpResponse(status: 500, headers: [], body: Data()) + } +} + +private final class BadHttpServer: @unchecked Sendable { + private let state = BadHttpServerState() + private var listener: NWListener? + private let lock = NSLock() + private var assignedPort: UInt16 = 0 + private var connections: [NWConnection] = [] + + var port: UInt16 { + lock.lock() + defer { lock.unlock() } + return assignedPort + } + + static func start(_ responder: @escaping @Sendable (Data) -> BadHttpResponse) async throws -> BadHttpServer { + let server = BadHttpServer() + await server.state.setResponder(responder) + try await server.startListener() + return server + } + + private func startListener() async throws { + try await withCheckedThrowingContinuation { (continuation: CheckedContinuation) in + do { + let params = NWParameters.tcp + params.allowLocalEndpointReuse = true + let listener = try NWListener(using: params) + lock.lock() + self.listener = listener + lock.unlock() + listener.stateUpdateHandler = { [weak self] state in + guard let self else { return } + switch state { + case .ready: + if let port = listener.port?.rawValue { + self.lock.lock() + self.assignedPort = port + self.lock.unlock() + } + continuation.resume() + case .failed(let error): + continuation.resume(throwing: error) + default: + break + } + } + listener.newConnectionHandler = { [weak self] connection in + self?.handle(connection) + } + listener.start(queue: .global(qos: .userInitiated)) + } catch { + continuation.resume(throwing: error) + } + } + } + + func stop() { + lock.lock() + let listener = self.listener + let connections = self.connections + self.listener = nil + self.connections = [] + lock.unlock() + listener?.cancel() + for connection in connections { + connection.cancel() + } + } + + private func handle(_ connection: NWConnection) { + lock.lock() + connections.append(connection) + lock.unlock() + connection.stateUpdateHandler = { [weak self] state in + switch state { + case .ready: + self?.readLoop(connection: connection, accumulated: Data()) + case .failed, .cancelled: + break + default: + break + } + } + connection.start(queue: .global(qos: .userInitiated)) + } + + private func readLoop(connection: NWConnection, accumulated: Data) { + connection.receive(minimumIncompleteLength: 1, maximumLength: 64 * 1_024) { [weak self] data, _, isComplete, _ in + guard let self else { return } + var buffer = accumulated + if let data { + buffer.append(data) + } + + if let bodyStart = Self.findHeaderEnd(buffer) { + let contentLength = Self.contentLength(buffer.prefix(bodyStart)) + let bodyAvailable = buffer.count - bodyStart + if bodyAvailable < contentLength { + if isComplete { + connection.cancel() + return + } + self.readLoop(connection: connection, accumulated: buffer) + return + } + let body = buffer.subdata(in: bodyStart..<(bodyStart + contentLength)) + Task { + let response = await self.state.respond(body) + let raw = Self.serialize(response) + connection.send(content: raw, completion: .contentProcessed { _ in + connection.cancel() + }) + } + return + } + + if isComplete { + connection.cancel() + return + } + self.readLoop(connection: connection, accumulated: buffer) + } + } + + private static func findHeaderEnd(_ data: Data) -> Int? { + guard let range = data.range(of: Data("\r\n\r\n".utf8)) else { return nil } + return range.upperBound + } + + private static func contentLength(_ headerData: Data) -> Int { + guard let headerString = String(data: headerData, encoding: .utf8) else { return 0 } + for line in headerString.components(separatedBy: "\r\n") { + guard let colon = line.firstIndex(of: ":") else { continue } + let key = line[line.startIndex.. Data { + var output = "HTTP/1.1 \(response.status) \(reasonPhrase(for: response.status))\r\n" + var headers = response.headers + if !headers.contains(where: { $0.0.lowercased() == "content-length" }) { + headers.append(("Content-Length", "\(response.body.count)")) + } + if !headers.contains(where: { $0.0.lowercased() == "connection" }) { + headers.append(("Connection", "close")) + } + for (key, value) in headers { + output.append("\(key): \(value)\r\n") + } + output.append("\r\n") + var data = Data(output.utf8) + data.append(response.body) + return data + } + + private static func reasonPhrase(for status: Int) -> String { + switch status { + case 200: return "OK" + case 400: return "Bad Request" + case 401: return "Unauthorized" + case 404: return "Not Found" + case 500: return "Internal Server Error" + default: return "Status" + } + } +} diff --git a/TableProTests/Core/MCP/MCPAuthGuardTests.swift b/TableProTests/Core/MCP/MCPAuthGuardTests.swift deleted file mode 100644 index be9540422..000000000 --- a/TableProTests/Core/MCP/MCPAuthGuardTests.swift +++ /dev/null @@ -1,128 +0,0 @@ -// -// MCPAuthGuardTests.swift -// TableProTests -// - -import Foundation -import Testing - -@testable import TablePro - -@Suite("MCP Auth Guard external access", .serialized) -@MainActor -struct MCPAuthGuardTests { - private let storage = ConnectionStorage.shared - - private func withConnection( - externalAccess: ExternalAccessLevel, - aiPolicy: AIConnectionPolicy = .alwaysAllow, - body: (UUID) async throws -> Void - ) async throws { - let original = storage.loadConnections() - defer { storage.saveConnections(original) } - - let connection = DatabaseConnection( - name: "MCP Test", - type: .mysql, - aiPolicy: aiPolicy, - externalAccess: externalAccess - ) - storage.saveConnections([connection]) - try await body(connection.id) - } - - @Test("Read query passes when externalAccess is readOnly") - func readQueryReadOnly() async throws { - try await withConnection(externalAccess: .readOnly) { connectionId in - let guardian = MCPAuthGuard() - try await guardian.checkExternalWritePermission( - connectionId: connectionId, - sql: "SELECT * FROM users", - databaseType: .mysql - ) - } - } - - @Test("Write query is blocked when externalAccess is readOnly") - func writeQueryBlockedReadOnly() async throws { - try await withConnection(externalAccess: .readOnly) { connectionId in - let guardian = MCPAuthGuard() - do { - try await guardian.checkExternalWritePermission( - connectionId: connectionId, - sql: "UPDATE users SET name='x' WHERE id=1", - databaseType: .mysql - ) - Issue.record("Expected MCPError.forbidden for write on read-only connection") - } catch let error as MCPError { - if case .forbidden = error { - return - } - Issue.record("Expected forbidden, got \(error)") - } - } - } - - @Test("Write query passes when externalAccess is readWrite") - func writeQueryAllowedReadWrite() async throws { - try await withConnection(externalAccess: .readWrite) { connectionId in - let guardian = MCPAuthGuard() - try await guardian.checkExternalWritePermission( - connectionId: connectionId, - sql: "INSERT INTO users (id) VALUES (1)", - databaseType: .mysql - ) - } - } - - @Test("Connection access blocked when externalAccess is blocked") - func connectionAccessBlocked() async throws { - try await withConnection(externalAccess: .blocked) { connectionId in - let guardian = MCPAuthGuard() - do { - try await guardian.checkConnectionAccess( - connectionId: connectionId, - sessionId: "session-1" - ) - Issue.record("Expected MCPError.forbidden for blocked connection") - } catch let error as MCPError { - if case .forbidden = error { - return - } - Issue.record("Expected forbidden, got \(error)") - } - } - } - - @Test("Connection access allowed when externalAccess is readOnly") - func connectionAccessAllowedReadOnly() async throws { - try await withConnection(externalAccess: .readOnly) { connectionId in - let guardian = MCPAuthGuard() - try await guardian.checkConnectionAccess( - connectionId: connectionId, - sessionId: "session-1" - ) - } - } - - @Test("Missing connection rejects external write check") - func missingConnectionRejectsExternalWrite() async { - let guardian = MCPAuthGuard() - let unknownId = UUID() - do { - try await guardian.checkExternalWritePermission( - connectionId: unknownId, - sql: "UPDATE foo SET bar=1", - databaseType: .mysql - ) - Issue.record("Expected MCPError.forbidden for missing connection") - } catch let error as MCPError { - if case .forbidden = error { - return - } - Issue.record("Expected forbidden, got \(error)") - } catch { - Issue.record("Unexpected error type: \(error)") - } - } -} diff --git a/TableProTests/Core/MCP/MCPPairingServiceTests.swift b/TableProTests/Core/MCP/MCPPairingServiceTests.swift index 725508f9e..1ac1ccb61 100644 --- a/TableProTests/Core/MCP/MCPPairingServiceTests.swift +++ b/TableProTests/Core/MCP/MCPPairingServiceTests.swift @@ -55,7 +55,7 @@ struct MCPPairingServiceTests { _ = try store.consume(code: "code-3", verifier: verifier) - #expect(throws: MCPError.self) { + #expect(throws: MCPDataLayerError.self) { try store.consume(code: "code-3", verifier: verifier) } } @@ -67,7 +67,7 @@ struct MCPPairingServiceTests { do { _ = try store.consume(code: "missing", verifier: "any") Issue.record("Expected notFound error") - } catch let error as MCPError { + } catch let error as MCPDataLayerError { guard case .notFound = error else { Issue.record("Expected notFound, got \(error)") return @@ -87,7 +87,7 @@ struct MCPPairingServiceTests { do { _ = try store.consume(code: "code-4", verifier: verifier, now: Date.now) Issue.record("Expected expired error") - } catch let error as MCPError { + } catch let error as MCPDataLayerError { guard case .expired = error else { Issue.record("Expected expired, got \(error)") return @@ -106,7 +106,7 @@ struct MCPPairingServiceTests { do { _ = try store.consume(code: "code-5", verifier: "attacker-verifier") Issue.record("Expected forbidden error") - } catch let error as MCPError { + } catch let error as MCPDataLayerError { guard case .forbidden = error else { Issue.record("Expected forbidden, got \(error)") return @@ -195,7 +195,7 @@ struct MCPPairingServiceTests { record: record(plaintext: "tp_x", challenge: "challenge", expiresIn: 60) ) Issue.record("Expected forbidden error after exceeding maxPendingCodes") - } catch let error as MCPError { + } catch let error as MCPDataLayerError { guard case .forbidden = error else { Issue.record("Expected forbidden, got \(error)") return diff --git a/TableProTests/Core/MCP/MCPRateLimiterTests.swift b/TableProTests/Core/MCP/MCPRateLimiterTests.swift deleted file mode 100644 index 95830f5c4..000000000 --- a/TableProTests/Core/MCP/MCPRateLimiterTests.swift +++ /dev/null @@ -1,216 +0,0 @@ -// -// MCPRateLimiterTests.swift -// TableProTests -// - -import Foundation -@testable import TablePro -import Testing - -@Suite("MCP Rate Limiter") -struct MCPRateLimiterTests { - private func makeLimiter() -> MCPRateLimiter { - MCPRateLimiter() - } - - private func expectAllowed(_ result: MCPRateLimiter.AuthRateResult, message: String = "") { - guard case .allowed = result else { - Issue.record("Expected .allowed but got \(result). \(message)") - return - } - } - - @discardableResult - private func expectRateLimited(_ result: MCPRateLimiter.AuthRateResult, message: String = "") -> Duration? { - guard case .rateLimited(let retryAfter) = result else { - Issue.record("Expected .rateLimited but got \(result). \(message)") - return nil - } - return retryAfter - } - - @Test("First request is allowed") - func firstRequestAllowed() async { - let limiter = makeLimiter() - let result = await limiter.checkAndRecord(ip: "1.2.3.4", success: false) - expectAllowed(result) - } - - @Test("Success clears failure record") - func successClearsFailureRecord() async { - let limiter = makeLimiter() - _ = await limiter.checkAndRecord(ip: "1.2.3.4", success: false) - _ = await limiter.checkAndRecord(ip: "1.2.3.4", success: true) - let result = await limiter.checkAndRecord(ip: "1.2.3.4", success: false) - expectAllowed(result, message: "Counter should have been reset by success") - } - - @Test("Unknown IP is allowed") - func unknownIpAllowed() async { - let limiter = makeLimiter() - let result = await limiter.checkAndRecord(ip: "never-seen-before", success: false) - expectAllowed(result) - } - - @Test("isLockedOut for unknown IP returns allowed") - func isLockedOutUnknownIp() async { - let limiter = makeLimiter() - let result = await limiter.isLockedOut(ip: "unknown") - expectAllowed(result) - } - - @Test("Second failure triggers 1s lockout") - func secondFailureLockout() async { - let limiter = makeLimiter() - _ = await limiter.checkAndRecord(ip: "10.0.0.1", success: false) - let result = await limiter.checkAndRecord(ip: "10.0.0.1", success: false) - - guard let retryAfter = expectRateLimited(result, message: "Second failure should lock out") else { return } - let seconds = retryAfter.components.seconds - #expect(seconds >= 0 && seconds <= 2) - } - - @Test("Third failure triggers 5s lockout after previous lockout expires") - func thirdFailureLockout() async { - let limiter = makeLimiter() - _ = await limiter.checkAndRecord(ip: "10.0.0.2", success: false) - _ = await limiter.checkAndRecord(ip: "10.0.0.2", success: false) - - try? await Task.sleep(for: .seconds(1.1)) - - let result = await limiter.checkAndRecord(ip: "10.0.0.2", success: false) - guard let retryAfter = expectRateLimited(result, message: "Third failure should lock out for ~5s") else { return } - let seconds = retryAfter.components.seconds - #expect(seconds >= 4 && seconds <= 6) - } - - @Test("Fourth failure triggers 30s lockout") - func fourthFailureLockout() async { - let limiter = makeLimiter() - _ = await limiter.checkAndRecord(ip: "10.0.0.3", success: false) - _ = await limiter.checkAndRecord(ip: "10.0.0.3", success: false) - - try? await Task.sleep(for: .seconds(1.1)) - _ = await limiter.checkAndRecord(ip: "10.0.0.3", success: false) - - try? await Task.sleep(for: .seconds(5.1)) - let result = await limiter.checkAndRecord(ip: "10.0.0.3", success: false) - - guard let retryAfter = expectRateLimited(result, message: "Fourth failure should lock out for ~30s") else { return } - let seconds = retryAfter.components.seconds - #expect(seconds >= 28 && seconds <= 32) - } - - @Test("Repeated failures while locked return remaining lockout time") - func repeatedFailuresWhileLocked() async { - let limiter = makeLimiter() - _ = await limiter.checkAndRecord(ip: "10.0.0.4", success: false) - let lockResult = await limiter.checkAndRecord(ip: "10.0.0.4", success: false) - - guard let initialRetry = expectRateLimited(lockResult) else { return } - - let retryResult = await limiter.checkAndRecord(ip: "10.0.0.4", success: false) - guard let remainingRetry = expectRateLimited(retryResult, message: "Should still be locked") else { return } - - #expect(remainingRetry <= initialRetry) - } - - @Test("isLockedOut returns rateLimited during lockout") - func isLockedOutDuringLockout() async { - let limiter = makeLimiter() - _ = await limiter.checkAndRecord(ip: "10.0.1.1", success: false) - _ = await limiter.checkAndRecord(ip: "10.0.1.1", success: false) - - let result = await limiter.isLockedOut(ip: "10.0.1.1") - expectRateLimited(result, message: "Should be locked out after 2 failures") - } - - @Test("isLockedOut returns allowed when not locked") - func isLockedOutWhenNotLocked() async { - let limiter = makeLimiter() - let result = await limiter.isLockedOut(ip: "fresh-ip") - expectAllowed(result) - } - - @Test("Different IPs have independent counters") - func independentCounters() async { - let limiter = makeLimiter() - _ = await limiter.checkAndRecord(ip: "ip-a", success: false) - _ = await limiter.checkAndRecord(ip: "ip-a", success: false) - - let lockedResult = await limiter.isLockedOut(ip: "ip-a") - expectRateLimited(lockedResult, message: "IP-A should be locked") - - let resultB = await limiter.checkAndRecord(ip: "ip-b", success: false) - expectAllowed(resultB, message: "IP-B should be independent of IP-A") - } - - @Test("Locking one IP does not affect another") - func lockingIsolation() async { - let limiter = makeLimiter() - _ = await limiter.checkAndRecord(ip: "ip-a", success: false) - _ = await limiter.checkAndRecord(ip: "ip-a", success: false) - - let lockedResult = await limiter.isLockedOut(ip: "ip-a") - expectRateLimited(lockedResult, message: "IP-A should be locked") - - let resultB = await limiter.checkAndRecord(ip: "ip-b", success: false) - expectAllowed(resultB, message: "IP-B should not be affected by IP-A lockout") - } - - @Test("Success after failure resets counter") - func successResetsCounter() async { - let limiter = makeLimiter() - _ = await limiter.checkAndRecord(ip: "10.0.2.1", success: false) - _ = await limiter.checkAndRecord(ip: "10.0.2.1", success: true) - - let firstFail = await limiter.checkAndRecord(ip: "10.0.2.1", success: false) - expectAllowed(firstFail, message: "Counter should reset after success, so first failure again is allowed") - - let secondFail = await limiter.checkAndRecord(ip: "10.0.2.1", success: false) - expectRateLimited(secondFail, message: "Second failure after reset should lock out again") - } - - @Test("Empty IP string works") - func emptyIpString() async { - let limiter = makeLimiter() - let result = await limiter.checkAndRecord(ip: "", success: false) - expectAllowed(result, message: "First failure for empty IP should be allowed") - } - - @Test("Success on first call returns allowed without prior record") - func successOnFirstCall() async { - let limiter = makeLimiter() - let result = await limiter.checkAndRecord(ip: "10.0.3.1", success: true) - expectAllowed(result) - } - - @Test("Rapid sequential failures while locked do not escalate") - func rapidSequentialFailuresWhileLocked() async { - let limiter = makeLimiter() - let ip = "10.0.3.2" - - let result1 = await limiter.checkAndRecord(ip: ip, success: false) - expectAllowed(result1, message: "Failure 1 should be allowed") - - let result2 = await limiter.checkAndRecord(ip: ip, success: false) - guard let retry2 = expectRateLimited(result2, message: "Failure 2 should trigger lockout") else { return } - #expect(retry2.components.seconds >= 0 && retry2.components.seconds <= 2) - - let result3 = await limiter.checkAndRecord(ip: ip, success: false) - guard let retry3 = expectRateLimited(result3, message: "Failure 3 while locked returns remaining time") else { return } - #expect(retry3 <= retry2) - - let result4 = await limiter.checkAndRecord(ip: ip, success: false) - guard let retry4 = expectRateLimited(result4, message: "Failure 4 while locked returns remaining time") else { return } - #expect(retry4 <= retry3) - } - - @Test("isLockedOut returns allowed after single failure with no lockout") - func isLockedOutAfterSingleFailure() async { - let limiter = makeLimiter() - _ = await limiter.checkAndRecord(ip: "10.0.4.1", success: false) - let result = await limiter.isLockedOut(ip: "10.0.4.1") - expectAllowed(result, message: "Single failure sets no lockout") - } -} diff --git a/TableProTests/Core/MCP/MCPRouterTests.swift b/TableProTests/Core/MCP/MCPRouterTests.swift deleted file mode 100644 index f2c392eae..000000000 --- a/TableProTests/Core/MCP/MCPRouterTests.swift +++ /dev/null @@ -1,167 +0,0 @@ -// -// MCPRouterTests.swift -// TableProTests -// - -import Foundation -@testable import TablePro -import Testing - -@Suite("MCP Router") -struct MCPRouterTests { - private final class StubHandler: MCPRouteHandler, @unchecked Sendable { - let methods: [HTTPRequest.Method] - let path: String - private let result: MCPRouter.RouteResult - private(set) var invocationCount: Int = 0 - private(set) var lastRequest: HTTPRequest? - - init(methods: [HTTPRequest.Method], path: String, result: MCPRouter.RouteResult = .accepted) { - self.methods = methods - self.path = path - self.result = result - } - - func handle(_ request: HTTPRequest) async -> MCPRouter.RouteResult { - invocationCount += 1 - lastRequest = request - return result - } - } - - private func makeRequest( - method: HTTPRequest.Method, - path: String, - body: Data? = nil - ) -> HTTPRequest { - HTTPRequest(method: method, path: path, headers: [:], body: body, remoteIP: nil) - } - - @Test("OPTIONS preflight returns noContent regardless of path") - func optionsPreflightAlwaysNoContent() async { - let mcpHandler = StubHandler(methods: [.post], path: "/mcp", result: .accepted) - let router = MCPRouter(routes: [mcpHandler]) - - let optionsAtMcp = makeRequest(method: .options, path: "/mcp") - let result1 = await router.handle(optionsAtMcp) - guard case .noContent = result1 else { - Issue.record("Expected .noContent for OPTIONS /mcp, got \(result1)") - return - } - - let optionsAtUnknown = makeRequest(method: .options, path: "/unknown/path") - let result2 = await router.handle(optionsAtUnknown) - guard case .noContent = result2 else { - Issue.record("Expected .noContent for OPTIONS /unknown, got \(result2)") - return - } - - #expect(mcpHandler.invocationCount == 0) - } - - @Test("POST /mcp dispatches to MCP protocol handler") - func postMcpDispatchesToProtocolHandler() async { - let mcpHandler = StubHandler(methods: [.get, .post, .delete], path: "/mcp", result: .accepted) - let exchangeHandler = StubHandler(methods: [.post], path: "/v1/integrations/exchange", result: .accepted) - let router = MCPRouter(routes: [mcpHandler, exchangeHandler]) - - let request = makeRequest(method: .post, path: "/mcp") - _ = await router.handle(request) - - #expect(mcpHandler.invocationCount == 1) - #expect(exchangeHandler.invocationCount == 0) - } - - @Test("POST /v1/integrations/exchange dispatches to exchange handler") - func postExchangeDispatchesToExchangeHandler() async { - let mcpHandler = StubHandler(methods: [.get, .post, .delete], path: "/mcp", result: .accepted) - let exchangeHandler = StubHandler(methods: [.post], path: "/v1/integrations/exchange", result: .accepted) - let router = MCPRouter(routes: [mcpHandler, exchangeHandler]) - - let request = makeRequest(method: .post, path: "/v1/integrations/exchange") - _ = await router.handle(request) - - #expect(exchangeHandler.invocationCount == 1) - #expect(mcpHandler.invocationCount == 0) - } - - @Test("Path with query string still matches canonical route") - func queryStringMatchesCanonicalPath() async { - let mcpHandler = StubHandler(methods: [.post], path: "/mcp", result: .accepted) - let router = MCPRouter(routes: [mcpHandler]) - - let request = makeRequest(method: .post, path: "/mcp?session=abc") - _ = await router.handle(request) - - #expect(mcpHandler.invocationCount == 1) - } - - @Test("Unknown path returns 404 httpError") - func unknownPathReturnsNotFound() async { - let mcpHandler = StubHandler(methods: [.post], path: "/mcp", result: .accepted) - let router = MCPRouter(routes: [mcpHandler]) - - let request = makeRequest(method: .post, path: "/totally/unknown") - let result = await router.handle(request) - - guard case .httpError(let status, _) = result else { - Issue.record("Expected .httpError, got \(result)") - return - } - #expect(status == 404) - #expect(mcpHandler.invocationCount == 0) - } - - @Test("Method mismatch on registered path returns 404") - func methodMismatchReturnsNotFound() async { - let exchangeHandler = StubHandler(methods: [.post], path: "/v1/integrations/exchange", result: .accepted) - let router = MCPRouter(routes: [exchangeHandler]) - - let request = makeRequest(method: .get, path: "/v1/integrations/exchange") - let result = await router.handle(request) - - guard case .httpError(let status, _) = result else { - Issue.record("Expected .httpError, got \(result)") - return - } - #expect(status == 404) - #expect(exchangeHandler.invocationCount == 0) - } - - @Test(".well-known requests return 404 immediately") - func wellKnownReturnsNotFound() async { - let mcpHandler = StubHandler(methods: [.get], path: "/.well-known/oauth", result: .accepted) - let router = MCPRouter(routes: [mcpHandler]) - - let request = makeRequest(method: .get, path: "/.well-known/oauth") - let result = await router.handle(request) - - guard case .httpError(let status, _) = result else { - Issue.record("Expected .httpError, got \(result)") - return - } - #expect(status == 404) - #expect(mcpHandler.invocationCount == 0) - } - - @Test("Handler receives the original request") - func handlerReceivesOriginalRequest() async { - let mcpHandler = StubHandler(methods: [.post], path: "/mcp", result: .accepted) - let router = MCPRouter(routes: [mcpHandler]) - - let body = Data("{\"hello\":\"world\"}".utf8) - let request = HTTPRequest( - method: .post, - path: "/mcp", - headers: ["content-type": "application/json"], - body: body, - remoteIP: "10.0.0.1" - ) - _ = await router.handle(request) - - #expect(mcpHandler.lastRequest?.path == "/mcp") - #expect(mcpHandler.lastRequest?.method == .post) - #expect(mcpHandler.lastRequest?.body == body) - #expect(mcpHandler.lastRequest?.remoteIP == "10.0.0.1") - } -} diff --git a/TableProTests/Core/MCP/MCPTokenStoreTests.swift b/TableProTests/Core/MCP/MCPTokenStoreTests.swift index 757a9791c..e3cd12b3b 100644 --- a/TableProTests/Core/MCP/MCPTokenStoreTests.swift +++ b/TableProTests/Core/MCP/MCPTokenStoreTests.swift @@ -20,7 +20,7 @@ struct MCPTokenStoreTests { tokenHash: "fakehash", salt: "fakesalt", permissions: .readOnly, - allowedConnectionIds: nil, + connectionAccess: .all, createdAt: Date.now, lastUsedAt: nil, expiresAt: expiresAt, @@ -183,23 +183,31 @@ struct MCPTokenStoreTests { #expect(result.token.expiresAt != nil) } - @Test("generate with nil connectionIds stores nil") - func generateWithNilConnectionIds() async { + @Test("generate with .all access stores .all") + func generateWithAllAccess() async { let store = makeStore() - let result = await store.generate(name: "test", permissions: .readOnly, allowedConnectionIds: nil) + let result = await store.generate(name: "test", permissions: .readOnly, connectionAccess: .all) await store.delete(tokenId: result.token.id) - #expect(result.token.allowedConnectionIds == nil) + #expect(result.token.connectionAccess == .all) } - @Test("generate with specific connectionIds stores them") - func generateWithSpecificConnectionIds() async { + @Test("generate with .limited stores the connection ids") + func generateWithLimitedAccess() async { let ids: Set = [UUID(), UUID()] let store = makeStore() - let result = await store.generate(name: "test", permissions: .readOnly, allowedConnectionIds: ids) + let result = await store.generate( + name: "test", + permissions: .readOnly, + connectionAccess: .limited(ids) + ) await store.delete(tokenId: result.token.id) - #expect(result.token.allowedConnectionIds == ids) + if case .limited(let stored) = result.token.connectionAccess { + #expect(stored == ids) + } else { + Issue.record("Expected .limited connection access") + } } @Test("validate returns token for valid bearer") @@ -281,6 +289,26 @@ struct MCPTokenStoreTests { #expect(revokedToken.isActive == false) } + @Test("revoke fires registered revocation observers with token id") + func revokeNotifiesObservers() async { + let store = makeStore() + let result = await store.generate(name: "observed", permissions: .readOnly) + + let receivedBox = Lock(value: [String]()) + let observed = receivedBox + await store.addRevocationObserver { tokenIdString in + await observed.append(tokenIdString) + } + + await store.revoke(tokenId: result.token.id) + try? await Task.sleep(for: .milliseconds(50)) + await store.delete(tokenId: result.token.id) + try? await Task.sleep(for: .milliseconds(50)) + + let received = await receivedBox.snapshot() + #expect(received.contains(result.token.id.uuidString)) + } + @Test("delete removes token from list") func deleteRemovesTokenFromList() async { let store = makeStore() @@ -363,3 +391,19 @@ struct MCPTokenStoreTests { #expect(result1.plaintext != result2.plaintext) } } + +private actor Lock: Sendable { + private var value: Value + + init(value: Value) { + self.value = value + } + + func append(_ element: T) where Value == [T] { + value.append(element) + } + + func snapshot() -> Value { + value + } +} diff --git a/TableProTests/Core/MCP/MCPToolHandlerExportTests.swift b/TableProTests/Core/MCP/MCPToolHandlerExportTests.swift deleted file mode 100644 index b1be3ae00..000000000 --- a/TableProTests/Core/MCP/MCPToolHandlerExportTests.swift +++ /dev/null @@ -1,183 +0,0 @@ -// -// MCPToolHandlerExportTests.swift -// TableProTests -// - -import Foundation -import Testing - -@testable import TablePro - -@Suite("MCP Tool Handler — export_data validation", .serialized) -@MainActor -struct MCPToolHandlerExportTests { - private let storage = ConnectionStorage.shared - - private func makeHandler() -> MCPToolHandler { - MCPToolHandler(bridge: MCPConnectionBridge(), authGuard: MCPAuthGuard()) - } - - private func withConnections( - _ connections: [DatabaseConnection], - body: () async throws -> Void - ) async throws { - let original = storage.loadConnections() - defer { storage.saveConnections(original) } - storage.saveConnections(connections) - try await body() - } - - @Test("export_data rejects table name with SQL injection payload") - func exportDataRejectsInjectionInTableName() async throws { - let handler = makeHandler() - let connection = DatabaseConnection( - name: "Target", - type: .mysql, - aiPolicy: .alwaysAllow, - externalAccess: .readWrite - ) - try await withConnections([connection]) { - do { - _ = try await handler.handleToolCall( - name: "export_data", - arguments: .object([ - "connection_id": .string(connection.id.uuidString), - "format": .string("csv"), - "tables": .array([.string("users; DROP TABLE users;--")]) - ]), - sessionId: "test-session", - token: nil - ) - Issue.record("Expected MCPError.invalidParams for malicious table name") - } catch let error as MCPError { - if case .invalidParams = error { return } - Issue.record("Expected invalidParams, got \(error)") - } catch { - Issue.record("Expected MCPError, got \(error)") - } - } - } - - @Test("export_data rejects table name with quote payload") - func exportDataRejectsQuotePayload() async throws { - let handler = makeHandler() - let connection = DatabaseConnection( - name: "Target", - type: .mysql, - aiPolicy: .alwaysAllow, - externalAccess: .readWrite - ) - try await withConnections([connection]) { - do { - _ = try await handler.handleToolCall( - name: "export_data", - arguments: .object([ - "connection_id": .string(connection.id.uuidString), - "format": .string("csv"), - "tables": .array([.string("users`; DROP TABLE x;--")]) - ]), - sessionId: "test-session", - token: nil - ) - Issue.record("Expected MCPError.invalidParams for backtick injection") - } catch let error as MCPError { - if case .invalidParams = error { return } - Issue.record("Expected invalidParams, got \(error)") - } catch { - Issue.record("Expected MCPError, got \(error)") - } - } - } - - @Test("validateExportTableName accepts simple identifiers") - func validateExportTableNameAcceptsSimple() throws { - try MCPToolHandler.validateExportTableName("users") - try MCPToolHandler.validateExportTableName("users_v2") - try MCPToolHandler.validateExportTableName("public.users") - try MCPToolHandler.validateExportTableName("schema.table_name_42") - } - - @Test("validateExportTableName rejects spaces") - func validateExportTableNameRejectsSpaces() { - do { - try MCPToolHandler.validateExportTableName("users x") - Issue.record("Expected throw for table name with space") - } catch let error as MCPError { - if case .invalidParams = error { return } - Issue.record("Expected invalidParams, got \(error)") - } catch { - Issue.record("Expected MCPError, got \(error)") - } - } - - @Test("validateExportTableName rejects semicolon") - func validateExportTableNameRejectsSemicolon() { - do { - try MCPToolHandler.validateExportTableName("users;DROP TABLE x") - Issue.record("Expected throw for table name with semicolon") - } catch let error as MCPError { - if case .invalidParams = error { return } - Issue.record("Expected invalidParams, got \(error)") - } catch { - Issue.record("Expected MCPError, got \(error)") - } - } - - @Test("validateExportTableName rejects empty string") - func validateExportTableNameRejectsEmpty() { - do { - try MCPToolHandler.validateExportTableName("") - Issue.record("Expected throw for empty table name") - } catch let error as MCPError { - if case .invalidParams = error { return } - Issue.record("Expected invalidParams, got \(error)") - } catch { - Issue.record("Expected MCPError, got \(error)") - } - } - - @Test("validateExportTableName rejects leading dot") - func validateExportTableNameRejectsLeadingDot() { - do { - try MCPToolHandler.validateExportTableName(".users") - Issue.record("Expected throw for table name with leading dot") - } catch let error as MCPError { - if case .invalidParams = error { return } - Issue.record("Expected invalidParams, got \(error)") - } catch { - Issue.record("Expected MCPError, got \(error)") - } - } - - @Test("export_data rejects output_path outside Downloads") - func exportDataRejectsPathOutsideDownloads() async throws { - let handler = makeHandler() - let connection = DatabaseConnection( - name: "Target", - type: .mysql, - aiPolicy: .alwaysAllow, - externalAccess: .readWrite - ) - try await withConnections([connection]) { - do { - _ = try await handler.handleToolCall( - name: "export_data", - arguments: .object([ - "connection_id": .string(connection.id.uuidString), - "format": .string("csv"), - "query": .string("SELECT 1"), - "output_path": .string("/tmp/escape.csv") - ]), - sessionId: "test-session", - token: nil - ) - Issue.record("Expected MCPError.invalidParams for path outside Downloads") - } catch let error as MCPError { - if case .invalidParams = error { return } - Issue.record("Expected invalidParams, got \(error)") - } catch { - Issue.record("Expected MCPError, got \(error)") - } - } - } -} diff --git a/TableProTests/Core/MCP/MCPToolHandlerIntegrationTests.swift b/TableProTests/Core/MCP/MCPToolHandlerIntegrationTests.swift deleted file mode 100644 index 73eb404be..000000000 --- a/TableProTests/Core/MCP/MCPToolHandlerIntegrationTests.swift +++ /dev/null @@ -1,701 +0,0 @@ -// -// MCPToolHandlerIntegrationTests.swift -// TableProTests -// - -import Foundation -import Testing - -@testable import TablePro - -@Suite("MCP Tool Handler — integration tools", .serialized) -@MainActor -struct MCPToolHandlerIntegrationTests { - private let storage = ConnectionStorage.shared - - private func makeHandler() -> MCPToolHandler { - MCPToolHandler(bridge: MCPConnectionBridge(), authGuard: MCPAuthGuard()) - } - - private func makeToken( - permissions: TokenPermissions = .readWrite, - allowedConnectionIds: Set? = nil - ) -> MCPAuthToken { - MCPAuthToken( - id: UUID(), - name: "test-token", - prefix: "tp_test1", - tokenHash: "fakehash", - salt: "fakesalt", - permissions: permissions, - allowedConnectionIds: allowedConnectionIds, - createdAt: Date.now, - lastUsedAt: nil, - expiresAt: nil, - isActive: true - ) - } - - private func withConnections( - _ connections: [DatabaseConnection], - body: () async throws -> Void - ) async throws { - let original = storage.loadConnections() - defer { storage.saveConnections(original) } - storage.saveConnections(connections) - try await body() - } - - @Test("list_connections omits connections with externalAccess == .blocked") - func listConnectionsFiltersBlocked() async throws { - let handler = makeHandler() - let blocked = DatabaseConnection(name: "Blocked Prod", type: .mysql, externalAccess: .blocked) - let visible = DatabaseConnection(name: "Visible Staging", type: .mysql, externalAccess: .readOnly) - try await withConnections([blocked, visible]) { - let result = try await handler.handleToolCall( - name: "list_connections", - arguments: nil, - sessionId: "test-session", - token: nil - ) - #expect(result.isError == nil) - let payload = result.content.first?.text ?? "" - #expect(!payload.contains(blocked.id.uuidString)) - #expect(payload.contains(visible.id.uuidString)) - } - } - - @Test("list_recent_tabs returns tabs JSON object") - func listRecentTabsShape() async throws { - let handler = makeHandler() - let result = try await handler.handleToolCall( - name: "list_recent_tabs", - arguments: .object(["limit": .int(5)]), - sessionId: "test-session", - token: nil - ) - #expect(result.isError == nil) - #expect(result.content.first?.type == "text") - let payload = result.content.first?.text ?? "" - #expect(payload.contains("\"tabs\"")) - } - - @Test("blockedExternalConnectionIds returns ids of connections with externalAccess == .blocked") - func blockedExternalConnectionIdsHelper() async throws { - let blocked = DatabaseConnection(name: "Blocked", type: .mysql, externalAccess: .blocked) - let readOnly = DatabaseConnection(name: "ReadOnly", type: .mysql, aiPolicy: .alwaysAllow, externalAccess: .readOnly) - let readWrite = DatabaseConnection(name: "ReadWrite", type: .mysql, externalAccess: .readWrite) - try await withConnections([blocked, readOnly, readWrite]) { - let ids = MCPToolHandler.blockedExternalConnectionIds() - #expect(ids.contains(blocked.id)) - #expect(!ids.contains(readOnly.id)) - #expect(!ids.contains(readWrite.id)) - } - } - - @Test("list_recent_tabs requires read scope only") - func listRecentTabsScope() async throws { - let handler = makeHandler() - let token = makeToken(permissions: .readOnly) - let result = try await handler.handleToolCall( - name: "list_recent_tabs", - arguments: nil, - sessionId: "test-session", - token: token - ) - #expect(result.isError == nil) - } - - @Test("search_query_history rejects missing query parameter") - func searchQueryHistoryRequiresQuery() async { - let handler = makeHandler() - do { - _ = try await handler.handleToolCall( - name: "search_query_history", - arguments: nil, - sessionId: "test-session", - token: nil - ) - Issue.record("Expected MCPError.invalidParams when query is missing") - } catch let error as MCPError { - if case .invalidParams = error { - return - } - Issue.record("Expected invalidParams, got \(error)") - } catch { - Issue.record("Expected MCPError, got \(error)") - } - } - - @Test("search_query_history rejects invalid connection_id UUID") - func searchQueryHistoryRejectsInvalidUUID() async { - let handler = makeHandler() - do { - _ = try await handler.handleToolCall( - name: "search_query_history", - arguments: .object([ - "query": .string("SELECT"), - "connection_id": .string("not-a-uuid") - ]), - sessionId: "test-session", - token: nil - ) - Issue.record("Expected MCPError.invalidParams for malformed UUID") - } catch let error as MCPError { - if case .invalidParams = error { - return - } - Issue.record("Expected invalidParams, got \(error)") - } catch { - Issue.record("Expected MCPError, got \(error)") - } - } - - @Test("search_query_history with empty query returns entries object") - func searchQueryHistoryEmptyQuery() async throws { - let handler = makeHandler() - let result = try await handler.handleToolCall( - name: "search_query_history", - arguments: .object(["query": .string(""), "limit": .int(1)]), - sessionId: "test-session", - token: nil - ) - #expect(result.isError == nil) - let payload = result.content.first?.text ?? "" - #expect(payload.contains("\"entries\"")) - } - - @Test("search_query_history rejects since greater than until") - func searchQueryHistoryRejectsInvertedWindow() async { - let handler = makeHandler() - do { - _ = try await handler.handleToolCall( - name: "search_query_history", - arguments: .object([ - "query": .string(""), - "since": .double(2_000), - "until": .double(1_000) - ]), - sessionId: "test-session", - token: nil - ) - Issue.record("Expected MCPError.invalidParams when since > until") - } catch let error as MCPError { - if case .invalidParams = error { return } - Issue.record("Expected invalidParams, got \(error)") - } catch { - Issue.record("Expected MCPError, got \(error)") - } - } - - @Test("search_query_history rejects connection_id whose externalAccess is .blocked") - func searchQueryHistoryRejectsBlockedConnection() async throws { - let handler = makeHandler() - let blocked = DatabaseConnection(name: "Blocked Prod", type: .mysql, externalAccess: .blocked) - try await withConnections([blocked]) { - do { - _ = try await handler.handleToolCall( - name: "search_query_history", - arguments: .object([ - "query": .string(""), - "connection_id": .string(blocked.id.uuidString) - ]), - sessionId: "test-session", - token: nil - ) - Issue.record("Expected MCPError.forbidden for blocked connection") - } catch let error as MCPError { - if case .forbidden = error { return } - Issue.record("Expected forbidden, got \(error)") - } catch { - Issue.record("Expected MCPError, got \(error)") - } - } - } - - @Test("search_query_history filters out blocked connections when iterating without connection_id") - func searchQueryHistoryFiltersBlockedFromUnscopedQuery() async throws { - let handler = makeHandler() - let blocked = DatabaseConnection(name: "Blocked", type: .mysql, externalAccess: .blocked) - let visible = DatabaseConnection(name: "Visible", type: .mysql, externalAccess: .readOnly) - let marker = UUID().uuidString - - try await withConnections([blocked, visible]) { - let blockedEntry = QueryHistoryEntry( - query: "SELECT blocked_\(marker)", - connectionId: blocked.id, - databaseName: "db", - executionTime: 0.01, - rowCount: 1, - wasSuccessful: true - ) - let visibleEntry = QueryHistoryEntry( - query: "SELECT visible_\(marker)", - connectionId: visible.id, - databaseName: "db", - executionTime: 0.01, - rowCount: 1, - wasSuccessful: true - ) - _ = await QueryHistoryStorage.shared.addHistory(blockedEntry) - _ = await QueryHistoryStorage.shared.addHistory(visibleEntry) - - let result = try await handler.handleToolCall( - name: "search_query_history", - arguments: .object(["query": .string(marker)]), - sessionId: "test-session", - token: nil - ) - #expect(result.isError == nil) - let payload = result.content.first?.text ?? "" - #expect(payload.contains("visible_\(marker)")) - #expect(!payload.contains("blocked_\(marker)")) - } - } - - @Test("search_query_history pushes token allowlist into SQL so older allowed entries surface") - func searchQueryHistoryAllowlistOverFlood() async throws { - let handler = makeHandler() - let allowedConn = DatabaseConnection(name: "Allowed", type: .mysql) - let otherConn = DatabaseConnection(name: "Other", type: .mysql) - let marker = UUID().uuidString - let now = Date() - - try await withConnections([allowedConn, otherConn]) { - let oldAllowed = QueryHistoryEntry( - query: "SELECT old_allowed_\(marker)", - connectionId: allowedConn.id, - databaseName: "db", - executedAt: now.addingTimeInterval(-3_600), - executionTime: 0.01, - rowCount: 1, - wasSuccessful: true - ) - _ = await QueryHistoryStorage.shared.addHistory(oldAllowed) - - for index in 0..<20 { - let recentOther = QueryHistoryEntry( - query: "SELECT recent_other_\(marker)_\(index)", - connectionId: otherConn.id, - databaseName: "db", - executedAt: now.addingTimeInterval(Double(index)), - executionTime: 0.01, - rowCount: 1, - wasSuccessful: true - ) - _ = await QueryHistoryStorage.shared.addHistory(recentOther) - } - - let token = makeToken(allowedConnectionIds: [allowedConn.id]) - let result = try await handler.handleToolCall( - name: "search_query_history", - arguments: .object(["query": .string(marker), "limit": .int(5)]), - sessionId: "test-session", - token: token - ) - #expect(result.isError == nil) - let payload = result.content.first?.text ?? "" - #expect(payload.contains("old_allowed_\(marker)")) - #expect(!payload.contains("recent_other_\(marker)")) - } - } - - @Test("QueryHistoryStorage.fetchHistory restricts results to allowedConnectionIds") - func fetchHistoryAllowlistFilters() async throws { - let allowedId = UUID() - let otherId = UUID() - let marker = UUID().uuidString - - let allowedEntry = QueryHistoryEntry( - query: "SELECT allowed_\(marker)", - connectionId: allowedId, - databaseName: "db", - executionTime: 0.01, - rowCount: 1, - wasSuccessful: true - ) - let otherEntry = QueryHistoryEntry( - query: "SELECT other_\(marker)", - connectionId: otherId, - databaseName: "db", - executionTime: 0.01, - rowCount: 1, - wasSuccessful: true - ) - _ = await QueryHistoryStorage.shared.addHistory(allowedEntry) - _ = await QueryHistoryStorage.shared.addHistory(otherEntry) - - let entries = await QueryHistoryStorage.shared.fetchHistory( - limit: 100, - searchText: marker, - allowedConnectionIds: [allowedId] - ) - - #expect(entries.contains { $0.query.contains("allowed_\(marker)") }) - #expect(!entries.contains { $0.query.contains("other_\(marker)") }) - } - - @Test("QueryHistoryStorage.fetchHistory returns empty when allowedConnectionIds is empty") - func fetchHistoryEmptyAllowlistReturnsEmpty() async throws { - let connectionId = UUID() - let marker = UUID().uuidString - let entry = QueryHistoryEntry( - query: "SELECT empty_allowlist_\(marker)", - connectionId: connectionId, - databaseName: "db", - executionTime: 0.01, - rowCount: 1, - wasSuccessful: true - ) - _ = await QueryHistoryStorage.shared.addHistory(entry) - - let entries = await QueryHistoryStorage.shared.fetchHistory( - limit: 100, - searchText: marker, - allowedConnectionIds: [] - ) - - #expect(entries.isEmpty) - } - - @Test("search_query_history with since/until filters by executed_at window") - func searchQueryHistorySinceUntilFilters() async throws { - let handler = makeHandler() - let connId = UUID() - let now = Date() - let oneHourAgo = now.addingTimeInterval(-3_600) - let twoHoursAgo = now.addingTimeInterval(-7_200) - let marker = UUID().uuidString - - let outside = QueryHistoryEntry( - query: "SELECT outside_\(marker)", - connectionId: connId, - databaseName: "testdb", - executedAt: twoHoursAgo, - executionTime: 0.01, - rowCount: 1, - wasSuccessful: true - ) - let inside = QueryHistoryEntry( - query: "SELECT inside_\(marker)", - connectionId: connId, - databaseName: "testdb", - executedAt: oneHourAgo, - executionTime: 0.01, - rowCount: 1, - wasSuccessful: true - ) - _ = await QueryHistoryStorage.shared.addHistory(outside) - _ = await QueryHistoryStorage.shared.addHistory(inside) - - let result = try await handler.handleToolCall( - name: "search_query_history", - arguments: .object([ - "query": .string(marker), - "connection_id": .string(connId.uuidString), - "since": .double(now.addingTimeInterval(-5_400).timeIntervalSince1970), - "until": .double(now.timeIntervalSince1970) - ]), - sessionId: "test-session", - token: nil - ) - #expect(result.isError == nil) - let payload = result.content.first?.text ?? "" - #expect(payload.contains("inside_\(marker)")) - #expect(!payload.contains("outside_\(marker)")) - } - - @Test("switch_database against a readOnly connection returns forbidden") - func switchDatabaseDeniedByReadOnlyExternalAccess() async throws { - let handler = makeHandler() - let connection = DatabaseConnection(name: "ReadOnly", type: .mysql, aiPolicy: .alwaysAllow, externalAccess: .readOnly) - try await withConnections([connection]) { - do { - _ = try await handler.handleToolCall( - name: "switch_database", - arguments: .object([ - "connection_id": .string(connection.id.uuidString), - "database": .string("postgres") - ]), - sessionId: "test-session", - token: nil - ) - Issue.record("Expected MCPError.forbidden for readOnly externalAccess") - } catch let error as MCPError { - if case .forbidden = error { return } - Issue.record("Expected forbidden, got \(error)") - } catch { - Issue.record("Expected MCPError, got \(error)") - } - } - } - - @Test("switch_schema against a readOnly connection returns forbidden") - func switchSchemaDeniedByReadOnlyExternalAccess() async throws { - let handler = makeHandler() - let connection = DatabaseConnection(name: "ReadOnly", type: .postgresql, aiPolicy: .alwaysAllow, externalAccess: .readOnly) - try await withConnections([connection]) { - do { - _ = try await handler.handleToolCall( - name: "switch_schema", - arguments: .object([ - "connection_id": .string(connection.id.uuidString), - "schema": .string("public") - ]), - sessionId: "test-session", - token: nil - ) - Issue.record("Expected MCPError.forbidden for readOnly externalAccess") - } catch let error as MCPError { - if case .forbidden = error { return } - Issue.record("Expected forbidden, got \(error)") - } catch { - Issue.record("Expected MCPError, got \(error)") - } - } - } - - @Test("export_data against a readOnly connection returns forbidden") - func exportDataDeniedByReadOnlyExternalAccess() async throws { - let handler = makeHandler() - let connection = DatabaseConnection(name: "ReadOnly", type: .mysql, aiPolicy: .alwaysAllow, externalAccess: .readOnly) - try await withConnections([connection]) { - do { - _ = try await handler.handleToolCall( - name: "export_data", - arguments: .object([ - "connection_id": .string(connection.id.uuidString), - "format": .string("csv"), - "tables": .array([.string("users")]) - ]), - sessionId: "test-session", - token: nil - ) - Issue.record("Expected MCPError.forbidden for readOnly externalAccess") - } catch let error as MCPError { - if case .forbidden = error { return } - Issue.record("Expected forbidden, got \(error)") - } catch { - Issue.record("Expected MCPError, got \(error)") - } - } - } - - @Test("open_connection_window against a readOnly connection returns forbidden") - func openConnectionWindowDeniedByReadOnlyExternalAccess() async throws { - let handler = makeHandler() - let connection = DatabaseConnection(name: "ReadOnly", type: .mysql, aiPolicy: .alwaysAllow, externalAccess: .readOnly) - try await withConnections([connection]) { - do { - _ = try await handler.handleToolCall( - name: "open_connection_window", - arguments: .object(["connection_id": .string(connection.id.uuidString)]), - sessionId: "test-session", - token: nil - ) - Issue.record("Expected MCPError.forbidden for readOnly externalAccess") - } catch let error as MCPError { - if case .forbidden = error { return } - Issue.record("Expected forbidden, got \(error)") - } catch { - Issue.record("Expected MCPError, got \(error)") - } - } - } - - @Test("open_table_tab against a readOnly connection returns forbidden") - func openTableTabDeniedByReadOnlyExternalAccess() async throws { - let handler = makeHandler() - let connection = DatabaseConnection(name: "ReadOnly", type: .mysql, aiPolicy: .alwaysAllow, externalAccess: .readOnly) - try await withConnections([connection]) { - do { - _ = try await handler.handleToolCall( - name: "open_table_tab", - arguments: .object([ - "connection_id": .string(connection.id.uuidString), - "table_name": .string("users") - ]), - sessionId: "test-session", - token: nil - ) - Issue.record("Expected MCPError.forbidden for readOnly externalAccess") - } catch let error as MCPError { - if case .forbidden = error { return } - Issue.record("Expected forbidden, got \(error)") - } catch { - Issue.record("Expected MCPError, got \(error)") - } - } - } - - @Test("ExternalAccessLevel.satisfies follows blocked < readOnly < readWrite ordering") - func externalAccessLevelSatisfiesOrdering() { - #expect(ExternalAccessLevel.readWrite.satisfies(.readWrite)) - #expect(ExternalAccessLevel.readWrite.satisfies(.readOnly)) - #expect(ExternalAccessLevel.readOnly.satisfies(.readOnly)) - #expect(!ExternalAccessLevel.readOnly.satisfies(.readWrite)) - #expect(!ExternalAccessLevel.blocked.satisfies(.readOnly)) - #expect(!ExternalAccessLevel.blocked.satisfies(.readWrite)) - } - - @Test("open_connection_window rejects missing connection_id") - func openConnectionWindowRequiresConnectionId() async { - let handler = makeHandler() - do { - _ = try await handler.handleToolCall( - name: "open_connection_window", - arguments: nil, - sessionId: "test-session", - token: nil - ) - Issue.record("Expected MCPError.invalidParams") - } catch let error as MCPError { - if case .invalidParams = error { return } - Issue.record("Expected invalidParams, got \(error)") - } catch { - Issue.record("Expected MCPError, got \(error)") - } - } - - @Test("open_connection_window rejects unknown connection") - func openConnectionWindowRejectsUnknown() async throws { - let handler = makeHandler() - do { - _ = try await handler.handleToolCall( - name: "open_connection_window", - arguments: .object(["connection_id": .string(UUID().uuidString)]), - sessionId: "test-session", - token: nil - ) - Issue.record("Expected MCPError.notFound for unknown connection") - } catch let error as MCPError { - if case .notFound = error { return } - Issue.record("Expected notFound, got \(error)") - } catch { - Issue.record("Expected MCPError, got \(error)") - } - } - - @Test("open_connection_window denies read-only token") - func openConnectionWindowReadOnlyDenied() async throws { - let handler = makeHandler() - let token = makeToken(permissions: .readOnly) - do { - _ = try await handler.handleToolCall( - name: "open_connection_window", - arguments: .object(["connection_id": .string(UUID().uuidString)]), - sessionId: "test-session", - token: token - ) - Issue.record("Expected MCPError.forbidden for read-only token") - } catch let error as MCPError { - if case .forbidden = error { return } - Issue.record("Expected forbidden, got \(error)") - } catch { - Issue.record("Expected MCPError, got \(error)") - } - } - - @Test("open_connection_window respects token connection allowlist") - func openConnectionWindowAllowlist() async throws { - let handler = makeHandler() - let connection = DatabaseConnection(name: "Test", type: .mysql) - try await withConnections([connection]) { - let token = makeToken( - permissions: .readWrite, - allowedConnectionIds: [UUID()] - ) - do { - _ = try await handler.handleToolCall( - name: "open_connection_window", - arguments: .object(["connection_id": .string(connection.id.uuidString)]), - sessionId: "test-session", - token: token - ) - Issue.record("Expected MCPError.forbidden for disallowed connection") - } catch let error as MCPError { - if case .forbidden = error { return } - Issue.record("Expected forbidden, got \(error)") - } catch { - Issue.record("Expected MCPError, got \(error)") - } - } - } - - @Test("open_table_tab requires table_name") - func openTableTabRequiresTableName() async { - let handler = makeHandler() - do { - _ = try await handler.handleToolCall( - name: "open_table_tab", - arguments: .object(["connection_id": .string(UUID().uuidString)]), - sessionId: "test-session", - token: nil - ) - Issue.record("Expected MCPError.invalidParams") - } catch let error as MCPError { - if case .invalidParams = error { return } - Issue.record("Expected invalidParams, got \(error)") - } catch { - Issue.record("Expected MCPError, got \(error)") - } - } - - @Test("focus_query_tab returns notFound when tab is not open") - func focusQueryTabNotFound() async { - let handler = makeHandler() - do { - _ = try await handler.handleToolCall( - name: "focus_query_tab", - arguments: .object(["tab_id": .string(UUID().uuidString)]), - sessionId: "test-session", - token: nil - ) - Issue.record("Expected MCPError.notFound") - } catch let error as MCPError { - if case .notFound = error { return } - Issue.record("Expected notFound, got \(error)") - } catch { - Issue.record("Expected MCPError, got \(error)") - } - } - - @Test("focus_query_tab requires read-write token") - func focusQueryTabRequiresWriteScope() async { - let handler = makeHandler() - let token = makeToken(permissions: .readOnly) - do { - _ = try await handler.handleToolCall( - name: "focus_query_tab", - arguments: .object(["tab_id": .string(UUID().uuidString)]), - sessionId: "test-session", - token: token - ) - Issue.record("Expected MCPError.forbidden for read-only token") - } catch let error as MCPError { - if case .forbidden = error { return } - Issue.record("Expected forbidden, got \(error)") - } catch { - Issue.record("Expected MCPError, got \(error)") - } - } - - @Test("Unknown tool name throws methodNotFound") - func unknownToolThrows() async { - let handler = makeHandler() - do { - _ = try await handler.handleToolCall( - name: "totally_made_up_tool", - arguments: nil, - sessionId: "test-session", - token: nil - ) - Issue.record("Expected methodNotFound") - } catch let error as MCPError { - if case .methodNotFound = error { return } - Issue.record("Expected methodNotFound, got \(error)") - } catch { - Issue.record("Expected MCPError, got \(error)") - } - } -} diff --git a/TableProTests/Core/MCP/MCPToolHandlerSecurityTests.swift b/TableProTests/Core/MCP/MCPToolHandlerSecurityTests.swift deleted file mode 100644 index 36c595ec5..000000000 --- a/TableProTests/Core/MCP/MCPToolHandlerSecurityTests.swift +++ /dev/null @@ -1,87 +0,0 @@ -import Foundation -import Testing - -@testable import TablePro - -@Suite("MCP Tool Handler — identifier validation hardening") -struct MCPToolHandlerSecurityTests { - @Test("validateExportTableName rejects double-dot") - func rejectsDoubleDot() { - do { - try MCPToolHandler.validateExportTableName("schema..table") - Issue.record("Expected throw for double-dot table name") - } catch let error as MCPError { - if case .invalidParams = error { return } - Issue.record("Expected invalidParams, got \(error)") - } catch { - Issue.record("Expected MCPError, got \(error)") - } - } - - @Test("validateExportTableName rejects trailing dot") - func rejectsTrailingDot() { - do { - try MCPToolHandler.validateExportTableName("schema.") - Issue.record("Expected throw for trailing-dot table name") - } catch let error as MCPError { - if case .invalidParams = error { return } - Issue.record("Expected invalidParams, got \(error)") - } catch { - Issue.record("Expected MCPError, got \(error)") - } - } - - @Test("validateExportTableName rejects only dots") - func rejectsOnlyDots() { - do { - try MCPToolHandler.validateExportTableName("..") - Issue.record("Expected throw for dots-only table name") - } catch let error as MCPError { - if case .invalidParams = error { return } - Issue.record("Expected invalidParams, got \(error)") - } catch { - Issue.record("Expected MCPError, got \(error)") - } - } - - @Test("validateExportTableName accepts schema-qualified identifiers") - func acceptsValidQualified() throws { - try MCPToolHandler.validateExportTableName("public.users") - try MCPToolHandler.validateExportTableName("db.schema.table") - } - - @Test("quoteQualifiedIdentifier throws on empty component") - func quoteThrowsOnEmptyComponent() { - let quoter: (String) -> String = { "\"\($0)\"" } - do { - _ = try MCPToolHandler.quoteQualifiedIdentifier("schema..table", quoter: quoter) - Issue.record("Expected throw for empty component in qualified identifier") - } catch let error as MCPError { - if case .invalidParams = error { return } - Issue.record("Expected invalidParams, got \(error)") - } catch { - Issue.record("Expected MCPError, got \(error)") - } - } - - @Test("quoteQualifiedIdentifier throws on leading dot") - func quoteThrowsOnLeadingDot() { - let quoter: (String) -> String = { "\"\($0)\"" } - do { - _ = try MCPToolHandler.quoteQualifiedIdentifier(".table", quoter: quoter) - Issue.record("Expected throw for leading-dot identifier") - } catch let error as MCPError { - if case .invalidParams = error { return } - Issue.record("Expected invalidParams, got \(error)") - } catch { - Issue.record("Expected MCPError, got \(error)") - } - } - - @Test("quoteQualifiedIdentifier quotes each segment for valid identifiers") - func quoteQuotesValidSegments() throws { - let quoter: (String) -> String = { "\"\($0)\"" } - let result = try MCPToolHandler.quoteQualifiedIdentifier("public.users", quoter: quoter) - #expect(result == "\"public\".\"users\"") - } -} diff --git a/TableProTests/Core/MCP/Protocol/Handlers/InitializeHandlerTests.swift b/TableProTests/Core/MCP/Protocol/Handlers/InitializeHandlerTests.swift new file mode 100644 index 000000000..eb946ae0f --- /dev/null +++ b/TableProTests/Core/MCP/Protocol/Handlers/InitializeHandlerTests.swift @@ -0,0 +1,214 @@ +import Foundation +@testable import TablePro +import XCTest + +final class InitializeHandlerTests: XCTestCase { + func testHandlerMethodIsInitialize() { + XCTAssertEqual(InitializeHandler.method, "initialize") + } + + func testHandlerRequiresNoScopes() { + XCTAssertTrue(InitializeHandler.requiredScopes.isEmpty) + } + + func testHandlerOnlyAllowsUninitializedState() { + XCTAssertEqual(InitializeHandler.allowedSessionStates, [.uninitialized]) + } + + func testHappyPathReturnsServerInfoAndCapabilities() async throws { + let context = try await makeContext() + let handler = InitializeHandler() + let params: JsonValue = .object([ + "protocolVersion": .string("2025-11-25"), + "clientInfo": .object([ + "name": .string("test-client"), + "version": .string("1.2.3") + ]), + "capabilities": .object([:]) + ]) + + let response = try await handler.handle(params: params, context: context) + + guard case .successResponse(let success) = response else { + XCTFail("Expected success response, got \(response)") + return + } + + guard case .object(let result) = success.result else { + XCTFail("Expected object result") + return + } + + XCTAssertEqual(result["protocolVersion"]?.stringValue, "2025-11-25") + + guard let serverInfo = result["serverInfo"], case .object(let serverInfoDict) = serverInfo else { + XCTFail("Expected serverInfo object") + return + } + XCTAssertEqual(serverInfoDict["name"]?.stringValue, "tablepro") + XCTAssertNotNil(serverInfoDict["version"]?.stringValue) + + guard let capabilities = result["capabilities"], case .object(let capDict) = capabilities else { + XCTFail("Expected capabilities object") + return + } + XCTAssertNotNil(capDict["tools"]) + XCTAssertNotNil(capDict["resources"]) + XCTAssertNotNil(capDict["prompts"]) + XCTAssertNotNil(capDict["logging"]) + XCTAssertNotNil(capDict["completions"]) + } + + func testEchoesBackEachSupportedProtocolVersion() async throws { + for version in ["2025-03-26", "2025-06-18", "2025-11-25"] { + let context = try await makeContext() + let handler = InitializeHandler() + let params: JsonValue = .object([ + "protocolVersion": .string(version), + "clientInfo": .object(["name": .string("client")]) + ]) + + let response = try await handler.handle(params: params, context: context) + guard case .successResponse(let success) = response, + case .object(let result) = success.result else { + XCTFail("Expected success object for version \(version)") + return + } + XCTAssertEqual(result["protocolVersion"]?.stringValue, version) + + let negotiated = await context.session.negotiatedProtocolVersion + XCTAssertEqual(negotiated, version) + } + } + + func testRecordsClientInfoOnSession() async throws { + let context = try await makeContext() + let handler = InitializeHandler() + let params: JsonValue = .object([ + "protocolVersion": .string("2025-06-18"), + "clientInfo": .object([ + "name": .string("acme-cli"), + "version": .string("9.9.9") + ]), + "capabilities": .object(["x": .bool(true)]) + ]) + + _ = try await handler.handle(params: params, context: context) + + let info = await context.session.clientInfo + XCTAssertEqual(info?.name, "acme-cli") + XCTAssertEqual(info?.version, "9.9.9") + + let negotiated = await context.session.negotiatedProtocolVersion + XCTAssertEqual(negotiated, "2025-06-18") + + let recordedCapabilities = await context.session.clientCapabilities + XCTAssertEqual(recordedCapabilities, .object(["x": .bool(true)])) + } + + func testMissingClientInfoFallsBackToUnknown() async throws { + let context = try await makeContext() + let handler = InitializeHandler() + + _ = try await handler.handle(params: nil, context: context) + + let info = await context.session.clientInfo + XCTAssertEqual(info?.name, "unknown") + XCTAssertNil(info?.version) + } + + func testRejectsRepeatedInitializeOnSameSession() async throws { + let context = try await makeContext() + let handler = InitializeHandler() + let params: JsonValue = .object([ + "protocolVersion": .string("2025-11-25"), + "clientInfo": .object(["name": .string("first")]) + ]) + + _ = try await handler.handle(params: params, context: context) + + do { + _ = try await handler.handle(params: params, context: context) + XCTFail("Expected handler to throw on second initialize") + } catch let error as MCPProtocolError { + XCTAssertEqual(error.code, JsonRpcErrorCode.invalidRequest) + } + } + + func testUnknownProtocolVersionDowngradesToLatest() async throws { + let context = try await makeContext() + let handler = InitializeHandler() + let params: JsonValue = .object([ + "protocolVersion": .string("1999-01-01"), + "clientInfo": .object(["name": .string("vintage")]) + ]) + + let response = try await handler.handle(params: params, context: context) + guard case .successResponse(let success) = response, + case .object(let result) = success.result else { + XCTFail("Expected success object") + return + } + XCTAssertEqual(result["protocolVersion"]?.stringValue, InitializeHandler.supportedProtocolVersion) + XCTAssertEqual(InitializeHandler.supportedProtocolVersion, "2025-11-25") + + let negotiated = await context.session.negotiatedProtocolVersion + XCTAssertEqual(negotiated, "2025-11-25") + } + + func testNewerUnknownProtocolVersionDowngradesToLatest() async throws { + let context = try await makeContext() + let handler = InitializeHandler() + let params: JsonValue = .object([ + "protocolVersion": .string("2099-01-01"), + "clientInfo": .object(["name": .string("future")]) + ]) + + let response = try await handler.handle(params: params, context: context) + guard case .successResponse(let success) = response, + case .object(let result) = success.result else { + XCTFail("Expected success object") + return + } + XCTAssertEqual(result["protocolVersion"]?.stringValue, "2025-11-25") + } + + func testMissingProtocolVersionFallsBackToSupported() async throws { + let context = try await makeContext() + let handler = InitializeHandler() + + _ = try await handler.handle(params: .object([:]), context: context) + + let negotiated = await context.session.negotiatedProtocolVersion + XCTAssertEqual(negotiated, InitializeHandler.supportedProtocolVersion) + } + + private func makeContext() async throws -> MCPRequestContext { + let store = MCPSessionStore() + let session = try await store.create() + let sessionId = await session.id + let progressSink = StubProgressSink() + let dispatcher = MCPProtocolDispatcher( + handlers: [InitializeHandler()], + sessionStore: store, + progressSink: progressSink + ) + let request = MCPProtocolTestSupport.makeRequest(method: "initialize") + let (exchange, _) = MCPProtocolTestSupport.makeExchange(message: request, sessionId: sessionId) + let token = MCPCancellationToken() + let emitter = MCPProgressEmitter( + progressToken: nil, + target: progressSink, + sessionId: sessionId + ) + return MCPRequestContext( + exchange: exchange, + session: session, + principal: MCPProtocolTestSupport.makePrincipal(), + dispatcher: dispatcher, + progress: emitter, + cancellation: token, + clock: MCPSystemClock() + ) + } +} diff --git a/TableProTests/Core/MCP/Protocol/Handlers/LoggingSetLevelHandlerTests.swift b/TableProTests/Core/MCP/Protocol/Handlers/LoggingSetLevelHandlerTests.swift new file mode 100644 index 000000000..3cdfbf61a --- /dev/null +++ b/TableProTests/Core/MCP/Protocol/Handlers/LoggingSetLevelHandlerTests.swift @@ -0,0 +1,100 @@ +import Foundation +@testable import TablePro +import XCTest + +final class LoggingSetLevelHandlerTests: XCTestCase { + func testMethodIsLoggingSetLevel() { + XCTAssertEqual(LoggingSetLevelHandler.method, "logging/setLevel") + } + + func testRequiresNoScopes() { + XCTAssertTrue(LoggingSetLevelHandler.requiredScopes.isEmpty) + } + + func testAcceptsKnownLevels() async throws { + for level in ["debug", "info", "notice", "warning", "error", "critical", "alert", "emergency"] { + let (handler, context) = try await makeContext() + let params: JsonValue = .object(["level": .string(level)]) + let response = try await handler.handle(params: params, context: context) + + guard case .successResponse(let success) = response else { + XCTFail("Expected success response for level \(level)") + return + } + XCTAssertEqual(success.result, .object([:])) + } + } + + func testAcceptsUppercaseLevels() async throws { + let (handler, context) = try await makeContext() + let params: JsonValue = .object(["level": .string("WARNING")]) + let response = try await handler.handle(params: params, context: context) + + guard case .successResponse = response else { + XCTFail("Expected success response for uppercase level") + return + } + } + + func testRejectsUnknownLevel() async throws { + let (handler, context) = try await makeContext() + let params: JsonValue = .object(["level": .string("verbose")]) + + do { + _ = try await handler.handle(params: params, context: context) + XCTFail("Expected MCPProtocolError") + } catch let error as MCPProtocolError { + XCTAssertEqual(error.code, JsonRpcErrorCode.invalidParams) + } + } + + func testRejectsMissingLevel() async throws { + let (handler, context) = try await makeContext() + + do { + _ = try await handler.handle(params: .object([:]), context: context) + XCTFail("Expected MCPProtocolError") + } catch let error as MCPProtocolError { + XCTAssertEqual(error.code, JsonRpcErrorCode.invalidParams) + } + } + + private func makeContext( + clock: any MCPClock = MCPSystemClock() + ) async throws -> (LoggingSetLevelHandler, MCPRequestContext) { + let store = MCPSessionStore(clock: clock) + let session = try await store.create() + try await session.transitionToReady() + let progressSink = StubProgressSink() + let dispatcher = MCPProtocolDispatcher( + handlers: [LoggingSetLevelHandler()], + sessionStore: store, + progressSink: progressSink, + clock: clock + ) + let request = MCPProtocolTestSupport.makeRequest(method: "logging/setLevel") + let principal = MCPProtocolTestSupport.makePrincipal(scopes: []) + let sessionId = await session.id + let (exchange, _) = MCPProtocolTestSupport.makeExchange( + message: request, + sessionId: sessionId, + principal: principal + ) + let token = MCPCancellationToken() + let emitter = MCPProgressEmitter( + progressToken: nil, + target: progressSink, + sessionId: sessionId + ) + let context = MCPRequestContext( + exchange: exchange, + session: session, + principal: principal, + dispatcher: dispatcher, + progress: emitter, + cancellation: token, + clock: clock + ) + return (LoggingSetLevelHandler(), context) + } +} diff --git a/TableProTests/Core/MCP/Protocol/Handlers/PingHandlerTests.swift b/TableProTests/Core/MCP/Protocol/Handlers/PingHandlerTests.swift new file mode 100644 index 000000000..62bf1bbaf --- /dev/null +++ b/TableProTests/Core/MCP/Protocol/Handlers/PingHandlerTests.swift @@ -0,0 +1,76 @@ +import Foundation +@testable import TablePro +import XCTest + +final class PingHandlerTests: XCTestCase { + func testHandlerMethodIsPing() { + XCTAssertEqual(PingHandler.method, "ping") + } + + func testHandlerRequiresNoScopes() { + XCTAssertTrue(PingHandler.requiredScopes.isEmpty) + } + + func testHandlerAllowsReadyAndUninitializedStates() { + XCTAssertTrue(PingHandler.allowedSessionStates.contains(.ready)) + } + + func testReturnsEmptyResult() async throws { + let (handler, context, _) = try await makeContext() + + let response = try await handler.handle(params: nil, context: context) + + guard case .successResponse(let success) = response else { + XCTFail("Expected success response, got \(response)") + return + } + XCTAssertEqual(success.result, .object([:])) + } + + func testTouchesSessionLastActivity() async throws { + let clock = MCPTestClock(start: Date(timeIntervalSince1970: 1_700_000_000)) + let (handler, context, session) = try await makeContext(clock: clock) + + let initialActivity = await session.lastActivityAt + await clock.advance(by: .seconds(120)) + + _ = try await handler.handle(params: nil, context: context) + + let after = await session.lastActivityAt + XCTAssertGreaterThan(after, initialActivity) + XCTAssertEqual(after, Date(timeIntervalSince1970: 1_700_000_000 + 120)) + } + + private func makeContext( + clock: any MCPClock = MCPSystemClock() + ) async throws -> (PingHandler, MCPRequestContext, MCPSession) { + let store = MCPSessionStore(clock: clock) + let session = try await store.create() + let sessionId = await session.id + let progressSink = StubProgressSink() + let dispatcher = MCPProtocolDispatcher( + handlers: [PingHandler()], + sessionStore: store, + progressSink: progressSink, + clock: clock + ) + let request = MCPProtocolTestSupport.makeRequest(method: "ping") + let (exchange, _) = MCPProtocolTestSupport.makeExchange(message: request, sessionId: sessionId) + let token = MCPCancellationToken() + let emitter = MCPProgressEmitter( + progressToken: nil, + target: progressSink, + sessionId: sessionId + ) + let context = MCPRequestContext( + exchange: exchange, + session: session, + principal: MCPProtocolTestSupport.makePrincipal(), + dispatcher: dispatcher, + progress: emitter, + cancellation: token, + clock: clock + ) + return (PingHandler(), context, session) + } +} diff --git a/TableProTests/Core/MCP/Protocol/Handlers/PromptsListHandlerTests.swift b/TableProTests/Core/MCP/Protocol/Handlers/PromptsListHandlerTests.swift new file mode 100644 index 000000000..915d247f7 --- /dev/null +++ b/TableProTests/Core/MCP/Protocol/Handlers/PromptsListHandlerTests.swift @@ -0,0 +1,68 @@ +import Foundation +@testable import TablePro +import XCTest + +final class PromptsListHandlerTests: XCTestCase { + func testMethodIsPromptsList() { + XCTAssertEqual(PromptsListHandler.method, "prompts/list") + } + + func testRequiresNoScopes() { + XCTAssertTrue(PromptsListHandler.requiredScopes.isEmpty) + } + + func testAllowedInReadyState() { + XCTAssertEqual(PromptsListHandler.allowedSessionStates, [.ready]) + } + + func testReturnsEmptyList() async throws { + let (handler, context) = try await makeContext() + let response = try await handler.handle(params: nil, context: context) + + guard case .successResponse(let success) = response else { + XCTFail("Expected success response, got \(response)") + return + } + + XCTAssertEqual(success.result, .object(["prompts": .array([])])) + } + + private func makeContext( + clock: any MCPClock = MCPSystemClock() + ) async throws -> (PromptsListHandler, MCPRequestContext) { + let store = MCPSessionStore(clock: clock) + let session = try await store.create() + try await session.transitionToReady() + let progressSink = StubProgressSink() + let dispatcher = MCPProtocolDispatcher( + handlers: [PromptsListHandler()], + sessionStore: store, + progressSink: progressSink, + clock: clock + ) + let request = MCPProtocolTestSupport.makeRequest(method: "prompts/list") + let principal = MCPProtocolTestSupport.makePrincipal(scopes: []) + let sessionId = await session.id + let (exchange, _) = MCPProtocolTestSupport.makeExchange( + message: request, + sessionId: sessionId, + principal: principal + ) + let token = MCPCancellationToken() + let emitter = MCPProgressEmitter( + progressToken: nil, + target: progressSink, + sessionId: sessionId + ) + let context = MCPRequestContext( + exchange: exchange, + session: session, + principal: principal, + dispatcher: dispatcher, + progress: emitter, + cancellation: token, + clock: clock + ) + return (PromptsListHandler(), context) + } +} diff --git a/TableProTests/Core/MCP/Protocol/Handlers/ResourcesListHandlerTests.swift b/TableProTests/Core/MCP/Protocol/Handlers/ResourcesListHandlerTests.swift new file mode 100644 index 000000000..01364de1c --- /dev/null +++ b/TableProTests/Core/MCP/Protocol/Handlers/ResourcesListHandlerTests.swift @@ -0,0 +1,91 @@ +import Foundation +@testable import TablePro +import XCTest + +final class ResourcesListHandlerTests: XCTestCase { + func testMethodIsResourcesList() { + XCTAssertEqual(ResourcesListHandler.method, "resources/list") + } + + func testRequiresResourcesReadScope() { + XCTAssertEqual(ResourcesListHandler.requiredScopes, [.resourcesRead]) + } + + func testAllowedInReadyState() { + XCTAssertEqual(ResourcesListHandler.allowedSessionStates, [.ready]) + } + + func testReturnsConnectionsResource() async throws { + let (handler, context) = try await makeContext() + let response = try await handler.handle(params: nil, context: context) + + guard case .successResponse(let success) = response else { + XCTFail("Expected success response, got \(response)") + return + } + + let resources = success.result["resources"]?.arrayValue + XCTAssertNotNil(resources) + let uris = resources?.compactMap { $0["uri"]?.stringValue } ?? [] + XCTAssertTrue(uris.contains("tablepro://connections")) + } + + func testEntriesIncludeNameAndMimeType() async throws { + let (handler, context) = try await makeContext() + let response = try await handler.handle(params: nil, context: context) + + guard case .successResponse(let success) = response, + let resources = success.result["resources"]?.arrayValue, + let connections = resources.first(where: { $0["uri"]?.stringValue == "tablepro://connections" }) + else { + XCTFail("Expected connections resource") + return + } + + XCTAssertNotNil(connections["name"]?.stringValue) + XCTAssertEqual(connections["mimeType"]?.stringValue, "application/json") + } + + private func makeContext( + clock: any MCPClock = MCPSystemClock() + ) async throws -> (ResourcesListHandler, MCPRequestContext) { + let store = MCPSessionStore(clock: clock) + let session = try await store.create() + try await session.transitionToReady() + let progressSink = StubProgressSink() + let services = MCPToolServices( + connectionBridge: MCPConnectionBridge(), + authPolicy: MCPAuthPolicy() + ) + let dispatcher = MCPProtocolDispatcher( + handlers: [ResourcesListHandler(services: services)], + sessionStore: store, + progressSink: progressSink, + clock: clock + ) + let request = MCPProtocolTestSupport.makeRequest(method: "resources/list") + let principal = MCPProtocolTestSupport.makePrincipal(scopes: [.resourcesRead]) + let sessionId = await session.id + let (exchange, _) = MCPProtocolTestSupport.makeExchange( + message: request, + sessionId: sessionId, + principal: principal + ) + let token = MCPCancellationToken() + let emitter = MCPProgressEmitter( + progressToken: nil, + target: progressSink, + sessionId: sessionId + ) + let context = MCPRequestContext( + exchange: exchange, + session: session, + principal: principal, + dispatcher: dispatcher, + progress: emitter, + cancellation: token, + clock: clock + ) + return (ResourcesListHandler(services: services), context) + } +} diff --git a/TableProTests/Core/MCP/Protocol/Handlers/ResourcesReadHandlerTests.swift b/TableProTests/Core/MCP/Protocol/Handlers/ResourcesReadHandlerTests.swift new file mode 100644 index 000000000..9b789812a --- /dev/null +++ b/TableProTests/Core/MCP/Protocol/Handlers/ResourcesReadHandlerTests.swift @@ -0,0 +1,133 @@ +import Foundation +@testable import TablePro +import XCTest + +final class ResourcesReadHandlerTests: XCTestCase { + func testMethodIsResourcesRead() { + XCTAssertEqual(ResourcesReadHandler.method, "resources/read") + } + + func testRequiresResourcesReadScope() { + XCTAssertEqual(ResourcesReadHandler.requiredScopes, [.resourcesRead]) + } + + func testReadsConnectionsList() async throws { + let (handler, context) = try await makeContext() + let params: JsonValue = .object(["uri": .string("tablepro://connections")]) + + let response = try await handler.handle(params: params, context: context) + + guard case .successResponse(let success) = response else { + XCTFail("Expected success response, got \(response)") + return + } + + let contents = success.result["contents"]?.arrayValue + XCTAssertEqual(contents?.count, 1) + let entry = contents?.first + XCTAssertEqual(entry?["uri"]?.stringValue, "tablepro://connections") + XCTAssertEqual(entry?["mimeType"]?.stringValue, "application/json") + XCTAssertNotNil(entry?["text"]?.stringValue) + } + + func testMissingUriThrowsInvalidParams() async throws { + let (handler, context) = try await makeContext() + do { + _ = try await handler.handle(params: .object([:]), context: context) + XCTFail("Expected MCPProtocolError") + } catch let error as MCPProtocolError { + XCTAssertEqual(error.code, JsonRpcErrorCode.invalidParams) + } + } + + func testInvalidUriThrowsInvalidParams() async throws { + let (handler, context) = try await makeContext() + let params: JsonValue = .object(["uri": .string("not a url at all spaces")]) + + do { + _ = try await handler.handle(params: params, context: context) + XCTFail("Expected MCPProtocolError") + } catch let error as MCPProtocolError { + XCTAssertEqual(error.code, JsonRpcErrorCode.invalidParams) + } + } + + func testNonTableproSchemeRejected() async throws { + let (handler, context) = try await makeContext() + let params: JsonValue = .object(["uri": .string("https://example.com/foo")]) + + do { + _ = try await handler.handle(params: params, context: context) + XCTFail("Expected MCPProtocolError") + } catch let error as MCPProtocolError { + XCTAssertEqual(error.code, JsonRpcErrorCode.invalidParams) + } + } + + func testUnknownPathReturnsMethodNotFound() async throws { + let (handler, context) = try await makeContext() + let params: JsonValue = .object(["uri": .string("tablepro://unknown/resource")]) + + do { + _ = try await handler.handle(params: params, context: context) + XCTFail("Expected MCPProtocolError") + } catch let error as MCPProtocolError { + XCTAssertEqual(error.code, JsonRpcErrorCode.methodNotFound) + } + } + + func testInvalidUuidInSchemaPathRejected() async throws { + let (handler, context) = try await makeContext() + let params: JsonValue = .object(["uri": .string("tablepro://connections/not-a-uuid/schema")]) + + do { + _ = try await handler.handle(params: params, context: context) + XCTFail("Expected MCPProtocolError") + } catch let error as MCPProtocolError { + XCTAssertEqual(error.code, JsonRpcErrorCode.invalidParams) + } + } + + private func makeContext( + clock: any MCPClock = MCPSystemClock() + ) async throws -> (ResourcesReadHandler, MCPRequestContext) { + let store = MCPSessionStore(clock: clock) + let session = try await store.create() + try await session.transitionToReady() + let progressSink = StubProgressSink() + let services = MCPToolServices( + connectionBridge: MCPConnectionBridge(), + authPolicy: MCPAuthPolicy() + ) + let dispatcher = MCPProtocolDispatcher( + handlers: [ResourcesReadHandler(services: services)], + sessionStore: store, + progressSink: progressSink, + clock: clock + ) + let request = MCPProtocolTestSupport.makeRequest(method: "resources/read") + let principal = MCPProtocolTestSupport.makePrincipal(scopes: [.resourcesRead]) + let sessionId = await session.id + let (exchange, _) = MCPProtocolTestSupport.makeExchange( + message: request, + sessionId: sessionId, + principal: principal + ) + let token = MCPCancellationToken() + let emitter = MCPProgressEmitter( + progressToken: nil, + target: progressSink, + sessionId: sessionId + ) + let context = MCPRequestContext( + exchange: exchange, + session: session, + principal: principal, + dispatcher: dispatcher, + progress: emitter, + cancellation: token, + clock: clock + ) + return (ResourcesReadHandler(services: services), context) + } +} diff --git a/TableProTests/Core/MCP/Protocol/Handlers/ToolsCallHandlerTests.swift b/TableProTests/Core/MCP/Protocol/Handlers/ToolsCallHandlerTests.swift new file mode 100644 index 000000000..602dc87df --- /dev/null +++ b/TableProTests/Core/MCP/Protocol/Handlers/ToolsCallHandlerTests.swift @@ -0,0 +1,141 @@ +import Foundation +@testable import TablePro +import Testing + +@Suite("ToolsCallHandler") +struct ToolsCallHandlerTests { + @Test("Unknown tool returns method not found") + func unknownTool() async throws { + let handler = makeHandler() + let context = await MCPProtocolHandlerTestSupport.makeContext(method: "tools/call") + let params: JsonValue = .object([ + "name": .string("nonexistent_tool"), + "arguments": .object([:]) + ]) + + await #expect(throws: MCPProtocolError.self) { + _ = try await handler.handle(params: params, context: context) + } + } + + @Test("Missing tool name returns invalid params") + func missingToolName() async throws { + let handler = makeHandler() + let context = await MCPProtocolHandlerTestSupport.makeContext(method: "tools/call") + let params: JsonValue = .object(["arguments": .object([:])]) + + await #expect(throws: MCPProtocolError.self) { + _ = try await handler.handle(params: params, context: context) + } + } + + @Test("Non-object params return invalid params") + func nonObjectParams() async throws { + let handler = makeHandler() + let context = await MCPProtocolHandlerTestSupport.makeContext(method: "tools/call") + let params: JsonValue = .string("oops") + + await #expect(throws: MCPProtocolError.self) { + _ = try await handler.handle(params: params, context: context) + } + } + + @Test("Insufficient scope returns forbidden") + func insufficientScope() async throws { + let handler = makeHandler() + let context = await MCPProtocolHandlerTestSupport.makeContext( + method: "tools/call", + principalScopes: [] + ) + let params: JsonValue = .object([ + "name": .string("list_connections"), + "arguments": .object([:]) + ]) + + await #expect(throws: MCPProtocolError.self) { + _ = try await handler.handle(params: params, context: context) + } + } + + @Test("list_connections returns content array") + func listConnectionsHappyPath() async throws { + let handler = makeHandler() + let context = await MCPProtocolHandlerTestSupport.makeContext(method: "tools/call") + let params: JsonValue = .object([ + "name": .string("list_connections"), + "arguments": .object([:]) + ]) + + let response = try await handler.handle(params: params, context: context) + guard case .successResponse(let success) = response else { + Issue.record("expected success, got \(response)") + return + } + let content = success.result["content"]?.arrayValue + #expect(content != nil) + #expect(content?.first?["type"]?.stringValue == "text") + } + + @Test("list_connections includes structuredContent for 2025-11-25 clients") + func listConnectionsExposesStructuredContent() async throws { + let handler = makeHandler() + let context = await MCPProtocolHandlerTestSupport.makeContext(method: "tools/call") + let params: JsonValue = .object([ + "name": .string("list_connections"), + "arguments": .object([:]) + ]) + + let response = try await handler.handle(params: params, context: context) + guard case .successResponse(let success) = response else { + Issue.record("expected success, got \(response)") + return + } + let structured = success.result["structuredContent"] + #expect(structured != nil) + if case .object = structured { + // ok + } else { + Issue.record("expected structuredContent to be an object") + } + } + + @Test("get_table_ddl with missing connection_id returns invalid params") + func getTableDdlMissingId() async throws { + let handler = makeHandler() + let context = await MCPProtocolHandlerTestSupport.makeContext(method: "tools/call") + let params: JsonValue = .object([ + "name": .string("get_table_ddl"), + "arguments": .object([ + "table": .string("users") + ]) + ]) + + await #expect(throws: MCPProtocolError.self) { + _ = try await handler.handle(params: params, context: context) + } + } + + @Test("list_tables with malformed connection_id returns invalid params") + func listTablesMalformedId() async throws { + let handler = makeHandler() + let context = await MCPProtocolHandlerTestSupport.makeContext(method: "tools/call") + let params: JsonValue = .object([ + "name": .string("list_tables"), + "arguments": .object([ + "connection_id": .string("not-a-uuid") + ]) + ]) + + await #expect(throws: MCPProtocolError.self) { + _ = try await handler.handle(params: params, context: context) + } + } + + private func makeHandler() -> ToolsCallHandler { + let services = MCPToolServices( + connectionBridge: MCPConnectionBridge(), + authPolicy: MCPAuthPolicy() + ) + return ToolsCallHandler(services: services) + } +} diff --git a/TableProTests/Core/MCP/Protocol/Handlers/ToolsListHandlerTests.swift b/TableProTests/Core/MCP/Protocol/Handlers/ToolsListHandlerTests.swift new file mode 100644 index 000000000..6e3ea4418 --- /dev/null +++ b/TableProTests/Core/MCP/Protocol/Handlers/ToolsListHandlerTests.swift @@ -0,0 +1,133 @@ +import Foundation +@testable import TablePro +import Testing + +@Suite("ToolsListHandler") +struct ToolsListHandlerTests { + @Test("Lists all 19 tools from the registry") + func listsAllRegisteredTools() async throws { + let response = try await runToolsList() + let names = response["tools"]?.arrayValue?.compactMap { $0["name"]?.stringValue } ?? [] + + let expected: Set = [ + "list_connections", + "get_connection_status", + "list_databases", + "list_schemas", + "list_tables", + "describe_table", + "get_table_ddl", + "list_recent_tabs", + "search_query_history", + "focus_query_tab", + "connect", + "disconnect", + "switch_database", + "switch_schema", + "execute_query", + "export_data", + "confirm_destructive_operation", + "open_table_tab", + "open_connection_window" + ] + + #expect(Set(names) == expected) + #expect(names.count == 19) + } + + @Test("Each tool has name, description, and inputSchema") + func eachToolHasShapeFields() async throws { + let response = try await runToolsList() + let tools = response["tools"]?.arrayValue ?? [] + + for tool in tools { + let name = tool["name"]?.stringValue + let description = tool["description"]?.stringValue + let schema = tool["inputSchema"] + #expect(name != nil) + #expect(description?.isEmpty == false) + #expect(schema != nil) + } + } + + @Test("Each input schema is a JSON Schema object") + func inputSchemasAreObjects() async throws { + let response = try await runToolsList() + let tools = response["tools"]?.arrayValue ?? [] + + for tool in tools { + guard case .object(let schema) = tool["inputSchema"] else { + Issue.record("inputSchema not an object for tool \(tool["name"]?.stringValue ?? "?")") + continue + } + #expect(schema["type"]?.stringValue == "object") + #expect(schema["properties"] != nil) + #expect(schema["required"] != nil) + } + } + + @Test("Each tool exposes annotations with hints") + func toolsExposeAnnotations() async throws { + let response = try await runToolsList() + let tools = response["tools"]?.arrayValue ?? [] + + for tool in tools { + guard let name = tool["name"]?.stringValue else { + Issue.record("missing tool name") + continue + } + guard case .object(let annotations) = tool["annotations"] else { + Issue.record("missing annotations for tool \(name)") + continue + } + #expect(annotations["title"]?.stringValue?.isEmpty == false) + #expect(annotations["readOnlyHint"]?.boolValue != nil) + #expect(annotations["destructiveHint"]?.boolValue != nil) + #expect(annotations["idempotentHint"]?.boolValue != nil) + #expect(annotations["openWorldHint"]?.boolValue != nil) + } + } + + @Test("Read tools advertise readOnlyHint=true") + func readToolsAreReadOnly() async throws { + let response = try await runToolsList() + let tools = response["tools"]?.arrayValue ?? [] + + let readOnlyExpected: Set = [ + "list_connections", + "get_connection_status", + "list_databases", + "list_schemas", + "list_tables", + "describe_table", + "get_table_ddl", + "list_recent_tabs", + "search_query_history" + ] + for tool in tools { + guard let name = tool["name"]?.stringValue, readOnlyExpected.contains(name) else { continue } + #expect(tool["annotations"]?["readOnlyHint"]?.boolValue == true) + } + } + + @Test("confirm_destructive_operation advertises destructiveHint=true") + func destructiveToolFlagged() async throws { + let response = try await runToolsList() + let tools = response["tools"]?.arrayValue ?? [] + let target = tools.first { $0["name"]?.stringValue == "confirm_destructive_operation" } + #expect(target != nil) + #expect(target?["annotations"]?["destructiveHint"]?.boolValue == true) + } + + private func runToolsList() async throws -> JsonValue { + let handler = ToolsListHandler() + let context = await MCPProtocolHandlerTestSupport.makeContext(method: "tools/list") + let message = try await handler.handle(params: nil, context: context) + + guard case .successResponse(let response) = message else { + Issue.record("expected success response, got \(message)") + return .null + } + return response.result + } +} diff --git a/TableProTests/Core/MCP/Protocol/MCPArgumentDecoderTests.swift b/TableProTests/Core/MCP/Protocol/MCPArgumentDecoderTests.swift new file mode 100644 index 000000000..65cb1e6db --- /dev/null +++ b/TableProTests/Core/MCP/Protocol/MCPArgumentDecoderTests.swift @@ -0,0 +1,167 @@ +import Foundation +@testable import TablePro +import Testing + +@Suite("MCP Argument Decoder") +struct MCPArgumentDecoderTests { + @Test("requireString returns string when present") + func requireStringPresent() throws { + let args: JsonValue = .object(["name": .string("hello")]) + let value = try MCPArgumentDecoder.requireString(args, key: "name") + #expect(value == "hello") + } + + @Test("requireString throws when missing") + func requireStringMissing() { + let args: JsonValue = .object([:]) + #expect(throws: MCPProtocolError.self) { + _ = try MCPArgumentDecoder.requireString(args, key: "name") + } + } + + @Test("requireString throws when wrong type") + func requireStringWrongType() { + let args: JsonValue = .object(["name": .int(5)]) + #expect(throws: MCPProtocolError.self) { + _ = try MCPArgumentDecoder.requireString(args, key: "name") + } + } + + @Test("optionalString returns nil when missing") + func optionalStringMissing() { + let args: JsonValue = .object([:]) + let value = MCPArgumentDecoder.optionalString(args, key: "name") + #expect(value == nil) + } + + @Test("optionalString returns value when present") + func optionalStringPresent() { + let args: JsonValue = .object(["name": .string("foo")]) + let value = MCPArgumentDecoder.optionalString(args, key: "name") + #expect(value == "foo") + } + + @Test("requireUuid parses a valid UUID string") + func requireUuidValid() throws { + let id = UUID() + let args: JsonValue = .object(["connection_id": .string(id.uuidString)]) + let value = try MCPArgumentDecoder.requireUuid(args, key: "connection_id") + #expect(value == id) + } + + @Test("requireUuid throws on malformed string") + func requireUuidInvalid() { + let args: JsonValue = .object(["connection_id": .string("not-a-uuid")]) + #expect(throws: MCPProtocolError.self) { + _ = try MCPArgumentDecoder.requireUuid(args, key: "connection_id") + } + } + + @Test("requireUuid throws when missing") + func requireUuidMissing() { + let args: JsonValue = .object([:]) + #expect(throws: MCPProtocolError.self) { + _ = try MCPArgumentDecoder.requireUuid(args, key: "connection_id") + } + } + + @Test("optionalUuid returns nil when missing") + func optionalUuidMissing() throws { + let args: JsonValue = .object([:]) + let value = try MCPArgumentDecoder.optionalUuid(args, key: "connection_id") + #expect(value == nil) + } + + @Test("optionalUuid throws on invalid value") + func optionalUuidInvalid() { + let args: JsonValue = .object(["connection_id": .string("bad")]) + #expect(throws: MCPProtocolError.self) { + _ = try MCPArgumentDecoder.optionalUuid(args, key: "connection_id") + } + } + + @Test("requireInt returns value") + func requireIntPresent() throws { + let args: JsonValue = .object(["count": .int(7)]) + let value = try MCPArgumentDecoder.requireInt(args, key: "count") + #expect(value == 7) + } + + @Test("requireInt throws when missing") + func requireIntMissing() { + let args: JsonValue = .object([:]) + #expect(throws: MCPProtocolError.self) { + _ = try MCPArgumentDecoder.requireInt(args, key: "count") + } + } + + @Test("optionalInt returns default when missing") + func optionalIntMissing() { + let args: JsonValue = .object([:]) + let value = MCPArgumentDecoder.optionalInt(args, key: "count", default: 42) + #expect(value == 42) + } + + @Test("optionalInt clamps within range") + func optionalIntClamps() { + let args: JsonValue = .object(["count": .int(1_000)]) + let value = MCPArgumentDecoder.optionalInt(args, key: "count", default: nil, clamp: 1...100) + #expect(value == 100) + } + + @Test("optionalInt clamps lower bound") + func optionalIntClampLower() { + let args: JsonValue = .object(["count": .int(-5)]) + let value = MCPArgumentDecoder.optionalInt(args, key: "count", default: nil, clamp: 1...100) + #expect(value == 1) + } + + @Test("optionalInt returns default when missing without clamp") + func optionalIntDefault() { + let args: JsonValue = .object([:]) + let value = MCPArgumentDecoder.optionalInt(args, key: "count", default: 5) + #expect(value == 5) + } + + @Test("optionalBool returns default when missing") + func optionalBoolDefault() { + let args: JsonValue = .object([:]) + #expect(MCPArgumentDecoder.optionalBool(args, key: "flag", default: true)) + #expect(!MCPArgumentDecoder.optionalBool(args, key: "flag", default: false)) + } + + @Test("optionalBool returns value when present") + func optionalBoolPresent() { + let args: JsonValue = .object(["flag": .bool(true)]) + #expect(MCPArgumentDecoder.optionalBool(args, key: "flag", default: false)) + } + + @Test("optionalDouble returns int as double") + func optionalDoubleFromInt() { + let args: JsonValue = .object(["value": .int(3)]) + #expect(MCPArgumentDecoder.optionalDouble(args, key: "value") == 3.0) + } + + @Test("optionalStringArray returns nil when missing") + func optionalStringArrayMissing() { + let args: JsonValue = .object([:]) + let value = MCPArgumentDecoder.optionalStringArray(args, key: "tables") + #expect(value == nil) + } + + @Test("optionalStringArray returns nil when empty") + func optionalStringArrayEmpty() { + let args: JsonValue = .object(["tables": .array([])]) + let value = MCPArgumentDecoder.optionalStringArray(args, key: "tables") + #expect(value == nil) + } + + @Test("optionalStringArray collects strings") + func optionalStringArrayCollects() { + let args: JsonValue = .object([ + "tables": .array([.string("a"), .string("b"), .int(3)]) + ]) + let value = MCPArgumentDecoder.optionalStringArray(args, key: "tables") + #expect(value == ["a", "b"]) + } +} diff --git a/TableProTests/Core/MCP/Protocol/MCPCancellationTokenTests.swift b/TableProTests/Core/MCP/Protocol/MCPCancellationTokenTests.swift new file mode 100644 index 000000000..0470062cd --- /dev/null +++ b/TableProTests/Core/MCP/Protocol/MCPCancellationTokenTests.swift @@ -0,0 +1,117 @@ +import Foundation +@testable import TablePro +import XCTest + +final class MCPCancellationTokenTests: XCTestCase { + func testNewTokenIsNotCancelled() async { + let token = MCPCancellationToken() + let cancelled = await token.isCancelled() + XCTAssertFalse(cancelled) + } + + func testIsCancelledAfterCancel() async { + let token = MCPCancellationToken() + await token.cancel() + let cancelled = await token.isCancelled() + XCTAssertTrue(cancelled) + } + + func testOnCancelHandlerRunsWhenCancelFires() async { + let token = MCPCancellationToken() + let flag = ObservedFlag() + await token.onCancel { + await flag.set() + } + + let beforeCancel = await flag.value() + XCTAssertFalse(beforeCancel) + + await token.cancel() + + let afterCancel = await flag.value() + XCTAssertTrue(afterCancel) + } + + func testOnCancelRegisteredAfterCancelRunsImmediately() async { + let token = MCPCancellationToken() + await token.cancel() + + let flag = ObservedFlag() + await token.onCancel { + await flag.set() + } + + let value = await flag.value() + XCTAssertTrue(value) + } + + func testMultipleOnCancelHandlersAllInvoked() async { + let token = MCPCancellationToken() + let flagA = ObservedFlag() + let flagB = ObservedFlag() + let flagC = ObservedFlag() + + await token.onCancel { await flagA.set() } + await token.onCancel { await flagB.set() } + await token.onCancel { await flagC.set() } + + await token.cancel() + + let valueA = await flagA.value() + let valueB = await flagB.value() + let valueC = await flagC.value() + XCTAssertTrue(valueA) + XCTAssertTrue(valueB) + XCTAssertTrue(valueC) + } + + func testCancelTwiceIsIdempotent() async { + let token = MCPCancellationToken() + let counter = HandlerInvocationCounter() + await token.onCancel { + await counter.increment() + } + + await token.cancel() + await token.cancel() + + let count = await counter.value() + XCTAssertEqual(count, 1) + + let cancelled = await token.isCancelled() + XCTAssertTrue(cancelled) + } + + func testThrowIfCancelledThrowsAfterCancel() async { + let token = MCPCancellationToken() + await token.cancel() + do { + try await token.throwIfCancelled() + XCTFail("Expected CancellationError to be thrown") + } catch is CancellationError { + } catch { + XCTFail("Unexpected error: \(error)") + } + } + + func testThrowIfCancelledDoesNotThrowWhenNotCancelled() async { + let token = MCPCancellationToken() + do { + try await token.throwIfCancelled() + } catch { + XCTFail("Unexpected error: \(error)") + } + } +} + +private actor HandlerInvocationCounter { + private var invocations: Int = 0 + + func increment() { + invocations += 1 + } + + func value() -> Int { + invocations + } +} diff --git a/TableProTests/Core/MCP/Protocol/MCPInflightRegistryTests.swift b/TableProTests/Core/MCP/Protocol/MCPInflightRegistryTests.swift new file mode 100644 index 000000000..1373b13e5 --- /dev/null +++ b/TableProTests/Core/MCP/Protocol/MCPInflightRegistryTests.swift @@ -0,0 +1,154 @@ +import Foundation +@testable import TablePro +import XCTest + +final class MCPInflightRegistryTests: XCTestCase { + func testCancelByRequestIdAndSessionIdCancelsToken() async { + let registry = MCPInflightRegistry() + let token = MCPCancellationToken() + let sessionId = MCPSessionId("session-1") + let requestId = JsonRpcId.number(42) + + await registry.register(requestId: requestId, sessionId: sessionId, token: token) + await registry.cancel(requestId: requestId, sessionId: sessionId) + + let cancelled = await token.isCancelled() + XCTAssertTrue(cancelled) + } + + func testRegisterSameKeyTwiceLatestWins() async { + let registry = MCPInflightRegistry() + let firstToken = MCPCancellationToken() + let secondToken = MCPCancellationToken() + let sessionId = MCPSessionId("session-2") + let requestId = JsonRpcId.string("req-x") + + await registry.register(requestId: requestId, sessionId: sessionId, token: firstToken) + await registry.register(requestId: requestId, sessionId: sessionId, token: secondToken) + + await registry.cancel(requestId: requestId, sessionId: sessionId) + + let firstCancelled = await firstToken.isCancelled() + let secondCancelled = await secondToken.isCancelled() + + XCTAssertFalse(firstCancelled) + XCTAssertTrue(secondCancelled) + } + + func testCancelNonexistentEntryIsNoop() async { + let registry = MCPInflightRegistry() + let sessionId = MCPSessionId("session-3") + let requestId = JsonRpcId.number(99) + + await registry.cancel(requestId: requestId, sessionId: sessionId) + let count = await registry.count() + XCTAssertEqual(count, 0) + } + + func testRemoveDropsEntryAndSubsequentCancelIsNoop() async { + let registry = MCPInflightRegistry() + let token = MCPCancellationToken() + let sessionId = MCPSessionId("session-4") + let requestId = JsonRpcId.number(7) + + await registry.register(requestId: requestId, sessionId: sessionId, token: token) + await registry.remove(requestId: requestId, sessionId: sessionId) + + let countAfterRemove = await registry.count() + XCTAssertEqual(countAfterRemove, 0) + + await registry.cancel(requestId: requestId, sessionId: sessionId) + let cancelled = await token.isCancelled() + XCTAssertFalse(cancelled) + } + + func testEntriesAreScopedBySessionId() async { + let registry = MCPInflightRegistry() + let tokenA = MCPCancellationToken() + let tokenB = MCPCancellationToken() + let sessionA = MCPSessionId("session-A") + let sessionB = MCPSessionId("session-B") + let requestId = JsonRpcId.number(1) + + await registry.register(requestId: requestId, sessionId: sessionA, token: tokenA) + await registry.register(requestId: requestId, sessionId: sessionB, token: tokenB) + + await registry.cancel(requestId: requestId, sessionId: sessionA) + + let cancelledA = await tokenA.isCancelled() + let cancelledB = await tokenB.isCancelled() + + XCTAssertTrue(cancelledA) + XCTAssertFalse(cancelledB) + } + + func testCancelAllMatchingTokenIdCancelsOnlyMatching() async { + let registry = MCPInflightRegistry() + let tokenA = MCPCancellationToken() + let tokenB = MCPCancellationToken() + let tokenC = MCPCancellationToken() + let session = MCPSessionId("session-revoked") + let revokedTokenId = UUID() + let otherTokenId = UUID() + + await registry.register( + requestId: .number(1), + sessionId: session, + token: tokenA, + tokenId: revokedTokenId + ) + await registry.register( + requestId: .number(2), + sessionId: session, + token: tokenB, + tokenId: revokedTokenId + ) + await registry.register( + requestId: .number(3), + sessionId: session, + token: tokenC, + tokenId: otherTokenId + ) + + let cancelledSessions = await registry.cancelAll(matchingTokenId: revokedTokenId) + XCTAssertEqual(cancelledSessions, [session]) + + let cancelledA = await tokenA.isCancelled() + let cancelledB = await tokenB.isCancelled() + let cancelledC = await tokenC.isCancelled() + XCTAssertTrue(cancelledA) + XCTAssertTrue(cancelledB) + XCTAssertFalse(cancelledC) + + let count = await registry.count() + XCTAssertEqual(count, 1) + } + + func testCountReflectsActiveRegistrations() async { + let registry = MCPInflightRegistry() + let session = MCPSessionId("session-count") + + await registry.register( + requestId: .number(1), + sessionId: session, + token: MCPCancellationToken() + ) + await registry.register( + requestId: .number(2), + sessionId: session, + token: MCPCancellationToken() + ) + await registry.register( + requestId: .number(3), + sessionId: session, + token: MCPCancellationToken() + ) + + let count = await registry.count() + XCTAssertEqual(count, 3) + + await registry.remove(requestId: .number(2), sessionId: session) + let countAfter = await registry.count() + XCTAssertEqual(countAfter, 2) + } +} diff --git a/TableProTests/Core/MCP/Protocol/MCPProgressEmitterTests.swift b/TableProTests/Core/MCP/Protocol/MCPProgressEmitterTests.swift new file mode 100644 index 000000000..a7830f8e1 --- /dev/null +++ b/TableProTests/Core/MCP/Protocol/MCPProgressEmitterTests.swift @@ -0,0 +1,160 @@ +import Foundation +@testable import TablePro +import XCTest + +final class MCPProgressEmitterTests: XCTestCase { + func testEmitWithoutProgressTokenIsNoop() async { + let sink = StubProgressSink() + let emitter = MCPProgressEmitter( + progressToken: nil, + target: sink, + sessionId: MCPSessionId("session-1") + ) + + await emitter.emit(progress: 0.5) + await emitter.emit(progress: 1.0, total: 1.0, message: "done") + + let count = await sink.count() + XCTAssertEqual(count, 0) + } + + func testEmitWithProgressTokenSendsNotification() async { + let sink = StubProgressSink() + let token = JsonValue.string("progress-token-1") + let emitter = MCPProgressEmitter( + progressToken: token, + target: sink, + sessionId: MCPSessionId("session-2") + ) + + await emitter.emit(progress: 0.42) + + let notifications = await sink.notifications + XCTAssertEqual(notifications.count, 1) + + guard let first = notifications.first else { + XCTFail("Expected at least one notification") + return + } + XCTAssertEqual(first.notification.method, "notifications/progress") + XCTAssertEqual(first.sessionId, MCPSessionId("session-2")) + + guard case .object(let params) = first.notification.params else { + XCTFail("Expected object params") + return + } + XCTAssertEqual(params["progressToken"], token) + XCTAssertEqual(params["progress"], .double(0.42)) + XCTAssertNil(params["total"]) + XCTAssertNil(params["message"]) + } + + func testEmitIncludesTotalAndMessageWhenProvided() async { + let sink = StubProgressSink() + let token = JsonValue.int(123) + let emitter = MCPProgressEmitter( + progressToken: token, + target: sink, + sessionId: MCPSessionId("session-3") + ) + + await emitter.emit(progress: 5.0, total: 10.0, message: "halfway there") + + let notifications = await sink.notifications + XCTAssertEqual(notifications.count, 1) + guard let first = notifications.first, + case .object(let params) = first.notification.params else { + XCTFail("Expected notification with object params") + return + } + XCTAssertEqual(params["progressToken"], token) + XCTAssertEqual(params["progress"], .double(5.0)) + XCTAssertEqual(params["total"], .double(10.0)) + XCTAssertEqual(params["message"], .string("halfway there")) + } + + func testMultipleEmitsQueueInOrder() async { + let sink = StubProgressSink() + let token = JsonValue.string("queue-token") + let emitter = MCPProgressEmitter( + progressToken: token, + target: sink, + sessionId: MCPSessionId("session-4") + ) + + await emitter.emit(progress: 0.1) + await emitter.emit(progress: 0.2) + await emitter.emit(progress: 0.3, message: "third") + + let notifications = await sink.notifications + XCTAssertEqual(notifications.count, 3) + + XCTAssertEqual(progressValue(in: notifications[0].notification), 0.1) + XCTAssertEqual(progressValue(in: notifications[1].notification), 0.2) + XCTAssertEqual(progressValue(in: notifications[2].notification), 0.3) + XCTAssertEqual(messageValue(in: notifications[2].notification), "third") + } + + func testEmitNotificationSendsCustomMethod() async { + let sink = StubProgressSink() + let emitter = MCPProgressEmitter( + progressToken: nil, + target: sink, + sessionId: MCPSessionId("session-5") + ) + + await emitter.emitNotification(method: "custom/event", params: .object(["x": .int(1)])) + + let notifications = await sink.notifications + XCTAssertEqual(notifications.count, 1) + XCTAssertEqual(notifications.first?.notification.method, "custom/event") + } + + func testHasProgressTokenReflectsState() async { + let sink = StubProgressSink() + let withToken = MCPProgressEmitter( + progressToken: .string("t"), + target: sink, + sessionId: MCPSessionId("s") + ) + let withoutToken = MCPProgressEmitter( + progressToken: nil, + target: sink, + sessionId: MCPSessionId("s") + ) + + let hasA = await withToken.hasProgressToken + let hasB = await withoutToken.hasProgressToken + + XCTAssertTrue(hasA) + XCTAssertFalse(hasB) + } + + func testExtractProgressTokenReadsMetaField() { + let params: JsonValue = .object([ + "_meta": .object(["progressToken": .string("abc-123")]) + ]) + + let token = MCPProgressEmitter.extractProgressToken(from: params) + XCTAssertEqual(token, .string("abc-123")) + } + + func testExtractProgressTokenReturnsNilWhenAbsent() { + let withoutMeta: JsonValue = .object(["foo": .int(1)]) + let withMetaButNoToken: JsonValue = .object(["_meta": .object([:])]) + + XCTAssertNil(MCPProgressEmitter.extractProgressToken(from: withoutMeta)) + XCTAssertNil(MCPProgressEmitter.extractProgressToken(from: withMetaButNoToken)) + XCTAssertNil(MCPProgressEmitter.extractProgressToken(from: nil)) + } + + private func progressValue(in notification: JsonRpcNotification) -> Double? { + guard case .object(let params) = notification.params else { return nil } + return params["progress"]?.doubleValue + } + + private func messageValue(in notification: JsonRpcNotification) -> String? { + guard case .object(let params) = notification.params else { return nil } + return params["message"]?.stringValue + } +} diff --git a/TableProTests/Core/MCP/Protocol/MCPProtocolDispatcherTests.swift b/TableProTests/Core/MCP/Protocol/MCPProtocolDispatcherTests.swift new file mode 100644 index 000000000..9d9184806 --- /dev/null +++ b/TableProTests/Core/MCP/Protocol/MCPProtocolDispatcherTests.swift @@ -0,0 +1,412 @@ +import Foundation +@testable import TablePro +import XCTest + +final class MCPProtocolDispatcherTests: XCTestCase { + func testMethodNotFoundReturnsErrorResponse() async throws { + let store = MCPSessionStore() + let session = try await store.create() + let sessionId = await session.id + let dispatcher = MCPProtocolDispatcher( + handlers: [InitializeHandler(), PingHandler()], + sessionStore: store, + progressSink: StubProgressSink() + ) + + let request = MCPProtocolTestSupport.makeRequest( + id: .number(1), + method: "unknown/method" + ) + let (exchange, sink) = MCPProtocolTestSupport.makeExchange( + message: request, + sessionId: sessionId + ) + + await dispatcher.dispatch(exchange) + await sink.waitForCompletion() + + let decoded = try await sink.firstJsonMessage() + guard case .errorResponse(let envelope) = decoded else { + XCTFail("Expected error response, got \(String(describing: decoded))") + return + } + XCTAssertEqual(envelope.error.code, JsonRpcErrorCode.methodNotFound) + XCTAssertEqual(envelope.id, .number(1)) + } + + func testUninitializedSessionRejectsNonInitializeMethods() async throws { + let store = MCPSessionStore() + let session = try await store.create() + let sessionId = await session.id + let dispatcher = MCPProtocolDispatcher( + handlers: [InitializeHandler(), StubToolsListHandler()], + sessionStore: store, + progressSink: StubProgressSink() + ) + + let request = MCPProtocolTestSupport.makeRequest( + id: .number(2), + method: "tools/list" + ) + let (exchange, sink) = MCPProtocolTestSupport.makeExchange( + message: request, + sessionId: sessionId + ) + + await dispatcher.dispatch(exchange) + await sink.waitForCompletion() + + let decoded = try await sink.firstJsonMessage() + guard case .errorResponse(let envelope) = decoded else { + XCTFail("Expected error response, got \(String(describing: decoded))") + return + } + XCTAssertEqual(envelope.error.code, JsonRpcErrorCode.invalidRequest) + } + + func testInitializeCreatesSessionAndNotificationTransitionsToReady() async throws { + let store = MCPSessionStore() + let dispatcher = MCPProtocolDispatcher( + handlers: [InitializeHandler()], + sessionStore: store, + progressSink: StubProgressSink() + ) + + let eventStream = await store.events + let collectorTask = Task { + for await event in eventStream { + if case .created(let id) = event { + return id + } + } + return nil + } + + let initRequest = MCPProtocolTestSupport.makeRequest( + id: .number(10), + method: "initialize", + params: .object([ + "protocolVersion": .string("2025-03-26"), + "clientInfo": .object(["name": .string("client-x")]), + "capabilities": .object([:]) + ]) + ) + let (initExchange, initSink) = MCPProtocolTestSupport.makeExchange(message: initRequest) + + await dispatcher.dispatch(initExchange) + await initSink.waitForCompletion() + + let initResponse = try await initSink.firstJsonMessage() + guard case .successResponse = initResponse else { + XCTFail("Expected success response, got \(String(describing: initResponse))") + return + } + + let sessionCount = await store.count() + XCTAssertEqual(sessionCount, 1) + + guard let createdId = await collectorTask.value else { + XCTFail("Expected the dispatcher to have created a session") + return + } + guard let session = await store.session(id: createdId) else { + XCTFail("Expected to find created session in store") + return + } + let sessionId = await session.id + + let stateAfterInitialize = await session.state + XCTAssertEqual(stateAfterInitialize, .initializing) + + let initializedNotification = MCPProtocolTestSupport.makeNotification( + method: "notifications/initialized" + ) + let (notifExchange, notifSink) = MCPProtocolTestSupport.makeExchange( + message: initializedNotification, + sessionId: sessionId + ) + + await dispatcher.dispatch(notifExchange) + await notifSink.waitForCompletion() + + let stateAfterNotification = await session.state + XCTAssertEqual(stateAfterNotification, .ready) + + let acceptedCount = await notifSink.acceptedCount + XCTAssertEqual(acceptedCount, 1) + } + + func testAuthScopeCheckRejectsInsufficientScopes() async throws { + let store = MCPSessionStore() + let session = try await store.create() + let sessionId = await session.id + try await session.transitionToReady() + + let dispatcher = MCPProtocolDispatcher( + handlers: [ScopedToolsCallHandler()], + sessionStore: store, + progressSink: StubProgressSink() + ) + + let principal = MCPProtocolTestSupport.makePrincipal(scopes: [.toolsRead]) + let request = MCPProtocolTestSupport.makeRequest( + id: .number(3), + method: "tools/call" + ) + let (exchange, sink) = MCPProtocolTestSupport.makeExchange( + message: request, + sessionId: sessionId, + principal: principal + ) + + await dispatcher.dispatch(exchange) + await sink.waitForCompletion() + + let decoded = try await sink.firstJsonMessage() + guard case .errorResponse(let envelope) = decoded else { + XCTFail("Expected error response, got \(String(describing: decoded))") + return + } + XCTAssertEqual(envelope.error.code, JsonRpcErrorCode.forbidden) + } + + func testCancellationFlowDeliversCancelledError() async throws { + let store = MCPSessionStore() + let session = try await store.create() + let sessionId = await session.id + try await session.transitionToReady() + + let stubHandler = StubMethodHandler(behavior: .waitForCancellation) + let dispatcher = MCPProtocolDispatcher( + handlers: [stubHandler], + sessionStore: store, + progressSink: StubProgressSink() + ) + let stubMethod = StubMethodHandler.method + + let requestId = JsonRpcId.number(7) + let request = MCPProtocolTestSupport.makeRequest(id: requestId, method: stubMethod) + let (exchange, sink) = MCPProtocolTestSupport.makeExchange( + message: request, + sessionId: sessionId + ) + + let dispatchTask = Task { + await dispatcher.dispatch(exchange) + } + + try await waitUntil(timeoutMs: 2_000) { + await stubHandler.started.value() + } + + let cancelNotification = MCPProtocolTestSupport.makeNotification( + method: "notifications/cancelled", + params: .object(["requestId": .int(7)]) + ) + let (cancelExchange, cancelSink) = MCPProtocolTestSupport.makeExchange( + message: cancelNotification, + sessionId: sessionId + ) + + await dispatcher.dispatch(cancelExchange) + await cancelSink.waitForCompletion() + + await dispatchTask.value + await sink.waitForCompletion() + + let decoded = try await sink.firstJsonMessage() + guard case .errorResponse(let envelope) = decoded else { + XCTFail("Expected error response, got \(String(describing: decoded))") + return + } + XCTAssertEqual(envelope.error.code, JsonRpcErrorCode.requestCancelled) + + let observed = await stubHandler.observedCancel.value() + XCTAssertTrue(observed) + } + + func testInboundResponsesAreIgnored() async throws { + let store = MCPSessionStore() + let dispatcher = MCPProtocolDispatcher( + handlers: [PingHandler()], + sessionStore: store, + progressSink: StubProgressSink() + ) + + let response = JsonRpcMessage.successResponse( + JsonRpcSuccessResponse(id: .number(99), result: .object([:])) + ) + let (exchange, sink) = MCPProtocolTestSupport.makeExchange(message: response) + + await dispatcher.dispatch(exchange) + await sink.waitForCompletion() + + let acceptedCount = await sink.acceptedCount + XCTAssertEqual(acceptedCount, 1) + let jsonWrites = await sink.jsonWrites + XCTAssertTrue(jsonWrites.isEmpty) + } + + func testNotificationInitializedTransitionsSessionWithoutResponse() async throws { + let store = MCPSessionStore() + let session = try await store.create() + let sessionId = await session.id + let dispatcher = MCPProtocolDispatcher( + handlers: [], + sessionStore: store, + progressSink: StubProgressSink() + ) + + let stateBefore = await session.state + XCTAssertEqual(stateBefore, .initializing) + + let notification = MCPProtocolTestSupport.makeNotification( + method: "notifications/initialized" + ) + let (exchange, sink) = MCPProtocolTestSupport.makeExchange( + message: notification, + sessionId: sessionId + ) + + await dispatcher.dispatch(exchange) + await sink.waitForCompletion() + + let stateAfter = await session.state + XCTAssertEqual(stateAfter, .ready) + + let acceptedCount = await sink.acceptedCount + XCTAssertEqual(acceptedCount, 1) + let writes = await sink.jsonWrites + XCTAssertTrue(writes.isEmpty) + } + + func testConcurrentRequestsInSameSessionAllComplete() async throws { + let store = MCPSessionStore() + let session = try await store.create() + let sessionId = await session.id + try await session.transitionToReady() + + let dispatcher = MCPProtocolDispatcher( + handlers: [PingHandler()], + sessionStore: store, + progressSink: StubProgressSink() + ) + + let count = 5 + var sinks: [RecordingResponderSink] = [] + sinks.reserveCapacity(count) + + await withTaskGroup(of: RecordingResponderSink.self) { group in + for index in 0..() + for sink in sinks { + let decoded = try await sink.firstJsonMessage() + guard case .successResponse(let success) = decoded else { + XCTFail("Expected success response, got \(String(describing: decoded))") + return + } + guard case .number(let value) = success.id else { + XCTFail("Expected numeric id, got \(success.id)") + return + } + seenIds.insert(value) + } + XCTAssertEqual(seenIds, Set((1...count).map { Int64($0) })) + } + + func testHandlerThrowingProtocolErrorYieldsErrorResponse() async throws { + let store = MCPSessionStore() + let session = try await store.create() + let sessionId = await session.id + try await session.transitionToReady() + + let stubError = MCPProtocolError.invalidParams(detail: "bad shape") + let handler = StubMethodHandler(behavior: .throwProtocolError(stubError)) + let dispatcher = MCPProtocolDispatcher( + handlers: [handler], + sessionStore: store, + progressSink: StubProgressSink() + ) + + let request = MCPProtocolTestSupport.makeRequest( + id: .number(11), + method: StubMethodHandler.method + ) + let (exchange, sink) = MCPProtocolTestSupport.makeExchange( + message: request, + sessionId: sessionId + ) + + await dispatcher.dispatch(exchange) + await sink.waitForCompletion() + + let decoded = try await sink.firstJsonMessage() + guard case .errorResponse(let envelope) = decoded else { + XCTFail("Expected error response, got \(String(describing: decoded))") + return + } + XCTAssertEqual(envelope.error.code, JsonRpcErrorCode.invalidParams) + } + + func testRequestWithoutSessionIdAndNonInitializeMethodFails() async throws { + let store = MCPSessionStore() + let dispatcher = MCPProtocolDispatcher( + handlers: [PingHandler()], + sessionStore: store, + progressSink: StubProgressSink() + ) + + let request = MCPProtocolTestSupport.makeRequest( + id: .number(20), + method: "ping" + ) + let (exchange, sink) = MCPProtocolTestSupport.makeExchange( + message: request, + sessionId: nil + ) + + await dispatcher.dispatch(exchange) + await sink.waitForCompletion() + + let decoded = try await sink.firstJsonMessage() + guard case .errorResponse(let envelope) = decoded else { + XCTFail("Expected error response, got \(String(describing: decoded))") + return + } + XCTAssertEqual(envelope.error.code, JsonRpcErrorCode.sessionNotFound) + } + + private func waitUntil( + timeoutMs: UInt64, + _ predicate: @Sendable () async -> Bool + ) async throws { + let deadline = Date().addingTimeInterval(Double(timeoutMs) / 1_000.0) + while Date() < deadline { + if await predicate() { return } + try await Task.sleep(nanoseconds: 10_000_000) + } + if await predicate() { return } + XCTFail("Timed out waiting for condition after \(timeoutMs)ms") + } +} diff --git a/TableProTests/Core/MCP/Protocol/Tools/ConfirmDestructiveOperationToolTests.swift b/TableProTests/Core/MCP/Protocol/Tools/ConfirmDestructiveOperationToolTests.swift new file mode 100644 index 000000000..30c8c2600 --- /dev/null +++ b/TableProTests/Core/MCP/Protocol/Tools/ConfirmDestructiveOperationToolTests.swift @@ -0,0 +1,91 @@ +import Foundation +@testable import TablePro +import Testing + +@Suite("ConfirmDestructiveOperationTool") +struct ConfirmDestructiveOperationToolTests { + @Test("Tool requires write scope") + func requiresWriteScope() { + #expect(ConfirmDestructiveOperationTool.requiredScopes == [.toolsWrite]) + #expect(ConfirmDestructiveOperationTool.name == "confirm_destructive_operation") + } + + @Test("Wrong confirmation phrase returns invalidParams") + func wrongConfirmationPhrase() async throws { + let tool = ConfirmDestructiveOperationTool() + let context = await MCPProtocolHandlerTestSupport.makeContext(method: "tools/call") + let services = MCPToolServices( + connectionBridge: MCPConnectionBridge(), + authPolicy: MCPAuthPolicy() + ) + let connectionId = UUID() + + await #expect(throws: MCPProtocolError.self) { + _ = try await tool.call( + arguments: .object([ + "connection_id": .string(connectionId.uuidString), + "query": .string("DROP TABLE users"), + "confirmation_phrase": .string("yes do it") + ]), + context: context, + services: services + ) + } + } + + @Test("Missing query returns invalidParams") + func missingQuery() async throws { + let tool = ConfirmDestructiveOperationTool() + let context = await MCPProtocolHandlerTestSupport.makeContext(method: "tools/call") + let services = MCPToolServices( + connectionBridge: MCPConnectionBridge(), + authPolicy: MCPAuthPolicy() + ) + + await #expect(throws: MCPProtocolError.self) { + _ = try await tool.call( + arguments: .object([ + "connection_id": .string(UUID().uuidString), + "confirmation_phrase": .string("I understand this is irreversible") + ]), + context: context, + services: services + ) + } + } + + @Test("Multi-statement query is rejected before connection lookup") + func multiStatementRejected() async throws { + let tool = ConfirmDestructiveOperationTool() + let context = await MCPProtocolHandlerTestSupport.makeContext(method: "tools/call") + let services = MCPToolServices( + connectionBridge: MCPConnectionBridge(), + authPolicy: MCPAuthPolicy() + ) + let connectionId = UUID() + + do { + _ = try await tool.call( + arguments: .object([ + "connection_id": .string(connectionId.uuidString), + "query": .string("DROP TABLE users; DROP TABLE other"), + "confirmation_phrase": .string("I understand this is irreversible") + ]), + context: context, + services: services + ) + Issue.record("Expected MCPProtocolError for multi-statement query") + } catch let error as MCPProtocolError { + #expect(error.code == JsonRpcErrorCode.invalidParams) + } + } + + @Test("Tool input schema declares required fields") + func inputSchemaRequiredFields() { + let schema = ConfirmDestructiveOperationTool.inputSchema + let required = schema["required"]?.arrayValue?.compactMap(\.stringValue) ?? [] + #expect(required.contains("connection_id")) + #expect(required.contains("query")) + #expect(required.contains("confirmation_phrase")) + } +} diff --git a/TableProTests/Core/MCP/Protocol/Tools/ConnectToolTests.swift b/TableProTests/Core/MCP/Protocol/Tools/ConnectToolTests.swift new file mode 100644 index 000000000..be9518fa0 --- /dev/null +++ b/TableProTests/Core/MCP/Protocol/Tools/ConnectToolTests.swift @@ -0,0 +1,54 @@ +import Foundation +@testable import TablePro +import Testing + +@Suite("ConnectTool") +struct ConnectToolTests { + @Test("Missing connection_id returns invalidParams") + func missingConnectionId() async throws { + let tool = ConnectTool() + let context = await MCPProtocolHandlerTestSupport.makeContext(method: "tools/call") + let services = MCPToolServices( + connectionBridge: MCPConnectionBridge(), + authPolicy: MCPAuthPolicy() + ) + + await #expect(throws: MCPProtocolError.self) { + _ = try await tool.call( + arguments: .object([:]), + context: context, + services: services + ) + } + } + + @Test("Malformed connection_id returns invalidParams") + func malformedConnectionId() async throws { + let tool = ConnectTool() + let context = await MCPProtocolHandlerTestSupport.makeContext(method: "tools/call") + let services = MCPToolServices( + connectionBridge: MCPConnectionBridge(), + authPolicy: MCPAuthPolicy() + ) + + await #expect(throws: MCPProtocolError.self) { + _ = try await tool.call( + arguments: .object([ + "connection_id": .string("not-a-uuid") + ]), + context: context, + services: services + ) + } + } + + @Test("Tool exposes expected metadata") + func metadata() { + #expect(ConnectTool.name == "connect") + #expect(ConnectTool.requiredScopes == [.toolsRead]) + let schema = ConnectTool.inputSchema + #expect(schema["type"]?.stringValue == "object") + let required = schema["required"]?.arrayValue?.compactMap(\.stringValue) + #expect(required == ["connection_id"]) + } +} diff --git a/TableProTests/Core/MCP/Protocol/Tools/DescribeTableToolTests.swift b/TableProTests/Core/MCP/Protocol/Tools/DescribeTableToolTests.swift new file mode 100644 index 000000000..104628384 --- /dev/null +++ b/TableProTests/Core/MCP/Protocol/Tools/DescribeTableToolTests.swift @@ -0,0 +1,64 @@ +import Foundation +@testable import TablePro +import Testing + +@Suite("DescribeTableTool") +struct DescribeTableToolTests { + @Test("Tool exposes expected metadata") + func metadata() { + #expect(DescribeTableTool.name == "describe_table") + #expect(DescribeTableTool.requiredScopes == [.toolsRead]) + let schema = DescribeTableTool.inputSchema + #expect(schema["type"]?.stringValue == "object") + let required = schema["required"]?.arrayValue?.compactMap(\.stringValue) ?? [] + #expect(required == ["connection_id", "table"]) + } + + @Test("Missing connection_id returns invalidParams") + func missingConnectionId() async throws { + let tool = DescribeTableTool() + let context = await MCPProtocolHandlerTestSupport.makeContext(method: "tools/call") + let services = MCPToolServices(connectionBridge: MCPConnectionBridge(), authPolicy: MCPAuthPolicy()) + + await #expect(throws: MCPProtocolError.self) { + _ = try await tool.call( + arguments: .object(["table": .string("users")]), + context: context, + services: services + ) + } + } + + @Test("Missing table returns invalidParams") + func missingTable() async throws { + let tool = DescribeTableTool() + let context = await MCPProtocolHandlerTestSupport.makeContext(method: "tools/call") + let services = MCPToolServices(connectionBridge: MCPConnectionBridge(), authPolicy: MCPAuthPolicy()) + + await #expect(throws: MCPProtocolError.self) { + _ = try await tool.call( + arguments: .object(["connection_id": .string(UUID().uuidString)]), + context: context, + services: services + ) + } + } + + @Test("Malformed connection_id returns invalidParams") + func malformedConnectionId() async throws { + let tool = DescribeTableTool() + let context = await MCPProtocolHandlerTestSupport.makeContext(method: "tools/call") + let services = MCPToolServices(connectionBridge: MCPConnectionBridge(), authPolicy: MCPAuthPolicy()) + + await #expect(throws: MCPProtocolError.self) { + _ = try await tool.call( + arguments: .object([ + "connection_id": .string("not-a-uuid"), + "table": .string("users") + ]), + context: context, + services: services + ) + } + } +} diff --git a/TableProTests/Core/MCP/Protocol/Tools/DisconnectToolTests.swift b/TableProTests/Core/MCP/Protocol/Tools/DisconnectToolTests.swift new file mode 100644 index 000000000..94868ab05 --- /dev/null +++ b/TableProTests/Core/MCP/Protocol/Tools/DisconnectToolTests.swift @@ -0,0 +1,42 @@ +import Foundation +@testable import TablePro +import Testing + +@Suite("DisconnectTool") +struct DisconnectToolTests { + @Test("Tool exposes expected metadata") + func metadata() { + #expect(DisconnectTool.name == "disconnect") + #expect(DisconnectTool.requiredScopes == [.toolsWrite]) + let schema = DisconnectTool.inputSchema + #expect(schema["type"]?.stringValue == "object") + let required = schema["required"]?.arrayValue?.compactMap(\.stringValue) ?? [] + #expect(required == ["connection_id"]) + } + + @Test("Missing connection_id returns invalidParams") + func missingConnectionId() async throws { + let tool = DisconnectTool() + let context = await MCPProtocolHandlerTestSupport.makeContext(method: "tools/call") + let services = MCPToolServices(connectionBridge: MCPConnectionBridge(), authPolicy: MCPAuthPolicy()) + + await #expect(throws: MCPProtocolError.self) { + _ = try await tool.call(arguments: .object([:]), context: context, services: services) + } + } + + @Test("Malformed connection_id returns invalidParams") + func malformedConnectionId() async throws { + let tool = DisconnectTool() + let context = await MCPProtocolHandlerTestSupport.makeContext(method: "tools/call") + let services = MCPToolServices(connectionBridge: MCPConnectionBridge(), authPolicy: MCPAuthPolicy()) + + await #expect(throws: MCPProtocolError.self) { + _ = try await tool.call( + arguments: .object(["connection_id": .string("not-a-uuid")]), + context: context, + services: services + ) + } + } +} diff --git a/TableProTests/Core/MCP/Protocol/Tools/ExecuteQueryToolTests.swift b/TableProTests/Core/MCP/Protocol/Tools/ExecuteQueryToolTests.swift new file mode 100644 index 000000000..95bc91522 --- /dev/null +++ b/TableProTests/Core/MCP/Protocol/Tools/ExecuteQueryToolTests.swift @@ -0,0 +1,192 @@ +import Foundation +@testable import TablePro +import Testing + +@Suite("ExecuteQueryTool") +struct ExecuteQueryToolTests { + @Test("Tool exposes correct metadata") + func metadata() { + #expect(ExecuteQueryTool.name == "execute_query") + #expect(ExecuteQueryTool.requiredScopes == [.toolsRead]) + let schema = ExecuteQueryTool.inputSchema + let required = schema["required"]?.arrayValue?.compactMap(\.stringValue) ?? [] + #expect(required.contains("connection_id")) + #expect(required.contains("query")) + } + + @Test("Multi-statement query is rejected before connection lookup") + func multiStatementRejected() async throws { + let tool = ExecuteQueryTool() + let context = await MCPProtocolHandlerTestSupport.makeContext(method: "tools/call") + let services = MCPToolServices( + connectionBridge: MCPConnectionBridge(), + authPolicy: MCPAuthPolicy() + ) + + do { + _ = try await tool.call( + arguments: .object([ + "connection_id": .string(UUID().uuidString), + "query": .string("SELECT 1; SELECT 2") + ]), + context: context, + services: services + ) + Issue.record("Expected MCPProtocolError for multi-statement query") + } catch let error as MCPProtocolError { + #expect(error.code == JsonRpcErrorCode.invalidParams) + } + } + + @Test("Query exceeding 100KB is rejected with invalidParams") + func queryTooLargeRejected() async throws { + let tool = ExecuteQueryTool() + let context = await MCPProtocolHandlerTestSupport.makeContext(method: "tools/call") + let services = MCPToolServices( + connectionBridge: MCPConnectionBridge(), + authPolicy: MCPAuthPolicy() + ) + let oversized = String(repeating: "a", count: 102_401) + + do { + _ = try await tool.call( + arguments: .object([ + "connection_id": .string(UUID().uuidString), + "query": .string(oversized) + ]), + context: context, + services: services + ) + Issue.record("Expected oversized query to be rejected") + } catch let error as MCPProtocolError { + #expect(error.code == JsonRpcErrorCode.invalidParams) + } + } + + @Test("Cancellation propagates as requestCancelled") + func cancellationPropagates() async throws { + let tool = ExecuteQueryTool() + let progressSink = StubProgressSink() + let context = await ExecuteQueryToolTestContext.make( + progressToken: nil, + progressSink: progressSink + ) + let services = MCPToolServices( + connectionBridge: MCPConnectionBridge(), + authPolicy: MCPAuthPolicy() + ) + + await context.cancellation.cancel() + + do { + _ = try await tool.call( + arguments: .object([ + "connection_id": .string(UUID().uuidString), + "query": .string("SELECT 1") + ]), + context: context, + services: services + ) + Issue.record("Expected cancelled error") + } catch let error as MCPProtocolError { + #expect(error.code == JsonRpcErrorCode.requestCancelled) + } + } + + @Test("Progress notifications fire when progressToken is set") + func progressEmittedWhenTokenPresent() async throws { + let tool = ExecuteQueryTool() + let progressSink = StubProgressSink() + let context = await ExecuteQueryToolTestContext.make( + progressToken: .string("progress-1"), + progressSink: progressSink + ) + let services = MCPToolServices( + connectionBridge: MCPConnectionBridge(), + authPolicy: MCPAuthPolicy() + ) + + _ = try? await tool.call( + arguments: .object([ + "connection_id": .string(UUID().uuidString), + "query": .string("SELECT 1") + ]), + context: context, + services: services + ) + + let methods = await progressSink.methods() + #expect(methods.allSatisfy { $0 == "notifications/progress" }) + #expect(methods.count >= 1) + } + + @Test("Progress notifications are skipped when no progressToken") + func progressSkippedWithoutToken() async throws { + let tool = ExecuteQueryTool() + let progressSink = StubProgressSink() + let context = await ExecuteQueryToolTestContext.make( + progressToken: nil, + progressSink: progressSink + ) + let services = MCPToolServices( + connectionBridge: MCPConnectionBridge(), + authPolicy: MCPAuthPolicy() + ) + + _ = try? await tool.call( + arguments: .object([ + "connection_id": .string(UUID().uuidString), + "query": .string("SELECT 1") + ]), + context: context, + services: services + ) + + let count = await progressSink.count() + #expect(count == 0) + } +} + +enum ExecuteQueryToolTestContext { + static func make( + progressToken: JsonValue?, + progressSink: StubProgressSink + ) async -> MCPRequestContext { + let sessionStore = MCPSessionStore() + let dispatcher = MCPProtocolDispatcher( + handlers: [], + sessionStore: sessionStore, + progressSink: progressSink, + clock: MCPSystemClock() + ) + + let session = MCPSession() + try? await session.transitionToReady() + let resolvedSessionId = await session.id + + let principal = MCPProtocolTestSupport.makePrincipal(scopes: [.toolsRead, .toolsWrite]) + let request = JsonRpcRequest(id: .number(1), method: "tools/call", params: nil) + let (exchange, _) = MCPProtocolTestSupport.makeExchange( + message: .request(request), + sessionId: resolvedSessionId, + principal: principal + ) + + let cancellation = MCPCancellationToken() + let progress = MCPProgressEmitter( + progressToken: progressToken, + target: progressSink, + sessionId: resolvedSessionId + ) + + return MCPRequestContext( + exchange: exchange, + session: session, + principal: principal, + dispatcher: dispatcher, + progress: progress, + cancellation: cancellation, + clock: MCPSystemClock() + ) + } +} diff --git a/TableProTests/Core/MCP/Protocol/Tools/ExportDataToolTests.swift b/TableProTests/Core/MCP/Protocol/Tools/ExportDataToolTests.swift new file mode 100644 index 000000000..9f43d26a9 --- /dev/null +++ b/TableProTests/Core/MCP/Protocol/Tools/ExportDataToolTests.swift @@ -0,0 +1,83 @@ +import Foundation +@testable import TablePro +import Testing + +@Suite("ExportDataTool") +struct ExportDataToolTests { + @Test("Tool exposes expected metadata") + func metadata() { + #expect(ExportDataTool.name == "export_data") + #expect(ExportDataTool.requiredScopes == [.toolsRead]) + let schema = ExportDataTool.inputSchema + #expect(schema["type"]?.stringValue == "object") + let required = schema["required"]?.arrayValue?.compactMap(\.stringValue) ?? [] + #expect(required == ["connection_id", "format"]) + } + + @Test("Missing connection_id returns invalidParams") + func missingConnectionId() async throws { + let tool = ExportDataTool() + let context = await MCPProtocolHandlerTestSupport.makeContext(method: "tools/call") + let services = MCPToolServices(connectionBridge: MCPConnectionBridge(), authPolicy: MCPAuthPolicy()) + + await #expect(throws: MCPProtocolError.self) { + _ = try await tool.call( + arguments: .object(["format": .string("csv")]), + context: context, + services: services + ) + } + } + + @Test("Missing format returns invalidParams") + func missingFormat() async throws { + let tool = ExportDataTool() + let context = await MCPProtocolHandlerTestSupport.makeContext(method: "tools/call") + let services = MCPToolServices(connectionBridge: MCPConnectionBridge(), authPolicy: MCPAuthPolicy()) + + await #expect(throws: MCPProtocolError.self) { + _ = try await tool.call( + arguments: .object(["connection_id": .string(UUID().uuidString)]), + context: context, + services: services + ) + } + } + + @Test("Malformed connection_id returns invalidParams") + func malformedConnectionId() async throws { + let tool = ExportDataTool() + let context = await MCPProtocolHandlerTestSupport.makeContext(method: "tools/call") + let services = MCPToolServices(connectionBridge: MCPConnectionBridge(), authPolicy: MCPAuthPolicy()) + + await #expect(throws: MCPProtocolError.self) { + _ = try await tool.call( + arguments: .object([ + "connection_id": .string("not-a-uuid"), + "format": .string("csv"), + "query": .string("SELECT 1") + ]), + context: context, + services: services + ) + } + } + + @Test("Neither query nor tables returns invalidParams") + func missingQueryAndTables() async throws { + let tool = ExportDataTool() + let context = await MCPProtocolHandlerTestSupport.makeContext(method: "tools/call") + let services = MCPToolServices(connectionBridge: MCPConnectionBridge(), authPolicy: MCPAuthPolicy()) + + await #expect(throws: MCPProtocolError.self) { + _ = try await tool.call( + arguments: .object([ + "connection_id": .string(UUID().uuidString), + "format": .string("csv") + ]), + context: context, + services: services + ) + } + } +} diff --git a/TableProTests/Core/MCP/Protocol/Tools/FocusQueryTabToolTests.swift b/TableProTests/Core/MCP/Protocol/Tools/FocusQueryTabToolTests.swift new file mode 100644 index 000000000..34a2edc6b --- /dev/null +++ b/TableProTests/Core/MCP/Protocol/Tools/FocusQueryTabToolTests.swift @@ -0,0 +1,42 @@ +import Foundation +@testable import TablePro +import Testing + +@Suite("FocusQueryTabTool") +struct FocusQueryTabToolTests { + @Test("Tool exposes expected metadata") + func metadata() { + #expect(FocusQueryTabTool.name == "focus_query_tab") + #expect(FocusQueryTabTool.requiredScopes == [.toolsRead]) + let schema = FocusQueryTabTool.inputSchema + #expect(schema["type"]?.stringValue == "object") + let required = schema["required"]?.arrayValue?.compactMap(\.stringValue) ?? [] + #expect(required == ["tab_id"]) + } + + @Test("Missing tab_id returns invalidParams") + func missingTabId() async throws { + let tool = FocusQueryTabTool() + let context = await MCPProtocolHandlerTestSupport.makeContext(method: "tools/call") + let services = MCPToolServices(connectionBridge: MCPConnectionBridge(), authPolicy: MCPAuthPolicy()) + + await #expect(throws: MCPProtocolError.self) { + _ = try await tool.call(arguments: .object([:]), context: context, services: services) + } + } + + @Test("Malformed tab_id returns invalidParams") + func malformedTabId() async throws { + let tool = FocusQueryTabTool() + let context = await MCPProtocolHandlerTestSupport.makeContext(method: "tools/call") + let services = MCPToolServices(connectionBridge: MCPConnectionBridge(), authPolicy: MCPAuthPolicy()) + + await #expect(throws: MCPProtocolError.self) { + _ = try await tool.call( + arguments: .object(["tab_id": .string("not-a-uuid")]), + context: context, + services: services + ) + } + } +} diff --git a/TableProTests/Core/MCP/Protocol/Tools/GetConnectionStatusToolTests.swift b/TableProTests/Core/MCP/Protocol/Tools/GetConnectionStatusToolTests.swift new file mode 100644 index 000000000..4434a9f59 --- /dev/null +++ b/TableProTests/Core/MCP/Protocol/Tools/GetConnectionStatusToolTests.swift @@ -0,0 +1,42 @@ +import Foundation +@testable import TablePro +import Testing + +@Suite("GetConnectionStatusTool") +struct GetConnectionStatusToolTests { + @Test("Tool exposes expected metadata") + func metadata() { + #expect(GetConnectionStatusTool.name == "get_connection_status") + #expect(GetConnectionStatusTool.requiredScopes == [.toolsRead]) + let schema = GetConnectionStatusTool.inputSchema + #expect(schema["type"]?.stringValue == "object") + let required = schema["required"]?.arrayValue?.compactMap(\.stringValue) ?? [] + #expect(required == ["connection_id"]) + } + + @Test("Missing connection_id returns invalidParams") + func missingConnectionId() async throws { + let tool = GetConnectionStatusTool() + let context = await MCPProtocolHandlerTestSupport.makeContext(method: "tools/call") + let services = MCPToolServices(connectionBridge: MCPConnectionBridge(), authPolicy: MCPAuthPolicy()) + + await #expect(throws: MCPProtocolError.self) { + _ = try await tool.call(arguments: .object([:]), context: context, services: services) + } + } + + @Test("Malformed connection_id returns invalidParams") + func malformedConnectionId() async throws { + let tool = GetConnectionStatusTool() + let context = await MCPProtocolHandlerTestSupport.makeContext(method: "tools/call") + let services = MCPToolServices(connectionBridge: MCPConnectionBridge(), authPolicy: MCPAuthPolicy()) + + await #expect(throws: MCPProtocolError.self) { + _ = try await tool.call( + arguments: .object(["connection_id": .string("not-a-uuid")]), + context: context, + services: services + ) + } + } +} diff --git a/TableProTests/Core/MCP/Protocol/Tools/GetTableDdlToolTests.swift b/TableProTests/Core/MCP/Protocol/Tools/GetTableDdlToolTests.swift new file mode 100644 index 000000000..716c5d7ce --- /dev/null +++ b/TableProTests/Core/MCP/Protocol/Tools/GetTableDdlToolTests.swift @@ -0,0 +1,64 @@ +import Foundation +@testable import TablePro +import Testing + +@Suite("GetTableDdlTool") +struct GetTableDdlToolTests { + @Test("Tool exposes expected metadata") + func metadata() { + #expect(GetTableDdlTool.name == "get_table_ddl") + #expect(GetTableDdlTool.requiredScopes == [.toolsRead]) + let schema = GetTableDdlTool.inputSchema + #expect(schema["type"]?.stringValue == "object") + let required = schema["required"]?.arrayValue?.compactMap(\.stringValue) ?? [] + #expect(required == ["connection_id", "table"]) + } + + @Test("Missing connection_id returns invalidParams") + func missingConnectionId() async throws { + let tool = GetTableDdlTool() + let context = await MCPProtocolHandlerTestSupport.makeContext(method: "tools/call") + let services = MCPToolServices(connectionBridge: MCPConnectionBridge(), authPolicy: MCPAuthPolicy()) + + await #expect(throws: MCPProtocolError.self) { + _ = try await tool.call( + arguments: .object(["table": .string("users")]), + context: context, + services: services + ) + } + } + + @Test("Missing table returns invalidParams") + func missingTable() async throws { + let tool = GetTableDdlTool() + let context = await MCPProtocolHandlerTestSupport.makeContext(method: "tools/call") + let services = MCPToolServices(connectionBridge: MCPConnectionBridge(), authPolicy: MCPAuthPolicy()) + + await #expect(throws: MCPProtocolError.self) { + _ = try await tool.call( + arguments: .object(["connection_id": .string(UUID().uuidString)]), + context: context, + services: services + ) + } + } + + @Test("Malformed connection_id returns invalidParams") + func malformedConnectionId() async throws { + let tool = GetTableDdlTool() + let context = await MCPProtocolHandlerTestSupport.makeContext(method: "tools/call") + let services = MCPToolServices(connectionBridge: MCPConnectionBridge(), authPolicy: MCPAuthPolicy()) + + await #expect(throws: MCPProtocolError.self) { + _ = try await tool.call( + arguments: .object([ + "connection_id": .string("not-a-uuid"), + "table": .string("users") + ]), + context: context, + services: services + ) + } + } +} diff --git a/TableProTests/Core/MCP/Protocol/Tools/ListConnectionsToolTests.swift b/TableProTests/Core/MCP/Protocol/Tools/ListConnectionsToolTests.swift new file mode 100644 index 000000000..d8625f549 --- /dev/null +++ b/TableProTests/Core/MCP/Protocol/Tools/ListConnectionsToolTests.swift @@ -0,0 +1,27 @@ +import Foundation +@testable import TablePro +import Testing + +@Suite("ListConnectionsTool") +struct ListConnectionsToolTests { + @Test("Tool exposes expected metadata") + func metadata() { + #expect(ListConnectionsTool.name == "list_connections") + #expect(ListConnectionsTool.requiredScopes == [.toolsRead]) + let schema = ListConnectionsTool.inputSchema + #expect(schema["type"]?.stringValue == "object") + let required = schema["required"]?.arrayValue?.compactMap(\.stringValue) ?? [] + #expect(required == []) + } + + @Test("Empty arguments returns a successful result") + func emptyArgumentsSucceed() async throws { + let tool = ListConnectionsTool() + let context = await MCPProtocolHandlerTestSupport.makeContext(method: "tools/call") + let services = MCPToolServices(connectionBridge: MCPConnectionBridge(), authPolicy: MCPAuthPolicy()) + + let result = try await tool.call(arguments: .object([:]), context: context, services: services) + #expect(result.isError == false) + #expect(result.content.isEmpty == false) + } +} diff --git a/TableProTests/Core/MCP/Protocol/Tools/ListDatabasesToolTests.swift b/TableProTests/Core/MCP/Protocol/Tools/ListDatabasesToolTests.swift new file mode 100644 index 000000000..d0fb0fd9b --- /dev/null +++ b/TableProTests/Core/MCP/Protocol/Tools/ListDatabasesToolTests.swift @@ -0,0 +1,42 @@ +import Foundation +@testable import TablePro +import Testing + +@Suite("ListDatabasesTool") +struct ListDatabasesToolTests { + @Test("Tool exposes expected metadata") + func metadata() { + #expect(ListDatabasesTool.name == "list_databases") + #expect(ListDatabasesTool.requiredScopes == [.toolsRead]) + let schema = ListDatabasesTool.inputSchema + #expect(schema["type"]?.stringValue == "object") + let required = schema["required"]?.arrayValue?.compactMap(\.stringValue) ?? [] + #expect(required == ["connection_id"]) + } + + @Test("Missing connection_id returns invalidParams") + func missingConnectionId() async throws { + let tool = ListDatabasesTool() + let context = await MCPProtocolHandlerTestSupport.makeContext(method: "tools/call") + let services = MCPToolServices(connectionBridge: MCPConnectionBridge(), authPolicy: MCPAuthPolicy()) + + await #expect(throws: MCPProtocolError.self) { + _ = try await tool.call(arguments: .object([:]), context: context, services: services) + } + } + + @Test("Malformed connection_id returns invalidParams") + func malformedConnectionId() async throws { + let tool = ListDatabasesTool() + let context = await MCPProtocolHandlerTestSupport.makeContext(method: "tools/call") + let services = MCPToolServices(connectionBridge: MCPConnectionBridge(), authPolicy: MCPAuthPolicy()) + + await #expect(throws: MCPProtocolError.self) { + _ = try await tool.call( + arguments: .object(["connection_id": .string("not-a-uuid")]), + context: context, + services: services + ) + } + } +} diff --git a/TableProTests/Core/MCP/Protocol/Tools/ListRecentTabsToolTests.swift b/TableProTests/Core/MCP/Protocol/Tools/ListRecentTabsToolTests.swift new file mode 100644 index 000000000..d37fc0995 --- /dev/null +++ b/TableProTests/Core/MCP/Protocol/Tools/ListRecentTabsToolTests.swift @@ -0,0 +1,27 @@ +import Foundation +@testable import TablePro +import Testing + +@Suite("ListRecentTabsTool") +struct ListRecentTabsToolTests { + @Test("Tool exposes expected metadata") + func metadata() { + #expect(ListRecentTabsTool.name == "list_recent_tabs") + #expect(ListRecentTabsTool.requiredScopes == [.toolsRead]) + let schema = ListRecentTabsTool.inputSchema + #expect(schema["type"]?.stringValue == "object") + let required = schema["required"]?.arrayValue?.compactMap(\.stringValue) ?? [] + #expect(required == []) + } + + @Test("Empty arguments returns a successful result") + func emptyArgumentsSucceed() async throws { + let tool = ListRecentTabsTool() + let context = await MCPProtocolHandlerTestSupport.makeContext(method: "tools/call") + let services = MCPToolServices(connectionBridge: MCPConnectionBridge(), authPolicy: MCPAuthPolicy()) + + let result = try await tool.call(arguments: .object([:]), context: context, services: services) + #expect(result.isError == false) + #expect(result.content.isEmpty == false) + } +} diff --git a/TableProTests/Core/MCP/Protocol/Tools/ListSchemasToolTests.swift b/TableProTests/Core/MCP/Protocol/Tools/ListSchemasToolTests.swift new file mode 100644 index 000000000..912b898cb --- /dev/null +++ b/TableProTests/Core/MCP/Protocol/Tools/ListSchemasToolTests.swift @@ -0,0 +1,42 @@ +import Foundation +@testable import TablePro +import Testing + +@Suite("ListSchemasTool") +struct ListSchemasToolTests { + @Test("Tool exposes expected metadata") + func metadata() { + #expect(ListSchemasTool.name == "list_schemas") + #expect(ListSchemasTool.requiredScopes == [.toolsRead]) + let schema = ListSchemasTool.inputSchema + #expect(schema["type"]?.stringValue == "object") + let required = schema["required"]?.arrayValue?.compactMap(\.stringValue) ?? [] + #expect(required == ["connection_id"]) + } + + @Test("Missing connection_id returns invalidParams") + func missingConnectionId() async throws { + let tool = ListSchemasTool() + let context = await MCPProtocolHandlerTestSupport.makeContext(method: "tools/call") + let services = MCPToolServices(connectionBridge: MCPConnectionBridge(), authPolicy: MCPAuthPolicy()) + + await #expect(throws: MCPProtocolError.self) { + _ = try await tool.call(arguments: .object([:]), context: context, services: services) + } + } + + @Test("Malformed connection_id returns invalidParams") + func malformedConnectionId() async throws { + let tool = ListSchemasTool() + let context = await MCPProtocolHandlerTestSupport.makeContext(method: "tools/call") + let services = MCPToolServices(connectionBridge: MCPConnectionBridge(), authPolicy: MCPAuthPolicy()) + + await #expect(throws: MCPProtocolError.self) { + _ = try await tool.call( + arguments: .object(["connection_id": .string("not-a-uuid")]), + context: context, + services: services + ) + } + } +} diff --git a/TableProTests/Core/MCP/Protocol/Tools/ListTablesToolTests.swift b/TableProTests/Core/MCP/Protocol/Tools/ListTablesToolTests.swift new file mode 100644 index 000000000..d0ee51701 --- /dev/null +++ b/TableProTests/Core/MCP/Protocol/Tools/ListTablesToolTests.swift @@ -0,0 +1,42 @@ +import Foundation +@testable import TablePro +import Testing + +@Suite("ListTablesTool") +struct ListTablesToolTests { + @Test("Tool exposes expected metadata") + func metadata() { + #expect(ListTablesTool.name == "list_tables") + #expect(ListTablesTool.requiredScopes == [.toolsRead]) + let schema = ListTablesTool.inputSchema + #expect(schema["type"]?.stringValue == "object") + let required = schema["required"]?.arrayValue?.compactMap(\.stringValue) ?? [] + #expect(required == ["connection_id"]) + } + + @Test("Missing connection_id returns invalidParams") + func missingConnectionId() async throws { + let tool = ListTablesTool() + let context = await MCPProtocolHandlerTestSupport.makeContext(method: "tools/call") + let services = MCPToolServices(connectionBridge: MCPConnectionBridge(), authPolicy: MCPAuthPolicy()) + + await #expect(throws: MCPProtocolError.self) { + _ = try await tool.call(arguments: .object([:]), context: context, services: services) + } + } + + @Test("Malformed connection_id returns invalidParams") + func malformedConnectionId() async throws { + let tool = ListTablesTool() + let context = await MCPProtocolHandlerTestSupport.makeContext(method: "tools/call") + let services = MCPToolServices(connectionBridge: MCPConnectionBridge(), authPolicy: MCPAuthPolicy()) + + await #expect(throws: MCPProtocolError.self) { + _ = try await tool.call( + arguments: .object(["connection_id": .string("not-a-uuid")]), + context: context, + services: services + ) + } + } +} diff --git a/TableProTests/Core/MCP/Protocol/Tools/OpenConnectionWindowToolTests.swift b/TableProTests/Core/MCP/Protocol/Tools/OpenConnectionWindowToolTests.swift new file mode 100644 index 000000000..870f3519b --- /dev/null +++ b/TableProTests/Core/MCP/Protocol/Tools/OpenConnectionWindowToolTests.swift @@ -0,0 +1,42 @@ +import Foundation +@testable import TablePro +import Testing + +@Suite("OpenConnectionWindowTool") +struct OpenConnectionWindowToolTests { + @Test("Tool exposes expected metadata") + func metadata() { + #expect(OpenConnectionWindowTool.name == "open_connection_window") + #expect(OpenConnectionWindowTool.requiredScopes == [.toolsRead]) + let schema = OpenConnectionWindowTool.inputSchema + #expect(schema["type"]?.stringValue == "object") + let required = schema["required"]?.arrayValue?.compactMap(\.stringValue) ?? [] + #expect(required == ["connection_id"]) + } + + @Test("Missing connection_id returns invalidParams") + func missingConnectionId() async throws { + let tool = OpenConnectionWindowTool() + let context = await MCPProtocolHandlerTestSupport.makeContext(method: "tools/call") + let services = MCPToolServices(connectionBridge: MCPConnectionBridge(), authPolicy: MCPAuthPolicy()) + + await #expect(throws: MCPProtocolError.self) { + _ = try await tool.call(arguments: .object([:]), context: context, services: services) + } + } + + @Test("Malformed connection_id returns invalidParams") + func malformedConnectionId() async throws { + let tool = OpenConnectionWindowTool() + let context = await MCPProtocolHandlerTestSupport.makeContext(method: "tools/call") + let services = MCPToolServices(connectionBridge: MCPConnectionBridge(), authPolicy: MCPAuthPolicy()) + + await #expect(throws: MCPProtocolError.self) { + _ = try await tool.call( + arguments: .object(["connection_id": .string("not-a-uuid")]), + context: context, + services: services + ) + } + } +} diff --git a/TableProTests/Core/MCP/Protocol/Tools/OpenTableTabToolTests.swift b/TableProTests/Core/MCP/Protocol/Tools/OpenTableTabToolTests.swift new file mode 100644 index 000000000..12e30a160 --- /dev/null +++ b/TableProTests/Core/MCP/Protocol/Tools/OpenTableTabToolTests.swift @@ -0,0 +1,64 @@ +import Foundation +@testable import TablePro +import Testing + +@Suite("OpenTableTabTool") +struct OpenTableTabToolTests { + @Test("Tool exposes expected metadata") + func metadata() { + #expect(OpenTableTabTool.name == "open_table_tab") + #expect(OpenTableTabTool.requiredScopes == [.toolsRead]) + let schema = OpenTableTabTool.inputSchema + #expect(schema["type"]?.stringValue == "object") + let required = schema["required"]?.arrayValue?.compactMap(\.stringValue) ?? [] + #expect(required == ["connection_id", "table_name"]) + } + + @Test("Missing connection_id returns invalidParams") + func missingConnectionId() async throws { + let tool = OpenTableTabTool() + let context = await MCPProtocolHandlerTestSupport.makeContext(method: "tools/call") + let services = MCPToolServices(connectionBridge: MCPConnectionBridge(), authPolicy: MCPAuthPolicy()) + + await #expect(throws: MCPProtocolError.self) { + _ = try await tool.call( + arguments: .object(["table_name": .string("users")]), + context: context, + services: services + ) + } + } + + @Test("Missing table_name returns invalidParams") + func missingTableName() async throws { + let tool = OpenTableTabTool() + let context = await MCPProtocolHandlerTestSupport.makeContext(method: "tools/call") + let services = MCPToolServices(connectionBridge: MCPConnectionBridge(), authPolicy: MCPAuthPolicy()) + + await #expect(throws: MCPProtocolError.self) { + _ = try await tool.call( + arguments: .object(["connection_id": .string(UUID().uuidString)]), + context: context, + services: services + ) + } + } + + @Test("Malformed connection_id returns invalidParams") + func malformedConnectionId() async throws { + let tool = OpenTableTabTool() + let context = await MCPProtocolHandlerTestSupport.makeContext(method: "tools/call") + let services = MCPToolServices(connectionBridge: MCPConnectionBridge(), authPolicy: MCPAuthPolicy()) + + await #expect(throws: MCPProtocolError.self) { + _ = try await tool.call( + arguments: .object([ + "connection_id": .string("not-a-uuid"), + "table_name": .string("users") + ]), + context: context, + services: services + ) + } + } +} diff --git a/TableProTests/Core/MCP/Protocol/Tools/SearchQueryHistoryToolTests.swift b/TableProTests/Core/MCP/Protocol/Tools/SearchQueryHistoryToolTests.swift new file mode 100644 index 000000000..55264853d --- /dev/null +++ b/TableProTests/Core/MCP/Protocol/Tools/SearchQueryHistoryToolTests.swift @@ -0,0 +1,45 @@ +import Foundation +@testable import TablePro +import Testing + +@Suite("SearchQueryHistoryTool") +struct SearchQueryHistoryToolTests { + @Test("Tool exposes expected metadata") + func metadata() { + #expect(SearchQueryHistoryTool.name == "search_query_history") + #expect(SearchQueryHistoryTool.requiredScopes == [.toolsRead]) + let schema = SearchQueryHistoryTool.inputSchema + #expect(schema["type"]?.stringValue == "object") + let required = schema["required"]?.arrayValue?.compactMap(\.stringValue) ?? [] + #expect(required == ["query"]) + } + + @Test("Missing query returns invalidParams") + func missingQuery() async throws { + let tool = SearchQueryHistoryTool() + let context = await MCPProtocolHandlerTestSupport.makeContext(method: "tools/call") + let services = MCPToolServices(connectionBridge: MCPConnectionBridge(), authPolicy: MCPAuthPolicy()) + + await #expect(throws: MCPProtocolError.self) { + _ = try await tool.call(arguments: .object([:]), context: context, services: services) + } + } + + @Test("Malformed connection_id returns invalidParams") + func malformedConnectionId() async throws { + let tool = SearchQueryHistoryTool() + let context = await MCPProtocolHandlerTestSupport.makeContext(method: "tools/call") + let services = MCPToolServices(connectionBridge: MCPConnectionBridge(), authPolicy: MCPAuthPolicy()) + + await #expect(throws: MCPProtocolError.self) { + _ = try await tool.call( + arguments: .object([ + "query": .string("select"), + "connection_id": .string("not-a-uuid") + ]), + context: context, + services: services + ) + } + } +} diff --git a/TableProTests/Core/MCP/Protocol/Tools/SwitchDatabaseToolTests.swift b/TableProTests/Core/MCP/Protocol/Tools/SwitchDatabaseToolTests.swift new file mode 100644 index 000000000..4852ddc20 --- /dev/null +++ b/TableProTests/Core/MCP/Protocol/Tools/SwitchDatabaseToolTests.swift @@ -0,0 +1,58 @@ +import Foundation +@testable import TablePro +import Testing + +@Suite("SwitchDatabaseTool") +struct SwitchDatabaseToolTests { + @Test("Tool requires write scope") + func requiresWriteScope() { + #expect(SwitchDatabaseTool.requiredScopes == [.toolsWrite]) + #expect(SwitchDatabaseTool.name == "switch_database") + } + + @Test("Missing connection_id returns invalidParams") + func missingConnectionId() async throws { + let tool = SwitchDatabaseTool() + let context = await MCPProtocolHandlerTestSupport.makeContext(method: "tools/call") + let services = MCPToolServices( + connectionBridge: MCPConnectionBridge(), + authPolicy: MCPAuthPolicy() + ) + + await #expect(throws: MCPProtocolError.self) { + _ = try await tool.call( + arguments: .object(["database": .string("foo")]), + context: context, + services: services + ) + } + } + + @Test("Missing database returns invalidParams") + func missingDatabase() async throws { + let tool = SwitchDatabaseTool() + let context = await MCPProtocolHandlerTestSupport.makeContext(method: "tools/call") + let services = MCPToolServices( + connectionBridge: MCPConnectionBridge(), + authPolicy: MCPAuthPolicy() + ) + + await #expect(throws: MCPProtocolError.self) { + _ = try await tool.call( + arguments: .object([ + "connection_id": .string(UUID().uuidString) + ]), + context: context, + services: services + ) + } + } + + @Test("Schema lists both required parameters") + func schemaRequiredFields() { + let schema = SwitchDatabaseTool.inputSchema + let required = schema["required"]?.arrayValue?.compactMap(\.stringValue) ?? [] + #expect(required.contains("connection_id")) + #expect(required.contains("database")) + } +} diff --git a/TableProTests/Core/MCP/Protocol/Tools/SwitchSchemaToolTests.swift b/TableProTests/Core/MCP/Protocol/Tools/SwitchSchemaToolTests.swift new file mode 100644 index 000000000..3f8493276 --- /dev/null +++ b/TableProTests/Core/MCP/Protocol/Tools/SwitchSchemaToolTests.swift @@ -0,0 +1,64 @@ +import Foundation +@testable import TablePro +import Testing + +@Suite("SwitchSchemaTool") +struct SwitchSchemaToolTests { + @Test("Tool exposes expected metadata") + func metadata() { + #expect(SwitchSchemaTool.name == "switch_schema") + #expect(SwitchSchemaTool.requiredScopes == [.toolsWrite]) + let schema = SwitchSchemaTool.inputSchema + #expect(schema["type"]?.stringValue == "object") + let required = schema["required"]?.arrayValue?.compactMap(\.stringValue) ?? [] + #expect(required == ["connection_id", "schema"]) + } + + @Test("Missing connection_id returns invalidParams") + func missingConnectionId() async throws { + let tool = SwitchSchemaTool() + let context = await MCPProtocolHandlerTestSupport.makeContext(method: "tools/call") + let services = MCPToolServices(connectionBridge: MCPConnectionBridge(), authPolicy: MCPAuthPolicy()) + + await #expect(throws: MCPProtocolError.self) { + _ = try await tool.call( + arguments: .object(["schema": .string("public")]), + context: context, + services: services + ) + } + } + + @Test("Missing schema returns invalidParams") + func missingSchema() async throws { + let tool = SwitchSchemaTool() + let context = await MCPProtocolHandlerTestSupport.makeContext(method: "tools/call") + let services = MCPToolServices(connectionBridge: MCPConnectionBridge(), authPolicy: MCPAuthPolicy()) + + await #expect(throws: MCPProtocolError.self) { + _ = try await tool.call( + arguments: .object(["connection_id": .string(UUID().uuidString)]), + context: context, + services: services + ) + } + } + + @Test("Malformed connection_id returns invalidParams") + func malformedConnectionId() async throws { + let tool = SwitchSchemaTool() + let context = await MCPProtocolHandlerTestSupport.makeContext(method: "tools/call") + let services = MCPToolServices(connectionBridge: MCPConnectionBridge(), authPolicy: MCPAuthPolicy()) + + await #expect(throws: MCPProtocolError.self) { + _ = try await tool.call( + arguments: .object([ + "connection_id": .string("not-a-uuid"), + "schema": .string("public") + ]), + context: context, + services: services + ) + } + } +} diff --git a/TableProTests/Core/MCP/RateLimit/MCPRateLimiterTests.swift b/TableProTests/Core/MCP/RateLimit/MCPRateLimiterTests.swift new file mode 100644 index 000000000..3f8a3beb8 --- /dev/null +++ b/TableProTests/Core/MCP/RateLimit/MCPRateLimiterTests.swift @@ -0,0 +1,143 @@ +import Foundation +@testable import TablePro +import Testing + +@Suite("MCP Rate Limiter") +struct MCPRateLimiterNewTests { + private func standardKey() -> MCPRateLimitKey { + MCPRateLimitKey(clientAddress: .loopback, principalFingerprint: "abcd1234") + } + + @Test("Five failures lock the key") + func fiveFailuresLock() async { + let clock = MCPTestClock() + let limiter = MCPRateLimiter(clock: clock) + let key = standardKey() + + for _ in 0..<4 { + let verdict = await limiter.recordAttempt(key: key, success: false) + #expect(verdict == .allowed) + } + let final = await limiter.recordAttempt(key: key, success: false) + guard case .lockedUntil = final else { + Issue.record("Expected lockedUntil, got \(final)") + return + } + let locked = await limiter.isLocked(key: key) + #expect(locked == true) + } + + @Test("Lock expires after lockout duration") + func lockExpires() async { + let clock = MCPTestClock() + let limiter = MCPRateLimiter( + policy: MCPRateLimitPolicy( + maxFailedAttempts: 3, + windowDuration: .seconds(60), + lockoutDuration: .seconds(120) + ), + clock: clock + ) + let key = standardKey() + + for _ in 0..<3 { + _ = await limiter.recordAttempt(key: key, success: false) + } + let lockedNow = await limiter.isLocked(key: key) + #expect(lockedNow == true) + + await clock.advance(by: .seconds(121)) + let lockedLater = await limiter.isLocked(key: key) + #expect(lockedLater == false) + } + + @Test("Different keys are isolated") + func differentKeysIsolated() async { + let clock = MCPTestClock() + let limiter = MCPRateLimiter(clock: clock) + let keyA = MCPRateLimitKey(clientAddress: .loopback, principalFingerprint: "tokenA") + let keyB = MCPRateLimitKey(clientAddress: .loopback, principalFingerprint: "tokenB") + + for _ in 0..<5 { + _ = await limiter.recordAttempt(key: keyA, success: false) + } + let lockedA = await limiter.isLocked(key: keyA) + let lockedB = await limiter.isLocked(key: keyB) + #expect(lockedA == true) + #expect(lockedB == false) + } + + @Test("Same address different principal does not share bucket") + func sameAddressDifferentPrincipal() async { + let clock = MCPTestClock() + let limiter = MCPRateLimiter(clock: clock) + let attacker = MCPRateLimitKey(clientAddress: .loopback, principalFingerprint: "bad") + let legitimate = MCPRateLimitKey(clientAddress: .loopback, principalFingerprint: "good") + + for _ in 0..<5 { + _ = await limiter.recordAttempt(key: attacker, success: false) + } + let allowed = await limiter.recordAttempt(key: legitimate, success: true) + #expect(allowed == .allowed) + } + + @Test("Success resets failure count") + func successResetsFailureCount() async { + let clock = MCPTestClock() + let limiter = MCPRateLimiter( + policy: MCPRateLimitPolicy( + maxFailedAttempts: 5, + windowDuration: .seconds(60), + lockoutDuration: .seconds(300) + ), + clock: clock + ) + let key = standardKey() + + for _ in 0..<3 { + _ = await limiter.recordAttempt(key: key, success: false) + } + _ = await limiter.recordAttempt(key: key, success: true) + + for _ in 0..<4 { + let verdict = await limiter.recordAttempt(key: key, success: false) + #expect(verdict == .allowed) + } + let locked = await limiter.isLocked(key: key) + #expect(locked == false) + } + + @Test("Failures outside window do not count") + func failuresOutsideWindowExpire() async { + let clock = MCPTestClock() + let limiter = MCPRateLimiter( + policy: MCPRateLimitPolicy( + maxFailedAttempts: 5, + windowDuration: .seconds(60), + lockoutDuration: .seconds(300) + ), + clock: clock + ) + let key = standardKey() + + for _ in 0..<4 { + _ = await limiter.recordAttempt(key: key, success: false) + } + await clock.advance(by: .seconds(120)) + let verdict = await limiter.recordAttempt(key: key, success: false) + #expect(verdict == .allowed) + } + + @Test("Reset clears the bucket") + func resetClearsBucket() async { + let clock = MCPTestClock() + let limiter = MCPRateLimiter(clock: clock) + let key = standardKey() + for _ in 0..<5 { + _ = await limiter.recordAttempt(key: key, success: false) + } + await limiter.reset(key: key) + let locked = await limiter.isLocked(key: key) + #expect(locked == false) + } +} diff --git a/TableProTests/Core/MCP/Session/MCPSessionStoreTests.swift b/TableProTests/Core/MCP/Session/MCPSessionStoreTests.swift new file mode 100644 index 000000000..f351c2556 --- /dev/null +++ b/TableProTests/Core/MCP/Session/MCPSessionStoreTests.swift @@ -0,0 +1,182 @@ +import Foundation +@testable import TablePro +import Testing + +@Suite("MCP Session Store") +struct MCPSessionStoreTests { + @Test("Create then lookup returns same session") + func createThenLookup() async throws { + let store = MCPSessionStore() + let session = try await store.create() + let found = await store.session(id: session.id) + #expect(found != nil) + let count = await store.count() + #expect(count == 1) + } + + @Test("Touch updates session lastActivity to current clock time") + func touchUpdatesLastActivity() async throws { + let clock = MCPTestClock(start: Date(timeIntervalSince1970: 1_000_000)) + let store = MCPSessionStore(clock: clock) + let session = try await store.create() + + await clock.advance(by: .seconds(120)) + await store.touch(id: session.id) + + let activity = await session.lastActivityAt + let expected = Date(timeIntervalSince1970: 1_000_000 + 120) + #expect(activity == expected) + } + + @Test("Capacity overflow throws") + func capacityOverflow() async throws { + let policy = MCPSessionPolicy( + idleTimeout: .seconds(900), + maxSessions: 2, + cleanupInterval: .seconds(60) + ) + let store = MCPSessionStore(policy: policy) + _ = try await store.create() + _ = try await store.create() + await #expect(throws: MCPSessionStoreError.self) { + _ = try await store.create() + } + } + + @Test("Idle eviction terminates expired sessions") + func idleEviction() async throws { + let clock = MCPTestClock(start: Date(timeIntervalSince1970: 1_000_000)) + let policy = MCPSessionPolicy( + idleTimeout: .seconds(300), + maxSessions: 16, + cleanupInterval: .seconds(60) + ) + let store = MCPSessionStore(policy: policy, clock: clock) + let active = try await store.create() + let stale = try await store.create() + + await clock.advance(by: .seconds(200)) + await store.touch(id: active.id) + + await clock.advance(by: .seconds(200)) + await store.runCleanupPass() + + let activeFound = await store.session(id: active.id) + let staleFound = await store.session(id: stale.id) + #expect(activeFound != nil) + #expect(staleFound == nil) + } + + @Test("Termination broadcasts to subscribers") + func terminationBroadcastsEvents() async throws { + let store = MCPSessionStore() + let stream = await store.events + + let session = try await store.create() + await store.terminate(id: session.id, reason: .clientRequested) + + var collected: [MCPSessionEvent] = [] + var iterator = stream.makeAsyncIterator() + if let event = await iterator.next() { + collected.append(event) + } + if let event = await iterator.next() { + collected.append(event) + } + + #expect(collected.count == 2) + guard case .created(let createdId) = collected[0] else { + Issue.record("Expected created event, got \(collected[0])") + return + } + guard case .terminated(let terminatedId, let reason) = collected[1] else { + Issue.record("Expected terminated event, got \(collected[1])") + return + } + #expect(createdId == session.id) + #expect(terminatedId == session.id) + #expect(reason == .clientRequested) + } + + @Test("Multiple subscribers receive same events") + func multipleSubscribersReceiveSameEvents() async throws { + let store = MCPSessionStore() + let streamA = await store.events + let streamB = await store.events + + let session = try await store.create() + await store.terminate(id: session.id, reason: .idleTimeout) + + var iteratorA = streamA.makeAsyncIterator() + var iteratorB = streamB.makeAsyncIterator() + + let firstA = await iteratorA.next() + let firstB = await iteratorB.next() + #expect(firstA != nil) + #expect(firstB != nil) + + let secondA = await iteratorA.next() + let secondB = await iteratorB.next() + guard case .terminated(_, let reasonA) = secondA else { + Issue.record("Expected terminated for A") + return + } + guard case .terminated(_, let reasonB) = secondB else { + Issue.record("Expected terminated for B") + return + } + #expect(reasonA == .idleTimeout) + #expect(reasonB == .idleTimeout) + } + + @Test("Terminate on missing id is a no-op") + func terminateMissingIsNoop() async { + let store = MCPSessionStore() + let unknown = MCPSessionId.generate() + await store.terminate(id: unknown, reason: .clientRequested) + let count = await store.count() + #expect(count == 0) + } + + @Test("Cleanup pass with no idle sessions does nothing") + func cleanupNoIdle() async throws { + let clock = MCPTestClock() + let policy = MCPSessionPolicy( + idleTimeout: .seconds(900), + maxSessions: 8, + cleanupInterval: .seconds(60) + ) + let store = MCPSessionStore(policy: policy, clock: clock) + let session = try await store.create() + await clock.advance(by: .seconds(60)) + await store.runCleanupPass() + let found = await store.session(id: session.id) + #expect(found != nil) + } + + @Test("Idle eviction emits idleTimeout event") + func idleEvictionEmitsTimeoutEvent() async throws { + let clock = MCPTestClock(start: Date(timeIntervalSince1970: 2_000_000)) + let policy = MCPSessionPolicy( + idleTimeout: .seconds(60), + maxSessions: 4, + cleanupInterval: .seconds(15) + ) + let store = MCPSessionStore(policy: policy, clock: clock) + let stream = await store.events + let session = try await store.create() + + await clock.advance(by: .seconds(120)) + await store.runCleanupPass() + + var iterator = stream.makeAsyncIterator() + _ = await iterator.next() + let terminationEvent = await iterator.next() + guard case .terminated(let id, let reason) = terminationEvent else { + Issue.record("Expected terminated event, got \(String(describing: terminationEvent))") + return + } + #expect(id == session.id) + #expect(reason == .idleTimeout) + } +} diff --git a/TableProTests/Core/MCP/Session/MCPSessionTests.swift b/TableProTests/Core/MCP/Session/MCPSessionTests.swift new file mode 100644 index 000000000..20b3ef2df --- /dev/null +++ b/TableProTests/Core/MCP/Session/MCPSessionTests.swift @@ -0,0 +1,95 @@ +import Foundation +@testable import TablePro +import Testing + +@Suite("MCP Session") +struct MCPSessionTests { + @Test("New session starts in initializing state") + func newSessionStartsInitializing() async { + let session = MCPSession() + let state = await session.state + #expect(state == .initializing) + } + + @Test("Transition initializing to ready succeeds") + func transitionInitializingToReady() async throws { + let session = MCPSession() + try await session.transitionToReady() + let state = await session.state + #expect(state == .ready) + } + + @Test("Cannot transition to ready twice") + func cannotTransitionToReadyTwice() async throws { + let session = MCPSession() + try await session.transitionToReady() + await #expect(throws: MCPSessionTransitionError.self) { + try await session.transitionToReady() + } + } + + @Test("Cannot transition to ready after termination") + func cannotTransitionAfterTermination() async { + let session = MCPSession() + await session.terminate(reason: .clientRequested) + await #expect(throws: MCPSessionTransitionError.self) { + try await session.transitionToReady() + } + } + + @Test("Touch updates last activity for non-terminated sessions") + func touchUpdatesLastActivity() async { + let start = Date(timeIntervalSince1970: 1_000_000) + let session = MCPSession(now: start) + let later = start.addingTimeInterval(30) + await session.touch(now: later) + let activity = await session.lastActivityAt + #expect(activity == later) + } + + @Test("Touch is ignored after termination") + func touchIgnoredAfterTermination() async { + let start = Date(timeIntervalSince1970: 1_000_000) + let session = MCPSession(now: start) + await session.terminate(reason: .idleTimeout) + let later = start.addingTimeInterval(60) + await session.touch(now: later) + let activity = await session.lastActivityAt + #expect(activity == start) + } + + @Test("recordInitialize stores client info and capabilities") + func recordInitializeStoresInfo() async { + let session = MCPSession() + let info = MCPClientInfo(name: "Claude", version: "1.0") + await session.recordInitialize( + clientInfo: info, + protocolVersion: "2024-11-05", + capabilities: .object(["sampling": .object([:])]) + ) + let stored = await session.clientInfo + let version = await session.negotiatedProtocolVersion + #expect(stored == info) + #expect(version == "2024-11-05") + } + + @Test("Snapshot reflects current state") + func snapshotReflectsState() async throws { + let session = MCPSession() + try await session.transitionToReady() + let info = MCPClientInfo(name: "TestClient", version: nil) + await session.recordInitialize(clientInfo: info, protocolVersion: "v1", capabilities: nil) + let snapshot = await session.snapshot() + #expect(snapshot.state == .ready) + #expect(snapshot.clientInfo == info) + } + + @Test("Termination is idempotent") + func terminationIsIdempotent() async { + let session = MCPSession() + await session.terminate(reason: .clientRequested) + await session.terminate(reason: .idleTimeout) + let state = await session.state + #expect(state == .terminated(reason: .clientRequested)) + } +} diff --git a/TableProTests/Core/MCP/Transport/MCPHttpServerConfigurationTests.swift b/TableProTests/Core/MCP/Transport/MCPHttpServerConfigurationTests.swift new file mode 100644 index 000000000..d722f9628 --- /dev/null +++ b/TableProTests/Core/MCP/Transport/MCPHttpServerConfigurationTests.swift @@ -0,0 +1,72 @@ +import Foundation +@testable import TablePro +import Testing + +@Suite("MCP HTTP Server Configuration") +struct MCPHttpServerConfigurationTests { + @Test("Loopback factory works without TLS") + func loopbackWithoutTls() { + let config = MCPHttpServerConfiguration.loopback(port: 23_508) + #expect(config.bindAddress == .loopback) + #expect(config.port == 23_508) + #expect(config.tls == nil) + #expect(config.limits.maxRequestBodyBytes == 10 * 1_024 * 1_024) + } + + @Test("Standard limits expose 10 MiB body cap and 16 KiB header cap") + func standardLimits() { + let limits = MCPHttpServerLimits.standard + #expect(limits.maxRequestBodyBytes == 10 * 1_024 * 1_024) + #expect(limits.maxHeaderBytes == 16 * 1_024) + #expect(limits.connectionTimeout == .seconds(30)) + } + + @Test("Custom limits are preserved") + func customLimits() { + let limits = MCPHttpServerLimits( + maxRequestBodyBytes: 1_024, + maxHeaderBytes: 512, + connectionTimeout: .seconds(5) + ) + let config = MCPHttpServerConfiguration.loopback(port: 5_000, limits: limits) + #expect(config.limits.maxRequestBodyBytes == 1_024) + #expect(config.limits.maxHeaderBytes == 512) + #expect(config.limits.connectionTimeout == .seconds(5)) + } + + @Test("Loopback factory custom port is preserved") + func customPort() { + let config = MCPHttpServerConfiguration.loopback(port: 65_500) + #expect(config.port == 65_500) + } + + @Test("Transport refuses to start anyInterface bind without TLS") + func remoteRequiresTls() async { + let store = MCPSessionStore() + let authenticator = StubAlwaysAllowAuthenticator() + let unsafe = MCPHttpServerConfiguration.unsafeMake( + bindAddress: .anyInterface, + port: 0, + tls: nil, + limits: .standard + ) + let transport = MCPHttpServerTransport( + configuration: unsafe, + sessionStore: store, + authenticator: authenticator + ) + var captured: Error? + do { + try await transport.start() + } catch { + captured = error + } + #expect(captured is MCPHttpServerError) + if case .tlsRequiredForRemoteAccess = captured as? MCPHttpServerError { + #expect(true) + } else { + Issue.record("Expected tlsRequiredForRemoteAccess, got \(String(describing: captured))") + } + await transport.stop() + } +} diff --git a/TableProTests/Core/MCP/Transport/MCPHttpServerTransportPairingTests.swift b/TableProTests/Core/MCP/Transport/MCPHttpServerTransportPairingTests.swift new file mode 100644 index 000000000..7e66f68d4 --- /dev/null +++ b/TableProTests/Core/MCP/Transport/MCPHttpServerTransportPairingTests.swift @@ -0,0 +1,277 @@ +import Foundation +@testable import TablePro +import Testing + +@Suite("MCP HTTP Server Transport Pairing") +struct MCPHttpServerTransportPairingTests { + private struct ExchangeError: Decodable { + let error: String + } + + private struct ExchangeResponse: Decodable { + let token: String + } + + private func makeTransport( + authenticator: any MCPAuthenticator, + clock: any MCPClock = MCPSystemClock() + ) -> (MCPHttpServerTransport, MCPSessionStore) { + let policy = MCPSessionPolicy( + idleTimeout: .seconds(900), + maxSessions: 16, + cleanupInterval: .seconds(60) + ) + let store = MCPSessionStore(policy: policy, clock: clock) + let config = MCPHttpServerConfiguration.loopback(port: 0) + let transport = MCPHttpServerTransport( + configuration: config, + sessionStore: store, + authenticator: authenticator, + clock: clock + ) + return (transport, store) + } + + private func startedTransport( + authenticator: any MCPAuthenticator, + clock: any MCPClock = MCPSystemClock() + ) async throws -> (MCPHttpServerTransport, UInt16) { + let (transport, _) = makeTransport(authenticator: authenticator, clock: clock) + let stateStream = transport.listenerState + let stateTask = Task { + for await state in stateStream { + if case .running(let port) = state { + return port + } + if case .failed = state { + return nil + } + } + return nil + } + try await transport.start() + guard let port = await stateTask.value, port != 0 else { + await transport.stop() + throw PairingTestError.serverDidNotStart + } + return (transport, port) + } + + private func makeExchangeRequest( + port: UInt16, + body: Data?, + contentType: String = "application/json" + ) -> URLRequest { + guard let url = URL(string: "http://127.0.0.1:\(port)/v1/integrations/exchange") else { + fatalError("Failed to construct test URL") + } + var request = URLRequest(url: url) + request.httpMethod = "POST" + request.setValue(contentType, forHTTPHeaderField: "Content-Type") + if let body { + request.httpBody = body + } + return request + } + + private func insertPairingRecord( + code: String, + plaintextToken: String, + challenge: String, + expiresAt: Date + ) async throws { + try await MainActor.run { + try MCPPairingService.shared.store.insert( + code: code, + record: PairingExchangeRecord( + plaintextToken: plaintextToken, + challenge: challenge, + expiresAt: expiresAt + ) + ) + } + } + + private func clearPairingCode(_ code: String) async { + await MainActor.run { + _ = try? MCPPairingService.shared.store.consume(code: code, verifier: "__cleanup__") + } + } + + private func uniqueCode() -> String { + "test-code-\(UUID().uuidString)" + } + + private func challenge(for verifier: String) -> String { + PairingExchangeStore.sha256Base64Url(of: verifier) + } + + @Test("Empty body returns 400 with invalid JSON body error") + func emptyBodyReturnsBadRequest() async throws { + let auth = StubAlwaysAllowAuthenticator() + let (transport, port) = try await startedTransport(authenticator: auth) + defer { Task { await transport.stop() } } + + let request = makeExchangeRequest(port: port, body: Data()) + let (data, response) = try await URLSession.shared.data(for: request) + let http = try #require(response as? HTTPURLResponse) + + #expect(http.statusCode == 400) + let decoded = try JSONDecoder().decode(ExchangeError.self, from: data) + #expect(decoded.error == "Invalid JSON body") + } + + @Test("Malformed JSON returns 400 with invalid JSON body error") + func malformedJsonReturnsBadRequest() async throws { + let auth = StubAlwaysAllowAuthenticator() + let (transport, port) = try await startedTransport(authenticator: auth) + defer { Task { await transport.stop() } } + + let body = Data("{not-json".utf8) + let request = makeExchangeRequest(port: port, body: body) + let (data, response) = try await URLSession.shared.data(for: request) + let http = try #require(response as? HTTPURLResponse) + + #expect(http.statusCode == 400) + let decoded = try JSONDecoder().decode(ExchangeError.self, from: data) + #expect(decoded.error == "Invalid JSON body") + } + + @Test("Missing code returns 400 with missing code error") + func missingCodeReturnsBadRequest() async throws { + let auth = StubAlwaysAllowAuthenticator() + let (transport, port) = try await startedTransport(authenticator: auth) + defer { Task { await transport.stop() } } + + let body = Data(#"{"code":"","code_verifier":"verifier"}"#.utf8) + let request = makeExchangeRequest(port: port, body: body) + let (data, response) = try await URLSession.shared.data(for: request) + let http = try #require(response as? HTTPURLResponse) + + #expect(http.statusCode == 400) + let decoded = try JSONDecoder().decode(ExchangeError.self, from: data) + #expect(decoded.error == "Missing code or code_verifier") + } + + @Test("Missing code_verifier returns 400 with missing code error") + func missingCodeVerifierReturnsBadRequest() async throws { + let auth = StubAlwaysAllowAuthenticator() + let (transport, port) = try await startedTransport(authenticator: auth) + defer { Task { await transport.stop() } } + + let body = Data(#"{"code":"abc","code_verifier":""}"#.utf8) + let request = makeExchangeRequest(port: port, body: body) + let (data, response) = try await URLSession.shared.data(for: request) + let http = try #require(response as? HTTPURLResponse) + + #expect(http.statusCode == 400) + let decoded = try JSONDecoder().decode(ExchangeError.self, from: data) + #expect(decoded.error == "Missing code or code_verifier") + } + + @Test("Unknown code returns 404 with not-found error") + func unknownCodeReturnsNotFound() async throws { + let auth = StubAlwaysAllowAuthenticator() + let (transport, port) = try await startedTransport(authenticator: auth) + defer { Task { await transport.stop() } } + + let synthetic = "synthetic-\(UUID().uuidString)" + let body = Data(#"{"code":"\#(synthetic)","code_verifier":"any-verifier"}"#.utf8) + let request = makeExchangeRequest(port: port, body: body) + let (data, response) = try await URLSession.shared.data(for: request) + let http = try #require(response as? HTTPURLResponse) + + #expect(http.statusCode == 404) + let decoded = try JSONDecoder().decode(ExchangeError.self, from: data) + #expect(decoded.error == "Pairing code not found") + } + + @Test("Successful exchange returns 200 with token in body") + func successfulExchangeReturnsToken() async throws { + let auth = StubAlwaysAllowAuthenticator() + let (transport, port) = try await startedTransport(authenticator: auth) + defer { Task { await transport.stop() } } + + let code = uniqueCode() + let verifier = "verifier-\(UUID().uuidString)" + let plaintext = "tp_test-token-\(UUID().uuidString)" + try await insertPairingRecord( + code: code, + plaintextToken: plaintext, + challenge: challenge(for: verifier), + expiresAt: Date.now.addingTimeInterval(60) + ) + defer { Task { await clearPairingCode(code) } } + + let payload = ["code": code, "code_verifier": verifier] + let body = try JSONSerialization.data(withJSONObject: payload, options: [.sortedKeys]) + + let request = makeExchangeRequest(port: port, body: body) + let (data, response) = try await URLSession.shared.data(for: request) + let http = try #require(response as? HTTPURLResponse) + + #expect(http.statusCode == 200) + let decoded = try JSONDecoder().decode(ExchangeResponse.self, from: data) + #expect(decoded.token == plaintext) + } + + @Test("Mismatched verifier returns 403 with challenge mismatch error") + func mismatchedVerifierReturnsForbidden() async throws { + let auth = StubAlwaysAllowAuthenticator() + let (transport, port) = try await startedTransport(authenticator: auth) + defer { Task { await transport.stop() } } + + let code = uniqueCode() + let realVerifier = "real-verifier-\(UUID().uuidString)" + try await insertPairingRecord( + code: code, + plaintextToken: "tp_test", + challenge: challenge(for: realVerifier), + expiresAt: Date.now.addingTimeInterval(60) + ) + defer { Task { await clearPairingCode(code) } } + + let payload = ["code": code, "code_verifier": "wrong-verifier"] + let body = try JSONSerialization.data(withJSONObject: payload, options: [.sortedKeys]) + + let request = makeExchangeRequest(port: port, body: body) + let (data, response) = try await URLSession.shared.data(for: request) + let http = try #require(response as? HTTPURLResponse) + + #expect(http.statusCode == 403) + let decoded = try JSONDecoder().decode(ExchangeError.self, from: data) + #expect(decoded.error == "Challenge mismatch") + } + + @Test("Expired pairing code is unredeemable") + func expiredCodeIsUnredeemable() async throws { + let auth = StubAlwaysAllowAuthenticator() + let (transport, port) = try await startedTransport(authenticator: auth) + defer { Task { await transport.stop() } } + + let code = uniqueCode() + let verifier = "verifier-\(UUID().uuidString)" + try await insertPairingRecord( + code: code, + plaintextToken: "tp_test", + challenge: challenge(for: verifier), + expiresAt: Date.now.addingTimeInterval(-60) + ) + defer { Task { await clearPairingCode(code) } } + + let payload = ["code": code, "code_verifier": verifier] + let body = try JSONSerialization.data(withJSONObject: payload, options: [.sortedKeys]) + + let request = makeExchangeRequest(port: port, body: body) + let (data, response) = try await URLSession.shared.data(for: request) + let http = try #require(response as? HTTPURLResponse) + + #expect(http.statusCode == 410 || http.statusCode == 404) + let decoded = try JSONDecoder().decode(ExchangeError.self, from: data) + #expect(decoded.error == "Pairing code expired" || decoded.error == "Pairing code not found") + } +} + +private enum PairingTestError: Error { + case serverDidNotStart +} diff --git a/TableProTests/Core/MCP/Transport/MCPHttpServerTransportTests.swift b/TableProTests/Core/MCP/Transport/MCPHttpServerTransportTests.swift new file mode 100644 index 000000000..f4ad3d169 --- /dev/null +++ b/TableProTests/Core/MCP/Transport/MCPHttpServerTransportTests.swift @@ -0,0 +1,591 @@ +import Foundation +@testable import TablePro +import Testing + +@Suite("MCP HTTP Server Transport") +struct MCPHttpServerTransportTests { + private static let mcpVersion = "2024-11-05" + + private func makeTransport( + authenticator: any MCPAuthenticator, + clock: any MCPClock = MCPSystemClock(), + sessionPolicy: MCPSessionPolicy = MCPSessionPolicy( + idleTimeout: .seconds(900), + maxSessions: 16, + cleanupInterval: .seconds(60) + ) + ) -> (MCPHttpServerTransport, MCPSessionStore) { + let store = MCPSessionStore(policy: sessionPolicy, clock: clock) + let config = MCPHttpServerConfiguration.loopback(port: 0) + let transport = MCPHttpServerTransport( + configuration: config, + sessionStore: store, + authenticator: authenticator, + clock: clock + ) + return (transport, store) + } + + private func startedTransport( + authenticator: any MCPAuthenticator, + clock: any MCPClock = MCPSystemClock(), + sessionPolicy: MCPSessionPolicy = MCPSessionPolicy( + idleTimeout: .seconds(900), + maxSessions: 16, + cleanupInterval: .seconds(60) + ) + ) async throws -> (MCPHttpServerTransport, MCPSessionStore, UInt16) { + let (transport, store) = makeTransport( + authenticator: authenticator, + clock: clock, + sessionPolicy: sessionPolicy + ) + + let stateStream = transport.listenerState + let stateTask = Task { + for await state in stateStream { + if case .running(let port) = state { + return port + } + if case .failed = state { + return nil + } + } + return nil + } + + try await transport.start() + guard let port = await stateTask.value, port != 0 else { + await transport.stop() + throw TestError.serverDidNotStart + } + return (transport, store, port) + } + + private func makePost( + port: UInt16, + body: Data, + sessionId: String? = nil, + authorization: String? = "Bearer test-token", + contentType: String = "application/json" + ) -> URLRequest { + guard let url = URL(string: "http://127.0.0.1:\(port)/mcp") else { + fatalError("Failed to construct test URL") + } + var request = URLRequest(url: url) + request.httpMethod = "POST" + request.httpBody = body + request.setValue(contentType, forHTTPHeaderField: "Content-Type") + request.setValue(Self.mcpVersion, forHTTPHeaderField: "mcp-protocol-version") + if let sessionId { + request.setValue(sessionId, forHTTPHeaderField: "Mcp-Session-Id") + } + if let authorization { + request.setValue(authorization, forHTTPHeaderField: "Authorization") + } + return request + } + + private func makeOptions(port: UInt16, origin: String? = "http://localhost") -> URLRequest { + guard let url = URL(string: "http://127.0.0.1:\(port)/mcp") else { + fatalError("Failed to construct test URL") + } + var request = URLRequest(url: url) + request.httpMethod = "OPTIONS" + request.setValue("Bearer test-token", forHTTPHeaderField: "Authorization") + if let origin { + request.setValue(origin, forHTTPHeaderField: "Origin") + } + return request + } + + private func makeRequestBody(method: String, id: Int = 1) throws -> Data { + let request = JsonRpcRequest(id: .number(Int64(id)), method: method, params: nil) + return try JsonRpcCodec.encode(.request(request)) + } + + private func parseJsonRpcError(_ data: Data) throws -> (id: JsonRpcId?, code: Int, message: String) { + let decoded = try JsonRpcCodec.decode(data) + guard case .errorResponse(let envelope) = decoded else { + throw TestError.expectedErrorEnvelope + } + return (envelope.id, envelope.error.code, envelope.error.message) + } + + private func runEchoLoop( + transport: MCPHttpServerTransport, + consumer: StubExchangeConsumer, + successResult: JsonValue = .object(["ok": .bool(true)]) + ) async { + await consumer.start(transport: transport) { exchange in + switch exchange.message { + case .request(let request): + let response = JsonRpcMessage.successResponse( + JsonRpcSuccessResponse(id: request.id, result: successResult) + ) + await exchange.responder.respond(response, sessionId: exchange.context.sessionId) + case .notification: + await exchange.responder.acknowledgeAccepted() + default: + await exchange.responder.respondError(.invalidRequest(detail: "unsupported"), requestId: nil) + } + } + } + + @Test("Initialize creates session and returns Mcp-Session-Id header") + func initializeCreatesSession() async throws { + let auth = StubAlwaysAllowAuthenticator() + let (transport, _, port) = try await startedTransport(authenticator: auth) + defer { Task { await transport.stop() } } + + let consumer = StubExchangeConsumer() + await runEchoLoop(transport: transport, consumer: consumer) + defer { Task { await consumer.stop() } } + + let body = try makeRequestBody(method: "initialize") + let request = makePost(port: port, body: body) + let (data, response) = try await URLSession.shared.data(for: request) + let httpResponse = try #require(response as? HTTPURLResponse) + + #expect(httpResponse.statusCode == 200) + #expect(httpResponse.value(forHTTPHeaderField: "Mcp-Session-Id") != nil) + + let decoded = try JsonRpcCodec.decode(data) + guard case .successResponse = decoded else { + Issue.record("Expected success response, got \(decoded)") + return + } + } + + @Test("Tool call with valid session returns 200 and session header") + func toolCallWithValidSession() async throws { + let auth = StubAlwaysAllowAuthenticator() + let (transport, _, port) = try await startedTransport(authenticator: auth) + defer { Task { await transport.stop() } } + + let consumer = StubExchangeConsumer() + await runEchoLoop(transport: transport, consumer: consumer) + defer { Task { await consumer.stop() } } + + let initBody = try makeRequestBody(method: "initialize", id: 1) + let (_, initResponse) = try await URLSession.shared.data(for: makePost(port: port, body: initBody)) + let initHttp = try #require(initResponse as? HTTPURLResponse) + let sessionId = try #require(initHttp.value(forHTTPHeaderField: "Mcp-Session-Id")) + + let toolBody = try makeRequestBody(method: "tools/call", id: 2) + let (toolData, toolResponse) = try await URLSession.shared.data( + for: makePost(port: port, body: toolBody, sessionId: sessionId) + ) + let toolHttp = try #require(toolResponse as? HTTPURLResponse) + + #expect(toolHttp.statusCode == 200) + let decoded = try JsonRpcCodec.decode(toolData) + guard case .successResponse = decoded else { + Issue.record("Expected success response, got \(decoded)") + return + } + } + + @Test("Tool call without session id returns 400 with JSON-RPC error envelope") + func toolCallMissingSessionId() async throws { + let auth = StubAlwaysAllowAuthenticator() + let (transport, _, port) = try await startedTransport(authenticator: auth) + defer { Task { await transport.stop() } } + + let consumer = StubExchangeConsumer() + await runEchoLoop(transport: transport, consumer: consumer) + defer { Task { await consumer.stop() } } + + let body = try makeRequestBody(method: "tools/call", id: 7) + let (data, response) = try await URLSession.shared.data(for: makePost(port: port, body: body)) + let http = try #require(response as? HTTPURLResponse) + + #expect(http.statusCode == 400) + let parsed = try parseJsonRpcError(data) + #expect(parsed.code == JsonRpcErrorCode.invalidRequest) + } + + @Test("Tool call with stale session id returns 404 with JSON-RPC error envelope") + func toolCallStaleSession() async throws { + let auth = StubAlwaysAllowAuthenticator() + let (transport, _, port) = try await startedTransport(authenticator: auth) + defer { Task { await transport.stop() } } + + let consumer = StubExchangeConsumer() + await runEchoLoop(transport: transport, consumer: consumer) + defer { Task { await consumer.stop() } } + + let body = try makeRequestBody(method: "tools/call", id: 8) + let (data, response) = try await URLSession.shared.data( + for: makePost(port: port, body: body, sessionId: "nonexistent-session-id") + ) + let http = try #require(response as? HTTPURLResponse) + + #expect(http.statusCode == 404) + let parsed = try parseJsonRpcError(data) + #expect(parsed.code == JsonRpcErrorCode.sessionNotFound) + } + + @Test("Missing Authorization returns 401 with WWW-Authenticate") + func missingAuthorization() async throws { + let auth = StubBearerAuthenticator(validToken: "valid") + let (transport, _, port) = try await startedTransport(authenticator: auth) + defer { Task { await transport.stop() } } + + let consumer = StubExchangeConsumer() + await runEchoLoop(transport: transport, consumer: consumer) + defer { Task { await consumer.stop() } } + + let body = try makeRequestBody(method: "initialize", id: 1) + let request = makePost(port: port, body: body, authorization: nil) + let (data, response) = try await URLSession.shared.data(for: request) + let http = try #require(response as? HTTPURLResponse) + + #expect(http.statusCode == 401) + let challenge = http.value(forHTTPHeaderField: "Www-Authenticate") ?? http.value(forHTTPHeaderField: "WWW-Authenticate") + #expect(challenge?.contains("Bearer") == true) + let parsed = try parseJsonRpcError(data) + #expect(parsed.code != 0) + } + + @Test("Bad bearer token returns 401 with JSON-RPC error envelope") + func badBearerToken() async throws { + let auth = StubBearerAuthenticator(validToken: "valid") + let (transport, _, port) = try await startedTransport(authenticator: auth) + defer { Task { await transport.stop() } } + + let consumer = StubExchangeConsumer() + await runEchoLoop(transport: transport, consumer: consumer) + defer { Task { await consumer.stop() } } + + let body = try makeRequestBody(method: "initialize", id: 1) + let request = makePost(port: port, body: body, authorization: "Bearer wrong-token") + let (data, response) = try await URLSession.shared.data(for: request) + let http = try #require(response as? HTTPURLResponse) + + #expect(http.statusCode == 401) + let parsed = try parseJsonRpcError(data) + #expect(parsed.code != 0) + } + + @Test("Rate limit kicks in after repeated bad attempts and includes Retry-After") + func rateLimitAfterBadAttempts() async throws { + let auth = StubBearerAuthenticator(validToken: "valid", maxAttempts: 3) + let (transport, _, port) = try await startedTransport(authenticator: auth) + defer { Task { await transport.stop() } } + + let consumer = StubExchangeConsumer() + await runEchoLoop(transport: transport, consumer: consumer) + defer { Task { await consumer.stop() } } + + let body = try makeRequestBody(method: "initialize", id: 1) + + for _ in 0..<3 { + let request = makePost(port: port, body: body, authorization: "Bearer wrong-token") + _ = try await URLSession.shared.data(for: request) + } + + let request = makePost(port: port, body: body, authorization: "Bearer wrong-token") + let (data, response) = try await URLSession.shared.data(for: request) + let http = try #require(response as? HTTPURLResponse) + + #expect(http.statusCode == 429) + let retryAfter = http.value(forHTTPHeaderField: "Retry-After") + #expect(retryAfter == "30") + let parsed = try parseJsonRpcError(data) + #expect(parsed.code != 0) + } + + @Test("Payload too large returns 413 with JSON-RPC error envelope") + func payloadTooLarge() async throws { + let auth = StubAlwaysAllowAuthenticator() + let limits = MCPHttpServerLimits( + maxRequestBodyBytes: 1_024, + maxHeaderBytes: 16 * 1_024, + connectionTimeout: .seconds(30) + ) + let store = MCPSessionStore() + let config = MCPHttpServerConfiguration.loopback(port: 0, limits: limits) + let transport = MCPHttpServerTransport( + configuration: config, + sessionStore: store, + authenticator: auth + ) + + let stateStream = transport.listenerState + let stateTask = Task { + for await state in stateStream { + if case .running(let port) = state { return port } + if case .failed = state { return nil } + } + return nil + } + try await transport.start() + let port = try #require(await stateTask.value) + defer { Task { await transport.stop() } } + + let consumer = StubExchangeConsumer() + await runEchoLoop(transport: transport, consumer: consumer) + defer { Task { await consumer.stop() } } + + let bigBody = Data(repeating: 0x41, count: 2_048) + let request = makePost(port: port, body: bigBody) + let (_, response) = try await URLSession.shared.data(for: request) + let http = try #require(response as? HTTPURLResponse) + #expect(http.statusCode == 413) + } + + @Test("Method not found at unknown path returns 404 with JSON-RPC error envelope") + func unknownPathReturns404() async throws { + let auth = StubAlwaysAllowAuthenticator() + let (transport, _, port) = try await startedTransport(authenticator: auth) + defer { Task { await transport.stop() } } + + let consumer = StubExchangeConsumer() + await runEchoLoop(transport: transport, consumer: consumer) + defer { Task { await consumer.stop() } } + + guard let url = URL(string: "http://127.0.0.1:\(port)/foo") else { + Issue.record("Failed to construct URL") + return + } + var request = URLRequest(url: url) + request.httpMethod = "GET" + request.setValue("Bearer test", forHTTPHeaderField: "Authorization") + let (data, response) = try await URLSession.shared.data(for: request) + let http = try #require(response as? HTTPURLResponse) + + #expect(http.statusCode == 404) + let parsed = try parseJsonRpcError(data) + #expect(parsed.code == JsonRpcErrorCode.methodNotFound) + } + + @Test("OPTIONS request returns 204 with CORS headers reflecting allowed origin") + func optionsReturnsNoContent() async throws { + let auth = StubAlwaysAllowAuthenticator() + let (transport, _, port) = try await startedTransport(authenticator: auth) + defer { Task { await transport.stop() } } + + let request = makeOptions(port: port, origin: "http://localhost") + let (_, response) = try await URLSession.shared.data(for: request) + let http = try #require(response as? HTTPURLResponse) + + #expect(http.statusCode == 204) + let allowOrigin = http.value(forHTTPHeaderField: "Access-Control-Allow-Origin") + #expect(allowOrigin == "http://localhost") + let allowHeaders = http.value(forHTTPHeaderField: "Access-Control-Allow-Headers") + #expect(allowHeaders?.contains("Last-Event-ID") == true) + } + + @Test("OPTIONS request from disallowed origin omits CORS headers") + func optionsDisallowedOriginOmitsCors() async throws { + let auth = StubAlwaysAllowAuthenticator() + let (transport, _, port) = try await startedTransport(authenticator: auth) + defer { Task { await transport.stop() } } + + let request = makeOptions(port: port, origin: "https://evil.example.com") + let (_, response) = try await URLSession.shared.data(for: request) + let http = try #require(response as? HTTPURLResponse) + + #expect(http.statusCode == 204) + #expect(http.value(forHTTPHeaderField: "Access-Control-Allow-Origin") == nil) + } + + @Test("OPTIONS request without Origin header omits CORS headers") + func optionsWithoutOriginOmitsCors() async throws { + let auth = StubAlwaysAllowAuthenticator() + let (transport, _, port) = try await startedTransport(authenticator: auth) + defer { Task { await transport.stop() } } + + let request = makeOptions(port: port, origin: nil) + let (_, response) = try await URLSession.shared.data(for: request) + let http = try #require(response as? HTTPURLResponse) + + #expect(http.statusCode == 204) + #expect(http.value(forHTTPHeaderField: "Access-Control-Allow-Origin") == nil) + } + + @Test("Initialize with unsupported protocolVersion returns invalid_request error") + func initializeRejectsUnsupportedProtocolVersion() async throws { + let auth = StubAlwaysAllowAuthenticator() + let (transport, _, port) = try await startedTransport(authenticator: auth) + defer { Task { await transport.stop() } } + + let consumer = StubExchangeConsumer() + let store = MCPSessionStore() + let progressSink = NullProgressSink() + let dispatcher = MCPProtocolDispatcher( + handlers: [InitializeHandler()], + sessionStore: store, + progressSink: progressSink + ) + await consumer.start(transport: transport) { exchange in + await dispatcher.dispatch(exchange) + } + defer { Task { await consumer.stop() } } + + let request = JsonRpcRequest( + id: .number(1), + method: "initialize", + params: .object(["protocolVersion": .string("1999-01-01")]) + ) + let body = try JsonRpcCodec.encode(.request(request)) + let httpRequest = makePost(port: port, body: body) + let (data, response) = try await URLSession.shared.data(for: httpRequest) + let http = try #require(response as? HTTPURLResponse) + + #expect(http.statusCode == 400) + let parsed = try parseJsonRpcError(data) + #expect(parsed.code == JsonRpcErrorCode.invalidRequest) + } + + @Test("Subsequent request with mismatched MCP-Protocol-Version is rejected") + func mismatchedProtocolVersionHeaderRejected() async throws { + let auth = StubAlwaysAllowAuthenticator() + let (transport, _, port) = try await startedTransport(authenticator: auth) + defer { Task { await transport.stop() } } + + let consumer = StubExchangeConsumer() + let store = MCPSessionStore() + let progressSink = NullProgressSink() + let dispatcher = MCPProtocolDispatcher( + handlers: [InitializeHandler(), PingHandler()], + sessionStore: store, + progressSink: progressSink + ) + await consumer.start(transport: transport) { exchange in + await dispatcher.dispatch(exchange) + } + defer { Task { await consumer.stop() } } + + let initializeRequest = JsonRpcRequest( + id: .number(1), + method: "initialize", + params: .object(["protocolVersion": .string(InitializeHandler.supportedProtocolVersion)]) + ) + let initBody = try JsonRpcCodec.encode(.request(initializeRequest)) + let (_, initResponse) = try await URLSession.shared.data(for: makePost(port: port, body: initBody)) + let initHttp = try #require(initResponse as? HTTPURLResponse) + let sessionId = try #require(initHttp.value(forHTTPHeaderField: "Mcp-Session-Id")) + + let initialized = JsonRpcNotification(method: "notifications/initialized", params: nil) + let initializedBody = try JsonRpcCodec.encode(.notification(initialized)) + var initializedRequest = makePost(port: port, body: initializedBody, sessionId: sessionId) + _ = try await URLSession.shared.data(for: initializedRequest) + _ = initializedRequest + + let pingRequest = JsonRpcRequest(id: .number(2), method: "ping", params: nil) + let pingBody = try JsonRpcCodec.encode(.request(pingRequest)) + guard let url = URL(string: "http://127.0.0.1:\(port)/mcp") else { + Issue.record("Failed to construct URL") + return + } + var mismatched = URLRequest(url: url) + mismatched.httpMethod = "POST" + mismatched.httpBody = pingBody + mismatched.setValue("application/json", forHTTPHeaderField: "Content-Type") + mismatched.setValue("1999-01-01", forHTTPHeaderField: "mcp-protocol-version") + mismatched.setValue(sessionId, forHTTPHeaderField: "Mcp-Session-Id") + mismatched.setValue("Bearer test-token", forHTTPHeaderField: "Authorization") + let (data, response) = try await URLSession.shared.data(for: mismatched) + let http = try #require(response as? HTTPURLResponse) + + #expect(http.statusCode == 400) + let parsed = try parseJsonRpcError(data) + #expect(parsed.code == JsonRpcErrorCode.invalidRequest) + } + + @Test("GET /mcp opens an SSE stream that delivers server-sent notifications") + func getMcpStreamsServerNotifications() async throws { + let auth = StubAlwaysAllowAuthenticator() + let (transport, _, port) = try await startedTransport(authenticator: auth) + defer { Task { await transport.stop() } } + + let consumer = StubExchangeConsumer() + await runEchoLoop(transport: transport, consumer: consumer) + defer { Task { await consumer.stop() } } + + let initBody = try makeRequestBody(method: "initialize") + let (_, initResponse) = try await URLSession.shared.data(for: makePost(port: port, body: initBody)) + let initHttp = try #require(initResponse as? HTTPURLResponse) + let sessionId = try #require(initHttp.value(forHTTPHeaderField: "Mcp-Session-Id")) + + guard let url = URL(string: "http://127.0.0.1:\(port)/mcp") else { + Issue.record("Failed to construct URL") + return + } + var request = URLRequest(url: url) + request.httpMethod = "GET" + request.setValue("text/event-stream", forHTTPHeaderField: "Accept") + request.setValue(sessionId, forHTTPHeaderField: "Mcp-Session-Id") + request.setValue("Bearer test-token", forHTTPHeaderField: "Authorization") + request.timeoutInterval = 5 + + let session = URLSession(configuration: .ephemeral) + let streamTask = Task<(Int, String), Error> { + let (bytes, response) = try await session.bytes(for: request) + let httpResponse = response as? HTTPURLResponse + var collected = "" + for try await line in bytes.lines { + collected += line + "\n" + if collected.contains("notifications/test") { break } + } + return (httpResponse?.statusCode ?? 0, collected) + } + + try await Task.sleep(for: .milliseconds(200)) + + let notification = JsonRpcNotification( + method: "notifications/test", + params: .object(["progress": .double(0.5)]) + ) + await transport.sendNotification(notification, toSession: MCPSessionId(sessionId)) + + let (status, body) = try await streamTask.value + #expect(status == 200) + #expect(body.contains("notifications/test")) + session.invalidateAndCancel() + } + + @Test("Idle session eviction terminates SSE-tracked sessions") + func idleSessionEviction() async throws { + let clock = MCPTestClock(start: Date(timeIntervalSince1970: 1_000_000)) + let auth = StubAlwaysAllowAuthenticator() + let policy = MCPSessionPolicy( + idleTimeout: .seconds(60), + maxSessions: 16, + cleanupInterval: .seconds(60) + ) + let (transport, store, port) = try await startedTransport( + authenticator: auth, + clock: clock, + sessionPolicy: policy + ) + defer { Task { await transport.stop() } } + + let consumer = StubExchangeConsumer() + await runEchoLoop(transport: transport, consumer: consumer) + defer { Task { await consumer.stop() } } + + let initBody = try makeRequestBody(method: "initialize") + let (_, initResponse) = try await URLSession.shared.data(for: makePost(port: port, body: initBody)) + let initHttp = try #require(initResponse as? HTTPURLResponse) + let sessionId = try #require(initHttp.value(forHTTPHeaderField: "Mcp-Session-Id")) + + await clock.advance(by: .seconds(120)) + await store.runCleanupPass() + + let body = try makeRequestBody(method: "tools/call", id: 9) + let request = makePost(port: port, body: body, sessionId: sessionId) + let (data, response) = try await URLSession.shared.data(for: request) + let http = try #require(response as? HTTPURLResponse) + #expect(http.statusCode == 404) + let parsed = try parseJsonRpcError(data) + #expect(parsed.code == JsonRpcErrorCode.sessionNotFound) + } +} + +private enum TestError: Error { + case serverDidNotStart + case expectedErrorEnvelope +} diff --git a/TableProTests/Core/MCP/Transport/MCPProtocolErrorTests.swift b/TableProTests/Core/MCP/Transport/MCPProtocolErrorTests.swift new file mode 100644 index 000000000..79695a86c --- /dev/null +++ b/TableProTests/Core/MCP/Transport/MCPProtocolErrorTests.swift @@ -0,0 +1,127 @@ +import Foundation +@testable import TablePro +import XCTest + +final class MCPProtocolErrorTests: XCTestCase { + func testSessionNotFoundMapping() { + let error = MCPProtocolError.sessionNotFound() + XCTAssertEqual(error.code, JsonRpcErrorCode.sessionNotFound) + XCTAssertEqual(error.httpStatus, .notFound) + } + + func testMissingSessionIdMapping() { + let error = MCPProtocolError.missingSessionId() + XCTAssertEqual(error.code, JsonRpcErrorCode.invalidRequest) + XCTAssertEqual(error.httpStatus, .badRequest) + } + + func testParseErrorMapping() { + let error = MCPProtocolError.parseError(detail: "bad json") + XCTAssertEqual(error.code, JsonRpcErrorCode.parseError) + XCTAssertEqual(error.httpStatus, .badRequest) + XCTAssertTrue(error.message.contains("bad json")) + } + + func testInvalidRequestMapping() { + let error = MCPProtocolError.invalidRequest(detail: "missing method") + XCTAssertEqual(error.code, JsonRpcErrorCode.invalidRequest) + XCTAssertEqual(error.httpStatus, .badRequest) + } + + func testMethodNotFoundIsHttp200() { + let error = MCPProtocolError.methodNotFound(method: "tools/foo") + XCTAssertEqual(error.code, JsonRpcErrorCode.methodNotFound) + XCTAssertEqual(error.httpStatus, .ok) + } + + func testInvalidParamsIsHttp200() { + let error = MCPProtocolError.invalidParams(detail: "expected object") + XCTAssertEqual(error.code, JsonRpcErrorCode.invalidParams) + XCTAssertEqual(error.httpStatus, .ok) + } + + func testInternalErrorMapping() { + let error = MCPProtocolError.internalError(detail: "boom") + XCTAssertEqual(error.code, JsonRpcErrorCode.internalError) + XCTAssertEqual(error.httpStatus, .internalServerError) + } + + func testUnauthenticatedIncludesWwwAuthenticate() { + let error = MCPProtocolError.unauthenticated(challenge: "Bearer realm=\"x\"") + XCTAssertEqual(error.code, JsonRpcErrorCode.unauthenticated) + XCTAssertEqual(error.httpStatus, .unauthorized) + let header = error.extraHeaders.first { $0.0.lowercased() == "www-authenticate" } + XCTAssertNotNil(header) + XCTAssertEqual(header?.1, "Bearer realm=\"x\"") + } + + func testTokenInvalidIncludesWwwAuthenticate() { + let error = MCPProtocolError.tokenInvalid() + XCTAssertEqual(error.httpStatus, .unauthorized) + XCTAssertTrue(error.extraHeaders.contains { $0.0.lowercased() == "www-authenticate" }) + } + + func testTokenExpiredIncludesWwwAuthenticate() { + let error = MCPProtocolError.tokenExpired() + XCTAssertEqual(error.httpStatus, .unauthorized) + XCTAssertTrue(error.extraHeaders.contains { $0.0.lowercased() == "www-authenticate" }) + } + + func testForbiddenMapping() { + let error = MCPProtocolError.forbidden(reason: "policy") + XCTAssertEqual(error.code, JsonRpcErrorCode.forbidden) + XCTAssertEqual(error.httpStatus, .forbidden) + } + + func testRateLimitedMapping() { + let error = MCPProtocolError.rateLimited() + XCTAssertEqual(error.httpStatus, .tooManyRequests) + } + + func testPayloadTooLargeMapping() { + let error = MCPProtocolError.payloadTooLarge() + XCTAssertEqual(error.code, JsonRpcErrorCode.tooLarge) + XCTAssertEqual(error.httpStatus, .payloadTooLarge) + } + + func testNotAcceptableMapping() { + let error = MCPProtocolError.notAcceptable() + XCTAssertEqual(error.httpStatus, .notAcceptable) + } + + func testUnsupportedMediaTypeMapping() { + let error = MCPProtocolError.unsupportedMediaType() + XCTAssertEqual(error.httpStatus, .unsupportedMediaType) + } + + func testServiceUnavailableMapping() { + let error = MCPProtocolError.serviceUnavailable() + XCTAssertEqual(error.httpStatus, .serviceUnavailable) + } + + func testToJsonRpcErrorResponseRoundTrip() { + let protocolError = MCPProtocolError.sessionNotFound() + let response = protocolError.toJsonRpcErrorResponse(id: .number(7)) + XCTAssertEqual(response.id, .number(7)) + XCTAssertEqual(response.error.code, JsonRpcErrorCode.sessionNotFound) + XCTAssertEqual(response.error.message, "Session not found") + } + + func testToJsonRpcErrorResponseWithNilId() { + let protocolError = MCPProtocolError.parseError(detail: "x") + let response = protocolError.toJsonRpcErrorResponse(id: nil) + XCTAssertNil(response.id) + XCTAssertEqual(response.error.code, JsonRpcErrorCode.parseError) + } + + func testEqualityIgnoresHeadersAndStatus() { + let lhs = MCPProtocolError(code: -1, message: "x", httpStatus: .ok) + let rhs = MCPProtocolError( + code: -1, + message: "x", + httpStatus: .badRequest, + extraHeaders: [("X", "Y")] + ) + XCTAssertEqual(lhs, rhs) + } +} diff --git a/TableProTests/Core/MCP/Transport/MCPStdioMessageTransportTests.swift b/TableProTests/Core/MCP/Transport/MCPStdioMessageTransportTests.swift new file mode 100644 index 000000000..dda06aba2 --- /dev/null +++ b/TableProTests/Core/MCP/Transport/MCPStdioMessageTransportTests.swift @@ -0,0 +1,181 @@ +import Foundation +@testable import TablePro +import XCTest + +final class MCPStdioMessageTransportTests: XCTestCase { + private var stdinPipe: Pipe! + private var stdoutPipe: Pipe! + private var logger: FakeBridgeLogger! + + override func setUp() { + super.setUp() + stdinPipe = Pipe() + stdoutPipe = Pipe() + logger = FakeBridgeLogger() + } + + override func tearDown() { + stdinPipe = nil + stdoutPipe = nil + logger = nil + super.tearDown() + } + + func testReceivesValidLine() async throws { + let transport = makeTransport() + + let message = JsonRpcMessage.request( + JsonRpcRequest(id: .number(1), method: "ping", params: nil) + ) + let line = try JsonRpcCodec.encodeLine(message) + stdinPipe.fileHandleForWriting.write(line) + + let received = try await firstInbound(transport: transport) + XCTAssertEqual(received, message) + await transport.close() + } + + func testSkipsMalformedLineAndContinues() async throws { + let transport = makeTransport() + + stdinPipe.fileHandleForWriting.write(Data("not json at all\n".utf8)) + + let valid = JsonRpcMessage.notification( + JsonRpcNotification(method: "notifications/initialized", params: nil) + ) + try stdinPipe.fileHandleForWriting.write(contentsOf: try JsonRpcCodec.encodeLine(valid)) + + let received = try await firstInbound(transport: transport) + XCTAssertEqual(received, valid) + XCTAssertTrue(logger.entries.contains { $0.level == .warning && $0.message.contains("malformed") }) + await transport.close() + } + + func testHandlesBytesSplitAcrossWrites() async throws { + let transport = makeTransport() + + let message = JsonRpcMessage.request( + JsonRpcRequest(id: .number(42), method: "tools/list", params: nil) + ) + let line = try JsonRpcCodec.encodeLine(message) + let half = line.count / 2 + stdinPipe.fileHandleForWriting.write(Data(line.prefix(half))) + try await Task.sleep(nanoseconds: 50_000_000) + stdinPipe.fileHandleForWriting.write(Data(line.suffix(from: half))) + + let received = try await firstInbound(transport: transport) + XCTAssertEqual(received, message) + await transport.close() + } + + func testSendWritesValidJsonRpcLineToStdout() async throws { + let transport = makeTransport() + + let message = JsonRpcMessage.successResponse( + JsonRpcSuccessResponse(id: .number(3), result: .object(["ok": .bool(true)])) + ) + try await transport.send(message) + + try await Task.sleep(nanoseconds: 50_000_000) + + let written = stdoutPipe.fileHandleForReading.availableData + XCTAssertFalse(written.isEmpty) + XCTAssertEqual(written.last, 0x0A) + let trimmed = written.dropLast() + let decoded = try JsonRpcCodec.decode(trimmed) + XCTAssertEqual(decoded, message) + + await transport.close() + } + + func testInboundFinishesOnEof() async throws { + let transport = makeTransport() + + try stdinPipe.fileHandleForWriting.close() + + var iterator = transport.inbound.makeAsyncIterator() + let value = try await iterator.next() + XCTAssertNil(value) + + await transport.close() + } + + func testCloseIsIdempotent() async { + let transport = makeTransport() + await transport.close() + await transport.close() + } + + func testSendAfterCloseThrows() async { + let transport = makeTransport() + await transport.close() + + let message = JsonRpcMessage.notification( + JsonRpcNotification(method: "ping", params: nil) + ) + do { + try await transport.send(message) + XCTFail("Expected throw") + } catch let error as MCPTransportError { + XCTAssertEqual(error, .closed) + } catch { + XCTFail("Unexpected error \(error)") + } + } + + private func makeTransport() -> MCPStdioMessageTransport { + MCPStdioMessageTransport( + stdin: stdinPipe.fileHandleForReading, + stdout: stdoutPipe.fileHandleForWriting, + errorLogger: logger + ) + } + + private func firstInbound( + transport: MCPStdioMessageTransport, + timeout: TimeInterval = 2.0 + ) async throws -> JsonRpcMessage { + try await withThrowingTaskGroup(of: JsonRpcMessage?.self) { group in + group.addTask { + var iterator = transport.inbound.makeAsyncIterator() + return try await iterator.next() + } + group.addTask { + try await Task.sleep(nanoseconds: UInt64(timeout * 1_000_000_000)) + return nil + } + guard let result = try await group.next(), let value = result else { + group.cancelAll() + throw TestError.timeout + } + group.cancelAll() + return value + } + } +} + +private enum TestError: Error { + case timeout +} + +private final class FakeBridgeLogger: MCPBridgeLogger, @unchecked Sendable { + struct Entry { + let level: MCPBridgeLogLevel + let message: String + } + + private let lock = NSLock() + private var storage: [Entry] = [] + + var entries: [Entry] { + lock.lock() + defer { lock.unlock() } + return storage + } + + func log(_ level: MCPBridgeLogLevel, _ message: String) { + lock.lock() + defer { lock.unlock() } + storage.append(Entry(level: level, message: message)) + } +} diff --git a/TableProTests/Core/MCP/Transport/MCPStreamableHttpClientTransportTests.swift b/TableProTests/Core/MCP/Transport/MCPStreamableHttpClientTransportTests.swift new file mode 100644 index 000000000..e37991522 --- /dev/null +++ b/TableProTests/Core/MCP/Transport/MCPStreamableHttpClientTransportTests.swift @@ -0,0 +1,528 @@ +import Foundation +import Network +@testable import TablePro +import XCTest + +final class MCPStreamableHttpClientTransportTests: XCTestCase { + private var server: MockHttpServer! + + override func setUp() async throws { + try await super.setUp() + server = MockHttpServer() + try await server.start() + } + + override func tearDown() async throws { + await server.stop() + server = nil + try await super.tearDown() + } + + func testJsonResponseArrivesOnInbound() async throws { + let response = JsonRpcMessage.successResponse( + JsonRpcSuccessResponse(id: .number(1), result: .object(["ok": .bool(true)])) + ) + let body = try JsonRpcCodec.encode(response) + await server.setResponder { _ in + MockHttpResponse(status: 200, headers: [("Content-Type", "application/json")], body: body) + } + + let transport = makeTransport() + let request = JsonRpcMessage.request( + JsonRpcRequest(id: .number(1), method: "ping", params: nil) + ) + try await transport.send(request) + + let received = try await firstInbound(transport: transport) + XCTAssertEqual(received, response) + await transport.close() + } + + func testSseResponseDeliversFramesIncrementally() async throws { + let frame1 = JsonRpcMessage.notification( + JsonRpcNotification(method: "notifications/progress", params: .object(["progress": .int(50)])) + ) + let frame2 = JsonRpcMessage.successResponse( + JsonRpcSuccessResponse(id: .number(2), result: .object(["done": .bool(true)])) + ) + let payload1 = try JsonRpcCodec.encode(frame1) + let payload2 = try JsonRpcCodec.encode(frame2) + let body1 = "data: \(String(data: payload1, encoding: .utf8) ?? "")\n\n" + let body2 = "data: \(String(data: payload2, encoding: .utf8) ?? "")\n\n" + + await server.setResponder { _ in + MockHttpResponse( + status: 200, + headers: [("Content-Type", "text/event-stream")], + body: Data((body1 + body2).utf8) + ) + } + + let transport = makeTransport() + let request = JsonRpcMessage.request( + JsonRpcRequest(id: .number(2), method: "tools/run", params: nil) + ) + try await transport.send(request) + + let received = try await collectInbound(transport: transport, count: 2) + XCTAssertEqual(received[0], frame1) + XCTAssertEqual(received[1], frame2) + await transport.close() + } + + func testHttp404SynthesizesSessionNotFoundError() async throws { + await server.setResponder { _ in + MockHttpResponse( + status: 404, + headers: [("Content-Type", "text/plain")], + body: Data("Session not found".utf8) + ) + } + + let transport = makeTransport() + let request = JsonRpcMessage.request( + JsonRpcRequest(id: .number(7), method: "tools/list", params: nil) + ) + try await transport.send(request) + + let received = try await firstInbound(transport: transport) + guard case .errorResponse(let response) = received else { + XCTFail("Expected errorResponse, got \(received)") + return + } + XCTAssertEqual(response.id, .number(7)) + XCTAssertEqual(response.error.code, JsonRpcErrorCode.sessionNotFound) + await transport.close() + } + + func testHttp401IncludesUnauthenticatedError() async throws { + await server.setResponder { _ in + MockHttpResponse( + status: 401, + headers: [ + ("Content-Type", "text/plain"), + ("WWW-Authenticate", "Bearer realm=\"TablePro\"") + ], + body: Data("Unauthenticated".utf8) + ) + } + + let transport = makeTransport() + let request = JsonRpcMessage.request( + JsonRpcRequest(id: .number(99), method: "tools/list", params: nil) + ) + try await transport.send(request) + + let received = try await firstInbound(transport: transport) + guard case .errorResponse(let response) = received else { + XCTFail("Expected errorResponse, got \(received)") + return + } + XCTAssertEqual(response.id, .number(99)) + XCTAssertEqual(response.error.code, JsonRpcErrorCode.unauthenticated) + XCTAssertEqual(response.error.message, "Unauthenticated") + await transport.close() + } + + func testHttp500ProducesInternalError() async throws { + await server.setResponder { _ in + MockHttpResponse( + status: 500, + headers: [("Content-Type", "text/plain")], + body: Data("kaboom".utf8) + ) + } + + let transport = makeTransport() + let request = JsonRpcMessage.request( + JsonRpcRequest(id: .number(5), method: "x", params: nil) + ) + try await transport.send(request) + + let received = try await firstInbound(transport: transport) + guard case .errorResponse(let response) = received else { + XCTFail("Expected errorResponse, got \(received)") + return + } + XCTAssertEqual(response.id, .number(5)) + XCTAssertEqual(response.error.code, JsonRpcErrorCode.internalError) + await transport.close() + } + + func testServerEmittedJsonRpcErrorIsForwarded() async throws { + let serverError = JsonRpcMessage.errorResponse( + JsonRpcErrorResponse( + id: .number(8), + error: JsonRpcError(code: -32_007, message: "policy denied") + ) + ) + let body = try JsonRpcCodec.encode(serverError) + await server.setResponder { _ in + MockHttpResponse( + status: 403, + headers: [("Content-Type", "application/json")], + body: body + ) + } + + let transport = makeTransport() + let request = JsonRpcMessage.request( + JsonRpcRequest(id: .number(8), method: "x", params: nil) + ) + try await transport.send(request) + + let received = try await firstInbound(transport: transport) + guard case .errorResponse(let response) = received else { + XCTFail("Expected errorResponse") + return + } + XCTAssertEqual(response.error.code, -32_007) + XCTAssertEqual(response.error.message, "policy denied") + await transport.close() + } + + func testCapturesSessionIdFromResponse() async throws { + let response = JsonRpcMessage.successResponse( + JsonRpcSuccessResponse(id: .number(1), result: .object(["ok": .bool(true)])) + ) + let body = try JsonRpcCodec.encode(response) + await server.setResponder { _ in + MockHttpResponse( + status: 200, + headers: [ + ("Content-Type", "application/json"), + ("Mcp-Session-Id", "session-xyz") + ], + body: body + ) + } + + let transport = makeTransport() + try await transport.send(JsonRpcMessage.request( + JsonRpcRequest(id: .number(1), method: "initialize", params: nil) + )) + _ = try await firstInbound(transport: transport) + + await server.setResponder { received in + let sessionHeader = received.headers.first { $0.0.lowercased() == "mcp-session-id" }?.1 + let resultBody = try? JsonRpcCodec.encode(.successResponse( + JsonRpcSuccessResponse( + id: .number(2), + result: .object(["session": .string(sessionHeader ?? "")]) + ) + )) + return MockHttpResponse( + status: 200, + headers: [("Content-Type", "application/json")], + body: resultBody ?? Data() + ) + } + + try await transport.send(JsonRpcMessage.request( + JsonRpcRequest(id: .number(2), method: "tools/list", params: nil) + )) + let second = try await firstInbound(transport: transport) + guard case .successResponse(let success) = second else { + XCTFail("Expected successResponse") + return + } + XCTAssertEqual(success.result["session"]?.stringValue, "session-xyz") + + await transport.close() + } + + private func makeTransport() -> MCPStreamableHttpClientTransport { + let url = URL(string: "http://127.0.0.1:\(server.port)/mcp")! + let configuration = MCPStreamableHttpClientConfiguration( + endpoint: url, + bearerToken: "test-token", + tlsCertFingerprint: nil, + requestTimeout: .seconds(5), + serverInitiatedStream: false + ) + return MCPStreamableHttpClientTransport(configuration: configuration, errorLogger: nil) + } + + private func firstInbound( + transport: MCPStreamableHttpClientTransport, + timeout: TimeInterval = 3.0 + ) async throws -> JsonRpcMessage { + try await withThrowingTaskGroup(of: JsonRpcMessage?.self) { group in + group.addTask { + var iterator = transport.inbound.makeAsyncIterator() + return try await iterator.next() + } + group.addTask { + try await Task.sleep(nanoseconds: UInt64(timeout * 1_000_000_000)) + return nil + } + guard let result = try await group.next(), let value = result else { + group.cancelAll() + throw TransportTestError.timeout + } + group.cancelAll() + return value + } + } + + private func collectInbound( + transport: MCPStreamableHttpClientTransport, + count: Int, + timeout: TimeInterval = 3.0 + ) async throws -> [JsonRpcMessage] { + try await withThrowingTaskGroup(of: [JsonRpcMessage]?.self) { group in + group.addTask { + var iterator = transport.inbound.makeAsyncIterator() + var collected: [JsonRpcMessage] = [] + while collected.count < count { + guard let next = try await iterator.next() else { break } + collected.append(next) + } + return collected + } + group.addTask { + try await Task.sleep(nanoseconds: UInt64(timeout * 1_000_000_000)) + return nil + } + guard let result = try await group.next(), let value = result else { + group.cancelAll() + throw TransportTestError.timeout + } + group.cancelAll() + return value + } + } +} + +private enum TransportTestError: Error { + case timeout +} + +private struct MockHttpRequest: Sendable { + let method: String + let path: String + let headers: [(String, String)] + let body: Data +} + +private struct MockHttpResponse: Sendable { + let status: Int + let headers: [(String, String)] + let body: Data +} + +private actor MockServerState { + var responder: (@Sendable (MockHttpRequest) -> MockHttpResponse)? + + func setResponder(_ responder: @escaping @Sendable (MockHttpRequest) -> MockHttpResponse) { + self.responder = responder + } + + func respond(to request: MockHttpRequest) -> MockHttpResponse { + if let responder { + return responder(request) + } + return MockHttpResponse( + status: 500, + headers: [("Content-Type", "text/plain")], + body: Data("no responder".utf8) + ) + } +} + +private final class MockHttpServer: @unchecked Sendable { + private var listener: NWListener? + private let state = MockServerState() + private let lock = NSLock() + private var assignedPort: UInt16 = 0 + private var connections: [NWConnection] = [] + + var port: UInt16 { + lock.lock() + defer { lock.unlock() } + return assignedPort + } + + func setResponder(_ responder: @escaping @Sendable (MockHttpRequest) -> MockHttpResponse) async { + await state.setResponder(responder) + } + + func start() async throws { + try await withCheckedThrowingContinuation { (continuation: CheckedContinuation) in + do { + let params = NWParameters.tcp + params.allowLocalEndpointReuse = true + let listener = try NWListener(using: params) + lock.lock() + self.listener = listener + lock.unlock() + + let port = self.port + _ = port + + listener.stateUpdateHandler = { [weak self] state in + guard let self else { return } + switch state { + case .ready: + if let port = listener.port?.rawValue { + self.lock.lock() + self.assignedPort = port + self.lock.unlock() + } + continuation.resume() + case .failed(let error): + continuation.resume(throwing: error) + default: + break + } + } + listener.newConnectionHandler = { [weak self] connection in + self?.handle(connection) + } + listener.start(queue: .global(qos: .userInitiated)) + } catch { + continuation.resume(throwing: error) + } + } + } + + func stop() async { + lock.lock() + let listener = self.listener + let connections = self.connections + self.listener = nil + self.connections = [] + lock.unlock() + listener?.cancel() + for connection in connections { + connection.cancel() + } + } + + private func handle(_ connection: NWConnection) { + lock.lock() + connections.append(connection) + lock.unlock() + + connection.stateUpdateHandler = { [weak self] state in + switch state { + case .ready: + self?.readRequest(connection: connection, accumulated: Data()) + case .failed, .cancelled: + break + default: + break + } + } + connection.start(queue: .global(qos: .userInitiated)) + } + + private func readRequest(connection: NWConnection, accumulated: Data) { + connection.receive(minimumIncompleteLength: 1, maximumLength: 64 * 1024) { [weak self] data, _, isComplete, error in + guard let self else { return } + if let error { + _ = error + connection.cancel() + return + } + var buffer = accumulated + if let data { + buffer.append(data) + } + + if let request = Self.parseRequest(buffer) { + Task { + let response = await self.state.respond(to: request) + let raw = Self.serializeResponse(response) + connection.send(content: raw, completion: .contentProcessed { _ in + connection.cancel() + }) + } + return + } + + if isComplete { + connection.cancel() + return + } + self.readRequest(connection: connection, accumulated: buffer) + } + } + + private static func parseRequest(_ data: Data) -> MockHttpRequest? { + guard let separatorRange = data.range(of: Data("\r\n\r\n".utf8)) else { + return nil + } + let headerData = data[..= 3 else { return nil } + let method = String(parts[0]) + let path = String(parts[1]) + + var headers: [(String, String)] = [] + for line in lines.dropFirst() where !line.isEmpty { + guard let colon = line.firstIndex(of: ":") else { continue } + let key = String(line[line.startIndex.. 0 { + let remaining = data.count - bodyStart + if remaining < contentLength { + return nil + } + body = data.subdata(in: bodyStart..<(bodyStart + contentLength)) + } else { + body = Data() + } + + return MockHttpRequest(method: method, path: path, headers: headers, body: body) + } + + private static func serializeResponse(_ response: MockHttpResponse) -> Data { + var output = "HTTP/1.1 \(response.status) \(reasonPhrase(for: response.status))\r\n" + var headers = response.headers + if !headers.contains(where: { $0.0.lowercased() == "content-length" }) { + headers.append(("Content-Length", "\(response.body.count)")) + } + if !headers.contains(where: { $0.0.lowercased() == "connection" }) { + headers.append(("Connection", "close")) + } + for (key, value) in headers { + output.append("\(key): \(value)\r\n") + } + output.append("\r\n") + var data = Data(output.utf8) + data.append(response.body) + return data + } + + private static func reasonPhrase(for status: Int) -> String { + switch status { + case 200: return "OK" + case 401: return "Unauthorized" + case 403: return "Forbidden" + case 404: return "Not Found" + case 500: return "Internal Server Error" + default: return "Status" + } + } +} diff --git a/TableProTests/Core/MCP/Wire/HttpRequestParserTests.swift b/TableProTests/Core/MCP/Wire/HttpRequestParserTests.swift new file mode 100644 index 000000000..435646df8 --- /dev/null +++ b/TableProTests/Core/MCP/Wire/HttpRequestParserTests.swift @@ -0,0 +1,147 @@ +import Foundation +@testable import TablePro +import XCTest + +final class HttpRequestParserTests: XCTestCase { + func testParsesSimpleGetRequest() throws { + let raw = "GET /index HTTP/1.1\r\nHost: example.com\r\n\r\n" + let result = try HttpRequestParser.parse(Data(raw.utf8)) + guard case .complete(let head, let body, let consumed) = result else { + XCTFail("Expected complete, got \(result)") + return + } + XCTAssertEqual(head.method, .get) + XCTAssertEqual(head.path, "/index") + XCTAssertEqual(head.httpVersion, "HTTP/1.1") + XCTAssertEqual(head.headers.value(for: "Host"), "example.com") + XCTAssertEqual(body, Data()) + XCTAssertEqual(consumed, raw.utf8.count) + } + + func testCaseInsensitiveHeaderLookup() throws { + let raw = "GET / HTTP/1.1\r\nContent-Type: text/plain\r\n\r\n" + let result = try HttpRequestParser.parse(Data(raw.utf8)) + guard case .complete(let head, _, _) = result else { + XCTFail("Expected complete") + return + } + XCTAssertEqual(head.headers.value(for: "content-type"), "text/plain") + XCTAssertEqual(head.headers.value(for: "CONTENT-TYPE"), "text/plain") + } + + func testMcpSessionIdLookupCaseInsensitive() throws { + let lowercaseRaw = "GET / HTTP/1.1\r\nmcp-session-id: abc-123\r\n\r\n" + let lowercaseResult = try HttpRequestParser.parse(Data(lowercaseRaw.utf8)) + guard case .complete(let lowerHead, _, _) = lowercaseResult else { + XCTFail("Expected complete for lowercase") + return + } + XCTAssertEqual(lowerHead.headers.value(for: "Mcp-Session-Id"), "abc-123") + + let uppercaseRaw = "GET / HTTP/1.1\r\nMCP-SESSION-ID: xyz-789\r\n\r\n" + let uppercaseResult = try HttpRequestParser.parse(Data(uppercaseRaw.utf8)) + guard case .complete(let upperHead, _, _) = uppercaseResult else { + XCTFail("Expected complete for uppercase") + return + } + XCTAssertEqual(upperHead.headers.value(for: "Mcp-Session-Id"), "xyz-789") + } + + func testParsesPostBodyOfExactContentLength() throws { + let body = "{\"x\":1}" + let raw = "POST /rpc HTTP/1.1\r\nHost: x\r\nContent-Length: \(body.utf8.count)\r\n\r\n\(body)" + let result = try HttpRequestParser.parse(Data(raw.utf8)) + guard case .complete(let head, let parsedBody, let consumed) = result else { + XCTFail("Expected complete") + return + } + XCTAssertEqual(head.method, .post) + XCTAssertEqual(parsedBody, Data(body.utf8)) + XCTAssertEqual(consumed, raw.utf8.count) + } + + func testReportsExtraBytesAfterBodyViaConsumedBytes() throws { + let body = "abc" + let raw = "POST / HTTP/1.1\r\nHost: x\r\nContent-Length: 3\r\n\r\n\(body)REMAINDER" + let result = try HttpRequestParser.parse(Data(raw.utf8)) + guard case .complete(_, let parsedBody, let consumed) = result else { + XCTFail("Expected complete") + return + } + XCTAssertEqual(parsedBody, Data(body.utf8)) + let expectedConsumed = raw.utf8.count - "REMAINDER".utf8.count + XCTAssertEqual(consumed, expectedConsumed) + } + + func testIncompleteWhenHeadersNotFinished() throws { + let raw = "GET / HTTP/1.1\r\nHost: x" + let result = try HttpRequestParser.parse(Data(raw.utf8)) + XCTAssertEqual(result, .incomplete) + } + + func testIncompleteWhenBodyShorterThanContentLength() throws { + let raw = "POST / HTTP/1.1\r\nHost: x\r\nContent-Length: 10\r\n\r\nshort" + let result = try HttpRequestParser.parse(Data(raw.utf8)) + XCTAssertEqual(result, .incomplete) + } + + func testRejectsBareLfAsTerminator() { + let raw = "GET / HTTP/1.1\nHost: x\n\n" + XCTAssertThrowsError(try HttpRequestParser.parse(Data(raw.utf8))) { error in + XCTAssertEqual(error as? HttpRequestParseError, .nonStrictLineEndings) + } + } + + func testRejectsBareLfInHeaderLine() { + let raw = "GET / HTTP/1.1\r\nBad: value\nHost: x\r\n\r\n" + XCTAssertThrowsError(try HttpRequestParser.parse(Data(raw.utf8))) { error in + XCTAssertEqual(error as? HttpRequestParseError, .nonStrictLineEndings) + } + } + + func testRejectsHeaderTooLarge() { + let bigHeaderValue = String(repeating: "a", count: 17 * 1_024) + let raw = "GET / HTTP/1.1\r\nX-Big: \(bigHeaderValue)\r\n\r\n" + XCTAssertThrowsError(try HttpRequestParser.parse(Data(raw.utf8))) { error in + XCTAssertEqual(error as? HttpRequestParseError, .headerTooLarge) + } + } + + func testRejectsHeaderTooLargeWithoutTerminator() { + let huge = String(repeating: "X-Pad: pad\r\n", count: 2_000) + let raw = "GET / HTTP/1.1\r\n\(huge)" + XCTAssertThrowsError(try HttpRequestParser.parse(Data(raw.utf8))) { error in + XCTAssertEqual(error as? HttpRequestParseError, .headerTooLarge) + } + } + + func testUnknownMethodMappedToOther() throws { + let raw = "PROPFIND / HTTP/1.1\r\nHost: x\r\n\r\n" + let result = try HttpRequestParser.parse(Data(raw.utf8)) + guard case .complete(let head, _, _) = result else { + XCTFail("Expected complete") + return + } + XCTAssertEqual(head.method, .other("PROPFIND")) + } + + func testRejectsBodyOverLimit() { + let raw = "POST / HTTP/1.1\r\nHost: x\r\nContent-Length: 99999999\r\n\r\n" + XCTAssertThrowsError(try HttpRequestParser.parse(Data(raw.utf8))) { error in + guard case HttpRequestParseError.bodyTooLarge = error else { + XCTFail("Expected bodyTooLarge") + return + } + } + } + + func testPathPreservedVerbatim() throws { + let raw = "GET /path%20with%20spaces?x=1 HTTP/1.1\r\nHost: x\r\n\r\n" + let result = try HttpRequestParser.parse(Data(raw.utf8)) + guard case .complete(let head, _, _) = result else { + XCTFail("Expected complete") + return + } + XCTAssertEqual(head.path, "/path%20with%20spaces?x=1") + } +} diff --git a/TableProTests/Core/MCP/Wire/JsonRpcIdTests.swift b/TableProTests/Core/MCP/Wire/JsonRpcIdTests.swift new file mode 100644 index 000000000..0e24d42a1 --- /dev/null +++ b/TableProTests/Core/MCP/Wire/JsonRpcIdTests.swift @@ -0,0 +1,60 @@ +import Foundation +@testable import TablePro +import XCTest + +final class JsonRpcIdTests: XCTestCase { + func testNullRoundTrip() throws { + let id: JsonRpcId = .null + let data = try JSONEncoder().encode(id) + let decoded = try JSONDecoder().decode(JsonRpcId.self, from: data) + XCTAssertEqual(decoded, .null) + } + + func testNullEncodesAsJsonNull() throws { + let id: JsonRpcId = .null + let data = try JSONEncoder().encode(id) + XCTAssertEqual(String(data: data, encoding: .utf8), "null") + } + + func testStringRoundTrip() throws { + let id: JsonRpcId = .string("abc-123") + let data = try JSONEncoder().encode(id) + let decoded = try JSONDecoder().decode(JsonRpcId.self, from: data) + XCTAssertEqual(decoded, .string("abc-123")) + } + + func testNumberRoundTrip() throws { + let id: JsonRpcId = .number(42) + let data = try JSONEncoder().encode(id) + let decoded = try JSONDecoder().decode(JsonRpcId.self, from: data) + XCTAssertEqual(decoded, .number(42)) + } + + func testLargeNumberRoundTrip() throws { + let id: JsonRpcId = .number(Int64.max) + let data = try JSONEncoder().encode(id) + let decoded = try JSONDecoder().decode(JsonRpcId.self, from: data) + XCTAssertEqual(decoded, .number(Int64.max)) + } + + func testDecodeJsonNullProducesNullCase() throws { + let raw = Data("null".utf8) + let decoded = try JSONDecoder().decode(JsonRpcId.self, from: raw) + XCTAssertEqual(decoded, .null) + } + + func testDecodeBoolThrows() { + let raw = Data("true".utf8) + XCTAssertThrowsError(try JSONDecoder().decode(JsonRpcId.self, from: raw)) + } + + func testDecodeArrayThrows() { + let raw = Data("[1,2]".utf8) + XCTAssertThrowsError(try JSONDecoder().decode(JsonRpcId.self, from: raw)) + } + + func testDecodeObjectThrows() { + let raw = Data("{}".utf8) + XCTAssertThrowsError(try JSONDecoder().decode(JsonRpcId.self, from: raw)) + } +} diff --git a/TableProTests/Core/MCP/Wire/JsonRpcMessageTests.swift b/TableProTests/Core/MCP/Wire/JsonRpcMessageTests.swift new file mode 100644 index 000000000..02bc4e900 --- /dev/null +++ b/TableProTests/Core/MCP/Wire/JsonRpcMessageTests.swift @@ -0,0 +1,206 @@ +import Foundation +@testable import TablePro +import XCTest + +final class JsonRpcMessageTests: XCTestCase { + func testRequestRoundTrip() throws { + let message = JsonRpcMessage.request( + JsonRpcRequest( + id: .number(1), + method: "tools/list", + params: .object(["cursor": .string("abc")]) + ) + ) + let data = try JsonRpcCodec.encode(message) + let decoded = try JsonRpcCodec.decode(data) + XCTAssertEqual(decoded, message) + } + + func testRequestWithoutParamsRoundTrip() throws { + let message = JsonRpcMessage.request( + JsonRpcRequest(id: .string("req-1"), method: "ping", params: nil) + ) + let data = try JsonRpcCodec.encode(message) + let decoded = try JsonRpcCodec.decode(data) + XCTAssertEqual(decoded, message) + + let json = try XCTUnwrap(String(data: data, encoding: .utf8)) + XCTAssertFalse(json.contains("\"params\"")) + } + + func testNotificationRoundTrip() throws { + let message = JsonRpcMessage.notification( + JsonRpcNotification(method: "notifications/initialized", params: nil) + ) + let data = try JsonRpcCodec.encode(message) + let decoded = try JsonRpcCodec.decode(data) + XCTAssertEqual(decoded, message) + + let json = try XCTUnwrap(String(data: data, encoding: .utf8)) + XCTAssertFalse(json.contains("\"id\"")) + XCTAssertFalse(json.contains("\"params\"")) + } + + func testNotificationWithParamsRoundTrip() throws { + let message = JsonRpcMessage.notification( + JsonRpcNotification( + method: "notifications/progress", + params: .object(["progress": .int(50)]) + ) + ) + let data = try JsonRpcCodec.encode(message) + let decoded = try JsonRpcCodec.decode(data) + XCTAssertEqual(decoded, message) + } + + func testSuccessResponseRoundTrip() throws { + let message = JsonRpcMessage.successResponse( + JsonRpcSuccessResponse( + id: .number(7), + result: .object(["tools": .array([])]) + ) + ) + let data = try JsonRpcCodec.encode(message) + let decoded = try JsonRpcCodec.decode(data) + XCTAssertEqual(decoded, message) + } + + func testErrorResponseRoundTrip() throws { + let message = JsonRpcMessage.errorResponse( + JsonRpcErrorResponse( + id: .number(8), + error: JsonRpcError.methodNotFound(message: "not here") + ) + ) + let data = try JsonRpcCodec.encode(message) + let decoded = try JsonRpcCodec.decode(data) + XCTAssertEqual(decoded, message) + } + + func testErrorResponseWithNullIdEncodesAsJsonNull() throws { + let message = JsonRpcMessage.errorResponse( + JsonRpcErrorResponse(id: nil, error: JsonRpcError.parseError()) + ) + let data = try JsonRpcCodec.encode(message) + let json = try XCTUnwrap(String(data: data, encoding: .utf8)) + XCTAssertTrue(json.contains("\"id\":null")) + } + + func testErrorResponseWithExplicitNullIdRoundTrips() throws { + let message = JsonRpcMessage.errorResponse( + JsonRpcErrorResponse(id: .null, error: JsonRpcError.serverError()) + ) + let data = try JsonRpcCodec.encode(message) + let decoded = try JsonRpcCodec.decode(data) + if case .errorResponse(let response) = decoded { + XCTAssertEqual(response.id, .null) + } else { + XCTFail("Expected errorResponse") + } + } + + func testErrorResponseDataRoundTrip() throws { + let message = JsonRpcMessage.errorResponse( + JsonRpcErrorResponse( + id: .number(9), + error: JsonRpcError( + code: JsonRpcErrorCode.forbidden, + message: "no access", + data: .object(["reason": .string("policy")]) + ) + ) + ) + let data = try JsonRpcCodec.encode(message) + let decoded = try JsonRpcCodec.decode(data) + XCTAssertEqual(decoded, message) + } + + func testErrorResponseWithoutDataOmitsField() throws { + let message = JsonRpcMessage.errorResponse( + JsonRpcErrorResponse( + id: .number(10), + error: JsonRpcError.methodNotFound() + ) + ) + let data = try JsonRpcCodec.encode(message) + let json = try XCTUnwrap(String(data: data, encoding: .utf8)) + XCTAssertFalse(json.contains("\"data\"")) + } + + func testRejectsNon20JsonRpcVersion() { + let raw = Data(#"{"jsonrpc":"1.0","id":1,"method":"ping"}"#.utf8) + XCTAssertThrowsError(try JsonRpcCodec.decode(raw)) { error in + guard case JsonRpcDecodingError.invalidJsonRpcVersion(let value) = error else { + XCTFail("Expected invalidJsonRpcVersion, got \(error)") + return + } + XCTAssertEqual(value, "1.0") + } + } + + func testRejectsMissingJsonRpcVersion() { + let raw = Data(#"{"id":1,"method":"ping"}"#.utf8) + XCTAssertThrowsError(try JsonRpcCodec.decode(raw)) { error in + XCTAssertEqual(error as? JsonRpcDecodingError, .missingJsonRpcVersion) + } + } + + func testRejectsBatchArray() { + let raw = Data(#"[{"jsonrpc":"2.0","id":1,"method":"ping"}]"#.utf8) + XCTAssertThrowsError(try JsonRpcCodec.decode(raw)) { error in + XCTAssertEqual(error as? JsonRpcDecodingError, .batchUnsupported) + } + } + + func testRejectsBatchArrayWithLeadingWhitespace() { + let raw = Data(" \n[{\"jsonrpc\":\"2.0\"}]".utf8) + XCTAssertThrowsError(try JsonRpcCodec.decode(raw)) { error in + XCTAssertEqual(error as? JsonRpcDecodingError, .batchUnsupported) + } + } + + func testEncodeLineAppendsNewline() throws { + let message = JsonRpcMessage.notification( + JsonRpcNotification(method: "ping", params: nil) + ) + let data = try JsonRpcCodec.encodeLine(message) + XCTAssertEqual(data.last, 0x0A) + } + + func testNullIdInRequestRoundTrips() throws { + let message = JsonRpcMessage.request( + JsonRpcRequest(id: .null, method: "test", params: nil) + ) + let data = try JsonRpcCodec.encode(message) + let decoded = try JsonRpcCodec.decode(data) + XCTAssertEqual(decoded, message) + } + + func testRejectsAmbiguousMessageWithMethodAndResult() { + let raw = Data(#"{"jsonrpc":"2.0","id":1,"method":"foo","result":1}"#.utf8) + XCTAssertThrowsError(try JsonRpcCodec.decode(raw)) { error in + XCTAssertEqual(error as? JsonRpcDecodingError, .ambiguousMessage) + } + } + + func testRejectsResultAndError() { + let raw = Data(#"{"jsonrpc":"2.0","id":1,"result":1,"error":{"code":-32000,"message":"x"}}"#.utf8) + XCTAssertThrowsError(try JsonRpcCodec.decode(raw)) { error in + XCTAssertEqual(error as? JsonRpcDecodingError, .ambiguousMessage) + } + } + + func testRejectsEmptyEnvelope() { + let raw = Data(#"{"jsonrpc":"2.0","id":1}"#.utf8) + XCTAssertThrowsError(try JsonRpcCodec.decode(raw)) { error in + XCTAssertEqual(error as? JsonRpcDecodingError, .missingResultOrError) + } + } + + func testRejectsEnvelopeWithoutMethodOrIdEvenWithVersion() { + let raw = Data(#"{"jsonrpc":"2.0"}"#.utf8) + XCTAssertThrowsError(try JsonRpcCodec.decode(raw)) { error in + XCTAssertEqual(error as? JsonRpcDecodingError, .missingMethod) + } + } +} diff --git a/TableProTests/Core/MCP/Wire/SseEncoderDecoderTests.swift b/TableProTests/Core/MCP/Wire/SseEncoderDecoderTests.swift new file mode 100644 index 000000000..e4c08d3a0 --- /dev/null +++ b/TableProTests/Core/MCP/Wire/SseEncoderDecoderTests.swift @@ -0,0 +1,104 @@ +import Foundation +@testable import TablePro +import XCTest + +final class SseEncoderDecoderTests: XCTestCase { + func testRoundTripSingleLineFrame() async throws { + let frame = SseFrame(event: "message", id: "1", data: "hello", retry: nil) + let encoded = SseEncoder.encode(frame) + let decoder = SseDecoder() + let frames = await decoder.feed(encoded) + XCTAssertEqual(frames.count, 1) + XCTAssertEqual(frames.first?.event, "message") + XCTAssertEqual(frames.first?.id, "1") + XCTAssertEqual(frames.first?.data, "hello") + } + + func testEncodeMultiLineDataProducesMultipleDataLines() { + let frame = SseFrame(data: "line1\nline2\nline3") + let encoded = SseEncoder.encode(frame) + let text = String(data: encoded, encoding: .utf8) ?? "" + XCTAssertTrue(text.contains("data: line1\n")) + XCTAssertTrue(text.contains("data: line2\n")) + XCTAssertTrue(text.contains("data: line3\n")) + XCTAssertTrue(text.hasSuffix("\n\n")) + } + + func testRoundTripMultiLineData() async throws { + let frame = SseFrame(data: "alpha\nbeta\ngamma") + let encoded = SseEncoder.encode(frame) + let decoder = SseDecoder() + let frames = await decoder.feed(encoded) + XCTAssertEqual(frames.count, 1) + XCTAssertEqual(frames.first?.data, "alpha\nbeta\ngamma") + } + + func testDecodesMultipleFramesInOneChunk() async throws { + let frameA = SseEncoder.encode(SseFrame(event: "a", data: "first")) + let frameB = SseEncoder.encode(SseFrame(event: "b", data: "second")) + var combined = Data() + combined.append(frameA) + combined.append(frameB) + + let decoder = SseDecoder() + let frames = await decoder.feed(combined) + XCTAssertEqual(frames.count, 2) + XCTAssertEqual(frames[0].data, "first") + XCTAssertEqual(frames[1].data, "second") + } + + func testBuffersPartialFramesAcrossChunks() async throws { + let frame = SseFrame(event: "ping", data: "hello world") + let encoded = SseEncoder.encode(frame) + + let split = encoded.count / 2 + let firstPart = encoded.prefix(split) + let secondPart = encoded.suffix(from: split) + + let decoder = SseDecoder() + let firstFrames = await decoder.feed(Data(firstPart)) + XCTAssertTrue(firstFrames.isEmpty) + let secondFrames = await decoder.feed(Data(secondPart)) + XCTAssertEqual(secondFrames.count, 1) + XCTAssertEqual(secondFrames.first?.data, "hello world") + } + + func testDecoderToleratesCrlfFieldSeparators() async throws { + let raw = "event: x\r\nid: 7\r\ndata: hi\r\n\r\n" + let decoder = SseDecoder() + let frames = await decoder.feed(Data(raw.utf8)) + XCTAssertEqual(frames.count, 1) + XCTAssertEqual(frames.first?.event, "x") + XCTAssertEqual(frames.first?.id, "7") + XCTAssertEqual(frames.first?.data, "hi") + } + + func testDecoderJoinsMultipleDataFieldsWithNewline() async throws { + let raw = "data: a\ndata: b\ndata: c\n\n" + let decoder = SseDecoder() + let frames = await decoder.feed(Data(raw.utf8)) + XCTAssertEqual(frames.count, 1) + XCTAssertEqual(frames.first?.data, "a\nb\nc") + } + + func testDecoderIgnoresCommentLines() async throws { + let raw = ": this is a comment\ndata: payload\n\n" + let decoder = SseDecoder() + let frames = await decoder.feed(Data(raw.utf8)) + XCTAssertEqual(frames.count, 1) + XCTAssertEqual(frames.first?.data, "payload") + } + + func testEncoderIncludesRetry() { + let frame = SseFrame(data: "ping", retry: 5_000) + let encoded = SseEncoder.encode(frame) + let text = String(data: encoded, encoding: .utf8) ?? "" + XCTAssertTrue(text.contains("retry: 5000\n")) + } + + func testEncoderEndsWithDoubleNewline() { + let frame = SseFrame(data: "x") + let encoded = SseEncoder.encode(frame) + XCTAssertEqual(encoded.suffix(2), Data([0x0A, 0x0A])) + } +} diff --git a/TableProTests/Core/Services/ConnectionSharingTests.swift b/TableProTests/Core/Services/ConnectionSharingTests.swift index 1286eeb11..a331a8253 100644 --- a/TableProTests/Core/Services/ConnectionSharingTests.swift +++ b/TableProTests/Core/Services/ConnectionSharingTests.swift @@ -311,7 +311,7 @@ struct ConnectionSharingTests { ) let link = ConnectionExportService.buildImportDeeplink(for: original)! let url = URL(string: link)! - guard case .importConnection(let parsed) = DeeplinkHandler.parse(url) else { + guard case .success(.importConnection(let parsed)) = DeeplinkParser.parse(url) else { Issue.record("Failed to parse round-trip link") return } @@ -342,7 +342,7 @@ struct ConnectionSharingTests { ) let link = ConnectionExportService.buildImportDeeplink(for: original)! let url = URL(string: link)! - guard case .importConnection(let parsed) = DeeplinkHandler.parse(url) else { + guard case .success(.importConnection(let parsed)) = DeeplinkParser.parse(url) else { Issue.record("Failed to parse round-trip link") return } @@ -373,7 +373,7 @@ struct ConnectionSharingTests { ) let link = ConnectionExportService.buildImportDeeplink(for: original)! let url = URL(string: link)! - guard case .importConnection(let parsed) = DeeplinkHandler.parse(url) else { + guard case .success(.importConnection(let parsed)) = DeeplinkParser.parse(url) else { Issue.record("Failed to parse round-trip link") return } @@ -398,7 +398,7 @@ struct ConnectionSharingTests { ) let link = ConnectionExportService.buildImportDeeplink(for: original)! let url = URL(string: link)! - guard case .importConnection(let parsed) = DeeplinkHandler.parse(url) else { + guard case .success(.importConnection(let parsed)) = DeeplinkParser.parse(url) else { Issue.record("Failed to parse round-trip link") return } @@ -419,7 +419,7 @@ struct ConnectionSharingTests { ) let link = ConnectionExportService.buildImportDeeplink(for: original)! let url = URL(string: link)! - guard case .importConnection(let parsed) = DeeplinkHandler.parse(url) else { + guard case .success(.importConnection(let parsed)) = DeeplinkParser.parse(url) else { Issue.record("Failed to parse round-trip link") return } @@ -435,7 +435,7 @@ struct ConnectionSharingTests { ) let link = ConnectionExportService.buildImportDeeplink(for: original)! let url = URL(string: link)! - guard case .importConnection(let parsed) = DeeplinkHandler.parse(url) else { + guard case .success(.importConnection(let parsed)) = DeeplinkParser.parse(url) else { Issue.record("Failed to parse round-trip link") return } @@ -453,7 +453,7 @@ struct ConnectionSharingTests { ) let link = ConnectionExportService.buildImportDeeplink(for: original)! let url = URL(string: link)! - guard case .importConnection(let parsed) = DeeplinkHandler.parse(url) else { + guard case .success(.importConnection(let parsed)) = DeeplinkParser.parse(url) else { Issue.record("Failed to parse round-trip link") return } @@ -510,7 +510,7 @@ struct ConnectionSharingTests { let link = ConnectionExportService.buildImportDeeplink(for: original)! let url = URL(string: link)! - guard case .importConnection(let parsed) = DeeplinkHandler.parse(url) else { + guard case .success(.importConnection(let parsed)) = DeeplinkParser.parse(url) else { Issue.record("Failed to parse round-trip link") return } diff --git a/TableProTests/Core/Services/DeeplinkHandlerTests.swift b/TableProTests/Core/Services/DeeplinkHandlerTests.swift deleted file mode 100644 index c336353d2..000000000 --- a/TableProTests/Core/Services/DeeplinkHandlerTests.swift +++ /dev/null @@ -1,664 +0,0 @@ -// -// DeeplinkHandlerTests.swift -// TableProTests -// - -import Foundation -import Testing -@testable import TablePro - -@Suite("Deeplink Handler") -@MainActor -struct DeeplinkHandlerTests { - - // MARK: - Connect Actions - - private static let sampleId = UUID(uuidString: "11111111-2222-3333-4444-555555555555")! - - @Test("Connect action with UUID") - func testConnectByUUID() { - let url = URL(string: "tablepro://connect/\(Self.sampleId.uuidString)")! - let action = DeeplinkHandler.parse(url) - if case .connect(let connectionId) = action { - #expect(connectionId == Self.sampleId) - } else { - Issue.record("Expected .connect, got \(String(describing: action))") - } - } - - @Test("Connect action with non-UUID first segment returns nil") - func testConnectNonUUIDReturnsNil() { - let url = URL(string: "tablepro://connect/Production")! - #expect(DeeplinkHandler.parse(url) == nil) - } - - @Test("Connect action with empty path returns nil") - func testConnectEmptyPathReturnsNil() { - let url = URL(string: "tablepro://connect/")! - #expect(DeeplinkHandler.parse(url) == nil) - } - - @Test("Connect action accepts lowercase UUID") - func testConnectLowercaseUUID() { - let id = UUID() - let url = URL(string: "tablepro://connect/\(id.uuidString.lowercased())")! - if case .connect(let parsed) = DeeplinkHandler.parse(url) { - #expect(parsed == id) - } else { - Issue.record("Expected .connect for lowercase UUID") - } - } - - @Test("Open table without database") - func testOpenTableWithoutDatabase() { - let url = URL(string: "tablepro://connect/\(Self.sampleId.uuidString)/table/users")! - let action = DeeplinkHandler.parse(url) - if case .openTable(let connectionId, let tableName, let databaseName) = action { - #expect(connectionId == Self.sampleId) - #expect(tableName == "users") - #expect(databaseName == nil) - } else { - Issue.record("Expected .openTable, got \(String(describing: action))") - } - } - - @Test("Open table with database") - func testOpenTableWithDatabase() { - let url = URL(string: "tablepro://connect/\(Self.sampleId.uuidString)/database/analytics/table/events")! - let action = DeeplinkHandler.parse(url) - if case .openTable(let connectionId, let tableName, let databaseName) = action { - #expect(connectionId == Self.sampleId) - #expect(tableName == "events") - #expect(databaseName == "analytics") - } else { - Issue.record("Expected .openTable, got \(String(describing: action))") - } - } - - @Test("Open query with decoded SQL") - func testOpenQueryDecodedSQL() { - let url = URL(string: "tablepro://connect/\(Self.sampleId.uuidString)/query?sql=SELECT%20*%20FROM%20users")! - let action = DeeplinkHandler.parse(url) - if case .openQuery(let connectionId, let sql) = action { - #expect(connectionId == Self.sampleId) - #expect(sql == "SELECT * FROM users") - } else { - Issue.record("Expected .openQuery, got \(String(describing: action))") - } - } - - @Test("Open query with empty SQL returns nil") - func testOpenQueryEmptySQLReturnsNil() { - let url = URL(string: "tablepro://connect/\(Self.sampleId.uuidString)/query?sql=")! - let action = DeeplinkHandler.parse(url) - #expect(action == nil) - } - - @Test("Unrecognized path returns nil") - func testUnrecognizedPathReturnsNil() { - let url = URL(string: "tablepro://connect/\(Self.sampleId.uuidString)/unknown/path")! - let action = DeeplinkHandler.parse(url) - #expect(action == nil) - } - - @Test("Unknown host returns nil") - func testUnknownHostReturnsNil() { - let url = URL(string: "tablepro://unknown-host")! - let action = DeeplinkHandler.parse(url) - #expect(action == nil) - } - - @Test("Wrong scheme returns nil") - func testWrongSchemeReturnsNil() { - let url = URL(string: "https://example.com")! - let action = DeeplinkHandler.parse(url) - #expect(action == nil) - } - - @Test("Malformed UUID with extra characters returns nil") - func testMalformedUUIDReturnsNil() { - let url = URL(string: "tablepro://connect/not-a-real-uuid-1234")! - #expect(DeeplinkHandler.parse(url) == nil) - } - - // MARK: - Integrations Actions - - @Test("Pair action parses required params") - func testPairAction() { - let url = URL(string: "tablepro://integrations/pair?client=Raycast&challenge=abc123&redirect=raycast://callback&scopes=readOnly")! - if case .pairIntegration(let request) = DeeplinkHandler.parse(url) { - #expect(request.clientName == "Raycast") - #expect(request.challenge == "abc123") - #expect(request.redirectURL.absoluteString == "raycast://callback") - #expect(request.requestedScopes == "readOnly") - #expect(request.requestedConnectionIds == nil) - } else { - Issue.record("Expected .pairIntegration") - } - } - - @Test("Pair action parses connection-ids CSV") - func testPairActionConnectionIds() { - let id1 = UUID() - let id2 = UUID() - let csv = "\(id1.uuidString),\(id2.uuidString)" - let url = URL(string: "tablepro://integrations/pair?client=Raycast&challenge=abc&redirect=raycast://cb&connection-ids=\(csv)")! - if case .pairIntegration(let request) = DeeplinkHandler.parse(url) { - #expect(request.requestedConnectionIds == Set([id1, id2])) - } else { - Issue.record("Expected .pairIntegration with parsed UUIDs") - } - } - - @Test("Pair missing client returns nil") - func testPairMissingClientReturnsNil() { - let url = URL(string: "tablepro://integrations/pair?challenge=abc&redirect=raycast://cb")! - #expect(DeeplinkHandler.parse(url) == nil) - } - - @Test("Pair missing challenge returns nil") - func testPairMissingChallengeReturnsNil() { - let url = URL(string: "tablepro://integrations/pair?client=Raycast&redirect=raycast://cb")! - #expect(DeeplinkHandler.parse(url) == nil) - } - - @Test("Exchange action parses code and verifier") - func testExchangeAction() { - let url = URL(string: "tablepro://integrations/exchange?code=abc-123&verifier=xyz-456")! - if case .exchangePairing(let exchange) = DeeplinkHandler.parse(url) { - #expect(exchange.code == "abc-123") - #expect(exchange.verifier == "xyz-456") - } else { - Issue.record("Expected .exchangePairing") - } - } - - @Test("Exchange missing verifier returns nil") - func testExchangeMissingVerifierReturnsNil() { - let url = URL(string: "tablepro://integrations/exchange?code=abc")! - #expect(DeeplinkHandler.parse(url) == nil) - } - - @Test("Start MCP action parses without params") - func testStartMCPAction() { - let url = URL(string: "tablepro://integrations/start-mcp")! - if case .startMCP = DeeplinkHandler.parse(url) { - // matched - } else { - Issue.record("Expected .startMCP") - } - } - - @Test("Unknown integrations action returns nil") - func testUnknownIntegrationsAction() { - let url = URL(string: "tablepro://integrations/unknown")! - #expect(DeeplinkHandler.parse(url) == nil) - } - - // MARK: - Import — Basic Fields - - @Test("Import with all basic params") - func testImportBasicParams() { - let url = URL(string: "tablepro://import?name=Dev&host=localhost&type=mysql&port=3306&username=root&database=mydb")! - let action = DeeplinkHandler.parse(url) - guard case .importConnection(let conn) = action else { - Issue.record("Expected .importConnection, got \(String(describing: action))") - return - } - #expect(conn.name == "Dev") - #expect(conn.host == "localhost") - #expect(conn.port == 3306) - #expect(conn.type == "MySQL") - #expect(conn.username == "root") - #expect(conn.database == "mydb") - } - - @Test("Import with minimal required params") - func testImportMinimalParams() { - let url = URL(string: "tablepro://import?name=Test&host=db.example.com&type=postgresql")! - let action = DeeplinkHandler.parse(url) - guard case .importConnection(let conn) = action else { - Issue.record("Expected .importConnection, got \(String(describing: action))") - return - } - #expect(conn.name == "Test") - #expect(conn.host == "db.example.com") - #expect(conn.type == "PostgreSQL") - #expect(conn.username == "") - #expect(conn.database == "") - #expect(conn.sshConfig == nil) - #expect(conn.sslConfig == nil) - #expect(conn.color == nil) - #expect(conn.tagName == nil) - #expect(conn.groupName == nil) - #expect(conn.additionalFields == nil) - } - - @Test("Import uses default port when not specified") - func testImportDefaultPort() { - let url = URL(string: "tablepro://import?name=PG&host=localhost&type=postgresql")! - guard case .importConnection(let conn) = DeeplinkHandler.parse(url) else { - Issue.record("Expected .importConnection") - return - } - #expect(conn.port == 5432) - } - - @Test("Import with case-insensitive type") - func testImportCaseInsensitiveType() { - let url = URL(string: "tablepro://import?name=Dev&host=localhost&type=PostgreSQL")! - guard case .importConnection(let conn) = DeeplinkHandler.parse(url) else { - Issue.record("Expected .importConnection") - return - } - #expect(conn.type == "PostgreSQL") - } - - @Test("Import missing name returns nil") - func testImportMissingNameReturnsNil() { - let url = URL(string: "tablepro://import?host=localhost&type=mysql")! - #expect(DeeplinkHandler.parse(url) == nil) - } - - @Test("Import missing host returns nil") - func testImportMissingHostReturnsNil() { - let url = URL(string: "tablepro://import?name=Dev&type=mysql")! - #expect(DeeplinkHandler.parse(url) == nil) - } - - @Test("Import missing type returns nil") - func testImportMissingTypeReturnsNil() { - let url = URL(string: "tablepro://import?name=Dev&host=localhost")! - #expect(DeeplinkHandler.parse(url) == nil) - } - - @Test("Import with empty name returns nil") - func testImportEmptyNameReturnsNil() { - let url = URL(string: "tablepro://import?name=&host=localhost&type=mysql")! - #expect(DeeplinkHandler.parse(url) == nil) - } - - @Test("Import with empty host returns nil") - func testImportEmptyHostReturnsNil() { - let url = URL(string: "tablepro://import?name=Dev&host=&type=mysql")! - #expect(DeeplinkHandler.parse(url) == nil) - } - - // MARK: - Import — SSH Config - - @Test("Import with SSH config") - func testImportWithSSH() { - let url = URL(string: "tablepro://import?name=Prod&host=db.internal&type=postgresql&ssh=1&sshHost=bastion.example.com&sshPort=2222&sshUsername=deploy&sshAuthMethod=privateKey&sshPrivateKeyPath=~/.ssh/id_ed25519")! - guard case .importConnection(let conn) = DeeplinkHandler.parse(url) else { - Issue.record("Expected .importConnection") - return - } - #expect(conn.sshConfig != nil) - #expect(conn.sshConfig?.enabled == true) - #expect(conn.sshConfig?.host == "bastion.example.com") - #expect(conn.sshConfig?.port == 2222) - #expect(conn.sshConfig?.username == "deploy") - #expect(conn.sshConfig?.authMethod == "privateKey") - #expect(conn.sshConfig?.privateKeyPath == "~/.ssh/id_ed25519") - } - - @Test("Import without ssh=1 has no SSH config") - func testImportNoSSHFlag() { - let url = URL(string: "tablepro://import?name=Dev&host=localhost&type=mysql&sshHost=bastion.com")! - guard case .importConnection(let conn) = DeeplinkHandler.parse(url) else { - Issue.record("Expected .importConnection") - return - } - #expect(conn.sshConfig == nil) - } - - @Test("Import with SSH defaults") - func testImportSSHDefaults() { - let url = URL(string: "tablepro://import?name=Dev&host=localhost&type=mysql&ssh=1&sshHost=bastion.com")! - guard case .importConnection(let conn) = DeeplinkHandler.parse(url) else { - Issue.record("Expected .importConnection") - return - } - #expect(conn.sshConfig?.port == 22) - #expect(conn.sshConfig?.username == "") - #expect(conn.sshConfig?.authMethod == "password") - #expect(conn.sshConfig?.privateKeyPath == "") - #expect(conn.sshConfig?.useSSHConfig == false) - #expect(conn.sshConfig?.agentSocketPath == "") - } - - @Test("Import with SSH agent") - func testImportSSHAgent() { - let url = URL(string: "tablepro://import?name=Dev&host=localhost&type=mysql&ssh=1&sshHost=bastion.com&sshAuthMethod=sshAgent&sshAgentSocketPath=/tmp/agent.sock")! - guard case .importConnection(let conn) = DeeplinkHandler.parse(url) else { - Issue.record("Expected .importConnection") - return - } - #expect(conn.sshConfig?.authMethod == "sshAgent") - #expect(conn.sshConfig?.agentSocketPath == "/tmp/agent.sock") - } - - @Test("Import with SSH use config flag") - func testImportSSHUseConfig() { - let url = URL(string: "tablepro://import?name=Dev&host=localhost&type=mysql&ssh=1&sshHost=bastion.com&sshUseSSHConfig=1")! - guard case .importConnection(let conn) = DeeplinkHandler.parse(url) else { - Issue.record("Expected .importConnection") - return - } - #expect(conn.sshConfig?.useSSHConfig == true) - } - - @Test("Import with SSH TOTP config") - func testImportSSHTOTP() { - let url = URL(string: "tablepro://import?name=Dev&host=localhost&type=mysql&ssh=1&sshHost=bastion.com&sshTotpMode=autoGenerate&sshTotpAlgorithm=sha256&sshTotpDigits=8&sshTotpPeriod=60")! - guard case .importConnection(let conn) = DeeplinkHandler.parse(url) else { - Issue.record("Expected .importConnection") - return - } - #expect(conn.sshConfig?.totpMode == "autoGenerate") - #expect(conn.sshConfig?.totpAlgorithm == "sha256") - #expect(conn.sshConfig?.totpDigits == 8) - #expect(conn.sshConfig?.totpPeriod == 60) - } - - @Test("Import with SSH jump hosts") - func testImportSSHJumpHosts() { - let jumpJson = #"[{"host":"jump1.com","port":22,"username":"admin","authMethod":"password","privateKeyPath":""}]"# - let encoded = jumpJson.addingPercentEncoding(withAllowedCharacters: .urlQueryAllowed)! - let url = URL(string: "tablepro://import?name=Dev&host=localhost&type=mysql&ssh=1&sshHost=bastion.com&sshJumpHosts=\(encoded)")! - guard case .importConnection(let conn) = DeeplinkHandler.parse(url) else { - Issue.record("Expected .importConnection") - return - } - #expect(conn.sshConfig?.jumpHosts?.count == 1) - #expect(conn.sshConfig?.jumpHosts?.first?.host == "jump1.com") - #expect(conn.sshConfig?.jumpHosts?.first?.username == "admin") - } - - @Test("Import with invalid jump hosts JSON ignores gracefully") - func testImportInvalidJumpHostsJSON() { - let url = URL(string: "tablepro://import?name=Dev&host=localhost&type=mysql&ssh=1&sshHost=bastion.com&sshJumpHosts=not-json")! - guard case .importConnection(let conn) = DeeplinkHandler.parse(url) else { - Issue.record("Expected .importConnection") - return - } - #expect(conn.sshConfig?.jumpHosts == nil) - } - - // MARK: - Import — SSL Config - - @Test("Import with SSL config") - func testImportWithSSL() { - let url = URL(string: "tablepro://import?name=Prod&host=db.com&type=postgresql&sslMode=require&sslCaCertPath=~/certs/ca.pem&sslClientCertPath=~/certs/client.pem&sslClientKeyPath=~/certs/client.key")! - guard case .importConnection(let conn) = DeeplinkHandler.parse(url) else { - Issue.record("Expected .importConnection") - return - } - #expect(conn.sslConfig != nil) - #expect(conn.sslConfig?.mode == "require") - #expect(conn.sslConfig?.caCertificatePath == "~/certs/ca.pem") - #expect(conn.sslConfig?.clientCertificatePath == "~/certs/client.pem") - #expect(conn.sslConfig?.clientKeyPath == "~/certs/client.key") - } - - @Test("Import without sslMode has no SSL config") - func testImportNoSSLMode() { - let url = URL(string: "tablepro://import?name=Dev&host=localhost&type=mysql&sslCaCertPath=~/ca.pem")! - guard case .importConnection(let conn) = DeeplinkHandler.parse(url) else { - Issue.record("Expected .importConnection") - return - } - #expect(conn.sslConfig == nil) - } - - @Test("Import with SSL mode only, no cert paths") - func testImportSSLModeOnly() { - let url = URL(string: "tablepro://import?name=Dev&host=localhost&type=mysql&sslMode=preferred")! - guard case .importConnection(let conn) = DeeplinkHandler.parse(url) else { - Issue.record("Expected .importConnection") - return - } - #expect(conn.sslConfig?.mode == "preferred") - #expect(conn.sslConfig?.caCertificatePath == nil) - } - - // MARK: - Import — Metadata - - @Test("Import with color") - func testImportWithColor() { - let url = URL(string: "tablepro://import?name=Dev&host=localhost&type=mysql&color=red")! - guard case .importConnection(let conn) = DeeplinkHandler.parse(url) else { - Issue.record("Expected .importConnection") - return - } - #expect(conn.color == "red") - } - - @Test("Import with tag and group names") - func testImportWithTagAndGroup() { - let url = URL(string: "tablepro://import?name=Dev&host=localhost&type=mysql&tagName=production&groupName=Backend%20Services")! - guard case .importConnection(let conn) = DeeplinkHandler.parse(url) else { - Issue.record("Expected .importConnection") - return - } - #expect(conn.tagName == "production") - #expect(conn.groupName == "Backend Services") - } - - @Test("Import with safe mode level") - func testImportWithSafeModeLevel() { - let url = URL(string: "tablepro://import?name=Prod&host=db.com&type=postgresql&safeModeLevel=readOnly")! - guard case .importConnection(let conn) = DeeplinkHandler.parse(url) else { - Issue.record("Expected .importConnection") - return - } - #expect(conn.safeModeLevel == "readOnly") - } - - @Test("Import with AI policy") - func testImportWithAIPolicy() { - let url = URL(string: "tablepro://import?name=Prod&host=db.com&type=postgresql&aiPolicy=never")! - guard case .importConnection(let conn) = DeeplinkHandler.parse(url) else { - Issue.record("Expected .importConnection") - return - } - #expect(conn.aiPolicy == "never") - } - - // MARK: - Import — Other Fields - - @Test("Import with Redis database") - func testImportWithRedisDatabase() { - let url = URL(string: "tablepro://import?name=Cache&host=localhost&type=redis&redisDatabase=3")! - guard case .importConnection(let conn) = DeeplinkHandler.parse(url) else { - Issue.record("Expected .importConnection") - return - } - #expect(conn.redisDatabase == 3) - } - - @Test("Import with startup commands") - func testImportWithStartupCommands() { - let commands = "SET search_path TO myschema;" - let encoded = commands.addingPercentEncoding(withAllowedCharacters: .urlQueryAllowed)! - let url = URL(string: "tablepro://import?name=Dev&host=localhost&type=postgresql&startupCommands=\(encoded)")! - guard case .importConnection(let conn) = DeeplinkHandler.parse(url) else { - Issue.record("Expected .importConnection") - return - } - #expect(conn.startupCommands == commands) - } - - @Test("Import with localOnly flag") - func testImportWithLocalOnly() { - let url = URL(string: "tablepro://import?name=Local&host=localhost&type=sqlite&localOnly=1")! - guard case .importConnection(let conn) = DeeplinkHandler.parse(url) else { - Issue.record("Expected .importConnection") - return - } - #expect(conn.localOnly == true) - } - - @Test("Import without localOnly defaults to nil") - func testImportLocalOnlyDefault() { - let url = URL(string: "tablepro://import?name=Dev&host=localhost&type=mysql")! - guard case .importConnection(let conn) = DeeplinkHandler.parse(url) else { - Issue.record("Expected .importConnection") - return - } - #expect(conn.localOnly == nil) - } - - // MARK: - Import — Additional Fields (Plugin) - - @Test("Import with additional fields using af_ prefix") - func testImportAdditionalFields() { - let url = URL(string: "tablepro://import?name=Mongo&host=cluster.mongodb.net&type=mongodb&af_authSource=admin&af_replicaSet=rs0&af_useSrv=true")! - guard case .importConnection(let conn) = DeeplinkHandler.parse(url) else { - Issue.record("Expected .importConnection") - return - } - #expect(conn.additionalFields?["authSource"] == "admin") - #expect(conn.additionalFields?["replicaSet"] == "rs0") - #expect(conn.additionalFields?["useSrv"] == "true") - } - - @Test("Import with af_ prefix but no value is ignored") - func testImportAdditionalFieldsNoValue() { - let url = URL(string: "tablepro://import?name=Dev&host=localhost&type=mysql&af_emptyField=")! - guard case .importConnection(let conn) = DeeplinkHandler.parse(url) else { - Issue.record("Expected .importConnection") - return - } - #expect(conn.additionalFields == nil) - } - - @Test("Import with af_ prefix but empty key is ignored") - func testImportAdditionalFieldsEmptyKey() { - let url = URL(string: "tablepro://import?name=Dev&host=localhost&type=mysql&af_=someValue")! - guard case .importConnection(let conn) = DeeplinkHandler.parse(url) else { - Issue.record("Expected .importConnection") - return - } - #expect(conn.additionalFields == nil) - } - - // MARK: - Import — Combined Full Config - - @Test("Import with all fields combined") - func testImportFullConfig() { - var components = URLComponents() - components.scheme = "tablepro" - components.host = "import" - components.queryItems = [ - URLQueryItem(name: "name", value: "Production DB"), - URLQueryItem(name: "host", value: "db.prod.internal"), - URLQueryItem(name: "port", value: "5433"), - URLQueryItem(name: "type", value: "postgresql"), - URLQueryItem(name: "username", value: "app_user"), - URLQueryItem(name: "database", value: "main"), - URLQueryItem(name: "ssh", value: "1"), - URLQueryItem(name: "sshHost", value: "bastion.prod.com"), - URLQueryItem(name: "sshPort", value: "2222"), - URLQueryItem(name: "sshUsername", value: "deploy"), - URLQueryItem(name: "sshAuthMethod", value: "privateKey"), - URLQueryItem(name: "sshPrivateKeyPath", value: "~/.ssh/prod_key"), - URLQueryItem(name: "sslMode", value: "verify-ca"), - URLQueryItem(name: "sslCaCertPath", value: "~/certs/ca.pem"), - URLQueryItem(name: "color", value: "red"), - URLQueryItem(name: "tagName", value: "production"), - URLQueryItem(name: "groupName", value: "Backend"), - URLQueryItem(name: "safeModeLevel", value: "readOnly"), - URLQueryItem(name: "aiPolicy", value: "never"), - URLQueryItem(name: "startupCommands", value: "SET statement_timeout = 30000;"), - URLQueryItem(name: "af_schema", value: "public"), - ] - let url = components.url! - - guard case .importConnection(let conn) = DeeplinkHandler.parse(url) else { - Issue.record("Expected .importConnection") - return - } - - #expect(conn.name == "Production DB") - #expect(conn.host == "db.prod.internal") - #expect(conn.port == 5433) - #expect(conn.type == "PostgreSQL") - #expect(conn.username == "app_user") - #expect(conn.database == "main") - - #expect(conn.sshConfig?.enabled == true) - #expect(conn.sshConfig?.host == "bastion.prod.com") - #expect(conn.sshConfig?.port == 2222) - #expect(conn.sshConfig?.username == "deploy") - #expect(conn.sshConfig?.authMethod == "privateKey") - #expect(conn.sshConfig?.privateKeyPath == "~/.ssh/prod_key") - - #expect(conn.sslConfig?.mode == "verify-ca") - #expect(conn.sslConfig?.caCertificatePath == "~/certs/ca.pem") - - #expect(conn.color == "red") - #expect(conn.tagName == "production") - #expect(conn.groupName == "Backend") - #expect(conn.safeModeLevel == "readOnly") - #expect(conn.aiPolicy == "never") - #expect(conn.startupCommands == "SET statement_timeout = 30000;") - #expect(conn.additionalFields?["schema"] == "public") - } - - // MARK: - Import — Security - - @Test("Import link never contains password field") - func testImportNoPasswordField() { - let url = URL(string: "tablepro://import?name=Dev&host=localhost&type=mysql&password=secret123")! - guard case .importConnection(let conn) = DeeplinkHandler.parse(url) else { - Issue.record("Expected .importConnection") - return - } - #expect(conn.name == "Dev") - } - - // MARK: - Import — Edge Cases - - @Test("Import with percent-encoded special characters in name") - func testImportSpecialCharsInName() { - let url = URL(string: "tablepro://import?name=Dev%20%26%20Staging&host=localhost&type=mysql")! - guard case .importConnection(let conn) = DeeplinkHandler.parse(url) else { - Issue.record("Expected .importConnection") - return - } - #expect(conn.name == "Dev & Staging") - } - - @Test("Import with IPv6 host") - func testImportIPv6Host() { - let url = URL(string: "tablepro://import?name=IPv6&host=::1&type=postgresql")! - guard case .importConnection(let conn) = DeeplinkHandler.parse(url) else { - Issue.record("Expected .importConnection") - return - } - #expect(conn.host == "::1") - } - - @Test("Import with no query params returns nil") - func testImportNoQueryParams() { - let url = URL(string: "tablepro://import")! - #expect(DeeplinkHandler.parse(url) == nil) - } - - @Test("Import with unknown type returns nil") - func testImportUnknownType() { - let url = URL(string: "tablepro://import?name=Dev&host=localhost&type=nonexistent_db")! - #expect(DeeplinkHandler.parse(url) == nil) - } - - @Test("sshProfileId is always nil in deep links") - func testImportSSHProfileIdAlwaysNil() { - let url = URL(string: "tablepro://import?name=Dev&host=localhost&type=mysql&ssh=1&sshHost=bastion.com")! - guard case .importConnection(let conn) = DeeplinkHandler.parse(url) else { - Issue.record("Expected .importConnection") - return - } - #expect(conn.sshProfileId == nil) - } -} diff --git a/TableProTests/Core/Services/TabPersistenceCoordinatorTests.swift b/TableProTests/Core/Services/TabPersistenceCoordinatorTests.swift index 7d267b654..47b290847 100644 --- a/TableProTests/Core/Services/TabPersistenceCoordinatorTests.swift +++ b/TableProTests/Core/Services/TabPersistenceCoordinatorTests.swift @@ -103,30 +103,6 @@ struct TabPersistenceCoordinatorTests { await sleep() } - @Test("saveNow with pre-converted PersistedTab array round-trips") - func saveNowWithPersistedTabsRoundTrips() async { - let coordinator = makeCoordinator() - let persistedTabs = [ - PersistedTab(id: UUID(), title: "P1", query: "SELECT 1", tabType: .query, tableName: nil), - PersistedTab(id: UUID(), title: "P2", query: "SELECT 2", tabType: .table, tableName: "users") - ] - let selectedId = persistedTabs[0].id - - coordinator.saveNow(persistedTabs: persistedTabs, selectedTabId: selectedId) - await sleep() - - let result = await coordinator.restoreFromDisk() - - #expect(result.tabs.count == 2) - #expect(result.selectedTabId == selectedId) - #expect(result.tabs[0].title == "P1") - #expect(result.tabs[1].tableContext.tableName == "users") - #expect(result.source == .disk) - - coordinator.clearSavedState() - await sleep() - } - @Test("Large query over 500KB is truncated to empty string in persisted tab") func largeQueryIsTruncated() async { let coordinator = makeCoordinator() diff --git a/TableProTests/Core/Services/TableQueryBuilderMSSQLTests.swift b/TableProTests/Core/Services/TableQueryBuilderMSSQLTests.swift index 9d369d289..ff6699b2b 100644 --- a/TableProTests/Core/Services/TableQueryBuilderMSSQLTests.swift +++ b/TableProTests/Core/Services/TableQueryBuilderMSSQLTests.swift @@ -17,11 +17,12 @@ struct TableQueryBuilderMSSQLTests { init() { FakeMSSQLPluginRegistration.registerIfNeeded() let dialect = PluginManager.shared.sqlDialect(for: .mssql) + let dialectQuote = dialect.map(quoteIdentifierFromDialect) self.builder = TableQueryBuilder( databaseType: .mssql, pluginDriver: PluginManager.shared.queryBuildingDriver(for: .mssql), dialect: dialect, - dialectQuote: quoteIdentifierFromDialect(dialect) + dialectQuote: dialectQuote ) } diff --git a/TableProTests/Core/Services/WelcomeWindowSuppressionTests.swift b/TableProTests/Core/Services/WelcomeWindowSuppressionTests.swift deleted file mode 100644 index 513cc8a7f..000000000 --- a/TableProTests/Core/Services/WelcomeWindowSuppressionTests.swift +++ /dev/null @@ -1,286 +0,0 @@ -// -// WelcomeWindowSuppressionTests.swift -// TableProTests -// -// Regression tests for the welcome window suppression logic in AppDelegate+WindowConfig. -// Covers the fix where double-clicking .duckdb files from Finder caused the app to freeze -// because suppression gave up too early and welcome was closed instead of hidden. -// - -import AppKit -import Foundation -import Testing -@testable import TablePro - -@Suite("Welcome Window Suppression") -@MainActor -struct WelcomeWindowSuppressionTests { - /// Create a fresh AppDelegate for each test — avoids relying on NSApp.delegate - /// which may not be our AppDelegate in parallel test runner processes. - private func makeAppDelegate() -> AppDelegate { - AppDelegate() - } - - private func makeWindow(identifier: String) -> NSWindow { - let window = NSWindow() - window.identifier = NSUserInterfaceItemIdentifier(identifier) - return window - } - - // MARK: - Window Identification - - @Test("isMainWindow — exact identifier 'main'") - func isMainWindowExact() { - let delegate = makeAppDelegate() - let window = makeWindow(identifier: "main") - #expect(delegate.isMainWindow(window)) - } - - @Test("isMainWindow — prefixed identifier 'main-123'") - func isMainWindowPrefixed() { - let delegate = makeAppDelegate() - let window = makeWindow(identifier: "main-123") - #expect(delegate.isMainWindow(window)) - } - - @Test("isMainWindow — returns false for nil identifier") - func isMainWindowNilIdentifier() { - let delegate = makeAppDelegate() - let window = NSWindow() - window.identifier = nil - #expect(!delegate.isMainWindow(window)) - } - - @Test("isMainWindow — returns false for 'welcome'") - func isMainWindowUnrelated() { - let delegate = makeAppDelegate() - let window = makeWindow(identifier: "welcome") - #expect(!delegate.isMainWindow(window)) - } - - @Test("isMainWindow — returns false for 'mainExtra' (no dash separator)") - func isMainWindowNoDash() { - let delegate = makeAppDelegate() - let window = makeWindow(identifier: "mainExtra") - #expect(!delegate.isMainWindow(window)) - } - - @Test("isWelcomeWindow — exact identifier 'welcome'") - func isWelcomeWindowExact() { - let delegate = makeAppDelegate() - let window = makeWindow(identifier: "welcome") - #expect(delegate.isWelcomeWindow(window)) - } - - @Test("isWelcomeWindow — prefixed identifier 'welcome-abc'") - func isWelcomeWindowPrefixed() { - let delegate = makeAppDelegate() - let window = makeWindow(identifier: "welcome-abc") - #expect(delegate.isWelcomeWindow(window)) - } - - @Test("isWelcomeWindow — returns false for nil identifier") - func isWelcomeWindowNilIdentifier() { - let delegate = makeAppDelegate() - let window = NSWindow() - window.identifier = nil - #expect(!delegate.isWelcomeWindow(window)) - } - - @Test("isWelcomeWindow — returns false for 'main'") - func isWelcomeWindowNotMain() { - let delegate = makeAppDelegate() - let window = makeWindow(identifier: "main") - #expect(!delegate.isWelcomeWindow(window)) - } - - @Test("isWelcomeWindow — returns false for 'welcomeExtra' (no dash separator)") - func isWelcomeWindowNoDash() { - let delegate = makeAppDelegate() - let window = makeWindow(identifier: "welcomeExtra") - #expect(!delegate.isWelcomeWindow(window)) - } - - // MARK: - suppressWelcomeWindow State - - @Test("suppressWelcomeWindow — sets isHandlingFileOpen to true") - func suppressSetsFlag() { - let delegate = makeAppDelegate() - delegate.suppressWelcomeWindow() - #expect(delegate.isHandlingFileOpen == true) - } - - @Test("suppressWelcomeWindow — increments fileOpenSuppressionCount") - func suppressIncrementsCount() { - let delegate = makeAppDelegate() - delegate.suppressWelcomeWindow() - #expect(delegate.fileOpenSuppressionCount == 1) - - delegate.suppressWelcomeWindow() - #expect(delegate.fileOpenSuppressionCount == 2) - } - - @Test("suppressWelcomeWindow — hides existing welcome windows via orderOut") - func suppressHidesWelcomeWindows() { - let delegate = makeAppDelegate() - - let welcome = makeWindow(identifier: "welcome") - welcome.orderFront(nil) - defer { welcome.close() } - - #expect(welcome.isVisible) - - delegate.suppressWelcomeWindow() - - #expect(!welcome.isVisible) - } - - // MARK: - windowDidBecomeKey Suppression Behavior - - @Test("windowDidBecomeKey — welcome hides (orderOut) when file open and no main window") - func windowDidBecomeKeyHidesWelcomeWhenNoMain() { - let delegate = makeAppDelegate() - delegate.isHandlingFileOpen = true - - let welcome = makeWindow(identifier: "welcome") - welcome.orderFront(nil) - defer { welcome.close() } - - #expect(welcome.isVisible) - - let notification = Notification(name: NSWindow.didBecomeKeyNotification, object: welcome) - delegate.windowDidBecomeKey(notification) - - // Key regression fix: welcome should be hidden (not closed) so it can reappear - // when the main window is ready — prevents "no visible windows" freeze - #expect(!welcome.isVisible) - } - - @Test("windowDidBecomeKey — welcome closes when file open and main window exists") - func windowDidBecomeKeyClosesWelcomeWhenMainExists() { - let delegate = makeAppDelegate() - delegate.isHandlingFileOpen = true - - let mainWin = makeWindow(identifier: "main") - mainWin.orderFront(nil) - defer { mainWin.close() } - - let welcome = makeWindow(identifier: "welcome") - welcome.orderFront(nil) - defer { welcome.close() } - - let notification = Notification(name: NSWindow.didBecomeKeyNotification, object: welcome) - delegate.windowDidBecomeKey(notification) - - #expect(!welcome.isVisible) - } - - @Test("windowDidBecomeKey — welcome not suppressed when isHandlingFileOpen is false") - func windowDidBecomeKeyNoSuppressionWhenNotHandlingFile() { - let delegate = makeAppDelegate() - delegate.isHandlingFileOpen = false - - let welcome = makeWindow(identifier: "welcome") - welcome.orderFront(nil) - defer { welcome.close() } - - let notification = Notification(name: NSWindow.didBecomeKeyNotification, object: welcome) - delegate.windowDidBecomeKey(notification) - - #expect(welcome.isVisible) - } - - @Test("windowDidBecomeKey — non-welcome window is not affected by suppression") - func windowDidBecomeKeyIgnoresNonWelcome() { - let delegate = makeAppDelegate() - delegate.isHandlingFileOpen = true - - let other = makeWindow(identifier: "settings") - other.orderFront(nil) - defer { other.close() } - - let notification = Notification(name: NSWindow.didBecomeKeyNotification, object: other) - delegate.windowDidBecomeKey(notification) - - #expect(other.isVisible) - } - - // MARK: - Suppression Count State - - @Test("Multiple suppress calls — count increments independently") - func multipleSuppressionCountsStack() { - let delegate = makeAppDelegate() - delegate.suppressWelcomeWindow() - delegate.suppressWelcomeWindow() - delegate.suppressWelcomeWindow() - - #expect(delegate.fileOpenSuppressionCount == 3) - #expect(delegate.isHandlingFileOpen == true) - } - - @Test("endFileOpenSuppression — decrement to zero resets isHandlingFileOpen") - func endSuppressionResetsFlag() { - let delegate = makeAppDelegate() - delegate.isHandlingFileOpen = true - delegate.fileOpenSuppressionCount = 1 - - delegate.endFileOpenSuppression() - - #expect(delegate.fileOpenSuppressionCount == 0) - #expect(delegate.isHandlingFileOpen == false) - } - - @Test("endFileOpenSuppression — keeps flag true while count > 0") - func endSuppressionKeepsFlagWhilePositive() { - let delegate = makeAppDelegate() - delegate.isHandlingFileOpen = true - delegate.fileOpenSuppressionCount = 2 - - delegate.endFileOpenSuppression() - - #expect(delegate.fileOpenSuppressionCount == 1) - #expect(delegate.isHandlingFileOpen == true) - } - - // MARK: - Main Window Becomes Key - - @Test("windowDidBecomeKey — main window appearing closes welcome during file open") - func windowDidBecomeKeyMainWindowClosesWelcome() { - let delegate = makeAppDelegate() - delegate.isHandlingFileOpen = true - - let welcome = makeWindow(identifier: "welcome") - welcome.orderFront(nil) - defer { welcome.close() } - - let mainWin = makeWindow(identifier: "main") - mainWin.orderFront(nil) - defer { mainWin.close() } - - // Simulate main window becoming key — should close welcome - let notification = Notification(name: NSWindow.didBecomeKeyNotification, object: mainWin) - delegate.windowDidBecomeKey(notification) - - #expect(!welcome.isVisible) - } - - @Test("windowDidBecomeKey — main window does not close welcome when not handling file open") - func windowDidBecomeKeyMainWindowNoEffectWhenNotHandling() { - let delegate = makeAppDelegate() - delegate.isHandlingFileOpen = false - - let welcome = makeWindow(identifier: "welcome") - welcome.orderFront(nil) - defer { welcome.close() } - - let mainWin = makeWindow(identifier: "main") - mainWin.orderFront(nil) - defer { mainWin.close() } - - let notification = Notification(name: NSWindow.didBecomeKeyNotification, object: mainWin) - delegate.windowDidBecomeKey(notification) - - // Welcome should remain visible — no suppression active - #expect(welcome.isVisible) - } -} diff --git a/TableProTests/ViewModels/SidebarViewModelTests.swift b/TableProTests/ViewModels/SidebarViewModelTests.swift index 90b0e9e61..f153a07c0 100644 --- a/TableProTests/ViewModels/SidebarViewModelTests.swift +++ b/TableProTests/ViewModels/SidebarViewModelTests.swift @@ -42,7 +42,6 @@ private func makeSUT( let optionsBinding = Binding(get: { optionsState }, set: { optionsState = $0 }) let vm = SidebarViewModel( - tables: tablesBinding, selectedTables: selectedBinding, pendingTruncates: truncatesBinding, pendingDeletes: deletesBinding, diff --git a/TableProTests/Views/Main/CoordinatorReloadSidebarTests.swift b/TableProTests/Views/Main/CoordinatorReloadSidebarTests.swift deleted file mode 100644 index 41c296498..000000000 --- a/TableProTests/Views/Main/CoordinatorReloadSidebarTests.swift +++ /dev/null @@ -1,57 +0,0 @@ -// -// CoordinatorRefreshTablesTests.swift -// TableProTests -// -// Tests for MainContentCoordinator.refreshTables() — -// verifies it updates sidebarLoadingState and populates session tables. -// - -import SwiftUI -import Testing - -@testable import TablePro - -@Suite("CoordinatorRefreshTables") -struct CoordinatorRefreshTablesTests { - @Test("refreshTables sets loading state to error when no driver") - @MainActor - func setsErrorWhenNoDriver() async { - let connection = TestFixtures.makeConnection(database: "db_a") - let tabManager = QueryTabManager() - let changeManager = DataChangeManager() - let toolbarState = ConnectionToolbarState() - - let coordinator = MainContentCoordinator( - connection: connection, - tabManager: tabManager, - changeManager: changeManager, - toolbarState: toolbarState - ) - defer { coordinator.teardown() } - - #expect(coordinator.sidebarLoadingState == .idle) - - await coordinator.refreshTables() - - #expect(coordinator.sidebarLoadingState == .error("Not connected")) - } - - @Test("sidebarLoadingState defaults to idle") - @MainActor - func defaultsToIdle() { - let connection = TestFixtures.makeConnection(database: "db_a") - let tabManager = QueryTabManager() - let changeManager = DataChangeManager() - let toolbarState = ConnectionToolbarState() - - let coordinator = MainContentCoordinator( - connection: connection, - tabManager: tabManager, - changeManager: changeManager, - toolbarState: toolbarState - ) - defer { coordinator.teardown() } - - #expect(coordinator.sidebarLoadingState == .idle) - } -} diff --git a/TableProTests/Views/Main/EvictionTests.swift b/TableProTests/Views/Main/EvictionTests.swift index e96528b76..c9429e1d9 100644 --- a/TableProTests/Views/Main/EvictionTests.swift +++ b/TableProTests/Views/Main/EvictionTests.swift @@ -30,7 +30,7 @@ struct EvictionTests { to coordinator: MainContentCoordinator, tabManager: QueryTabManager, tableName: String = "users" - ) { + ) throws { try tabManager.addTableTab(tableName: tableName) guard let index = tabManager.selectedTabIndex else { return } let rows = TestFixtures.makeRows(count: 10) @@ -43,12 +43,11 @@ struct EvictionTests { } @Test("evictInactiveRowData evicts background tabs without pending changes") - func evictsLoadedTabs() { + func evictsLoadedTabs() throws { let (coordinator, tabManager) = makeCoordinator() - addLoadedTab(to: coordinator, tabManager: tabManager, tableName: "users") + try addLoadedTab(to: coordinator, tabManager: tabManager, tableName: "users") let backgroundTabId = tabManager.tabs[0].id - // Add a second tab so the first becomes background (eviction skips the selected tab) - addLoadedTab(to: coordinator, tabManager: tabManager, tableName: "orders") + try addLoadedTab(to: coordinator, tabManager: tabManager, tableName: "orders") #expect(coordinator.tabSessionRegistry.tableRows(for: backgroundTabId).rows.count == 10) #expect(coordinator.tabSessionRegistry.isEvicted(backgroundTabId) == false) @@ -60,9 +59,9 @@ struct EvictionTests { } @Test("evictInactiveRowData skips tabs with pending changes") - func skipsTabsWithPendingChanges() { + func skipsTabsWithPendingChanges() throws { let (coordinator, tabManager) = makeCoordinator() - addLoadedTab(to: coordinator, tabManager: tabManager, tableName: "users") + try addLoadedTab(to: coordinator, tabManager: tabManager, tableName: "users") tabManager.tabs[0].pendingChanges.deletedRowIndices = [0] @@ -74,11 +73,11 @@ struct EvictionTests { } @Test("evictInactiveRowData preserves column metadata after eviction") - func preservesMetadataAfterEviction() { + func preservesMetadataAfterEviction() throws { let (coordinator, tabManager) = makeCoordinator() - addLoadedTab(to: coordinator, tabManager: tabManager, tableName: "users") + try addLoadedTab(to: coordinator, tabManager: tabManager, tableName: "users") let backgroundTabId = tabManager.tabs[0].id - addLoadedTab(to: coordinator, tabManager: tabManager, tableName: "orders") + try addLoadedTab(to: coordinator, tabManager: tabManager, tableName: "orders") coordinator.evictInactiveRowData() diff --git a/TableProTests/Views/Main/MainStatusBarLayoutTests.swift b/TableProTests/Views/Main/MainStatusBarLayoutTests.swift index 27e07802a..1c1f5f6ef 100644 --- a/TableProTests/Views/Main/MainStatusBarLayoutTests.swift +++ b/TableProTests/Views/Main/MainStatusBarLayoutTests.swift @@ -15,6 +15,7 @@ struct MainStatusBarLayoutTests { func instantiateWithEmptySnapshot() { let view = MainStatusBarView( snapshot: StatusBarSnapshot(tab: nil, tableRows: nil), + filterState: TabFilterState(), hiddenColumns: [], allColumns: [], selectedRowIndices: [], @@ -28,7 +29,8 @@ struct MainStatusBarLayoutTests { onPaginationGo: {}, onToggleColumn: { _ in }, onShowAllColumns: {}, - onHideAllColumns: { _ in } + onHideAllColumns: { _ in }, + onToggleFilters: {} ) #expect(type(of: view.body) != Never.self) } diff --git a/TableProTests/Views/Main/MultiConnectionNavigationTests.swift b/TableProTests/Views/Main/MultiConnectionNavigationTests.swift index c73cee17c..6b82f30d2 100644 --- a/TableProTests/Views/Main/MultiConnectionNavigationTests.swift +++ b/TableProTests/Views/Main/MultiConnectionNavigationTests.swift @@ -95,8 +95,6 @@ struct MultiConnectionNavigationTests { #expect(tab.tableContext.databaseName == "primary_db") } - // Note: sidebarLoadingState guard test lives in SwitchDatabaseTests.swift - // MARK: - openTableTab: different database types create correct tab @Test("openTableTab with postgresql connection adds tab") diff --git a/TableProTests/Views/Main/OpenTableTabTests.swift b/TableProTests/Views/Main/OpenTableTabTests.swift index 6dde4a479..1544e12e4 100644 --- a/TableProTests/Views/Main/OpenTableTabTests.swift +++ b/TableProTests/Views/Main/OpenTableTabTests.swift @@ -1,14 +1,3 @@ -// -// OpenTableTabTests.swift -// TableProTests -// -// Tests for openTableTab logic — verifies skip/open behavior -// based on current tab state and database context. -// -// Note: sidebarLoadingState guard and same-table fast path tests -// live in SwitchDatabaseTests.swift to avoid duplication. -// - import Foundation import Testing diff --git a/TableProTests/Views/Main/RowOperationsDispatchTests.swift b/TableProTests/Views/Main/RowOperationsDispatchTests.swift index c0778484e..f46939d08 100644 --- a/TableProTests/Views/Main/RowOperationsDispatchTests.swift +++ b/TableProTests/Views/Main/RowOperationsDispatchTests.swift @@ -49,7 +49,7 @@ struct RowOperationsDispatchTests { let tabId: UUID } - private func makeFixture(rowCount: Int = 5) -> Fixture { + private func makeFixture(rowCount: Int = 5) throws -> Fixture { let tabManager = QueryTabManager() let coordinator = MainContentCoordinator( connection: TestFixtures.makeConnection(), @@ -85,8 +85,8 @@ struct RowOperationsDispatchTests { } @Test("Soft-delete of existing rows dispatches invalidateCachesForUndoRedo") - func softDeleteDispatchesInvalidate() { - let f = makeFixture(rowCount: 5) + func softDeleteDispatchesInvalidate() throws { + let f = try makeFixture(rowCount: 5) let beforeInvalidate = f.fake.invalidateCount f.coordinator.deleteSelectedRows(indices: [0, 1]) @@ -96,8 +96,8 @@ struct RowOperationsDispatchTests { } @Test("Physical delete of inserted rows dispatches applyDelta, not invalidate") - func physicalDeleteDispatchesDelta() { - let f = makeFixture(rowCount: 3) + func physicalDeleteDispatchesDelta() throws { + let f = try makeFixture(rowCount: 3) f.coordinator.addNewRow() let insertedIndex = f.coordinator.tabSessionRegistry.tableRows(for: f.tabId).count - 1 let beforeInvalidate = f.fake.invalidateCount diff --git a/TableProTests/Views/Main/SortCacheInvalidationTests.swift b/TableProTests/Views/Main/SortCacheInvalidationTests.swift index 982748204..cabeddcc8 100644 --- a/TableProTests/Views/Main/SortCacheInvalidationTests.swift +++ b/TableProTests/Views/Main/SortCacheInvalidationTests.swift @@ -16,7 +16,7 @@ import Testing @Suite("querySortCache invalidation on row mutations") @MainActor struct SortCacheInvalidationTests { - private func makeCoordinator() -> (MainContentCoordinator, QueryTabManager, UUID) { + private func makeCoordinator() throws -> (MainContentCoordinator, QueryTabManager, UUID) { let tabManager = QueryTabManager() let coordinator = MainContentCoordinator( connection: TestFixtures.makeConnection(), @@ -49,8 +49,8 @@ struct SortCacheInvalidationTests { } @Test("addNewRow clears querySortCache for the tab") - func addNewRowInvalidatesCache() { - let (coordinator, _, tabId) = makeCoordinator() + func addNewRowInvalidatesCache() throws { + let (coordinator, _, tabId) = try makeCoordinator() seedRows(coordinator, for: tabId, count: 3) seedCache(coordinator, for: tabId) @@ -60,8 +60,8 @@ struct SortCacheInvalidationTests { } @Test("deleteSelectedRows clears querySortCache when physically removing inserted rows") - func physicalDeleteInvalidatesCache() { - let (coordinator, _, tabId) = makeCoordinator() + func physicalDeleteInvalidatesCache() throws { + let (coordinator, _, tabId) = try makeCoordinator() seedRows(coordinator, for: tabId, count: 3) coordinator.addNewRow() let insertedIndex = coordinator.tabSessionRegistry.tableRows(for: tabId).count - 1 @@ -73,8 +73,8 @@ struct SortCacheInvalidationTests { } @Test("deleteSelectedRows preserves querySortCache on soft delete of existing rows") - func softDeletePreservesCache() { - let (coordinator, _, tabId) = makeCoordinator() + func softDeletePreservesCache() throws { + let (coordinator, _, tabId) = try makeCoordinator() seedRows(coordinator, for: tabId, count: 5) seedCache(coordinator, for: tabId) @@ -84,8 +84,8 @@ struct SortCacheInvalidationTests { } @Test("duplicateSelectedRow clears querySortCache for the tab") - func duplicateRowInvalidatesCache() { - let (coordinator, _, tabId) = makeCoordinator() + func duplicateRowInvalidatesCache() throws { + let (coordinator, _, tabId) = try makeCoordinator() seedRows(coordinator, for: tabId, count: 3) seedCache(coordinator, for: tabId) diff --git a/TableProTests/Views/Main/TableRowsMutationTests.swift b/TableProTests/Views/Main/TableRowsMutationTests.swift index a7073a9d4..cef137597 100644 --- a/TableProTests/Views/Main/TableRowsMutationTests.swift +++ b/TableProTests/Views/Main/TableRowsMutationTests.swift @@ -77,7 +77,7 @@ struct TableRowsMutationTests { @Test("setActiveTableRows on the active tab dispatches applyFullReplace") func dispatchesOnActiveTab() throws { let f = makeFixture() - f.try tabManager.addTableTab(tableName: "users") + try f.tabManager.addTableTab(tableName: "users") let activeTabId = f.tabManager.tabs[0].id f.coordinator.setActiveTableRows(makeTableRows(rowCount: 3), for: activeTabId) @@ -88,9 +88,9 @@ struct TableRowsMutationTests { @Test("setActiveTableRows on a background tab does not dispatch") func skipsOnBackgroundTab() throws { let f = makeFixture() - f.try tabManager.addTableTab(tableName: "users") + try f.tabManager.addTableTab(tableName: "users") let backgroundTabId = f.tabManager.tabs[0].id - f.try tabManager.addTableTab(tableName: "orders") + try f.tabManager.addTableTab(tableName: "orders") f.coordinator.setActiveTableRows(makeTableRows(rowCount: 5), for: backgroundTabId) @@ -100,7 +100,7 @@ struct TableRowsMutationTests { @Test("repeated setActiveTableRows dispatches once per call") func dispatchesOncePerCall() throws { let f = makeFixture() - f.try tabManager.addTableTab(tableName: "users") + try f.tabManager.addTableTab(tableName: "users") let activeTabId = f.tabManager.tabs[0].id f.coordinator.setActiveTableRows(TableRows(), for: activeTabId) @@ -112,7 +112,7 @@ struct TableRowsMutationTests { @Test("setActiveTableRows dispatches scrollToTop when pendingScrollToTopAfterReplace contains tabId") func scrollToTopFiresOnPendingFlag() throws { let f = makeFixture() - f.try tabManager.addTableTab(tableName: "users") + try f.tabManager.addTableTab(tableName: "users") let activeTabId = f.tabManager.tabs[0].id f.coordinator.pendingScrollToTopAfterReplace.insert(activeTabId) @@ -125,9 +125,9 @@ struct TableRowsMutationTests { @Test("scrollToTop pending flag for tab A does not fire when tab B is replaced") func scrollToTopFlagIsScopedPerTab() throws { let f = makeFixture() - f.try tabManager.addTableTab(tableName: "users") + try f.tabManager.addTableTab(tableName: "users") let firstTabId = f.tabManager.tabs[0].id - f.try tabManager.addTableTab(tableName: "orders") + try f.tabManager.addTableTab(tableName: "orders") let secondTabId = f.tabManager.tabs[1].id f.coordinator.pendingScrollToTopAfterReplace.insert(firstTabId) @@ -140,7 +140,7 @@ struct TableRowsMutationTests { @Test("setActiveTableRows without pending flag does not scroll to top") func scrollToTopSkippedWhenFlagAbsent() throws { let f = makeFixture() - f.try tabManager.addTableTab(tableName: "users") + try f.tabManager.addTableTab(tableName: "users") let activeTabId = f.tabManager.tabs[0].id f.coordinator.setActiveTableRows(makeTableRows(rowCount: 3), for: activeTabId) diff --git a/TableProTests/Views/SwitchDatabaseTests.swift b/TableProTests/Views/SwitchDatabaseTests.swift index c04c9ecf6..de1d95543 100644 --- a/TableProTests/Views/SwitchDatabaseTests.swift +++ b/TableProTests/Views/SwitchDatabaseTests.swift @@ -27,83 +27,6 @@ private func simulateDatabaseSwitch( @Suite("SwitchDatabase") struct SwitchDatabaseTests { - // MARK: - sidebarLoadingState - - @Test("sidebarLoadingState defaults to idle") - @MainActor - func loadingStateDefaultsToIdle() { - let connection = TestFixtures.makeConnection() - let tabManager = QueryTabManager() - let changeManager = DataChangeManager() - let toolbarState = ConnectionToolbarState() - - let coordinator = MainContentCoordinator( - connection: connection, - tabManager: tabManager, - changeManager: changeManager, - toolbarState: toolbarState - ) - defer { coordinator.teardown() } - - #expect(coordinator.sidebarLoadingState == .idle) - } - - // MARK: - openTableTab behavior during database switch - - @Test("openTableTab skips new window when sidebar is loading with existing tabs") - @MainActor - func openTableTabSkipsNewWindowDuringSwitch() throws { - let connection = TestFixtures.makeConnection(database: "db_a") - let tabManager = QueryTabManager() - let changeManager = DataChangeManager() - let toolbarState = ConnectionToolbarState() - - let coordinator = MainContentCoordinator( - connection: connection, - tabManager: tabManager, - changeManager: changeManager, - toolbarState: toolbarState - ) - defer { coordinator.teardown() } - - try tabManager.addTableTab(tableName: "users", databaseType: .mysql, databaseName: "db_a") - let tabCountBefore = tabManager.tabs.count - - coordinator.sidebarLoadingState = .loading - - coordinator.openTableTab("orders") - - #expect(tabManager.tabs.count == tabCountBefore) - } - - @Test("openTableTab adds tab in-place when sidebar is loading with empty tabs") - @MainActor - func openTableTabAddsInPlaceWhenSwitchingWithEmptyTabs() { - let connection = TestFixtures.makeConnection(database: "db_a") - let tabManager = QueryTabManager() - let changeManager = DataChangeManager() - let toolbarState = ConnectionToolbarState() - - let coordinator = MainContentCoordinator( - connection: connection, - tabManager: tabManager, - changeManager: changeManager, - toolbarState: toolbarState - ) - defer { coordinator.teardown() } - - #expect(tabManager.tabs.isEmpty) - - coordinator.sidebarLoadingState = .loading - - coordinator.openTableTab("users") - - #expect(tabManager.tabs.count == 1) - #expect(tabManager.tabs.first?.tableContext.tableName == "users") - } - - // MARK: - openTableTab fast path (same table + same database) - @Test("openTableTab skips when table is already active tab in same database") @MainActor func openTableTabSkipsForSameTableSameDatabase() throws { @@ -170,31 +93,4 @@ struct SwitchDatabaseTests { #expect(tabManager.tabs.isEmpty) #expect(tabManager.selectedTabId == nil) } - - // MARK: - sidebarLoadingState during database switch - - @Test("switchDatabase sets sidebarLoadingState to loading then error when no driver") - @MainActor - func switchDatabaseSetsLoadingState() async { - let connection = TestFixtures.makeConnection(database: "db_a") - let tabManager = QueryTabManager() - let changeManager = DataChangeManager() - let toolbarState = ConnectionToolbarState() - - let coordinator = MainContentCoordinator( - connection: connection, - tabManager: tabManager, - changeManager: changeManager, - toolbarState: toolbarState - ) - defer { coordinator.teardown() } - - #expect(coordinator.sidebarLoadingState == .idle) - - await coordinator.switchDatabase(to: "db_b") - - // Without a driver, switchDatabase sets loading then returns early - // refreshTables will set error state since there's no driver - #expect(coordinator.sidebarLoadingState == .error("Not connected")) - } } diff --git a/docs/README.md b/docs/README.md index 2080ab608..db275c528 100644 --- a/docs/README.md +++ b/docs/README.md @@ -13,8 +13,8 @@ docs/ ├── databases/ # Database connection guides ├── features/ # Feature documentation ├── customization/ # Settings and customization -├── development/ # Developer documentation -└── vi/ # Vietnamese translation (full parity) +├── external-api/ # URL scheme, MCP, pairing +└── development/ # Developer documentation ``` ## Local Development diff --git a/docs/customization/settings.mdx b/docs/customization/settings.mdx index 96339fc5d..14047429d 100644 --- a/docs/customization/settings.mdx +++ b/docs/customization/settings.mdx @@ -76,7 +76,7 @@ A tab is "clean" when it's a table tab (not query/create), unpinned, no unsaved ## MCP -The **MCP** tab covers the [External API](/external-api) surface. The server lazy-starts on first use and exposes the following sections: +The **MCP** tab covers the [External API](/external-api/index) surface. The server lazy-starts on first use and exposes the following sections: - **MCP Server**: enable toggle and live status. The server runs on a free port in the `51000-52000` range. See [MCP Server](/features/mcp). - **MCP Configuration**: row limit defaults, query timeout, "log MCP queries in history". diff --git a/docs/databases/overview.mdx b/docs/databases/overview.mdx index bf0a7a8f0..547e9190f 100644 --- a/docs/databases/overview.mdx +++ b/docs/databases/overview.mdx @@ -250,7 +250,7 @@ Open the **Advanced** tab on the connection form for the following settings: | Field | Description | |-------|-------------| | **Startup Commands** | SQL statements that run automatically on every connection. See [Startup Commands](#startup-commands). | -| **External Access** | Controls how external clients (Raycast, Cursor, Claude Desktop) reach this connection: `blocked`, `readOnly` (default), or `readWrite`. Tokens cannot exceed this level. See [External API](/external-api). | +| **External Access** | Controls how external clients (Raycast, Cursor, Claude Desktop) reach this connection: `blocked`, `readOnly` (default), or `readWrite`. Tokens cannot exceed this level. See [External API](/external-api/index). | | **Local only** | Excludes this connection from iCloud Sync. See [iCloud Sync](/features/icloud-sync). | | **Plugin fields** | Driver-specific options (for example, MongoDB `replicaSet`, ClickHouse `Secure`). | diff --git a/docs/development/architecture.mdx b/docs/development/architecture.mdx index 745c434b4..559cc4b7f 100644 --- a/docs/development/architecture.mdx +++ b/docs/development/architecture.mdx @@ -118,6 +118,36 @@ flowchart LR - **SQLContextAnalyzer**: parses cursor position context (table ref, column ref, keyword) - **SQLSchemaProvider**: actor that caches and serves schema data +### MCP Layer + +The MCP server lives under `Core/MCP/` and is split into five horizontal layers. Each layer talks only to the layer below it. + +```mermaid +flowchart TD + Wire["Wire (Codable values)
JsonRpcMessage, JsonRpcCodec, HttpRequestParser, SseDecoder"] + Transport["Transport (NWListener, URLSession, FileHandle)
MCPHttpServerTransport, MCPStdioMessageTransport, MCPStreamableHttpClientTransport"] + Session["Session / Auth / RateLimit (actors)
MCPSessionStore, MCPBearerTokenAuthenticator, MCPRateLimiter"] + Protocol["Protocol (dispatcher + handlers)
MCPProtocolDispatcher, 19 tools, MCPProgressEmitter"] + Bridge["Bridge (tablepro-mcp CLI)
BridgeProxy: stdio <-> HTTP"] + + Bridge --> Wire + Transport --> Wire + Session --> Transport + Protocol --> Session +``` + +**Wire**: pure Codable types, no I/O. JSON-RPC 2.0, strict-CRLF HTTP, SSE encoder/decoder. + +**Transport**: HTTP server uses `NWListener` and binds to `127.0.0.1:` by default. The stream endpoints (`exchanges`, `listenerState`) are bounded `AsyncStream`s consumed by `MCPServerManager`. The bridge's client-side transport uses `URLSession.bytes(for:)` for incremental SSE. + +**Session**: `MCPSessionStore` is an actor that owns session lifecycle. Idle timeout is 15 minutes. Token revocation marks sessions with `.tokenRevoked` and the SSE stream emits a typed terminate comment so clients can distinguish revoke from network blip. + +**Protocol**: `MCPProtocolDispatcher` spawns a child `Task` per inbound exchange, so two concurrent tool calls run in parallel instead of queueing on the dispatcher actor. Per-request cancellation flows through `MCPInflightRegistry`. Long-running tools emit `notifications/progress` to clients that pass `_meta.progressToken`. + +**Bridge**: `tablepro-mcp` is a 50-line composition root. `MCPStdioMessageTransport` host-side, `MCPStreamableHttpClientTransport` upstream. Errors land in os_log and stderr. The host-facing transport writes only validated `JsonRpcMessage` bytes to stdout. + +The server accepts protocol versions `2025-03-26`, `2025-06-18`, and `2025-11-25`. See [Versioning](/external-api/versioning) for the negotiation rules and the additive-within-major-version stability policy. + ## Data Flow ### Connection diff --git a/docs/external-api/mcp-clients.mdx b/docs/external-api/mcp-clients.mdx index 9e6464c73..c0bffa520 100644 --- a/docs/external-api/mcp-clients.mdx +++ b/docs/external-api/mcp-clients.mdx @@ -44,10 +44,10 @@ Restart Claude Desktop. Open a new chat, click the connectors icon below the inp Use the `claude mcp add` CLI: ```bash -claude mcp add --transport stdio tablepro -- /Applications/TablePro.app/Contents/MacOS/tablepro-mcp +claude mcp add tablepro -- /Applications/TablePro.app/Contents/MacOS/tablepro-mcp ``` -The double dash separates Claude Code's flags from the command it runs. Verify with `claude mcp list`. +The double dash separates Claude Code's flags from the command it runs. stdio is the default transport, so no `--transport` flag is needed. Verify with `claude mcp list`. ## Cursor @@ -200,8 +200,10 @@ After configuring a client, the fastest check is to ask it to list TablePro tool If the call fails, the response code tells you which layer rejected it: - **stdio process exits immediately**: TablePro is not running, or you are on a build older than 0.37. Open TablePro and re-launch the client. -- **`401 Unauthorized`**: the bridge token is stale. Quit and reopen TablePro to regenerate the handshake. +- **`401 Unauthorized`** (`WWW-Authenticate: Bearer ...`): the bridge token is stale. Quit and reopen TablePro to regenerate the handshake. - **`403 Forbidden`**: the connection's `externalAccess` is `blocked` or `readOnly`, or the token's allowlist excludes it. Open the connection editor in TablePro and adjust under **External Access**. +- **`404 Session not found`** (JSON-RPC code `-32001`): the session expired (idle timeout is 15 minutes) or the server restarted. Per the MCP spec, drop the cached `Mcp-Session-Id` and start a new `initialize` handshake. Compliant clients (Claude Desktop 0.7+, Cursor, Cline) do this automatically. +- **`429 Too Many Requests`**: 5 failed auth attempts within 60 seconds against the same `(client_address, principal)` pair triggered a 5-minute lockout. Wait it out or restart TablePro to clear the bucket. ## Troubleshooting diff --git a/docs/external-api/mcp-resources.mdx b/docs/external-api/mcp-resources.mdx index 13b28545e..fa3797f85 100644 --- a/docs/external-api/mcp-resources.mdx +++ b/docs/external-api/mcp-resources.mdx @@ -9,6 +9,31 @@ Resources are read-only views of TablePro state. AI clients use them to discover URIs use the `tablepro://` scheme inside the MCP transport. Do not confuse them with shell-level [URL scheme deep links](/external-api/url-scheme). +## Discovery + +Two MCP methods enumerate resources: + +- `resources/list` returns the static `tablepro://connections` resource plus a schema and history entry for each currently connected database. +- `resources/templates/list` returns the URI templates for `tablepro://connections/{id}/schema` and `tablepro://connections/{id}/history`, so clients can construct a URL for any connection without waiting for it to be open. + +## Response envelope + +`resources/read` wraps the resource payload in the MCP standard envelope: + +```json +{ + "contents": [ + { + "uri": "tablepro://connections", + "mimeType": "application/json", + "text": "{...JSON payload below as a string...}" + } + ] +} +``` + +The shapes documented below are what you get after parsing `text` as JSON. + ## `tablepro://connections` All saved connections with their current session state. @@ -107,7 +132,9 @@ Recent query history for a connection. ## Errors -| Code | Meaning | -|------|---------| -| `403` | Token allowlist rejects the connection, or `externalAccess` is `blocked`. | -| `404` | Connection not found. | +| JSON-RPC code | HTTP status | Meaning | +|---------------|-------------|---------| +| `-32602` | 200 | Invalid params (malformed URI, missing `uri`, bad UUID, connection not active). | +| `-32601` | 404 | Unknown resource URI (e.g. `tablepro://connections/{id}/foo`). | +| `-32004` | 404 | Resource not found in the data layer. | +| `-32007` | 403 | Token allowlist rejects the connection, or `externalAccess` is `blocked`. | diff --git a/docs/external-api/mcp-tools.mdx b/docs/external-api/mcp-tools.mdx index 988072906..86dd34b76 100644 --- a/docs/external-api/mcp-tools.mdx +++ b/docs/external-api/mcp-tools.mdx @@ -11,10 +11,18 @@ The MCP server exposes tools and resources over JSON-RPC. The tools are grouped The same tool catalog is available over two transports: -- **HTTP**: `http://127.0.0.1:/mcp` (port from the handshake file). Bearer token in `Authorization` header. +- **HTTP**: MCP Streamable HTTP at `http://127.0.0.1:/mcp` (port from the handshake file). POST for JSON-RPC requests, GET for the SSE stream that carries server-initiated notifications. Bearer token in `Authorization` header. - **stdio**: bundled `tablepro-mcp` CLI bridges stdio JSON-RPC to localhost HTTP. No token needed because the bridge reuses the in-app handshake. -See [MCP Clients](/external-api/mcp-clients) for stdio config snippets. +The server accepts `2025-03-26`, `2025-06-18`, and `2025-11-25`. On `initialize` it echoes whichever version the client requested. If the client asks for something else, the server returns `2025-11-25`. See [Versioning](/external-api/versioning). + +## What 2025-11-25 adds + +Clients on the latest spec see three things that older clients don't: + +- **Structured tool output**. Every tool that returns data fills `structuredContent` next to `content[]`. Older clients keep parsing the JSON text in `content[0].text`. Newer clients can read the typed object directly. Applies to `list_*`, `describe_table`, `get_table_ddl`, `get_connection_status`, `list_recent_tabs`, `search_query_history`, `execute_query`, and `confirm_destructive_operation`. +- **Tool annotations**. `tools/list` returns `title`, `readOnlyHint`, `destructiveHint`, `idempotentHint`, and `openWorldHint` per tool. Read tools advertise `readOnlyHint=true`. `confirm_destructive_operation` advertises `destructiveHint=true`. `execute_query` and `export_data` advertise `openWorldHint=true`. +- **Streaming progress**. Long-running tool calls emit `notifications/progress` events when the client passes a `_meta.progressToken` in the request. Today this fires on `execute_query` at four stages: Connecting, Executing, Formatting result, Done. ## Scope and access matrix @@ -90,9 +98,9 @@ Close a connection. **Input**: `{ "connection_id": "..." }` -**Output**: empty object on success. +**Output**: `{ "status": "disconnected" }` on success. -**Scope**: `readOnly`. +**Scope**: `readWrite`. ### `get_connection_status` @@ -177,7 +185,7 @@ Columns, indexes, foreign keys, primary key, DDL. } ``` -`schema` is optional. The current database is used unless the connection was first switched with `switch_database`. +`schema` is optional. The connection's current schema is used when omitted. To target a different database, call `switch_database` first. **Output**: @@ -229,7 +237,7 @@ Just the `CREATE TABLE` statement. ### `execute_query` -Run a SQL query. +Execute a SQL query. All queries are subject to the connection's safe mode policy. DROP, TRUNCATE, and ALTER...DROP must use `confirm_destructive_operation`. **Input**: @@ -244,7 +252,7 @@ Run a SQL query. } ``` -`max_rows` defaults to 500, max 10,000. `timeout_seconds` defaults to 30, max 300. Single-statement queries only. Query size cap is 100 KB. +Defaults for `max_rows` and `timeout_seconds` come from **Settings > Integrations > MCP Configuration** (default row limit, query timeout). `max_rows` is clamped to the configured maximum (default 10,000). `timeout_seconds` is clamped to 1-300. Single-statement queries only. Query size cap is 100 KB. `database` and `schema` are optional; when present, the tool calls `switch_database` and/or `switch_schema` before executing. **Output**: @@ -269,6 +277,8 @@ Run a SQL query. Safe Mode rules apply on top. A connection in Safe Mode `readOnly` returns `403` for any write SQL. +**Streaming progress**: pass `_meta.progressToken` in the request and the server sends `notifications/progress` events on the SSE channel as the query moves through "Connecting", "Executing", "Formatting result", and "Done". Clients that don't include a token get the final response only. + ### `confirm_destructive_operation` Run a DROP, TRUNCATE, or ALTER...DROP after a typed confirmation. @@ -287,7 +297,7 @@ The confirmation phrase is fixed: `I understand this is irreversible`. Anything **Output**: same shape as `execute_query`. -**Scope**: `fullAccess`. +**Scope**: `readWrite` or `fullAccess` (both grant the `tools:write` MCP scope). The connection's external access must also permit writes; a `readOnly` connection rejects destructive operations even with a matching token. ### `export_data` @@ -304,9 +314,9 @@ Export query or table data as CSV, JSON, or SQL. } ``` -`format` is one of `csv`, `json`, `sql`. `max_rows` defaults to 50,000, max 100,000. Provide either `tables` or `query`. Pass `output_path` to write to disk instead of returning data inline. +`format` is one of `csv`, `json`, `sql`. `max_rows` defaults to 50,000, max 100,000. Provide either `tables` or `query`. Table names accept letters, digits, underscore, and `.` for schema-qualified names. Pass `output_path` to write to disk instead of returning data inline; the path must resolve inside the user's `~/Downloads` directory or the request is rejected with `400`. -**Output**: an envelope with one entry per query/table exported. Each entry has the export label and either inline data or the file path. Provide `output_path` in the request to receive a file-path response. +**Output**: when `output_path` is set, returns `{ "path": "...", "rows_exported": N }`. Otherwise returns the export inline. A single export returns `{ "label": "...", "format": "csv", "row_count": N, "data": "..." }`. Multiple exports (multi-table requests) return `{ "exports": [ { "label": "...", "format": "csv", "row_count": N, "data": "..." }, ... ] }`. **Scope**: `readOnly`. @@ -316,15 +326,15 @@ Export query or table data as CSV, JSON, or SQL. **Output**: `{ "status": "switched", "current_database": "analytics" }` or `{ "status": "switched", "current_schema": "reporting" }` -**Scope**: `readOnly`. +**Scope**: `readWrite` (mutates session state). ## Navigation tools -These mutate UI state in the running TablePro app: opening tabs, focusing windows. They require `readWrite` scope because the user sees the result. +These open or focus tabs and windows in the running TablePro app. They require `readOnly` scope and respect the connection allowlist; tabs from `externalAccess: blocked` connections are filtered out. ### `open_connection_window` -Open a connection in TablePro and bring its window to front. +Open a connection in TablePro and bring its window to front. If the connection is already open, the existing window is focused. **Input**: `{ "connection_id": "..." }` @@ -338,7 +348,7 @@ Open a connection in TablePro and bring its window to front. } ``` -**Scope**: `readWrite`. +**Scope**: `readOnly`. ### `open_table_tab` @@ -368,11 +378,11 @@ Open a table tab. } ``` -**Scope**: `readWrite`. +**Scope**: `readOnly`. ### `focus_query_tab` -Bring an existing tab to front. +Bring an existing tab to front. The `tab_id` comes from `list_recent_tabs`. **Input**: `{ "tab_id": "..." }` @@ -387,11 +397,13 @@ Bring an existing tab to front. } ``` -**Scope**: `readWrite`. +If the tab is no longer open, the call returns `-32602 invalid params` with detail `tab not found`. + +**Scope**: `readOnly`. ### `list_recent_tabs` -Read the cross-window tab registry. +Read the cross-window tab registry. Tabs from connections with `externalAccess: blocked` are filtered out. **Input**: `{ "limit": 20 }` (optional, 1-500, default 20). @@ -465,25 +477,37 @@ Full-text search over the query history database. ## Errors -All tools return JSON-RPC errors with these codes: - -| Code | Meaning | -|------|---------| -| `400` | Invalid input | -| `401` | Missing or invalid bearer token | -| `403` | Token scope or `externalAccess` rejects the request | -| `404` | Connection, table, or tab not found | -| `408` | Query timeout | -| `429` | Rate limit | -| `500` | Server error | - -Error responses include a `message` field. Example: +Tool failures come back as JSON-RPC error envelopes. Codes follow the JSON-RPC spec plus TablePro's reserved range: + +| JSON-RPC code | HTTP status | Meaning | +|---------------|-------------|---------| +| `-32700` | 400 | Parse error (malformed JSON body) | +| `-32600` | 400 | Invalid request (bad envelope, missing `Mcp-Session-Id`) | +| `-32601` | 200 / 404 | Method or resource URI not found | +| `-32602` | 200 | Invalid params (bad input, unknown tab id, unknown connection) | +| `-32603` | 500 | Internal error | +| `-32001` | 404 / 401 | Session not found, or unauthenticated | +| `-32002` | 200 | Request cancelled | +| `-32003` | 200 | Request timeout (e.g. query timeout) | +| `-32004` | 404 | Resource not found | +| `-32005` | 413 | Payload too large | +| `-32007` | 403 | Forbidden (token scope, allowlist, or `externalAccess` rejects) | +| `-32008` | 401 | Token expired | +| `-32000` | 429 / 503 | Server error (rate limited, service unavailable) | + +Error responses include a `message`. Example: ```json { + "jsonrpc": "2.0", + "id": 7, "error": { - "code": 403, - "message": "Connection is read-only for external clients" + "code": -32007, + "message": "Forbidden: Connection is read-only for external clients" } } ``` + +A `404` from `GET/POST/DELETE /mcp` with a stale `Mcp-Session-Id` returns the JSON-RPC envelope with `code: -32001, message: "Session not found"`. Per the [MCP spec](https://modelcontextprotocol.io), clients MUST treat that response as a signal to start a new `initialize` handshake before retrying. + +`401` responses include a `WWW-Authenticate: Bearer realm="TablePro MCP"` header. When the token has expired, the challenge adds `error="invalid_token", error_description="token_expired"`. diff --git a/docs/external-api/pairing.mdx b/docs/external-api/pairing.mdx index de1181bb0..e2f710a81 100644 --- a/docs/external-api/pairing.mdx +++ b/docs/external-api/pairing.mdx @@ -154,6 +154,17 @@ For preferences-backed storage, use `updateCommandMetadata` or write to the pass A failed exchange is recorded in the activity log under the `auth` category with outcome `denied`. +### Denied approvals + +If the user clicks **Deny** on the approval sheet, TablePro opens the `redirect` URL with two extra parameters so the extension can show a clear error and stop spinning: + +- `error=denied` +- `error_description=user_denied` + +For `raycast://...` redirects these are wrapped inside the standard `context` JSON payload (`{"error":"denied","error_description":"user_denied"}`); for any other scheme they are appended as flat query parameters. + +Extensions should treat the presence of an `error` parameter on the callback as terminal and surface the description to the user. + ## Implementing pairing in another extension The flow is not Raycast-specific. Cursor, Claude Desktop, or any custom client can use it. Requirements: diff --git a/docs/external-api/tokens.mdx b/docs/external-api/tokens.mdx index e9014a5e7..ad6933fa3 100644 --- a/docs/external-api/tokens.mdx +++ b/docs/external-api/tokens.mdx @@ -28,15 +28,25 @@ The `prefix` is shown in the token list so the user can identify a token without ## Scopes -| Scope | Read schema | SELECT | INSERT/UPDATE/DELETE | DROP/TRUNCATE | UI mutation | -|-------|:-----------:|:------:|:--------------------:|:-------------:|:-----------:| -| `readOnly` | yes | yes | no | no | no | -| `readWrite` | yes | yes | yes | no | yes | -| `fullAccess` | yes | yes | yes | yes (with phrase) | yes | +A token's `permissions` value maps to the MCP scopes the server enforces: -UI mutation covers `open_connection_window`, `open_table_tab`, `focus_query_tab`. These open windows and tabs in the running app. +| Token permission | MCP scopes granted | +|------------------|--------------------| +| `readOnly` | `tools:read`, `resources:read` | +| `readWrite` | `tools:read`, `tools:write`, `resources:read` | +| `fullAccess` | `tools:read`, `tools:write`, `resources:read`, `admin` | -DROP and TRUNCATE always require an explicit confirmation phrase via `confirm_destructive_operation`, even with `fullAccess`. There is no token scope that bypasses the phrase. +What each token can do: + +| Permission | Read schema | SELECT | INSERT/UPDATE/DELETE | DROP/TRUNCATE | switch_database/switch_schema | open / focus tabs | +|------------|:-----------:|:------:|:--------------------:|:-------------:|:----------------------------:|:-----------------:| +| `readOnly` | yes | yes | no | no | no | yes | +| `readWrite` | yes | yes | yes | yes (with phrase) | yes | yes | +| `fullAccess` | yes | yes | yes | yes (with phrase) | yes | yes | + +Navigation tools (`open_connection_window`, `open_table_tab`, `focus_query_tab`, `list_recent_tabs`) need only `tools:read`. They surface UI but never bypass the connection allowlist or `externalAccess: blocked`. + +DROP and TRUNCATE always require an explicit confirmation phrase via `confirm_destructive_operation`, plus a token with `tools:write` (i.e. `readWrite` or `fullAccess`). There is no token permission that bypasses the phrase. ## Connection allowlist @@ -56,11 +66,11 @@ The effective permission is `MIN(token.scope, connection.externalAccess)`. | `readOnly` | `readWrite` | `readOnly` | | `readWrite` | `readOnly` | `readOnly` | | `fullAccess` | `readOnly` | `readOnly` | -| `fullAccess` | `readWrite` | `readWrite` (no destructive) | +| `fullAccess` | `readWrite` | `readWrite` | | `fullAccess` | `blocked` | denied | | any | `blocked` | denied | -A `fullAccess` token cannot mutate data on a `readOnly` connection. A token's reach is bounded by both itself and the connection. +A `fullAccess` or `readWrite` token cannot mutate data on a `readOnly` connection. A token's reach is bounded by both itself and the connection's `externalAccess`. ## Creation @@ -110,16 +120,15 @@ Entries are kept for 90 days, auto-pruned on app launch. ## Rate limits -Per-IP, on failed auth: +The MCP authenticator throttles failed token attempts. The bucket key is `(client_address, principal_fingerprint)`, so a misbehaving bridge cannot lock out other principals on the same loopback address. -| Failures | Lockout | -|----------|---------| -| 2 | 1 second | -| 3 | 5 seconds | -| 4 | 30 seconds | -| 5+ | 5 minutes | +| Setting | Value | +|---------|-------| +| Failure window | 60 seconds | +| Max failures in window | 5 | +| Lockout after threshold | 5 minutes | -A successful auth resets the counter. During lockout the server returns `429 Too Many Requests`. +A successful auth clears the bucket. During lockout the server returns HTTP `429 Too Many Requests` with JSON-RPC `code: -32000, message: "Rate limited"`. ## What tokens cannot do diff --git a/docs/external-api/url-scheme.mdx b/docs/external-api/url-scheme.mdx index 4f9b43ff1..4bec5423f 100644 --- a/docs/external-api/url-scheme.mdx +++ b/docs/external-api/url-scheme.mdx @@ -67,25 +67,19 @@ open "tablepro://connect/9f1f0c3e-2e3d-4b14-9c3a-1d2f4ad1f6f1/database/app/schem ``` tablepro://connect//query?sql= -tablepro://connect//query?sql=&token= ``` -Opens a new query tab with the SQL pre-filled. Without a `token`, TablePro shows a confirmation dialog with the SQL before opening, so the user can verify the query is safe. +Opens a new query tab with the SQL pre-filled. TablePro always shows a confirmation dialog with a preview of the SQL before opening, so the user can verify the query is safe. The query does not auto-execute; the user runs it from the editor. The SQL has a 51,200-character cap. -If a valid `token` is provided and the token has `query.write` scope (i.e. `readWrite` or `fullAccess`), the confirmation is skipped. The token is matched against the active connection's `externalAccess` level. A read-only connection rejects any write SQL regardless of token scope. +To run SQL from a script and read rows back, use the MCP [`execute_query`](/external-api/mcp-tools) tool instead. The URL scheme is for handing SQL into the GUI, not for headless execution. ```bash -# With confirmation open "tablepro://connect/9f1f0c3e-2e3d-4b14-9c3a-1d2f4ad1f6f1/query?sql=SELECT%20*%20FROM%20users%20LIMIT%2010" - -# With token, no confirmation -open "tablepro://connect/9f1f0c3e-2e3d-4b14-9c3a-1d2f4ad1f6f1/query?sql=SELECT%20*%20FROM%20users%20LIMIT%2010&token=tp_abc123..." ``` | Parameter | Required | Description | |-----------|----------|-------------| -| `sql` | yes | Percent-encoded SQL. | -| `token` | no | Bearer token. Skips the confirmation dialog when present and valid. | +| `sql` | yes | Percent-encoded SQL. Max 51,200 characters. | ## Start pairing diff --git a/docs/external-api/versioning.mdx b/docs/external-api/versioning.mdx index 905fab298..268631c3c 100644 --- a/docs/external-api/versioning.mdx +++ b/docs/external-api/versioning.mdx @@ -7,6 +7,26 @@ description: Stability policy for the URL scheme, MCP tools, and resource catalo The External API follows TablePro's semver. The contract is the URL scheme, the MCP tool catalog, the resource list, and the pairing flow. +The MCP server accepts three versions from the [MCP spec](https://modelcontextprotocol.io): `2025-03-26`, `2025-06-18`, and `2025-11-25`. On `initialize` the server echoes the version the client asked for. If the client asks for something else, the server returns `2025-11-25` and the client decides whether to use it. + +### Capabilities + +The server reports these capabilities. Anything not listed here is not supported. + +- `tools.listChanged: false`. The tool list does not change during a session. +- `resources.listChanged: false`, `resources.subscribe: false`. Resources are static. Clients that need fresh data should call `resources/read` again. +- `prompts.listChanged: false`. No prompts yet. +- `logging`. Accepts `logging/setLevel`. +- `completions`. Accepts `completion/complete`. Returns an empty list today. + +There is no `elicitation` capability. The server does not ask clients for input. + +### What changed in 2025-11-25 + +- `tools/call` results now include `structuredContent` next to `content[]`. Older clients keep reading the text content. Newer clients can read the typed object directly. Tools that return data use both: `list_*`, `describe_table`, `get_table_ddl`, `get_connection_status`, `list_recent_tabs`, `search_query_history`, `execute_query`, `confirm_destructive_operation`. +- `tools/list` returns annotations per tool. `readOnlyHint` and `idempotentHint` mark read tools. `destructiveHint` marks `confirm_destructive_operation`. `openWorldHint` marks `execute_query` and `export_data`. +- `serverInfo` includes `title: "TablePro"`. + ## Stability rules Within a major version, the External API is **additive only**: diff --git a/docs/features/ai-assistant.mdx b/docs/features/ai-assistant.mdx index fb773b2d8..cbe358c83 100644 --- a/docs/features/ai-assistant.mdx +++ b/docs/features/ai-assistant.mdx @@ -194,7 +194,7 @@ Set a per-connection AI policy in the connection form: **Use Default**, **Always ### External AI clients -External clients (Raycast, Cursor, Claude Desktop, and other MCP clients) call the same AI tools through the [External API](/external-api). Two per-connection settings gate them: +External clients (Raycast, Cursor, Claude Desktop, and other MCP clients) call the same AI tools through the [External API](/external-api/index). Two per-connection settings gate them: - **AI policy** decides whether the connection is reachable by AI clients at all. `Never` blocks every external AI tool call against this connection. - **External Access** caps the level: `blocked`, `readOnly` (default), or `readWrite`. A token's effective permission is `MIN(token.scope, connection.externalAccess)`. Set this in the connection form's **Advanced** tab. diff --git a/docs/features/mcp.mdx b/docs/features/mcp.mdx index 16f43f627..e69c7c406 100644 --- a/docs/features/mcp.mdx +++ b/docs/features/mcp.mdx @@ -5,7 +5,7 @@ description: Built-in Model Context Protocol server that exposes TablePro to AI # MCP Server -TablePro includes a built-in [Model Context Protocol](https://modelcontextprotocol.io) (MCP) server that lets AI clients query your databases through TablePro's saved connections. The MCP server is one of three pillars of the [External API](/external-api), alongside the URL scheme and the pairing flow. +TablePro includes a built-in [Model Context Protocol](https://modelcontextprotocol.io) (MCP) server that lets AI clients query your databases through TablePro's saved connections. The MCP server is one of three pillars of the [External API](/external-api/index), alongside the URL scheme and the pairing flow. This page covers the in-app **Settings > Integrations** UI. For protocol details, see the External API section. @@ -157,4 +157,4 @@ The reachable surface is the [tool catalog](/external-api/mcp-tools) and the [UR **Certificate trust error**: export the PEM from **Settings > Integrations > Network** and add it to your client's trust store, or use fingerprint pinning. -**`429 Too Many Requests`**: too many failed auth attempts. The lockout escalates to 5 minutes and resets on the next successful auth. +**`429 Too Many Requests`**: 5 failed auth attempts within 60 seconds against the same `(client_address, principal)` pair triggers a 5-minute lockout. A successful auth clears the bucket. diff --git a/docs/features/safe-mode.mdx b/docs/features/safe-mode.mdx index 28db4e788..f2da132ee 100644 --- a/docs/features/safe-mode.mdx +++ b/docs/features/safe-mode.mdx @@ -81,5 +81,5 @@ A write request from an external client clears three locks in this order: 2. **Token scope** (per-integration, `readOnly` / `readWrite` / `fullAccess`). Issued by the [pairing flow](/external-api/pairing) and bounded by External Access: effective permission is `MIN(token.scope, connection.externalAccess)`. 3. **Safe Mode** (per-query). The same rules on this page apply once the request has been routed to the connection. Touch ID prompts and confirmation dialogs still appear, even for queries originating from an external client. -DROP and TRUNCATE always need an explicit confirmation phrase via the `confirm_destructive_operation` tool, regardless of token scope. See [External API security model](/external-api#security-model). +DROP and TRUNCATE always need an explicit confirmation phrase via the `confirm_destructive_operation` tool, regardless of token scope. See [External API security model](/external-api/index#security-model).