Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/iceberg/catalog/rest/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ set(ICEBERG_REST_SOURCES
auth/auth_properties.cc
auth/auth_session.cc
auth/oauth2_util.cc
auth/token_refresh_scheduler.cc
catalog_properties.cc
endpoint.cc
error_handlers.cc
Expand Down
194 changes: 188 additions & 6 deletions src/iceberg/catalog/rest/auth/auth_session.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,17 @@

#include "iceberg/catalog/rest/auth/auth_session.h"

#include <algorithm>
#include <chrono>
#include <memory>
#include <shared_mutex>
#include <thread>
#include <utility>

#include "iceberg/catalog/rest/auth/auth_properties.h"
#include "iceberg/catalog/rest/auth/oauth2_util.h"
#include "iceberg/catalog/rest/auth/token_refresh_scheduler.h"
#include "iceberg/catalog/rest/http_client.h"

namespace iceberg::rest::auth {

Expand All @@ -44,6 +52,175 @@ class DefaultAuthSession : public AuthSession {
std::unordered_map<std::string, std::string> headers_;
};

/// \brief OAuth2 session with automatic token refresh.
class OAuth2AuthSession : public AuthSession,
public std::enable_shared_from_this<OAuth2AuthSession> {
public:
struct Config {
std::string token_endpoint;
std::string client_id;
std::string client_secret;
std::string scope;
bool keep_refreshed;
};

/// \brief Create an OAuth2 session and optionally schedule refresh.
static std::shared_ptr<OAuth2AuthSession> Create(
const OAuthTokenResponse& initial_token, Config config, HttpClient& client) {
auto session = std::shared_ptr<OAuth2AuthSession>(
new OAuth2AuthSession(std::move(config), client));
session->SetInitialToken(initial_token);
return session;
}

Status Authenticate(std::unordered_map<std::string, std::string>& headers) override {
std::shared_lock lock(mutex_);
for (const auto& [key, value] : headers_) {
headers.insert_or_assign(key, value);
}
return {};
}

Status Close() override {
bool expected = false;
if (!closed_.compare_exchange_strong(expected, true)) {
return {}; // Already closed
}
TokenRefreshScheduler::Instance().Cancel(scheduled_task_id_.load());
return {};
}

private:
OAuth2AuthSession(Config config, HttpClient& client)
: config_(std::move(config)), client_(client) {}

void SetInitialToken(const OAuthTokenResponse& token_response) {
token_ = token_response.access_token;
headers_ = {{std::string(kAuthorizationHeader), std::string(kBearerPrefix) + token_}};

// Determine expiration time
if (token_response.expires_in_secs.has_value()) {
expires_at_ = std::chrono::steady_clock::now() +
std::chrono::seconds(*token_response.expires_in_secs);
} else if (auto exp_ms = ExpiresAtMillis(token_); exp_ms.has_value()) {
// Convert absolute epoch millis to steady_clock time_point
auto now_sys = std::chrono::system_clock::now();
auto now_steady = std::chrono::steady_clock::now();
auto exp_sys =
std::chrono::system_clock::time_point(std::chrono::milliseconds(*exp_ms));
expires_at_ = now_steady + (exp_sys - now_sys);
}

if (config_.keep_refreshed &&
expires_at_ != std::chrono::steady_clock::time_point{}) {
ScheduleRefresh();
}
}

void DoRefresh() {
if (closed_.load()) return;

constexpr int kMaxRetries = 5;
constexpr auto kInitialBackoff = std::chrono::milliseconds(200);
constexpr auto kMaxBackoff = std::chrono::seconds(10);

// Build credential string for FetchToken
std::string credential = config_.client_id.empty()
? config_.client_secret
: config_.client_id + ":" + config_.client_secret;

auto backoff = kInitialBackoff;
for (int attempt = 0; attempt < kMaxRetries; ++attempt) {
if (closed_.load()) return;

// Use an empty session for the refresh request (no auth headers)
auto empty_session = AuthSession::MakeDefault({});

// Build properties for FetchToken
AuthProperties props;
props.Set(AuthProperties::kCredential, credential);
props.Set(AuthProperties::kScope, config_.scope);
props.Set(AuthProperties::kOAuth2ServerUri, config_.token_endpoint);

auto result = FetchToken(client_, *empty_session, props);
if (result.has_value()) {
auto& response = result.value();
{
std::unique_lock lock(mutex_);
token_ = response.access_token;
headers_ = {
{std::string(kAuthorizationHeader), std::string(kBearerPrefix) + token_}};

// Update expiration
if (response.expires_in_secs.has_value()) {
expires_at_ = std::chrono::steady_clock::now() +
std::chrono::seconds(*response.expires_in_secs);
} else if (auto exp_ms = ExpiresAtMillis(token_); exp_ms.has_value()) {
auto now_sys = std::chrono::system_clock::now();
auto now_steady = std::chrono::steady_clock::now();
auto exp_sys =
std::chrono::system_clock::time_point(std::chrono::milliseconds(*exp_ms));
expires_at_ = now_steady + (exp_sys - now_sys);
}
}
// Schedule next refresh
ScheduleRefresh();
return; // Success
}

// Retry with exponential backoff
if (attempt < kMaxRetries - 1) {
std::this_thread::sleep_for(backoff);
backoff =
std::min(std::chrono::duration_cast<std::chrono::milliseconds>(backoff * 2),
std::chrono::duration_cast<std::chrono::milliseconds>(kMaxBackoff));
}
}

// All retries failed — stop refreshing silently
// Next request will use the expired token; server returns 401
}

void ScheduleRefresh() {
if (!config_.keep_refreshed || closed_.load()) return;

auto delay = CalculateRefreshDelay();
if (delay <= std::chrono::milliseconds::zero()) return;

std::weak_ptr<OAuth2AuthSession> weak_self = shared_from_this();
auto new_id = TokenRefreshScheduler::Instance().Schedule(
delay, [weak_self = std::move(weak_self)] {
if (auto self = weak_self.lock()) {
self->DoRefresh();
}
});
scheduled_task_id_.store(new_id);
}

std::chrono::milliseconds CalculateRefreshDelay() const {
std::shared_lock lock(mutex_);
auto now = std::chrono::steady_clock::now();
if (expires_at_ <= now) return std::chrono::milliseconds::zero();

auto expires_in =
std::chrono::duration_cast<std::chrono::milliseconds>(expires_at_ - now);
// Refresh window: 10% of remaining time, capped at 5 minutes
auto refresh_window = std::min(expires_in / 10, std::chrono::milliseconds(300'000));
auto wait_time = expires_in - refresh_window;
return std::max(wait_time, std::chrono::milliseconds(10));
}

mutable std::shared_mutex mutex_; // protects token_, headers_, expires_at_
std::string token_;
std::unordered_map<std::string, std::string> headers_;
std::chrono::steady_clock::time_point expires_at_{};

Config config_;
HttpClient& client_;
std::atomic<uint64_t> scheduled_task_id_{0};
std::atomic<bool> closed_{false};
};

} // namespace

std::shared_ptr<AuthSession> AuthSession::MakeDefault(
Expand All @@ -52,12 +229,17 @@ std::shared_ptr<AuthSession> AuthSession::MakeDefault(
}

std::shared_ptr<AuthSession> AuthSession::MakeOAuth2(
const OAuthTokenResponse& initial_token, const std::string& /*token_endpoint*/,
const std::string& /*client_id*/, const std::string& /*client_secret*/,
const std::string& /*scope*/, HttpClient& /*client*/) {
// TODO(lishuxu): Create OAuth2AuthSession with auto-refresh support.
return MakeDefault({{std::string(kAuthorizationHeader),
std::string(kBearerPrefix) + initial_token.access_token}});
const OAuthTokenResponse& initial_token, const std::string& token_endpoint,
const std::string& client_id, const std::string& client_secret,
const std::string& scope, HttpClient& client) {
OAuth2AuthSession::Config config{
.token_endpoint = token_endpoint,
.client_id = client_id,
.client_secret = client_secret,
.scope = scope,
.keep_refreshed = true,
};
return OAuth2AuthSession::Create(initial_token, std::move(config), client);
}

} // namespace iceberg::rest::auth
45 changes: 45 additions & 0 deletions src/iceberg/catalog/rest/auth/oauth2_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#include "iceberg/catalog/rest/json_serde_internal.h"
#include "iceberg/json_serde_internal.h"
#include "iceberg/util/macros.h"
#include "iceberg/util/transform_util.h"

namespace iceberg::rest::auth {

Expand Down Expand Up @@ -74,4 +75,48 @@ Result<OAuthTokenResponse> FetchToken(HttpClient& client, AuthSession& session,
return token_response;
}

std::optional<int64_t> ExpiresAtMillis(const std::string& token) {
if (token.empty()) {
return std::nullopt;
}

// A JWT has exactly 3 dot-separated parts: header.payload.signature
auto first_dot = token.find('.');
if (first_dot == std::string::npos) {
return std::nullopt;
}
auto second_dot = token.find('.', first_dot + 1);
if (second_dot == std::string::npos) {
return std::nullopt;
}
// Ensure there's no third dot (exactly 3 parts)
if (token.find('.', second_dot + 1) != std::string::npos) {
return std::nullopt;
}

// Extract and decode the payload (second part)
std::string_view payload_b64 =
std::string_view(token).substr(first_dot + 1, second_dot - first_dot - 1);
std::string payload = TransformUtil::Base64UrlDecode(payload_b64);
if (payload.empty()) {
return std::nullopt;
}

// Parse JSON and extract "exp" claim
try {
auto json = nlohmann::json::parse(payload, nullptr, false);
if (json.is_discarded() || !json.is_object()) {
return std::nullopt;
}
auto it = json.find("exp");
if (it == json.end() || !it->is_number_integer()) {
return std::nullopt;
}
int64_t exp_seconds = it->get<int64_t>();
return exp_seconds * 1000; // Convert seconds to milliseconds
} catch (...) {
return std::nullopt;
}
}

} // namespace iceberg::rest::auth
12 changes: 12 additions & 0 deletions src/iceberg/catalog/rest/auth/oauth2_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@

#pragma once

#include <cstdint>
#include <optional>
#include <string>
#include <string_view>
#include <unordered_map>
Expand Down Expand Up @@ -53,4 +55,14 @@ ICEBERG_REST_EXPORT Result<OAuthTokenResponse> FetchToken(
ICEBERG_REST_EXPORT std::unordered_map<std::string, std::string> AuthHeaders(
const std::string& token);

/// \brief Extract expiration time from a JWT token.
///
/// Decodes the JWT payload (base64url) and reads the "exp" claim.
/// Returns std::nullopt if the token is not a valid JWT or has no "exp" claim.
///
/// \param token A token string. If it is a JWT (three dot-separated base64url
/// segments), the "exp" claim is extracted from the payload.
/// \return Expiration time as milliseconds since epoch, or std::nullopt.
ICEBERG_REST_EXPORT std::optional<int64_t> ExpiresAtMillis(const std::string& token);

} // namespace iceberg::rest::auth
2 changes: 2 additions & 0 deletions src/iceberg/catalog/rest/meson.build
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ iceberg_rest_sources = files(
'auth/auth_properties.cc',
'auth/auth_session.cc',
'auth/oauth2_util.cc',
'auth/token_refresh_scheduler.cc',
'catalog_properties.cc',
'endpoint.cc',
'error_handlers.cc',
Expand Down Expand Up @@ -87,6 +88,7 @@ install_headers(
'auth/auth_properties.h',
'auth/auth_session.h',
'auth/oauth2_util.h',
'auth/token_refresh_scheduler.h',
],
subdir: 'iceberg/catalog/rest/auth',
)
Loading
Loading