From 9cecbd6f16e9f45cc2d9f138551cfc0162d6dab0 Mon Sep 17 00:00:00 2001 From: Hao Guo Date: Mon, 11 May 2026 17:22:47 -0700 Subject: [PATCH 1/6] add dflash in specdec Signed-off-by: Hao Guo --- examples/specdec_bench/dflash_kimi.yaml | 10 ++ examples/specdec_bench/dflash_qwen.yaml | 7 ++ examples/specdec_bench/run.py | 2 +- .../specdec_bench/models/sglang.py | 93 ++++++++++++------- 4 files changed, 75 insertions(+), 37 deletions(-) create mode 100644 examples/specdec_bench/dflash_kimi.yaml create mode 100644 examples/specdec_bench/dflash_qwen.yaml diff --git a/examples/specdec_bench/dflash_kimi.yaml b/examples/specdec_bench/dflash_kimi.yaml new file mode 100644 index 00000000000..0e34c5359e7 --- /dev/null +++ b/examples/specdec_bench/dflash_kimi.yaml @@ -0,0 +1,10 @@ +chat_template_args: + thinking: true +engine_args: + mem_fraction_static: 0.9 + speculative_num_draft_tokens: 8 + # cuda_graph_max_bs: 128 + speculative_dflash_draft_window_size: 4096 + disable_cuda_graph: true +sampling_kwargs: + temperature: 0 diff --git a/examples/specdec_bench/dflash_qwen.yaml b/examples/specdec_bench/dflash_qwen.yaml new file mode 100644 index 00000000000..457e1eadaca --- /dev/null +++ b/examples/specdec_bench/dflash_qwen.yaml @@ -0,0 +1,7 @@ +engine_args: + mem_fraction_static: 0.9 + speculative_num_draft_tokens: 8 + speculative_dflash_draft_window_size: 4096 + mamba_scheduler_strategy: extra_buffer +sampling_kwargs: + temperature: 0 diff --git a/examples/specdec_bench/run.py b/examples/specdec_bench/run.py index f4fbf06c0e8..94932c787b8 100644 --- a/examples/specdec_bench/run.py +++ b/examples/specdec_bench/run.py @@ -265,7 +265,7 @@ def run_simple(args): type=str, required=False, default="EAGLE3", - choices=["EAGLE3", "EAGLE", "DRAFT_TARGET", "NGRAM", "MTP", "NONE"], + choices=["EAGLE3", "EAGLE", "DRAFT_TARGET", "NGRAM", "MTP", "DFLASH", "NONE"], help="Speculative algorithm to use", ) parser.add_argument("--model_dir", type=str, required=True, help="Path to the model directory") diff --git a/examples/specdec_bench/specdec_bench/models/sglang.py b/examples/specdec_bench/specdec_bench/models/sglang.py index d5ff890ffd7..00ba1de1f44 100644 --- a/examples/specdec_bench/specdec_bench/models/sglang.py +++ b/examples/specdec_bench/specdec_bench/models/sglang.py @@ -41,46 +41,67 @@ def __init__( speculative_algorithm = "STANDALONE" elif speculative_algorithm == "NGRAM": speculative_algorithm = "LOOKAHEAD" + elif speculative_algorithm == "DFLASH": + pass # SGLang native name, pass through elif speculative_algorithm == "NONE": speculative_algorithm = None + + engine_kwargs = dict( + model_path=model_dir, + skip_tokenizer_init=True, + trust_remote_code=kwargs.get("trust_remote_code", False), + mem_fraction_static=kwargs.get("mem_fraction_static", 0.8), + disable_overlap_schedule=kwargs.get("disable_overlap_schedule", False), + tp_size=kwargs.get("tensor_parallel_size", 1), + ep_size=kwargs.get("moe_expert_parallel_size", 1), + torch_compile_max_bs=max_concurrent_requests, + max_running_requests=max_concurrent_requests, + attention_backend=kwargs.get("attention_backend"), + enable_torch_compile=kwargs.get("enable_torch_compile", False), + cuda_graph_max_bs=max_concurrent_requests, + disable_cuda_graph=False, + ) if speculative_algorithm is not None: # https://github.com/sgl-project/sglang/pull/3582 - self.model = sgl.Engine( - model_path=model_dir, - skip_tokenizer_init=True, - trust_remote_code=kwargs.get("trust_remote_code", False), - mem_fraction_static=0.8, - disable_overlap_schedule=kwargs.get("disable_overlap_schedule", False), - tp_size=kwargs.get("tensor_parallel_size", 1), - ep_size=kwargs.get("moe_expert_parallel_size", 1), - speculative_algorithm=speculative_algorithm, - speculative_num_steps=kwargs.get("speculative_num_steps", 3), - speculative_eagle_topk=kwargs.get("speculative_eagle_topk", 1), - speculative_num_draft_tokens=kwargs.get("speculative_num_draft_tokens", 4), - speculative_draft_model_path=kwargs.get("draft_model_dir"), - torch_compile_max_bs=max_concurrent_requests, - max_running_requests=max_concurrent_requests, - attention_backend=kwargs.get("attention_backend"), - enable_torch_compile=kwargs.get("enable_torch_compile", False), - cuda_graph_max_bs=max_concurrent_requests, - disable_cuda_graph=False, - ) - else: - self.model = sgl.Engine( - model_path=model_dir, - skip_tokenizer_init=True, - trust_remote_code=kwargs.get("trust_remote_code", False), - mem_fraction_static=0.8, - disable_overlap_schedule=kwargs.get("disable_overlap_schedule", False), - tp_size=kwargs.get("tensor_parallel_size", 1), - ep_size=kwargs.get("moe_expert_parallel_size", 1), - torch_compile_max_bs=max_concurrent_requests, - max_running_requests=max_concurrent_requests, - attention_backend=kwargs.get("attention_backend"), - enable_torch_compile=kwargs.get("enable_torch_compile", False), - cuda_graph_max_bs=max_concurrent_requests, - disable_cuda_graph=False, - ) + engine_kwargs["speculative_algorithm"] = speculative_algorithm + num_draft_tokens = kwargs.get("speculative_num_draft_tokens", 4) + engine_kwargs["speculative_num_draft_tokens"] = num_draft_tokens + engine_kwargs["speculative_draft_model_path"] = kwargs.get("draft_model_dir") + if speculative_algorithm == "DFLASH": + if "speculative_dflash_draft_window_size" in kwargs: + engine_kwargs["speculative_dflash_draft_window_size"] = kwargs[ + "speculative_dflash_draft_window_size" + ] + print( + f"[specdec_bench] DFLASH ignores --draft_length / speculative_num_steps / " + f"speculative_eagle_topk; effective draft block = " + f"speculative_num_draft_tokens={num_draft_tokens}" + ) + else: + engine_kwargs["speculative_num_steps"] = kwargs.get("speculative_num_steps", 3) + engine_kwargs["speculative_eagle_topk"] = kwargs.get("speculative_eagle_topk", 1) + + # Forward any other kwargs (e.g. from runtime_params.engine_args) to + # sgl.Engine, letting yaml override the defaults set above. Skip only + # specdec_bench-internal routing keys that should never reach SGLang. + _internal_keys = frozenset({ + "speculative_algorithm", + "draft_model_dir", + "speculative_num_steps", + "speculative_eagle_topk", + "speculative_num_draft_tokens", + "speculative_dflash_draft_window_size", + "tensor_parallel_size", + "moe_expert_parallel_size", + "tokenizer_path", + "use_draft_logits", + }) + for _k, _v in kwargs.items(): + if _k in _internal_keys: + continue + engine_kwargs[_k] = _v + + self.model = sgl.Engine(**engine_kwargs) self.sampling_config = sampling_kwargs From 334776e18e97a571c7d0420dd18b9a21c085dc34 Mon Sep 17 00:00:00 2001 From: Hao Guo Date: Mon, 11 May 2026 18:06:06 -0700 Subject: [PATCH 2/6] dev Signed-off-by: Hao Guo --- examples/specdec_bench/dflash_kimi.yaml | 6 ------ examples/specdec_bench/dflash_qwen.yaml | 5 ----- examples/specdec_bench/specdec_bench/models/sglang.py | 10 +++------- 3 files changed, 3 insertions(+), 18 deletions(-) diff --git a/examples/specdec_bench/dflash_kimi.yaml b/examples/specdec_bench/dflash_kimi.yaml index 0e34c5359e7..0a284774740 100644 --- a/examples/specdec_bench/dflash_kimi.yaml +++ b/examples/specdec_bench/dflash_kimi.yaml @@ -1,10 +1,4 @@ chat_template_args: thinking: true engine_args: - mem_fraction_static: 0.9 - speculative_num_draft_tokens: 8 - # cuda_graph_max_bs: 128 speculative_dflash_draft_window_size: 4096 - disable_cuda_graph: true -sampling_kwargs: - temperature: 0 diff --git a/examples/specdec_bench/dflash_qwen.yaml b/examples/specdec_bench/dflash_qwen.yaml index 457e1eadaca..01b315d6a17 100644 --- a/examples/specdec_bench/dflash_qwen.yaml +++ b/examples/specdec_bench/dflash_qwen.yaml @@ -1,7 +1,2 @@ engine_args: - mem_fraction_static: 0.9 - speculative_num_draft_tokens: 8 - speculative_dflash_draft_window_size: 4096 mamba_scheduler_strategy: extra_buffer -sampling_kwargs: - temperature: 0 diff --git a/examples/specdec_bench/specdec_bench/models/sglang.py b/examples/specdec_bench/specdec_bench/models/sglang.py index 00ba1de1f44..59b6ff9bbb9 100644 --- a/examples/specdec_bench/specdec_bench/models/sglang.py +++ b/examples/specdec_bench/specdec_bench/models/sglang.py @@ -64,20 +64,16 @@ def __init__( if speculative_algorithm is not None: # https://github.com/sgl-project/sglang/pull/3582 engine_kwargs["speculative_algorithm"] = speculative_algorithm - num_draft_tokens = kwargs.get("speculative_num_draft_tokens", 4) - engine_kwargs["speculative_num_draft_tokens"] = num_draft_tokens engine_kwargs["speculative_draft_model_path"] = kwargs.get("draft_model_dir") if speculative_algorithm == "DFLASH": - if "speculative_dflash_draft_window_size" in kwargs: - engine_kwargs["speculative_dflash_draft_window_size"] = kwargs[ - "speculative_dflash_draft_window_size" - ] + engine_kwargs["speculative_num_draft_tokens"] = kwargs.get("speculative_num_draft_tokens", 8) print( f"[specdec_bench] DFLASH ignores --draft_length / speculative_num_steps / " f"speculative_eagle_topk; effective draft block = " - f"speculative_num_draft_tokens={num_draft_tokens}" + f"speculative_num_draft_tokens={engine_kwargs['speculative_num_draft_tokens']}" ) else: + engine_kwargs["speculative_num_draft_tokens"] = kwargs.get("speculative_num_draft_tokens", 4) engine_kwargs["speculative_num_steps"] = kwargs.get("speculative_num_steps", 3) engine_kwargs["speculative_eagle_topk"] = kwargs.get("speculative_eagle_topk", 1) From cad8728b2aa9d1a8001126175ad24ab1be6f3a7a Mon Sep 17 00:00:00 2001 From: Hao Guo Date: Mon, 11 May 2026 22:49:55 -0700 Subject: [PATCH 3/6] dev Signed-off-by: Hao Guo --- examples/specdec_bench/dflash_kimi.yaml | 4 --- examples/specdec_bench/dflash_qwen.yaml | 2 -- .../specdec_bench/models/sglang.py | 26 +++++-------------- 3 files changed, 7 insertions(+), 25 deletions(-) delete mode 100644 examples/specdec_bench/dflash_kimi.yaml delete mode 100644 examples/specdec_bench/dflash_qwen.yaml diff --git a/examples/specdec_bench/dflash_kimi.yaml b/examples/specdec_bench/dflash_kimi.yaml deleted file mode 100644 index 0a284774740..00000000000 --- a/examples/specdec_bench/dflash_kimi.yaml +++ /dev/null @@ -1,4 +0,0 @@ -chat_template_args: - thinking: true -engine_args: - speculative_dflash_draft_window_size: 4096 diff --git a/examples/specdec_bench/dflash_qwen.yaml b/examples/specdec_bench/dflash_qwen.yaml deleted file mode 100644 index 01b315d6a17..00000000000 --- a/examples/specdec_bench/dflash_qwen.yaml +++ /dev/null @@ -1,2 +0,0 @@ -engine_args: - mamba_scheduler_strategy: extra_buffer diff --git a/examples/specdec_bench/specdec_bench/models/sglang.py b/examples/specdec_bench/specdec_bench/models/sglang.py index 59b6ff9bbb9..3510886a1e4 100644 --- a/examples/specdec_bench/specdec_bench/models/sglang.py +++ b/examples/specdec_bench/specdec_bench/models/sglang.py @@ -67,6 +67,10 @@ def __init__( engine_kwargs["speculative_draft_model_path"] = kwargs.get("draft_model_dir") if speculative_algorithm == "DFLASH": engine_kwargs["speculative_num_draft_tokens"] = kwargs.get("speculative_num_draft_tokens", 8) + if "speculative_dflash_draft_window_size" in kwargs: + engine_kwargs["speculative_dflash_draft_window_size"] = kwargs[ + "speculative_dflash_draft_window_size" + ] print( f"[specdec_bench] DFLASH ignores --draft_length / speculative_num_steps / " f"speculative_eagle_topk; effective draft block = " @@ -77,25 +81,9 @@ def __init__( engine_kwargs["speculative_num_steps"] = kwargs.get("speculative_num_steps", 3) engine_kwargs["speculative_eagle_topk"] = kwargs.get("speculative_eagle_topk", 1) - # Forward any other kwargs (e.g. from runtime_params.engine_args) to - # sgl.Engine, letting yaml override the defaults set above. Skip only - # specdec_bench-internal routing keys that should never reach SGLang. - _internal_keys = frozenset({ - "speculative_algorithm", - "draft_model_dir", - "speculative_num_steps", - "speculative_eagle_topk", - "speculative_num_draft_tokens", - "speculative_dflash_draft_window_size", - "tensor_parallel_size", - "moe_expert_parallel_size", - "tokenizer_path", - "use_draft_logits", - }) - for _k, _v in kwargs.items(): - if _k in _internal_keys: - continue - engine_kwargs[_k] = _v + # mamba_scheduler_strategy: extra_buffer needed for qwen3.5 + if "mamba_scheduler_strategy" in kwargs: + engine_kwargs["mamba_scheduler_strategy"] = kwargs["mamba_scheduler_strategy"] self.model = sgl.Engine(**engine_kwargs) From cc096230405297f2c89624b67a8bb9fe13a95601 Mon Sep 17 00:00:00 2001 From: Hao Guo Date: Mon, 11 May 2026 23:09:28 -0700 Subject: [PATCH 4/6] disable cuda graph padding Signed-off-by: Hao Guo --- examples/specdec_bench/specdec_bench/models/sglang.py | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/specdec_bench/specdec_bench/models/sglang.py b/examples/specdec_bench/specdec_bench/models/sglang.py index 3510886a1e4..36f754a8dd3 100644 --- a/examples/specdec_bench/specdec_bench/models/sglang.py +++ b/examples/specdec_bench/specdec_bench/models/sglang.py @@ -60,6 +60,7 @@ def __init__( enable_torch_compile=kwargs.get("enable_torch_compile", False), cuda_graph_max_bs=max_concurrent_requests, disable_cuda_graph=False, + disable_cuda_graph_padding=True, ) if speculative_algorithm is not None: # https://github.com/sgl-project/sglang/pull/3582 From 7bbaae21acc0b7655eadfce598cc4adfa018fb1c Mon Sep 17 00:00:00 2001 From: Hao Guo Date: Mon, 11 May 2026 23:11:27 -0700 Subject: [PATCH 5/6] polish Signed-off-by: Hao Guo --- examples/specdec_bench/specdec_bench/models/sglang.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/examples/specdec_bench/specdec_bench/models/sglang.py b/examples/specdec_bench/specdec_bench/models/sglang.py index 36f754a8dd3..c5f320e2c78 100644 --- a/examples/specdec_bench/specdec_bench/models/sglang.py +++ b/examples/specdec_bench/specdec_bench/models/sglang.py @@ -41,8 +41,6 @@ def __init__( speculative_algorithm = "STANDALONE" elif speculative_algorithm == "NGRAM": speculative_algorithm = "LOOKAHEAD" - elif speculative_algorithm == "DFLASH": - pass # SGLang native name, pass through elif speculative_algorithm == "NONE": speculative_algorithm = None From 61c7e797c10f67b772ccca3fddcb381c6e9bc0b9 Mon Sep 17 00:00:00 2001 From: Hao Guo Date: Tue, 12 May 2026 00:08:43 -0700 Subject: [PATCH 6/6] add dflash to vllm backend Signed-off-by: Hao Guo --- examples/specdec_bench/specdec_bench/models/sglang.py | 2 +- examples/specdec_bench/specdec_bench/models/vllm.py | 6 ++++++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/examples/specdec_bench/specdec_bench/models/sglang.py b/examples/specdec_bench/specdec_bench/models/sglang.py index c5f320e2c78..5e39695aff7 100644 --- a/examples/specdec_bench/specdec_bench/models/sglang.py +++ b/examples/specdec_bench/specdec_bench/models/sglang.py @@ -80,7 +80,7 @@ def __init__( engine_kwargs["speculative_num_steps"] = kwargs.get("speculative_num_steps", 3) engine_kwargs["speculative_eagle_topk"] = kwargs.get("speculative_eagle_topk", 1) - # mamba_scheduler_strategy: extra_buffer needed for qwen3.5 + # extra engine arg needed for qwen3.5 if "mamba_scheduler_strategy" in kwargs: engine_kwargs["mamba_scheduler_strategy"] = kwargs["mamba_scheduler_strategy"] diff --git a/examples/specdec_bench/specdec_bench/models/vllm.py b/examples/specdec_bench/specdec_bench/models/vllm.py index 2e312e7aec8..fc595c1d579 100644 --- a/examples/specdec_bench/specdec_bench/models/vllm.py +++ b/examples/specdec_bench/specdec_bench/models/vllm.py @@ -63,6 +63,12 @@ def __init__(self, model_dir, max_concurrent_requests, sampling_kwargs, **kwargs "method": "mtp", "num_speculative_tokens": kwargs.get("speculative_num_steps", 3), } + elif kwargs.get("speculative_algorithm") == "DFLASH": + specdec = { + "method": "dflash", + "model": kwargs.get("draft_model_dir"), + "num_speculative_tokens": kwargs.get("speculative_num_draft_tokens", 8), + } elif kwargs.get("speculative_algorithm") == "NONE": specdec = None