Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ env:
TestChaCha20Poly1305
TestCounter
TestCryptoKdfHkdfSha256
TestDatabase /tmp/tui-test.db
TestDatabase /tmp/tui-test.db /tmp/files
TestEcdhePsk
TestEd25519
TestFakeCredentialGenerator
Expand Down Expand Up @@ -61,7 +61,8 @@ jobs:
nlohmann-json3-dev \
libssl-dev \
valgrind \
libzstd-dev
libzstd-dev \
libxxhash-dev

- name: Install libsodium (>=1.0.19) from source
run: |
Expand Down
6 changes: 4 additions & 2 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@ target_link_libraries(tui-server
sqlite3
uuid
sodium
zstd)
zstd
xxhash)

add_executable(tui-register
${CMAKE_CURRENT_SOURCE_DIR}/src/tui-register.cpp
Expand All @@ -66,7 +67,8 @@ target_include_directories(tui-register
target_link_libraries(tui-register
PRIVATE
sqlite3
uuid)
uuid
xxhash)

install(TARGETS tui-server RUNTIME DESTINATION bin)
install(TARGETS tui-register RUNTIME DESTINATION bin)
97 changes: 53 additions & 44 deletions src/apiProvider/AzureOpenAI.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ void AzureOpenAI::Initialize(const nlohmann::json& params)
_params = ParamsDefinition.Parse(params);
}

RequestData AzureOpenAI::FormatRequest(const Schema::IServer::LinearHistory& history, bool stream) const
RequestData AzureOpenAI::FormatRequest(const Schema::IServer::LinearHistory& history, bool stream, const std::optional<std::vector<Schema::IServer::Tool>>& /*tools*/) const
{
RequestData data{};
data.url = _params.url;
Expand All @@ -67,55 +67,64 @@ RequestData AzureOpenAI::FormatRequest(const Schema::IServer::LinearHistory& his
}
for (const auto& message : history)
{
auto messageJson = nlohmann::json::object();
messageJson["content"] = nlohmann::json::array();
if (message.get_role() == Schema::IServer::MessageRole::DEVELOPER)
if (std::holds_alternative<Schema::IServer::ChatMessage>(message))
{
messageJson["role"] = "system";
}
else if(message.get_role() == Schema::IServer::MessageRole::USER)
{
messageJson["role"] = "user";
}
else if(message.get_role() == Schema::IServer::MessageRole::ASSISTANT)
{
messageJson["role"] = "assistant";
}
else
{
continue;
}
for (const auto& content : message.get_content())
{
if (content.get_type() == Schema::IServer::Type::TEXT || content.get_type() == Schema::IServer::Type::REFUSAL)
const auto& chatMessage = std::get<Schema::IServer::ChatMessage>(message);
auto messageJson = nlohmann::json::object();
messageJson["content"] = nlohmann::json::array();
if (chatMessage.get_role() == Schema::IServer::ChatMessageRole::DEVELOPER)
{
messageJson["role"] = "system";
}
else if(chatMessage.get_role() == Schema::IServer::ChatMessageRole::USER)
{
/** Azure open ai does not seem to have a special REFUSAL message type */
messageJson["content"].push_back({
{"type", "text"},
{"text", content.get_data()}
});
messageJson["role"] = "user";
}
else if (content.get_type() == Schema::IServer::Type::IMAGE_URL)
else if(chatMessage.get_role() == Schema::IServer::ChatMessageRole::ASSISTANT)
{
messageJson["content"].push_back({
{"type", "image_url"},
{"image_url", {
{"url", content.get_data()}
}}
});
messageJson["role"] = "assistant";
}
else
{
continue;
}
for (const auto& content : chatMessage.get_content())
{
if (content.get_type() == Schema::IServer::MessageContentType::TEXT || content.get_type() == Schema::IServer::MessageContentType::REFUSAL)
{
/** Azure open ai does not seem to have a special REFUSAL message type */
messageJson["content"].push_back({
{"type", "text"},
{"text", content.get_data()}
});
}
else if (content.get_type() == Schema::IServer::MessageContentType::IMAGE_URL)
{
messageJson["content"].push_back({
{"type", "image_url"},
{"image_url", {
{"url", content.get_data()}
}}
});
}
else
{
continue;
}
}
body["messages"].push_back(messageJson);
}
else if (std::holds_alternative<Schema::IServer::FunctionCallMessage>(message)
|| std::holds_alternative<Schema::IServer::FunctionCallOutputMessage>(message))
{
throw std::runtime_error("AzureOpenAI does not support function call messages");
}
body["messages"].push_back(messageJson);
}
data.body = body.dump();
return data;
}

Schema::IServer::MessageContent AzureOpenAI::ParseResponse(const std::string& responseString) const
Schema::IServer::LinearHistory AzureOpenAI::ParseResponse(const std::string& responseString) const
{
Schema::AzureOpenAI::BulkResponse response;
try
Expand All @@ -131,14 +140,17 @@ Schema::IServer::MessageContent AzureOpenAI::ParseResponse(const std::string& re
throw std::runtime_error("No choices in response");
}
const auto& choice = response.get_choices().front();
const auto& message = choice.get_message();
const auto& msg = choice.get_message();
Schema::IServer::MessageContent content;
content.set_type(message.get_refusal() ? Schema::IServer::Type::REFUSAL : Schema::IServer::Type::TEXT);
content.set_data(message.get_content());
return content;
content.set_type(msg.get_refusal() ? Schema::IServer::MessageContentType::REFUSAL : Schema::IServer::MessageContentType::TEXT);
content.set_data(msg.get_content());
Schema::IServer::ChatMessage chatMessage;
chatMessage.set_role(Schema::IServer::ChatMessageRole::ASSISTANT);
chatMessage.set_content({content});
return Schema::IServer::LinearHistory{std::move(chatMessage)};
}

std::optional<Schema::IServer::MessageContent> AzureOpenAI::ParseStreamResponse(const StreamResponse::Event& event) const
std::optional<Schema::IServer::ChatCompletionSegment> AzureOpenAI::ParseStreamResponse(const StreamResponse::Event& event) const
{
if (!event.value.has_value())
{
Expand Down Expand Up @@ -172,8 +184,5 @@ std::optional<Schema::IServer::MessageContent> AzureOpenAI::ParseStreamResponse(
{
return std::nullopt;
}
Schema::IServer::MessageContent content{};
content.set_type(Schema::IServer::Type::TEXT);
content.set_data(delta.get_content().value());
return content;
return delta.get_content().value();
}
9 changes: 6 additions & 3 deletions src/apiProvider/AzureOpenAI.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,12 @@ namespace TUI::ApiProvider
~AzureOpenAI() override = default;
nlohmann::json GetParams() const override;
void Initialize(const nlohmann::json& params) override;
Network::Http::RequestData FormatRequest(const Schema::IServer::LinearHistory& history, bool stream) const override;
Schema::IServer::MessageContent ParseResponse(const std::string& response) const override;
std::optional<Schema::IServer::MessageContent> ParseStreamResponse(const Network::Http::StreamResponse::Event& event) const override;
Network::Http::RequestData FormatRequest(
const Schema::IServer::LinearHistory& history,
bool stream,
const std::optional<std::vector<Schema::IServer::Tool>>& tools = std::nullopt) const override;
Schema::IServer::LinearHistory ParseResponse(const std::string& response) const override;
std::optional<Schema::IServer::ChatCompletionSegment> ParseStreamResponse(const Network::Http::StreamResponse::Event& event) const override;
private:
struct Params
{
Expand Down
9 changes: 6 additions & 3 deletions src/apiProvider/IProvider.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,11 @@ namespace TUI::ApiProvider
virtual ~IProvider() = default;
virtual nlohmann::json GetParams() const = 0;
virtual void Initialize(const nlohmann::json& params) = 0;
virtual Network::Http::RequestData FormatRequest(const Schema::IServer::LinearHistory& history, bool stream) const = 0;
virtual Schema::IServer::MessageContent ParseResponse(const std::string& response) const = 0;
virtual std::optional<Schema::IServer::MessageContent> ParseStreamResponse(const Network::Http::StreamResponse::Event& event) const = 0;
virtual Network::Http::RequestData FormatRequest(
const Schema::IServer::LinearHistory& history,
bool stream,
const std::optional<std::vector<Schema::IServer::Tool>>& tools = std::nullopt) const = 0;
virtual Schema::IServer::LinearHistory ParseResponse(const std::string& response) const = 0;
virtual std::optional<Schema::IServer::ChatCompletionSegment> ParseStreamResponse(const Network::Http::StreamResponse::Event& event) const = 0;
};
}
Loading
Loading