mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2026-05-19 11:59:19 +00:00
[New Model] DeepSeek-V4-Flash: kt-kernel MXFP4 MoE + sglang hybrid inference (#1970)
* [feat](kt-kernel): add MXFP4 MoE operator with E2M1 weights × BF16 activations Implements AMX_FP4_MOE_TP based on the RAWINT4 (k2-moe) CRTP pattern. FP4 E2M1 weights are nibble-packed and decoded via PSHUFB LUT, then computed with BF16 activations using _mm512_dpbf16_ps. Supports weight-only per-kgroup scaling (group_size=32) and tensor parallelism. Includes a Python validation test covering uniform, alternating, ramp, and random weight patterns. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com> * [feat](kt-kernel): adapt MXFP4 MoE backend for DeepSeek-V4-Flash (#1950) V4-Flash routed experts ship as native MXFP4 (E2M1 nibble + ue8m0 group scale). Expose AMXFP4_KGroup_MOE through NativeMoEWrapper, add a loader that handles V4's `layers.{L}.ffn.experts.{i}.{w1,w3,w2}.{weight,scale}` naming and converts ue8m0 → bf16 via a lossless bit-cast, register the model entry, and ship an end-to-end numerical validation script. Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com> * [perf](kt-kernel): MXFP4 MoE add mat-mat 4×4 tile, refine mat-vec reduce (#1957) mat_mul_kgroup previously aliased to fp4_mat_vec_kgroup, leaving large batches stuck on the per-token path. Implement fp4_mat_mat_kgroup as a 4×4 register tile (MB=NB=4, 16 zmm accumulators) so each PSHUFB decode of four weight rows is reused across four tokens. Refactor fp4_mat_vec_kgroup to accumulate four N-rows in parallel and flush them with a new reduce4 helper, removing per-row reduce_add_ps calls from the hot loop. Mark mxfp4_to_bf16_32 always_inline. Add bench/bench_fp4_moe.py with --routing {balanced,concentrated} and a backend registry so future kernels can be added without changing the runner. Dispatch thresholds, derived_init, GeneralMOEConfig handling, load_weights, write_weights_to_buffer and the TP_MOE specialization are unchanged. Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com> * fix(loader): avoid uint16 lshift in ue8m0->bf16 conversion PyTorch CPU has no lshift kernel for UInt16, so the previous `(scale_t.to(torch.uint16) << 7)` raised NotImplementedError when loading any V4-Flash MXFP4 routed-expert scale tensor on the host. Switch to int32 for the shift (kernel exists) and narrow to int16 afterwards. The shifted value max is 255<<7 = 32640, well within int16 range, so the narrow is lossless. The .view(bfloat16) bit pattern is identical (bf16 sign bit is always 0 for ue8m0 values). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> * docs(v4-flash): hybrid CPU/GPU recipe + bump kt-sglang submodule Bumps third_party/sglang to kvcache-ai/sglang main (3cbd49c29) which now contains DeepSeek V4 Flash model support + consumer-GPU (SM_120) portable Triton/TileLang fallbacks (kt-sglang PR #38). Adds doc/en/DeepSeek-V4-Flash.md tutorial: 8x RTX 5090 hybrid recipe with the full launch command, OpenAI-compatible /generate + /v1/chat/completions examples, and the kt chat CLI client. --------- Co-authored-by: ouqingliang <1692110604@qq.com> Co-authored-by: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
@@ -81,6 +81,26 @@ BUILTIN_MODELS: list[ModelInfo] = [
|
||||
description="DeepSeek R1-0528 reasoning model (May 2025, improved reasoning depth)",
|
||||
description_zh="DeepSeek R1-0528 推理模型(2025年5月,改进的推理深度)",
|
||||
),
|
||||
ModelInfo(
|
||||
name="DeepSeek-V4-Flash",
|
||||
hf_repo="deepseek-ai/DeepSeek-V4-Flash",
|
||||
aliases=["deepseek-v4-flash", "deepseek-v4", "dsv4", "v4-flash", "v4"],
|
||||
type="moe",
|
||||
default_params={
|
||||
"kt-method": "MXFP4",
|
||||
"kt-gpu-prefill-token-threshold": 4096,
|
||||
"attention-backend": "flashinfer",
|
||||
"max-total-tokens": 100000,
|
||||
"max-running-requests": 16,
|
||||
"chunked-prefill-size": 32768,
|
||||
"mem-fraction-static": 0.80,
|
||||
"watchdog-timeout": 3000,
|
||||
"served-model-name": "DeepSeek-V4-Flash",
|
||||
"disable-shared-experts-fusion": True,
|
||||
},
|
||||
description="DeepSeek V4-Flash MoE model (native MXFP4 experts, MQA + sparse index attention)",
|
||||
description_zh="DeepSeek V4-Flash MoE 模型(原生 MXFP4 专家,MQA + 稀疏索引注意力)",
|
||||
),
|
||||
ModelInfo(
|
||||
name="Kimi-K2-Thinking",
|
||||
hf_repo="moonshotai/Kimi-K2-Thinking",
|
||||
@@ -368,6 +388,19 @@ def compute_deepseek_v3_gpu_experts(tensor_parallel_size: int, vram_per_gpu_gb:
|
||||
return total_vram // 3
|
||||
|
||||
|
||||
def compute_deepseek_v4_gpu_experts(tensor_parallel_size: int, vram_per_gpu_gb: float) -> int:
|
||||
"""Compute kt-num-gpu-experts for DeepSeek-V4-Flash.
|
||||
|
||||
V4 uses MXFP4 experts (~0.5 bytes/param vs V3 FP8's 1 byte/param) so each GPU
|
||||
can hold ~2x more experts per VRAM unit than V3 at the same fragmentation.
|
||||
"""
|
||||
per_gpu_gb = 16
|
||||
if vram_per_gpu_gb < per_gpu_gb:
|
||||
return 0
|
||||
total_vram = int(tensor_parallel_size * (vram_per_gpu_gb - per_gpu_gb))
|
||||
return total_vram * 2 // 3
|
||||
|
||||
|
||||
def compute_kimi_k2_thinking_gpu_experts(tensor_parallel_size: int, vram_per_gpu_gb: float) -> int:
|
||||
"""Compute kt-num-gpu-experts for Kimi K2 Thinking."""
|
||||
per_gpu_gb = 16
|
||||
@@ -393,6 +426,7 @@ MODEL_COMPUTE_FUNCTIONS: dict[str, Callable[[int, float], int]] = {
|
||||
"DeepSeek-V3-0324": compute_deepseek_v3_gpu_experts,
|
||||
"DeepSeek-V3.2": compute_deepseek_v3_gpu_experts, # Same as V3-0324
|
||||
"DeepSeek-R1-0528": compute_deepseek_v3_gpu_experts, # Same as V3-0324
|
||||
"DeepSeek-V4-Flash": compute_deepseek_v4_gpu_experts,
|
||||
"Kimi-K2-Thinking": compute_kimi_k2_thinking_gpu_experts,
|
||||
"MiniMax-M2": compute_minimax_m2_gpu_experts,
|
||||
"MiniMax-M2.1": compute_minimax_m2_gpu_experts, # Same as M2
|
||||
|
||||
@@ -40,6 +40,7 @@ INFERENCE_METHODS = frozenset(
|
||||
"BF16", # BF16 native MoE
|
||||
"FP8_PERCHANNEL", # Per-channel FP8
|
||||
"GPTQ_INT4", # GPTQ INT4
|
||||
"MXFP4", # MXFP4 (E2M1 nibble + ue8m0 group scale, e.g. DeepSeek-V4-Flash routed experts)
|
||||
"LLAMAFILE", # GGUF format
|
||||
"MOE_INT4",
|
||||
"MOE_INT8", # General kernel
|
||||
@@ -312,7 +313,7 @@ def _create_inference_wrapper(
|
||||
# Select backend based on method
|
||||
if method in ["AMXINT4", "AMXINT8"]:
|
||||
backend_cls = AMXMoEWrapper
|
||||
elif method in ["RAWINT4", "FP8", "BF16", "FP8_PERCHANNEL", "GPTQ_INT4"]:
|
||||
elif method in ["RAWINT4", "FP8", "BF16", "FP8_PERCHANNEL", "GPTQ_INT4", "MXFP4"]:
|
||||
backend_cls = NativeMoEWrapper
|
||||
elif method == "LLAMAFILE":
|
||||
backend_cls = LlamafileMoEWrapper
|
||||
|
||||
@@ -11,6 +11,7 @@ from .loader import (
|
||||
FP8SafeTensorLoader,
|
||||
BF16SafeTensorLoader,
|
||||
GPTQSafeTensorLoader,
|
||||
MXFP4SafeTensorLoader,
|
||||
)
|
||||
from kt_kernel_ext.moe import MOEConfig
|
||||
import kt_kernel_ext.moe as _moe_mod
|
||||
@@ -18,6 +19,7 @@ import kt_kernel_ext.moe as _moe_mod
|
||||
AMXInt4_MOE = getattr(_moe_mod, "AMXInt4_MOE", None)
|
||||
AMXInt8_MOE = getattr(_moe_mod, "AMXInt8_MOE", None)
|
||||
AMXInt4_KGroup_MOE = getattr(_moe_mod, "AMXInt4_KGroup_MOE", None)
|
||||
AMXFP4_KGroup_MOE = getattr(_moe_mod, "AMXFP4_KGroup_MOE", None)
|
||||
AMXFP8_MOE = getattr(_moe_mod, "AMXFP8_MOE", None)
|
||||
AMXBF16_MOE = getattr(_moe_mod, "AMXBF16_MOE", None)
|
||||
AMXFP8PerChannel_MOE = getattr(_moe_mod, "AMXFP8PerChannel_MOE", None)
|
||||
@@ -31,6 +33,7 @@ AVXVNNI256RawInt4_MOE = getattr(_moe_mod, "AVXVNNI256RawInt4_MOE", None)
|
||||
_HAS_AMXINT4_SUPPORT = AMXInt4_MOE is not None
|
||||
_HAS_AMXINT8_SUPPORT = AMXInt8_MOE is not None
|
||||
_HAS_RAWINT4_SUPPORT = AMXInt4_KGroup_MOE is not None
|
||||
_HAS_MXFP4_SUPPORT = AMXFP4_KGroup_MOE is not None
|
||||
_HAS_FP8_SUPPORT = AMXFP8_MOE is not None
|
||||
_HAS_BF16_SUPPORT = AMXBF16_MOE is not None
|
||||
_HAS_FP8_PERCHANNEL_SUPPORT = AMXFP8PerChannel_MOE is not None
|
||||
@@ -495,6 +498,12 @@ class NativeMoEWrapper(BaseMoEWrapper):
|
||||
"Please recompile kt_kernel_ext with GPTQ INT4 support enabled.\n"
|
||||
"AVX-VNNI-256 will be selected automatically when available on the current CPU."
|
||||
)
|
||||
if method == "MXFP4" and not _HAS_MXFP4_SUPPORT:
|
||||
raise RuntimeError(
|
||||
"MXFP4 backend not available. Required ISA:\n"
|
||||
" - AVX512F + AVX512BW + AVX512_BF16\n"
|
||||
"Please recompile kt_kernel_ext with AVX512 + BF16 enabled."
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
layer_idx=layer_idx,
|
||||
@@ -525,6 +534,8 @@ class NativeMoEWrapper(BaseMoEWrapper):
|
||||
NativeMoEWrapper._native_loader_instance = BF16SafeTensorLoader(weight_path)
|
||||
elif method == "GPTQ_INT4":
|
||||
NativeMoEWrapper._native_loader_instance = GPTQSafeTensorLoader(weight_path)
|
||||
elif method == "MXFP4":
|
||||
NativeMoEWrapper._native_loader_instance = MXFP4SafeTensorLoader(weight_path)
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported method for NativeMoEWrapper: {method}")
|
||||
self.loader = NativeMoEWrapper._native_loader_instance
|
||||
@@ -592,6 +603,10 @@ class NativeMoEWrapper(BaseMoEWrapper):
|
||||
self.up_scales = [t.to(torch.float32).contiguous() for t in weights["up_scale"]]
|
||||
self.down_scales = [t.to(torch.float32).contiguous() for t in weights["down_scale"]]
|
||||
assert self.gate_scales[0].dtype == torch.float32, "Expected float32 scales for FP8_PERCHANNEL"
|
||||
elif self.method == "MXFP4":
|
||||
# ue8m0 is losslessly representable in bf16 (8-bit exponent, 0 mantissa);
|
||||
# the loader has already done that conversion.
|
||||
assert self.gate_scales[0].dtype == torch.bfloat16, "Expected bf16 scales for MXFP4"
|
||||
|
||||
t2 = time.time()
|
||||
|
||||
@@ -649,6 +664,14 @@ class NativeMoEWrapper(BaseMoEWrapper):
|
||||
f"{_AVXVNNI256_RAW_INT4_MAX_GROUP_SIZE}; AVX2 (AVX2RawInt4_MOE) is used as the final fallback."
|
||||
)
|
||||
self.moe = backend_cls(moe_config)
|
||||
elif self.method == "MXFP4":
|
||||
# MXFP4: E2M1 nibble-packed weights, ue8m0/bf16 per-32 group scale
|
||||
# (e.g. DeepSeek-V4-Flash routed experts)
|
||||
group_size = self.hidden_size // self.gate_scales[0].shape[1]
|
||||
moe_config.quant_config.bits = 4
|
||||
moe_config.quant_config.group_size = group_size
|
||||
moe_config.quant_config.zero_point = False
|
||||
self.moe = AMXFP4_KGroup_MOE(moe_config)
|
||||
elif self.method == "FP8":
|
||||
moe_config.quant_config.bits = 8
|
||||
moe_config.quant_config.group_size = 128
|
||||
|
||||
@@ -1231,3 +1231,91 @@ class GPTQSafeTensorLoader(FP8SafeTensorLoader):
|
||||
"up_scale": up_scales,
|
||||
"down_scale": down_scales,
|
||||
}
|
||||
|
||||
|
||||
class MXFP4SafeTensorLoader(SafeTensorLoader):
|
||||
"""Loader for native MXFP4 expert weights (DeepSeek-V4-Flash format).
|
||||
|
||||
Per expert layout:
|
||||
{base}.ffn.experts.{i}.w1.weight I8 [N, K/2] nibble-packed E2M1 (gate)
|
||||
{base}.ffn.experts.{i}.w1.scale F8_E8M0 [N, K/32] ue8m0 group scale
|
||||
{base}.ffn.experts.{i}.w3.{weight,scale} up
|
||||
{base}.ffn.experts.{i}.w2.{weight,scale} down
|
||||
|
||||
V4 ckpt keys are not prefixed with ``model.``; we also probe the stripped form so
|
||||
callers can keep passing ``base_key="model.layers.{L}"``. ue8m0 → bf16 is a lossless
|
||||
bit shift (both have an 8-bit exponent and zero mantissa for ue8m0), and the AMX
|
||||
FP4 backend already consumes bf16 scales.
|
||||
"""
|
||||
|
||||
EXPERTS_PATH_TPL = "{base}.ffn.experts"
|
||||
PROJ_NAMES = ("w1", "w3", "w2") # (gate, up, down)
|
||||
|
||||
def _experts_prefix_candidates(self, base_key: str) -> list[str]:
|
||||
candidates = [self.EXPERTS_PATH_TPL.format(base=base_key)]
|
||||
if base_key.startswith("model."):
|
||||
candidates.append(self.EXPERTS_PATH_TPL.format(base=base_key[len("model.") :]))
|
||||
return list(dict.fromkeys(candidates))
|
||||
|
||||
@staticmethod
|
||||
def _ue8m0_to_bf16(scale_t: torch.Tensor) -> torch.Tensor:
|
||||
if scale_t.dtype != torch.uint8:
|
||||
scale_t = scale_t.view(torch.uint8)
|
||||
# bf16 = [sign(1) | exp(8) | mant(7)]; setting mant=0, exp=e gives 2^(e-127),
|
||||
# which is exactly the value encoded by ue8m0 for e ∈ [1, 254]. e=0 → bf16 +0
|
||||
# (acceptable: ue8m0=0 represents 2^-127, below bf16 normal range), e=255 → +inf.
|
||||
# Compute in int32 then narrow to int16 (max value is 255<<7=32640, fits int16),
|
||||
# because torch CPU has no lshift kernel for uint16.
|
||||
return (scale_t.to(torch.int32) << 7).to(torch.int16).view(torch.bfloat16).contiguous()
|
||||
|
||||
def load_experts(self, base_key: str, device: str = "cpu"):
|
||||
gate_name, up_name, down_name = self.PROJ_NAMES
|
||||
prefix = None
|
||||
expert_count = 0
|
||||
for cand in self._experts_prefix_candidates(base_key):
|
||||
expert_count = 0
|
||||
while self.has_tensor(f"{cand}.{expert_count}.{gate_name}.weight"):
|
||||
expert_count += 1
|
||||
if expert_count > 0:
|
||||
prefix = cand
|
||||
break
|
||||
if prefix is None:
|
||||
raise ValueError(
|
||||
f"No MXFP4 experts found under any of: {self._experts_prefix_candidates(base_key)}"
|
||||
)
|
||||
|
||||
gate_weights = [None] * expert_count
|
||||
up_weights = [None] * expert_count
|
||||
down_weights = [None] * expert_count
|
||||
gate_scales = [None] * expert_count
|
||||
up_scales = [None] * expert_count
|
||||
down_scales = [None] * expert_count
|
||||
|
||||
for exp_id in range(expert_count):
|
||||
for proj, dst in (
|
||||
(gate_name, gate_weights),
|
||||
(up_name, up_weights),
|
||||
(down_name, down_weights),
|
||||
):
|
||||
w = self.load_tensor(f"{prefix}.{exp_id}.{proj}.weight", device).contiguous()
|
||||
if w.dtype != torch.uint8:
|
||||
w = w.view(torch.uint8)
|
||||
dst[exp_id] = w
|
||||
|
||||
for proj, dst in (
|
||||
(gate_name, gate_scales),
|
||||
(up_name, up_scales),
|
||||
(down_name, down_scales),
|
||||
):
|
||||
s = self.load_tensor(f"{prefix}.{exp_id}.{proj}.scale", device)
|
||||
dst[exp_id] = self._ue8m0_to_bf16(s)
|
||||
|
||||
print(f"[MXFP4SafeTensorLoader] Loaded {expert_count} experts from {prefix}")
|
||||
return {
|
||||
"gate": gate_weights,
|
||||
"up": up_weights,
|
||||
"down": down_weights,
|
||||
"gate_scale": gate_scales,
|
||||
"up_scale": up_scales,
|
||||
"down_scale": down_scales,
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user