mirror of
https://github.com/kvcache-ai/sglang.git
synced 2026-05-14 01:34:58 +00:00
fix(v4-flash): bundle V4-2604B SwiGLU clamp + hybrid SWA chunked-prefill hang fix (#44)
* fix(kt-ep): match cpu_buf dtype to kt-kernel's bf16 scale write for MXFP4 kt-kernel's write_weights_to_buffer (operators/amx/fp4-moe.hpp) writes gate/up scales as bf16 via fast_fp32_to_bf16, but mxfp4_deepseek allocates w13/w2_weight_scale_inv as fp32. The 2x element-size mismatch caused kt-kernel to fill only the first half of cpu_buf in fp32-element terms; after Phase 3 .to(float8_e8m0fnu) the second half (= up_proj rows) became 2^-127, zeroing dequantized up_proj weights for all experts loaded via the kt double-buffered pipeline. Single-chunk GPU prefill on V4-Flash MXFP4 produced mode-collapsed garbage as a result. Allocate the cpu_buf with bf16 dtype for these two scale tensors so kt-kernel's write fills it exactly; gpu_t[e].copy_(cpu_buf[slot]) then performs the bf16->fp32 dtype cast automatically. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> * feat(v4-2604b): apply SwiGLU clamp on triton-kernels GPU MoE path The trtllm and deep_gemm paths both apply a 2604B-specific asymmetric gate/up clamp (gate.clamp(max=limit); up.clamp(-limit, limit)) on the gemm1 output before silu_and_mul. The triton-kernels path (default GPU MoE on every capability outside _TRTLLM_FP4_CAPS, including SM_120 RTX 5090) was constructing a bare matmul_ogs → silu_and_mul → matmul_ogs sequence with no clamp, leaving routed-expert outputs numerically inconsistent with the trtllm reference on long-prompt / large-activation tokens. Threads moe_runner_config.swiglu_limit through DeepSeekMxfp4MoEMethod.apply to apply_v4_triton_kernels_moe; semantics match moe_runner/deep_gemm.py:_apply_swiglu_limit verbatim. No-op when submode != 2604B (swiglu_limit is None). Origin: sglang 本身. * feat(v4-2604b): pass swiglu_limit through KTEPWrapper to kt-kernel The kt CPU expert path was applying plain silu(g)*u with no clamp, diverging from the trtllm `gemm1_clamp_limit` and deep_gemm `_apply_swiglu_limit` references on long-prompt / large-activation tokens. Companion changes in kt-kernel (`feat/v4-2604b-swiglu-clamp:d10bd3d`) plumb a `swiglu_limit` field through `MOEConfig` into the AMX `act_fn`; this commit passes the value through the kt-sglang bridge. The KTMoEWrapper is constructed in `create_weights`, before `create_moe_runner` delivers the full `MoeRunnerConfig`, but the value is fully determined by SGLANG_DSV4_2604_SUBMODE which is fixed at process start, so we read the env directly here. Mirrors the `assert swiglu_limit == 10` in moe_runner/deep_gemm.py and the `torch.full(..., swiglu_limit, ...)` constructor in mxfp4_deepseek.py:177-186. Origin: kt-sglang 耦合. * fix(scheduler): correct inverted chunked_req check that hangs hybrid SWA chunked prefill In _get_new_batch_prefill_raw the inline comment explicitly says "Ignore the check if self.chunked_req is not None" but the code below used `is not None`, which is the opposite. With --disable-radix-cache + hybrid SWA + multi-chunk prompt, the chunked_req keeps holding its req_pool slot across chunks (ChunkCache.cache_unfinished_req does not release it), and ReqToTokenPool initialises free_slots = list(range(1, size)) wasting index 0, so once chunked_req takes the only available slot the check fires forever and the scheduler returns None on every iteration -> silent hang (chunk1 prefill completes, chunk2 never starts; TP CPU 60-145% busy spin; client request never returns). The sister check at line 2065 (`and self.chunked_req is None: return None`) is correctly inverted; this brings line 2082 in line with the comment and with that sister check. Repro (DeepSeek V4 Flash, hybrid SWA, page_size=256): --disable-radix-cache --chunked-prefill-size 2048 \ --tensor-parallel-size 4 --max-running-requests 2 + a prompt > 2048 tokens (forces multi-chunk) Before: chunk1 prefill runs, then silent hang or false-positive "token_to_kv_pool_allocator memory leak detected" SIGQUIT (the hybrid leak check is also too strict; addressed in a follow-up commit). After: 5001-token English prompt -> 3 chunks, HTTP 200 in 26.4s; 6695-token Chinese prompt -> 4 chunks, HTTP 200 in 52.2s. Origin: sglang itself (not kt-sglang coupling). Reproduces on pip- installed upstream sglang as well as on the kt third_party submodule. * fix(scheduler): skip self_check_during_idle when in-flight work still holds KV slots Defensive guard for the same bug class as the previous commit. When the scheduler enters the idle branch with chunked_req != None or a non-empty running_batch / waiting_queue, the in-flight KV slots are not yet freed nor cached. _check_hybrid_memory then reports them as leaked because its formula `full_num_used != 0` does not subtract protected_size / in-flight usage the way _check_radix_cache_memory does. The result was a SIGQUIT-on-false-positive: 4 TP ranks raise simultaneously and the server dies mid-request. The other branches of self_check_during_idle (DisaggregationMode.PREFILL and .DECODE) already early-return on similar in-flight conditions; this patch adds the equivalent guard for DisaggregationMode.NULL which had no such check. The same pattern is used at scheduler.py line 1372 and process_input_requests around line 1370. This guard is no longer load-bearing once the scheduler.py 2082 fix is in (chunked prefill advances every iter, the scheduler never reaches batch=None mid-request), but is kept as defence-in-depth against any future path that produces a double-None batch frame. Origin: sglang itself. --------- Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -556,9 +556,15 @@ class SharedFullContext:
|
||||
gpu_tensor = getattr(self.gpu_layer, name)
|
||||
# Only allocate 2 experts worth of buffer (double buffering)
|
||||
expert_shape = gpu_tensor.shape[1:] # Shape per expert
|
||||
expert_nbytes = (
|
||||
gpu_tensor.numel() // num_experts * gpu_tensor.element_size()
|
||||
)
|
||||
if (
|
||||
getattr(self, "is_mxfp4_quant", False)
|
||||
and name in ("w13_weight_scale_inv", "w2_weight_scale_inv")
|
||||
):
|
||||
buf_dtype = torch.bfloat16
|
||||
else:
|
||||
buf_dtype = gpu_tensor.dtype
|
||||
element_size = torch.empty((), dtype=buf_dtype).element_size()
|
||||
expert_nbytes = gpu_tensor.numel() // num_experts * element_size
|
||||
double_buf_nbytes = expert_nbytes * 2
|
||||
|
||||
shm_name = f"kt_buf_{name}_r{tp_rank}_{self.shm_unique_id}"
|
||||
@@ -568,7 +574,7 @@ class SharedFullContext:
|
||||
self.shm_handles[name] = shm
|
||||
|
||||
# Shape: [2, ...expert_shape...]
|
||||
cpu_buffer = torch.frombuffer(shm.buf, dtype=gpu_tensor.dtype).reshape(
|
||||
cpu_buffer = torch.frombuffer(shm.buf, dtype=buf_dtype).reshape(
|
||||
(2,) + expert_shape
|
||||
)
|
||||
|
||||
@@ -2285,6 +2291,18 @@ class KTEPWrapperMethod(FusedMoEMethodBase):
|
||||
# 2. Initialize KT wrapper for CPU experts
|
||||
# CPU experts are identified by gpu_experts_mask=False
|
||||
if self.tp_rank == 0:
|
||||
# V4-Flash 2604B SwiGLU clamp on routed experts. The full
|
||||
# moe_runner_config (which carries swiglu_limit) does not arrive
|
||||
# until create_moe_runner(), but the value is fully determined
|
||||
# by the DSV4 submode env (fixed at process start), so we read
|
||||
# it here without waiting. Matches the assert
|
||||
# `swiglu_limit == 10` in moe_runner/deep_gemm.py:_apply_swiglu_limit
|
||||
# and the default 10.0 set for 2604B in mxfp4_deepseek.py.
|
||||
# Origin: kt-sglang 耦合 (carries V4-2604B limit into kt-kernel).
|
||||
from sglang.srt.environ import envs as _envs
|
||||
_kt_swiglu_limit = (
|
||||
10.0 if _envs.SGLANG_DSV4_2604_SUBMODE.get() == "2604B" else 0.0
|
||||
)
|
||||
self.wrapper = KTMoEWrapper(
|
||||
layer_idx=self.kt_config.layer_idx,
|
||||
num_experts=num_experts,
|
||||
@@ -2294,6 +2312,7 @@ class KTEPWrapperMethod(FusedMoEMethodBase):
|
||||
gpu_experts_mask=self.gpu_experts_mask,
|
||||
cpuinfer_threads=self.kt_config.cpuinfer_threads,
|
||||
threadpool_count=self.kt_config.threadpool_count,
|
||||
swiglu_limit=_kt_swiglu_limit,
|
||||
numa_nodes=self.kt_config.numa_nodes,
|
||||
weight_path=self.kt_config.weight_path,
|
||||
chunked_prefill_size=self.kt_config.chunked_prefill_size,
|
||||
|
||||
@@ -481,6 +481,12 @@ class DeepSeekMxfp4MoEMethod:
|
||||
topk_ids,
|
||||
)
|
||||
rsf = layer.moe_runner_config.routed_scaling_factor
|
||||
# 2604B SwiGLU clamp: thread swiglu_limit through so the triton-
|
||||
# kernels GPU MoE path applies the same gate/up clamp as
|
||||
# trtllm's gemm1_clamp_limit and deep_gemm's _apply_swiglu_limit.
|
||||
# When submode != 2604B, moe_runner_config.swiglu_limit is None
|
||||
# and apply_v4_triton_kernels_moe skips the clamp.
|
||||
# Origin: sglang 本身.
|
||||
output = apply_v4_triton_kernels_moe(
|
||||
hidden_states=hidden_states,
|
||||
w13_swiz=layer._v4_tk_w13,
|
||||
@@ -492,6 +498,7 @@ class DeepSeekMxfp4MoEMethod:
|
||||
intermediate_size=layer._v4_tk_intermediate_size,
|
||||
num_experts=layer._v4_tk_num_experts,
|
||||
routed_scaling_factor=rsf if rsf is not None else 1.0,
|
||||
swiglu_limit=layer.moe_runner_config.swiglu_limit,
|
||||
)
|
||||
if envs.SGLANG_DSV4_2604_SUBMODE.get() == '2604B' and (
|
||||
self._gemm1_clamp_limit_tensor is not None
|
||||
|
||||
@@ -362,6 +362,7 @@ def apply_v4_triton_kernels_moe(
|
||||
intermediate_size: int, # per-partition N
|
||||
num_experts: int,
|
||||
routed_scaling_factor: float = 1.0,
|
||||
swiglu_limit: Optional[float] = None,
|
||||
) -> torch.Tensor:
|
||||
"""Run V4 sparse MoE through `triton_kernels.matmul_ogs`.
|
||||
|
||||
@@ -369,6 +370,14 @@ def apply_v4_triton_kernels_moe(
|
||||
byte-level reproducible under same input.
|
||||
|
||||
Activation: silu_and_mul (V4 default), applied between the two GEMMs.
|
||||
|
||||
`swiglu_limit` (DSV4 2604B): if not None, applies the same gate/up
|
||||
asymmetric clamp the trtllm path's `gemm1_clamp_limit` and the
|
||||
deep_gemm path's `_apply_swiglu_limit` use, on the gemm1 output before
|
||||
silu_and_mul:
|
||||
gate = clamp(gate, max=limit) # one-sided
|
||||
up = clamp(up, min=-limit, max=limit) # symmetric
|
||||
Origin: sglang 本身 (matches `moe_runner/deep_gemm.py:_apply_swiglu_limit`).
|
||||
"""
|
||||
from triton_kernels.matmul_ogs import matmul_ogs
|
||||
from sgl_kernel import silu_and_mul
|
||||
@@ -396,7 +405,16 @@ def apply_v4_triton_kernels_moe(
|
||||
gather_indx=gather_indx,
|
||||
precision_config=w13_pcg,
|
||||
)
|
||||
# intermediate1 shape: [M*topk, 2*N]
|
||||
# intermediate1 shape: [M*topk, 2*N]; layout = [gate, up] along last dim.
|
||||
# We skipped reorder_w1w3_to_w3w1 for this path so the natural [w1, w3]
|
||||
# = [gate, up] order from the checkpoint is preserved.
|
||||
if swiglu_limit is not None:
|
||||
# 2604B asymmetric SwiGLU clamp. View slices and clamp_ in place to
|
||||
# avoid chunk+cat copy; safe because intermediate1 is a fresh
|
||||
# matmul_ogs output, not a cached buffer.
|
||||
N_int = intermediate1.shape[-1] // 2
|
||||
intermediate1[..., :N_int].clamp_(max=swiglu_limit)
|
||||
intermediate1[..., N_int:].clamp_(min=-swiglu_limit, max=swiglu_limit)
|
||||
M_topk = intermediate1.shape[0]
|
||||
intermediate2 = torch.empty(
|
||||
(M_topk, N), device=hidden_states.device, dtype=hidden_states.dtype
|
||||
|
||||
@@ -2071,9 +2071,17 @@ class Scheduler(
|
||||
# as the space for the chunked requests has just been released.
|
||||
# In PP case, chunked requests (or dllm requests) can start in one microbatch and end in another microbatch, so the max_running_requests per microbatch should not be strict.
|
||||
# Instead, we should always allow chunked requests to be added, otherwise, there will be a memory leak.
|
||||
# issue1985 fix (sglang 本身): the inline comment above explicitly says
|
||||
# "Ignore the check if self.chunked_req is not None", but the code below
|
||||
# used `is not None`, which is the opposite. With --disable-radix-cache +
|
||||
# hybrid SWA, the chunked_req keeps holding its req_pool slot across
|
||||
# chunks (ChunkCache.cache_unfinished_req does not release), so when the
|
||||
# req pool only has one slot effectively (ReqToTokenPool initializes
|
||||
# free_slots = list(range(1, size)), wasting index 0), available_size
|
||||
# becomes 0 and this early-return fires forever -> silent hang.
|
||||
if (
|
||||
self.get_num_allocatable_reqs(running_bs) <= 0
|
||||
and self.chunked_req is not None
|
||||
and self.chunked_req is None
|
||||
and not self.try_preemption
|
||||
):
|
||||
self.running_batch.batch_is_full = True
|
||||
|
||||
@@ -412,6 +412,17 @@ class SchedulerRuntimeCheckerMixin:
|
||||
self.tree_cache.sanity_check()
|
||||
|
||||
def self_check_during_idle(self: Scheduler):
|
||||
# issue1985 fix (sglang 本身): skip self-check when any in-flight work
|
||||
# is still holding KV slots. With --disable-radix-cache + hybrid SWA +
|
||||
# chunked prefill, between chunks the chunked_req keeps the chunk's
|
||||
# slots in `used` state but they do not count toward evictable nor
|
||||
# protected. _check_hybrid_memory's `used != 0` then false-positives.
|
||||
if (
|
||||
self.chunked_req is not None
|
||||
or not self.running_batch.is_empty()
|
||||
or len(self.waiting_queue) > 0
|
||||
):
|
||||
return
|
||||
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
||||
if len(self.disagg_prefill_inflight_queue) > 0:
|
||||
return
|
||||
|
||||
Reference in New Issue
Block a user