[Deepseek] Refactor deepseek server_args _handle_model_specific_adjustments (#13687)

This commit is contained in:
hlu1
2025-11-23 12:41:14 -08:00
committed by GitHub
parent 5c2915494c
commit 618ca23802

View File

@@ -928,20 +928,91 @@ class ServerArgs:
hf_config = self.get_hf_config()
model_arch = hf_config.architectures[0]
if model_arch in ["DeepseekV3ForCausalLM"] and not is_deepseek_nsa(hf_config):
if self.enable_piecewise_cuda_graph:
logger.info("Piecewise CUDA graph is enabled, use MLA for prefill.")
if is_cuda() and is_sm100_supported():
if model_arch in ["DeepseekV3ForCausalLM"]:
if is_deepseek_nsa(hf_config):
if (
self.attention_backend is None
and self.prefill_attention_backend is None
and self.decode_attention_backend is None
):
self.attention_backend = "trtllm_mla"
logger.info(
"Use trtllm_mla as attention backend on sm100 for DeepseekV3ForCausalLM"
self.attention_backend = "nsa"
logger.warning("Set nsa attention backend for DeepSeek NSA.")
if not is_npu():
self.enable_dp_attention = True
logger.warning("DP attention is enabled for DeepSeek NSA.")
if self.enable_nsa_prefill_context_parallel:
# TODO Supports moe_dense_tp_size != 1, kv cache dtype = "fp8",moe_a2a_backend non-deepep and cross-machine operation .
self.moe_dense_tp_size = 1
self.moe_a2a_backend = "deepep"
self.ep_size = self.tp_size
self.kv_cache_dtype = "bf16"
assert (
self.tp_size == 8
), "Current multi-machine CP support suffers from precision issues. So context parallel only support Single machine(tp_size == 8)"
logger.warning(
f"Enable Context Parallel opt for deeeseekv3.2-DSA, Setting dp_size == {self.dp_size} and moe_dense_tp_size == {self.moe_dense_tp_size}, ep_size == {self.ep_size}, tp_size == {self.tp_size}, kv_cache_dtype == {self.kv_cache_dtype}, moe_a2a_backend {self.moe_a2a_backend} "
)
else:
self.dp_size = self.tp_size
self.page_size = 64
logger.warning("Setting page size to 64 for DeepSeek NSA.")
# For Hopper, we support both bf16 and fp8 kv cache; for Blackwell, we support fp8 only currently
import torch
major, _ = torch.cuda.get_device_capability()
if self.kv_cache_dtype == "auto":
self.kv_cache_dtype = "fp8_e4m3" if major >= 10 else "bfloat16"
logger.warning(
f"Setting KV cache dtype to {self.kv_cache_dtype} for DeepSeek NSA."
)
if self.kv_cache_dtype == "bf16":
self.kv_cache_dtype = "bfloat16"
assert self.kv_cache_dtype in [
"bfloat16",
"fp8_e4m3",
], "DeepSeek NSA only supports bf16/bfloat16 or fp8_e4m3 kv_cache_dtype"
if self.kv_cache_dtype == "fp8_e4m3":
# flashmla_auto dispatches to flashmla_sparse/flashmla_kv based on hardware and heuristics
self.nsa_prefill_backend = "flashmla_auto"
self.nsa_decode_backend = "flashmla_kv"
logger.warning(
"Setting NSA backend to flashmla_auto for prefill and flashmla_kv for decode for FP8 KV Cache."
)
else:
# set prefill/decode backends for Blackwell. The default settings are for Hopper.
if major >= 10:
self.nsa_prefill_backend = "flashmla_sparse"
self.nsa_decode_backend = "flashmla_sparse"
# Logging env vars for NSA
from sglang.srt.layers.attention.nsa.utils import (
print_nsa_bool_env_vars,
)
print_nsa_bool_env_vars()
else:
if self.enable_piecewise_cuda_graph:
logger.info("Piecewise CUDA graph is enabled, use MLA for prefill.")
if is_cuda() and is_sm100_supported():
if (
self.attention_backend is None
and self.prefill_attention_backend is None
and self.decode_attention_backend is None
):
self.attention_backend = "trtllm_mla"
logger.info(
"Use trtllm_mla as attention backend on sm100 for DeepseekV3ForCausalLM"
)
# common to all Deepseek MoE models
if is_cuda() and is_sm100_supported():
# workaround for https://github.com/flashinfer-ai/flashinfer/issues/2006
if not self.enable_dp_attention and self.nnodes == 1:
self.enable_flashinfer_allreduce_fusion = True
@@ -1148,72 +1219,6 @@ class ServerArgs:
logger.info(
"Use flashinfer_trtllm as MoE runner backend on sm100 for Qwen3NextForCausalLM"
)
if is_deepseek_nsa(hf_config):
if (
self.attention_backend is None
and self.prefill_attention_backend is None
and self.decode_attention_backend is None
):
self.attention_backend = "nsa"
logger.warning("Set nsa attention backend for DeepSeek NSA.")
if not is_npu():
self.enable_dp_attention = True
logger.warning("DP attention is enabled for DeepSeek NSA.")
if self.enable_nsa_prefill_context_parallel:
# TODO Supports moe_dense_tp_size != 1, kv cache dtype = "fp8",moe_a2a_backend non-deepep and cross-machine operation .
self.moe_dense_tp_size = 1
self.moe_a2a_backend = "deepep"
self.ep_size = self.tp_size
self.kv_cache_dtype = "bf16"
assert (
self.tp_size == 8
), "Current multi-machine CP support suffers from precision issues. So context parallel only support Single machine(tp_size == 8)"
logger.warning(
f"Enable Context Parallel opt for deeeseekv3.2-DSA, Setting dp_size == {self.dp_size} and moe_dense_tp_size == {self.moe_dense_tp_size}, ep_size == {self.ep_size}, tp_size == {self.tp_size}, kv_cache_dtype == {self.kv_cache_dtype}, moe_a2a_backend {self.moe_a2a_backend} "
)
else:
self.dp_size = self.tp_size
self.page_size = 64
logger.warning("Setting page size to 64 for DeepSeek NSA.")
# For Hopper, we support both bf16 and fp8 kv cache; for Blackwell, we support fp8 only currently
import torch
major, _ = torch.cuda.get_device_capability()
if self.kv_cache_dtype == "auto":
self.kv_cache_dtype = "fp8_e4m3" if major >= 10 else "bfloat16"
logger.warning(
f"Setting KV cache dtype to {self.kv_cache_dtype} for DeepSeek NSA."
)
if self.kv_cache_dtype == "bf16":
self.kv_cache_dtype = "bfloat16"
assert self.kv_cache_dtype in [
"bfloat16",
"fp8_e4m3",
], "DeepSeek NSA only supports bf16/bfloat16 or fp8_e4m3 kv_cache_dtype"
if self.kv_cache_dtype == "fp8_e4m3":
# flashmla_auto dispatches to flashmla_sparse/flashmla_kv based on hardware and heuristics
self.nsa_prefill_backend = "flashmla_auto"
self.nsa_decode_backend = "flashmla_kv"
logger.warning(
"Setting NSA backend to flashmla_auto for prefill and flashmla_kv for decode for FP8 KV Cache."
)
else:
# set prefill/decode backends for Blackwell. The default settings are for Hopper.
if major >= 10:
self.nsa_prefill_backend = "flashmla_sparse"
self.nsa_decode_backend = "flashmla_sparse"
# Logging env vars for NSA
from sglang.srt.layers.attention.nsa.utils import (
print_nsa_bool_env_vars,
)
print_nsa_bool_env_vars()
def _handle_sampling_backend(self):
if self.sampling_backend is None: