[New Model] DeepSeek-V4-Flash: kt-kernel MXFP4 MoE + sglang hybrid inference (#1970)

* [feat](kt-kernel): add MXFP4 MoE operator with E2M1 weights × BF16 activations

Implements AMX_FP4_MOE_TP based on the RAWINT4 (k2-moe) CRTP pattern.
FP4 E2M1 weights are nibble-packed and decoded via PSHUFB LUT, then
computed with BF16 activations using _mm512_dpbf16_ps. Supports weight-only
per-kgroup scaling (group_size=32) and tensor parallelism.

Includes a Python validation test covering uniform, alternating, ramp,
and random weight patterns.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>

* [feat](kt-kernel): adapt MXFP4 MoE backend for DeepSeek-V4-Flash (#1950)

V4-Flash routed experts ship as native MXFP4 (E2M1 nibble + ue8m0 group
scale). Expose AMXFP4_KGroup_MOE through NativeMoEWrapper, add a loader
that handles V4's `layers.{L}.ffn.experts.{i}.{w1,w3,w2}.{weight,scale}`
naming and converts ue8m0 → bf16 via a lossless bit-cast, register the
model entry, and ship an end-to-end numerical validation script.

Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

* [perf](kt-kernel): MXFP4 MoE add mat-mat 4×4 tile, refine mat-vec reduce (#1957)

mat_mul_kgroup previously aliased to fp4_mat_vec_kgroup, leaving large
batches stuck on the per-token path. Implement fp4_mat_mat_kgroup as a
4×4 register tile (MB=NB=4, 16 zmm accumulators) so each PSHUFB decode
of four weight rows is reused across four tokens.

Refactor fp4_mat_vec_kgroup to accumulate four N-rows in parallel and
flush them with a new reduce4 helper, removing per-row reduce_add_ps
calls from the hot loop. Mark mxfp4_to_bf16_32 always_inline.

Add bench/bench_fp4_moe.py with --routing {balanced,concentrated} and
a backend registry so future kernels can be added without changing the
runner.

Dispatch thresholds, derived_init, GeneralMOEConfig handling,
load_weights, write_weights_to_buffer and the TP_MOE specialization are
unchanged.

Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

* fix(loader): avoid uint16 lshift in ue8m0->bf16 conversion

PyTorch CPU has no lshift kernel for UInt16, so the previous
`(scale_t.to(torch.uint16) << 7)` raised NotImplementedError when
loading any V4-Flash MXFP4 routed-expert scale tensor on the host.

Switch to int32 for the shift (kernel exists) and narrow to int16
afterwards. The shifted value max is 255<<7 = 32640, well within
int16 range, so the narrow is lossless. The .view(bfloat16) bit
pattern is identical (bf16 sign bit is always 0 for ue8m0 values).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

* docs(v4-flash): hybrid CPU/GPU recipe + bump kt-sglang submodule

Bumps third_party/sglang to kvcache-ai/sglang main (3cbd49c29) which now
contains DeepSeek V4 Flash model support + consumer-GPU (SM_120) portable
Triton/TileLang fallbacks (kt-sglang PR #38).

Adds doc/en/DeepSeek-V4-Flash.md tutorial: 8x RTX 5090 hybrid recipe with
the full launch command, OpenAI-compatible /generate + /v1/chat/completions
examples, and the kt chat CLI client.

---------

Co-authored-by: ouqingliang <1692110604@qq.com>
Co-authored-by: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
Benjamin F
2026-05-03 10:48:31 +08:00
committed by GitHub
parent fe06c4d355
commit 041bdfc636
12 changed files with 1902 additions and 2 deletions

View File

@@ -1231,3 +1231,91 @@ class GPTQSafeTensorLoader(FP8SafeTensorLoader):
"up_scale": up_scales,
"down_scale": down_scales,
}
class MXFP4SafeTensorLoader(SafeTensorLoader):
"""Loader for native MXFP4 expert weights (DeepSeek-V4-Flash format).
Per expert layout:
{base}.ffn.experts.{i}.w1.weight I8 [N, K/2] nibble-packed E2M1 (gate)
{base}.ffn.experts.{i}.w1.scale F8_E8M0 [N, K/32] ue8m0 group scale
{base}.ffn.experts.{i}.w3.{weight,scale} up
{base}.ffn.experts.{i}.w2.{weight,scale} down
V4 ckpt keys are not prefixed with ``model.``; we also probe the stripped form so
callers can keep passing ``base_key="model.layers.{L}"``. ue8m0 → bf16 is a lossless
bit shift (both have an 8-bit exponent and zero mantissa for ue8m0), and the AMX
FP4 backend already consumes bf16 scales.
"""
EXPERTS_PATH_TPL = "{base}.ffn.experts"
PROJ_NAMES = ("w1", "w3", "w2") # (gate, up, down)
def _experts_prefix_candidates(self, base_key: str) -> list[str]:
candidates = [self.EXPERTS_PATH_TPL.format(base=base_key)]
if base_key.startswith("model."):
candidates.append(self.EXPERTS_PATH_TPL.format(base=base_key[len("model.") :]))
return list(dict.fromkeys(candidates))
@staticmethod
def _ue8m0_to_bf16(scale_t: torch.Tensor) -> torch.Tensor:
if scale_t.dtype != torch.uint8:
scale_t = scale_t.view(torch.uint8)
# bf16 = [sign(1) | exp(8) | mant(7)]; setting mant=0, exp=e gives 2^(e-127),
# which is exactly the value encoded by ue8m0 for e ∈ [1, 254]. e=0 → bf16 +0
# (acceptable: ue8m0=0 represents 2^-127, below bf16 normal range), e=255 → +inf.
# Compute in int32 then narrow to int16 (max value is 255<<7=32640, fits int16),
# because torch CPU has no lshift kernel for uint16.
return (scale_t.to(torch.int32) << 7).to(torch.int16).view(torch.bfloat16).contiguous()
def load_experts(self, base_key: str, device: str = "cpu"):
gate_name, up_name, down_name = self.PROJ_NAMES
prefix = None
expert_count = 0
for cand in self._experts_prefix_candidates(base_key):
expert_count = 0
while self.has_tensor(f"{cand}.{expert_count}.{gate_name}.weight"):
expert_count += 1
if expert_count > 0:
prefix = cand
break
if prefix is None:
raise ValueError(
f"No MXFP4 experts found under any of: {self._experts_prefix_candidates(base_key)}"
)
gate_weights = [None] * expert_count
up_weights = [None] * expert_count
down_weights = [None] * expert_count
gate_scales = [None] * expert_count
up_scales = [None] * expert_count
down_scales = [None] * expert_count
for exp_id in range(expert_count):
for proj, dst in (
(gate_name, gate_weights),
(up_name, up_weights),
(down_name, down_weights),
):
w = self.load_tensor(f"{prefix}.{exp_id}.{proj}.weight", device).contiguous()
if w.dtype != torch.uint8:
w = w.view(torch.uint8)
dst[exp_id] = w
for proj, dst in (
(gate_name, gate_scales),
(up_name, up_scales),
(down_name, down_scales),
):
s = self.load_tensor(f"{prefix}.{exp_id}.{proj}.scale", device)
dst[exp_id] = self._ue8m0_to_bf16(s)
print(f"[MXFP4SafeTensorLoader] Loaded {expert_count} experts from {prefix}")
return {
"gate": gate_weights,
"up": up_weights,
"down": down_weights,
"gate_scale": gate_scales,
"up_scale": up_scales,
"down_scale": down_scales,
}