mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2026-05-16 10:29:42 +00:00
Release/0.6.2.post3: carry kt-kernel SwiGLU clamp companion missing from post2
This commit is contained in:
@@ -461,7 +461,17 @@ class NativeMoEWrapper(BaseMoEWrapper):
|
||||
max_deferred_experts_per_token: Optional[int] = None,
|
||||
method: str = "RAWINT4",
|
||||
numa_nodes: Optional[List[int]] = None,
|
||||
swiglu_limit: float = 0.0,
|
||||
):
|
||||
# Defence in depth: reject swiglu_limit on non-MXFP4 methods even
|
||||
# if the experts.py guard is bypassed (e.g., by a future caller
|
||||
# that constructs NativeMoEWrapper directly). Origin: kt-sglang 耦合.
|
||||
if swiglu_limit != 0.0 and method != "MXFP4":
|
||||
raise ValueError(
|
||||
f"NativeMoEWrapper received swiglu_limit={swiglu_limit} with "
|
||||
f"method={method!r}; the V4-2604B clamp only applies to MXFP4. "
|
||||
f"This indicates a missing guard in the caller."
|
||||
)
|
||||
if method == "RAWINT4" and not (
|
||||
_HAS_RAWINT4_SUPPORT or _HAS_AVX2_RAWINT4_SUPPORT or _HAS_AVXVNNI256_RAW_INT4_SUPPORT
|
||||
):
|
||||
@@ -520,6 +530,7 @@ class NativeMoEWrapper(BaseMoEWrapper):
|
||||
max_deferred_experts_per_token=max_deferred_experts_per_token,
|
||||
method=method,
|
||||
numa_nodes=numa_nodes,
|
||||
swiglu_limit=swiglu_limit,
|
||||
)
|
||||
|
||||
if NativeMoEWrapper._native_loader_instance is None:
|
||||
@@ -637,6 +648,20 @@ class NativeMoEWrapper(BaseMoEWrapper):
|
||||
moe_config.layer_idx = self.layer_idx
|
||||
moe_config.pool = self.cpu_infer.backend_
|
||||
moe_config.max_len = self.chunked_prefill_size
|
||||
# V4-Flash 2604B SwiGLU clamp; 0.0 = disabled (default for non-MXFP4
|
||||
# paths). Read by `act_fn` in operators/amx/la/amx.hpp via
|
||||
# `apply_activation` in operators/amx/moe_base.hpp. Re-checked here
|
||||
# (defence in depth) so a future caller that bypasses both the
|
||||
# experts.py and the __init__ guards still cannot apply the clamp
|
||||
# on RAWINT4 / FP8 / BF16 / FP8_PERCHANNEL / GPTQ_INT4 paths.
|
||||
# Origin: kt-sglang 耦合.
|
||||
if self.swiglu_limit != 0.0 and self.method != "MXFP4":
|
||||
raise ValueError(
|
||||
f"NativeMoEWrapper.load_weights: swiglu_limit="
|
||||
f"{self.swiglu_limit} with method={self.method!r}; clamp is "
|
||||
f"only valid for MXFP4."
|
||||
)
|
||||
moe_config.swiglu_limit = self.swiglu_limit
|
||||
|
||||
# Use gate_projs instead of gate_proj for per-expert pointers
|
||||
moe_config.gate_projs = gate_ptrs
|
||||
|
||||
Reference in New Issue
Block a user