diff --git a/kt-kernel/ext_bindings.cpp b/kt-kernel/ext_bindings.cpp index 58331884..66f6fd31 100644 --- a/kt-kernel/ext_bindings.cpp +++ b/kt-kernel/ext_bindings.cpp @@ -757,6 +757,8 @@ PYBIND11_MODULE(kt_kernel_ext, m) { .def_readwrite("down_type", &GeneralMOEConfig::down_type) .def_readwrite("hidden_type", &GeneralMOEConfig::hidden_type) .def_readwrite("max_cache_depth", &GeneralMOEConfig::max_cache_depth) + // V4-Flash 2604B SwiGLU clamp limit (0.0 = disabled). See common.hpp. + .def_readwrite("swiglu_limit", &GeneralMOEConfig::swiglu_limit) ; diff --git a/kt-kernel/operators/amx/la/amx.hpp b/kt-kernel/operators/amx/la/amx.hpp index 6281ca05..59b16e4c 100644 --- a/kt-kernel/operators/amx/la/amx.hpp +++ b/kt-kernel/operators/amx/la/amx.hpp @@ -44,7 +44,23 @@ static inline __m512 exp_avx512(__m512 x) { return _mm512_mul_ps(two_pow_i, frac_exp); } -static inline __m512 act_fn(__m512 gate_val, __m512 up_val) { +static inline __m512 act_fn(__m512 gate_val, __m512 up_val, float swiglu_limit = 0.0f) { + // DeepSeek V4-Flash 2604B asymmetric SwiGLU clamp. swiglu_limit > 0 + // applies the same clamp the trtllm `gemm1_clamp_limit` and the sglang + // deep_gemm `_apply_swiglu_limit` use: + // gate = clamp(gate, max=limit) // one-sided (pre-silu) + // up = clamp(up, min=-limit, max=limit) // symmetric + // The branch is on a runtime float; for swiglu_limit==0.0f (every non- + // MXFP4 dtype today) the predictor stays on the fall-through path and + // adds at most one cmp+jmp per 32-lane tile. + // Origin: kt-sglang 耦合. + if (swiglu_limit > 0.0f) { + const __m512 pos_lim = _mm512_set1_ps(swiglu_limit); + const __m512 neg_lim = _mm512_set1_ps(-swiglu_limit); + gate_val = _mm512_min_ps(gate_val, pos_lim); + up_val = _mm512_min_ps(up_val, pos_lim); + up_val = _mm512_max_ps(up_val, neg_lim); + } __m512 neg_gate_val = _mm512_sub_ps(_mm512_setzero_ps(), gate_val); // Clamp neg_gate_val to avoid exp overflow (exp(88) overflows for float32) const __m512 max_exp_input = _mm512_set1_ps(88.0f); diff --git a/kt-kernel/operators/amx/moe_base.hpp b/kt-kernel/operators/amx/moe_base.hpp index 09244384..e56e83af 100644 --- a/kt-kernel/operators/amx/moe_base.hpp +++ b/kt-kernel/operators/amx/moe_base.hpp @@ -685,8 +685,8 @@ class AMX_MOE_BASE { __m512 gate_val0, gate_val1, up_val0, up_val1; avx512_32xbf16_to_32xfp32((__m512i*)(gate_output_ptr + j), &gate_val0, &gate_val1); avx512_32xbf16_to_32xfp32((__m512i*)(up_output_ptr + j), &up_val0, &up_val1); - __m512 result0 = amx::act_fn(gate_val0, up_val0); - __m512 result1 = amx::act_fn(gate_val1, up_val1); + __m512 result0 = amx::act_fn(gate_val0, up_val0, config_.swiglu_limit); + __m512 result1 = amx::act_fn(gate_val1, up_val1, config_.swiglu_limit); avx512_32xfp32_to_32xbf16(&result0, &result1, (__m512i*)(gate_output_ptr + j)); } } diff --git a/kt-kernel/operators/common.hpp b/kt-kernel/operators/common.hpp index 0800f8c5..da7bbd4b 100644 --- a/kt-kernel/operators/common.hpp +++ b/kt-kernel/operators/common.hpp @@ -308,6 +308,17 @@ struct GeneralMOEConfig { int max_cache_depth = 1; + // SwiGLU asymmetric clamp applied to gate/up before silu*up. 0.0f = + // disabled (default for all non-MXFP4 paths). Set to e.g. 10.0f for + // DeepSeek V4-Flash 2604B routed experts, matching the trtllm + // `gemm1_clamp_limit` and the sglang deep_gemm path's + // `_apply_swiglu_limit`: + // gate = clamp(gate, max=limit) // one-sided (silu input) + // up = clamp(up, min=-limit, max=limit) // symmetric + // Read by `act_fn` in la/amx.hpp; non-zero only for MXFP4 today. + // Origin: kt-sglang 耦合 (carries the V4-2604B limit set by sglang side). + float swiglu_limit = 0.0f; + GeneralMOEConfig() {} GeneralMOEConfig(int expert_num, int routed_expert_num, int hidden_size, int intermediate_size) diff --git a/kt-kernel/python/experts.py b/kt-kernel/python/experts.py index 24a92dea..2e99bf33 100644 --- a/kt-kernel/python/experts.py +++ b/kt-kernel/python/experts.py @@ -143,6 +143,12 @@ class KTMoEWrapper: # Quantization config (for K-Group SFT methods) group_size: int = 128, zero_point: bool = True, + # V4-Flash 2604B SwiGLU clamp limit. 0.0 = disabled (default for + # every dtype except DSV4-2604B routed experts, which set this to + # 10.0 to match trtllm gemm1_clamp_limit / deep_gemm + # _apply_swiglu_limit). Plumbed into MOEConfig.swiglu_limit and + # consumed by amx::act_fn. Origin: kt-sglang 耦合. + swiglu_limit: float = 0.0, ): """ Factory method to create the appropriate backend implementation. @@ -213,8 +219,17 @@ class KTMoEWrapper: max_deferred_experts_per_token=max_deferred_experts_per_token, method=method, numa_nodes=numa_nodes, + swiglu_limit=swiglu_limit, ) else: # mode == "sft" + # SFT factory does not plumb swiglu_limit; reject non-zero + # rather than dropping it on the floor. Origin: kt-sglang 耦合. + if swiglu_limit != 0.0: + raise ValueError( + f"swiglu_limit={swiglu_limit} is not supported in " + f"mode='sft' (method={method!r}); SFT backends do not " + f"implement the V4-2604B clamp." + ) return _create_sft_wrapper( layer_idx=layer_idx, num_experts=num_experts, @@ -300,6 +315,7 @@ def _create_inference_wrapper( max_deferred_experts_per_token: Optional[int], method: str, numa_nodes: Optional[List[int]] = None, + swiglu_limit: float = 0.0, ) -> BaseMoEWrapper: """ Create an inference wrapper based on the method. @@ -323,7 +339,24 @@ def _create_inference_wrapper( # This shouldn't happen due to validation in __new__ raise NotImplementedError(f"Unsupported inference method: {method}") - # Create and return backend instance + # Create and return backend instance. + # `swiglu_limit != 0` is meaningful only on the MXFP4 path. NativeMoEWrapper + # also serves RAWINT4 / FP8 / BF16 / FP8_PERCHANNEL / GPTQ_INT4, so a + # `backend_cls is NativeMoEWrapper` test would silently forward a stale + # 10.0 (e.g., from a leftover SGLANG_DSV4_2604_SUBMODE=2604B in the env) + # into a non-MXFP4 backend; act_fn would then clamp gate/up to ±10 with + # no warning. Gate strictly on method instead. Origin: kt-sglang 耦合. + extra_kwargs = {} + if method == "MXFP4": + extra_kwargs["swiglu_limit"] = swiglu_limit + elif swiglu_limit != 0.0: + raise ValueError( + f"swiglu_limit={swiglu_limit} is only supported on method='MXFP4', " + f"got method={method!r} (backend={backend_cls.__name__}). This " + f"usually means SGLANG_DSV4_2604_SUBMODE=2604B is set in the " + f"environment while the current launch does not actually use " + f"MXFP4 weights — either unset the env or pass --kt-method MXFP4." + ) return backend_cls( layer_idx=layer_idx, num_experts=num_experts, @@ -339,6 +372,7 @@ def _create_inference_wrapper( max_deferred_experts_per_token=max_deferred_experts_per_token, method=method, numa_nodes=numa_nodes, + **extra_kwargs, ) diff --git a/kt-kernel/python/experts_base.py b/kt-kernel/python/experts_base.py index d8fb657b..bc58b98f 100644 --- a/kt-kernel/python/experts_base.py +++ b/kt-kernel/python/experts_base.py @@ -248,6 +248,7 @@ class BaseMoEWrapper(_MoEBase, ABC): max_deferred_experts_per_token: Optional[int] = None, method: str = "AMXINT4", numa_nodes: Optional[List[int]] = None, + swiglu_limit: float = 0.0, ): """ Initialize base MoE Wrapper. @@ -302,6 +303,11 @@ class BaseMoEWrapper(_MoEBase, ABC): BaseMoEWrapper._layer_has_pending_deferred[self.layer_idx] = False self.method = method + # V4-Flash 2604B SwiGLU clamp limit; 0.0 = disabled. NativeMoEWrapper + # (MXFP4 path) reads this in load_weights() and writes it into + # MOEConfig.swiglu_limit. Other backends ignore it (C++ act_fn skips + # the clamp branch when limit==0). Origin: kt-sglang 耦合. + self.swiglu_limit = float(swiglu_limit) # Initialize CPU inference engine (singleton via shared base class) self.cpu_infer = self._get_cpu_infer(cpuinfer_threads, threadpool_count, numa_nodes=numa_nodes) diff --git a/kt-kernel/python/utils/amx.py b/kt-kernel/python/utils/amx.py index 917cb0cc..955db324 100644 --- a/kt-kernel/python/utils/amx.py +++ b/kt-kernel/python/utils/amx.py @@ -461,7 +461,17 @@ class NativeMoEWrapper(BaseMoEWrapper): max_deferred_experts_per_token: Optional[int] = None, method: str = "RAWINT4", numa_nodes: Optional[List[int]] = None, + swiglu_limit: float = 0.0, ): + # Defence in depth: reject swiglu_limit on non-MXFP4 methods even + # if the experts.py guard is bypassed (e.g., by a future caller + # that constructs NativeMoEWrapper directly). Origin: kt-sglang 耦合. + if swiglu_limit != 0.0 and method != "MXFP4": + raise ValueError( + f"NativeMoEWrapper received swiglu_limit={swiglu_limit} with " + f"method={method!r}; the V4-2604B clamp only applies to MXFP4. " + f"This indicates a missing guard in the caller." + ) if method == "RAWINT4" and not ( _HAS_RAWINT4_SUPPORT or _HAS_AVX2_RAWINT4_SUPPORT or _HAS_AVXVNNI256_RAW_INT4_SUPPORT ): @@ -520,6 +530,7 @@ class NativeMoEWrapper(BaseMoEWrapper): max_deferred_experts_per_token=max_deferred_experts_per_token, method=method, numa_nodes=numa_nodes, + swiglu_limit=swiglu_limit, ) if NativeMoEWrapper._native_loader_instance is None: @@ -637,6 +648,20 @@ class NativeMoEWrapper(BaseMoEWrapper): moe_config.layer_idx = self.layer_idx moe_config.pool = self.cpu_infer.backend_ moe_config.max_len = self.chunked_prefill_size + # V4-Flash 2604B SwiGLU clamp; 0.0 = disabled (default for non-MXFP4 + # paths). Read by `act_fn` in operators/amx/la/amx.hpp via + # `apply_activation` in operators/amx/moe_base.hpp. Re-checked here + # (defence in depth) so a future caller that bypasses both the + # experts.py and the __init__ guards still cannot apply the clamp + # on RAWINT4 / FP8 / BF16 / FP8_PERCHANNEL / GPTQ_INT4 paths. + # Origin: kt-sglang 耦合. + if self.swiglu_limit != 0.0 and self.method != "MXFP4": + raise ValueError( + f"NativeMoEWrapper.load_weights: swiglu_limit=" + f"{self.swiglu_limit} with method={self.method!r}; clamp is " + f"only valid for MXFP4." + ) + moe_config.swiglu_limit = self.swiglu_limit # Use gate_projs instead of gate_proj for per-expert pointers moe_config.gate_projs = gate_ptrs diff --git a/setup.py b/setup.py index 05938c96..abc62d0b 100644 --- a/setup.py +++ b/setup.py @@ -23,7 +23,7 @@ setup( "accelerate-kt==1.14.0.post1", ], "sglang": [ - "sglang-kt==0.6.2.post2", + "sglang-kt==0.6.2.post3", ], }, ) diff --git a/version.py b/version.py index c64c6391..09dff297 100644 --- a/version.py +++ b/version.py @@ -3,4 +3,4 @@ KTransformers version information. Shared across the top-level package and kt-kernel. """ -__version__ = "0.6.2.post2" +__version__ = "0.6.2.post3"