diff --git a/README.md b/README.md index 26b960cc..8cca9bdf 100644 --- a/README.md +++ b/README.md @@ -17,6 +17,7 @@ KTransformers is a research project focused on efficient inference and fine-tuni ## 🔥 Updates * **May 6, 2026**: KTransformers at [GOSIM Paris 2026](https://paris2026.gosim.org/zh/schedule/) — "Agentic AI on Edge" track. We'll present KT's inference performance on consumer hardware. +* **May 02, 2026**: DeepSeek-V4-Flash Support! ([Tutorial](./doc/en/DeepSeek-V4-Flash.md)) * **Apr 30, 2026**: KTransformers v0.6.1 refreshes kt-kernel inference and SFT docs with separate [Inference](./kt-kernel/README.md) and [SFT Quick Start](./doc/en/SFT/KTransformers-Fine-Tuning_Quick-Start.md) entry points. * **Mar 26, 2026**: Support AVX2-only CPU backend for KT-Kernel inference. ([Tutorial](./doc/en/kt-kernel/AVX2-Tutorial.md)) * **Feb 13, 2026**: MiniMax-M2.5 Day0 Support! ([Tutorial](./doc/en/MiniMax-M2.5.md)) diff --git a/doc/en/DeepSeek-V4-Flash.md b/doc/en/DeepSeek-V4-Flash.md new file mode 100644 index 00000000..a860a23d --- /dev/null +++ b/doc/en/DeepSeek-V4-Flash.md @@ -0,0 +1,121 @@ +# Running DeepSeek-V4-Flash with SGLang and KT-Kernel + +This tutorial demonstrates how to run **DeepSeek-V4-Flash** model inference using SGLang integrated with KT-Kernel for CPU-GPU heterogeneous inference. The hybrid path splits MXFP4 routed experts between CPU (KT-Kernel `cpuinfer`) and GPU (sglang `kt-num-gpu-experts`), enabling deployment on consumer-grade hardware. + +## Table of Contents + +- [Running DeepSeek-V4-Flash with SGLang and KT-Kernel](#running-deepseek-v4-flash-with-sglang-and-kt-kernel) + - [Table of Contents](#table-of-contents) + - [Hardware Requirements](#hardware-requirements) + - [Prerequisites](#prerequisites) + - [Step 1: Download Model Weights](#step-1-download-model-weights) + - [Step 2: Launch SGLang Server](#step-2-launch-sglang-server) + - [Launch Command (8× RTX 5090 Example)](#launch-command-8-rtx-5090-example) + - [Step 3: Send Inference Requests](#step-3-send-inference-requests) + - [Decode](#decode) + - [Interactive Chat (kt chat)](#interactive-chat-kt-chat) + +## Hardware Requirements + +**Validated Configuration (this tutorial):** +- **GPU**: 8× NVIDIA RTX 5090 (32GB VRAM each, SM_120) +- **CPU**: x86 CPU with AVX512 support +- **RAM**: ≥256GB system memory +- **Storage**: ~340GB for model weights + +**Supported GPU architectures** (auto-detected at startup; non-validated configurations should work but have not been benchmarked end-to-end): + +| Arch | Compute Cap | MXFP4 MoE | NSA sparse MLA | Validated | +|------|------------|-----------|----------------|-----------| +| Hopper (H100 / H200) | SM_90 | triton_kernels | flash_mla wheel | — | +| Datacenter Blackwell (B100 / B200) | SM_100 | trtllm-fp4 | Triton fallback | — | +| Consumer Blackwell (RTX 5090) | SM_120 | triton_kernels | Triton fallback | ✓ | +| Ada Lovelace (RTX 4090 / L20 / L40) | SM_89 | triton_kernels | Triton fallback | — | +| Ampere (A100 / A6000) | SM_80 / SM_86 | triton_kernels | Triton fallback | — | + + +## Prerequisites + +1. **KT-Kernel installed**: + ```bash + git clone https://github.com/kvcache-ai/ktransformers.git + cd ktransformers + git submodule update --init --recursive + cd kt-kernel && ./install.sh + ``` + +2. **SGLang installed** (kvcache-ai fork): + ```bash + ./install.sh # from ktransformers root + ``` + +3. **CUDA 12.8+** and **flashinfer ≥ 0.6.9** (`flashinfer-python` and `flashinfer-cubin` must be the same version): + ```bash + pip install --upgrade flashinfer-python flashinfer-cubin + ``` + This upgrade is required (even though `sglang-kt` pins `flashinfer_python==0.6.3`) because V4-Flash's MXFP4 MoE module imports `mxfp8_quantize`, `trtllm_fp4_block_scale_routed_moe`, etc., which only exist in flashinfer ≥ 0.6.9; + + +## Step 1: Download Model Weights + +```bash +mkdir -p /path/to/models +huggingface-cli download deepseek-ai/DeepSeek-V4-Flash \ + --local-dir /path/to/models/DeepSeek-V4-Flash +``` + +## Step 2: Launch SGLang Server + +### Launch Command (8× RTX 5090 Example) + +```bash +numactl --interleave=all python -m sglang.launch_server \ + --host 127.0.0.1 \ + --port 30000 \ + --model /path/to/models/DeepSeek-V4-Flash \ + --kt-weight-path /path/to/models/DeepSeek-V4-Flash \ + --kt-method MXFP4 \ + --kt-num-gpu-experts 144 \ + --kt-cpuinfer 8 \ + --kt-threadpool-count 2 \ + --kt-gpu-prefill-token-threshold 4096 \ + --kt-enable-dynamic-expert-update \ + --tensor-parallel-size 8 \ + --attention-backend flashinfer \ + --mem-fraction-static 0.80 \ + --chunked-prefill-size 2048 \ + --max-running-requests 4 \ + --max-total-tokens 32768 \ + --watchdog-timeout 3000 \ + --disable-shared-experts-fusion \ + --cuda-graph-bs 1 2 4 \ + --cuda-graph-max-bs 4 \ + --trust-remote-code +``` + +It takes about 4-5 minutes to start the server (weight load + CUDA Graph capture). + +See [KT-Kernel Parameters](https://github.com/kvcache-ai/ktransformers/tree/main/kt-kernel#kt-kernel-parameters) for detailed parameter tuning guidelines. + +## Step 3: Send Inference Requests + +### Decode + +```bash +curl -s -X POST http://127.0.0.1:30000/generate \ + -H "Content-Type: application/json" \ + -d '{ + "text": "Explain quantum computing in detail:", + "sampling_params": {"temperature": 0.0, "max_new_tokens": 256} + }' +``` + +### Interactive Chat (kt chat) + +The `kt` CLI ships with an OpenAI-compatible chat client that talks to the SGLang server's `/v1/chat/completions` endpoint: + +```bash +kt chat --host 127.0.0.1 --port 30000 --temperature 0.7 --max-tokens 2048 +``` + + diff --git a/kt-kernel/bench/bench_fp4_moe.py b/kt-kernel/bench/bench_fp4_moe.py new file mode 100644 index 00000000..01b78490 --- /dev/null +++ b/kt-kernel/bench/bench_fp4_moe.py @@ -0,0 +1,313 @@ +#!/usr/bin/env python +# coding=utf-8 +"""Benchmark MXFP4 MoE kernel — V4-Flash shape, mat-vec / mat-mat 双路径覆盖。 + +Synthesizes V4-Flash-shaped MXFP4 weights (random nibbles + bf16 group scales), +runs the chosen backend over a list of batch sizes M, prints a throughput table. + +Routing modes (决定是否触发 mat-mat 路径): + balanced —— 每 token randperm(EXPERT_NUM)[:TOP_K]; 平均 per-expert m ≈ + M*top_k/expert_num. V4 真实路由分布. 大 batch (M=1024) 才 + 平均触发 mat-mat (per-expert m ≥ 4). + concentrated —— 所有 M token 共用同一组 top_k expert; per-expert m = M. + M=4 立即触发 mat-mat —— 用来直观放大 mat-mat 性能优势. + +Usage: + python bench/bench_fp4_moe.py --backend v1 + python bench/bench_fp4_moe.py --backend v1 --routing concentrated + python bench/bench_fp4_moe.py --all --routing concentrated # 所有可用 backend 对比 + +`--backend` 是预留扩展点; 当前编译只绑定 v1 (AMXFP4_KGroup_MOE)。要选 v2/v3 +需要 ext_bindings 里加新绑定。`--all` 会自动检测哪些 backend 可用。 +""" +import argparse +import json +import os +import platform +import subprocess +import sys +import time + +import torch + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "build")) +from kt_kernel import kt_kernel_ext # noqa: E402 + +# ----- V4-Flash MoE shape ----- +HIDDEN = 4096 +INTER = 2048 +EXPERT_NUM = 256 +TOP_K = 6 +K_GROUP_SIZE = 32 + +# ----- bench knobs ----- +# M=1024 时平均 per-expert m ≈ M*6/256 = 24, balanced 路由也能触发 mat-mat。 +DEFAULT_M_LIST = [1, 4, 16, 64, 256, 1024] +WARMUP_ITER = 200 +TEST_ITER = 2000 + +# ----- WorkerPool: 2 NUMA × 40 thread (matches bench_k2_moe_amx.py) ----- +WORKER_NUMA = 2 +WORKER_THREADS_PER_NUMA = 40 + +# ----- Backend registry: name → kt_kernel_ext.moe class (None = not bound) ----- +BACKENDS = { + "v1": getattr(kt_kernel_ext.moe, "AMXFP4_KGroup_MOE", None), + # 预留扩展点; 加新 backend 时在 ext_bindings 绑定后这里加一行即可。 + "v2": getattr(kt_kernel_ext.moe, "AMXFP4_KGroup_MOE_V2", None), +} + +# OCP MXFP4 (E2M1) codepoints — same order as the kernel's LUT. +E2M1_VALUES = torch.tensor( + [0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0, + -0.0, -0.5, -1.0, -1.5, -2.0, -3.0, -4.0, -6.0], + dtype=torch.float32, +) + + +def quantize_mxfp4_tensor(weights: torch.Tensor, group_size: int): + """[E, N, K] fp32/bf16 → packed nibbles uint8 [E, N, K/2] + bf16 scales [E, N, K/gs].""" + w = weights.to(torch.float32) + e, rows, cols = w.shape + assert cols % group_size == 0 and cols % 2 == 0 + reshaped = w.view(e, rows, cols // group_size, group_size) + max_abs = reshaped.abs().amax(dim=-1, keepdim=True).clamp(min=1e-8) + scales = (max_abs / 6.0).squeeze(-1) + normalized = reshaped / scales.unsqueeze(-1) + distances = torch.abs(normalized.unsqueeze(-1) - E2M1_VALUES.view(1, 1, 1, 1, 16)) + nibbles = distances.argmin(dim=-1).to(torch.uint8).view(e, rows, cols // 2, 2) + lo = nibbles[..., 0] + hi = nibbles[..., 1] + packed = ((hi << 4) | lo).contiguous() # uint8 [E, N, K/2] + scales = scales.to(torch.bfloat16).contiguous() # bf16 [E, N, K/gs] + return packed, scales + + +def build_synth_weights(): + torch.manual_seed(0) + gate = torch.randn((EXPERT_NUM, INTER, HIDDEN), dtype=torch.float32) / 100 + up = torch.randn((EXPERT_NUM, INTER, HIDDEN), dtype=torch.float32) / 100 + down = torch.randn((EXPERT_NUM, HIDDEN, INTER), dtype=torch.float32) / 100 + gw, gs = quantize_mxfp4_tensor(gate, K_GROUP_SIZE) + uw, us = quantize_mxfp4_tensor(up, K_GROUP_SIZE) + dw, ds = quantize_mxfp4_tensor(down, K_GROUP_SIZE) + return { + "gate_w": gw, "up_w": uw, "down_w": dw, + "gate_s": gs, "up_s": us, "down_s": ds, + } + + +def build_moe(backend: str, weights, cpu_infer): + cls = BACKENDS.get(backend) + if cls is None: + raise RuntimeError( + f"backend={backend} not bound in this build. Available: " + f"{[k for k, v in BACKENDS.items() if v is not None]}" + ) + cfg = kt_kernel_ext.moe.MOEConfig(EXPERT_NUM, TOP_K, HIDDEN, INTER, 0) + cfg.max_len = max(DEFAULT_M_LIST) + cfg.pool = cpu_infer.backend_ + cfg.quant_config.bits = 4 + cfg.quant_config.group_size = K_GROUP_SIZE + cfg.quant_config.zero_point = False + cfg.gate_projs = [[t.data_ptr() for t in weights["gate_w"]]] + cfg.up_projs = [[t.data_ptr() for t in weights["up_w"]]] + cfg.down_projs = [[t.data_ptr() for t in weights["down_w"]]] + cfg.gate_scales = [[t.data_ptr() for t in weights["gate_s"]]] + cfg.up_scales = [[t.data_ptr() for t in weights["up_s"]]] + cfg.down_scales = [[t.data_ptr() for t in weights["down_s"]]] + moe = cls(cfg) + p2l = torch.arange(EXPERT_NUM, dtype=torch.int64).contiguous() + cpu_infer.submit(moe.load_weights_task(p2l.data_ptr())) + cpu_infer.sync() + return moe + + +def make_expert_ids(M: int, routing: str) -> torch.Tensor: + """[M, TOP_K] int64 (kernel forward_binding casts to const int64_t*).""" + if routing == "concentrated": + # 所有 M token 共用同组 top_k expert → per-expert m = M + hot = torch.randperm(EXPERT_NUM)[:TOP_K] + return hot.unsqueeze(0).expand(M, TOP_K).contiguous().to(torch.int64) + # balanced: 每 token 独立 randperm + return torch.stack( + [torch.randperm(EXPERT_NUM)[:TOP_K] for _ in range(M)] + ).to(torch.int64).contiguous() + + +def bench_one_m(moe, cpu_infer, M: int, routing: str): + bsz = torch.tensor([M], dtype=torch.int32) + expert_ids = make_expert_ids(M, routing) + routing_w = torch.rand((M, TOP_K), dtype=torch.float32).contiguous() + x = (torch.randn((M, HIDDEN), dtype=torch.bfloat16) / 100).contiguous() + y = torch.empty((M, HIDDEN), dtype=torch.bfloat16).contiguous() + + for _ in range(WARMUP_ITER): + cpu_infer.submit(moe.forward_task( + bsz.data_ptr(), TOP_K, expert_ids.data_ptr(), + routing_w.data_ptr(), x.data_ptr(), y.data_ptr(), False)) + cpu_infer.sync() + + start = time.perf_counter() + for _ in range(TEST_ITER): + cpu_infer.submit(moe.forward_task( + bsz.data_ptr(), TOP_K, expert_ids.data_ptr(), + routing_w.data_ptr(), x.data_ptr(), y.data_ptr(), False)) + cpu_infer.sync() + total = time.perf_counter() - start + + per_iter_us = total / TEST_ITER * 1e6 + tok_per_s = M * TEST_ITER / total + unique_e = int(torch.unique(expert_ids).numel()) + avg_m_per_expert = float(M * TOP_K) / max(unique_e, 1) + return { + "M": M, "iters": TEST_ITER, "total_s": total, + "per_iter_us": per_iter_us, "tokens_per_s": tok_per_s, + "unique_experts": unique_e, "avg_m_per_expert": avg_m_per_expert, + } + + +def run_backend(backend: str, weights, cpu_infer, m_list, routing): + print(f"\n[bench-fp4] backend={backend} routing={routing}") + moe = build_moe(backend, weights, cpu_infer) + rows = [] + for M in m_list: + r = bench_one_m(moe, cpu_infer, M, routing) + rows.append(r) + print(f" M={M:>5} avg_m/expert={r['avg_m_per_expert']:>6.1f} " + f"per-iter={r['per_iter_us']:>9.1f} us tok/s={r['tokens_per_s']:>9.1f}") + return rows + + +def print_single_table(backend, rows, routing): + print(f"\n=== Summary ({backend}, routing={routing}) ===") + print(f"{'M':>5} {'avg_m':>6} {'per-iter us':>12} {'tok/s':>10}") + for r in rows: + print(f"{r['M']:>5} {r['avg_m_per_expert']:>6.1f} " + f"{r['per_iter_us']:>12.1f} {r['tokens_per_s']:>10.1f}") + + +def print_compare_table(all_rows: dict, routing: str): + backends = list(all_rows.keys()) + if len(backends) < 2: + print_single_table(backends[0], all_rows[backends[0]], routing) + return + base = backends[0] + print(f"\n=== {' vs '.join(backends)} (routing={routing}, base={base}) ===") + header = f"{'M':>5} {'avg_m':>6}" + for be in backends: + header += f" {be + ' us':>10}" + for be in backends[1:]: + header += f" {be + '/' + base:>8}" + print(header) + n_rows = len(all_rows[base]) + for i in range(n_rows): + line = f"{all_rows[base][i]['M']:>5} {all_rows[base][i]['avg_m_per_expert']:>6.1f}" + for be in backends: + line += f" {all_rows[be][i]['per_iter_us']:>10.1f}" + for be in backends[1:]: + ratio = all_rows[be][i]['per_iter_us'] / all_rows[base][i]['per_iter_us'] + line += f" {ratio:>8.3f}" + print(line) + + +def get_git_commit(): + try: + commit = subprocess.check_output(["git", "rev-parse", "HEAD"]).decode().strip() + dirty = bool(subprocess.check_output(["git", "status", "--porcelain"]).decode().strip()) + return {"commit": commit, "dirty": dirty} + except Exception as e: + return {"commit": None, "error": str(e)} + + +def get_system_info(): + info = {"node": platform.node(), "system": platform.system()} + cpu_model = None + try: + with open("/proc/cpuinfo") as f: + for line in f: + if "model name" in line: + cpu_model = line.split(":", 1)[1].strip() + break + except Exception: + pass + info["cpu"] = cpu_model + info["cpu_cores"] = os.cpu_count() + return info + + +def main(): + p = argparse.ArgumentParser() + p.add_argument("--backend", choices=list(BACKENDS.keys()), default="v1", + help="单 backend 模式(被 --all 覆盖)。当前可用: " + + ",".join(k for k, v in BACKENDS.items() if v is not None)) + p.add_argument("--all", action="store_true", + help="跑所有已绑定的 backend 并打印对比表(自动跳过未绑定的)") + p.add_argument("--routing", choices=["balanced", "concentrated"], default="balanced", + help="balanced=每 token randperm (V4 真实); " + "concentrated=所有 token 共用同组 top_k (per-expert m=M, 放大 mat-mat)") + p.add_argument("--m-list", type=str, default=None, + help=f"Comma-separated M values, default: {','.join(map(str, DEFAULT_M_LIST))}") + p.add_argument("--numa", type=int, default=WORKER_NUMA) + p.add_argument("--threads-per-numa", type=int, default=WORKER_THREADS_PER_NUMA) + args = p.parse_args() + + m_list = [int(x) for x in args.m_list.split(",")] if args.m_list else DEFAULT_M_LIST + + if args.all: + backends = [k for k, v in BACKENDS.items() if v is not None] + if not backends: + raise RuntimeError("No MXFP4 backend bound in this build.") + print(f"[bench-fp4] --all: detected backends = {backends}") + else: + if BACKENDS.get(args.backend) is None: + raise RuntimeError( + f"backend={args.backend} not bound. Available: " + f"{[k for k, v in BACKENDS.items() if v is not None]}" + ) + backends = [args.backend] + + print(f"[bench-fp4] shape=H{HIDDEN}/I{INTER}/E{EXPERT_NUM}/k{TOP_K}/gs{K_GROUP_SIZE} routing={args.routing}") + print(f"[bench-fp4] WorkerPool: numa={args.numa} threads_per_numa={args.threads_per_numa}") + print(f"[bench-fp4] m_list: {m_list}") + + wp = kt_kernel_ext.WorkerPoolConfig() + wp.subpool_count = args.numa + wp.subpool_numa_map = list(range(args.numa)) + wp.subpool_thread_count = [args.threads_per_numa] * args.numa + cpu_infer = kt_kernel_ext.CPUInfer(wp) + + print("[bench-fp4] synthesizing MXFP4 weights …") + weights = build_synth_weights() + + all_rows = {} + for be in backends: + all_rows[be] = run_backend(be, weights, cpu_infer, m_list, args.routing) + + if len(backends) > 1: + print_compare_table(all_rows, args.routing) + else: + print_single_table(backends[0], all_rows[backends[0]], args.routing) + + # JSONL log + out_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "bench_fp4_moe.jsonl") + ts = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) + git = get_git_commit() + sys_info = get_system_info() + with open(out_path, "a") as f: + for be, rows in all_rows.items(): + record = { + "backend": be, + "routing": args.routing, + "shape": {"hidden": HIDDEN, "inter": INTER, "expert_num": EXPERT_NUM, + "top_k": TOP_K, "k_group_size": K_GROUP_SIZE}, + "worker_pool": {"numa": args.numa, "threads_per_numa": args.threads_per_numa}, + "rows": rows, + "git": git, "system": sys_info, "timestamp": ts, + } + f.write(json.dumps(record) + "\n") + print(f"\n[bench-fp4] appended {len(backends)} record(s) → {out_path}") + + +if __name__ == "__main__": + main() diff --git a/kt-kernel/examples/test_fp4_moe_amx.py b/kt-kernel/examples/test_fp4_moe_amx.py new file mode 100644 index 00000000..ec68bb5f --- /dev/null +++ b/kt-kernel/examples/test_fp4_moe_amx.py @@ -0,0 +1,354 @@ +import math +import os +import sys +from typing import Dict + +sys.path.insert(0, os.path.dirname(__file__) + "/../build") + +import torch +from kt_kernel import kt_kernel_ext + +torch.manual_seed(42) + +hidden_size = 7168 +intermediate_size = 2048 +max_len = 25600 + +expert_num = 16 +num_experts_per_tok = 8 + +layer_num = 1 +CPUInfer = kt_kernel_ext.CPUInfer(40) +validation_iter = 3 +k_group_size = 32 +debug_print_count = 16 + +# Forward dispatch in do_gate_up_gemm uses `qlen > 4 * expert_num / top_k` +# (= 8 with these constants), so qlen=1 hits mat-vec and qlen=32 hits the +# mat-mat 4×4 register tile (per-expert avg m = qlen*top_k/expert_num = 16). +QLEN_LIST = [1, 32] +DISPATCH_THRESHOLD = 4 * expert_num / num_experts_per_tok + +physical_to_logical_map = torch.tensor(data=range(expert_num), device="cpu", dtype=torch.int64).contiguous() + +# E2M1 values: {0, ±0.5, ±1.0, ±1.5, ±2.0, ±3.0, ±4.0, ±6.0} +E2M1_VALUES = torch.tensor([ + 0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0, + -0.0, -0.5, -1.0, -1.5, -2.0, -3.0, -4.0, -6.0, +], dtype=torch.float32) + +# Nibble encoding: index = 4-bit value +# 0b0000..0b0111 = positive, 0b1000..0b1111 = negative +E2M1_NIBBLE_MAP = torch.tensor([ + 0, # 0b0000 = 0.0 + 1, # 0b0001 = 0.5 + 2, # 0b0010 = 1.0 + 3, # 0b0011 = 1.5 + 4, # 0b0100 = 2.0 + 5, # 0b0101 = 3.0 + 6, # 0b0110 = 4.0 + 7, # 0b0111 = 6.0 + 8, # 0b1000 = -0.0 + 9, # 0b1001 = -0.5 + 10, # 0b1010 = -1.0 + 11, # 0b1011 = -1.5 + 12, # 0b1100 = -2.0 + 13, # 0b1101 = -3.0 + 14, # 0b1110 = -4.0 + 15, # 0b1111 = -6.0 +], dtype=torch.int32) + + +def _pattern_uniform(groups: int) -> torch.Tensor: + return torch.full((groups,), 0.02, dtype=torch.float32) + + +def _pattern_alternating(groups: int) -> torch.Tensor: + vals = torch.full((groups,), 0.015, dtype=torch.float32) + vals[1::2] = 0.03 + return vals + + +def _pattern_ramp(groups: int) -> torch.Tensor: + return torch.linspace(0.005, 0.04, steps=groups, dtype=torch.float32) + + +WEIGHT_PATTERNS = { + "uniform_scale": ("All k-groups share the same abs max / scale", _pattern_uniform), + "alternating_scale": ("Alternate small / large abs max per k-group", _pattern_alternating), + "ramp_scale": ("Linearly increasing abs max per k-group", _pattern_ramp), + "random": ("Random bf16 weights (baseline)", None), +} + + +def act_fn(x): + return x / (1.0 + torch.exp(-x)) + + +def mlp_torch(input, gate_proj, up_proj, down_proj): + gate_buf = torch.mm(input, gate_proj.t()) + up_buf = torch.mm(input, up_proj.t()) + intermediate = act_fn(gate_buf) * up_buf + ret = torch.mm(intermediate, down_proj.t()) + return ret + + +def moe_torch(input, expert_ids, weights, gate_proj, up_proj, down_proj): + cnts = expert_ids.new_zeros((expert_ids.shape[0], expert_num)) + cnts.scatter_(1, expert_ids, 1) + tokens_per_expert = cnts.sum(dim=0) + idxs = expert_ids.view(-1).argsort() + sorted_tokens = input[idxs // expert_ids.shape[1]] + + outputs = [] + start_idx = 0 + + for i, num_tokens in enumerate(tokens_per_expert): + end_idx = start_idx + num_tokens + if num_tokens == 0: + continue + tokens_for_this_expert = sorted_tokens[start_idx:end_idx] + expert_out = mlp_torch(tokens_for_this_expert, gate_proj[i], up_proj[i], down_proj[i]) + outputs.append(expert_out) + start_idx = end_idx + + outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0) + + new_x = torch.empty_like(outs) + new_x[idxs] = outs + t_output = ( + new_x.view(*expert_ids.shape, -1) + .type(weights.dtype) + .mul_(weights.unsqueeze(dim=-1)) + .sum(dim=1) + .type(new_x.dtype) + ) + return t_output + + +def quantize_mxfp4_tensor(weights: torch.Tensor, group_size: int): + """ + MXFP4 E2M1 quantization per k-group. + For each block of group_size (32) elements in K dimension: + scale = max_abs / 6.0 + quantized = round(value / scale) to nearest E2M1 value + Args: + weights: [expert_num, rows (N), cols (K)] in bf16 + Returns: + packed: int32 tensor storing 8 FP4 nibbles per int32, shape [expert_num, rows * (cols // 8)] + scales: bfloat16 tensor with shape [expert_num, rows * (cols // group_size)] + """ + weights_f32 = weights.to(torch.float32) + e, rows, cols = weights_f32.shape + if cols % group_size != 0 or cols % 2 != 0: + raise ValueError(f"cols ({cols}) must be divisible by group_size ({group_size}) and 2") + + reshaped = weights_f32.view(e, rows, cols // group_size, group_size) + max_abs = reshaped.abs().amax(dim=-1, keepdim=True) + max_abs = torch.clamp(max_abs, min=1e-8) + scales = (max_abs / 6.0).squeeze(-1) + + # Quantize: round(value / scale) to nearest E2M1 value + normalized = reshaped / scales.unsqueeze(-1) + + # For each normalized value, find the closest E2M1 value + # E2M1_VALUES shape: [16] + e2m1_vals = E2M1_VALUES.view(1, 1, 1, 1, 16) # broadcast over (e, rows, groups, group_size, 16) + normalized_expanded = normalized.unsqueeze(-1) # (e, rows, groups, group_size, 1) + distances = torch.abs(normalized_expanded - e2m1_vals) + closest_indices = distances.argmin(dim=-1) # (e, rows, groups, group_size) — indices 0..15 + + # Dequantized values for reference + dequant = E2M1_VALUES[closest_indices].to(torch.float32) * scales.unsqueeze(-1) + dequant = dequant.view(e, rows, cols) + + # Pack nibbles: each byte = (hi_nibble << 4) | lo_nibble + # Column-major: consecutive K elements are consecutive nibbles + # nibble at even K index goes to low nibble, odd K index goes to high nibble + # But wait — looking at the kernel's mxfp4_to_bf16_32: + # lo = packed & 0x0F, hi = (packed >> 4) & 0x0F + # And the interleaving is: [lo[0],hi[0], lo[1],hi[1], ...] = column order + # So for column-major layout: each byte has lo nibble = element at col c, hi nibble = element at col c+1 + # But the weight buffer layout is: b_row[k_block] = 16 packed bytes covering 32 K elements + # With column-major: K is the innermost dimension + # For 32 elements per k_group: byte 0 = [K_0 | K_1], byte 1 = [K_2 | K_3], ..., byte 15 = [K_30 | K_31] + # low nibble = even K index, high nibble = odd K index + + nibbles = closest_indices.to(torch.uint8) # 0..15, each is already a 4-bit value + nibbles = nibbles.view(e, rows, cols // 2, 2) + lo = nibbles[..., 0] # even K indices + hi = nibbles[..., 1] # odd K indices + packed_bytes = (hi << 4) | lo # low nibble first in memory (little-endian style) + + # Pack 4 bytes into int32 + bytes_view = packed_bytes.view(e, rows, cols // 8, 4) + packed_int32 = ( + bytes_view[..., 0].to(torch.int32) | + (bytes_view[..., 1].to(torch.int32) << 8) | + (bytes_view[..., 2].to(torch.int32) << 16) | + (bytes_view[..., 3].to(torch.int32) << 24) + ) + packed_int32 = packed_int32.view(e, rows, cols // 8).contiguous() + + scales = scales.to(torch.bfloat16).contiguous().view(e, rows, cols // group_size).contiguous() + + return packed_int32, scales, dequant + + +def build_structured_tensor(shape: torch.Size, pattern: str) -> torch.Tensor: + if pattern == "random": + torch.manual_seed(42) + return (torch.randn(shape, dtype=torch.bfloat16, device="cpu") / 100.0).contiguous() + + e, rows, cols = shape + groups = cols // k_group_size + group_builder = WEIGHT_PATTERNS[pattern][1] + group_vals = group_builder(groups).to(torch.float32) + block = group_vals.view(1, 1, groups, 1).expand(e, rows, groups, k_group_size).clone() + row_signs = torch.where( + (torch.arange(rows) % 2 == 0), + torch.ones(rows, dtype=torch.float32), + -torch.ones(rows, dtype=torch.float32), + ).view(1, rows, 1, 1) + col_offsets = torch.linspace(-0.0005, 0.0005, steps=k_group_size, dtype=torch.float32).view(1, 1, 1, k_group_size) + block = block * row_signs + col_offsets + return block.reshape(shape).to(torch.bfloat16).contiguous() + + +def prepare_mxfp4_quantized_weights(pattern: str) -> Dict[str, torch.Tensor]: + if pattern not in WEIGHT_PATTERNS: + raise ValueError(f"Unknown weight pattern: {pattern}") + + gate_proj = build_structured_tensor((expert_num, intermediate_size, hidden_size), pattern) + up_proj = build_structured_tensor((expert_num, intermediate_size, hidden_size), pattern) + down_proj = build_structured_tensor((expert_num, hidden_size, intermediate_size), pattern) + + gate_q, gate_scales, gate_dequant = quantize_mxfp4_tensor(gate_proj, k_group_size) + up_q, up_scales, up_dequant = quantize_mxfp4_tensor(up_proj, k_group_size) + down_q, down_scales, down_dequant = quantize_mxfp4_tensor(down_proj, k_group_size) + + return { + "gate_qweight": gate_q.contiguous(), + "up_qweight": up_q.contiguous(), + "down_qweight": down_q.contiguous(), + "gate_scales": gate_scales.contiguous(), + "up_scales": up_scales.contiguous(), + "down_scales": down_scales.contiguous(), + "dequantized": { + "gate_proj": gate_dequant.to(torch.bfloat16).contiguous(), + "up_proj": up_dequant.to(torch.bfloat16).contiguous(), + "down_proj": down_dequant.to(torch.bfloat16).contiguous(), + }, + } + + +def build_moes_from_quantized_data(quant_data: Dict[str, torch.Tensor]): + moes = [] + with torch.inference_mode(mode=True): + for _ in range(layer_num): + config = kt_kernel_ext.moe.MOEConfig(expert_num, num_experts_per_tok, hidden_size, intermediate_size, 0) + config.max_len = max_len + config.quant_config.bits = 4 + config.quant_config.group_size = k_group_size + config.quant_config.zero_point = False + + config.gate_proj = quant_data["gate_qweight"].data_ptr() + config.up_proj = quant_data["up_qweight"].data_ptr() + config.down_proj = quant_data["down_qweight"].data_ptr() + + config.gate_scale = quant_data["gate_scales"].data_ptr() + config.up_scale = quant_data["up_scales"].data_ptr() + config.down_scale = quant_data["down_scales"].data_ptr() + config.pool = CPUInfer.backend_ + + moe = kt_kernel_ext.moe.AMXFP4_KGroup_MOE(config) + CPUInfer.submit(moe.load_weights_task(physical_to_logical_map.data_ptr())) + CPUInfer.sync() + moes.append(moe) + return moes + + +def run_case(pattern: str, qlen: int) -> Dict[str, float]: + print("\n" + "=" * 70) + desc = WEIGHT_PATTERNS[pattern][0] + path = "mat-vec" if qlen <= DISPATCH_THRESHOLD else "mat-mat" + print(f"Running case: {pattern} -> {desc} (qlen={qlen}, path={path})") + print("=" * 70) + + quant_data = prepare_mxfp4_quantized_weights(pattern) + moes = build_moes_from_quantized_data(quant_data) + + dequant_weights = quant_data["dequantized"] + gate_bf16 = dequant_weights["gate_proj"] + up_bf16 = dequant_weights["up_proj"] + down_bf16 = dequant_weights["down_proj"] + + diffs = [] + with torch.inference_mode(mode=True): + for i in range(validation_iter): + torch.manual_seed(100 + i) + bsz_tensor = torch.tensor([qlen], device="cpu") + expert_ids = torch.stack( + [torch.randperm(expert_num)[:num_experts_per_tok] for _ in range(qlen)] + ).contiguous() + weights = torch.randn((qlen, num_experts_per_tok), dtype=torch.float32).contiguous() + input_tensor = torch.randn((qlen, hidden_size), dtype=torch.bfloat16).contiguous() / 100 + output = torch.empty((qlen, hidden_size), dtype=torch.bfloat16).contiguous() + + moe = moes[i % layer_num] + CPUInfer.submit( + moe.forward_task( + bsz_tensor.data_ptr(), + num_experts_per_tok, + expert_ids.data_ptr(), + weights.data_ptr(), + input_tensor.data_ptr(), + output.data_ptr(), + False, + ) + ) + CPUInfer.sync() + + # Torch reference: use dequantized weights + input_tensor_bf16 = input_tensor.to(torch.bfloat16) + t_output = moe_torch(input_tensor_bf16, expert_ids, weights, gate_bf16, up_bf16, down_bf16).to( + torch.bfloat16 + ) + + t_output = t_output.flatten() + output = output.flatten() + + diff = torch.mean(torch.abs(output.float() - t_output.float())) / ( + torch.mean(torch.abs(t_output.float())) + 1e-12 + ) + diffs.append(diff.item()) + print(f"[{pattern}] Iteration {i}: relative L1 diff = {diff:.4f}") + print(f" output {output[:debug_print_count]}") + print(f" t_output {t_output[:debug_print_count]}") + + mean_diff = float(sum(diffs) / len(diffs)) + max_diff = float(max(diffs)) + min_diff = float(min(diffs)) + return {"case": pattern, "description": desc, "mean": mean_diff, "max": max_diff, "min": min_diff} + + +def run_fp4_moe_test(): + summary_rows = [] + for qlen in QLEN_LIST: + path = "mat-vec" if qlen <= DISPATCH_THRESHOLD else "mat-mat" + print(f"\n##### qlen={qlen} path={path} #####") + for case_name in WEIGHT_PATTERNS.keys(): + results = run_case(case_name, qlen) + results["qlen"] = qlen + results["path"] = path + summary_rows.append(results) + + print("\n=== Case vs. Relative Error Summary ===") + print(f"{'Case':<20} {'qlen':>5} {'path':<8} {'Mean':>10} {'Max':>10} {'Min':>10}") + for row in summary_rows: + print(f"{row['case']:<20} {row['qlen']:>5} {row['path']:<8} " + f"{row['mean']*100:9.2f}% {row['max']*100:9.2f}% {row['min']*100:9.2f}%") + + +if __name__ == "__main__": + run_fp4_moe_test() diff --git a/kt-kernel/examples/test_fp4_moe_v4.py b/kt-kernel/examples/test_fp4_moe_v4.py new file mode 100644 index 00000000..30bf41ae --- /dev/null +++ b/kt-kernel/examples/test_fp4_moe_v4.py @@ -0,0 +1,178 @@ +"""End-to-end MXFP4 MoE validation against the native DeepSeek-V4-Flash ckpt. + +Loads layer-`LAYER_ID` experts via :class:`MXFP4SafeTensorLoader`, runs the AMX +FP4 backend, and compares against a torch reference that dequantizes the same +nibble-packed weights with the OCP E2M1 LUT. + +Usage: + python test_fp4_moe_v4.py --weight-path /path/to/DeepSeek-V4-Flash [--layer 1] +""" + +from __future__ import annotations + +import argparse +import os +import sys +from typing import Tuple + +import torch + +# Allow running from kt-kernel/examples without install. +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + "/build") +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + "/python") + +from kt_kernel import kt_kernel_ext # noqa: E402 +from kt_kernel.utils.loader import MXFP4SafeTensorLoader # noqa: E402 + +# OCP E2M1 codepoints in our LUT order (matches operators/amx/fp4-moe.hpp). +E2M1_VALUES = torch.tensor( + [0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0, + -0.0, -0.5, -1.0, -1.5, -2.0, -3.0, -4.0, -6.0], + dtype=torch.float32, +) + + +def dequantize_mxfp4(weight_u8: torch.Tensor, scale_bf16: torch.Tensor, group_size: int) -> torch.Tensor: + """Decode a [N, K/2] uint8 tensor of nibble-packed E2M1 with [N, K/gs] bf16 + scales into a [N, K] bf16 weight tensor. + + Layout (matches kernel's mxfp4_to_bf16_32): byte `b` low nibble = element K=2b, + high nibble = element K=2b+1. + """ + n, k_packed = weight_u8.shape + k = k_packed * 2 + assert k % group_size == 0, f"K={k} must be divisible by group_size={group_size}" + assert scale_bf16.shape == (n, k // group_size) + + lo = (weight_u8 & 0x0F).to(torch.long) + hi = ((weight_u8 >> 4) & 0x0F).to(torch.long) + nibbles = torch.stack([lo, hi], dim=-1).view(n, k) # interleave back to K order + decoded = E2M1_VALUES.to(weight_u8.device)[nibbles] # [N, K] fp32 + + scale_fp32 = scale_bf16.to(torch.float32) + scale_full = scale_fp32.repeat_interleave(group_size, dim=-1) # [N, K] + return (decoded * scale_full).to(torch.bfloat16).contiguous() + + +def reference_mlp(x: torch.Tensor, gate: torch.Tensor, up: torch.Tensor, down: torch.Tensor) -> torch.Tensor: + g = torch.mm(x, gate.t()) + u = torch.mm(x, up.t()) + silu = g / (1.0 + torch.exp(-g.float())).to(g.dtype) + return torch.mm(silu * u, down.t()) + + +def reference_moe( + hidden: torch.Tensor, + expert_ids: torch.Tensor, + weights: torch.Tensor, + gate_w: torch.Tensor, # [E, N, K] + up_w: torch.Tensor, + down_w: torch.Tensor, +) -> torch.Tensor: + out = torch.zeros_like(hidden, dtype=torch.float32) + for tok in range(hidden.shape[0]): + for slot in range(expert_ids.shape[1]): + eid = int(expert_ids[tok, slot]) + w = float(weights[tok, slot]) + x = hidden[tok : tok + 1] + y = reference_mlp(x, gate_w[eid], up_w[eid], down_w[eid]) + out[tok] += w * y[0].float() + return out.to(hidden.dtype) + + +def parse_args() -> argparse.Namespace: + p = argparse.ArgumentParser() + p.add_argument("--weight-path", required=True, help="Path to DeepSeek-V4-Flash safetensors directory.") + p.add_argument("--layer", type=int, default=1, help="Layer index to validate (default: 1).") + p.add_argument("--qlen", type=int, default=1, help="Number of tokens to test.") + p.add_argument("--top-k", type=int, default=6, help="num_experts_per_tok (V4 default 6).") + p.add_argument("--cpu-threads", type=int, default=32) + p.add_argument("--max-experts", type=int, default=0, help="Cap number of experts loaded (0 = all).") + return p.parse_args() + + +def main() -> int: + args = parse_args() + torch.manual_seed(0) + + print(f"[V4-MXFP4] Loading layer {args.layer} from {args.weight_path}") + loader = MXFP4SafeTensorLoader(args.weight_path) + weights = loader.load_experts(f"model.layers.{args.layer}") + + expert_num = len(weights["gate"]) + if args.max_experts and args.max_experts < expert_num: + for k in ("gate", "up", "down", "gate_scale", "up_scale", "down_scale"): + weights[k] = weights[k][: args.max_experts] + expert_num = args.max_experts + print(f"[V4-MXFP4] expert_num={expert_num}") + + gate0 = weights["gate"][0] + down0 = weights["down"][0] + intermediate_size = gate0.shape[0] + hidden_size = gate0.shape[1] * 2 # nibble-packed K + assert down0.shape == (hidden_size, intermediate_size // 2), f"unexpected down shape {down0.shape}" + + group_size = hidden_size // weights["gate_scale"][0].shape[1] + print(f"[V4-MXFP4] hidden={hidden_size} inter={intermediate_size} gs={group_size}") + assert group_size == 32, "MXFP4 backend hard-codes group_size=32" + + physical_to_logical = torch.arange(expert_num, dtype=torch.int64).contiguous() + + # ----- AMX FP4 forward ----- + cpu_infer = kt_kernel_ext.CPUInfer(args.cpu_threads) + cfg = kt_kernel_ext.moe.MOEConfig(expert_num, args.top_k, hidden_size, intermediate_size, 0) + cfg.layer_idx = args.layer + cfg.max_len = max(args.qlen, 1) + cfg.pool = cpu_infer.backend_ + cfg.quant_config.bits = 4 + cfg.quant_config.group_size = group_size + cfg.quant_config.zero_point = False + + cfg.gate_projs = [[t.data_ptr() for t in weights["gate"]]] + cfg.up_projs = [[t.data_ptr() for t in weights["up"]]] + cfg.down_projs = [[t.data_ptr() for t in weights["down"]]] + cfg.gate_scales = [[t.data_ptr() for t in weights["gate_scale"]]] + cfg.up_scales = [[t.data_ptr() for t in weights["up_scale"]]] + cfg.down_scales = [[t.data_ptr() for t in weights["down_scale"]]] + + moe = kt_kernel_ext.moe.AMXFP4_KGroup_MOE(cfg) + cpu_infer.submit(moe.load_weights_task(physical_to_logical.data_ptr())) + cpu_infer.sync() + + qlen = args.qlen + top_k = args.top_k + bsz = torch.tensor([qlen], dtype=torch.int32) + expert_ids = torch.stack([torch.randperm(expert_num)[:top_k] for _ in range(qlen)]).to(torch.int32).contiguous() + routing = torch.randn((qlen, top_k), dtype=torch.float32).contiguous() + x = (torch.randn((qlen, hidden_size), dtype=torch.bfloat16) / 100).contiguous() + y_amx = torch.empty((qlen, hidden_size), dtype=torch.bfloat16).contiguous() + + cpu_infer.submit( + moe.forward_task( + bsz.data_ptr(), top_k, expert_ids.data_ptr(), routing.data_ptr(), + x.data_ptr(), y_amx.data_ptr(), False, + ) + ) + cpu_infer.sync() + + # ----- Torch reference (dequantize same nibbles + scales) ----- + print("[V4-MXFP4] Building torch reference (dequantizing all loaded experts)…") + gate_bf16 = torch.stack([dequantize_mxfp4(weights["gate"][i], weights["gate_scale"][i], group_size) for i in range(expert_num)]) + up_bf16 = torch.stack([dequantize_mxfp4(weights["up"][i], weights["up_scale"][i], group_size) for i in range(expert_num)]) + down_bf16 = torch.stack([dequantize_mxfp4(weights["down"][i], weights["down_scale"][i], group_size) for i in range(expert_num)]) + + y_ref = reference_moe(x, expert_ids, routing, gate_bf16, up_bf16, down_bf16) + + diff = (y_amx.float() - y_ref.float()).abs() + rel = diff.mean() / (y_ref.float().abs().mean() + 1e-12) + print(f"[V4-MXFP4] mean abs diff = {diff.mean().item():.4e}") + print(f"[V4-MXFP4] max abs diff = {diff.max().item():.4e}") + print(f"[V4-MXFP4] rel mean diff = {rel.item()*100:.3f}%") + print(f"[V4-MXFP4] amx[:8] = {y_amx.flatten()[:8]}") + print(f"[V4-MXFP4] ref[:8] = {y_ref.flatten()[:8]}") + + return 0 if rel.item() < 0.10 else 1 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/kt-kernel/ext_bindings.cpp b/kt-kernel/ext_bindings.cpp index 34aae308..f22207f6 100644 --- a/kt-kernel/ext_bindings.cpp +++ b/kt-kernel/ext_bindings.cpp @@ -44,6 +44,7 @@ static const bool _is_plain_ = false; #if defined(__x86_64__) && defined(USE_AMX_AVX_KERNEL) #include "operators/amx/awq-moe.hpp" #include "operators/amx/bf16-moe.hpp" // Native BF16 MoE using CRTP pattern, with fallback for AVX512F +#include "operators/amx/fp4-moe.hpp" // MXFP4 MoE: FP4 E2M1 weights × BF16 activations #include "operators/amx/fp8-moe.hpp" // FP8 MoE requires AVX512 BF16 support, with fallback for AVX512F+BW #include "operators/amx/fp8-perchannel-moe.hpp" // FP8 Per-Channel MoE for GLM-4.7-FP8 #include "operators/amx/k2-moe.hpp" @@ -788,6 +789,9 @@ PYBIND11_MODULE(kt_kernel_ext, m) { bind_moe_module>(moe_module, "AMXBF16_MOE"); bind_moe_module>(moe_module, "AMXFP8_MOE"); bind_moe_module>(moe_module, "AMXFP8PerChannel_MOE"); +#endif +#if defined(__AVX512BF16__) + bind_moe_module>(moe_module, "AMXFP4_KGroup_MOE"); #endif // SFT MoE with LoRA support (BF16, INT8, INT4, AWQ, K2) bind_moe_sft_module>(moe_module, "AMXBF16_SFT_MOE"); diff --git a/kt-kernel/operators/amx/fp4-moe.hpp b/kt-kernel/operators/amx/fp4-moe.hpp new file mode 100644 index 00000000..ca261f56 --- /dev/null +++ b/kt-kernel/operators/amx/fp4-moe.hpp @@ -0,0 +1,783 @@ +/** + * @Description : MXFP4 MoE operator — FP4 E2M1 weights × BF16 activations + * @Author : oql, Codex and Claude + * @Date : 2026-04-20 + * @Version : 1.0.0 + * @Copyright (c) 2024 by KVCache.AI, All Rights Reserved. + * + * Based on k2-moe.hpp (RAWINT4). Key differences from RAWINT4: + * Weight: FP4 E2M1 (nibble-packed, same layout) → PSHUFB lookup → BF16 + * Act: BF16 direct (BufferABF16Impl, no online INT8 quantization) + * Dot prod: _mm512_dpbf16_ps (BF16×BF16→FP32) instead of _mm512_dpbssd_epi32 + * Scale: FP32 per-group scale (weight only, no activation scale) + **/ +#ifndef CPUINFER_OPERATOR_AMX_FP4_MOE_H +#define CPUINFER_OPERATOR_AMX_FP4_MOE_H + +#include "la/amx_raw_buffers.hpp" // BufferABF16Impl +#include "moe_base.hpp" + +namespace amx { + +// ============================================================================ +// MXFP4 kernel: FP4 E2M1 weights × BF16 activations → FP32 output (AVX512) +// ============================================================================ +struct GemmKernel224MXFP4SmallKGroup { + using dt = uint8_t; + using output_t = float; + static constexpr double ELEMENT_SIZE = 0.5; + + static const int M_STEP = 1; + static const int N_STEP = 32; + static const int K_STEP = 32; + + static inline const int N_BLOCK = 256; + static inline const int K_BLOCK = 7168; + + static std::string name() { return "MXFP4_KGROUP"; } + static int recommended_nth(int n) { return (n + N_BLOCK - 1) / N_BLOCK; } + static std::pair split_range_n(int n, int ith, int nth) { + int n_start = N_BLOCK * ith; + int n_end = std::min(n, N_BLOCK * (ith + 1)); + return {n_start, n_end}; + } + static void config() {} + + // FP4 E2M1 → BF16 LUTs (16 entries each, for PSHUFB within 128-bit lanes) + // E2M1 values: {0, ±0.5, ±1.0, ±1.5, ±2.0, ±3.0, ±4.0, ±6.0} + alignas(16) static constexpr uint8_t fp4_bf16_lo[16] = { + 0x00, 0x00, 0x80, 0xC0, 0x00, 0x40, 0x80, 0xC0, // 0..7 positive + 0x00, 0x00, 0x80, 0xC0, 0x00, 0x40, 0x80, 0xC0}; // 8..15 negative + alignas(16) static constexpr uint8_t fp4_bf16_hi[16] = { + 0x00, 0x3F, 0x3F, 0x3F, 0x40, 0x40, 0x40, 0x40, // 0..7 positive + 0x80, 0xBF, 0xBF, 0xBF, 0xC0, 0xC0, 0xC0, 0xC0}; // 8..15 negative + + // Convert 16 packed FP4 bytes (32 values = 1 k_group) → 32 BF16 values (__m512i) + // Output column order: [BF16(lo[0]),BF16(hi[0]), ..., BF16(lo[15]),BF16(hi[15])] + __attribute__((always_inline)) static inline __m512i mxfp4_to_bf16_32(__m128i packed) { + __m128i lo_mask = _mm_set1_epi8(0x0F); + __m128i lo = _mm_and_si128(packed, lo_mask); + __m128i hi = _mm_and_si128(_mm_srli_epi16(packed, 4), lo_mask); + + __m128i lut_lo = _mm_load_si128((__m128i*)fp4_bf16_lo); + __m128i lut_hi = _mm_load_si128((__m128i*)fp4_bf16_hi); + + // Look up low/high bytes for lo nibbles → 16 BF16 values + __m128i l_lo = _mm_shuffle_epi8(lut_lo, lo); + __m128i l_hi = _mm_shuffle_epi8(lut_hi, lo); + __m128i lo_bf16_0 = _mm_unpacklo_epi8(l_lo, l_hi); // BF16(lo[0..7]) + __m128i lo_bf16_1 = _mm_unpackhi_epi8(l_lo, l_hi); // BF16(lo[8..15]) + + // Look up low/high bytes for hi nibbles → 16 BF16 values + __m128i h_lo = _mm_shuffle_epi8(lut_lo, hi); + __m128i h_hi = _mm_shuffle_epi8(lut_hi, hi); + __m128i hi_bf16_0 = _mm_unpacklo_epi8(h_lo, h_hi); // BF16(hi[0..7]) + __m128i hi_bf16_1 = _mm_unpackhi_epi8(h_lo, h_hi); // BF16(hi[8..15]) + + // Interleave lo/hi at 16-bit: [lo[0],hi[0], lo[1],hi[1], ...] = column order + __m128i p0 = _mm_unpacklo_epi16(lo_bf16_0, hi_bf16_0); // cols 0..7 + __m128i p1 = _mm_unpackhi_epi16(lo_bf16_0, hi_bf16_0); // cols 8..15 + __m128i p2 = _mm_unpacklo_epi16(lo_bf16_1, hi_bf16_1); // cols 16..23 + __m128i p3 = _mm_unpackhi_epi16(lo_bf16_1, hi_bf16_1); // cols 24..31 + + __m256i q0 = _mm256_inserti128_si256(_mm256_castsi128_si256(p0), p1, 1); + __m256i q1 = _mm256_inserti128_si256(_mm256_castsi128_si256(p2), p3, 1); + return _mm512_inserti64x4(_mm512_castsi256_si512(q0), q1, 1); + } + + // Buffers + using BufferA = BufferABF16Impl; // raw BF16, no quant + using BufferB = BufferBInt4KGroupImpl; // nibble-packed FP4 + using BufferC = BufferCReduceImpl; // FP32 reduce + + // 4 个 zmm 的 horizontal reduce → 4 个连续 fp32。 + // 4 次 reduce_add_ps 之间无依赖,编译器/CPU 可并行调度。 + __attribute__((always_inline)) static inline void + reduce4(__m512 s0, __m512 s1, __m512 s2, __m512 s3, float* dst) { + dst[0] = _mm512_reduce_add_ps(s0); + dst[1] = _mm512_reduce_add_ps(s1); + dst[2] = _mm512_reduce_add_ps(s2); + dst[3] = _mm512_reduce_add_ps(s3); + } + + // mat-vec: M 个独立 token,N 维 4 行一组累加,摊销 horizontal reduce。 + static void fp4_mat_vec_kgroup(int m, int n, int k, int k_group_size, BufferA* ba, BufferB* bb, BufferC* bc, + int ith, int nth) { + auto [n_start, n_end] = split_range_n(n, ith, nth); + if (n_start >= n_end) return; + const int kg_count = k / 32; + + for (int m_idx = 0; m_idx < m; m_idx++) { + float* c_row = bc->get_submat(m, n, m_idx, n_start); + __m512bh* a_row = (__m512bh*)ba->get_submat(m, k, m_idx, 0); + + int n_pos = n_start; + // 主循环: N 维 4 行一组 + for (; n_pos + 4 <= n_end; n_pos += 4) { + __m128i* w0 = (__m128i*)bb->get_submat(n, k, n_pos + 0, 0); + __m128i* w1 = (__m128i*)bb->get_submat(n, k, n_pos + 1, 0); + __m128i* w2 = (__m128i*)bb->get_submat(n, k, n_pos + 2, 0); + __m128i* w3 = (__m128i*)bb->get_submat(n, k, n_pos + 3, 0); + const float* s0 = bb->get_scale(n, n_pos + 0, k, 0); + const float* s1 = bb->get_scale(n, n_pos + 1, k, 0); + const float* s2 = bb->get_scale(n, n_pos + 2, k, 0); + const float* s3 = bb->get_scale(n, n_pos + 3, k, 0); + + __m512 acc0 = _mm512_setzero_ps(); + __m512 acc1 = _mm512_setzero_ps(); + __m512 acc2 = _mm512_setzero_ps(); + __m512 acc3 = _mm512_setzero_ps(); + + for (int g = 0; g < kg_count; g++) { + const __m512bh a = a_row[g]; + const __m512bh d0 = (__m512bh)mxfp4_to_bf16_32(w0[g]); + const __m512bh d1 = (__m512bh)mxfp4_to_bf16_32(w1[g]); + const __m512bh d2 = (__m512bh)mxfp4_to_bf16_32(w2[g]); + const __m512bh d3 = (__m512bh)mxfp4_to_bf16_32(w3[g]); + acc0 = _mm512_fmadd_ps(_mm512_set1_ps(s0[g]), + _mm512_dpbf16_ps(_mm512_setzero_ps(), a, d0), acc0); + acc1 = _mm512_fmadd_ps(_mm512_set1_ps(s1[g]), + _mm512_dpbf16_ps(_mm512_setzero_ps(), a, d1), acc1); + acc2 = _mm512_fmadd_ps(_mm512_set1_ps(s2[g]), + _mm512_dpbf16_ps(_mm512_setzero_ps(), a, d2), acc2); + acc3 = _mm512_fmadd_ps(_mm512_set1_ps(s3[g]), + _mm512_dpbf16_ps(_mm512_setzero_ps(), a, d3), acc3); + } + reduce4(acc0, acc1, acc2, acc3, c_row + (n_pos - n_start)); + } + // N 尾巴: N % 4 != 0 时单行 fallback + for (; n_pos < n_end; n_pos++) { + __m128i* w = (__m128i*)bb->get_submat(n, k, n_pos, 0); + const float* s = bb->get_scale(n, n_pos, k, 0); + __m512 acc = _mm512_setzero_ps(); + for (int g = 0; g < kg_count; g++) { + const __m512bh a = a_row[g]; + const __m512bh d = (__m512bh)mxfp4_to_bf16_32(w[g]); + acc = _mm512_fmadd_ps(_mm512_set1_ps(s[g]), + _mm512_dpbf16_ps(_mm512_setzero_ps(), a, d), acc); + } + c_row[n_pos - n_start] = _mm512_reduce_add_ps(acc); + } + } + } + + // mat-mat: 4×4 register tile (M_TILE=4, N_TILE=4 → 16 累加器)。 + // 每 K-group 解码 4 行 N 一次, 被 4 个 token 共享 → PSHUFB 解码开销 / 4。 + // M / N 尾巴回退到 mat-vec 单 token 内层 (V4 chunked-prefill 16/32/64 整数倍, 极少触发)。 + static void fp4_mat_mat_kgroup(int m, int n, int k, int k_group_size, BufferA* ba, BufferB* bb, BufferC* bc, + int ith, int nth) { + auto [n_start, n_end] = split_range_n(n, ith, nth); + if (n_start >= n_end) return; + const int kg_count = k / 32; + constexpr int MB = 4; + constexpr int NB = 4; + + int m_pos = 0; + for (; m_pos + MB <= m; m_pos += MB) { + __m512bh* a_rows[MB] = { + (__m512bh*)ba->get_submat(m, k, m_pos + 0, 0), + (__m512bh*)ba->get_submat(m, k, m_pos + 1, 0), + (__m512bh*)ba->get_submat(m, k, m_pos + 2, 0), + (__m512bh*)ba->get_submat(m, k, m_pos + 3, 0), + }; + + int n_pos = n_start; + for (; n_pos + NB <= n_end; n_pos += NB) { + __m128i* w0 = (__m128i*)bb->get_submat(n, k, n_pos + 0, 0); + __m128i* w1 = (__m128i*)bb->get_submat(n, k, n_pos + 1, 0); + __m128i* w2 = (__m128i*)bb->get_submat(n, k, n_pos + 2, 0); + __m128i* w3 = (__m128i*)bb->get_submat(n, k, n_pos + 3, 0); + const float* s0 = bb->get_scale(n, n_pos + 0, k, 0); + const float* s1 = bb->get_scale(n, n_pos + 1, k, 0); + const float* s2 = bb->get_scale(n, n_pos + 2, k, 0); + const float* s3 = bb->get_scale(n, n_pos + 3, k, 0); + + __m512 acc[MB][NB]; + for (int i = 0; i < MB; i++) + for (int j = 0; j < NB; j++) acc[i][j] = _mm512_setzero_ps(); + + for (int g = 0; g < kg_count; g++) { + // 4 行权重解码一次, MB 个 token 共享 + const __m512bh d0 = (__m512bh)mxfp4_to_bf16_32(w0[g]); + const __m512bh d1 = (__m512bh)mxfp4_to_bf16_32(w1[g]); + const __m512bh d2 = (__m512bh)mxfp4_to_bf16_32(w2[g]); + const __m512bh d3 = (__m512bh)mxfp4_to_bf16_32(w3[g]); + const __m512 sv0 = _mm512_set1_ps(s0[g]); + const __m512 sv1 = _mm512_set1_ps(s1[g]); + const __m512 sv2 = _mm512_set1_ps(s2[g]); + const __m512 sv3 = _mm512_set1_ps(s3[g]); + + #define V_FMA_ROW(M_I) do { \ + const __m512bh a = a_rows[M_I][g]; \ + acc[M_I][0] = _mm512_fmadd_ps(sv0, _mm512_dpbf16_ps(_mm512_setzero_ps(), a, d0), acc[M_I][0]); \ + acc[M_I][1] = _mm512_fmadd_ps(sv1, _mm512_dpbf16_ps(_mm512_setzero_ps(), a, d1), acc[M_I][1]); \ + acc[M_I][2] = _mm512_fmadd_ps(sv2, _mm512_dpbf16_ps(_mm512_setzero_ps(), a, d2), acc[M_I][2]); \ + acc[M_I][3] = _mm512_fmadd_ps(sv3, _mm512_dpbf16_ps(_mm512_setzero_ps(), a, d3), acc[M_I][3]); \ + } while (0) + V_FMA_ROW(0); + V_FMA_ROW(1); + V_FMA_ROW(2); + V_FMA_ROW(3); + #undef V_FMA_ROW + } + for (int i = 0; i < MB; i++) { + float* c_row = bc->get_submat(m, n, m_pos + i, n_start); + reduce4(acc[i][0], acc[i][1], acc[i][2], acc[i][3], c_row + (n_pos - n_start)); + } + } + // N 尾巴: 单 N 列 × MB token (V4 不触发) + for (; n_pos < n_end; n_pos++) { + __m128i* w = (__m128i*)bb->get_submat(n, k, n_pos, 0); + const float* s = bb->get_scale(n, n_pos, k, 0); + for (int i = 0; i < MB; i++) { + float* c_row = bc->get_submat(m, n, m_pos + i, n_start); + __m512 acc = _mm512_setzero_ps(); + for (int g = 0; g < kg_count; g++) { + acc = _mm512_fmadd_ps(_mm512_set1_ps(s[g]), + _mm512_dpbf16_ps(_mm512_setzero_ps(), + a_rows[i][g], + (__m512bh)mxfp4_to_bf16_32(w[g])), + acc); + } + c_row[n_pos - n_start] = _mm512_reduce_add_ps(acc); + } + } + } + // M 尾巴: M 不是 MB 倍数时余下 token, 退回单 token mat-vec 内层 (V4 不触发) + for (int mi = m_pos; mi < m; mi++) { + float* c_row = bc->get_submat(m, n, mi, n_start); + __m512bh* a_row = (__m512bh*)ba->get_submat(m, k, mi, 0); + int n_pos = n_start; + for (; n_pos + 4 <= n_end; n_pos += 4) { + __m128i* w0 = (__m128i*)bb->get_submat(n, k, n_pos + 0, 0); + __m128i* w1 = (__m128i*)bb->get_submat(n, k, n_pos + 1, 0); + __m128i* w2 = (__m128i*)bb->get_submat(n, k, n_pos + 2, 0); + __m128i* w3 = (__m128i*)bb->get_submat(n, k, n_pos + 3, 0); + const float* s0 = bb->get_scale(n, n_pos + 0, k, 0); + const float* s1 = bb->get_scale(n, n_pos + 1, k, 0); + const float* s2 = bb->get_scale(n, n_pos + 2, k, 0); + const float* s3 = bb->get_scale(n, n_pos + 3, k, 0); + __m512 a0 = _mm512_setzero_ps(), a1 = _mm512_setzero_ps(), + a2 = _mm512_setzero_ps(), a3 = _mm512_setzero_ps(); + for (int g = 0; g < kg_count; g++) { + const __m512bh a = a_row[g]; + a0 = _mm512_fmadd_ps(_mm512_set1_ps(s0[g]), + _mm512_dpbf16_ps(_mm512_setzero_ps(), a, (__m512bh)mxfp4_to_bf16_32(w0[g])), a0); + a1 = _mm512_fmadd_ps(_mm512_set1_ps(s1[g]), + _mm512_dpbf16_ps(_mm512_setzero_ps(), a, (__m512bh)mxfp4_to_bf16_32(w1[g])), a1); + a2 = _mm512_fmadd_ps(_mm512_set1_ps(s2[g]), + _mm512_dpbf16_ps(_mm512_setzero_ps(), a, (__m512bh)mxfp4_to_bf16_32(w2[g])), a2); + a3 = _mm512_fmadd_ps(_mm512_set1_ps(s3[g]), + _mm512_dpbf16_ps(_mm512_setzero_ps(), a, (__m512bh)mxfp4_to_bf16_32(w3[g])), a3); + } + reduce4(a0, a1, a2, a3, c_row + (n_pos - n_start)); + } + for (; n_pos < n_end; n_pos++) { + __m128i* w = (__m128i*)bb->get_submat(n, k, n_pos, 0); + const float* s = bb->get_scale(n, n_pos, k, 0); + __m512 acc = _mm512_setzero_ps(); + for (int g = 0; g < kg_count; g++) { + acc = _mm512_fmadd_ps(_mm512_set1_ps(s[g]), + _mm512_dpbf16_ps(_mm512_setzero_ps(), + a_row[g], + (__m512bh)mxfp4_to_bf16_32(w[g])), + acc); + } + c_row[n_pos - n_start] = _mm512_reduce_add_ps(acc); + } + } + } +}; + +// Dispatch functions +inline void vec_mul_kgroup(int m, int n, int k, int k_group_size, + std::shared_ptr ba, + std::shared_ptr bb, + std::shared_ptr bc, int ith, int nth) { + GemmKernel224MXFP4SmallKGroup::fp4_mat_vec_kgroup(m, n, k, k_group_size, ba.get(), bb.get(), bc.get(), ith, nth); +} + +inline void mat_mul_kgroup(int m, int n, int k, int k_group_size, + std::shared_ptr ba, + std::shared_ptr bb, + std::shared_ptr bc, int ith, int nth) { + GemmKernel224MXFP4SmallKGroup::fp4_mat_mat_kgroup(m, n, k, k_group_size, ba.get(), bb.get(), bc.get(), ith, nth); +} + +} // namespace amx + +// ============================================================================ +// AMX_FP4_MOE_TP — CRTP class, identical structure to AMX_K2_MOE_TP +// ============================================================================ +template +class AMX_FP4_MOE_TP : public AMX_MOE_BASE> { + using Base = AMX_MOE_BASE>; + using Base::config_; + using Base::down_ba_; + using Base::down_bb_; + using Base::down_bc_; + using Base::gate_bb_; + using Base::gate_bc_; + using Base::gate_up_ba_; + using Base::m_local_num_; + using Base::tp_part_idx; + using Base::up_bb_; + using Base::up_bc_; + + public: + using typename Base::input_t; + using typename Base::output_t; + + AMX_FP4_MOE_TP() = default; + AMX_FP4_MOE_TP(GeneralMOEConfig config, int tp_part_idx_ = 0) : Base(config, tp_part_idx_) {} + + void derived_init() { + auto& quant_config = config_.quant_config; + if (quant_config.group_size == 0 || quant_config.zero_point) { + throw std::runtime_error("MXFP4 MoE only supports KGroup FP4"); + } + printf("Creating AMX_FP4_MOE_TP %d at numa %d\n", tp_part_idx, numa_node_of_cpu(sched_getcpu())); + } + + ~AMX_FP4_MOE_TP() = default; + + // BufferA: raw BF16, no group_size needed + size_t buffer_a_required_size_impl(size_t m, size_t k) const { return T::BufferA::required_size(m, k); } + size_t buffer_b_required_size_impl(size_t n, size_t k) const { + return T::BufferB::required_size(n, k, config_.quant_config.group_size); + } + size_t buffer_c_required_size_impl(size_t m, size_t n) const { return T::BufferC::required_size(m, n); } + + std::shared_ptr make_buffer_a_impl(size_t m, size_t k, void* data) const { + return std::make_shared(m, k, data); + } + std::shared_ptr make_buffer_b_impl(size_t n, size_t k, void* data) const { + return std::make_shared(n, k, config_.quant_config.group_size, data); + } + std::shared_ptr make_buffer_c_impl(size_t m, size_t n, void* data) const { + return std::make_shared(m, n, data); + } + + void do_gate_up_gemm(bool do_up, int expert_idx, int ith, int nth, int qlen) { + auto& group_size = config_.quant_config.group_size; + int m = m_local_num_[expert_idx]; + auto& ba = gate_up_ba_[expert_idx]; + auto& bb = do_up ? up_bb_[expert_idx] : gate_bb_[expert_idx]; + auto& bc = do_up ? up_bc_[expert_idx] : gate_bc_[expert_idx]; + + if (qlen > 4 * config_.expert_num / config_.num_experts_per_tok) { + amx::mat_mul_kgroup(m, config_.intermediate_size, config_.hidden_size, group_size, ba, bb, bc, ith, nth); + } else { + amx::vec_mul_kgroup(m, config_.intermediate_size, config_.hidden_size, group_size, ba, bb, bc, ith, nth); + } + } + + void do_down_gemm(int expert_idx, int ith, int nth, int qlen) { + auto& group_size = config_.quant_config.group_size; + int m = m_local_num_[expert_idx]; + + if (qlen > 4 * config_.expert_num / config_.num_experts_per_tok) { + amx::mat_mul_kgroup(m, config_.hidden_size, config_.intermediate_size, group_size, down_ba_[expert_idx], + down_bb_[expert_idx], down_bc_[expert_idx], ith, nth); + } else { + amx::vec_mul_kgroup(m, config_.hidden_size, config_.intermediate_size, group_size, down_ba_[expert_idx], + down_bb_[expert_idx], down_bc_[expert_idx], ith, nth); + } + } + + void load_weights() { + auto& quant_config = config_.quant_config; + const uint64_t* physical_to_logical_map = (const uint64_t*)config_.physical_to_logical_map; + auto pool = config_.pool->get_subpool(tp_part_idx); + + if (quant_config.group_size == 0 || quant_config.zero_point) + throw std::runtime_error("MXFP4 MoE only support KGroup FP4."); + if (config_.gate_scale == nullptr) throw std::runtime_error("MXFP4 MoE only support load native weight."); + + int nth = T::recommended_nth(config_.intermediate_size); + pool->do_work_stealing_job( + nth * config_.expert_num, nullptr, + [this, nth, physical_to_logical_map](int task_id) { + uint64_t expert_idx = task_id / nth; + uint64_t logical_expert_id = expert_map(physical_to_logical_map, expert_idx); + int ith = task_id % nth; + gate_bb_[expert_idx]->from_raw_mat( + (uint8_t*)config_.gate_proj + + ((logical_expert_id * config_.intermediate_size * config_.hidden_size) >> 1), + ith, nth); + up_bb_[expert_idx]->from_raw_mat( + (uint8_t*)config_.up_proj + ((logical_expert_id * config_.intermediate_size * config_.hidden_size) >> 1), + ith, nth); + }, + nullptr); + + nth = T::recommended_nth(config_.hidden_size); + pool->do_work_stealing_job( + nth * config_.expert_num, nullptr, + [this, nth, physical_to_logical_map](int task_id) { + uint64_t expert_idx = task_id / nth; + uint64_t logical_expert_id = expert_map(physical_to_logical_map, expert_idx); + int ith = task_id % nth; + down_bb_[expert_idx]->from_raw_mat( + (uint8_t*)config_.down_proj + + ((logical_expert_id * config_.hidden_size * config_.intermediate_size) >> 1), + ith, nth); + }, + nullptr); + + pool->do_work_stealing_job( + config_.expert_num, nullptr, + [this, physical_to_logical_map](int task_id) { + uint64_t expert_idx = task_id; + uint64_t logical_expert_id = expert_map(physical_to_logical_map, expert_idx); + size_t scale_elem_count = (config_.hidden_size * config_.intermediate_size) / config_.quant_config.group_size; + convert_or_copy(gate_bb_[expert_idx]->d, + (ggml_bf16_t*)config_.gate_scale + (logical_expert_id * scale_elem_count), scale_elem_count); + convert_or_copy(up_bb_[expert_idx]->d, + (ggml_bf16_t*)config_.up_scale + (logical_expert_id * scale_elem_count), scale_elem_count); + convert_or_copy(down_bb_[expert_idx]->d, + (ggml_bf16_t*)config_.down_scale + (logical_expert_id * scale_elem_count), scale_elem_count); + }, + nullptr); + } + + static inline void fast_memcpy(void* __restrict dst, const void* __restrict src, size_t bytes) { + uint8_t* d = (uint8_t*)dst; + const uint8_t* s = (const uint8_t*)src; + size_t chunks = bytes / 64; + for (size_t i = 0; i < chunks; i++) { + __m512i data = _mm512_loadu_si512((__m512i*)s); + _mm512_storeu_si512((__m512i*)d, data); + d += 64; + s += 64; + } + if (bytes -= chunks * 64) std::memcpy(d, s, bytes); + } + + static inline void fast_fp32_to_bf16(ggml_bf16_t* __restrict dst, const float* __restrict src, size_t count) { + size_t i = 0; + for (; i + 32 <= count; i += 32) { + __m512 v0 = _mm512_loadu_ps(src + i); + __m512 v1 = _mm512_loadu_ps(src + i + 16); + __m512i i0 = _mm512_srli_epi32(_mm512_castps_si512(v0), 16); + __m512i i1 = _mm512_srli_epi32(_mm512_castps_si512(v1), 16); + __m512i packed = _mm512_packus_epi32(i0, i1); + __m512i permuted = _mm512_permutexvar_epi64(_mm512_set_epi64(7, 5, 3, 1, 6, 4, 2, 0), packed); + _mm512_storeu_si512((__m512i*)(dst + i), permuted); + } + for (; i < count; i++) dst[i] = ggml_fp32_to_bf16(src[i]); + } + + void write_weights_to_buffer(int gpu_tp_count, int cpu_tp_count, int expert_id, const GeneralMOEConfig& full_config, + const std::vector& w13_weight_ptrs, + const std::vector& w13_scale_ptrs, + const std::vector& w2_weight_ptrs, + const std::vector& w2_scale_ptrs) const { + const int group_size = config_.quant_config.group_size; + auto pool = config_.pool->get_subpool(tp_part_idx); + + size_t cpu_tp_weight_elem_count = (size_t)config_.intermediate_size * config_.hidden_size; + size_t cpu_tp_weight_bytes = cpu_tp_weight_elem_count / 2; + size_t cpu_tp_scale_elem_count = cpu_tp_weight_elem_count / group_size; + + size_t gpu_tp_weight_elem_count = (size_t)full_config.intermediate_size * full_config.hidden_size / gpu_tp_count; + size_t gpu_tp_weight_bytes = gpu_tp_weight_elem_count / 2; + size_t gpu_tp_scale_elem_count = gpu_tp_weight_elem_count / group_size; + + if (cpu_tp_count >= gpu_tp_count) { + int target_gpu_tp = tp_part_idx / (cpu_tp_count / gpu_tp_count); + int local_idx = tp_part_idx % (cpu_tp_count / gpu_tp_count); + + uint8_t* w13_weight_dst = (uint8_t*)w13_weight_ptrs[target_gpu_tp]; + ggml_bf16_t* w13_scale_dst = (ggml_bf16_t*)w13_scale_ptrs[target_gpu_tp]; + uint8_t* w2_weight_dst = (uint8_t*)w2_weight_ptrs[target_gpu_tp]; + ggml_bf16_t* w2_scale_dst = (ggml_bf16_t*)w2_scale_ptrs[target_gpu_tp]; + + size_t offset_in_gpu_weight = local_idx * cpu_tp_weight_bytes; + size_t offset_in_gpu_scale = local_idx * cpu_tp_scale_elem_count; + + constexpr int NUM_WEIGHT_TASKS = 8; + constexpr int MIN_COLS_PER_TASK = 128; + int num_down_tasks = std::max(1, (int)config_.hidden_size / MIN_COLS_PER_TASK); + num_down_tasks = std::min(num_down_tasks, 32); + int total_tasks = NUM_WEIGHT_TASKS * 2 + num_down_tasks + 2; + + size_t weight_chunk_size = (cpu_tp_weight_bytes + NUM_WEIGHT_TASKS - 1) / NUM_WEIGHT_TASKS; + weight_chunk_size = (weight_chunk_size + 63) & ~63ULL; + + pool->do_work_stealing_job( + total_tasks, nullptr, + [&, this, num_down_tasks, expert_id, weight_chunk_size, offset_in_gpu_weight, offset_in_gpu_scale, + gpu_tp_weight_bytes, gpu_tp_scale_elem_count, w13_weight_dst, w13_scale_dst, w2_weight_dst, w2_scale_dst, + group_size](int task_id) { + if (task_id < NUM_WEIGHT_TASKS) { + int chunk_idx = task_id; + size_t start = chunk_idx * weight_chunk_size; + size_t end = std::min(start + weight_chunk_size, cpu_tp_weight_bytes); + if (start < end) + fast_memcpy(w13_weight_dst + offset_in_gpu_weight + start, (uint8_t*)gate_bb_[expert_id]->b + start, + end - start); + } else if (task_id < NUM_WEIGHT_TASKS * 2) { + int chunk_idx = task_id - NUM_WEIGHT_TASKS; + size_t start = chunk_idx * weight_chunk_size; + size_t end = std::min(start + weight_chunk_size, cpu_tp_weight_bytes); + if (start < end) + fast_memcpy(w13_weight_dst + offset_in_gpu_weight + gpu_tp_weight_bytes + start, + (uint8_t*)up_bb_[expert_id]->b + start, end - start); + } else if (task_id < NUM_WEIGHT_TASKS * 2 + num_down_tasks) { + int chunk_idx = task_id - NUM_WEIGHT_TASKS * 2; + size_t cols_per_chunk = (config_.hidden_size + num_down_tasks - 1) / num_down_tasks; + size_t col_start = chunk_idx * cols_per_chunk; + size_t col_end = std::min(col_start + cols_per_chunk, (size_t)config_.hidden_size); + + size_t weight_per_col = config_.intermediate_size >> 1; + size_t scale_per_col = config_.intermediate_size / group_size; + size_t gpu_weight_stride = (full_config.intermediate_size / gpu_tp_count) >> 1; + size_t gpu_scale_stride = (full_config.intermediate_size / gpu_tp_count) / group_size; + size_t gpu_weight_slice_offset = local_idx * weight_per_col; + size_t gpu_scale_slice_offset = local_idx * scale_per_col; + + for (size_t col = col_start; col < col_end; col++) { + fast_memcpy(w2_weight_dst + col * gpu_weight_stride + gpu_weight_slice_offset, + (uint8_t*)down_bb_[expert_id]->b + col * weight_per_col, weight_per_col); + fast_fp32_to_bf16(w2_scale_dst + col * gpu_scale_stride + gpu_scale_slice_offset, + down_bb_[expert_id]->d + col * scale_per_col, scale_per_col); + } + } else if (task_id == NUM_WEIGHT_TASKS * 2 + num_down_tasks) { + fast_fp32_to_bf16(w13_scale_dst + offset_in_gpu_scale, gate_bb_[expert_id]->d, cpu_tp_scale_elem_count); + } else { + fast_fp32_to_bf16(w13_scale_dst + offset_in_gpu_scale + gpu_tp_scale_elem_count, up_bb_[expert_id]->d, + cpu_tp_scale_elem_count); + } + }, + nullptr); + } else { + int gpu_tps_per_cpu_tp = gpu_tp_count / cpu_tp_count; + int start_gpu_tp = tp_part_idx * gpu_tps_per_cpu_tp; + + size_t data_per_gpu_tp_weight = cpu_tp_weight_bytes / gpu_tps_per_cpu_tp; + size_t data_per_gpu_tp_scale = cpu_tp_scale_elem_count / gpu_tps_per_cpu_tp; + + constexpr int NUM_WEIGHT_TASKS = 8; + constexpr int MIN_COLS_PER_TASK = 128; + int num_down_tasks = std::max(1, (int)config_.hidden_size / MIN_COLS_PER_TASK); + num_down_tasks = std::min(num_down_tasks, 32); + int tasks_per_gpu_tp = NUM_WEIGHT_TASKS * 2 + num_down_tasks + 2; + int total_tasks = tasks_per_gpu_tp * gpu_tps_per_cpu_tp; + + size_t weight_chunk_size = (data_per_gpu_tp_weight + NUM_WEIGHT_TASKS - 1) / NUM_WEIGHT_TASKS; + weight_chunk_size = (weight_chunk_size + 63) & ~63ULL; + + pool->do_work_stealing_job( + total_tasks, nullptr, + [&, this, gpu_tps_per_cpu_tp, start_gpu_tp, data_per_gpu_tp_weight, data_per_gpu_tp_scale, num_down_tasks, + tasks_per_gpu_tp, expert_id, weight_chunk_size, gpu_tp_weight_bytes, gpu_tp_scale_elem_count, + group_size](int task_id) { + int local_gpu_idx = task_id / tasks_per_gpu_tp; + int task_type = task_id % tasks_per_gpu_tp; + int gpu_tp_idx = start_gpu_tp + local_gpu_idx; + + uint8_t* w13_weight_dst = (uint8_t*)w13_weight_ptrs[gpu_tp_idx]; + ggml_bf16_t* w13_scale_dst = (ggml_bf16_t*)w13_scale_ptrs[gpu_tp_idx]; + uint8_t* w2_weight_dst = (uint8_t*)w2_weight_ptrs[gpu_tp_idx]; + ggml_bf16_t* w2_scale_dst = (ggml_bf16_t*)w2_scale_ptrs[gpu_tp_idx]; + + size_t cpu_offset_weight = local_gpu_idx * data_per_gpu_tp_weight; + size_t cpu_offset_scale = local_gpu_idx * data_per_gpu_tp_scale; + + if (task_type < NUM_WEIGHT_TASKS) { + int chunk_idx = task_type; + size_t start = chunk_idx * weight_chunk_size; + size_t end = std::min(start + weight_chunk_size, data_per_gpu_tp_weight); + if (start < end) + fast_memcpy(w13_weight_dst + start, (uint8_t*)gate_bb_[expert_id]->b + cpu_offset_weight + start, + end - start); + } else if (task_type < NUM_WEIGHT_TASKS * 2) { + int chunk_idx = task_type - NUM_WEIGHT_TASKS; + size_t start = chunk_idx * weight_chunk_size; + size_t end = std::min(start + weight_chunk_size, data_per_gpu_tp_weight); + if (start < end) + fast_memcpy(w13_weight_dst + gpu_tp_weight_bytes + start, + (uint8_t*)up_bb_[expert_id]->b + cpu_offset_weight + start, end - start); + } else if (task_type < NUM_WEIGHT_TASKS * 2 + num_down_tasks) { + int chunk_idx = task_type - NUM_WEIGHT_TASKS * 2; + size_t cols_per_chunk = (config_.hidden_size + num_down_tasks - 1) / num_down_tasks; + size_t col_start = chunk_idx * cols_per_chunk; + size_t col_end = std::min(col_start + cols_per_chunk, (size_t)config_.hidden_size); + + size_t weight_per_gpu_col = (config_.intermediate_size / gpu_tps_per_cpu_tp) >> 1; + size_t scale_per_gpu_col = (config_.intermediate_size / gpu_tps_per_cpu_tp) / group_size; + + for (size_t col = col_start; col < col_end; col++) { + size_t col_offset_weight = (col * config_.intermediate_size / 2) + + (local_gpu_idx * data_per_gpu_tp_weight / config_.hidden_size); + size_t col_offset_scale = (col * (config_.intermediate_size / group_size)) + + (local_gpu_idx * data_per_gpu_tp_scale / config_.hidden_size); + + fast_memcpy(w2_weight_dst + col * weight_per_gpu_col, + (uint8_t*)down_bb_[expert_id]->b + col_offset_weight, weight_per_gpu_col); + fast_fp32_to_bf16(w2_scale_dst + col * scale_per_gpu_col, down_bb_[expert_id]->d + col_offset_scale, + scale_per_gpu_col); + } + } else if (task_type == NUM_WEIGHT_TASKS * 2 + num_down_tasks) { + fast_fp32_to_bf16(w13_scale_dst, gate_bb_[expert_id]->d + cpu_offset_scale, data_per_gpu_tp_scale); + } else { + fast_fp32_to_bf16(w13_scale_dst + gpu_tp_scale_elem_count, up_bb_[expert_id]->d + cpu_offset_scale, + data_per_gpu_tp_scale); + } + }, + nullptr); + } + } +}; + +// ============================================================================ +// TP_MOE specialization for AMX_FP4_MOE_TP +// ============================================================================ +template +class TP_MOE> : public TP_MOE>> { + public: + using Base = TP_MOE>>; + using Base::Base; + + void load_weights() override { + auto& config = this->config; + auto& tps = this->tps; + auto& tp_count = this->tp_count; + auto pool = config.pool; + const uint64_t* physical_to_logical_map = (const uint64_t*)config.physical_to_logical_map; + + bool use_per_expert_ptrs = !config.gate_projs.empty(); + + if (config.gate_projs.empty() && config.gate_scale == nullptr) + throw std::runtime_error("MXFP4 MoE only supports Packed FP4 with KGroup Scale"); + + printf("From %s\n", use_per_expert_ptrs ? "per-expert pointers (gate_projs)" : "Packed FP4 with KGroup Scale"); + + int& group_size = config.quant_config.group_size; + + pool->dispense_backend()->do_numa_job([&, this](int i) { + auto& tpc = tps[i]->config_; + size_t weight_elem_count = tpc.intermediate_size * tpc.hidden_size; + size_t scales_elem_count = (tpc.hidden_size / group_size) * tpc.intermediate_size; + + tpc.gate_proj = new uint8_t[(tpc.expert_num * weight_elem_count) / 2]; + tpc.up_proj = new uint8_t[(tpc.expert_num * weight_elem_count) / 2]; + tpc.down_proj = new uint8_t[(tpc.expert_num * weight_elem_count) / 2]; + tpc.gate_scale = new ggml_bf16_t[tpc.expert_num * scales_elem_count]; + tpc.up_scale = new ggml_bf16_t[tpc.expert_num * scales_elem_count]; + tpc.down_scale = new ggml_bf16_t[tpc.expert_num * scales_elem_count]; + + if (use_per_expert_ptrs) { + pool->get_subpool(i)->do_work_stealing_job( + tpc.expert_num, nullptr, + [&, i](int expert_id_) { + size_t expert_id = expert_map(physical_to_logical_map, expert_id_); + + uint8_t* src_gate = (uint8_t*)config.gate_projs[0][expert_id]; + uint8_t* src_up = (uint8_t*)config.up_projs[0][expert_id]; + uint8_t* src_down = (uint8_t*)config.down_projs[0][expert_id]; + ggml_bf16_t* src_gate_scale = (ggml_bf16_t*)config.gate_scales[0][expert_id]; + ggml_bf16_t* src_up_scale = (ggml_bf16_t*)config.up_scales[0][expert_id]; + ggml_bf16_t* src_down_scale = (ggml_bf16_t*)config.down_scales[0][expert_id]; + + memcpy((uint8_t*)tpc.gate_proj + ((expert_id * weight_elem_count) >> 1), + src_gate + ((i * weight_elem_count) >> 1), (weight_elem_count >> 1)); + memcpy((uint8_t*)tpc.up_proj + ((expert_id * weight_elem_count) >> 1), + src_up + ((i * weight_elem_count) >> 1), (weight_elem_count >> 1)); + memcpy((ggml_bf16_t*)tpc.gate_scale + (expert_id * scales_elem_count), + src_gate_scale + (i * scales_elem_count), sizeof(ggml_bf16_t) * scales_elem_count); + memcpy((ggml_bf16_t*)tpc.up_scale + (expert_id * scales_elem_count), + src_up_scale + (i * scales_elem_count), sizeof(ggml_bf16_t) * scales_elem_count); + + for (size_t col = 0; col < config.hidden_size; col++) { + memcpy((uint8_t*)tpc.down_proj + ((expert_id * weight_elem_count + col * tpc.intermediate_size) >> 1), + src_down + ((col * config.intermediate_size + i * tpc.intermediate_size) >> 1), + (tpc.intermediate_size >> 1)); + memcpy((ggml_bf16_t*)tpc.down_scale + + (expert_id * scales_elem_count + col * (tpc.intermediate_size / group_size)), + src_down_scale + + (col * (config.intermediate_size / group_size) + i * (tpc.intermediate_size / group_size)), + sizeof(ggml_bf16_t) * (tpc.intermediate_size / group_size)); + } + }, + nullptr); + } else { + if (tpc.load == false) { + pool->get_subpool(i)->do_work_stealing_job( + tpc.expert_num, nullptr, + [&, i](int expert_id_) { + size_t expert_id = expert_map(physical_to_logical_map, expert_id_); + + memcpy((uint8_t*)tpc.gate_proj + ((expert_id * weight_elem_count) >> 1), + (uint8_t*)config.gate_proj + + ((expert_id * config.intermediate_size * config.hidden_size + i * weight_elem_count) >> 1), + (weight_elem_count >> 1)); + memcpy((uint8_t*)tpc.up_proj + ((expert_id * weight_elem_count) >> 1), + (uint8_t*)config.up_proj + + ((expert_id * config.intermediate_size * config.hidden_size + i * weight_elem_count) >> 1), + (weight_elem_count >> 1)); + memcpy((ggml_bf16_t*)tpc.gate_scale + (expert_id * scales_elem_count), + (ggml_bf16_t*)config.gate_scale + + (expert_id * (config.hidden_size / group_size) * config.intermediate_size + + i * scales_elem_count), + sizeof(ggml_bf16_t) * scales_elem_count); + memcpy((ggml_bf16_t*)tpc.up_scale + (expert_id * scales_elem_count), + (ggml_bf16_t*)config.up_scale + + (expert_id * (config.hidden_size / group_size) * config.intermediate_size + + i * scales_elem_count), + sizeof(ggml_bf16_t) * scales_elem_count); + + for (size_t col = 0; col < config.hidden_size; col++) { + memcpy((uint8_t*)tpc.down_proj + ((expert_id * weight_elem_count + col * tpc.intermediate_size) >> 1), + (uint8_t*)config.down_proj + ((expert_id * config.intermediate_size * config.hidden_size + + col * config.intermediate_size + i * tpc.intermediate_size) >> + 1), + (tpc.intermediate_size >> 1)); + memcpy((ggml_bf16_t*)tpc.down_scale + + (expert_id * scales_elem_count + col * (tpc.intermediate_size / group_size)), + (ggml_bf16_t*)config.down_scale + + ((expert_id * (config.intermediate_size / group_size) * config.hidden_size) + + col * (config.intermediate_size / group_size) + i * (tpc.intermediate_size / group_size)), + sizeof(ggml_bf16_t) * (tpc.intermediate_size / group_size)); + } + }, + nullptr); + } + } + printf("TP %d load weight done.\n", i); + }); + + DO_TPS_LOAD_WEIGHTS(pool); + + pool->dispense_backend()->do_numa_job([&, this](int i) { + auto& tpc = tps[i]->config_; + delete[] (uint8_t*)(tpc.gate_proj); + delete[] (uint8_t*)(tpc.up_proj); + delete[] (uint8_t*)(tpc.down_proj); + delete[] (ggml_bf16_t*)(tpc.gate_scale); + delete[] (ggml_bf16_t*)(tpc.up_scale); + delete[] (ggml_bf16_t*)(tpc.down_scale); + }); + + this->weights_loaded = true; + } + + void write_weight_scale_to_buffer(int gpu_tp_count, int expert_id, const std::vector& w13_weight_ptrs, + const std::vector& w13_scale_ptrs, + const std::vector& w2_weight_ptrs, + const std::vector& w2_scale_ptrs) { + if (!this->weights_loaded) throw std::runtime_error("Not Loaded"); + if (this->tps.empty()) throw std::runtime_error("No TP parts initialized"); + if (w13_weight_ptrs.size() != gpu_tp_count || w13_scale_ptrs.size() != gpu_tp_count || + w2_weight_ptrs.size() != gpu_tp_count || w2_scale_ptrs.size() != gpu_tp_count) + throw std::runtime_error("Pointer arrays size must match gpu_tp_count"); + + this->config.pool->dispense_backend()->do_numa_job([&, this](int i) { + this->tps[i]->write_weights_to_buffer(gpu_tp_count, this->tp_count, expert_id, this->config, w13_weight_ptrs, + w13_scale_ptrs, w2_weight_ptrs, w2_scale_ptrs); + }); + } +}; + +#endif // CPUINFER_OPERATOR_AMX_FP4_MOE_H diff --git a/kt-kernel/python/cli/utils/model_registry.py b/kt-kernel/python/cli/utils/model_registry.py index 099cc6ef..8dcdc901 100644 --- a/kt-kernel/python/cli/utils/model_registry.py +++ b/kt-kernel/python/cli/utils/model_registry.py @@ -81,6 +81,26 @@ BUILTIN_MODELS: list[ModelInfo] = [ description="DeepSeek R1-0528 reasoning model (May 2025, improved reasoning depth)", description_zh="DeepSeek R1-0528 推理模型(2025年5月,改进的推理深度)", ), + ModelInfo( + name="DeepSeek-V4-Flash", + hf_repo="deepseek-ai/DeepSeek-V4-Flash", + aliases=["deepseek-v4-flash", "deepseek-v4", "dsv4", "v4-flash", "v4"], + type="moe", + default_params={ + "kt-method": "MXFP4", + "kt-gpu-prefill-token-threshold": 4096, + "attention-backend": "flashinfer", + "max-total-tokens": 100000, + "max-running-requests": 16, + "chunked-prefill-size": 32768, + "mem-fraction-static": 0.80, + "watchdog-timeout": 3000, + "served-model-name": "DeepSeek-V4-Flash", + "disable-shared-experts-fusion": True, + }, + description="DeepSeek V4-Flash MoE model (native MXFP4 experts, MQA + sparse index attention)", + description_zh="DeepSeek V4-Flash MoE 模型(原生 MXFP4 专家,MQA + 稀疏索引注意力)", + ), ModelInfo( name="Kimi-K2-Thinking", hf_repo="moonshotai/Kimi-K2-Thinking", @@ -368,6 +388,19 @@ def compute_deepseek_v3_gpu_experts(tensor_parallel_size: int, vram_per_gpu_gb: return total_vram // 3 +def compute_deepseek_v4_gpu_experts(tensor_parallel_size: int, vram_per_gpu_gb: float) -> int: + """Compute kt-num-gpu-experts for DeepSeek-V4-Flash. + + V4 uses MXFP4 experts (~0.5 bytes/param vs V3 FP8's 1 byte/param) so each GPU + can hold ~2x more experts per VRAM unit than V3 at the same fragmentation. + """ + per_gpu_gb = 16 + if vram_per_gpu_gb < per_gpu_gb: + return 0 + total_vram = int(tensor_parallel_size * (vram_per_gpu_gb - per_gpu_gb)) + return total_vram * 2 // 3 + + def compute_kimi_k2_thinking_gpu_experts(tensor_parallel_size: int, vram_per_gpu_gb: float) -> int: """Compute kt-num-gpu-experts for Kimi K2 Thinking.""" per_gpu_gb = 16 @@ -393,6 +426,7 @@ MODEL_COMPUTE_FUNCTIONS: dict[str, Callable[[int, float], int]] = { "DeepSeek-V3-0324": compute_deepseek_v3_gpu_experts, "DeepSeek-V3.2": compute_deepseek_v3_gpu_experts, # Same as V3-0324 "DeepSeek-R1-0528": compute_deepseek_v3_gpu_experts, # Same as V3-0324 + "DeepSeek-V4-Flash": compute_deepseek_v4_gpu_experts, "Kimi-K2-Thinking": compute_kimi_k2_thinking_gpu_experts, "MiniMax-M2": compute_minimax_m2_gpu_experts, "MiniMax-M2.1": compute_minimax_m2_gpu_experts, # Same as M2 diff --git a/kt-kernel/python/experts.py b/kt-kernel/python/experts.py index 2a95e8b0..24a92dea 100644 --- a/kt-kernel/python/experts.py +++ b/kt-kernel/python/experts.py @@ -40,6 +40,7 @@ INFERENCE_METHODS = frozenset( "BF16", # BF16 native MoE "FP8_PERCHANNEL", # Per-channel FP8 "GPTQ_INT4", # GPTQ INT4 + "MXFP4", # MXFP4 (E2M1 nibble + ue8m0 group scale, e.g. DeepSeek-V4-Flash routed experts) "LLAMAFILE", # GGUF format "MOE_INT4", "MOE_INT8", # General kernel @@ -312,7 +313,7 @@ def _create_inference_wrapper( # Select backend based on method if method in ["AMXINT4", "AMXINT8"]: backend_cls = AMXMoEWrapper - elif method in ["RAWINT4", "FP8", "BF16", "FP8_PERCHANNEL", "GPTQ_INT4"]: + elif method in ["RAWINT4", "FP8", "BF16", "FP8_PERCHANNEL", "GPTQ_INT4", "MXFP4"]: backend_cls = NativeMoEWrapper elif method == "LLAMAFILE": backend_cls = LlamafileMoEWrapper diff --git a/kt-kernel/python/utils/amx.py b/kt-kernel/python/utils/amx.py index ea9af280..917cb0cc 100644 --- a/kt-kernel/python/utils/amx.py +++ b/kt-kernel/python/utils/amx.py @@ -11,6 +11,7 @@ from .loader import ( FP8SafeTensorLoader, BF16SafeTensorLoader, GPTQSafeTensorLoader, + MXFP4SafeTensorLoader, ) from kt_kernel_ext.moe import MOEConfig import kt_kernel_ext.moe as _moe_mod @@ -18,6 +19,7 @@ import kt_kernel_ext.moe as _moe_mod AMXInt4_MOE = getattr(_moe_mod, "AMXInt4_MOE", None) AMXInt8_MOE = getattr(_moe_mod, "AMXInt8_MOE", None) AMXInt4_KGroup_MOE = getattr(_moe_mod, "AMXInt4_KGroup_MOE", None) +AMXFP4_KGroup_MOE = getattr(_moe_mod, "AMXFP4_KGroup_MOE", None) AMXFP8_MOE = getattr(_moe_mod, "AMXFP8_MOE", None) AMXBF16_MOE = getattr(_moe_mod, "AMXBF16_MOE", None) AMXFP8PerChannel_MOE = getattr(_moe_mod, "AMXFP8PerChannel_MOE", None) @@ -31,6 +33,7 @@ AVXVNNI256RawInt4_MOE = getattr(_moe_mod, "AVXVNNI256RawInt4_MOE", None) _HAS_AMXINT4_SUPPORT = AMXInt4_MOE is not None _HAS_AMXINT8_SUPPORT = AMXInt8_MOE is not None _HAS_RAWINT4_SUPPORT = AMXInt4_KGroup_MOE is not None +_HAS_MXFP4_SUPPORT = AMXFP4_KGroup_MOE is not None _HAS_FP8_SUPPORT = AMXFP8_MOE is not None _HAS_BF16_SUPPORT = AMXBF16_MOE is not None _HAS_FP8_PERCHANNEL_SUPPORT = AMXFP8PerChannel_MOE is not None @@ -495,6 +498,12 @@ class NativeMoEWrapper(BaseMoEWrapper): "Please recompile kt_kernel_ext with GPTQ INT4 support enabled.\n" "AVX-VNNI-256 will be selected automatically when available on the current CPU." ) + if method == "MXFP4" and not _HAS_MXFP4_SUPPORT: + raise RuntimeError( + "MXFP4 backend not available. Required ISA:\n" + " - AVX512F + AVX512BW + AVX512_BF16\n" + "Please recompile kt_kernel_ext with AVX512 + BF16 enabled." + ) super().__init__( layer_idx=layer_idx, @@ -525,6 +534,8 @@ class NativeMoEWrapper(BaseMoEWrapper): NativeMoEWrapper._native_loader_instance = BF16SafeTensorLoader(weight_path) elif method == "GPTQ_INT4": NativeMoEWrapper._native_loader_instance = GPTQSafeTensorLoader(weight_path) + elif method == "MXFP4": + NativeMoEWrapper._native_loader_instance = MXFP4SafeTensorLoader(weight_path) else: raise NotImplementedError(f"Unsupported method for NativeMoEWrapper: {method}") self.loader = NativeMoEWrapper._native_loader_instance @@ -592,6 +603,10 @@ class NativeMoEWrapper(BaseMoEWrapper): self.up_scales = [t.to(torch.float32).contiguous() for t in weights["up_scale"]] self.down_scales = [t.to(torch.float32).contiguous() for t in weights["down_scale"]] assert self.gate_scales[0].dtype == torch.float32, "Expected float32 scales for FP8_PERCHANNEL" + elif self.method == "MXFP4": + # ue8m0 is losslessly representable in bf16 (8-bit exponent, 0 mantissa); + # the loader has already done that conversion. + assert self.gate_scales[0].dtype == torch.bfloat16, "Expected bf16 scales for MXFP4" t2 = time.time() @@ -649,6 +664,14 @@ class NativeMoEWrapper(BaseMoEWrapper): f"{_AVXVNNI256_RAW_INT4_MAX_GROUP_SIZE}; AVX2 (AVX2RawInt4_MOE) is used as the final fallback." ) self.moe = backend_cls(moe_config) + elif self.method == "MXFP4": + # MXFP4: E2M1 nibble-packed weights, ue8m0/bf16 per-32 group scale + # (e.g. DeepSeek-V4-Flash routed experts) + group_size = self.hidden_size // self.gate_scales[0].shape[1] + moe_config.quant_config.bits = 4 + moe_config.quant_config.group_size = group_size + moe_config.quant_config.zero_point = False + self.moe = AMXFP4_KGroup_MOE(moe_config) elif self.method == "FP8": moe_config.quant_config.bits = 8 moe_config.quant_config.group_size = 128 diff --git a/kt-kernel/python/utils/loader.py b/kt-kernel/python/utils/loader.py index 4771cb1c..ed4fd2b1 100644 --- a/kt-kernel/python/utils/loader.py +++ b/kt-kernel/python/utils/loader.py @@ -1231,3 +1231,91 @@ class GPTQSafeTensorLoader(FP8SafeTensorLoader): "up_scale": up_scales, "down_scale": down_scales, } + + +class MXFP4SafeTensorLoader(SafeTensorLoader): + """Loader for native MXFP4 expert weights (DeepSeek-V4-Flash format). + + Per expert layout: + {base}.ffn.experts.{i}.w1.weight I8 [N, K/2] nibble-packed E2M1 (gate) + {base}.ffn.experts.{i}.w1.scale F8_E8M0 [N, K/32] ue8m0 group scale + {base}.ffn.experts.{i}.w3.{weight,scale} up + {base}.ffn.experts.{i}.w2.{weight,scale} down + + V4 ckpt keys are not prefixed with ``model.``; we also probe the stripped form so + callers can keep passing ``base_key="model.layers.{L}"``. ue8m0 → bf16 is a lossless + bit shift (both have an 8-bit exponent and zero mantissa for ue8m0), and the AMX + FP4 backend already consumes bf16 scales. + """ + + EXPERTS_PATH_TPL = "{base}.ffn.experts" + PROJ_NAMES = ("w1", "w3", "w2") # (gate, up, down) + + def _experts_prefix_candidates(self, base_key: str) -> list[str]: + candidates = [self.EXPERTS_PATH_TPL.format(base=base_key)] + if base_key.startswith("model."): + candidates.append(self.EXPERTS_PATH_TPL.format(base=base_key[len("model.") :])) + return list(dict.fromkeys(candidates)) + + @staticmethod + def _ue8m0_to_bf16(scale_t: torch.Tensor) -> torch.Tensor: + if scale_t.dtype != torch.uint8: + scale_t = scale_t.view(torch.uint8) + # bf16 = [sign(1) | exp(8) | mant(7)]; setting mant=0, exp=e gives 2^(e-127), + # which is exactly the value encoded by ue8m0 for e ∈ [1, 254]. e=0 → bf16 +0 + # (acceptable: ue8m0=0 represents 2^-127, below bf16 normal range), e=255 → +inf. + # Compute in int32 then narrow to int16 (max value is 255<<7=32640, fits int16), + # because torch CPU has no lshift kernel for uint16. + return (scale_t.to(torch.int32) << 7).to(torch.int16).view(torch.bfloat16).contiguous() + + def load_experts(self, base_key: str, device: str = "cpu"): + gate_name, up_name, down_name = self.PROJ_NAMES + prefix = None + expert_count = 0 + for cand in self._experts_prefix_candidates(base_key): + expert_count = 0 + while self.has_tensor(f"{cand}.{expert_count}.{gate_name}.weight"): + expert_count += 1 + if expert_count > 0: + prefix = cand + break + if prefix is None: + raise ValueError( + f"No MXFP4 experts found under any of: {self._experts_prefix_candidates(base_key)}" + ) + + gate_weights = [None] * expert_count + up_weights = [None] * expert_count + down_weights = [None] * expert_count + gate_scales = [None] * expert_count + up_scales = [None] * expert_count + down_scales = [None] * expert_count + + for exp_id in range(expert_count): + for proj, dst in ( + (gate_name, gate_weights), + (up_name, up_weights), + (down_name, down_weights), + ): + w = self.load_tensor(f"{prefix}.{exp_id}.{proj}.weight", device).contiguous() + if w.dtype != torch.uint8: + w = w.view(torch.uint8) + dst[exp_id] = w + + for proj, dst in ( + (gate_name, gate_scales), + (up_name, up_scales), + (down_name, down_scales), + ): + s = self.load_tensor(f"{prefix}.{exp_id}.{proj}.scale", device) + dst[exp_id] = self._ue8m0_to_bf16(s) + + print(f"[MXFP4SafeTensorLoader] Loaded {expert_count} experts from {prefix}") + return { + "gate": gate_weights, + "up": up_weights, + "down": down_weights, + "gate_scale": gate_scales, + "up_scale": up_scales, + "down_scale": down_scales, + } diff --git a/third_party/sglang b/third_party/sglang index 537eb762..3cbd49c2 160000 --- a/third_party/sglang +++ b/third_party/sglang @@ -1 +1 @@ -Subproject commit 537eb762b0881071a0e098bd78666fe052b83deb +Subproject commit 3cbd49c291c5d1b968c81308647de34134790e07