diff --git a/include/stable-diffusion.h b/include/stable-diffusion.h index 17596f849..c16a754c6 100644 --- a/include/stable-diffusion.h +++ b/include/stable-diffusion.h @@ -450,6 +450,14 @@ SD_API void sd_img_gen_params_init(sd_img_gen_params_t* sd_img_gen_params); SD_API char* sd_img_gen_params_to_str(const sd_img_gen_params_t* sd_img_gen_params); SD_API sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_gen_params); +enum sd_cancel_mode_t { + SD_CANCEL_ALL, + SD_CANCEL_NEW_LATENTS, + SD_CANCEL_RESET +}; + +SD_API void sd_cancel_generation(sd_ctx_t* sd_ctx, enum sd_cancel_mode_t mode); + SD_API void sd_vid_gen_params_init(sd_vid_gen_params_t* sd_vid_gen_params); SD_API bool generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* sd_vid_gen_params, diff --git a/src/stable-diffusion.cpp b/src/stable-diffusion.cpp index 8ba4a463a..3d436e9f0 100644 --- a/src/stable-diffusion.cpp +++ b/src/stable-diffusion.cpp @@ -50,6 +50,8 @@ const char* sd_vae_format_name(enum sd_vae_format_t format); static SDVersion sd_vae_format_to_version(enum sd_vae_format_t format, SDVersion fallback); +#include + const char* model_version_to_str[] = { "SD 1.x", "SD 1.x Inpaint", @@ -155,6 +157,9 @@ static float get_cache_reuse_threshold(const sd_cache_params_t& params) { /*=============================================== StableDiffusionGGML ================================================*/ +static_assert(std::atomic::is_always_lock_free, + "sd_cancel_mode_t must be lock-free"); + class StableDiffusionGGML { public: std::vector mmap_tensor_store; @@ -225,6 +230,20 @@ class StableDiffusionGGML { return module_backend; } + std::atomic cancellation_flag; + + void set_cancel_flag(enum sd_cancel_mode_t flag) { + cancellation_flag.store(flag, std::memory_order_release); + } + + void reset_cancel_flag() { + set_cancel_flag(SD_CANCEL_RESET); + } + + enum sd_cancel_mode_t get_cancel_flag() { + return cancellation_flag.load(std::memory_order_acquire); + } + bool ensure_backend_pair(SDBackendModule module) { if (backend_for(module) == nullptr) { return false; @@ -1968,6 +1987,12 @@ class StableDiffusionGGML { SamplePreviewContext preview = prepare_sample_preview_context(); auto denoise = [&](const sd::Tensor& x, float sigma, int step) -> sd::guidance::GuiderOutput { + enum sd_cancel_mode_t cancel_flag = get_cancel_flag(); + if (cancel_flag != SD_CANCEL_RESET) { + LOG_DEBUG("cancelling generation"); + return {}; + } + if (step == 1 || step == -1) { pretty_progress(0, (int)steps, 0); } @@ -3010,6 +3035,15 @@ void free_sd_ctx(sd_ctx_t* sd_ctx) { free(sd_ctx); } +SD_API void sd_cancel_generation(sd_ctx_t* sd_ctx, enum sd_cancel_mode_t mode) { + if (sd_ctx && sd_ctx->sd) { + if (mode < SD_CANCEL_ALL || mode > SD_CANCEL_RESET) { + mode = SD_CANCEL_ALL; + } + sd_ctx->sd->set_cancel_flag(mode); + } +} + static sd_audio_t* waveform_to_sd_audio(const StableDiffusionGGML* sd, const sd::Tensor& waveform) { if (sd == nullptr || waveform.empty()) { @@ -4196,6 +4230,10 @@ static sd_image_t* decode_image_outputs(sd_ctx_t* sd_ctx, int64_t t0 = ggml_time_ms(); for (size_t i = 0; i < final_latents.size(); i++) { + if (sd_ctx->sd->get_cancel_flag() == SD_CANCEL_ALL) { + LOG_ERROR("cancelling latent decodings"); + break; + } int64_t t1 = ggml_time_ms(); sd::Tensor image = sd_ctx->sd->decode_first_stage(final_latents[i]); if (image.empty()) { @@ -4410,6 +4448,8 @@ SD_API sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* s return nullptr; } + sd_ctx->sd->reset_cancel_flag(); + int64_t t0 = ggml_time_ms(); sd_ctx->sd->vae_tiling_params = sd_img_gen_params->vae_tiling_params; GenerationRequest request(sd_ctx, sd_img_gen_params); @@ -4445,6 +4485,12 @@ SD_API sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* s std::vector> final_latents; int64_t denoise_start = ggml_time_ms(); for (int b = 0; b < request.batch_count; b++) { + sd_cancel_mode_t cancel = sd_ctx->sd->get_cancel_flag(); + if (cancel == SD_CANCEL_NEW_LATENTS || cancel == SD_CANCEL_ALL) { + LOG_ERROR("cancelling generation"); + break; + } + int64_t sampling_start = ggml_time_ms(); int64_t cur_seed = request.seed + b; LOG_INFO("generating image: %i/%i - seed %" PRId64, b + 1, request.batch_count, cur_seed); @@ -5218,6 +5264,9 @@ SD_API bool generate_video(sd_ctx_t* sd_ctx, if (audio_out != nullptr) { *audio_out = nullptr; } + + sd_ctx->sd->reset_cancel_flag(); + if (num_frames_out != nullptr) { *num_frames_out = 0; }