mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2026-05-20 04:19:17 +00:00
Release/0.6.2.post3: carry kt-kernel SwiGLU clamp companion missing from post2
This commit is contained in:
@@ -143,6 +143,12 @@ class KTMoEWrapper:
|
||||
# Quantization config (for K-Group SFT methods)
|
||||
group_size: int = 128,
|
||||
zero_point: bool = True,
|
||||
# V4-Flash 2604B SwiGLU clamp limit. 0.0 = disabled (default for
|
||||
# every dtype except DSV4-2604B routed experts, which set this to
|
||||
# 10.0 to match trtllm gemm1_clamp_limit / deep_gemm
|
||||
# _apply_swiglu_limit). Plumbed into MOEConfig.swiglu_limit and
|
||||
# consumed by amx::act_fn. Origin: kt-sglang 耦合.
|
||||
swiglu_limit: float = 0.0,
|
||||
):
|
||||
"""
|
||||
Factory method to create the appropriate backend implementation.
|
||||
@@ -213,8 +219,17 @@ class KTMoEWrapper:
|
||||
max_deferred_experts_per_token=max_deferred_experts_per_token,
|
||||
method=method,
|
||||
numa_nodes=numa_nodes,
|
||||
swiglu_limit=swiglu_limit,
|
||||
)
|
||||
else: # mode == "sft"
|
||||
# SFT factory does not plumb swiglu_limit; reject non-zero
|
||||
# rather than dropping it on the floor. Origin: kt-sglang 耦合.
|
||||
if swiglu_limit != 0.0:
|
||||
raise ValueError(
|
||||
f"swiglu_limit={swiglu_limit} is not supported in "
|
||||
f"mode='sft' (method={method!r}); SFT backends do not "
|
||||
f"implement the V4-2604B clamp."
|
||||
)
|
||||
return _create_sft_wrapper(
|
||||
layer_idx=layer_idx,
|
||||
num_experts=num_experts,
|
||||
@@ -300,6 +315,7 @@ def _create_inference_wrapper(
|
||||
max_deferred_experts_per_token: Optional[int],
|
||||
method: str,
|
||||
numa_nodes: Optional[List[int]] = None,
|
||||
swiglu_limit: float = 0.0,
|
||||
) -> BaseMoEWrapper:
|
||||
"""
|
||||
Create an inference wrapper based on the method.
|
||||
@@ -323,7 +339,24 @@ def _create_inference_wrapper(
|
||||
# This shouldn't happen due to validation in __new__
|
||||
raise NotImplementedError(f"Unsupported inference method: {method}")
|
||||
|
||||
# Create and return backend instance
|
||||
# Create and return backend instance.
|
||||
# `swiglu_limit != 0` is meaningful only on the MXFP4 path. NativeMoEWrapper
|
||||
# also serves RAWINT4 / FP8 / BF16 / FP8_PERCHANNEL / GPTQ_INT4, so a
|
||||
# `backend_cls is NativeMoEWrapper` test would silently forward a stale
|
||||
# 10.0 (e.g., from a leftover SGLANG_DSV4_2604_SUBMODE=2604B in the env)
|
||||
# into a non-MXFP4 backend; act_fn would then clamp gate/up to ±10 with
|
||||
# no warning. Gate strictly on method instead. Origin: kt-sglang 耦合.
|
||||
extra_kwargs = {}
|
||||
if method == "MXFP4":
|
||||
extra_kwargs["swiglu_limit"] = swiglu_limit
|
||||
elif swiglu_limit != 0.0:
|
||||
raise ValueError(
|
||||
f"swiglu_limit={swiglu_limit} is only supported on method='MXFP4', "
|
||||
f"got method={method!r} (backend={backend_cls.__name__}). This "
|
||||
f"usually means SGLANG_DSV4_2604_SUBMODE=2604B is set in the "
|
||||
f"environment while the current launch does not actually use "
|
||||
f"MXFP4 weights — either unset the env or pass --kt-method MXFP4."
|
||||
)
|
||||
return backend_cls(
|
||||
layer_idx=layer_idx,
|
||||
num_experts=num_experts,
|
||||
@@ -339,6 +372,7 @@ def _create_inference_wrapper(
|
||||
max_deferred_experts_per_token=max_deferred_experts_per_token,
|
||||
method=method,
|
||||
numa_nodes=numa_nodes,
|
||||
**extra_kwargs,
|
||||
)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user