Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
182 changes: 164 additions & 18 deletions src/core/ggml_extend.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1699,6 +1699,12 @@ struct GGMLRunner {
ggml_backend_buffer_t partial_runtime_params_buffer = nullptr;
std::vector<std::pair<ggml_tensor*, ggml_tensor*>> partial_offload_pairs;

// Next segment's params prefetched during current segment's compute.
ggml_context* pending_offload_ctx = nullptr;
ggml_backend_buffer_t pending_runtime_params_buffer = nullptr;
std::vector<std::pair<ggml_tensor*, ggml_tensor*>> pending_offload_pairs;
uint64_t pending_param_signature = 0;

// Params kept on the runtime backend across streaming segments.
ggml_context* resident_offload_ctx = nullptr;
std::vector<std::pair<ggml_tensor*, ggml_tensor*>> resident_offload_pairs;
Expand Down Expand Up @@ -2159,36 +2165,66 @@ struct GGMLRunner {
return true;
}

bool offload_partial_params(const std::vector<ggml_tensor*>& tensors) {
restore_partial_params();
if (params_backend == runtime_backend) {
return true;
}
if (tensors.empty()) {
return true;
static uint64_t param_signature(const std::vector<ggml_tensor*>& tensors) {
uint64_t h = 0;
for (ggml_tensor* t : tensors) {
h ^= reinterpret_cast<uintptr_t>(t) * 0x9E3779B97F4A7C15ull;
}
GGML_ASSERT(!params_on_runtime_backend);
GGML_ASSERT(partial_runtime_params_buffer == nullptr);
return h;
}

std::vector<ggml_tensor*> unique_tensors;
std::unordered_set<ggml_tensor*> seen_tensors;
void dedup_runtime_params(const std::vector<ggml_tensor*>& tensors,
std::vector<ggml_tensor*>& unique_tensors) {
std::unordered_set<ggml_tensor*> seen;
unique_tensors.reserve(tensors.size());
seen_tensors.reserve(tensors.size());
seen.reserve(tensors.size());
for (ggml_tensor* tensor : tensors) {
if (tensor == nullptr) {
continue;
}
if (resident_param_set.find(tensor) != resident_param_set.end()) {
continue;
}
if (seen_tensors.insert(tensor).second) {
if (seen.insert(tensor).second) {
unique_tensors.push_back(tensor);
}
}
}

bool offload_partial_params(const std::vector<ggml_tensor*>& tensors) {
if (params_backend == runtime_backend) {
restore_pending_params();
restore_partial_params();
return true;
}
if (tensors.empty()) {
restore_pending_params();
restore_partial_params();
return true;
}

std::vector<ggml_tensor*> unique_tensors;
dedup_runtime_params(tensors, unique_tensors);
if (unique_tensors.empty()) {
restore_pending_params();
restore_partial_params();
return true;
}

// Fast path: if the prefetch already loaded these exact params, just
// swap the original tensors onto the pending buffer (no extra H2D).
if (pending_runtime_params_buffer != nullptr &&
pending_param_signature == param_signature(unique_tensors)) {
restore_partial_params();
promote_pending_to_partial();
return true;
}

restore_pending_params();
restore_partial_params();
GGML_ASSERT(!params_on_runtime_backend);
GGML_ASSERT(partial_runtime_params_buffer == nullptr);

ggml_init_params params;
params.mem_size = std::max<size_t>(1, unique_tensors.size()) * ggml_tensor_overhead();
params.mem_buffer = nullptr;
Expand Down Expand Up @@ -2303,6 +2339,95 @@ struct GGMLRunner {
}
}

bool offload_pending_params(const std::vector<ggml_tensor*>& tensors) {
restore_pending_params();
if (params_backend == runtime_backend) {
return true;
}
if (tensors.empty()) {
return true;
}

std::vector<ggml_tensor*> unique_tensors;
dedup_runtime_params(tensors, unique_tensors);
if (unique_tensors.empty()) {
return true;
}

ggml_init_params params;
params.mem_size = std::max<size_t>(1, unique_tensors.size()) * ggml_tensor_overhead();
params.mem_buffer = nullptr;
params.no_alloc = true;

pending_offload_ctx = ggml_init(params);
GGML_ASSERT(pending_offload_ctx != nullptr);
pending_offload_pairs.reserve(unique_tensors.size());

for (ggml_tensor* tensor : unique_tensors) {
GGML_ASSERT(tensor->view_src == nullptr);
ggml_tensor* offload_tensor = ggml_dup_tensor(pending_offload_ctx, tensor);
ggml_set_name(offload_tensor, tensor->name);
pending_offload_pairs.push_back({tensor, offload_tensor});
}

pending_runtime_params_buffer = ggml_backend_alloc_ctx_tensors(pending_offload_ctx, runtime_backend);
if (pending_runtime_params_buffer == nullptr) {
LOG_DEBUG("%s alloc pending runtime params backend buffer failed, num_tensors = %zu",
get_desc().c_str(),
pending_offload_pairs.size());
ggml_free(pending_offload_ctx);
pending_offload_ctx = nullptr;
pending_offload_pairs.clear();
return false;
}
ggml_backend_buffer_set_usage(pending_runtime_params_buffer, GGML_BACKEND_BUFFER_USAGE_WEIGHTS);

// Original tensors stay pointed at the partial buffer until promote.
for (auto& pair : pending_offload_pairs) {
ggml_backend_tensor_copy(pair.first, pair.second);
}

pending_param_signature = param_signature(unique_tensors);
return true;
}

void restore_pending_params() {
pending_offload_pairs.clear();
if (pending_runtime_params_buffer != nullptr) {
ggml_backend_buffer_free(pending_runtime_params_buffer);
pending_runtime_params_buffer = nullptr;
}
if (pending_offload_ctx != nullptr) {
ggml_free(pending_offload_ctx);
pending_offload_ctx = nullptr;
}
pending_param_signature = 0;
}

// Caller must have already restore_partial_params()ed.
void promote_pending_to_partial() {
GGML_ASSERT(partial_runtime_params_buffer == nullptr);
GGML_ASSERT(partial_offload_ctx == nullptr);
GGML_ASSERT(partial_offload_pairs.empty());

for (auto& pair : pending_offload_pairs) {
ggml_tensor* tensor = pair.first;
ggml_tensor* offload_tensor = pair.second;
std::swap(tensor->buffer, offload_tensor->buffer);
std::swap(tensor->data, offload_tensor->data);
std::swap(tensor->extra, offload_tensor->extra);
}

partial_offload_ctx = pending_offload_ctx;
partial_runtime_params_buffer = pending_runtime_params_buffer;
partial_offload_pairs = std::move(pending_offload_pairs);

pending_offload_ctx = nullptr;
pending_runtime_params_buffer = nullptr;
pending_offload_pairs.clear();
pending_param_signature = 0;
}

bool offload_resident_params(const std::vector<ggml_tensor*>& tensors) {
if (params_backend == runtime_backend) {
return true;
Expand Down Expand Up @@ -2631,7 +2756,8 @@ struct GGMLRunner {
const std::vector<ggml_tensor*>& runtime_param_tensors,
bool preserve_backend_tensor_data_map,
bool no_return = false,
const std::unordered_set<std::string>* cache_keep_names = nullptr) {
const std::unordered_set<std::string>* cache_keep_names = nullptr,
const std::function<void()>& prefetch_cb = {}) {
int64_t t_execute_begin = ggml_time_ms();
const bool use_partial_param_offload = !runtime_param_tensors.empty();
int64_t t_offload_begin = ggml_time_ms();
Expand Down Expand Up @@ -2676,9 +2802,14 @@ struct GGMLRunner {
}

int64_t t_compute_begin = ggml_time_ms();
ggml_status status = ggml_backend_graph_compute(runtime_backend, gf);
int64_t t_compute_end = ggml_time_ms();
ggml_status status = ggml_backend_graph_compute_async(runtime_backend, gf);
if (prefetch_cb) {
prefetch_cb();
}
ggml_backend_synchronize(runtime_backend);
int64_t t_compute_end = ggml_time_ms();
if (status != GGML_STATUS_SUCCESS) {
restore_pending_params();
LOG_ERROR("%s compute failed: %s", get_desc().c_str(), ggml_status_to_string(status));
if (free_compute_buffer_immediately) {
free_compute_buffer();
Expand Down Expand Up @@ -2955,15 +3086,29 @@ struct GGMLRunner {

ggml_context* segment_graph_ctx = nullptr;
ggml_cgraph* segment_graph = sd::ggml_graph_cut::build_segment_graph(gf, segment, &segment_graph_ctx);
auto segment_output = execute_graph<T>(segment_graph,

std::function<void()> prefetch_cb;
if (!is_last) {
const auto& next_segment = plan.segments[seg_idx + 1];
auto next_params = sd::ggml_graph_cut::runtime_param_tensors(gf, next_segment, get_desc().c_str());
if (!next_params.empty()) {
prefetch_cb = [this, next_params = std::move(next_params)]() {
offload_pending_params(next_params);
};
}
}

auto segment_output = execute_graph<T>(segment_graph,
n_threads,
/*free_compute_buffer_immediately=*/true,
sd::ggml_graph_cut::runtime_param_tensors(gf, segment, get_desc().c_str()),
/*preserve_backend_tensor_data_map=*/true,
/*no_return=*/!is_last || no_return,
&future_cut_names);
&future_cut_names,
prefetch_cb);
ggml_free(segment_graph_ctx);
if (!segment_output.has_value()) {
restore_pending_params();
free_cache_ctx_and_buffer();
free_compute_buffer();
free_compute_ctx();
Expand Down Expand Up @@ -3081,6 +3226,7 @@ struct GGMLRunner {

void free_params_buffer() {
// Restore swapped resident params before freeing their backing buffer.
restore_pending_params();
restore_resident_params();
if (params_buffer != nullptr) {
ggml_backend_buffer_free(params_buffer);
Expand Down
Loading