mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2026-05-24 06:34:50 +00:00
Release/0.6.2.post3: carry kt-kernel SwiGLU clamp companion missing from post2
This commit is contained in:
@@ -757,6 +757,8 @@ PYBIND11_MODULE(kt_kernel_ext, m) {
|
||||
.def_readwrite("down_type", &GeneralMOEConfig::down_type)
|
||||
.def_readwrite("hidden_type", &GeneralMOEConfig::hidden_type)
|
||||
.def_readwrite("max_cache_depth", &GeneralMOEConfig::max_cache_depth)
|
||||
// V4-Flash 2604B SwiGLU clamp limit (0.0 = disabled). See common.hpp.
|
||||
.def_readwrite("swiglu_limit", &GeneralMOEConfig::swiglu_limit)
|
||||
|
||||
;
|
||||
|
||||
|
||||
@@ -44,7 +44,23 @@ static inline __m512 exp_avx512(__m512 x) {
|
||||
return _mm512_mul_ps(two_pow_i, frac_exp);
|
||||
}
|
||||
|
||||
static inline __m512 act_fn(__m512 gate_val, __m512 up_val) {
|
||||
static inline __m512 act_fn(__m512 gate_val, __m512 up_val, float swiglu_limit = 0.0f) {
|
||||
// DeepSeek V4-Flash 2604B asymmetric SwiGLU clamp. swiglu_limit > 0
|
||||
// applies the same clamp the trtllm `gemm1_clamp_limit` and the sglang
|
||||
// deep_gemm `_apply_swiglu_limit` use:
|
||||
// gate = clamp(gate, max=limit) // one-sided (pre-silu)
|
||||
// up = clamp(up, min=-limit, max=limit) // symmetric
|
||||
// The branch is on a runtime float; for swiglu_limit==0.0f (every non-
|
||||
// MXFP4 dtype today) the predictor stays on the fall-through path and
|
||||
// adds at most one cmp+jmp per 32-lane tile.
|
||||
// Origin: kt-sglang 耦合.
|
||||
if (swiglu_limit > 0.0f) {
|
||||
const __m512 pos_lim = _mm512_set1_ps(swiglu_limit);
|
||||
const __m512 neg_lim = _mm512_set1_ps(-swiglu_limit);
|
||||
gate_val = _mm512_min_ps(gate_val, pos_lim);
|
||||
up_val = _mm512_min_ps(up_val, pos_lim);
|
||||
up_val = _mm512_max_ps(up_val, neg_lim);
|
||||
}
|
||||
__m512 neg_gate_val = _mm512_sub_ps(_mm512_setzero_ps(), gate_val);
|
||||
// Clamp neg_gate_val to avoid exp overflow (exp(88) overflows for float32)
|
||||
const __m512 max_exp_input = _mm512_set1_ps(88.0f);
|
||||
|
||||
@@ -685,8 +685,8 @@ class AMX_MOE_BASE {
|
||||
__m512 gate_val0, gate_val1, up_val0, up_val1;
|
||||
avx512_32xbf16_to_32xfp32((__m512i*)(gate_output_ptr + j), &gate_val0, &gate_val1);
|
||||
avx512_32xbf16_to_32xfp32((__m512i*)(up_output_ptr + j), &up_val0, &up_val1);
|
||||
__m512 result0 = amx::act_fn(gate_val0, up_val0);
|
||||
__m512 result1 = amx::act_fn(gate_val1, up_val1);
|
||||
__m512 result0 = amx::act_fn(gate_val0, up_val0, config_.swiglu_limit);
|
||||
__m512 result1 = amx::act_fn(gate_val1, up_val1, config_.swiglu_limit);
|
||||
avx512_32xfp32_to_32xbf16(&result0, &result1, (__m512i*)(gate_output_ptr + j));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -308,6 +308,17 @@ struct GeneralMOEConfig {
|
||||
|
||||
int max_cache_depth = 1;
|
||||
|
||||
// SwiGLU asymmetric clamp applied to gate/up before silu*up. 0.0f =
|
||||
// disabled (default for all non-MXFP4 paths). Set to e.g. 10.0f for
|
||||
// DeepSeek V4-Flash 2604B routed experts, matching the trtllm
|
||||
// `gemm1_clamp_limit` and the sglang deep_gemm path's
|
||||
// `_apply_swiglu_limit`:
|
||||
// gate = clamp(gate, max=limit) // one-sided (silu input)
|
||||
// up = clamp(up, min=-limit, max=limit) // symmetric
|
||||
// Read by `act_fn` in la/amx.hpp; non-zero only for MXFP4 today.
|
||||
// Origin: kt-sglang 耦合 (carries the V4-2604B limit set by sglang side).
|
||||
float swiglu_limit = 0.0f;
|
||||
|
||||
GeneralMOEConfig() {}
|
||||
|
||||
GeneralMOEConfig(int expert_num, int routed_expert_num, int hidden_size, int intermediate_size)
|
||||
|
||||
@@ -143,6 +143,12 @@ class KTMoEWrapper:
|
||||
# Quantization config (for K-Group SFT methods)
|
||||
group_size: int = 128,
|
||||
zero_point: bool = True,
|
||||
# V4-Flash 2604B SwiGLU clamp limit. 0.0 = disabled (default for
|
||||
# every dtype except DSV4-2604B routed experts, which set this to
|
||||
# 10.0 to match trtllm gemm1_clamp_limit / deep_gemm
|
||||
# _apply_swiglu_limit). Plumbed into MOEConfig.swiglu_limit and
|
||||
# consumed by amx::act_fn. Origin: kt-sglang 耦合.
|
||||
swiglu_limit: float = 0.0,
|
||||
):
|
||||
"""
|
||||
Factory method to create the appropriate backend implementation.
|
||||
@@ -213,8 +219,17 @@ class KTMoEWrapper:
|
||||
max_deferred_experts_per_token=max_deferred_experts_per_token,
|
||||
method=method,
|
||||
numa_nodes=numa_nodes,
|
||||
swiglu_limit=swiglu_limit,
|
||||
)
|
||||
else: # mode == "sft"
|
||||
# SFT factory does not plumb swiglu_limit; reject non-zero
|
||||
# rather than dropping it on the floor. Origin: kt-sglang 耦合.
|
||||
if swiglu_limit != 0.0:
|
||||
raise ValueError(
|
||||
f"swiglu_limit={swiglu_limit} is not supported in "
|
||||
f"mode='sft' (method={method!r}); SFT backends do not "
|
||||
f"implement the V4-2604B clamp."
|
||||
)
|
||||
return _create_sft_wrapper(
|
||||
layer_idx=layer_idx,
|
||||
num_experts=num_experts,
|
||||
@@ -300,6 +315,7 @@ def _create_inference_wrapper(
|
||||
max_deferred_experts_per_token: Optional[int],
|
||||
method: str,
|
||||
numa_nodes: Optional[List[int]] = None,
|
||||
swiglu_limit: float = 0.0,
|
||||
) -> BaseMoEWrapper:
|
||||
"""
|
||||
Create an inference wrapper based on the method.
|
||||
@@ -323,7 +339,24 @@ def _create_inference_wrapper(
|
||||
# This shouldn't happen due to validation in __new__
|
||||
raise NotImplementedError(f"Unsupported inference method: {method}")
|
||||
|
||||
# Create and return backend instance
|
||||
# Create and return backend instance.
|
||||
# `swiglu_limit != 0` is meaningful only on the MXFP4 path. NativeMoEWrapper
|
||||
# also serves RAWINT4 / FP8 / BF16 / FP8_PERCHANNEL / GPTQ_INT4, so a
|
||||
# `backend_cls is NativeMoEWrapper` test would silently forward a stale
|
||||
# 10.0 (e.g., from a leftover SGLANG_DSV4_2604_SUBMODE=2604B in the env)
|
||||
# into a non-MXFP4 backend; act_fn would then clamp gate/up to ±10 with
|
||||
# no warning. Gate strictly on method instead. Origin: kt-sglang 耦合.
|
||||
extra_kwargs = {}
|
||||
if method == "MXFP4":
|
||||
extra_kwargs["swiglu_limit"] = swiglu_limit
|
||||
elif swiglu_limit != 0.0:
|
||||
raise ValueError(
|
||||
f"swiglu_limit={swiglu_limit} is only supported on method='MXFP4', "
|
||||
f"got method={method!r} (backend={backend_cls.__name__}). This "
|
||||
f"usually means SGLANG_DSV4_2604_SUBMODE=2604B is set in the "
|
||||
f"environment while the current launch does not actually use "
|
||||
f"MXFP4 weights — either unset the env or pass --kt-method MXFP4."
|
||||
)
|
||||
return backend_cls(
|
||||
layer_idx=layer_idx,
|
||||
num_experts=num_experts,
|
||||
@@ -339,6 +372,7 @@ def _create_inference_wrapper(
|
||||
max_deferred_experts_per_token=max_deferred_experts_per_token,
|
||||
method=method,
|
||||
numa_nodes=numa_nodes,
|
||||
**extra_kwargs,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -248,6 +248,7 @@ class BaseMoEWrapper(_MoEBase, ABC):
|
||||
max_deferred_experts_per_token: Optional[int] = None,
|
||||
method: str = "AMXINT4",
|
||||
numa_nodes: Optional[List[int]] = None,
|
||||
swiglu_limit: float = 0.0,
|
||||
):
|
||||
"""
|
||||
Initialize base MoE Wrapper.
|
||||
@@ -302,6 +303,11 @@ class BaseMoEWrapper(_MoEBase, ABC):
|
||||
|
||||
BaseMoEWrapper._layer_has_pending_deferred[self.layer_idx] = False
|
||||
self.method = method
|
||||
# V4-Flash 2604B SwiGLU clamp limit; 0.0 = disabled. NativeMoEWrapper
|
||||
# (MXFP4 path) reads this in load_weights() and writes it into
|
||||
# MOEConfig.swiglu_limit. Other backends ignore it (C++ act_fn skips
|
||||
# the clamp branch when limit==0). Origin: kt-sglang 耦合.
|
||||
self.swiglu_limit = float(swiglu_limit)
|
||||
|
||||
# Initialize CPU inference engine (singleton via shared base class)
|
||||
self.cpu_infer = self._get_cpu_infer(cpuinfer_threads, threadpool_count, numa_nodes=numa_nodes)
|
||||
|
||||
@@ -461,7 +461,17 @@ class NativeMoEWrapper(BaseMoEWrapper):
|
||||
max_deferred_experts_per_token: Optional[int] = None,
|
||||
method: str = "RAWINT4",
|
||||
numa_nodes: Optional[List[int]] = None,
|
||||
swiglu_limit: float = 0.0,
|
||||
):
|
||||
# Defence in depth: reject swiglu_limit on non-MXFP4 methods even
|
||||
# if the experts.py guard is bypassed (e.g., by a future caller
|
||||
# that constructs NativeMoEWrapper directly). Origin: kt-sglang 耦合.
|
||||
if swiglu_limit != 0.0 and method != "MXFP4":
|
||||
raise ValueError(
|
||||
f"NativeMoEWrapper received swiglu_limit={swiglu_limit} with "
|
||||
f"method={method!r}; the V4-2604B clamp only applies to MXFP4. "
|
||||
f"This indicates a missing guard in the caller."
|
||||
)
|
||||
if method == "RAWINT4" and not (
|
||||
_HAS_RAWINT4_SUPPORT or _HAS_AVX2_RAWINT4_SUPPORT or _HAS_AVXVNNI256_RAW_INT4_SUPPORT
|
||||
):
|
||||
@@ -520,6 +530,7 @@ class NativeMoEWrapper(BaseMoEWrapper):
|
||||
max_deferred_experts_per_token=max_deferred_experts_per_token,
|
||||
method=method,
|
||||
numa_nodes=numa_nodes,
|
||||
swiglu_limit=swiglu_limit,
|
||||
)
|
||||
|
||||
if NativeMoEWrapper._native_loader_instance is None:
|
||||
@@ -637,6 +648,20 @@ class NativeMoEWrapper(BaseMoEWrapper):
|
||||
moe_config.layer_idx = self.layer_idx
|
||||
moe_config.pool = self.cpu_infer.backend_
|
||||
moe_config.max_len = self.chunked_prefill_size
|
||||
# V4-Flash 2604B SwiGLU clamp; 0.0 = disabled (default for non-MXFP4
|
||||
# paths). Read by `act_fn` in operators/amx/la/amx.hpp via
|
||||
# `apply_activation` in operators/amx/moe_base.hpp. Re-checked here
|
||||
# (defence in depth) so a future caller that bypasses both the
|
||||
# experts.py and the __init__ guards still cannot apply the clamp
|
||||
# on RAWINT4 / FP8 / BF16 / FP8_PERCHANNEL / GPTQ_INT4 paths.
|
||||
# Origin: kt-sglang 耦合.
|
||||
if self.swiglu_limit != 0.0 and self.method != "MXFP4":
|
||||
raise ValueError(
|
||||
f"NativeMoEWrapper.load_weights: swiglu_limit="
|
||||
f"{self.swiglu_limit} with method={self.method!r}; clamp is "
|
||||
f"only valid for MXFP4."
|
||||
)
|
||||
moe_config.swiglu_limit = self.swiglu_limit
|
||||
|
||||
# Use gate_projs instead of gate_proj for per-expert pointers
|
||||
moe_config.gate_projs = gate_ptrs
|
||||
|
||||
2
setup.py
2
setup.py
@@ -23,7 +23,7 @@ setup(
|
||||
"accelerate-kt==1.14.0.post1",
|
||||
],
|
||||
"sglang": [
|
||||
"sglang-kt==0.6.2.post2",
|
||||
"sglang-kt==0.6.2.post3",
|
||||
],
|
||||
},
|
||||
)
|
||||
|
||||
@@ -3,4 +3,4 @@ KTransformers version information.
|
||||
Shared across the top-level package and kt-kernel.
|
||||
"""
|
||||
|
||||
__version__ = "0.6.2.post2"
|
||||
__version__ = "0.6.2.post3"
|
||||
|
||||
Reference in New Issue
Block a user