From 53a1a07ee79a2729bc02d653da06d5cd02a585e3 Mon Sep 17 00:00:00 2001 From: Feng Wang Date: Sat, 14 Mar 2026 13:38:29 +0800 Subject: [PATCH 01/11] add file management to database --- .github/workflows/ci.yml | 2 +- CMakeLists.txt | 6 +- src/database/Database.cpp | 213 ++++++++++++++++++++++++++++++++++- src/database/Database.h | 22 +++- src/schema/IServer.h | 231 ++++++++++++++++++++++++++++++++++++++ src/schema/src/types | 2 +- src/tui-register.cpp | 3 +- src/tui-server.cpp | 14 ++- test/CMakeLists.txt | 9 +- test/TestDatabase.cpp | 94 +++++++++++++++- test/TestService.cpp | 15 +-- test/TestTuiServer.cpp | 17 ++- 12 files changed, 603 insertions(+), 25 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index c08ffc4..62385bc 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -18,7 +18,7 @@ env: TestChaCha20Poly1305 TestCounter TestCryptoKdfHkdfSha256 - TestDatabase /tmp/tui-test.db + TestDatabase /tmp/tui-test.db /tmp/files TestEcdhePsk TestEd25519 TestFakeCredentialGenerator diff --git a/CMakeLists.txt b/CMakeLists.txt index 73ee1f8..64e231e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -49,7 +49,8 @@ target_link_libraries(tui-server sqlite3 uuid sodium - zstd) + zstd + xxhash) add_executable(tui-register ${CMAKE_CURRENT_SOURCE_DIR}/src/tui-register.cpp @@ -66,7 +67,8 @@ target_include_directories(tui-register target_link_libraries(tui-register PRIVATE sqlite3 - uuid) + uuid + xxhash) install(TARGETS tui-server RUNTIME DESTINATION bin) install(TARGETS tui-register RUNTIME DESTINATION bin) diff --git a/src/database/Database.cpp b/src/database/Database.cpp index 5702a50..0dc11e4 100644 --- a/src/database/Database.cpp +++ b/src/database/Database.cpp @@ -1,5 +1,7 @@ #include #include +#include +#include #include "Database.h" #include "common/Timestamp.h" @@ -7,9 +9,26 @@ using namespace TUI::Common; using namespace TUI::Database; using namespace TUI::Schema; -JS::Promise> Database::CreateAsync(Tev& tev, const std::filesystem::path& dbPath) +JS::Promise> Database::CreateAsync( + Tev& tev, + const std::filesystem::path& dbPath, + const std::filesystem::path& fileDirectory) { auto db = std::shared_ptr(new Database()); + + db->_fileDirectory = fileDirectory; + if (!std::filesystem::exists(fileDirectory)) + { + std::filesystem::create_directories(fileDirectory); + } + else + { + if (!std::filesystem::is_directory(fileDirectory)) + { + throw std::runtime_error("File directory path exists but is not a directory"); + } + } + db->_db = co_await Sqlite::CreateAsync(tev, dbPath); /** Create tables */ co_await db->_db->ExecAsync( @@ -47,6 +66,13 @@ JS::Promise> Database::CreateAsync(Tev& tev, const std "message TEXT, " "timestamp INTEGER, " "PRIMARY KEY (user_id, chat_id, id));"); + co_await db->_db->ExecAsync( + "CREATE TABLE IF NOT EXISTS file_meta (" + "user_id TEXT, " + "id TEXT, " + "content_id TEXT, " + "metadata TEXT, " + "PRIMARY KEY (user_id, id));"); co_return db; } @@ -148,6 +174,14 @@ JS::Promise Database::DeleteUserAsync(const Uuid& id) co_await _db->ExecAsync( "DELETE FROM chat_content WHERE user_id = ?;", static_cast(id)); + co_await _db->ExecAsync( + "DELETE FROM file_meta WHERE user_id = ?;", + static_cast(id)); + auto userDirectory = _fileDirectory / static_cast(id); + if (std::filesystem::exists(userDirectory)) + { + std::filesystem::remove_all(userDirectory); + } } std::list Database::ListUser() @@ -612,3 +646,180 @@ std::string Database::GetStringFromChat( } } +JS::Promise Database::SaveFileAsync( + const Common::Uuid& userId, std::string metadata, std::vector content) +{ + Uuid fileId{}; + auto fileHash = XXH3_128bits(content.data(), content.size()); + std::string contentId = std::format("{:016x}{:016x}", fileHash.high64, fileHash.low64); + std::string userIdStr = static_cast(userId); + auto userDirectory = _fileDirectory / userIdStr; + if (!std::filesystem::exists(userDirectory)) + { + std::filesystem::create_directories(userDirectory); + } + auto filePath = userDirectory / contentId; + if (!std::filesystem::exists(filePath)) + { + /** @todo: Make this async */ + std::ofstream ofs(filePath, std::ios::binary); + ofs.write(reinterpret_cast(content.data()), content.size()); + if (!ofs) + { + throw std::runtime_error("Failed to write file"); + } + } + auto sql = "INSERT INTO file_meta (user_id, id, content_id, metadata) VALUES (?, ?, ?, ?);"; + co_await _db->ExecAsync( + sql, + userIdStr, + static_cast(fileId), + contentId, + metadata); + co_return FileMeta{fileId, contentId, std::move(metadata)}; +} + +JS::Promise Database::DeleteFileAsync( + const Common::Uuid& userId, const Common::Uuid& fileId) +{ + auto sql = "SELECT content_id FROM file_meta WHERE user_id = ? AND id = ?;"; + auto result = _db->Exec( + sql, static_cast(userId), static_cast(fileId)); + if (result.empty()) + { + co_return; + } + auto& row = result.front(); + auto contentIdItem = row.find("content_id"); + if (contentIdItem == row.end() || !std::holds_alternative(contentIdItem->second)) + { + throw std::runtime_error("content_id not found or invalid"); + } + std::string contentId = std::get(contentIdItem->second); + sql = "DELETE FROM file_meta WHERE user_id = ? AND id = ?;"; + co_await _db->ExecAsync( + sql, + static_cast(userId), + static_cast(fileId)); + sql = "SELECT COUNT(*) AS count FROM file_meta WHERE content_id = ?;"; + result = _db->Exec(sql, contentId); + if (result.empty()) + { + throw std::runtime_error("Failed to count file_meta with content_id"); + } + auto& countRow = result.front(); + auto countItem = countRow.find("count"); + if (countItem == countRow.end() || !std::holds_alternative(countItem->second)) + { + throw std::runtime_error("count not found or invalid"); + } + int64_t count = std::get(countItem->second); + if (count != 0) + { + co_return; + } + auto filePath = _fileDirectory / static_cast(userId) / contentId; + if (std::filesystem::exists(filePath)) + { + std::filesystem::remove(filePath); + } +} + +Database::FileMeta Database::GetFileMeta( + const Common::Uuid& userId, const Common::Uuid& fileId) +{ + auto sql = "SELECT content_id, metadata FROM file_meta WHERE user_id = ? AND id = ?;"; + auto result = _db->Exec( + sql, static_cast(userId), static_cast(fileId)); + if (result.empty()) + { + throw std::runtime_error("File not found"); + } + auto& row = result.front(); + auto contentIdItem = row.find("content_id"); + if (contentIdItem == row.end() || !std::holds_alternative(contentIdItem->second)) + { + throw std::runtime_error("content_id not found or invalid"); + } + std::string contentId = std::get(contentIdItem->second); + auto metadataItem = row.find("metadata"); + std::string metadata{}; + if (metadataItem != row.end()) + { + if (std::holds_alternative(metadataItem->second)) + { + metadata = std::get(metadataItem->second); + } + else if (!std::holds_alternative(metadataItem->second)) + { + throw std::runtime_error("Invalid metadata type"); + } + } + return FileMeta{fileId, contentId, std::move(metadata)}; +} + +std::vector Database::GetFileContent( + const Common::Uuid& userId, const std::string& contentId) +{ + auto filePath = _fileDirectory / static_cast(userId) / contentId; + if (!std::filesystem::exists(filePath)) + { + throw std::runtime_error("File content not found"); + } + std::ifstream ifs(filePath, std::ios::binary); + if (!ifs) + { + throw std::runtime_error("Failed to open file"); + } + std::vector content((std::istreambuf_iterator(ifs)), std::istreambuf_iterator()); + return content; +} + +std::list Database::ListFileMeta(const Common::Uuid& userId) +{ + auto sql = "SELECT id, content_id, metadata FROM file_meta WHERE user_id = ?;"; + auto result = _db->Exec(sql, static_cast(userId)); + std::list list{}; + for (auto& row : result) + { + try + { + auto idItem = row.find("id"); + if (idItem == row.end()) + { + continue; + } + Uuid id{std::get(idItem->second)}; + + auto contentIdItem = row.find("content_id"); + if (contentIdItem == row.end() || !std::holds_alternative(contentIdItem->second)) + { + continue; + } + std::string contentId = std::get(contentIdItem->second); + + auto metadataItem = row.find("metadata"); + std::string metadata{}; + if (metadataItem != row.end()) + { + if (std::holds_alternative(metadataItem->second)) + { + metadata = std::get(metadataItem->second); + } + else if (!std::holds_alternative(metadataItem->second)) + { + /** Invalid metadata type, ignore */ + continue; + } + } + + list.emplace_back(id, contentId, std::move(metadata)); + } + catch(...) + { + /** @todo log */ + /** Ignored, avoid corrupted data from corrupting the whole application */ + } + } + return list; +} diff --git a/src/database/Database.h b/src/database/Database.h index d6e3d97..59a33cc 100644 --- a/src/database/Database.h +++ b/src/database/Database.h @@ -19,7 +19,10 @@ namespace TUI::Database class Database { public: - static JS::Promise> CreateAsync(Tev& tev, const std::filesystem::path& dbPath); + static JS::Promise> CreateAsync( + Tev& tev, + const std::filesystem::path& dbPath, + const std::filesystem::path& fileDirectory); Database(const Database&) = delete; Database& operator=(const Database&) = delete; @@ -88,6 +91,22 @@ namespace TUI::Database bool updateParent = true); Schema::IServer::TreeHistory GetChatHistory(const Common::Uuid& userId, const Common::Uuid& chatId); + /** File */ + struct FileMeta + { + Common::Uuid fileId; + std::string contentId; + std::string metadata; + }; + JS::Promise SaveFileAsync( + const Common::Uuid& userId, + std::string metadata, + std::vector content); + JS::Promise DeleteFileAsync(const Common::Uuid& userId, const Common::Uuid& fileId); + FileMeta GetFileMeta(const Common::Uuid& userId, const Common::Uuid& fileId); + std::vector GetFileContent(const Common::Uuid& userId, const std::string& contentId); + std::list ListFileMeta(const Common::Uuid& userId); + private: Database() = default; @@ -103,5 +122,6 @@ namespace TUI::Database const Common::Uuid& userId, const Common::Uuid& id, const std::string& name); std::shared_ptr _db; + std::filesystem::path _fileDirectory; }; } diff --git a/src/schema/IServer.h b/src/schema/IServer.h index 95f37d0..c5ea357 100644 --- a/src/schema/IServer.h +++ b/src/schema/IServer.h @@ -29,6 +29,14 @@ // SetUserAdminSettingsParams data = nlohmann::json::parse(jsonString); // ProtocolNegotiationRequest data = nlohmann::json::parse(jsonString); // ProtocolNegotiationResponse data = nlohmann::json::parse(jsonString); +// PutFileParams data = nlohmann::json::parse(jsonString); +// PutFileResult data = nlohmann::json::parse(jsonString); +// GetFileMetaParams data = nlohmann::json::parse(jsonString); +// GetFileMetaResult data = nlohmann::json::parse(jsonString); +// GetFileContentParams data = nlohmann::json::parse(jsonString); +// GetFileContentResult data = nlohmann::json::parse(jsonString); +// DeleteFileParams data = nlohmann::json::parse(jsonString); +// ListFileResult data = nlohmann::json::parse(jsonString); #pragma once @@ -641,11 +649,152 @@ namespace IServer { void set_was_under_attack(const bool & value) { this->was_under_attack = value; } }; + class PutFileParams { + public: + PutFileParams() = default; + virtual ~PutFileParams() = default; + + private: + std::string content_base64; + nlohmann::json file_metadata; + + public: + const std::string & get_content_base64() const { return content_base64; } + std::string & get_mutable_content_base64() { return content_base64; } + void set_content_base64(const std::string & value) { this->content_base64 = value; } + + const nlohmann::json & get_file_metadata() const { return file_metadata; } + nlohmann::json & get_mutable_file_metadata() { return file_metadata; } + void set_file_metadata(const nlohmann::json & value) { this->file_metadata = value; } + }; + + class PutFileResult { + public: + PutFileResult() = default; + virtual ~PutFileResult() = default; + + private: + std::string content_id; + std::string file_id; + + public: + /** + * Multiple file can point to the same content. + */ + const std::string & get_content_id() const { return content_id; } + std::string & get_mutable_content_id() { return content_id; } + void set_content_id(const std::string & value) { this->content_id = value; } + + const std::string & get_file_id() const { return file_id; } + std::string & get_mutable_file_id() { return file_id; } + void set_file_id(const std::string & value) { this->file_id = value; } + }; + + class GetFileMetaParams { + public: + GetFileMetaParams() = default; + virtual ~GetFileMetaParams() = default; + + private: + std::string file_id; + + public: + const std::string & get_file_id() const { return file_id; } + std::string & get_mutable_file_id() { return file_id; } + void set_file_id(const std::string & value) { this->file_id = value; } + }; + + class GetFileMetaResult { + public: + GetFileMetaResult() = default; + virtual ~GetFileMetaResult() = default; + + private: + std::string content_id; + nlohmann::json file_metadata; + + public: + const std::string & get_content_id() const { return content_id; } + std::string & get_mutable_content_id() { return content_id; } + void set_content_id(const std::string & value) { this->content_id = value; } + + const nlohmann::json & get_file_metadata() const { return file_metadata; } + nlohmann::json & get_mutable_file_metadata() { return file_metadata; } + void set_file_metadata(const nlohmann::json & value) { this->file_metadata = value; } + }; + + class GetFileContentParams { + public: + GetFileContentParams() = default; + virtual ~GetFileContentParams() = default; + + private: + std::string content_id; + + public: + const std::string & get_content_id() const { return content_id; } + std::string & get_mutable_content_id() { return content_id; } + void set_content_id(const std::string & value) { this->content_id = value; } + }; + + class GetFileContentResult { + public: + GetFileContentResult() = default; + virtual ~GetFileContentResult() = default; + + private: + std::string content_base64; + + public: + const std::string & get_content_base64() const { return content_base64; } + std::string & get_mutable_content_base64() { return content_base64; } + void set_content_base64(const std::string & value) { this->content_base64 = value; } + }; + + class DeleteFileParams { + public: + DeleteFileParams() = default; + virtual ~DeleteFileParams() = default; + + private: + std::string file_id; + + public: + const std::string & get_file_id() const { return file_id; } + std::string & get_mutable_file_id() { return file_id; } + void set_file_id(const std::string & value) { this->file_id = value; } + }; + + class ListFileResultElement { + public: + ListFileResultElement() = default; + virtual ~ListFileResultElement() = default; + + private: + std::string content_id; + std::string file_id; + nlohmann::json file_metadata; + + public: + const std::string & get_content_id() const { return content_id; } + std::string & get_mutable_content_id() { return content_id; } + void set_content_id(const std::string & value) { this->content_id = value; } + + const std::string & get_file_id() const { return file_id; } + std::string & get_mutable_file_id() { return file_id; } + void set_file_id(const std::string & value) { this->file_id = value; } + + const nlohmann::json & get_file_metadata() const { return file_metadata; } + nlohmann::json & get_mutable_file_metadata() { return file_metadata; } + void set_file_metadata(const nlohmann::json & value) { this->file_metadata = value; } + }; + using GetMetadataResult = std::map; using LinearHistory = std::vector; using GetChatListResult = std::vector; using GetModelListResult = std::vector; using GetUserListResult = std::vector; + using ListFileResult = std::vector; } } } @@ -1038,6 +1187,88 @@ namespace IServer { j["wasUnderAttack"] = x.get_was_under_attack(); } + inline void from_json(const json & j, PutFileParams& x) { + x.set_content_base64(j.at("contentBase64").get()); + x.set_file_metadata(get_untyped(j, "fileMetadata")); + } + + inline void to_json(json & j, const PutFileParams & x) { + j = json::object(); + j["contentBase64"] = x.get_content_base64(); + j["fileMetadata"] = x.get_file_metadata(); + } + + inline void from_json(const json & j, PutFileResult& x) { + x.set_content_id(j.at("contentId").get()); + x.set_file_id(j.at("fileId").get()); + } + + inline void to_json(json & j, const PutFileResult & x) { + j = json::object(); + j["contentId"] = x.get_content_id(); + j["fileId"] = x.get_file_id(); + } + + inline void from_json(const json & j, GetFileMetaParams& x) { + x.set_file_id(j.at("fileId").get()); + } + + inline void to_json(json & j, const GetFileMetaParams & x) { + j = json::object(); + j["fileId"] = x.get_file_id(); + } + + inline void from_json(const json & j, GetFileMetaResult& x) { + x.set_content_id(j.at("contentId").get()); + x.set_file_metadata(get_untyped(j, "fileMetadata")); + } + + inline void to_json(json & j, const GetFileMetaResult & x) { + j = json::object(); + j["contentId"] = x.get_content_id(); + j["fileMetadata"] = x.get_file_metadata(); + } + + inline void from_json(const json & j, GetFileContentParams& x) { + x.set_content_id(j.at("contentId").get()); + } + + inline void to_json(json & j, const GetFileContentParams & x) { + j = json::object(); + j["contentId"] = x.get_content_id(); + } + + inline void from_json(const json & j, GetFileContentResult& x) { + x.set_content_base64(j.at("contentBase64").get()); + } + + inline void to_json(json & j, const GetFileContentResult & x) { + j = json::object(); + j["contentBase64"] = x.get_content_base64(); + } + + inline void from_json(const json & j, DeleteFileParams& x) { + x.set_file_id(j.at("fileId").get()); + } + + inline void to_json(json & j, const DeleteFileParams & x) { + j = json::object(); + j["fileId"] = x.get_file_id(); + } + + inline void from_json(const json & j, ListFileResultElement& x) { + x.set_content_id(j.at("contentId").get()); + x.set_file_id(j.at("fileId").get()); + x.set_file_metadata(get_untyped(j, "fileMetadata")); + } + + inline void to_json(json & j, const ListFileResultElement & x) { + j = json::object(); + j["contentId"] = x.get_content_id(); + j["fileId"] = x.get_file_id(); + j["fileMetadata"] = x.get_file_metadata(); + } + inline void from_json(const json & j, Type & x) { if (j == "image_url") x = Type::IMAGE_URL; else if (j == "refusal") x = Type::REFUSAL; diff --git a/src/schema/src/types b/src/schema/src/types index c983c33..2f92040 160000 --- a/src/schema/src/types +++ b/src/schema/src/types @@ -1 +1 @@ -Subproject commit c983c33e7f7a21da0f4def1f7bfd50d02a010ddb +Subproject commit 2f9204085a9b0307c1593d39bf4489ec9314b940 diff --git a/src/tui-register.cpp b/src/tui-register.cpp index c46128a..c9968fb 100644 --- a/src/tui-register.cpp +++ b/src/tui-register.cpp @@ -178,7 +178,8 @@ static JS::Promise MainAsync(Tev& tev, AppParams params) credential.set_w0(Common::Base64::Encode(w0Opt.value())); credential.set_l(Common::Base64::Encode(LOpt.value())); - auto database = co_await Database::Database::CreateAsync(tev, params.dbPath.value()); + /** The file root is not used for tui-register. */ + auto database = co_await Database::Database::CreateAsync(tev, params.dbPath.value(), "/tmp"); /** * Check if: * 1. There is an existing user. If so, this is a password reset. diff --git a/src/tui-server.cpp b/src/tui-server.cpp index 043d7f9..f8b40bc 100644 --- a/src/tui-server.cpp +++ b/src/tui-server.cpp @@ -22,6 +22,7 @@ using namespace TUI; struct AppParams { std::optional dbPath{std::nullopt}; + std::optional fileRoot{std::nullopt}; std::optional unixSocketPath{std::nullopt}; std::optional address{std::nullopt}; std::optional port{std::nullopt}; @@ -30,13 +31,16 @@ struct AppParams { int opt = -1; AppParams params{}; - while ((opt = getopt(argc, const_cast(argv), "d:u:a:p:")) != -1) + while ((opt = getopt(argc, const_cast(argv), "d:u:a:p:f:")) != -1) { switch (opt) { case 'd': params.dbPath = std::filesystem::path(optarg); break; + case 'f': + params.fileRoot = std::filesystem::path(optarg); + break; case 'u': params.unixSocketPath = std::string(optarg); break; @@ -59,6 +63,10 @@ struct AppParams { throw std::invalid_argument("Database path is required"); } + if (!fileRoot.has_value()) + { + throw std::invalid_argument("File root path is required"); + } if (!unixSocketPath.has_value() && (!address.has_value() || !port.has_value())) { throw std::invalid_argument("Either unix socket path or address and port must be provided"); @@ -71,6 +79,7 @@ struct AppParams oss << "Usage: " << std::endl << programName << std::endl << " -d " << std::endl + << " -f " << std::endl << " -u | -a
-p " << std::endl; return oss.str(); } @@ -129,7 +138,8 @@ static JS::Promise MainNoexceptAsync(AppParams params) static JS::Promise MainAsync(AppParams params) { - auto database = co_await Database::Database::CreateAsync(gApp.tev, params.dbPath.value()); + auto database = co_await Database::Database::CreateAsync( + gApp.tev, params.dbPath.value(), params.fileRoot.value()); std::shared_ptr> webSocketServer{nullptr}; if (params.unixSocketPath.has_value()) { diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 6b3667a..7968388 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -134,7 +134,8 @@ target_include_directories(TestDatabase target_link_libraries(TestDatabase PRIVATE sqlite3 - uuid) + uuid + xxhash) add_executable(TestRpcServer TestRpcServer.cpp) @@ -374,7 +375,8 @@ target_link_libraries(TestService sqlite3 uuid sodium - zstd) + zstd + xxhash) add_executable(TestTuiServer TestTuiServer.cpp @@ -391,7 +393,8 @@ target_link_libraries(TestTuiServer sqlite3 uuid sodium - zstd) + zstd + xxhash) add_executable(TestZstdMessageCodec TestZstdMessageCodec.cpp diff --git a/test/TestDatabase.cpp b/test/TestDatabase.cpp index 66a14c4..4878ba2 100644 --- a/test/TestDatabase.cpp +++ b/test/TestDatabase.cpp @@ -12,11 +12,12 @@ using namespace TUI::Schema; Tev tev{}; std::string dbPath{}; +std::string fileRoot{}; std::shared_ptr db{nullptr}; JS::Promise TestCreateAsync() { - db = co_await Database::CreateAsync(tev, dbPath); + db = co_await Database::CreateAsync(tev, dbPath, fileRoot); } void TestClose() @@ -262,6 +263,91 @@ JS::Promise TestChatAsync() AssertWithMessage(!chatFound, "Deleted chat found"); } +JS::Promise TestFileAsync() +{ + std::string username = "test-user-file"; + auto userId = co_await db->CreateUserAsync(username, "", ""); + + std::string metadataInput = "test-file-metadata"; + std::vector contentInput = {'H', 'e', 'l', 'l', 'o'}; + + auto fileMeta = co_await db->SaveFileAsync(userId, metadataInput, contentInput); + AssertWithMessage(fileMeta.fileId != nullptr, "File ID should not be empty"); + AssertWithMessage(!fileMeta.contentId.empty(), "Content ID should not be empty"); + AssertWithMessage(fileMeta.metadata == metadataInput, "File metadata should match"); + + auto retrievedMeta = db->GetFileMeta(userId, fileMeta.fileId); + AssertWithMessage(retrievedMeta.fileId == fileMeta.fileId, "Retrieved file ID should match"); + AssertWithMessage(retrievedMeta.contentId == fileMeta.contentId, "Retrieved content ID should match"); + AssertWithMessage(retrievedMeta.metadata == metadataInput, "Retrieved metadata should match"); + + auto retrievedContent = db->GetFileContent(userId, fileMeta.contentId); + AssertWithMessage(retrievedContent == contentInput, "Retrieved file content should match"); + + auto fileList = db->ListFileMeta(userId); + bool fileFound = false; + for (const auto& f : fileList) + { + if (f.fileId == fileMeta.fileId) + { + fileFound = true; + break; + } + } + AssertWithMessage(fileFound, "Saved file should appear in list"); + + /** Save another file with the same content to test content deduplication */ + std::string metadataInput2 = "test-file-metadata-2"; + auto fileMeta2 = co_await db->SaveFileAsync(userId, metadataInput2, contentInput); + AssertWithMessage(fileMeta2.fileId != fileMeta.fileId, "Second file should have a different file ID"); + AssertWithMessage(fileMeta2.contentId == fileMeta.contentId, "Same content should produce the same content ID"); + + fileList = db->ListFileMeta(userId); + size_t fileCount = 0; + for (const auto& f : fileList) + { + if (f.fileId == fileMeta.fileId || f.fileId == fileMeta2.fileId) + { + fileCount++; + } + } + AssertWithMessage(fileCount == 2, "Both files should appear in list"); + + /** Delete first file; content should still exist because second file references it */ + co_await db->DeleteFileAsync(userId, fileMeta.fileId); + fileList = db->ListFileMeta(userId); + fileFound = false; + for (const auto& f : fileList) + { + if (f.fileId == fileMeta.fileId) + { + fileFound = true; + break; + } + } + AssertWithMessage(!fileFound, "Deleted file should not appear in list"); + + auto contentAfterFirstDelete = db->GetFileContent(userId, fileMeta2.contentId); + AssertWithMessage(contentAfterFirstDelete == contentInput, "Content should still exist after deleting first reference"); + + /** Delete second file; content should now be removed */ + co_await db->DeleteFileAsync(userId, fileMeta2.fileId); + fileList = db->ListFileMeta(userId); + AssertWithMessage(fileList.empty(), "File list should be empty after deleting all files"); + + /** Save a file with different content */ + std::vector contentInput3 = {'W', 'o', 'r', 'l', 'd'}; + auto fileMeta3 = co_await db->SaveFileAsync(userId, "metadata-3", contentInput3); + AssertWithMessage(fileMeta3.contentId != fileMeta.contentId, "Different content should produce a different content ID"); + auto retrievedContent3 = db->GetFileContent(userId, fileMeta3.contentId); + AssertWithMessage(retrievedContent3 == contentInput3, "Retrieved content should match for third file"); + + /** Files should be deleted with the user */ + co_await db->DeleteUserAsync(userId); + fileList = db->ListFileMeta(userId); + AssertWithMessage(fileList.empty(), "Files should be deleted with the user"); +} + JS::Promise TestAsync() { /** Always run this first */ @@ -270,17 +356,19 @@ JS::Promise TestAsync() RunAsyncTest(TestModelAsync()); RunAsyncTest(TestUserAsync()); RunAsyncTest(TestChatAsync()); + RunAsyncTest(TestFileAsync()); RunTest(TestClose()); } int main(int argc, char const *argv[]) { - if (argc < 2) + if (argc < 3) { - std::cerr << "Usage: " << argv[0] << " " << std::endl; + std::cerr << "Usage: " << argv[0] << " " << " " << std::endl; return 1; } dbPath = argv[1]; + fileRoot = argv[2]; /** Delete the old database */ if (std::filesystem::exists(dbPath)) { diff --git a/test/TestService.cpp b/test/TestService.cpp index 03048a3..2ec3e53 100644 --- a/test/TestService.cpp +++ b/test/TestService.cpp @@ -106,7 +106,7 @@ static void SignalHandler(int sig) gSignalQueue->Inject(sig); } -JS::Promise MainAsync(Tev& tev, std::string dbPath) +JS::Promise MainAsync(Tev& tev, std::string dbPath, std::string fileRoot) { /** Delete the old database if any */ if (std::filesystem::exists(dbPath)) @@ -122,7 +122,7 @@ JS::Promise MainAsync(Tev& tev, std::string dbPath) std::filesystem::remove(dbPath + "-shm"); } - auto database = co_await Database::Database::CreateAsync(tev, dbPath); + auto database = co_await Database::Database::CreateAsync(tev, dbPath, fileRoot); Common::Uuid userId{}; { /** Create a default user in the database. THIS IS ONLY FOR TESTING */ @@ -157,23 +157,24 @@ JS::Promise MainAsync(Tev& tev, std::string dbPath) signal(SIGTERM, SignalHandler); } -JS::Promise TestAsync(Tev& tev, std::string dbPath) +JS::Promise TestAsync(Tev& tev, std::string dbPath, std::string fileRoot) { - RunAsyncTest(MainAsync(tev, dbPath)); + RunAsyncTest(MainAsync(tev, std::move(dbPath), std::move(fileRoot))); } int main(int argc, char const *argv[]) { - if (argc < 2) + if (argc < 3) { - std::cerr << "Usage: " << argv[0] << " " << std::endl; + std::cerr << "Usage: " << argv[0] << " " << std::endl; return 1; } std::string dbPath = argv[1]; + std::string fileRoot = argv[2]; Tev tev{}; - TestAsync(tev, dbPath); + TestAsync(tev, std::move(dbPath), std::move(fileRoot)); tev.MainLoop(); diff --git a/test/TestTuiServer.cpp b/test/TestTuiServer.cpp index 43f8241..c8a77a6 100644 --- a/test/TestTuiServer.cpp +++ b/test/TestTuiServer.cpp @@ -23,13 +23,14 @@ using namespace TUI; struct AppParams { std::optional dbPath{std::nullopt}; + std::optional fileRoot{std::nullopt}; std::optional configPath{std::nullopt}; static AppParams Parse(int argc, char const *argv[]) { int opt = -1; AppParams params{}; - while ((opt = getopt(argc, const_cast(argv), "d:c:")) != -1) + while ((opt = getopt(argc, const_cast(argv), "d:c:f:")) != -1) { switch (opt) { @@ -39,6 +40,9 @@ struct AppParams case 'c': params.configPath = std::filesystem::path(optarg); break; + case 'f': + params.fileRoot = std::filesystem::path(optarg); + break; default: break; } @@ -52,6 +56,10 @@ struct AppParams { throw std::invalid_argument("Database path is required"); } + if (!fileRoot.has_value()) + { + throw std::invalid_argument("File root path is required"); + } } std::string getHelp(const std::string& programName) const @@ -60,6 +68,7 @@ struct AppParams oss << "Usage: " << std::endl << programName << std::endl << " -d " << std::endl + << " -f " << std::endl << " [-c ]" << std::endl; return oss.str(); } @@ -137,7 +146,8 @@ static JS::Promise PrepareTestDatabaseAsync(AppParams params) std::filesystem::remove(params.dbPath->string() + "-shm"); } - auto database = co_await Database::Database::CreateAsync(gApp.tev, params.dbPath.value()); + auto database = co_await Database::Database::CreateAsync( + gApp.tev, params.dbPath.value(), params.fileRoot.value()); /** Register a test user */ auto registration = Cipher::Spake2p::Register(std::string(testUsername), std::string(testPassword)); Schema::IServer::UserCredential userCredential; @@ -173,7 +183,8 @@ static JS::Promise PrepareTestDatabaseAsync(AppParams params) static JS::Promise MainAsync(AppParams params) { - auto database = co_await Database::Database::CreateAsync(gApp.tev, params.dbPath.value()); + auto database = co_await Database::Database::CreateAsync( + gApp.tev, params.dbPath.value(), params.fileRoot.value()); std::shared_ptr> webSocketServer{nullptr}; webSocketServer = Network::WebSocket::Server::Create( gApp.tev, std::string(serverAddress), serverPort); From 74850b8f36b9339c920abc51d4d0c73c0d219c40 Mon Sep 17 00:00:00 2001 From: Feng Wang Date: Sat, 14 Mar 2026 14:13:07 +0800 Subject: [PATCH 02/11] add file apis --- src/application/Service.cpp | 82 ++++++++++++++++++++++++++++++++++++- src/application/Service.h | 5 +++ 2 files changed, 86 insertions(+), 1 deletion(-) diff --git a/src/application/Service.cpp b/src/application/Service.cpp index d0d3a57..5689f3f 100644 --- a/src/application/Service.cpp +++ b/src/application/Service.cpp @@ -3,6 +3,7 @@ #include "network/HttpStreamResponseParser.h" #include "common/Timestamp.h" #include "common/StreamBatcher.h" +#include "common/Base64.h" using namespace TUI; using namespace TUI::Application; @@ -38,7 +39,12 @@ Service::Service( {"deleteUser", std::bind(&Service::OnDeleteUserAsync, this, std::placeholders::_1, std::placeholders::_2)}, {"getUserAdminSettings", std::bind(&Service::OnGetUserAdminSettingsAsync, this, std::placeholders::_1, std::placeholders::_2)}, {"setUserAdminSettings", std::bind(&Service::OnSetUserAdminSettingsAsync, this, std::placeholders::_1, std::placeholders::_2)}, - {"setUserCredential", std::bind(&Service::OnSetUserCredentialAsync, this, std::placeholders::_1, std::placeholders::_2)} + {"setUserCredential", std::bind(&Service::OnSetUserCredentialAsync, this, std::placeholders::_1, std::placeholders::_2)}, + {"putFile", std::bind(&Service::OnPutFileAsync, this, std::placeholders::_1, std::placeholders::_2)}, + {"getFileMeta", std::bind(&Service::OnGetFileMetaAsync, this, std::placeholders::_1, std::placeholders::_2)}, + {"getFileContent", std::bind(&Service::OnGetFileContentAsync, this, std::placeholders::_1, std::placeholders::_2)}, + {"deleteFile", std::bind(&Service::OnDeleteFileAsync, this, std::placeholders::_1, std::placeholders::_2)}, + {"listFile", std::bind(&Service::OnListFileAsync, this, std::placeholders::_1, std::placeholders::_2)} }, std::unordered_map::StreamRequestHandler>{ {"chatCompletion", std::bind(&Service::OnChatCompletionAsync, this, std::placeholders::_1, std::placeholders::_2)}, @@ -970,6 +976,80 @@ JS::Promise Service::OnSetUserCredentialAsync(CallerId callerId, co_return nlohmann::json{}; } +JS::Promise Service::OnPutFileAsync(CallerId callerId, nlohmann::json paramsJson) +{ + auto params = ParseParams(paramsJson); + auto fileListLock = _resourceVersionManager->GetWriteLock( + {"fileList", static_cast(callerId.userId)}, callerId); + /** + * File metadata and content are not managed by version manager. As they are constants. + * It's up to the client to cache them properly. + */ + auto fileContent = Common::Base64::Decode(params.get_content_base64()); + auto metadataStr = (static_cast(params.get_file_metadata())).dump(); + auto fileMeta = co_await _database->SaveFileAsync( + callerId.userId, + std::move(metadataStr), + std::move(fileContent)); + Schema::IServer::PutFileResult result{}; + result.set_file_id(static_cast(fileMeta.fileId)); + result.set_content_id(fileMeta.contentId); + co_return static_cast(result); +} + +JS::Promise Service::OnGetFileMetaAsync(CallerId callerId, nlohmann::json paramsJson) +{ + auto params = ParseParams(paramsJson); + Common::Uuid fileId{params.get_file_id()}; + auto fileMeta = _database->GetFileMeta(callerId.userId, fileId); + nlohmann::json metadata = nlohmann::json::parse(fileMeta.metadata); + Schema::IServer::GetFileMetaResult result{}; + result.set_content_id(fileMeta.contentId); + result.set_file_metadata(std::move(metadata)); + co_return static_cast(result); +} + +JS::Promise Service::OnGetFileContentAsync(CallerId callerId, nlohmann::json paramsJson) +{ + auto params = ParseParams(paramsJson); + auto content = _database->GetFileContent(callerId.userId, params.get_content_id()); + auto contentBase64 = Common::Base64::Encode(content); + Schema::IServer::GetFileContentResult result{}; + result.set_content_base64(std::move(contentBase64)); + co_return static_cast(result); +} + +JS::Promise Service::OnDeleteFileAsync(CallerId callerId, nlohmann::json paramsJson) +{ + auto params = ParseParams(paramsJson); + auto fileListLock = _resourceVersionManager->GetWriteLock( + {"fileList", static_cast(callerId.userId)}, callerId); + Common::Uuid fileId{params.get_file_id()}; + co_await _database->DeleteFileAsync(callerId.userId, fileId); + /** Return null */ + co_return nlohmann::json{}; +} + +JS::Promise Service::OnListFileAsync(CallerId callerId, nlohmann::json paramsJson) +{ + (void)paramsJson; + auto lock = _resourceVersionManager->GetReadLock( + {"fileList", static_cast(callerId.userId)}, callerId); + auto list = _database->ListFileMeta(callerId.userId); + Schema::IServer::ListFileResult result{}; + result.reserve(list.size()); + for (const auto& item : list) + { + using EntryType = decltype(result)::value_type; + EntryType entry{}; + entry.set_file_id(static_cast(item.fileId)); + nlohmann::json metadata = nlohmann::json::parse(item.metadata); + entry.set_file_metadata(std::move(metadata)); + result.push_back(std::move(entry)); + } + co_return static_cast(result); +} + void Service::OnNewConnection(CallerId callerId) { /** diff --git a/src/application/Service.h b/src/application/Service.h index 9091470..d06d86f 100644 --- a/src/application/Service.h +++ b/src/application/Service.h @@ -64,6 +64,11 @@ namespace TUI::Application JS::Promise OnGetUserAdminSettingsAsync(CallerId callerId, nlohmann::json params); JS::Promise OnSetUserAdminSettingsAsync(CallerId callerId, nlohmann::json params); JS::Promise OnSetUserCredentialAsync(CallerId callerId, nlohmann::json params); + JS::Promise OnPutFileAsync(CallerId callerId, nlohmann::json params); + JS::Promise OnGetFileMetaAsync(CallerId callerId, nlohmann::json params); + JS::Promise OnGetFileContentAsync(CallerId callerId, nlohmann::json params); + JS::Promise OnDeleteFileAsync(CallerId callerId, nlohmann::json params); + JS::Promise OnListFileAsync(CallerId callerId, nlohmann::json params); /** Connection handlers */ void OnNewConnection(CallerId callerId); From e35fc8fe2c0d09c981dceccfee2a6167c2c3162e Mon Sep 17 00:00:00 2001 From: Feng Wang Date: Sun, 15 Mar 2026 12:08:04 +0800 Subject: [PATCH 03/11] manual clean up generated type --- src/schema/IServer.h | 628 ++++++++++++++++++++++++++++++++++--------- 1 file changed, 499 insertions(+), 129 deletions(-) diff --git a/src/schema/IServer.h b/src/schema/IServer.h index c5ea357..e7e0640 100644 --- a/src/schema/IServer.h +++ b/src/schema/IServer.h @@ -8,6 +8,10 @@ // GetMetadataParams data = nlohmann::json::parse(jsonString); // GetMetadataResult data = nlohmann::json::parse(jsonString); // DeleteMetadataParams data = nlohmann::json::parse(jsonString); +// MessageContent data = nlohmann::json::parse(jsonString); +// ChatMessage data = nlohmann::json::parse(jsonString); +// FunctionCallMessage data = nlohmann::json::parse(jsonString); +// FunctionCallOutputMessage data = nlohmann::json::parse(jsonString); // Message data = nlohmann::json::parse(jsonString); // MessageNode data = nlohmann::json::parse(jsonString); // LinearHistory data = nlohmann::json::parse(jsonString); @@ -15,6 +19,7 @@ // GetChatListParams data = nlohmann::json::parse(jsonString); // GetChatListResult data = nlohmann::json::parse(jsonString); // ChatCompletionParams data = nlohmann::json::parse(jsonString); +// ChatCompletionSegment data = nlohmann::json::parse(jsonString); // ChatCompletionInfo data = nlohmann::json::parse(jsonString); // ExecuteGenerationTaskParams data = nlohmann::json::parse(jsonString); // GetModelListParams data = nlohmann::json::parse(jsonString); @@ -41,6 +46,7 @@ #pragma once #include +#include #include #include @@ -192,7 +198,7 @@ namespace IServer { void set_path(const std::vector & value) { this->path = value; } }; - enum class Type : int { IMAGE_URL, REFUSAL, TEXT }; + enum class MessageContentType : int { IMAGE_URL, REFUSAL, TEXT }; class MessageContent { public: @@ -201,37 +207,159 @@ namespace IServer { private: std::string data; - Type type; + MessageContentType type; public: const std::string & get_data() const { return data; } std::string & get_mutable_data() { return data; } void set_data(const std::string & value) { this->data = value; } - const Type & get_type() const { return type; } - Type & get_mutable_type() { return type; } - void set_type(const Type & value) { this->type = value; } + const MessageContentType & get_type() const { return type; } + MessageContentType & get_mutable_type() { return type; } + void set_type(const MessageContentType & value) { this->type = value; } }; - enum class MessageRole : int { ASSISTANT, DEVELOPER, USER }; + enum class ChatMessageRole : int { ASSISTANT, DEVELOPER, USER }; - class Message { + class ChatMessage { public: - Message() = default; - virtual ~Message() = default; + ChatMessage() = default; + virtual ~ChatMessage() = default; private: std::vector content; - MessageRole role; + ChatMessageRole role; public: const std::vector & get_content() const { return content; } std::vector & get_mutable_content() { return content; } void set_content(const std::vector & value) { this->content = value; } - const MessageRole & get_role() const { return role; } - MessageRole & get_mutable_role() { return role; } - void set_role(const MessageRole & value) { this->role = value; } + const ChatMessageRole & get_role() const { return role; } + ChatMessageRole & get_mutable_role() { return role; } + void set_role(const ChatMessageRole & value) { this->role = value; } + }; + + enum class FunctionCallMessageType : int { FUNCTION_CALL }; + + class FunctionCallMessage { + public: + FunctionCallMessage() = default; + virtual ~FunctionCallMessage() = default; + + private: + std::string arguments; + std::string call_id; + nlohmann::json extra; + std::string name; + FunctionCallMessageType type; + + public: + const std::string & get_arguments() const { return arguments; } + std::string & get_mutable_arguments() { return arguments; } + void set_arguments(const std::string & value) { this->arguments = value; } + + const std::string & get_call_id() const { return call_id; } + std::string & get_mutable_call_id() { return call_id; } + void set_call_id(const std::string & value) { this->call_id = value; } + + /** + * Provider specific trash required for the message but not the logic. + */ + const nlohmann::json & get_extra() const { return extra; } + nlohmann::json & get_mutable_extra() { return extra; } + void set_extra(const nlohmann::json & value) { this->extra = value; } + + const std::string & get_name() const { return name; } + std::string & get_mutable_name() { return name; } + void set_name(const std::string & value) { this->name = value; } + + const FunctionCallMessageType & get_type() const { return type; } + FunctionCallMessageType & get_mutable_type() { return type; } + void set_type(const FunctionCallMessageType & value) { this->type = value; } + }; + + enum class FunctionCallOutputMessageType : int { FUNCTION_CALL_OUTPUT }; + + class FunctionCallOutputMessage { + public: + FunctionCallOutputMessage() = default; + virtual ~FunctionCallOutputMessage() = default; + + private: + std::string call_id; + nlohmann::json extra; + std::vector output; + FunctionCallOutputMessageType type; + + public: + const std::string & get_call_id() const { return call_id; } + std::string & get_mutable_call_id() { return call_id; } + void set_call_id(const std::string & value) { this->call_id = value; } + + /** + * Provider specific trash. + */ + const nlohmann::json & get_extra() const { return extra; } + nlohmann::json & get_mutable_extra() { return extra; } + void set_extra(const nlohmann::json & value) { this->extra = value; } + + const std::vector & get_output() const { return output; } + std::vector & get_mutable_output() { return output; } + void set_output(const std::vector & value) { this->output = value; } + + const FunctionCallOutputMessageType & get_type() const { return type; } + FunctionCallOutputMessageType & get_mutable_type() { return type; } + void set_type(const FunctionCallOutputMessageType & value) { this->type = value; } + }; + + enum class MessageType : int { FUNCTION_CALL, FUNCTION_CALL_OUTPUT }; + + class Message { + public: + Message() = default; + virtual ~Message() = default; + + private: + std::optional> content; + std::optional role; + std::optional arguments; + std::optional call_id; + nlohmann::json extra; + std::optional name; + std::optional type; + std::optional> output; + + public: + std::optional> get_content() const { return content; } + void set_content(std::optional> value) { this->content = value; } + + std::optional get_role() const { return role; } + void set_role(std::optional value) { this->role = value; } + + std::optional get_arguments() const { return arguments; } + void set_arguments(std::optional value) { this->arguments = value; } + + std::optional get_call_id() const { return call_id; } + void set_call_id(std::optional value) { this->call_id = value; } + + /** + * Provider specific trash required for the message but not the logic. + * + * Provider specific trash. + */ + const nlohmann::json & get_extra() const { return extra; } + nlohmann::json & get_mutable_extra() { return extra; } + void set_extra(const nlohmann::json & value) { this->extra = value; } + + std::optional get_name() const { return name; } + void set_name(std::optional value) { this->name = value; } + + std::optional get_type() const { return type; } + void set_type(std::optional value) { this->type = value; } + + std::optional> get_output() const { return output; } + void set_output(std::optional> value) { this->output = value; } }; class MessageNode { @@ -332,15 +460,19 @@ namespace IServer { private: std::string id; + std::vector messages; std::string model_id; std::optional parent; - Message user_message; public: const std::string & get_id() const { return id; } std::string & get_mutable_id() { return id; } void set_id(const std::string & value) { this->id = value; } + const std::vector & get_messages() const { return messages; } + std::vector & get_mutable_messages() { return messages; } + void set_messages(const std::vector & value) { this->messages = value; } + const std::string & get_model_id() const { return model_id; } std::string & get_mutable_model_id() { return model_id; } void set_model_id(const std::string & value) { this->model_id = value; } @@ -350,29 +482,45 @@ namespace IServer { */ std::optional get_parent() const { return parent; } void set_parent(std::optional value) { this->parent = value; } + }; + + enum class Event : int { FUNCTION_CALL_END, FUNCTION_CALL_START }; + + class ChatCompletionSegmentClass { + public: + ChatCompletionSegmentClass() = default; + virtual ~ChatCompletionSegmentClass() = default; - const Message & get_user_message() const { return user_message; } - Message & get_mutable_user_message() { return user_message; } - void set_user_message(const Message & value) { this->user_message = value; } + private: + Event event; + std::optional data; + + public: + const Event & get_event() const { return event; } + Event & get_mutable_event() { return event; } + void set_event(const Event & value) { this->event = value; } + + std::optional get_data() const { return data; } + void set_data(std::optional value) { this->data = value; } }; + using ChatCompletionSegment = std::variant; + class ChatCompletionInfo { public: ChatCompletionInfo() = default; virtual ~ChatCompletionInfo() = default; private: - std::string assistant_message_id; - std::string user_message_id; + std::vector message_ids; public: - const std::string & get_assistant_message_id() const { return assistant_message_id; } - std::string & get_mutable_assistant_message_id() { return assistant_message_id; } - void set_assistant_message_id(const std::string & value) { this->assistant_message_id = value; } - - const std::string & get_user_message_id() const { return user_message_id; } - std::string & get_mutable_user_message_id() { return user_message_id; } - void set_user_message_id(const std::string & value) { this->user_message_id = value; } + /** + * Message ids for the new requests and responses. + */ + const std::vector & get_message_ids() const { return message_ids; } + std::vector & get_mutable_message_ids() { return message_ids; } + void set_message_ids(const std::vector & value) { this->message_ids = value; } }; class ExecuteGenerationTaskParams { @@ -381,13 +529,13 @@ namespace IServer { virtual ~ExecuteGenerationTaskParams() = default; private: - Message message; + ChatMessage message; std::string model_id; public: - const Message & get_message() const { return message; } - Message & get_mutable_message() { return message; } - void set_message(const Message & value) { this->message = value; } + const ChatMessage & get_message() const { return message; } + ChatMessage & get_mutable_message() { return message; } + void set_message(const ChatMessage & value) { this->message = value; } const std::string & get_model_id() const { return model_id; } std::string & get_mutable_model_id() { return model_id; } @@ -518,15 +666,15 @@ namespace IServer { virtual ~GetUserListParams() = default; private: - std::optional> public_metadata_keys; std::optional> admin_metadata_keys; + std::optional> public_metadata_keys; public: - std::optional> get_public_metadata_keys() const { return public_metadata_keys; } - void set_public_metadata_keys(std::optional> value) { this->public_metadata_keys = value; } - std::optional> get_admin_metadata_keys() const { return admin_metadata_keys; } void set_admin_metadata_keys(std::optional> value) { this->admin_metadata_keys = value; } + + std::optional> get_public_metadata_keys() const { return public_metadata_keys; } + void set_public_metadata_keys(std::optional> value) { this->public_metadata_keys = value; } }; class GetUserListResultElement { @@ -535,14 +683,17 @@ namespace IServer { virtual ~GetUserListResultElement() = default; private: + std::optional> admin_metadata; UserAdminSettings admin_settings; std::string id; std::optional is_self; std::optional> public_metadata; - std::optional> admin_metadata; std::string user_name; public: + std::optional> get_admin_metadata() const { return admin_metadata; } + void set_admin_metadata(std::optional> value) { this->admin_metadata = value; } + const UserAdminSettings & get_admin_settings() const { return admin_settings; } UserAdminSettings & get_mutable_admin_settings() { return admin_settings; } void set_admin_settings(const UserAdminSettings & value) { this->admin_settings = value; } @@ -557,9 +708,6 @@ namespace IServer { std::optional> get_public_metadata() const { return public_metadata; } void set_public_metadata(std::optional> value) { this->public_metadata = value; } - std::optional> get_admin_metadata() const { return admin_metadata; } - void set_admin_metadata(std::optional> value) { this->admin_metadata = value; } - const std::string & get_user_name() const { return user_name; } std::string & get_mutable_user_name() { return user_name; } void set_user_name(const std::string & value) { this->user_name = value; } @@ -611,6 +759,9 @@ namespace IServer { void set_id(const std::string & value) { this->id = value; } }; + /** + * These two messages are not exposed to the application layer. + */ class ProtocolNegotiationRequest { public: ProtocolNegotiationRequest() = default; @@ -802,87 +953,147 @@ namespace IServer { namespace TUI { namespace Schema { namespace IServer { - void from_json(const json & j, SetMetadataParams & x); - void to_json(json & j, const SetMetadataParams & x); +void from_json(const json & j, SetMetadataParams & x); +void to_json(json & j, const SetMetadataParams & x); + +void from_json(const json & j, GetMetadataParams & x); +void to_json(json & j, const GetMetadataParams & x); + +void from_json(const json & j, DeleteMetadataParams & x); +void to_json(json & j, const DeleteMetadataParams & x); + +void from_json(const json & j, MessageContent & x); +void to_json(json & j, const MessageContent & x); + +void from_json(const json & j, ChatMessage & x); +void to_json(json & j, const ChatMessage & x); + +void from_json(const json & j, FunctionCallMessage & x); +void to_json(json & j, const FunctionCallMessage & x); + +void from_json(const json & j, FunctionCallOutputMessage & x); +void to_json(json & j, const FunctionCallOutputMessage & x); + +void from_json(const json & j, Message & x); +void to_json(json & j, const Message & x); + +void from_json(const json & j, MessageNode & x); +void to_json(json & j, const MessageNode & x); + +void from_json(const json & j, TreeHistory & x); +void to_json(json & j, const TreeHistory & x); + +void from_json(const json & j, GetChatListParams & x); +void to_json(json & j, const GetChatListParams & x); + +void from_json(const json & j, GetChatListResultElement & x); +void to_json(json & j, const GetChatListResultElement & x); + +void from_json(const json & j, ChatCompletionParams & x); +void to_json(json & j, const ChatCompletionParams & x); - void from_json(const json & j, GetMetadataParams & x); - void to_json(json & j, const GetMetadataParams & x); +void from_json(const json & j, ChatCompletionSegmentClass & x); +void to_json(json & j, const ChatCompletionSegmentClass & x); - void from_json(const json & j, DeleteMetadataParams & x); - void to_json(json & j, const DeleteMetadataParams & x); +void from_json(const json & j, ChatCompletionInfo & x); +void to_json(json & j, const ChatCompletionInfo & x); - void from_json(const json & j, MessageContent & x); - void to_json(json & j, const MessageContent & x); +void from_json(const json & j, ExecuteGenerationTaskParams & x); +void to_json(json & j, const ExecuteGenerationTaskParams & x); - void from_json(const json & j, Message & x); - void to_json(json & j, const Message & x); +void from_json(const json & j, GetModelListParams & x); +void to_json(json & j, const GetModelListParams & x); - void from_json(const json & j, MessageNode & x); - void to_json(json & j, const MessageNode & x); +void from_json(const json & j, GetModelListResultElement & x); +void to_json(json & j, const GetModelListResultElement & x); - void from_json(const json & j, TreeHistory & x); - void to_json(json & j, const TreeHistory & x); +void from_json(const json & j, ModelSettings & x); +void to_json(json & j, const ModelSettings & x); - void from_json(const json & j, GetChatListParams & x); - void to_json(json & j, const GetChatListParams & x); +void from_json(const json & j, ModifyModelSettingsParams & x); +void to_json(json & j, const ModifyModelSettingsParams & x); - void from_json(const json & j, GetChatListResultElement & x); - void to_json(json & j, const GetChatListResultElement & x); +void from_json(const json & j, UserAdminSettings & x); +void to_json(json & j, const UserAdminSettings & x); - void from_json(const json & j, ChatCompletionParams & x); - void to_json(json & j, const ChatCompletionParams & x); +void from_json(const json & j, UserCredential & x); +void to_json(json & j, const UserCredential & x); - void from_json(const json & j, ChatCompletionInfo & x); - void to_json(json & j, const ChatCompletionInfo & x); +void from_json(const json & j, GetUserListParams & x); +void to_json(json & j, const GetUserListParams & x); - void from_json(const json & j, ExecuteGenerationTaskParams & x); - void to_json(json & j, const ExecuteGenerationTaskParams & x); +void from_json(const json & j, GetUserListResultElement & x); +void to_json(json & j, const GetUserListResultElement & x); - void from_json(const json & j, GetModelListParams & x); - void to_json(json & j, const GetModelListParams & x); +void from_json(const json & j, NewUserParams & x); +void to_json(json & j, const NewUserParams & x); - void from_json(const json & j, GetModelListResultElement & x); - void to_json(json & j, const GetModelListResultElement & x); +void from_json(const json & j, SetUserAdminSettingsParams & x); +void to_json(json & j, const SetUserAdminSettingsParams & x); - void from_json(const json & j, ModelSettings & x); - void to_json(json & j, const ModelSettings & x); +void from_json(const json & j, ProtocolNegotiationRequest & x); +void to_json(json & j, const ProtocolNegotiationRequest & x); - void from_json(const json & j, ModifyModelSettingsParams & x); - void to_json(json & j, const ModifyModelSettingsParams & x); +void from_json(const json & j, ProtocolNegotiationResponse & x); +void to_json(json & j, const ProtocolNegotiationResponse & x); - void from_json(const json & j, UserAdminSettings & x); - void to_json(json & j, const UserAdminSettings & x); +void from_json(const json & j, PutFileParams & x); +void to_json(json & j, const PutFileParams & x); - void from_json(const json & j, UserCredential & x); - void to_json(json & j, const UserCredential & x); +void from_json(const json & j, PutFileResult & x); +void to_json(json & j, const PutFileResult & x); - void from_json(const json & j, GetUserListParams & x); - void to_json(json & j, const GetUserListParams & x); +void from_json(const json & j, GetFileMetaParams & x); +void to_json(json & j, const GetFileMetaParams & x); - void from_json(const json & j, GetUserListResultElement & x); - void to_json(json & j, const GetUserListResultElement & x); +void from_json(const json & j, GetFileMetaResult & x); +void to_json(json & j, const GetFileMetaResult & x); - void from_json(const json & j, NewUserParams & x); - void to_json(json & j, const NewUserParams & x); +void from_json(const json & j, GetFileContentParams & x); +void to_json(json & j, const GetFileContentParams & x); - void from_json(const json & j, SetUserAdminSettingsParams & x); - void to_json(json & j, const SetUserAdminSettingsParams & x); +void from_json(const json & j, GetFileContentResult & x); +void to_json(json & j, const GetFileContentResult & x); - void from_json(const json & j, ProtocolNegotiationRequest & x); - void to_json(json & j, const ProtocolNegotiationRequest & x); +void from_json(const json & j, DeleteFileParams & x); +void to_json(json & j, const DeleteFileParams & x); - void from_json(const json & j, ProtocolNegotiationResponse & x); - void to_json(json & j, const ProtocolNegotiationResponse & x); +void from_json(const json & j, ListFileResultElement & x); +void to_json(json & j, const ListFileResultElement & x); - void from_json(const json & j, Type & x); - void to_json(json & j, const Type & x); +void from_json(const json & j, MessageContentType & x); +void to_json(json & j, const MessageContentType & x); - void from_json(const json & j, MessageRole & x); - void to_json(json & j, const MessageRole & x); +void from_json(const json & j, ChatMessageRole & x); +void to_json(json & j, const ChatMessageRole & x); - void from_json(const json & j, UserAdminSettingsRole & x); - void to_json(json & j, const UserAdminSettingsRole & x); +void from_json(const json & j, FunctionCallMessageType & x); +void to_json(json & j, const FunctionCallMessageType & x); +void from_json(const json & j, FunctionCallOutputMessageType & x); +void to_json(json & j, const FunctionCallOutputMessageType & x); + +void from_json(const json & j, MessageType & x); +void to_json(json & j, const MessageType & x); + +void from_json(const json & j, Event & x); +void to_json(json & j, const Event & x); + +void from_json(const json & j, UserAdminSettingsRole & x); +void to_json(json & j, const UserAdminSettingsRole & x); +} +} +} +namespace nlohmann { +template <> +struct adl_serializer> { + static void from_json(const json & j, std::variant & x); + static void to_json(json & j, const std::variant & x); +}; +} +namespace TUI { +namespace Schema { +namespace IServer { inline void from_json(const json & j, SetMetadataParams& x) { x.set_entries(j.at("entries").get>()); x.set_path(j.at("path").get>()); @@ -918,7 +1129,7 @@ namespace IServer { inline void from_json(const json & j, MessageContent& x) { x.set_data(j.at("data").get()); - x.set_type(j.at("type").get()); + x.set_type(j.at("type").get()); } inline void to_json(json & j, const MessageContent & x) { @@ -927,17 +1138,92 @@ namespace IServer { j["type"] = x.get_type(); } - inline void from_json(const json & j, Message& x) { + inline void from_json(const json & j, ChatMessage& x) { x.set_content(j.at("content").get>()); - x.set_role(j.at("role").get()); + x.set_role(j.at("role").get()); } - inline void to_json(json & j, const Message & x) { + inline void to_json(json & j, const ChatMessage & x) { j = json::object(); j["content"] = x.get_content(); j["role"] = x.get_role(); } + inline void from_json(const json & j, FunctionCallMessage& x) { + x.set_arguments(j.at("arguments").get()); + x.set_call_id(j.at("call_id").get()); + x.set_extra(get_untyped(j, "extra")); + x.set_name(j.at("name").get()); + x.set_type(j.at("type").get()); + } + + inline void to_json(json & j, const FunctionCallMessage & x) { + j = json::object(); + j["arguments"] = x.get_arguments(); + j["call_id"] = x.get_call_id(); + if (x.get_extra()) { + j["extra"] = x.get_extra(); + } + j["name"] = x.get_name(); + j["type"] = x.get_type(); + } + + inline void from_json(const json & j, FunctionCallOutputMessage& x) { + x.set_call_id(j.at("call_id").get()); + x.set_extra(get_untyped(j, "extra")); + x.set_output(j.at("output").get>()); + x.set_type(j.at("type").get()); + } + + inline void to_json(json & j, const FunctionCallOutputMessage & x) { + j = json::object(); + j["call_id"] = x.get_call_id(); + if (x.get_extra()) { + j["extra"] = x.get_extra(); + } + j["output"] = x.get_output(); + j["type"] = x.get_type(); + } + + inline void from_json(const json & j, Message& x) { + x.set_content(get_stack_optional>(j, "content")); + x.set_role(get_stack_optional(j, "role")); + x.set_arguments(get_stack_optional(j, "arguments")); + x.set_call_id(get_stack_optional(j, "call_id")); + x.set_extra(get_untyped(j, "extra")); + x.set_name(get_stack_optional(j, "name")); + x.set_type(get_stack_optional(j, "type")); + x.set_output(get_stack_optional>(j, "output")); + } + + inline void to_json(json & j, const Message & x) { + j = json::object(); + if (x.get_content()) { + j["content"] = x.get_content(); + } + if (x.get_role()) { + j["role"] = x.get_role(); + } + if (x.get_arguments()) { + j["arguments"] = x.get_arguments(); + } + if (x.get_call_id()) { + j["call_id"] = x.get_call_id(); + } + if (x.get_extra()) { + j["extra"] = x.get_extra(); + } + if (x.get_name()) { + j["name"] = x.get_name(); + } + if (x.get_type()) { + j["type"] = x.get_type(); + } + if (x.get_output()) { + j["output"] = x.get_output(); + } + } + inline void from_json(const json & j, MessageNode& x) { x.set_children(j.at("children").get>()); x.set_id(j.at("id").get()); @@ -996,34 +1282,45 @@ namespace IServer { inline void from_json(const json & j, ChatCompletionParams& x) { x.set_id(j.at("id").get()); + x.set_messages(j.at("messages").get>()); x.set_model_id(j.at("modelId").get()); x.set_parent(get_stack_optional(j, "parent")); - x.set_user_message(j.at("userMessage").get()); } inline void to_json(json & j, const ChatCompletionParams & x) { j = json::object(); j["id"] = x.get_id(); + j["messages"] = x.get_messages(); j["modelId"] = x.get_model_id(); if (x.get_parent()) { j["parent"] = x.get_parent(); } - j["userMessage"] = x.get_user_message(); + } + + inline void from_json(const json & j, ChatCompletionSegmentClass& x) { + x.set_event(j.at("event").get()); + x.set_data(get_stack_optional(j, "data")); + } + + inline void to_json(json & j, const ChatCompletionSegmentClass & x) { + j = json::object(); + j["event"] = x.get_event(); + if (x.get_data()) { + j["data"] = x.get_data(); + } } inline void from_json(const json & j, ChatCompletionInfo& x) { - x.set_assistant_message_id(j.at("assistantMessageId").get()); - x.set_user_message_id(j.at("userMessageId").get()); + x.set_message_ids(j.at("messageIds").get>()); } inline void to_json(json & j, const ChatCompletionInfo & x) { j = json::object(); - j["assistantMessageId"] = x.get_assistant_message_id(); - j["userMessageId"] = x.get_user_message_id(); + j["messageIds"] = x.get_message_ids(); } inline void from_json(const json & j, ExecuteGenerationTaskParams& x) { - x.set_message(j.at("message").get()); + x.set_message(j.at("message").get()); x.set_model_id(j.at("modelId").get()); } @@ -1102,31 +1399,34 @@ namespace IServer { } inline void from_json(const json & j, GetUserListParams& x) { - x.set_public_metadata_keys(get_stack_optional>(j, "publicMetadataKeys")); x.set_admin_metadata_keys(get_stack_optional>(j, "adminMetadataKeys")); + x.set_public_metadata_keys(get_stack_optional>(j, "publicMetadataKeys")); } inline void to_json(json & j, const GetUserListParams & x) { j = json::object(); - if (x.get_public_metadata_keys()) { - j["publicMetadataKeys"] = x.get_public_metadata_keys(); - } if (x.get_admin_metadata_keys()) { j["adminMetadataKeys"] = x.get_admin_metadata_keys(); } + if (x.get_public_metadata_keys()) { + j["publicMetadataKeys"] = x.get_public_metadata_keys(); + } } inline void from_json(const json & j, GetUserListResultElement& x) { + x.set_admin_metadata(get_stack_optional>(j, "adminMetadata")); x.set_admin_settings(j.at("adminSettings").get()); x.set_id(j.at("id").get()); x.set_is_self(get_stack_optional(j, "isSelf")); x.set_public_metadata(get_stack_optional>(j, "publicMetadata")); - x.set_admin_metadata(get_stack_optional>(j, "adminMetadata")); x.set_user_name(j.at("userName").get()); } inline void to_json(json & j, const GetUserListResultElement & x) { j = json::object(); + if (x.get_admin_metadata()) { + j["adminMetadata"] = x.get_admin_metadata(); + } j["adminSettings"] = x.get_admin_settings(); j["id"] = x.get_id(); if (x.get_is_self()) { @@ -1135,9 +1435,6 @@ namespace IServer { if (x.get_public_metadata()) { j["publicMetadata"] = x.get_public_metadata(); } - if (x.get_admin_metadata()) { - j["adminMetadata"] = x.get_admin_metadata(); - } j["userName"] = x.get_user_name(); } @@ -1269,35 +1566,87 @@ namespace IServer { j["fileMetadata"] = x.get_file_metadata(); } - inline void from_json(const json & j, Type & x) { - if (j == "image_url") x = Type::IMAGE_URL; - else if (j == "refusal") x = Type::REFUSAL; - else if (j == "text") x = Type::TEXT; + inline void from_json(const json & j, MessageContentType & x) { + if (j == "image_url") x = MessageContentType::IMAGE_URL; + else if (j == "refusal") x = MessageContentType::REFUSAL; + else if (j == "text") x = MessageContentType::TEXT; + else { throw std::runtime_error("Input JSON does not conform to schema!"); } + } + + inline void to_json(json & j, const MessageContentType & x) { + switch (x) { + case MessageContentType::IMAGE_URL: j = "image_url"; break; + case MessageContentType::REFUSAL: j = "refusal"; break; + case MessageContentType::TEXT: j = "text"; break; + default: throw std::runtime_error("Unexpected value in enumeration \"MessageContentType\": " + std::to_string(static_cast(x))); + } + } + + inline void from_json(const json & j, ChatMessageRole & x) { + if (j == "assistant") x = ChatMessageRole::ASSISTANT; + else if (j == "developer") x = ChatMessageRole::DEVELOPER; + else if (j == "user") x = ChatMessageRole::USER; + else { throw std::runtime_error("Input JSON does not conform to schema!"); } + } + + inline void to_json(json & j, const ChatMessageRole & x) { + switch (x) { + case ChatMessageRole::ASSISTANT: j = "assistant"; break; + case ChatMessageRole::DEVELOPER: j = "developer"; break; + case ChatMessageRole::USER: j = "user"; break; + default: throw std::runtime_error("Unexpected value in enumeration \"ChatMessageRole\": " + std::to_string(static_cast(x))); + } + } + + inline void from_json(const json & j, FunctionCallMessageType & x) { + if (j == "function_call") x = FunctionCallMessageType::FUNCTION_CALL; + else { throw std::runtime_error("Input JSON does not conform to schema!"); } + } + + inline void to_json(json & j, const FunctionCallMessageType & x) { + switch (x) { + case FunctionCallMessageType::FUNCTION_CALL: j = "function_call"; break; + default: throw std::runtime_error("Unexpected value in enumeration \"FunctionCallMessageType\": " + std::to_string(static_cast(x))); + } + } + + inline void from_json(const json & j, FunctionCallOutputMessageType & x) { + if (j == "function_call_output") x = FunctionCallOutputMessageType::FUNCTION_CALL_OUTPUT; + else { throw std::runtime_error("Input JSON does not conform to schema!"); } + } + + inline void to_json(json & j, const FunctionCallOutputMessageType & x) { + switch (x) { + case FunctionCallOutputMessageType::FUNCTION_CALL_OUTPUT: j = "function_call_output"; break; + default: throw std::runtime_error("Unexpected value in enumeration \"FunctionCallOutputMessageType\": " + std::to_string(static_cast(x))); + } + } + + inline void from_json(const json & j, MessageType & x) { + if (j == "function_call") x = MessageType::FUNCTION_CALL; + else if (j == "function_call_output") x = MessageType::FUNCTION_CALL_OUTPUT; else { throw std::runtime_error("Input JSON does not conform to schema!"); } } - inline void to_json(json & j, const Type & x) { + inline void to_json(json & j, const MessageType & x) { switch (x) { - case Type::IMAGE_URL: j = "image_url"; break; - case Type::REFUSAL: j = "refusal"; break; - case Type::TEXT: j = "text"; break; - default: throw std::runtime_error("Unexpected value in enumeration \"Type\": " + std::to_string(static_cast(x))); + case MessageType::FUNCTION_CALL: j = "function_call"; break; + case MessageType::FUNCTION_CALL_OUTPUT: j = "function_call_output"; break; + default: throw std::runtime_error("Unexpected value in enumeration \"MessageType\": " + std::to_string(static_cast(x))); } } - inline void from_json(const json & j, MessageRole & x) { - if (j == "assistant") x = MessageRole::ASSISTANT; - else if (j == "developer") x = MessageRole::DEVELOPER; - else if (j == "user") x = MessageRole::USER; + inline void from_json(const json & j, Event & x) { + if (j == "function_call_end") x = Event::FUNCTION_CALL_END; + else if (j == "function_call_start") x = Event::FUNCTION_CALL_START; else { throw std::runtime_error("Input JSON does not conform to schema!"); } } - inline void to_json(json & j, const MessageRole & x) { + inline void to_json(json & j, const Event & x) { switch (x) { - case MessageRole::ASSISTANT: j = "assistant"; break; - case MessageRole::DEVELOPER: j = "developer"; break; - case MessageRole::USER: j = "user"; break; - default: throw std::runtime_error("Unexpected value in enumeration \"MessageRole\": " + std::to_string(static_cast(x))); + case Event::FUNCTION_CALL_END: j = "function_call_end"; break; + case Event::FUNCTION_CALL_START: j = "function_call_start"; break; + default: throw std::runtime_error("Unexpected value in enumeration \"Event\": " + std::to_string(static_cast(x))); } } @@ -1317,3 +1666,24 @@ namespace IServer { } } } +namespace nlohmann { + inline void adl_serializer>::from_json(const json & j, std::variant & x) { + if (j.is_string()) + x = j.get(); + else if (j.is_object()) + x = j.get(); + else throw std::runtime_error("Could not deserialise!"); + } + + inline void adl_serializer>::to_json(json & j, const std::variant & x) { + switch (x.index()) { + case 0: + j = std::get(x); + break; + case 1: + j = std::get(x); + break; + default: throw std::runtime_error("Input JSON does not conform to schema!"); + } + } +} From 9c66cf574424fea71e16316af0150afd8e84aa38 Mon Sep 17 00:00:00 2001 From: Feng Wang Date: Sun, 15 Mar 2026 12:29:16 +0800 Subject: [PATCH 04/11] use variant for the new message --- src/schema/IServer.h | 156 ++++---------- test/CMakeLists.txt | 7 + test/TestIServerSchema.cpp | 414 +++++++++++++++++++++++++++++++++++++ 3 files changed, 460 insertions(+), 117 deletions(-) create mode 100644 test/TestIServerSchema.cpp diff --git a/src/schema/IServer.h b/src/schema/IServer.h index e7e0640..7770e37 100644 --- a/src/schema/IServer.h +++ b/src/schema/IServer.h @@ -250,7 +250,7 @@ namespace IServer { private: std::string arguments; std::string call_id; - nlohmann::json extra; + std::optional extra; std::string name; FunctionCallMessageType type; @@ -266,9 +266,8 @@ namespace IServer { /** * Provider specific trash required for the message but not the logic. */ - const nlohmann::json & get_extra() const { return extra; } - nlohmann::json & get_mutable_extra() { return extra; } - void set_extra(const nlohmann::json & value) { this->extra = value; } + std::optional get_extra() const { return extra; } + void set_extra(std::optional value) { this->extra = value; } const std::string & get_name() const { return name; } std::string & get_mutable_name() { return name; } @@ -288,7 +287,7 @@ namespace IServer { private: std::string call_id; - nlohmann::json extra; + std::optional extra; std::vector output; FunctionCallOutputMessageType type; @@ -300,9 +299,8 @@ namespace IServer { /** * Provider specific trash. */ - const nlohmann::json & get_extra() const { return extra; } - nlohmann::json & get_mutable_extra() { return extra; } - void set_extra(const nlohmann::json & value) { this->extra = value; } + std::optional get_extra() const { return extra; } + void set_extra(std::optional value) { this->extra = value; } const std::vector & get_output() const { return output; } std::vector & get_mutable_output() { return output; } @@ -313,54 +311,7 @@ namespace IServer { void set_type(const FunctionCallOutputMessageType & value) { this->type = value; } }; - enum class MessageType : int { FUNCTION_CALL, FUNCTION_CALL_OUTPUT }; - - class Message { - public: - Message() = default; - virtual ~Message() = default; - - private: - std::optional> content; - std::optional role; - std::optional arguments; - std::optional call_id; - nlohmann::json extra; - std::optional name; - std::optional type; - std::optional> output; - - public: - std::optional> get_content() const { return content; } - void set_content(std::optional> value) { this->content = value; } - - std::optional get_role() const { return role; } - void set_role(std::optional value) { this->role = value; } - - std::optional get_arguments() const { return arguments; } - void set_arguments(std::optional value) { this->arguments = value; } - - std::optional get_call_id() const { return call_id; } - void set_call_id(std::optional value) { this->call_id = value; } - - /** - * Provider specific trash required for the message but not the logic. - * - * Provider specific trash. - */ - const nlohmann::json & get_extra() const { return extra; } - nlohmann::json & get_mutable_extra() { return extra; } - void set_extra(const nlohmann::json & value) { this->extra = value; } - - std::optional get_name() const { return name; } - void set_name(std::optional value) { this->name = value; } - - std::optional get_type() const { return type; } - void set_type(std::optional value) { this->type = value; } - - std::optional> get_output() const { return output; } - void set_output(std::optional> value) { this->output = value; } - }; + using Message = std::variant; class MessageNode { public: @@ -974,9 +925,6 @@ void to_json(json & j, const FunctionCallMessage & x); void from_json(const json & j, FunctionCallOutputMessage & x); void to_json(json & j, const FunctionCallOutputMessage & x); -void from_json(const json & j, Message & x); -void to_json(json & j, const Message & x); - void from_json(const json & j, MessageNode & x); void to_json(json & j, const MessageNode & x); @@ -1073,9 +1021,6 @@ void to_json(json & j, const FunctionCallMessageType & x); void from_json(const json & j, FunctionCallOutputMessageType & x); void to_json(json & j, const FunctionCallOutputMessageType & x); -void from_json(const json & j, MessageType & x); -void to_json(json & j, const MessageType & x); - void from_json(const json & j, Event & x); void to_json(json & j, const Event & x); @@ -1086,6 +1031,11 @@ void to_json(json & j, const UserAdminSettingsRole & x); } namespace nlohmann { template <> +struct adl_serializer> { + static void from_json(const json & j, std::variant & x); + static void to_json(json & j, const std::variant & x); +}; +template <> struct adl_serializer> { static void from_json(const json & j, std::variant & x); static void to_json(json & j, const std::variant & x); @@ -1152,7 +1102,7 @@ namespace IServer { inline void from_json(const json & j, FunctionCallMessage& x) { x.set_arguments(j.at("arguments").get()); x.set_call_id(j.at("call_id").get()); - x.set_extra(get_untyped(j, "extra")); + x.set_extra(get_stack_optional(j, "extra")); x.set_name(j.at("name").get()); x.set_type(j.at("type").get()); } @@ -1170,7 +1120,7 @@ namespace IServer { inline void from_json(const json & j, FunctionCallOutputMessage& x) { x.set_call_id(j.at("call_id").get()); - x.set_extra(get_untyped(j, "extra")); + x.set_extra(get_stack_optional(j, "extra")); x.set_output(j.at("output").get>()); x.set_type(j.at("type").get()); } @@ -1185,45 +1135,6 @@ namespace IServer { j["type"] = x.get_type(); } - inline void from_json(const json & j, Message& x) { - x.set_content(get_stack_optional>(j, "content")); - x.set_role(get_stack_optional(j, "role")); - x.set_arguments(get_stack_optional(j, "arguments")); - x.set_call_id(get_stack_optional(j, "call_id")); - x.set_extra(get_untyped(j, "extra")); - x.set_name(get_stack_optional(j, "name")); - x.set_type(get_stack_optional(j, "type")); - x.set_output(get_stack_optional>(j, "output")); - } - - inline void to_json(json & j, const Message & x) { - j = json::object(); - if (x.get_content()) { - j["content"] = x.get_content(); - } - if (x.get_role()) { - j["role"] = x.get_role(); - } - if (x.get_arguments()) { - j["arguments"] = x.get_arguments(); - } - if (x.get_call_id()) { - j["call_id"] = x.get_call_id(); - } - if (x.get_extra()) { - j["extra"] = x.get_extra(); - } - if (x.get_name()) { - j["name"] = x.get_name(); - } - if (x.get_type()) { - j["type"] = x.get_type(); - } - if (x.get_output()) { - j["output"] = x.get_output(); - } - } - inline void from_json(const json & j, MessageNode& x) { x.set_children(j.at("children").get>()); x.set_id(j.at("id").get()); @@ -1622,20 +1533,6 @@ namespace IServer { } } - inline void from_json(const json & j, MessageType & x) { - if (j == "function_call") x = MessageType::FUNCTION_CALL; - else if (j == "function_call_output") x = MessageType::FUNCTION_CALL_OUTPUT; - else { throw std::runtime_error("Input JSON does not conform to schema!"); } - } - - inline void to_json(json & j, const MessageType & x) { - switch (x) { - case MessageType::FUNCTION_CALL: j = "function_call"; break; - case MessageType::FUNCTION_CALL_OUTPUT: j = "function_call_output"; break; - default: throw std::runtime_error("Unexpected value in enumeration \"MessageType\": " + std::to_string(static_cast(x))); - } - } - inline void from_json(const json & j, Event & x) { if (j == "function_call_end") x = Event::FUNCTION_CALL_END; else if (j == "function_call_start") x = Event::FUNCTION_CALL_START; @@ -1667,6 +1564,31 @@ namespace IServer { } } namespace nlohmann { + inline void adl_serializer>::from_json(const json & j, std::variant & x) { + if (j.contains("role")) + x = j.get(); + else if (j.contains("type") && j.at("type") == "function_call") + x = j.get(); + else if (j.contains("type") && j.at("type") == "function_call_output") + x = j.get(); + else throw std::runtime_error("Could not deserialise Message!"); + } + + inline void adl_serializer>::to_json(json & j, const std::variant & x) { + switch (x.index()) { + case 0: + j = std::get(x); + break; + case 1: + j = std::get(x); + break; + case 2: + j = std::get(x); + break; + default: throw std::runtime_error("Input JSON does not conform to schema!"); + } + } + inline void adl_serializer>::from_json(const json & j, std::variant & x) { if (j.is_string()) x = j.get(); diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 7968388..0ac1a84 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -407,3 +407,10 @@ target_include_directories(TestZstdMessageCodec target_link_libraries(TestZstdMessageCodec PRIVATE zstd) + +add_executable(TestIServerSchema + TestIServerSchema.cpp) + +target_include_directories(TestIServerSchema + PRIVATE + ../src) diff --git a/test/TestIServerSchema.cpp b/test/TestIServerSchema.cpp new file mode 100644 index 0000000..89b7694 --- /dev/null +++ b/test/TestIServerSchema.cpp @@ -0,0 +1,414 @@ +#include +#include +#include +#include +#include "schema/IServer.h" +#include "Utility.h" + +using namespace TUI::Schema::IServer; +using json = nlohmann::json; + +static void TestChatMessageRoundTrip() +{ + json input = R"({ + "role": "user", + "content": [ + {"type": "text", "data": "Hello, world!"} + ] + })"_json; + + Message msg = input.get(); + AssertWithMessage(std::holds_alternative(msg), "Should be ChatMessage"); + + auto& chat = std::get(msg); + AssertWithMessage(chat.get_role() == ChatMessageRole::USER, "Role should be user"); + AssertWithMessage(chat.get_content().size() == 1, "Should have 1 content item"); + AssertWithMessage(chat.get_content()[0].get_data() == "Hello, world!", "Content data mismatch"); + AssertWithMessage(chat.get_content()[0].get_type() == MessageContentType::TEXT, "Content type should be text"); + + json output = msg; + AssertWithMessage(output == input, "Round-trip should produce identical JSON"); +} + +static void TestChatMessageAssistantRole() +{ + json input = R"({ + "role": "assistant", + "content": [ + {"type": "text", "data": "I can help with that."}, + {"type": "refusal", "data": "I cannot do that."} + ] + })"_json; + + Message msg = input.get(); + AssertWithMessage(std::holds_alternative(msg), "Should be ChatMessage"); + + auto& chat = std::get(msg); + AssertWithMessage(chat.get_role() == ChatMessageRole::ASSISTANT, "Role should be assistant"); + AssertWithMessage(chat.get_content().size() == 2, "Should have 2 content items"); + AssertWithMessage(chat.get_content()[1].get_type() == MessageContentType::REFUSAL, "Second item should be refusal"); + + json output = msg; + AssertWithMessage(output == input, "Round-trip should produce identical JSON"); +} + +static void TestFunctionCallMessageRoundTrip() +{ + json input = R"({ + "type": "function_call", + "call_id": "call_123", + "name": "get_weather", + "arguments": "{\"city\": \"London\"}" + })"_json; + + Message msg = input.get(); + AssertWithMessage(std::holds_alternative(msg), "Should be FunctionCallMessage"); + + auto& fc = std::get(msg); + AssertWithMessage(fc.get_call_id() == "call_123", "call_id mismatch"); + AssertWithMessage(fc.get_name() == "get_weather", "name mismatch"); + AssertWithMessage(fc.get_arguments() == "{\"city\": \"London\"}", "arguments mismatch"); + AssertWithMessage(fc.get_type() == FunctionCallMessageType::FUNCTION_CALL, "type mismatch"); + + json output = msg; + AssertWithMessage(output == input, "Round-trip should produce identical JSON"); +} + +static void TestFunctionCallMessageWithExtra() +{ + json input = R"({ + "type": "function_call", + "call_id": "call_456", + "name": "search", + "arguments": "{}", + "extra": {"id": "fc_1", "status": "completed"} + })"_json; + + Message msg = input.get(); + AssertWithMessage(std::holds_alternative(msg), "Should be FunctionCallMessage"); + + auto& fc = std::get(msg); + AssertWithMessage(fc.get_extra().has_value(), "extra should have a value"); + AssertWithMessage(fc.get_extra()->is_object(), "extra should be an object"); + AssertWithMessage((*fc.get_extra())["id"] == "fc_1", "extra id mismatch"); + + json output = msg; + AssertWithMessage(output == input, "Round-trip should produce identical JSON"); +} + +static void TestFunctionCallOutputMessageRoundTrip() +{ + json input = R"({ + "type": "function_call_output", + "call_id": "call_123", + "output": [ + {"type": "text", "data": "Sunny, 22°C"} + ] + })"_json; + + Message msg = input.get(); + AssertWithMessage(std::holds_alternative(msg), "Should be FunctionCallOutputMessage"); + + auto& fco = std::get(msg); + AssertWithMessage(fco.get_call_id() == "call_123", "call_id mismatch"); + AssertWithMessage(fco.get_output().size() == 1, "Should have 1 output item"); + AssertWithMessage(fco.get_output()[0].get_data() == "Sunny, 22°C", "Output data mismatch"); + + json output = msg; + AssertWithMessage(output == input, "Round-trip should produce identical JSON"); +} + +static void TestFunctionCallOutputMessageWithExtra() +{ + json input = R"({ + "type": "function_call_output", + "call_id": "call_789", + "extra": {"provider_id": "x"}, + "output": [ + {"type": "text", "data": "result"} + ] + })"_json; + + Message msg = input.get(); + AssertWithMessage(std::holds_alternative(msg), "Should be FunctionCallOutputMessage"); + + auto& fco = std::get(msg); + AssertWithMessage(fco.get_extra().has_value(), "extra should have a value"); + AssertWithMessage(fco.get_extra()->is_object(), "extra should be an object"); + AssertWithMessage((*fco.get_extra())["provider_id"] == "x", "extra provider_id mismatch"); + + json output = msg; + AssertWithMessage(output == input, "Round-trip should produce identical JSON"); +} + +static void TestMessageArrayRoundTrip() +{ + json input = R"([ + { + "role": "user", + "content": [{"type": "text", "data": "What is the weather?"}] + }, + { + "type": "function_call", + "call_id": "call_1", + "name": "get_weather", + "arguments": "{}" + }, + { + "type": "function_call_output", + "call_id": "call_1", + "output": [{"type": "text", "data": "Rainy"}] + }, + { + "role": "assistant", + "content": [{"type": "text", "data": "It is rainy."}] + } + ])"_json; + + auto messages = input.get>(); + AssertWithMessage(messages.size() == 4, "Should have 4 messages"); + AssertWithMessage(std::holds_alternative(messages[0]), "messages[0] should be ChatMessage"); + AssertWithMessage(std::holds_alternative(messages[1]), "messages[1] should be FunctionCallMessage"); + AssertWithMessage(std::holds_alternative(messages[2]), "messages[2] should be FunctionCallOutputMessage"); + AssertWithMessage(std::holds_alternative(messages[3]), "messages[3] should be ChatMessage"); + + json output = messages; + AssertWithMessage(output == input, "Round-trip should produce identical JSON"); +} + +static void TestLinearHistoryRoundTrip() +{ + json input = R"([ + { + "role": "developer", + "content": [{"type": "text", "data": "You are a helpful assistant."}] + }, + { + "role": "user", + "content": [{"type": "text", "data": "Hi"}] + } + ])"_json; + + LinearHistory history = input.get(); + AssertWithMessage(history.size() == 2, "Should have 2 messages"); + AssertWithMessage(std::holds_alternative(history[0]), "First should be ChatMessage"); + + auto& dev = std::get(history[0]); + AssertWithMessage(dev.get_role() == ChatMessageRole::DEVELOPER, "Role should be developer"); + + json output = history; + AssertWithMessage(output == input, "Round-trip should produce identical JSON"); +} + +static void TestMessageNodeRoundTrip() +{ + json input = R"({ + "children": ["child_1"], + "id": "node_1", + "message": { + "role": "user", + "content": [{"type": "text", "data": "Hello"}] + }, + "parent": "root", + "timestamp": 1234567890.0 + })"_json; + + MessageNode node = input.get(); + AssertWithMessage(node.get_id() == "node_1", "id mismatch"); + AssertWithMessage(node.get_children().size() == 1, "Should have 1 child"); + AssertWithMessage(node.get_parent().has_value(), "Should have parent"); + AssertWithMessage(node.get_parent().value() == "root", "parent mismatch"); + AssertWithMessage(std::holds_alternative(node.get_message()), "message should be ChatMessage"); + + json output = node; + AssertWithMessage(output == input, "Round-trip should produce identical JSON"); +} + +static void TestMessageNodeWithFunctionCall() +{ + json input = R"({ + "children": [], + "id": "node_2", + "message": { + "type": "function_call", + "call_id": "c1", + "name": "foo", + "arguments": "{}" + }, + "timestamp": 100.0 + })"_json; + + MessageNode node = input.get(); + AssertWithMessage(std::holds_alternative(node.get_message()), "message should be FunctionCallMessage"); + AssertWithMessage(!node.get_parent().has_value(), "Should not have parent"); + + json output = node; + AssertWithMessage(output == input, "Round-trip should produce identical JSON"); +} + +static void TestInvalidMessageThrows() +{ + json invalid = R"({"unknown_field": 123})"_json; + + bool threw = false; + try + { + [[maybe_unused]] auto msg = invalid.get(); + } + catch (const std::exception&) + { + threw = true; + } + AssertWithMessage(threw, "Deserializing invalid JSON should throw"); +} + +static void TestChatCompletionParamsRoundTrip() +{ + json input = R"({ + "id": "chat_1", + "messages": [ + {"role": "user", "content": [{"type": "text", "data": "Hi"}]}, + {"type": "function_call", "call_id": "c1", "name": "greet", "arguments": "{}"}, + {"type": "function_call_output", "call_id": "c1", "output": [{"type": "text", "data": "Hello!"}]} + ], + "modelId": "gpt-4", + "parent": "prev_node" + })"_json; + + ChatCompletionParams params = input.get(); + AssertWithMessage(params.get_id() == "chat_1", "id mismatch"); + AssertWithMessage(params.get_messages().size() == 3, "Should have 3 messages"); + AssertWithMessage(params.get_model_id() == "gpt-4", "model_id mismatch"); + AssertWithMessage(params.get_parent().has_value(), "Should have parent"); + + json output = params; + AssertWithMessage(output == input, "Round-trip should produce identical JSON"); +} + +static void TestImageUrlContentType() +{ + json input = R"({ + "role": "user", + "content": [ + {"type": "image_url", "data": "https://example.com/img.png"}, + {"type": "text", "data": "What is in this image?"} + ] + })"_json; + + Message msg = input.get(); + auto& chat = std::get(msg); + AssertWithMessage(chat.get_content()[0].get_type() == MessageContentType::IMAGE_URL, "First content should be image_url"); + + json output = msg; + AssertWithMessage(output == input, "Round-trip should produce identical JSON"); +} + +static void TestTreeHistoryRoundTrip() +{ + json input = R"({ + "nodes": { + "n1": { + "children": ["n2"], + "id": "n1", + "message": {"role": "user", "content": [{"type": "text", "data": "Hi"}]}, + "timestamp": 1.0 + }, + "n2": { + "children": [], + "id": "n2", + "message": {"role": "assistant", "content": [{"type": "text", "data": "Hello!"}]}, + "parent": "n1", + "timestamp": 2.0 + } + } + })"_json; + + TreeHistory history = input.get(); + AssertWithMessage(history.get_nodes().size() == 2, "Should have 2 nodes"); + AssertWithMessage(history.get_nodes().at("n1").get_children().size() == 1, "n1 should have 1 child"); + AssertWithMessage(!history.get_nodes().at("n1").get_parent().has_value(), "n1 should not have parent"); + AssertWithMessage(history.get_nodes().at("n2").get_parent().value() == "n1", "n2 parent should be n1"); + + json output = history; + AssertWithMessage(output == input, "Round-trip should produce identical JSON"); +} + +static void TestChatCompletionSegmentStringVariant() +{ + json strInput = R"("hello stream chunk")"_json; + + ChatCompletionSegment seg = strInput.get(); + AssertWithMessage(std::holds_alternative(seg), "Should be string variant"); + AssertWithMessage(std::get(seg) == "hello stream chunk", "String value mismatch"); + + json output = seg; + AssertWithMessage(output == strInput, "Round-trip should produce identical JSON"); +} + +static void TestChatCompletionSegmentObjectVariant() +{ + json objInput = R"({ + "event": "function_call_start", + "data": { + "type": "function_call", + "call_id": "c1", + "name": "test_fn", + "arguments": "{}", + "extra": {"id": "x", "status": "in_progress"} + } + })"_json; + + ChatCompletionSegment seg = objInput.get(); + AssertWithMessage(std::holds_alternative(seg), "Should be ChatCompletionSegmentClass variant"); + + auto& cls = std::get(seg); + AssertWithMessage(cls.get_event() == Event::FUNCTION_CALL_START, "Event should be function_call_start"); + AssertWithMessage(cls.get_data().has_value(), "Should have data"); + AssertWithMessage(cls.get_data()->get_name() == "test_fn", "data name mismatch"); + + json output = seg; + AssertWithMessage(output == objInput, "Round-trip should produce identical JSON"); +} + +int main() +{ + struct { const char* name; void (*fn)(); } tests[] = { + {"ChatMessageRoundTrip", TestChatMessageRoundTrip}, + {"ChatMessageAssistantRole", TestChatMessageAssistantRole}, + {"FunctionCallMessageRoundTrip", TestFunctionCallMessageRoundTrip}, + {"FunctionCallMessageWithExtra", TestFunctionCallMessageWithExtra}, + {"FunctionCallOutputMessageRoundTrip", TestFunctionCallOutputMessageRoundTrip}, + {"FunctionCallOutputMessageWithExtra", TestFunctionCallOutputMessageWithExtra}, + {"MessageArrayRoundTrip", TestMessageArrayRoundTrip}, + {"LinearHistoryRoundTrip", TestLinearHistoryRoundTrip}, + {"MessageNodeRoundTrip", TestMessageNodeRoundTrip}, + {"MessageNodeWithFunctionCall", TestMessageNodeWithFunctionCall}, + {"InvalidMessageThrows", TestInvalidMessageThrows}, + {"ChatCompletionParamsRoundTrip", TestChatCompletionParamsRoundTrip}, + {"ImageUrlContentType", TestImageUrlContentType}, + {"TreeHistoryRoundTrip", TestTreeHistoryRoundTrip}, + {"ChatCompletionSegmentStringVariant", TestChatCompletionSegmentStringVariant}, + {"ChatCompletionSegmentObjectVariant", TestChatCompletionSegmentObjectVariant}, + }; + + int passed = 0; + int failed = 0; + for (auto& [name, fn] : tests) + { + try + { + std::cout << "Running: " << name << "... "; + fn(); + std::cout << "PASS" << std::endl; + passed++; + } + catch (const std::exception& e) + { + std::cout << "FAIL: " << e.what() << std::endl; + failed++; + } + } + + std::cout << "\n" << passed << " passed, " << failed << " failed, " << (passed + failed) << " total." << std::endl; + return failed > 0 ? 1 : 0; +} From a4566c8eb2c5e23dd1f0b6dcd9736ec8deaa6a4a Mon Sep 17 00:00:00 2001 From: Feng Wang Date: Sun, 15 Mar 2026 14:24:27 +0800 Subject: [PATCH 05/11] update bulk generation capability --- src/schema/IServer.h | 39 +++++++++++++++++++++++++++++++++------ 1 file changed, 33 insertions(+), 6 deletions(-) diff --git a/src/schema/IServer.h b/src/schema/IServer.h index 7770e37..086a301 100644 --- a/src/schema/IServer.h +++ b/src/schema/IServer.h @@ -22,6 +22,7 @@ // ChatCompletionSegment data = nlohmann::json::parse(jsonString); // ChatCompletionInfo data = nlohmann::json::parse(jsonString); // ExecuteGenerationTaskParams data = nlohmann::json::parse(jsonString); +// ExecuteGenerationTaskResult data = nlohmann::json::parse(jsonString); // GetModelListParams data = nlohmann::json::parse(jsonString); // GetModelListResult data = nlohmann::json::parse(jsonString); // ModelSettings data = nlohmann::json::parse(jsonString); @@ -480,19 +481,33 @@ namespace IServer { virtual ~ExecuteGenerationTaskParams() = default; private: - ChatMessage message; + std::vector messages; std::string model_id; public: - const ChatMessage & get_message() const { return message; } - ChatMessage & get_mutable_message() { return message; } - void set_message(const ChatMessage & value) { this->message = value; } + const std::vector & get_messages() const { return messages; } + std::vector & get_mutable_messages() { return messages; } + void set_messages(const std::vector & value) { this->messages = value; } const std::string & get_model_id() const { return model_id; } std::string & get_mutable_model_id() { return model_id; } void set_model_id(const std::string & value) { this->model_id = value; } }; + class ExecuteGenerationTaskResult { + public: + ExecuteGenerationTaskResult() = default; + virtual ~ExecuteGenerationTaskResult() = default; + + private: + std::vector messages; + + public: + const std::vector & get_messages() const { return messages; } + std::vector & get_mutable_messages() { return messages; } + void set_messages(const std::vector & value) { this->messages = value; } + }; + class GetModelListParams { public: GetModelListParams() = default; @@ -949,6 +964,9 @@ void to_json(json & j, const ChatCompletionInfo & x); void from_json(const json & j, ExecuteGenerationTaskParams & x); void to_json(json & j, const ExecuteGenerationTaskParams & x); +void from_json(const json & j, ExecuteGenerationTaskResult & x); +void to_json(json & j, const ExecuteGenerationTaskResult & x); + void from_json(const json & j, GetModelListParams & x); void to_json(json & j, const GetModelListParams & x); @@ -1231,16 +1249,25 @@ namespace IServer { } inline void from_json(const json & j, ExecuteGenerationTaskParams& x) { - x.set_message(j.at("message").get()); + x.set_messages(j.at("messages").get>()); x.set_model_id(j.at("modelId").get()); } inline void to_json(json & j, const ExecuteGenerationTaskParams & x) { j = json::object(); - j["message"] = x.get_message(); + j["messages"] = x.get_messages(); j["modelId"] = x.get_model_id(); } + inline void from_json(const json & j, ExecuteGenerationTaskResult& x) { + x.set_messages(j.at("messages").get>()); + } + + inline void to_json(json & j, const ExecuteGenerationTaskResult & x) { + j = json::object(); + j["messages"] = x.get_messages(); + } + inline void from_json(const json & j, GetModelListParams& x) { x.set_metadata_keys(get_stack_optional>(j, "metadataKeys")); } From 930b5215a893809f252833787d4f0ab97b61166b Mon Sep 17 00:00:00 2001 From: Feng Wang Date: Sun, 15 Mar 2026 14:51:18 +0800 Subject: [PATCH 06/11] support tool calling at provider level --- src/apiProvider/AzureOpenAI.cpp | 97 ++++++------ src/apiProvider/AzureOpenAI.h | 9 +- src/apiProvider/IProvider.h | 9 +- src/apiProvider/OpenAI.cpp | 261 +++++++++++++++++++++++++------- src/apiProvider/OpenAI.h | 9 +- src/schema/IServer.h | 60 ++++++++ test/TestApiProvider.cpp | 231 +++++++++++++++++++++++++--- 7 files changed, 550 insertions(+), 126 deletions(-) diff --git a/src/apiProvider/AzureOpenAI.cpp b/src/apiProvider/AzureOpenAI.cpp index c81e828..d22803d 100644 --- a/src/apiProvider/AzureOpenAI.cpp +++ b/src/apiProvider/AzureOpenAI.cpp @@ -48,7 +48,7 @@ void AzureOpenAI::Initialize(const nlohmann::json& params) _params = ParamsDefinition.Parse(params); } -RequestData AzureOpenAI::FormatRequest(const Schema::IServer::LinearHistory& history, bool stream) const +RequestData AzureOpenAI::FormatRequest(const Schema::IServer::LinearHistory& history, bool stream, const std::optional>& /*tools*/) const { RequestData data{}; data.url = _params.url; @@ -67,55 +67,64 @@ RequestData AzureOpenAI::FormatRequest(const Schema::IServer::LinearHistory& his } for (const auto& message : history) { - auto messageJson = nlohmann::json::object(); - messageJson["content"] = nlohmann::json::array(); - if (message.get_role() == Schema::IServer::MessageRole::DEVELOPER) + if (std::holds_alternative(message)) { - messageJson["role"] = "system"; - } - else if(message.get_role() == Schema::IServer::MessageRole::USER) - { - messageJson["role"] = "user"; - } - else if(message.get_role() == Schema::IServer::MessageRole::ASSISTANT) - { - messageJson["role"] = "assistant"; - } - else - { - continue; - } - for (const auto& content : message.get_content()) - { - if (content.get_type() == Schema::IServer::Type::TEXT || content.get_type() == Schema::IServer::Type::REFUSAL) + const auto& chatMessage = std::get(message); + auto messageJson = nlohmann::json::object(); + messageJson["content"] = nlohmann::json::array(); + if (chatMessage.get_role() == Schema::IServer::ChatMessageRole::DEVELOPER) + { + messageJson["role"] = "system"; + } + else if(chatMessage.get_role() == Schema::IServer::ChatMessageRole::USER) { - /** Azure open ai does not seem to have a special REFUSAL message type */ - messageJson["content"].push_back({ - {"type", "text"}, - {"text", content.get_data()} - }); + messageJson["role"] = "user"; } - else if (content.get_type() == Schema::IServer::Type::IMAGE_URL) + else if(chatMessage.get_role() == Schema::IServer::ChatMessageRole::ASSISTANT) { - messageJson["content"].push_back({ - {"type", "image_url"}, - {"image_url", { - {"url", content.get_data()} - }} - }); + messageJson["role"] = "assistant"; } else { continue; } + for (const auto& content : chatMessage.get_content()) + { + if (content.get_type() == Schema::IServer::MessageContentType::TEXT || content.get_type() == Schema::IServer::MessageContentType::REFUSAL) + { + /** Azure open ai does not seem to have a special REFUSAL message type */ + messageJson["content"].push_back({ + {"type", "text"}, + {"text", content.get_data()} + }); + } + else if (content.get_type() == Schema::IServer::MessageContentType::IMAGE_URL) + { + messageJson["content"].push_back({ + {"type", "image_url"}, + {"image_url", { + {"url", content.get_data()} + }} + }); + } + else + { + continue; + } + } + body["messages"].push_back(messageJson); + } + else if (std::holds_alternative(message) + || std::holds_alternative(message)) + { + throw std::runtime_error("AzureOpenAI does not support function call messages"); } - body["messages"].push_back(messageJson); } data.body = body.dump(); return data; } -Schema::IServer::MessageContent AzureOpenAI::ParseResponse(const std::string& responseString) const +Schema::IServer::LinearHistory AzureOpenAI::ParseResponse(const std::string& responseString) const { Schema::AzureOpenAI::BulkResponse response; try @@ -131,14 +140,17 @@ Schema::IServer::MessageContent AzureOpenAI::ParseResponse(const std::string& re throw std::runtime_error("No choices in response"); } const auto& choice = response.get_choices().front(); - const auto& message = choice.get_message(); + const auto& msg = choice.get_message(); Schema::IServer::MessageContent content; - content.set_type(message.get_refusal() ? Schema::IServer::Type::REFUSAL : Schema::IServer::Type::TEXT); - content.set_data(message.get_content()); - return content; + content.set_type(msg.get_refusal() ? Schema::IServer::MessageContentType::REFUSAL : Schema::IServer::MessageContentType::TEXT); + content.set_data(msg.get_content()); + Schema::IServer::ChatMessage chatMessage; + chatMessage.set_role(Schema::IServer::ChatMessageRole::ASSISTANT); + chatMessage.set_content({content}); + return Schema::IServer::LinearHistory{std::move(chatMessage)}; } -std::optional AzureOpenAI::ParseStreamResponse(const StreamResponse::Event& event) const +std::optional AzureOpenAI::ParseStreamResponse(const StreamResponse::Event& event) const { if (!event.value.has_value()) { @@ -172,8 +184,5 @@ std::optional AzureOpenAI::ParseStreamResponse( { return std::nullopt; } - Schema::IServer::MessageContent content{}; - content.set_type(Schema::IServer::Type::TEXT); - content.set_data(delta.get_content().value()); - return content; + return delta.get_content().value(); } diff --git a/src/apiProvider/AzureOpenAI.h b/src/apiProvider/AzureOpenAI.h index 0a49898..e05733b 100644 --- a/src/apiProvider/AzureOpenAI.h +++ b/src/apiProvider/AzureOpenAI.h @@ -12,9 +12,12 @@ namespace TUI::ApiProvider ~AzureOpenAI() override = default; nlohmann::json GetParams() const override; void Initialize(const nlohmann::json& params) override; - Network::Http::RequestData FormatRequest(const Schema::IServer::LinearHistory& history, bool stream) const override; - Schema::IServer::MessageContent ParseResponse(const std::string& response) const override; - std::optional ParseStreamResponse(const Network::Http::StreamResponse::Event& event) const override; + Network::Http::RequestData FormatRequest( + const Schema::IServer::LinearHistory& history, + bool stream, + const std::optional>& tools = std::nullopt) const override; + Schema::IServer::LinearHistory ParseResponse(const std::string& response) const override; + std::optional ParseStreamResponse(const Network::Http::StreamResponse::Event& event) const override; private: struct Params { diff --git a/src/apiProvider/IProvider.h b/src/apiProvider/IProvider.h index d9880d3..41b16cb 100644 --- a/src/apiProvider/IProvider.h +++ b/src/apiProvider/IProvider.h @@ -22,8 +22,11 @@ namespace TUI::ApiProvider virtual ~IProvider() = default; virtual nlohmann::json GetParams() const = 0; virtual void Initialize(const nlohmann::json& params) = 0; - virtual Network::Http::RequestData FormatRequest(const Schema::IServer::LinearHistory& history, bool stream) const = 0; - virtual Schema::IServer::MessageContent ParseResponse(const std::string& response) const = 0; - virtual std::optional ParseStreamResponse(const Network::Http::StreamResponse::Event& event) const = 0; + virtual Network::Http::RequestData FormatRequest( + const Schema::IServer::LinearHistory& history, + bool stream, + const std::optional>& tools = std::nullopt) const = 0; + virtual Schema::IServer::LinearHistory ParseResponse(const std::string& response) const = 0; + virtual std::optional ParseStreamResponse(const Network::Http::StreamResponse::Event& event) const = 0; }; } diff --git a/src/apiProvider/OpenAI.cpp b/src/apiProvider/OpenAI.cpp index f65eef9..5adf076 100644 --- a/src/apiProvider/OpenAI.cpp +++ b/src/apiProvider/OpenAI.cpp @@ -60,7 +60,7 @@ void OpenAI::Initialize(const nlohmann::json& params) _params = ParamsDefinition.Parse(params); } -RequestData OpenAI::FormatRequest(const Schema::IServer::LinearHistory& history, bool stream) const +RequestData OpenAI::FormatRequest(const Schema::IServer::LinearHistory& history, bool stream, const std::optional>& tools) const { RequestData data{}; data.url = _params.url; @@ -83,96 +83,191 @@ RequestData OpenAI::FormatRequest(const Schema::IServer::LinearHistory& history, }; } - for (const auto& message : history) + if (tools.has_value() && !tools.value().empty()) { - auto messageJson = nlohmann::json::object(); - messageJson["content"] = nlohmann::json::array(); - if (message.get_role() == Schema::IServer::MessageRole::DEVELOPER) - { - messageJson["role"] = "system"; - } - else if(message.get_role() == Schema::IServer::MessageRole::USER) - { - messageJson["role"] = "user"; - } - else if(message.get_role() == Schema::IServer::MessageRole::ASSISTANT) + body["tools"] = nlohmann::json::array(); + for (const auto& tool : tools.value()) { - messageJson["role"] = "assistant"; - } - else - { - continue; + body["tools"].push_back({ + {"type", "function"}, + {"name", tool.get_name()}, + {"description", tool.get_description()}, + {"parameters", tool.get_parameters()} + }); } - for (const auto& content : message.get_content()) + } + + for (const auto& message : history) + { + if (std::holds_alternative(message)) { - if (content.get_type() == Schema::IServer::Type::TEXT || content.get_type() == Schema::IServer::Type::REFUSAL) + const auto& chatMessage = std::get(message); + auto messageJson = nlohmann::json::object(); + messageJson["content"] = nlohmann::json::array(); + if (chatMessage.get_role() == Schema::IServer::ChatMessageRole::DEVELOPER) { - messageJson["content"].push_back({ - {"type", message.get_role() == Schema::IServer::MessageRole::ASSISTANT ? "output_text" : "input_text"}, - {"text", content.get_data()} - }); + messageJson["role"] = "system"; } - else if (content.get_type() == Schema::IServer::Type::IMAGE_URL) + else if(chatMessage.get_role() == Schema::IServer::ChatMessageRole::USER) { - messageJson["content"].push_back({ - {"type", "input_image"}, - {"image_url", content.get_data()} - }); + messageJson["role"] = "user"; + } + else if(chatMessage.get_role() == Schema::IServer::ChatMessageRole::ASSISTANT) + { + messageJson["role"] = "assistant"; } else { continue; } + for (const auto& content : chatMessage.get_content()) + { + if (content.get_type() == Schema::IServer::MessageContentType::TEXT || content.get_type() == Schema::IServer::MessageContentType::REFUSAL) + { + messageJson["content"].push_back({ + {"type", chatMessage.get_role() == Schema::IServer::ChatMessageRole::ASSISTANT ? "output_text" : "input_text"}, + {"text", content.get_data()} + }); + } + else if (content.get_type() == Schema::IServer::MessageContentType::IMAGE_URL) + { + messageJson["content"].push_back({ + {"type", "input_image"}, + {"image_url", content.get_data()} + }); + } + else + { + continue; + } + } + body["input"].push_back(messageJson); + } + else if (std::holds_alternative(message)) + { + const auto& functionCall = std::get(message); + auto functionCallJson = nlohmann::json::object(); + if (functionCall.get_extra().has_value()) + { + functionCallJson = functionCall.get_extra().value(); + } + functionCallJson["type"] = "function_call"; + functionCallJson["call_id"] = functionCall.get_call_id(); + functionCallJson["name"] = functionCall.get_name(); + functionCallJson["arguments"] = functionCall.get_arguments(); + body["input"].push_back(functionCallJson); + } + else if (std::holds_alternative(message)) + { + const auto& functionCallOutput = std::get(message); + auto functionCallOutputJson = nlohmann::json::object(); + if (functionCallOutput.get_extra().has_value()) + { + functionCallOutputJson = functionCallOutput.get_extra().value(); + } + functionCallOutputJson["type"] = "function_call_output"; + functionCallOutputJson["call_id"] = functionCallOutput.get_call_id(); + std::string outputText; + for (const auto& content : functionCallOutput.get_output()) + { + if (content.get_type() == Schema::IServer::MessageContentType::TEXT) + { + outputText += content.get_data(); + } + } + functionCallOutputJson["output"] = outputText; + body["input"].push_back(functionCallOutputJson); } - body["input"].push_back(messageJson); + else + { + throw std::runtime_error("Unknown message type in history"); + } + + } data.body = body.dump(); return data; } -Schema::IServer::MessageContent OpenAI::ParseResponse(const std::string& responseString) const +Schema::IServer::LinearHistory OpenAI::ParseResponse(const std::string& responseString) const { + Schema::IServer::LinearHistory results{}; + auto json = nlohmann::json::parse(responseString); if (!json.contains("output") || !json.at("output").is_array()) { throw std::invalid_argument("Invalid response: missing output array"); } - std::string text{}; for (const auto& item : json.at("output")) { - if (!item.contains("content") || !item.at("content").is_array()) + auto typeIt = item.find("type"); + if (typeIt == item.end()) { continue; } - for (const auto& content : item.at("content")) + auto itemType = typeIt->get(); + + if (itemType == "message") { - auto typeIt = content.find("type"); - if (typeIt == content.end()) + if (!item.contains("content") || !item.at("content").is_array()) { continue; } - auto typeStr = typeIt->get(); - if (content.contains("text") && (typeStr == "output_text" || typeStr == "text")) + Schema::IServer::ChatMessage chatMessage; + chatMessage.set_role(Schema::IServer::ChatMessageRole::ASSISTANT); + std::vector contents; + for (const auto& content : item.at("content")) + { + auto contentTypeIt = content.find("type"); + if (contentTypeIt == content.end()) + { + continue; + } + auto contentType = contentTypeIt->get(); + if (content.contains("text") && (contentType == "output_text" || contentType == "text")) + { + Schema::IServer::MessageContent messageContent; + messageContent.set_type(Schema::IServer::MessageContentType::TEXT); + messageContent.set_data(content.at("text").get()); + contents.push_back(std::move(messageContent)); + } + else if (contentType == "refusal" && content.contains("refusal")) + { + Schema::IServer::MessageContent messageContent; + messageContent.set_type(Schema::IServer::MessageContentType::REFUSAL); + messageContent.set_data(content.at("refusal").get()); + contents.push_back(std::move(messageContent)); + } + } + if (!contents.empty()) { - text += content.at("text").get(); + chatMessage.set_content(contents); + results.push_back(std::move(chatMessage)); } } + else if (itemType == "function_call") + { + Schema::IServer::FunctionCallMessage functionCall; + functionCall.set_type(Schema::IServer::FunctionCallMessageType::FUNCTION_CALL); + functionCall.set_call_id(item.at("call_id").get()); + functionCall.set_name(item.at("name").get()); + functionCall.set_arguments(item.at("arguments").get()); + functionCall.set_extra(std::make_optional(item)); + results.push_back(std::move(functionCall)); + } } - if (text.empty()) + if (results.empty()) { - throw std::runtime_error("No output text in response"); + throw std::runtime_error("No output in response"); } - Schema::IServer::MessageContent content; - content.set_type(Schema::IServer::Type::TEXT); - content.set_data(std::move(text)); - return content; + return results; } -std::optional OpenAI::ParseStreamResponse(const StreamResponse::Event& event) const +std::optional OpenAI::ParseStreamResponse(const StreamResponse::Event& event) const { if (!event.value.has_value()) { @@ -190,11 +285,6 @@ std::optional OpenAI::ParseStreamResponse(const const std::string eventType = trimLeading(event.type.value_or("")); const std::string valueString = event.value.value(); - if (eventType == "response.completed" || valueString == "[DONE]") - { - return std::nullopt; - } - if (eventType == "response.output_text.delta") { try @@ -202,10 +292,7 @@ std::optional OpenAI::ParseStreamResponse(const auto json = nlohmann::json::parse(valueString); if (json.contains("delta") && json.at("delta").is_string()) { - Schema::IServer::MessageContent content{}; - content.set_type(Schema::IServer::Type::TEXT); - content.set_data(json.at("delta").get()); - return content; + return json.at("delta").get(); } } catch(...) @@ -213,6 +300,70 @@ std::optional OpenAI::ParseStreamResponse(const return std::nullopt; } } + else if (eventType == "response.output_item.added") + { + try + { + auto json = nlohmann::json::parse(valueString); + if (json.contains("item") && json.at("item").contains("type") + && json.at("item").at("type").get() == "function_call") + { + const auto& item = json.at("item"); + Schema::IServer::FunctionCallMessage functionCall; + functionCall.set_type(Schema::IServer::FunctionCallMessageType::FUNCTION_CALL); + functionCall.set_call_id(item.at("call_id").get()); + functionCall.set_name(item.at("name").get()); + functionCall.set_arguments(item.value("arguments", "")); + Schema::IServer::ChatCompletionSegmentClass segment; + segment.set_event(Schema::IServer::Event::FUNCTION_CALL_START); + segment.set_data(std::move(functionCall)); + return segment; + } + } + catch (...) + { + return std::nullopt; + } + } + else if (eventType == "response.function_call_arguments.done") + { + try + { + auto json = nlohmann::json::parse(valueString); + Schema::IServer::FunctionCallMessage functionCall; + functionCall.set_type(Schema::IServer::FunctionCallMessageType::FUNCTION_CALL); + functionCall.set_call_id(json.value("call_id", "")); + functionCall.set_name(json.value("name", "")); + functionCall.set_arguments(json.value("arguments", "")); + Schema::IServer::ChatCompletionSegmentClass segment; + segment.set_event(Schema::IServer::Event::FUNCTION_CALL_END); + segment.set_data(std::move(functionCall)); + return segment; + } + catch (...) + { + return std::nullopt; + } + } + else if (eventType == "response.completed" || valueString == "[DONE]") + { + return std::nullopt; + } + else if (eventType == "response.failed") + { + std::string errorMsg = "Response failed"; + try + { + auto json = nlohmann::json::parse(valueString); + if (json.contains("response") && json.at("response").contains("error") + && json.at("response").at("error").contains("message")) + { + errorMsg = json.at("response").at("error").at("message").get(); + } + } + catch (...) {} + throw std::runtime_error(errorMsg); + } return std::nullopt; } diff --git a/src/apiProvider/OpenAI.h b/src/apiProvider/OpenAI.h index bde4f00..cbf04b6 100644 --- a/src/apiProvider/OpenAI.h +++ b/src/apiProvider/OpenAI.h @@ -19,9 +19,12 @@ namespace TUI::ApiProvider nlohmann::json GetParams() const override; void Initialize(const nlohmann::json& params) override; - Network::Http::RequestData FormatRequest(const Schema::IServer::LinearHistory& history, bool stream) const override; - Schema::IServer::MessageContent ParseResponse(const std::string& response) const override; - std::optional ParseStreamResponse(const Network::Http::StreamResponse::Event& event) const override; + Network::Http::RequestData FormatRequest( + const Schema::IServer::LinearHistory& history, + bool stream, + const std::optional>& tools = std::nullopt) const override; + Schema::IServer::LinearHistory ParseResponse(const std::string& response) const override; + std::optional ParseStreamResponse(const Network::Http::StreamResponse::Event& event) const override; private: struct Params diff --git a/src/schema/IServer.h b/src/schema/IServer.h index 086a301..1481948 100644 --- a/src/schema/IServer.h +++ b/src/schema/IServer.h @@ -18,6 +18,7 @@ // TreeHistory data = nlohmann::json::parse(jsonString); // GetChatListParams data = nlohmann::json::parse(jsonString); // GetChatListResult data = nlohmann::json::parse(jsonString); +// Tool data = nlohmann::json::parse(jsonString); // ChatCompletionParams data = nlohmann::json::parse(jsonString); // ChatCompletionSegment data = nlohmann::json::parse(jsonString); // ChatCompletionInfo data = nlohmann::json::parse(jsonString); @@ -405,6 +406,33 @@ namespace IServer { void set_metadata(std::optional> value) { this->metadata = value; } }; + class Tool { + public: + Tool() = default; + virtual ~Tool() = default; + + private: + std::string description; + std::string name; + nlohmann::json parameters; + + public: + const std::string & get_description() const { return description; } + std::string & get_mutable_description() { return description; } + void set_description(const std::string & value) { this->description = value; } + + const std::string & get_name() const { return name; } + std::string & get_mutable_name() { return name; } + void set_name(const std::string & value) { this->name = value; } + + /** + * JSON schema for the parameter + */ + const nlohmann::json & get_parameters() const { return parameters; } + nlohmann::json & get_mutable_parameters() { return parameters; } + void set_parameters(const nlohmann::json & value) { this->parameters = value; } + }; + class ChatCompletionParams { public: ChatCompletionParams() = default; @@ -415,6 +443,7 @@ namespace IServer { std::vector messages; std::string model_id; std::optional parent; + std::optional> tools; public: const std::string & get_id() const { return id; } @@ -434,6 +463,9 @@ namespace IServer { */ std::optional get_parent() const { return parent; } void set_parent(std::optional value) { this->parent = value; } + + std::optional> get_tools() const { return tools; } + void set_tools(std::optional> value) { this->tools = value; } }; enum class Event : int { FUNCTION_CALL_END, FUNCTION_CALL_START }; @@ -483,6 +515,7 @@ namespace IServer { private: std::vector messages; std::string model_id; + std::optional> tools; public: const std::vector & get_messages() const { return messages; } @@ -492,6 +525,9 @@ namespace IServer { const std::string & get_model_id() const { return model_id; } std::string & get_mutable_model_id() { return model_id; } void set_model_id(const std::string & value) { this->model_id = value; } + + std::optional> get_tools() const { return tools; } + void set_tools(std::optional> value) { this->tools = value; } }; class ExecuteGenerationTaskResult { @@ -952,6 +988,9 @@ void to_json(json & j, const GetChatListParams & x); void from_json(const json & j, GetChatListResultElement & x); void to_json(json & j, const GetChatListResultElement & x); +void from_json(const json & j, Tool & x); +void to_json(json & j, const Tool & x); + void from_json(const json & j, ChatCompletionParams & x); void to_json(json & j, const ChatCompletionParams & x); @@ -1209,11 +1248,25 @@ namespace IServer { } } + inline void from_json(const json & j, Tool& x) { + x.set_description(j.at("description").get()); + x.set_name(j.at("name").get()); + x.set_parameters(get_untyped(j, "parameters")); + } + + inline void to_json(json & j, const Tool & x) { + j = json::object(); + j["description"] = x.get_description(); + j["name"] = x.get_name(); + j["parameters"] = x.get_parameters(); + } + inline void from_json(const json & j, ChatCompletionParams& x) { x.set_id(j.at("id").get()); x.set_messages(j.at("messages").get>()); x.set_model_id(j.at("modelId").get()); x.set_parent(get_stack_optional(j, "parent")); + x.set_tools(get_stack_optional>(j, "tools")); } inline void to_json(json & j, const ChatCompletionParams & x) { @@ -1224,6 +1277,9 @@ namespace IServer { if (x.get_parent()) { j["parent"] = x.get_parent(); } + if (x.get_tools()) { + j["tools"] = x.get_tools(); + } } inline void from_json(const json & j, ChatCompletionSegmentClass& x) { @@ -1251,12 +1307,16 @@ namespace IServer { inline void from_json(const json & j, ExecuteGenerationTaskParams& x) { x.set_messages(j.at("messages").get>()); x.set_model_id(j.at("modelId").get()); + x.set_tools(get_stack_optional>(j, "tools")); } inline void to_json(json & j, const ExecuteGenerationTaskParams & x) { j = json::object(); j["messages"] = x.get_messages(); j["modelId"] = x.get_model_id(); + if (x.get_tools()) { + j["tools"] = x.get_tools(); + } } inline void from_json(const json & j, ExecuteGenerationTaskResult& x) { diff --git a/test/TestApiProvider.cpp b/test/TestApiProvider.cpp index a1b115c..bdf5050 100644 --- a/test/TestApiProvider.cpp +++ b/test/TestApiProvider.cpp @@ -16,7 +16,6 @@ using namespace TUI::Network; static Tev tev{}; static std::filesystem::path configFilePath; -static std::filesystem::path historyFilePath; static nlohmann::json LoadJsonFile(const std::filesystem::path& path) { @@ -29,12 +28,92 @@ static nlohmann::json LoadJsonFile(const std::filesystem::path& path) return nlohmann::json::parse(content); } -static Schema::IServer::LinearHistory LoadChatHistory(const std::filesystem::path& path) +static Schema::IServer::LinearHistory GetChatHistory() { - auto json = LoadJsonFile(path); + auto json = nlohmann::json::parse(R"([ + { + "role": "developer", + "content": [ + {"type": "text", "data": "You are a helpful assistant."} + ] + }, + { + "role": "user", + "content": [ + {"type": "text", "data": "What is the capital of France?"} + ] + }, + { + "role": "assistant", + "content": [ + {"type": "text", "data": "Paris"} + ] + }, + { + "role": "user", + "content": [ + {"type": "text", "data": "What's the second largest city of France?"} + ] + } + ])"); + return json.get(); +} + +static Schema::IServer::LinearHistory GetToolCallHistory() +{ + auto json = nlohmann::json::parse(R"([ + { + "role": "user", + "content": [ + {"type": "text", "data": "What is the current room temperature?"} + ] + } + ])"); return json.get(); } +static std::vector GetTools() +{ + Schema::IServer::Tool tool; + tool.set_name("get_room_temperature"); + tool.set_description("Get the current room temperature."); + tool.set_parameters(nlohmann::json::parse(R"({ + "type": "object", + "properties": {}, + "required": [], + "additionalProperties": false + })")); + return {tool}; +} + +static std::string ExecuteFunctionCall(const Schema::IServer::FunctionCallMessage& fcm) +{ + if (fcm.get_name() == "get_room_temperature") + { + return R"({"temperature": "22.5°C"})"; + } + return R"({"error": "unknown function"})"; +} + +static void AppendFunctionCallOutput( + Schema::IServer::LinearHistory& history, + const Schema::IServer::FunctionCallMessage& fcm) +{ + // Add the function call itself to history + history.push_back(fcm); + + // Execute and add the output + auto result = ExecuteFunctionCall(fcm); + Schema::IServer::FunctionCallOutputMessage fcom; + fcom.set_type(Schema::IServer::FunctionCallOutputMessageType::FUNCTION_CALL_OUTPUT); + fcom.set_call_id(fcm.get_call_id()); + Schema::IServer::MessageContent content; + content.set_type(Schema::IServer::MessageContentType::TEXT); + content.set_data(std::move(result)); + fcom.set_output({content}); + history.push_back(std::move(fcom)); +} + JS::Promise TestBulkChatAsync() { auto client = Http::Client::Create(tev); @@ -42,11 +121,11 @@ JS::Promise TestBulkChatAsync() auto provider = ApiProvider::Factory::CreateProvider( params["providerName"].get(), params["providerParams"]); - auto history = LoadChatHistory(historyFilePath); + auto history = GetChatHistory(); auto requestData = provider->FormatRequest(history, false); auto response = co_await client->MakeRequest(Http::Method::POST, requestData).GetResponseAsync(); - auto message = provider->ParseResponse(response); - std::cout << nlohmann::json(message).dump(4) << std::endl; + auto result = provider->ParseResponse(response); + std::cout << nlohmann::json(result).dump(4) << std::endl; } JS::Promise TestStreamChatAsync() @@ -56,7 +135,7 @@ JS::Promise TestStreamChatAsync() auto provider = ApiProvider::Factory::CreateProvider( params["providerName"].get(), params["providerParams"]); - auto history = LoadChatHistory(historyFilePath); + auto history = GetChatHistory(); auto requestData = provider->FormatRequest(history, true); auto streamEventParser = Http::StreamResponse::Parser{}; auto responseStream = client->MakeStreamRequest(Http::Method::POST, requestData).GetResponseStream(); @@ -70,11 +149,128 @@ JS::Promise TestStreamChatAsync() std::cout << std::endl; break; } - auto message = provider->ParseStreamResponse(event.value()); - if (message.has_value()) + auto segment = provider->ParseStreamResponse(event.value()); + if (segment.has_value()) { - std::cout << message.value().get_data(); - std::cout.flush(); + if (std::holds_alternative(segment.value())) + { + std::cout << std::get(segment.value()); + std::cout.flush(); + } + } + } +} + +JS::Promise TestBulkToolCallAsync() +{ + auto client = Http::Client::Create(tev); + auto params = LoadJsonFile(configFilePath); + auto provider = ApiProvider::Factory::CreateProvider( + params["providerName"].get(), + params["providerParams"]); + auto history = GetToolCallHistory(); + auto tools = GetTools(); + + while (true) + { + auto requestData = provider->FormatRequest(history, false, tools); + auto response = co_await client->MakeRequest(Http::Method::POST, requestData).GetResponseAsync(); + auto result = provider->ParseResponse(response); + + bool hasFunctionCall = false; + for (const auto& msg : result) + { + if (std::holds_alternative(msg)) + { + hasFunctionCall = true; + const auto& fcm = std::get(msg); + std::cout << "[tool call] " << fcm.get_name() << "(" << fcm.get_arguments() << ")" << std::endl; + AppendFunctionCallOutput(history, fcm); + } + else if (std::holds_alternative(msg)) + { + const auto& chatMsg = std::get(msg); + for (const auto& content : chatMsg.get_content()) + { + std::cout << content.get_data(); + } + std::cout << std::endl; + } + } + if (!hasFunctionCall) + { + break; + } + } +} + +JS::Promise TestStreamToolCallAsync() +{ + auto client = Http::Client::Create(tev); + auto params = LoadJsonFile(configFilePath); + auto provider = ApiProvider::Factory::CreateProvider( + params["providerName"].get(), + params["providerParams"]); + auto history = GetToolCallHistory(); + auto tools = GetTools(); + + while (true) + { + auto requestData = provider->FormatRequest(history, true, tools); + auto responseStream = client->MakeStreamRequest(Http::Method::POST, requestData).GetResponseStream(); + auto parser = Http::StreamResponse::AsyncParser(responseStream); + auto eventStream = parser.Parse(); + + bool hasFunctionCall = false; + std::map pendingCalls; + while (true) + { + auto event = co_await eventStream.NextAsync(); + if (!event.has_value()) + { + break; + } + auto segment = provider->ParseStreamResponse(event.value()); + if (!segment.has_value()) + { + continue; + } + if (std::holds_alternative(segment.value())) + { + std::cout << std::get(segment.value()); + std::cout.flush(); + } + else + { + const auto& segClass = std::get(segment.value()); + if (segClass.get_event() == Schema::IServer::Event::FUNCTION_CALL_START) + { + auto fcm = segClass.get_data().value(); + std::cout << "[tool call start] " << fcm.get_name() << std::endl; + pendingCalls[static_cast(pendingCalls.size())] = std::move(fcm); + } + else if (segClass.get_event() == Schema::IServer::Event::FUNCTION_CALL_END) + { + hasFunctionCall = true; + auto endFcm = segClass.get_data().value(); + // Find the pending call and fill in arguments + for (auto& [idx, pending] : pendingCalls) + { + if (pending.get_arguments().empty()) + { + pending.set_arguments(endFcm.get_arguments()); + std::cout << "[tool call end] " << pending.get_name() << "(" << pending.get_arguments() << ")" << std::endl; + AppendFunctionCallOutput(history, pending); + break; + } + } + } + } + } + std::cout << std::endl; + if (!hasFunctionCall) + { + break; } } } @@ -83,29 +279,28 @@ static JS::Promise TestAsync() { RunAsyncTest(TestBulkChatAsync()); RunAsyncTest(TestStreamChatAsync()); + RunAsyncTest(TestBulkToolCallAsync()); + RunAsyncTest(TestStreamToolCallAsync()); } int main(int argc, char const *argv[]) { int opt = 0; - while ((opt = getopt(argc, const_cast(argv), "c:h:")) != -1) + while ((opt = getopt(argc, const_cast(argv), "c:")) != -1) { switch (opt) { case 'c': configFilePath = optarg; break; - case 'h': - historyFilePath = optarg; - break; default: - std::cerr << "Usage: " << argv[0] << " -c -h " << std::endl; + std::cerr << "Usage: " << argv[0] << " -c " << std::endl; return 1; } } - if (configFilePath.empty() || historyFilePath.empty()) + if (configFilePath.empty()) { - std::cerr << "Usage: " << argv[0] << " -c -h " << std::endl; + std::cerr << "Usage: " << argv[0] << " -c " << std::endl; return 1; } From 7eaa0aa06fa67607094e7e4b8ca983cff3f48f49 Mon Sep 17 00:00:00 2001 From: Feng Wang Date: Sun, 15 Mar 2026 18:20:16 +0800 Subject: [PATCH 07/11] update generation interfaces --- src/application/Service.cpp | 211 +++++++++++++++++++++------------- test/ServiceClient/ChatBot.js | 169 ++++++++++++++++++++++++--- test/TestDatabase.cpp | 41 ++++--- 3 files changed, 303 insertions(+), 118 deletions(-) diff --git a/src/application/Service.cpp b/src/application/Service.cpp index 5689f3f..5b165de 100644 --- a/src/application/Service.cpp +++ b/src/application/Service.cpp @@ -484,22 +484,20 @@ JS::Promise Service::DeleteChatAsync(CallerId callerId, nlohmann */ JS::AsyncGenerator Service::OnChatCompletionAsync(CallerId callerId, nlohmann::json paramsJson) { - using MessageRoleType = std::remove_reference>::type; - auto params = ParseParams(paramsJson); - if (params.get_user_message().get_role() != MessageRoleType::USER) + if (params.get_messages().empty()) { throw Schema::Rpc::Exception( Schema::Rpc::ErrorCode::BAD_REQUEST, - "The user message must have the role user"); + "Messages must not be empty"); } Common::Uuid chatId{params.get_id()}; auto lock = _resourceVersionManager->GetWriteLock( {"chat", static_cast(callerId.userId), static_cast(chatId)}, callerId); - int64_t userMessageTimestamp = Common::Timestamp::GetWallClock(); + int64_t inputTimestamp = Common::Timestamp::GetWallClock(); - /** Construct the linear history from the tree history and the new user message */ + /** Construct the linear history from the tree history and the new messages */ Common::Uuid parentId{nullptr}; Schema::IServer::LinearHistory history{}; { @@ -526,16 +524,11 @@ JS::AsyncGenerator Service::OnChatCompletionAsyn historyList.push_front(node.get_message()); parentIdStr = node.get_parent(); } - /** The previous message should be a assistant message if it exists */ - if (!historyList.empty() && historyList.back().get_role() != MessageRoleType::ASSISTANT) + + for (const auto& msg : params.get_messages()) { - throw Schema::Rpc::Exception( - Schema::Rpc::ErrorCode::BAD_REQUEST, - "The parent message must be an assistant message"); + historyList.push_back(msg); } - - /** We need to use the user message later. So don't move it. */ - historyList.push_back(params.get_user_message()); history.reserve(historyList.size()); for (auto&& item : historyList) { @@ -543,12 +536,12 @@ JS::AsyncGenerator Service::OnChatCompletionAsyn } } - /** Send the request */ - std::string wholeResponse{}; + /** Send the request and stream response */ + std::vector responseMessages{}; { Common::Uuid modelId{params.get_model_id()}; auto provider = GetProvider(modelId); - auto requestData = provider->FormatRequest(history, true); + auto requestData = provider->FormatRequest(history, true, params.get_tools()); auto request = _httpClient->MakeStreamRequest( Network::Http::Method::POST, requestData); @@ -557,6 +550,22 @@ JS::AsyncGenerator Service::OnChatCompletionAsyn auto eventStream = parser.Parse(); auto streamBatcher = Common::StreamBatcher::BatchStream(_tev, std::move(eventStream), STREAM_BATCHING_INTERVAL_MS); + std::string currentText{}; + Schema::IServer::FunctionCallMessage pendingCall{}; + bool hasPendingCall = false; + + auto flushCurrentText = [&]() { + if (currentText.empty()) return; + Schema::IServer::ChatMessage chatMsg{}; + chatMsg.set_role(Schema::IServer::ChatMessageRole::ASSISTANT); + Schema::IServer::MessageContent content{}; + content.set_type(Schema::IServer::MessageContentType::TEXT); + content.set_data(std::move(currentText)); + chatMsg.set_content({std::move(content)}); + responseMessages.push_back(std::move(chatMsg)); + currentText.clear(); + }; + while (true) { auto events = co_await streamBatcher.NextAsync(); @@ -564,73 +573,126 @@ JS::AsyncGenerator Service::OnChatCompletionAsyn { break; } - std::string segment = ""; + std::string batchText{}; for (const auto& event : events.value()) { - auto content = provider->ParseStreamResponse(event); - if (!content.has_value()) + auto segment = provider->ParseStreamResponse(event); + if (!segment.has_value()) { continue; } - using ContentTypeType = std::remove_reference::type; - if (content.value().get_type() != ContentTypeType::TEXT) + auto& seg = segment.value(); + if (std::holds_alternative(seg)) { - throw Schema::Rpc::Exception( - Schema::Rpc::ErrorCode::BAD_GATEWAY, - content.value().get_data()); + auto& text = std::get(seg); + batchText += text; + currentText += text; } - auto& data = content.value().get_data(); - segment += data; + else + { + if (!batchText.empty()) + { + co_yield static_cast( + Schema::IServer::ChatCompletionSegment{std::move(batchText)}); + batchText.clear(); + } + auto segClass = std::get(seg); + if (segClass.get_event() == Schema::IServer::Event::FUNCTION_CALL_START) + { + /** Flush any accumulated text as an assistant message */ + flushCurrentText(); + auto data = segClass.get_data(); + if (data.has_value()) + { + pendingCall = std::move(data.value()); + } + hasPendingCall = true; + } + else if (segClass.get_event() == Schema::IServer::Event::FUNCTION_CALL_END) + { + if (hasPendingCall) + { + auto data = segClass.get_data(); + if (data.has_value()) + { + pendingCall.set_arguments(data.value().get_arguments()); + } + responseMessages.push_back(std::move(pendingCall)); + hasPendingCall = false; + } + } + co_yield static_cast( + Schema::IServer::ChatCompletionSegment{std::move(segClass)}); + } + } + if (!batchText.empty()) + { + co_yield static_cast( + Schema::IServer::ChatCompletionSegment{std::move(batchText)}); } - wholeResponse += segment; - co_yield static_cast(segment); } + /** Flush any remaining text after stream ends */ + flushCurrentText(); } - /** Save changes to the database */ - Common::Uuid userMessageId{}; - Common::Uuid responseMessageId{}; + /** Save all new messages to the database as a chain */ + std::vector allMessageIds{}; { - Schema::IServer::MessageNode userNode{}; - userNode.set_id(static_cast(userMessageId)); - userNode.set_message(std::move(params.get_mutable_user_message())); - if (parentId != nullptr) + std::vector allNodes{}; + double inputTs = static_cast(inputTimestamp); + for (auto& msg : params.get_mutable_messages()) { - userNode.set_parent(static_cast(parentId)); + Schema::IServer::MessageNode node{}; + node.set_id(static_cast(Common::Uuid{})); + node.set_message(std::move(msg)); + node.set_timestamp(inputTs); + allNodes.push_back(std::move(node)); + } + double responseTs = static_cast(Common::Timestamp::GetWallClock()); + for (auto& msg : responseMessages) + { + Schema::IServer::MessageNode node{}; + node.set_id(static_cast(Common::Uuid{})); + node.set_message(std::move(msg)); + node.set_timestamp(responseTs); + allNodes.push_back(std::move(node)); + } + + /** Chain nodes together */ + for (size_t i = 0; i < allNodes.size(); ++i) + { + if (i == 0) + { + if (parentId != nullptr) + { + allNodes[i].set_parent(static_cast(parentId)); + } + } + else + { + allNodes[i].set_parent(allNodes[i - 1].get_id()); + } + + if (i + 1 < allNodes.size()) + { + allNodes[i].set_children({allNodes[i + 1].get_id()}); + } + } + + /** Write nodes to the database */ + for (size_t i = 0; i < allNodes.size(); ++i) + { + allMessageIds.push_back(allNodes[i].get_id()); + co_await _database->AppendChatHistoryAsync( + callerId.userId, chatId, + std::move(allNodes[i]), + i == 0); } - userNode.set_children({static_cast(responseMessageId)}); - userNode.set_timestamp(static_cast(userMessageTimestamp)); - co_await _database->AppendChatHistoryAsync( - callerId.userId, chatId, - std::move(userNode)); - } - { - Schema::IServer::MessageNode responseNode{}; - responseNode.set_id(static_cast(responseMessageId)); - Schema::IServer::Message responseMessage{}; - responseMessage.set_role(MessageRoleType::ASSISTANT); - using MessageContentType = std::remove_reference::type; - MessageContentType responseContents{}; - decltype(responseContents)::value_type responseContent{}; - using ContentTypeType = std::remove_reference::type; - responseContent.set_type(ContentTypeType::TEXT); - responseContent.set_data(wholeResponse); - responseContents.push_back(std::move(responseContent)); - responseMessage.set_content(std::move(responseContents)); - responseNode.set_message(std::move(responseMessage)); - responseNode.set_parent(static_cast(userMessageId)); - responseNode.set_timestamp(static_cast(Common::Timestamp::GetWallClock())); - co_await _database->AppendChatHistoryAsync( - callerId.userId, chatId, - std::move(responseNode), - /** Skip parent update. As the user node is already written */ - false); } /** Return completion info */ Schema::IServer::ChatCompletionInfo completionInfo{}; - completionInfo.set_user_message_id(static_cast(userMessageId)); - completionInfo.set_assistant_message_id(static_cast(responseMessageId)); + completionInfo.set_message_ids(std::move(allMessageIds)); co_return static_cast(completionInfo); } @@ -649,22 +711,15 @@ JS::Promise Service::OnExecuteGenerationTaskAsync(CallerId calle auto params = ParseParams(paramsJson); Common::Uuid modelId{params.get_model_id()}; auto provider = GetProvider(modelId); - Schema::IServer::LinearHistory history{}; - history.push_back(std::move(params.get_message())); - auto requestData = provider->FormatRequest(history, false); + auto requestData = provider->FormatRequest(params.get_messages(), false, params.get_tools()); auto request = _httpClient->MakeRequest( Network::Http::Method::POST, requestData); auto response = co_await request.GetResponseAsync(); - auto content = provider->ParseResponse(response); - using ContentTypeType = std::remove_reference::type; - if (content.get_type() != ContentTypeType::TEXT) - { - throw Schema::Rpc::Exception( - Schema::Rpc::ErrorCode::BAD_GATEWAY, - content.get_data()); - } - co_return static_cast(content.get_data()); + auto result = provider->ParseResponse(response); + Schema::IServer::ExecuteGenerationTaskResult taskResult; + taskResult.set_messages(std::move(result)); + co_return static_cast(taskResult); } /** diff --git a/test/ServiceClient/ChatBot.js b/test/ServiceClient/ChatBot.js index b44b5a0..2805033 100644 --- a/test/ServiceClient/ChatBot.js +++ b/test/ServiceClient/ChatBot.js @@ -2,6 +2,7 @@ import { TUIClient } from "./TUIClient.js"; import fs from "fs"; import { exit } from "process"; import readline from 'readline'; +import { execSync } from 'child_process'; if (process.argv.length < 3) { console.error("Usage: node ChatBot.js "); @@ -31,45 +32,175 @@ const lineReader = readline.createInterface({ output: process.stdout }); +/** @type {string|undefined} */ let parentId = undefined; +/** @type {Array<{type: string, call_id: string, name: string, arguments: string}>} */ +let pendingFunctionCalls = []; + +const tools = [ + { + name: 'run_python', + description: 'Execute Python code and return stdout/stderr output. Use print() to produce output; bare expressions are not displayed. Use this to run calculations, data processing, or any Python code.', + parameters: { + type: 'object', + properties: { + code: { + type: 'string', + description: 'The Python code to execute' + } + }, + required: ['code'] + } + } +]; + +/** + * @param {string} prompt + * @returns {Promise} + */ +function askUser(prompt) { + return new Promise((resolve) => { + lineReader.question(prompt, (input) => { + resolve(input.trim().toLowerCase()); + }); + }); +} + +/** + * @param {string} code + * @returns {string} + */ +function runPython(code) { + try { + const output = execSync('python3', { + input: code, + encoding: 'utf-8', + timeout: 30000, + stdio: ['pipe', 'pipe', 'pipe'] + }); + return output; + } catch (/** @type {any} */ err) { + let result = ''; + if (err.stdout) result += err.stdout; + if (err.stderr) result += err.stderr; + return result || err.message; + } +} + +/** + * @param {{type: string, call_id: string, name: string, arguments: string}} call + * @returns {Promise<{call_id: string, type: string, output: Array<{type: string, data: string}>}>} + */ +async function executeFunctionCall(call) { + if (call.name === 'run_python') { + let args; + try { + args = JSON.parse(call.arguments); + } catch { + return { + call_id: call.call_id, + type: 'function_call_output', + output: [{ type: 'text', data: 'Error: Invalid JSON arguments' }] + }; + } + const code = args.code ?? ''; + console.log(`\n[Python Code]\n${code}`); + const answer = await askUser('\nExecute this code? (y/n): '); + if (answer !== 'y' && answer !== 'yes') { + console.log('[Execution denied by user]'); + return { + call_id: call.call_id, + type: 'function_call_output', + output: [{ type: 'text', data: 'Error: Execution denied by user' }] + }; + } + const output = runPython(code); + console.log(`[Output]\n${output}`); + return { + call_id: call.call_id, + type: 'function_call_output', + output: [{ type: 'text', data: output }] + }; + } + + console.log(`\n[Unknown Function Call] ${call.name}(${call.arguments})`); + return { + call_id: call.call_id, + type: 'function_call_output', + output: [{ type: 'text', data: `Error: Unknown function "${call.name}"` }] + }; +} + while(true) { - const userMessage = await new Promise((resolve) => { - lineReader.question('User: \n', (input) => { - resolve(input); - }) - }); - console.log("") + /** @type {Array} */ + let newMessages; - const stream = client.makeStreamRequestAsync('chatCompletion', { - id: chatId, - modelId: modelId, - parent: parentId, - userMessage: { + if (pendingFunctionCalls.length > 0) { + /** Process pending function calls and send results */ + newMessages = await Promise.all(pendingFunctionCalls.map(call => executeFunctionCall(call))); + pendingFunctionCalls = []; + } else { + /** Get user input */ + const userMessage = await new Promise((resolve) => { + lineReader.question('User: \n', (input) => { + resolve(input); + }) + }); + console.log(""); + newMessages = [{ role: 'user', content:[{ type: 'text', data: userMessage }] - } + }]; + } + + const stream = client.makeStreamRequestAsync('chatCompletion', { + id: chatId, + modelId: modelId, + parent: parentId, + messages: newMessages, + tools: tools }); console.log("Assistant: "); let result = undefined; + pendingFunctionCalls = []; + /** @type {{call_id?: string, name?: string}|undefined} */ + let currentCall = undefined; + while(!(result = await stream.next()).done){ const chunk = result.value; - process.stdout.write(chunk); + if (typeof chunk === 'string') { + process.stdout.write(chunk); + } else if (chunk && typeof chunk === 'object') { + if (chunk.event === 'function_call_start') { + currentCall = chunk.data ?? {}; + } else if (chunk.event === 'function_call_end') { + const callData = chunk.data ?? {}; + pendingFunctionCalls.push({ + type: 'function_call', + call_id: currentCall?.call_id ?? callData.call_id ?? '', + name: currentCall?.name ?? callData.name ?? '', + arguments: callData.arguments ?? '' + }); + currentCall = undefined; + } + } } console.log("\n"); - /** The chat info */ + if (result === undefined) { console.error("No result received from chat completion."); exit(1); } const info = result.value; - parentId = info.assistantMessageId; -} - - - + const messageIds = info.messageIds; + parentId = messageIds[messageIds.length - 1]; + if (pendingFunctionCalls.length > 0) { + console.log(`[${pendingFunctionCalls.length} function call(s) to process]`); + } +} diff --git a/test/TestDatabase.cpp b/test/TestDatabase.cpp index 4878ba2..b61e9f0 100644 --- a/test/TestDatabase.cpp +++ b/test/TestDatabase.cpp @@ -170,32 +170,29 @@ JS::Promise TestChatAsync() auto metadata = db->GetChatMetadata(userId, chatId); AssertWithMessage(metadata == "test-chat-metadata", "Chat metadata should match"); { - IServer::Message message0{}; - using MessageRoleType = std::remove_reference::type; - message0.set_role(MessageRoleType::USER); - using MessageContentType = std::remove_reference::type::value_type; - MessageContentType content0{}; - using MessageContentTypeType = std::remove_reference::type; - content0.set_type(MessageContentTypeType::TEXT); + IServer::ChatMessage chatMsg0{}; + chatMsg0.set_role(IServer::ChatMessageRole::USER); + IServer::MessageContent content0{}; + content0.set_type(IServer::MessageContentType::TEXT); content0.set_data("Hello, this is a test message."); - message0.get_mutable_content().push_back(std::move(content0)); + chatMsg0.set_content({std::move(content0)}); IServer::MessageNode node0{}; node0.set_id("node0"); node0.set_timestamp(1.0); - node0.set_message(std::move(message0)); + node0.set_message(std::move(chatMsg0)); co_await db->AppendChatHistoryAsync(userId, chatId, node0); IServer::MessageNode node1{}; node1.set_id("node1"); node1.set_parent("node0"); node1.set_timestamp(2.0); - IServer::Message message1{}; - message1.set_role(MessageRoleType::ASSISTANT); - MessageContentType content1{}; - content1.set_type(MessageContentTypeType::TEXT); + IServer::ChatMessage chatMsg1{}; + chatMsg1.set_role(IServer::ChatMessageRole::ASSISTANT); + IServer::MessageContent content1{}; + content1.set_type(IServer::MessageContentType::TEXT); content1.set_data("This is a response message."); - message1.get_mutable_content().push_back(std::move(content1)); - node1.set_message(std::move(message1)); + chatMsg1.set_content({std::move(content1)}); + node1.set_message(std::move(chatMsg1)); co_await db->AppendChatHistoryAsync(userId, chatId, node1); auto history = db->GetChatHistory(userId, chatId); @@ -205,9 +202,10 @@ JS::Promise TestChatAsync() AssertWithMessage(retrievedNode0It != nodes.end(), "Node0 should be found"); auto& retrievedNode0 = retrievedNode0It->second; AssertWithMessage(retrievedNode0.get_timestamp() == 1.0, "Node0 timestamp should match"); - AssertWithMessage(retrievedNode0.get_message().get_role() == MessageRoleType::USER, "Node0 role should match"); - AssertWithMessage(retrievedNode0.get_message().get_content().size() == 1, "Node0 content size should match"); - AssertWithMessage(retrievedNode0.get_message().get_content().front().get_data() == "Hello, this is a test message.", "Node0 content data should match"); + auto& msg0 = std::get(retrievedNode0.get_message()); + AssertWithMessage(msg0.get_role() == IServer::ChatMessageRole::USER, "Node0 role should match"); + AssertWithMessage(msg0.get_content().size() == 1, "Node0 content size should match"); + AssertWithMessage(msg0.get_content().front().get_data() == "Hello, this is a test message.", "Node0 content data should match"); AssertWithMessage(retrievedNode0.get_parent().has_value() == false, "Node0 parent should be null"); AssertWithMessage(retrievedNode0.get_children().size() == 1, "Node0 should have 1 child"); AssertWithMessage(retrievedNode0.get_children().front() == "node1", "Node0 child ID should match"); @@ -215,9 +213,10 @@ JS::Promise TestChatAsync() AssertWithMessage(retrievedNode1It != nodes.end(), "Node1 should be found"); auto& retrievedNode1 = retrievedNode1It->second; AssertWithMessage(retrievedNode1.get_timestamp() == 2.0, "Node1 timestamp should match"); - AssertWithMessage(retrievedNode1.get_message().get_role() == MessageRoleType::ASSISTANT, "Node1 role should match"); - AssertWithMessage(retrievedNode1.get_message().get_content().size() == 1, "Node1 content size should match"); - AssertWithMessage(retrievedNode1.get_message().get_content().front().get_data() == "This is a response message.", "Node1 content data should match"); + auto& msg1 = std::get(retrievedNode1.get_message()); + AssertWithMessage(msg1.get_role() == IServer::ChatMessageRole::ASSISTANT, "Node1 role should match"); + AssertWithMessage(msg1.get_content().size() == 1, "Node1 content size should match"); + AssertWithMessage(msg1.get_content().front().get_data() == "This is a response message.", "Node1 content data should match"); AssertWithMessage(retrievedNode1.get_parent().has_value() == true, "Node1 parent should not be null"); AssertWithMessage(retrievedNode1.get_parent().value() == "node0", "Node1 parent ID should match"); AssertWithMessage(retrievedNode1.get_children().empty(), "Node1 should have no children"); From 3aaae0c4381a49f1a276f55a4eaa03eb29aedc3e Mon Sep 17 00:00:00 2001 From: Feng Wang Date: Sun, 15 Mar 2026 18:21:52 +0800 Subject: [PATCH 08/11] update types --- src/schema/src/types | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/schema/src/types b/src/schema/src/types index 2f92040..aa233cb 160000 --- a/src/schema/src/types +++ b/src/schema/src/types @@ -1 +1 @@ -Subproject commit 2f9204085a9b0307c1593d39bf4489ec9314b940 +Subproject commit aa233cb6777dc9936c8e376dc8187822aa8cc4df From 6da2db33627981539a1b96772163ae1313c93e16 Mon Sep 17 00:00:00 2001 From: Feng Wang Date: Sat, 21 Mar 2026 17:49:10 +0800 Subject: [PATCH 09/11] fix ai slop --- src/apiProvider/OpenAI.cpp | 44 ++++++++++++++++++------------------- src/application/Service.cpp | 20 ++++------------- 2 files changed, 26 insertions(+), 38 deletions(-) diff --git a/src/apiProvider/OpenAI.cpp b/src/apiProvider/OpenAI.cpp index 5adf076..15a4682 100644 --- a/src/apiProvider/OpenAI.cpp +++ b/src/apiProvider/OpenAI.cpp @@ -161,10 +161,6 @@ RequestData OpenAI::FormatRequest(const Schema::IServer::LinearHistory& history, { const auto& functionCallOutput = std::get(message); auto functionCallOutputJson = nlohmann::json::object(); - if (functionCallOutput.get_extra().has_value()) - { - functionCallOutputJson = functionCallOutput.get_extra().value(); - } functionCallOutputJson["type"] = "function_call_output"; functionCallOutputJson["call_id"] = functionCallOutput.get_call_id(); std::string outputText; @@ -254,7 +250,10 @@ Schema::IServer::LinearHistory OpenAI::ParseResponse(const std::string& response functionCall.set_call_id(item.at("call_id").get()); functionCall.set_name(item.at("name").get()); functionCall.set_arguments(item.at("arguments").get()); - functionCall.set_extra(std::make_optional(item)); + nlohmann::json extra; + if (item.contains("id")) extra["id"] = item["id"]; + if (item.contains("status")) extra["status"] = item["status"]; + if (!extra.empty()) functionCall.set_extra(std::make_optional(extra)); results.push_back(std::move(functionCall)); } } @@ -308,15 +307,8 @@ std::optional OpenAI::ParseStreamRespons if (json.contains("item") && json.at("item").contains("type") && json.at("item").at("type").get() == "function_call") { - const auto& item = json.at("item"); - Schema::IServer::FunctionCallMessage functionCall; - functionCall.set_type(Schema::IServer::FunctionCallMessageType::FUNCTION_CALL); - functionCall.set_call_id(item.at("call_id").get()); - functionCall.set_name(item.at("name").get()); - functionCall.set_arguments(item.value("arguments", "")); Schema::IServer::ChatCompletionSegmentClass segment; segment.set_event(Schema::IServer::Event::FUNCTION_CALL_START); - segment.set_data(std::move(functionCall)); return segment; } } @@ -325,20 +317,28 @@ std::optional OpenAI::ParseStreamRespons return std::nullopt; } } - else if (eventType == "response.function_call_arguments.done") + else if (eventType == "response.output_item.done") { try { auto json = nlohmann::json::parse(valueString); - Schema::IServer::FunctionCallMessage functionCall; - functionCall.set_type(Schema::IServer::FunctionCallMessageType::FUNCTION_CALL); - functionCall.set_call_id(json.value("call_id", "")); - functionCall.set_name(json.value("name", "")); - functionCall.set_arguments(json.value("arguments", "")); - Schema::IServer::ChatCompletionSegmentClass segment; - segment.set_event(Schema::IServer::Event::FUNCTION_CALL_END); - segment.set_data(std::move(functionCall)); - return segment; + if (json.contains("item") && json.at("item").contains("type") + && json.at("item").at("type").get() == "function_call") + { + Schema::IServer::ChatCompletionSegmentClass segment; + segment.set_event(Schema::IServer::Event::FUNCTION_CALL_END); + Schema::IServer::FunctionCallMessage functionCall; + functionCall.set_type(Schema::IServer::FunctionCallMessageType::FUNCTION_CALL); + const auto& item = json.at("item"); + functionCall.set_call_id(item.at("call_id").get()); + functionCall.set_name(item.at("name").get()); + functionCall.set_arguments(item.value("arguments", "")); + nlohmann::json extra; + extra["status"] = "completed"; + functionCall.set_extra(std::make_optional(extra)); + segment.set_data(std::move(functionCall)); + return segment; + } } catch (...) { diff --git a/src/application/Service.cpp b/src/application/Service.cpp index 5b165de..ef9334c 100644 --- a/src/application/Service.cpp +++ b/src/application/Service.cpp @@ -551,8 +551,6 @@ JS::AsyncGenerator Service::OnChatCompletionAsyn auto streamBatcher = Common::StreamBatcher::BatchStream(_tev, std::move(eventStream), STREAM_BATCHING_INTERVAL_MS); std::string currentText{}; - Schema::IServer::FunctionCallMessage pendingCall{}; - bool hasPendingCall = false; auto flushCurrentText = [&]() { if (currentText.empty()) return; @@ -601,25 +599,15 @@ JS::AsyncGenerator Service::OnChatCompletionAsyn { /** Flush any accumulated text as an assistant message */ flushCurrentText(); - auto data = segClass.get_data(); - if (data.has_value()) - { - pendingCall = std::move(data.value()); - } - hasPendingCall = true; } else if (segClass.get_event() == Schema::IServer::Event::FUNCTION_CALL_END) { - if (hasPendingCall) + auto data = segClass.get_data(); + if (!data.has_value()) { - auto data = segClass.get_data(); - if (data.has_value()) - { - pendingCall.set_arguments(data.value().get_arguments()); - } - responseMessages.push_back(std::move(pendingCall)); - hasPendingCall = false; + throw Schema::Rpc::Exception(Schema::Rpc::ErrorCode::INTERNAL_SERVER_ERROR, "Function call segment missing data"); } + responseMessages.push_back(data.value()); } co_yield static_cast( Schema::IServer::ChatCompletionSegment{std::move(segClass)}); From 9813d3c55c305fc9c988b240fe0cecb11fa369b7 Mon Sep 17 00:00:00 2001 From: Feng Wang Date: Sun, 3 May 2026 11:32:06 +0800 Subject: [PATCH 10/11] Fix file path resolution --- src/application/Service.cpp | 1 + src/database/Database.cpp | 28 +++++++++++++++++++++++----- src/database/Database.h | 3 +++ 3 files changed, 27 insertions(+), 5 deletions(-) diff --git a/src/application/Service.cpp b/src/application/Service.cpp index ef9334c..4ee7c04 100644 --- a/src/application/Service.cpp +++ b/src/application/Service.cpp @@ -1086,6 +1086,7 @@ JS::Promise Service::OnListFileAsync(CallerId callerId, nlohmann using EntryType = decltype(result)::value_type; EntryType entry{}; entry.set_file_id(static_cast(item.fileId)); + entry.set_content_id(item.contentId); nlohmann::json metadata = nlohmann::json::parse(item.metadata); entry.set_file_metadata(std::move(metadata)); result.push_back(std::move(entry)); diff --git a/src/database/Database.cpp b/src/database/Database.cpp index 0dc11e4..3d7645c 100644 --- a/src/database/Database.cpp +++ b/src/database/Database.cpp @@ -16,7 +16,6 @@ JS::Promise> Database::CreateAsync( { auto db = std::shared_ptr(new Database()); - db->_fileDirectory = fileDirectory; if (!std::filesystem::exists(fileDirectory)) { std::filesystem::create_directories(fileDirectory); @@ -28,6 +27,7 @@ JS::Promise> Database::CreateAsync( throw std::runtime_error("File directory path exists but is not a directory"); } } + db->_fileDirectory = std::filesystem::canonical(fileDirectory); db->_db = co_await Sqlite::CreateAsync(tev, dbPath); /** Create tables */ @@ -646,6 +646,24 @@ std::string Database::GetStringFromChat( } } +std::filesystem::path Database::ResolveUserFilePath( + const Common::Uuid& userId, const std::string& contentId) const +{ + if (contentId.empty()) + { + throw std::runtime_error("Invalid contentId: empty"); + } + auto userDirectory = _fileDirectory / static_cast(userId); + auto filePath = userDirectory / contentId; + auto canonicalUserDir = std::filesystem::weakly_canonical(userDirectory); + auto canonicalFilePath = std::filesystem::weakly_canonical(filePath); + if (canonicalFilePath.parent_path() != canonicalUserDir) + { + throw std::runtime_error("Invalid contentId: path traversal"); + } + return canonicalFilePath; +} + JS::Promise Database::SaveFileAsync( const Common::Uuid& userId, std::string metadata, std::vector content) { @@ -658,7 +676,7 @@ JS::Promise Database::SaveFileAsync( { std::filesystem::create_directories(userDirectory); } - auto filePath = userDirectory / contentId; + auto filePath = ResolveUserFilePath(userId, contentId); if (!std::filesystem::exists(filePath)) { /** @todo: Make this async */ @@ -718,7 +736,7 @@ JS::Promise Database::DeleteFileAsync( { co_return; } - auto filePath = _fileDirectory / static_cast(userId) / contentId; + auto filePath = ResolveUserFilePath(userId, contentId); if (std::filesystem::exists(filePath)) { std::filesystem::remove(filePath); @@ -761,8 +779,8 @@ Database::FileMeta Database::GetFileMeta( std::vector Database::GetFileContent( const Common::Uuid& userId, const std::string& contentId) { - auto filePath = _fileDirectory / static_cast(userId) / contentId; - if (!std::filesystem::exists(filePath)) + auto filePath = ResolveUserFilePath(userId, contentId); + if (!std::filesystem::is_regular_file(filePath)) { throw std::runtime_error("File content not found"); } diff --git a/src/database/Database.h b/src/database/Database.h index 59a33cc..2876ea4 100644 --- a/src/database/Database.h +++ b/src/database/Database.h @@ -110,6 +110,9 @@ namespace TUI::Database private: Database() = default; + std::filesystem::path ResolveUserFilePath( + const Common::Uuid& userId, const std::string& contentId) const; + std::list ParseListTableIdWithMetadataResult( Sqlite::ExecResult& result); JS::Promise SetStringToTableById( From af19c0150a43b564e59faf62d813c953be341c82 Mon Sep 17 00:00:00 2001 From: Feng Wang Date: Sun, 3 May 2026 14:02:44 +0800 Subject: [PATCH 11/11] add missing dependency in pipeline Co-authored-by: Copilot --- .github/workflows/ci.yml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 62385bc..310570e 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -61,7 +61,8 @@ jobs: nlohmann-json3-dev \ libssl-dev \ valgrind \ - libzstd-dev + libzstd-dev \ + libxxhash-dev - name: Install libsodium (>=1.0.19) from source run: |