From 6277da4c2b826eb6354d6a0edb0dbe4af927563c Mon Sep 17 00:00:00 2001 From: Oql <1692110604@qq.com> Date: Tue, 13 Jan 2026 17:36:25 +0800 Subject: [PATCH] support GLM 4.7 (#1791) support GLM 4.7 --- kt-kernel/bench/bench_fp8_moe.py | 8 +- kt-kernel/bench/bench_fp8_perchannel_moe.py | 277 +++++++ kt-kernel/bench/bench_write_buffer.py | 139 +++- kt-kernel/examples/test_bf16_moe.py | 14 +- kt-kernel/examples/test_fp8_perchannel_moe.py | 408 ++++++++++ kt-kernel/examples/test_write_buffer.py | 239 +++++- kt-kernel/ext_bindings.cpp | 107 +-- kt-kernel/operators/amx/bf16-moe.hpp | 18 +- .../operators/amx/fp8-perchannel-moe.hpp | 716 ++++++++++++++++++ .../operators/amx/la/amx_raw_buffers.hpp | 113 +++ .../operators/amx/la/amx_raw_kernels.hpp | 229 +++++- kt-kernel/operators/common.hpp | 3 +- kt-kernel/python/utils/amx.py | 22 +- kt-kernel/python/utils/loader.py | 187 ++++- 14 files changed, 2336 insertions(+), 144 deletions(-) create mode 100644 kt-kernel/bench/bench_fp8_perchannel_moe.py create mode 100644 kt-kernel/examples/test_fp8_perchannel_moe.py create mode 100644 kt-kernel/operators/amx/fp8-perchannel-moe.hpp diff --git a/kt-kernel/bench/bench_fp8_moe.py b/kt-kernel/bench/bench_fp8_moe.py index a5108f5..a27a490 100644 --- a/kt-kernel/bench/bench_fp8_moe.py +++ b/kt-kernel/bench/bench_fp8_moe.py @@ -17,7 +17,7 @@ import platform sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "build")) import torch -import kt_kernel_ext +from kt_kernel import kt_kernel_ext from tqdm import tqdm # Test parameters @@ -29,9 +29,9 @@ fp8_group_size = 128 max_len = 25600 layer_num = 2 -qlen = 1024 -warm_up_iter = 10 -test_iter = 30 +qlen = 1 +warm_up_iter = 1000 +test_iter = 3000 CPUINFER_PARAM = 80 CPUInfer = kt_kernel_ext.CPUInfer(CPUINFER_PARAM) diff --git a/kt-kernel/bench/bench_fp8_perchannel_moe.py b/kt-kernel/bench/bench_fp8_perchannel_moe.py new file mode 100644 index 0000000..a8f8811 --- /dev/null +++ b/kt-kernel/bench/bench_fp8_perchannel_moe.py @@ -0,0 +1,277 @@ +""" +Performance benchmark for FP8 Per-Channel MoE kernel (GLM-4.7-FP8 style). + +This benchmark measures the performance of the FP8 Per-Channel MoE operator with: +- FP8 (E4M3) weights with per-channel scaling (one scale per output row) +- BF16 activations +- AVX-512 DPBF16 compute path +""" + +import os +import sys +import time +import json +import subprocess +import platform + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "build")) + +import torch +from kt_kernel import kt_kernel_ext +from tqdm import tqdm + +# Test parameters +expert_num = 256 +hidden_size = 7168 +intermediate_size = 2048 +num_experts_per_tok = 8 +max_len = 25600 + +layer_num = 2 +qlen = 1 +warm_up_iter = 1000 +test_iter = 3000 +CPUINFER_PARAM = 80 + +CPUInfer = kt_kernel_ext.CPUInfer(CPUINFER_PARAM) + +# Result file path +script_path = os.path.abspath(__file__) +script_dir = os.path.dirname(script_path) +json_path = os.path.join(script_dir, "bench_results.jsonl") + + +def get_git_commit(): + """Get current git commit info""" + result = {} + try: + commit = subprocess.check_output(["git", "rev-parse", "HEAD"]).decode("utf-8").strip() + commit_msg = subprocess.check_output(["git", "log", "-1", "--pretty=%B"]).decode("utf-8").strip() + result["commit"] = commit + result["commit_message"] = commit_msg + dirty_output = subprocess.check_output(["git", "status", "--porcelain"]).decode("utf-8").strip() + result["dirty"] = bool(dirty_output) + if dirty_output: + result["dirty_files"] = dirty_output.splitlines() + except Exception as e: + result["commit"] = None + result["error"] = str(e) + return result + + +def get_system_info(): + """Get system information""" + info = {} + uname = platform.uname() + info["system_name"] = uname.system + info["node_name"] = uname.node + + cpu_model = None + if os.path.exists("/proc/cpuinfo"): + try: + with open("/proc/cpuinfo", "r") as f: + for line in f: + if "model name" in line: + cpu_model = line.split(":", 1)[1].strip() + break + except Exception: + pass + info["cpu_model"] = cpu_model + info["cpu_core_count"] = os.cpu_count() + return info + + +def record_results(result, filename=json_path): + """Append result to JSON file""" + with open(filename, "a") as f: + f.write(json.dumps(result) + "\n") + + +def generate_fp8_perchannel_weights_direct(shape: tuple): + """ + Directly generate random FP8 weights and per-channel scales. + + Args: + shape: (expert_num, n, k) - weight tensor shape + + Returns: + fp8_weights: uint8 tensor with random FP8 E4M3 values + scales: fp32 tensor with per-channel scales, shape [expert_num, n] + """ + e, n, k = shape + + # Directly generate random FP8 weights as uint8 + # FP8 E4M3 format: 1 sign + 4 exp + 3 mantissa + fp8_weights = torch.randint(0, 256, (e, n, k), dtype=torch.uint8, device="cuda").to("cpu").contiguous() + + # Generate random per-channel scales (one per output row) + # Use reasonable scale range (e.g., 2^-8 to 2^8) + exponents = torch.randint(-8, 9, (e, n), dtype=torch.int32, device="cuda").to("cpu").contiguous() + scales = (2.0 ** exponents.float()).to(torch.float32).contiguous() + + return fp8_weights, scales + + +def bench_fp8_perchannel_moe(): + """Benchmark FP8 Per-Channel MoE performance""" + with torch.inference_mode(): + print("=" * 70) + print("FP8 Per-Channel MoE Kernel Performance Benchmark") + print("=" * 70) + + # Generate FP8 weights with per-channel scales + print("\nGenerating FP8 weights with per-channel scales...") + torch.manual_seed(42) + gate_fp8, gate_scales = generate_fp8_perchannel_weights_direct((expert_num, intermediate_size, hidden_size)) + up_fp8, up_scales = generate_fp8_perchannel_weights_direct((expert_num, intermediate_size, hidden_size)) + down_fp8, down_scales = generate_fp8_perchannel_weights_direct((expert_num, hidden_size, intermediate_size)) + + physical_to_logical_map = torch.tensor(range(expert_num), device="cpu", dtype=torch.int64).contiguous() + + # Build MoE layers + print("Building FP8 Per-Channel MoE layers...") + moes = [] + for _ in tqdm(range(layer_num), desc="Initializing MOEs"): + config = kt_kernel_ext.moe.MOEConfig(expert_num, num_experts_per_tok, hidden_size, intermediate_size, 0) + config.max_len = max_len + config.quant_config.bits = 8 + config.quant_config.group_size = 0 # Not used for per-channel + config.quant_config.zero_point = False + config.quant_config.per_channel = True # Enable per-channel mode + + config.gate_proj = gate_fp8.data_ptr() + config.up_proj = up_fp8.data_ptr() + config.down_proj = down_fp8.data_ptr() + config.gate_scale = gate_scales.data_ptr() + config.up_scale = up_scales.data_ptr() + config.down_scale = down_scales.data_ptr() + config.pool = CPUInfer.backend_ + + moe = kt_kernel_ext.moe.AMXFP8PerChannel_MOE(config) + CPUInfer.submit(moe.load_weights_task(physical_to_logical_map.data_ptr())) + CPUInfer.sync() + moes.append(moe) + + # Generate input data + print("Generating input data...") + gen_iter = 1000 + expert_ids = ( + torch.rand(gen_iter * qlen, expert_num, device="cpu") + .argsort(dim=-1)[:, :num_experts_per_tok] + .reshape(gen_iter, qlen * num_experts_per_tok) + .contiguous() + ) + weights = torch.rand((gen_iter, qlen, num_experts_per_tok), dtype=torch.float32, device="cpu").contiguous() + input_tensor = torch.randn((layer_num, qlen, hidden_size), dtype=torch.bfloat16, device="cpu").contiguous() + output_tensor = torch.empty((layer_num, qlen, hidden_size), dtype=torch.bfloat16, device="cpu").contiguous() + qlen_tensor = torch.tensor([qlen], dtype=torch.int32) + + # Warmup + print(f"Warming up ({warm_up_iter} iterations)...") + for i in tqdm(range(warm_up_iter), desc="Warm-up"): + CPUInfer.submit( + moes[i % layer_num].forward_task( + qlen_tensor.data_ptr(), + num_experts_per_tok, + expert_ids[i % gen_iter].data_ptr(), + weights[i % gen_iter].data_ptr(), + input_tensor[i % layer_num].data_ptr(), + output_tensor[i % layer_num].data_ptr(), + False, + ) + ) + CPUInfer.sync() + + # Benchmark + print(f"Running benchmark ({test_iter} iterations)...") + start = time.perf_counter() + for i in tqdm(range(test_iter), desc="Testing"): + CPUInfer.submit( + moes[i % layer_num].forward_task( + qlen_tensor.data_ptr(), + num_experts_per_tok, + expert_ids[i % gen_iter].data_ptr(), + weights[i % gen_iter].data_ptr(), + input_tensor[i % layer_num].data_ptr(), + output_tensor[i % layer_num].data_ptr(), + False, + ) + ) + CPUInfer.sync() + end = time.perf_counter() + total_time = end - start + + # Calculate metrics + time_per_iter_us = total_time / test_iter * 1e6 + + # FLOPS calculation: + # Each expert performs: gate(intermediate x hidden) + up(intermediate x hidden) + down(hidden x intermediate) + # GEMM/GEMV: 2 * m * n * k flops (multiply + accumulate = 2 ops per element) + # For vector-matrix multiply (qlen=1): 2 * n * k per matrix + flops_per_expert = ( + 2 * intermediate_size * hidden_size # gate + + 2 * intermediate_size * hidden_size # up + + 2 * hidden_size * intermediate_size # down + ) + total_flops = qlen * num_experts_per_tok * flops_per_expert * test_iter + tflops = total_flops / total_time / 1e12 + + # Bandwidth calculation (FP8 = 1 byte per element) + bytes_per_elem = 1.0 + # Weight memory: gate + up + down per expert + bandwidth = ( + hidden_size + * intermediate_size + * 3 + * num_experts_per_tok + * (1 / num_experts_per_tok * expert_num * (1 - (1 - num_experts_per_tok / expert_num) ** qlen)) + * bytes_per_elem + * test_iter + / total_time + / 1e9 + ) + + # Print results + print("\n" + "=" * 70) + print("Benchmark Results") + print("=" * 70) + print(f"Quant mode: FP8 (E4M3) with per-channel scaling") + print(f"Total time: {total_time:.4f} s") + print(f"Iterations: {test_iter}") + print(f"Time per iteration: {time_per_iter_us:.2f} us") + print(f"Bandwidth: {bandwidth:.2f} GB/s") + print(f"TFLOPS: {tflops:.4f}") + print("") + + # Record results + result = { + "test_name": os.path.basename(__file__), + "quant_mode": "fp8_e4m3_perchannel", + "total_time_seconds": total_time, + "iterations": test_iter, + "time_per_iteration_us": time_per_iter_us, + "bandwidth_GBs": bandwidth, + "flops_TFLOPS": tflops, + "timestamp": time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()), + "test_parameters": { + "expert_num": expert_num, + "hidden_size": hidden_size, + "intermediate_size": intermediate_size, + "num_experts_per_tok": num_experts_per_tok, + "quant_type": "per_channel", + "layer_num": layer_num, + "qlen": qlen, + "warm_up_iter": warm_up_iter, + "test_iter": test_iter, + "CPUInfer_parameter": CPUINFER_PARAM, + }, + } + result.update(get_git_commit()) + result.update(get_system_info()) + record_results(result) + + return tflops, bandwidth + + +if __name__ == "__main__": + bench_fp8_perchannel_moe() diff --git a/kt-kernel/bench/bench_write_buffer.py b/kt-kernel/bench/bench_write_buffer.py index cb5ca44..e2ca330 100644 --- a/kt-kernel/bench/bench_write_buffer.py +++ b/kt-kernel/bench/bench_write_buffer.py @@ -4,12 +4,14 @@ Benchmark write_weight_scale_to_buffer for AMX MOE operators. Supports: -- FP8: FP8 weights (1 byte) + float32 scales +- FP8: FP8 weights (1 byte) + float32 scales (block-wise) +- FP8_PERCHANNEL: FP8 weights (1 byte) + float32 per-channel scales - BF16: Native BF16 weights (2 bytes), no scales Usage: python bench_write_buffer.py # Run all modes python bench_write_buffer.py fp8 # Run FP8 only + python bench_write_buffer.py fp8_perchannel # Run FP8 per-channel only python bench_write_buffer.py bf16 # Run BF16 only """ import json @@ -31,8 +33,8 @@ expert_num = 256 num_experts_per_tok = 8 gpu_tp_count = 2 -warm_up_iter = 3 -test_iter = 7 +warm_up_iter = 30 +test_iter = 70 gpu_experts_num = expert_num @@ -147,6 +149,49 @@ def allocate_weights_fp8(): } +def allocate_weights_fp8_perchannel(): + per_mat_weight_bytes = hidden_size * intermediate_size + per_mat_scale_elems_gate_up = intermediate_size + per_mat_scale_elems_down = hidden_size + + gate_q = ( + torch.randint(0, 256, (expert_num * per_mat_weight_bytes,), dtype=torch.uint8, device="cuda") + .to("cpu") + .contiguous() + ) + up_q = ( + torch.randint(0, 256, (expert_num * per_mat_weight_bytes,), dtype=torch.uint8, device="cuda") + .to("cpu") + .contiguous() + ) + down_q = ( + torch.randint(0, 256, (expert_num * per_mat_weight_bytes,), dtype=torch.uint8, device="cuda") + .to("cpu") + .contiguous() + ) + gate_scale = ( + torch.randn(expert_num * per_mat_scale_elems_gate_up, dtype=torch.float32, device="cuda").to("cpu").contiguous() + ) + up_scale = ( + torch.randn(expert_num * per_mat_scale_elems_gate_up, dtype=torch.float32, device="cuda").to("cpu").contiguous() + ) + down_scale = ( + torch.randn(expert_num * per_mat_scale_elems_down, dtype=torch.float32, device="cuda").to("cpu").contiguous() + ) + + return { + "gate_q": gate_q, + "up_q": up_q, + "down_q": down_q, + "gate_scale": gate_scale, + "up_scale": up_scale, + "down_scale": down_scale, + "per_mat_weight_bytes": per_mat_weight_bytes, + "per_mat_scale_elems_gate_up": per_mat_scale_elems_gate_up, + "per_mat_scale_elems_down": per_mat_scale_elems_down, + } + + def build_moe_fp8(layer_idx=0): """Build a single FP8 MOE instance.""" weights = allocate_weights_fp8() @@ -178,6 +223,38 @@ def build_moe_fp8(layer_idx=0): return moe, buffer_shapes, weights +def build_moe_fp8_perchannel(layer_idx=0): + """Build a single FP8 per-channel MOE instance.""" + weights = allocate_weights_fp8_perchannel() + + config = kt_kernel_ext.moe.MOEConfig(expert_num, num_experts_per_tok, hidden_size, intermediate_size) + config.max_len = max_len + config.layer_idx = layer_idx + config.quant_config.bits = 8 + config.quant_config.group_size = 0 + config.quant_config.zero_point = False + config.quant_config.per_channel = True + config.pool = CPUInfer.backend_ + config.gate_proj = weights["gate_q"].data_ptr() + config.up_proj = weights["up_q"].data_ptr() + config.down_proj = weights["down_q"].data_ptr() + config.gate_scale = weights["gate_scale"].data_ptr() + config.up_scale = weights["up_scale"].data_ptr() + config.down_scale = weights["down_scale"].data_ptr() + + moe = kt_kernel_ext.moe.AMXFP8PerChannel_MOE(config) + CPUInfer.submit(moe.load_weights_task(physical_to_logical_map.data_ptr())) + CPUInfer.sync() + + buffer_shapes = { + "per_mat_weight_bytes": weights["per_mat_weight_bytes"], + "per_mat_scale_elems_gate_up": weights["per_mat_scale_elems_gate_up"], + "per_mat_scale_elems_down": weights["per_mat_scale_elems_down"], + } + + return moe, buffer_shapes, weights + + def allocate_buffers_fp8(buffer_shapes): """Allocate output buffers for FP8 single expert.""" per_mat_weight_bytes = buffer_shapes["per_mat_weight_bytes"] @@ -212,6 +289,40 @@ def allocate_buffers_fp8(buffer_shapes): return buffer_ptrs, keep_tensors +def allocate_buffers_fp8_perchannel(buffer_shapes): + """Allocate output buffers for FP8 per-channel single expert.""" + per_mat_weight_bytes = buffer_shapes["per_mat_weight_bytes"] + per_mat_scale_elems_gate_up = buffer_shapes["per_mat_scale_elems_gate_up"] + per_mat_scale_elems_down = buffer_shapes["per_mat_scale_elems_down"] + + weight_bytes_per_expert_per_tp = per_mat_weight_bytes // gpu_tp_count + scale_elems_per_expert_per_tp_gate_up = per_mat_scale_elems_gate_up // gpu_tp_count + scale_elems_per_expert_per_tp_down = per_mat_scale_elems_down + + w13_weight_bufs = [torch.empty(2 * weight_bytes_per_expert_per_tp, dtype=torch.uint8) for _ in range(gpu_tp_count)] + w13_scale_bufs = [ + torch.empty(2 * scale_elems_per_expert_per_tp_gate_up, dtype=torch.float32) for _ in range(gpu_tp_count) + ] + w2_weight_bufs = [torch.empty(weight_bytes_per_expert_per_tp, dtype=torch.uint8) for _ in range(gpu_tp_count)] + w2_scale_bufs = [torch.empty(scale_elems_per_expert_per_tp_down, dtype=torch.float32) for _ in range(gpu_tp_count)] + + buffer_ptrs = { + "w13_weight_ptrs": [buf.data_ptr() for buf in w13_weight_bufs], + "w13_scale_ptrs": [buf.data_ptr() for buf in w13_scale_bufs], + "w2_weight_ptrs": [buf.data_ptr() for buf in w2_weight_bufs], + "w2_scale_ptrs": [buf.data_ptr() for buf in w2_scale_bufs], + } + + keep_tensors = { + "w13_weight_bufs": w13_weight_bufs, + "w13_scale_bufs": w13_scale_bufs, + "w2_weight_bufs": w2_weight_bufs, + "w2_scale_bufs": w2_scale_bufs, + } + + return buffer_ptrs, keep_tensors + + # ============================================================================== # BF16 Functions # ============================================================================== @@ -320,6 +431,20 @@ def bench_write_buffer(quant_mode: str): ) bytes_per_call = total_weights + total_scale_bytes + elif quant_mode == "fp8_perchannel": + bytes_per_elem = 1.0 + moe_0, buffer_shapes, keep_tensors_0 = build_moe_fp8_perchannel(layer_idx=0) + moe_1, _, keep_tensors_1 = build_moe_fp8_perchannel(layer_idx=1) + buffer_ptrs, buffer_keep = allocate_buffers_fp8_perchannel(buffer_shapes) + + total_weights = hidden_size * intermediate_size * expert_num * 3 + total_scale_bytes = ( + (buffer_shapes["per_mat_scale_elems_gate_up"] * 2 + buffer_shapes["per_mat_scale_elems_down"]) + * expert_num + * 4 + ) + bytes_per_call = total_weights + total_scale_bytes + elif quant_mode == "bf16": bytes_per_elem = 2.0 moe_0, buffer_shapes, keep_tensors_0 = build_moe_bf16(layer_idx=0) @@ -356,7 +481,7 @@ def bench_write_buffer(quant_mode: str): end = time.perf_counter() iter_time = end - start total_time += iter_time - print(f" Iter {iter_idx}: {iter_time*1000:.2f} ms") + # print(f" Iter {iter_idx}: {iter_time*1000:.2f} ms") time.sleep(0.3) # bytes_per_call is for one MOE, we have 2 MOEs @@ -400,7 +525,7 @@ def bench_write_buffer(quant_mode: str): def main(quant_modes=None): """Run benchmarks for specified quant modes.""" if quant_modes is None: - quant_modes = ["fp8", "bf16"] + quant_modes = ["fp8", "fp8_perchannel", "bf16"] results = {} for mode in quant_modes: @@ -423,10 +548,10 @@ def main(quant_modes=None): if __name__ == "__main__": if len(sys.argv) > 1: mode = sys.argv[1].lower() - if mode in ["fp8", "bf16"]: + if mode in ["fp8", "fp8_perchannel", "bf16"]: main([mode]) else: - print(f"Unknown mode: {mode}. Use 'fp8' or 'bf16'") + print(f"Unknown mode: {mode}. Use 'fp8', 'fp8_perchannel' or 'bf16'") sys.exit(1) else: main() diff --git a/kt-kernel/examples/test_bf16_moe.py b/kt-kernel/examples/test_bf16_moe.py index ec700cc..8c715cb 100644 --- a/kt-kernel/examples/test_bf16_moe.py +++ b/kt-kernel/examples/test_bf16_moe.py @@ -22,12 +22,12 @@ from kt_kernel import kt_kernel_ext torch.manual_seed(42) # Model config -hidden_size = 3072 -intermediate_size = 1536 +hidden_size = 2048 +intermediate_size = 768 max_len = 25600 -expert_num = 16 -num_experts_per_tok = 16 +expert_num = 128 +num_experts_per_tok = 8 qlen = 1 layer_num = 5 @@ -180,13 +180,13 @@ def run_bf16_moe_test(): diffs = [] with torch.inference_mode(mode=True): for i in range(validation_iter): - torch.manual_seed(100 + i) + torch.manual_seed(114514 + i) bsz_tensor = torch.tensor([qlen], device="cpu") expert_ids = torch.stack( [torch.randperm(expert_num)[:num_experts_per_tok] for _ in range(qlen)] ).contiguous() - weights = torch.randn((qlen, num_experts_per_tok), dtype=torch.float32).contiguous() / 100 - input_tensor = torch.randn((qlen, hidden_size), dtype=torch.bfloat16).contiguous() * 1.5 + weights = torch.randn((qlen, num_experts_per_tok), dtype=torch.float32).contiguous() / 10 + input_tensor = torch.randn((qlen, hidden_size), dtype=torch.bfloat16).contiguous() * 3 output = torch.empty((qlen, hidden_size), dtype=torch.bfloat16).contiguous() moe = moes[i % layer_num] diff --git a/kt-kernel/examples/test_fp8_perchannel_moe.py b/kt-kernel/examples/test_fp8_perchannel_moe.py new file mode 100644 index 0000000..54f7400 --- /dev/null +++ b/kt-kernel/examples/test_fp8_perchannel_moe.py @@ -0,0 +1,408 @@ +""" +Test script for FP8 Per-Channel MoE kernel validation (GLM-4.7-FP8 style). + +This script: +1. Generates random BF16 weights +2. Quantizes them to FP8 format with per-channel scales (one scale per output channel) +3. Runs the FP8 Per-Channel MoE kernel +4. Compares results with PyTorch reference using dequantized BF16 weights + +FP8 Per-Channel format notes: +- Weight: FP8 (E4M3) stored as uint8, shape [expert_num, n, k] +- Scale: FP32, shape [expert_num, n] (one scale per output row) +""" + +import os +import sys + +sys.path.insert(0, os.path.dirname(__file__) + "/../build") + +import torch +from kt_kernel import kt_kernel_ext + +torch.manual_seed(42) + +# Model config +hidden_size = 3072 +intermediate_size = 1536 +max_len = 25600 + +expert_num = 16 +num_experts_per_tok = 8 + +qlen = 100 +layer_num = 1 +CPUInfer = kt_kernel_ext.CPUInfer(40) +validation_iter = 1 +debug_print_count = 16 + +physical_to_logical_map = torch.tensor(data=range(expert_num), device="cpu", dtype=torch.int64).contiguous() + + +def act_fn(x): + """SiLU activation function""" + return x / (1.0 + torch.exp(-x)) + + +def mlp_torch(input, gate_proj, up_proj, down_proj): + """Reference MLP computation in PyTorch""" + gate_buf = torch.mm(input, gate_proj.t()) + up_buf = torch.mm(input, up_proj.t()) + intermediate = act_fn(gate_buf) * up_buf + ret = torch.mm(intermediate, down_proj.t()) + return ret + + +def moe_torch(input, expert_ids, weights, gate_proj, up_proj, down_proj): + """Reference MoE computation in PyTorch""" + cnts = expert_ids.new_zeros((expert_ids.shape[0], expert_num)) + cnts.scatter_(1, expert_ids, 1) + tokens_per_expert = cnts.sum(dim=0) + idxs = expert_ids.view(-1).argsort() + sorted_tokens = input[idxs // expert_ids.shape[1]] + + outputs = [] + start_idx = 0 + + for i, num_tokens in enumerate(tokens_per_expert): + end_idx = start_idx + num_tokens + if num_tokens == 0: + continue + tokens_for_this_expert = sorted_tokens[start_idx:end_idx] + expert_out = mlp_torch(tokens_for_this_expert, gate_proj[i], up_proj[i], down_proj[i]) + outputs.append(expert_out) + start_idx = end_idx + + outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0) + + new_x = torch.empty_like(outs) + new_x[idxs] = outs + t_output = ( + new_x.view(*expert_ids.shape, -1) + .type(weights.dtype) + .mul_(weights.unsqueeze(dim=-1)) + .sum(dim=1) + .type(new_x.dtype) + ) + return t_output + + +# FP8 E4M3 constants +FP8_E4M3_MAX = 448.0 # Maximum representable value in FP8 E4M3 + + +def fp8_e4m3_to_float(fp8_val: int) -> float: + """ + Convert FP8 E4M3 value to float. + FP8 E4M3 format: 1 sign bit, 4 exponent bits, 3 mantissa bits + """ + sign = (fp8_val >> 7) & 1 + exp = (fp8_val >> 3) & 0xF + mant = fp8_val & 0x7 + + if exp == 0: + # Subnormal or zero + if mant == 0: + return -0.0 if sign else 0.0 + # Subnormal: value = (-1)^sign * 2^(-6) * (0.mant) + return ((-1) ** sign) * (2**-6) * (mant / 8.0) + elif exp == 15: + # NaN (FP8 E4M3 doesn't have Inf, all exp=15 are NaN) + return float("nan") + else: + # Normal: value = (-1)^sign * 2^(exp-7) * (1.mant) + return ((-1) ** sign) * (2 ** (exp - 7)) * (1.0 + mant / 8.0) + + +def float_to_fp8_e4m3(val: float) -> int: + """ + Convert float to FP8 E4M3 value. + """ + if val != val: # NaN + return 0x7F # NaN representation + + sign = 1 if val < 0 else 0 + val = abs(val) + + if val == 0: + return sign << 7 + + # Clamp to max representable value + val = min(val, FP8_E4M3_MAX) + + # Find exponent + import math + + if val < 2**-9: # Subnormal threshold + # Subnormal + mant = int(round(val / (2**-9))) + mant = min(mant, 7) + return (sign << 7) | mant + + exp = int(math.floor(math.log2(val))) + 7 + exp = max(1, min(exp, 14)) # Clamp exponent to valid range + + # Calculate mantissa + mant = int(round((val / (2 ** (exp - 7)) - 1.0) * 8)) + mant = max(0, min(mant, 7)) + + # Handle overflow to next exponent + if mant > 7: + mant = 0 + exp += 1 + if exp > 14: + exp = 14 + mant = 7 + + return (sign << 7) | (exp << 3) | mant + + +def quantize_to_fp8_perchannel(weights: torch.Tensor): + """ + Quantize BF16/FP32 weights to FP8 with per-channel scaling. + + Args: + weights: [expert_num, n, k] tensor in BF16/FP32 + + Returns: + fp8_weights: [expert_num, n, k] uint8 tensor + scales: [expert_num, n] FP32 tensor (one scale per output row) + """ + weights_f32 = weights.to(torch.float32) + e, n, k = weights_f32.shape + + # Calculate max abs per row (per output channel) + max_abs = weights_f32.abs().amax(dim=-1, keepdim=True) # [e, n, 1] + max_abs = torch.clamp(max_abs, min=1e-12) + + # Scale to FP8 range: scale = max_abs / FP8_MAX + scales = (max_abs / FP8_E4M3_MAX).squeeze(-1) # [e, n] + + # Quantize: q = round(val / scale) + scaled = weights_f32 / (scales.unsqueeze(-1) + 1e-12) + + # Clamp to FP8 representable range + scaled = scaled.clamp(-FP8_E4M3_MAX, FP8_E4M3_MAX) + + # Vectorized FP8 quantization + fp8_q = torch.zeros_like(scaled, dtype=torch.uint8) + + sign_mask = (scaled < 0).to(torch.uint8) << 7 + abs_scaled = scaled.abs() + + # Handle different ranges + # Subnormal: 0 < |x| < 2^-6 + subnormal_mask = (abs_scaled > 0) & (abs_scaled < 2**-6) + subnormal_mant = (abs_scaled / (2**-9)).round().clamp(0, 7).to(torch.uint8) + + # Normal values + normal_mask = abs_scaled >= 2**-6 + log2_val = torch.log2(abs_scaled.clamp(min=2**-9)) + exp = (log2_val.floor() + 7).clamp(1, 14).to(torch.int32) + mant = ((abs_scaled / (2.0 ** (exp.float() - 7)) - 1.0) * 8).round().clamp(0, 7).to(torch.uint8) + + # Combine + fp8_q = torch.where(subnormal_mask, sign_mask | subnormal_mant, fp8_q) + fp8_q = torch.where(normal_mask, sign_mask | (exp.to(torch.uint8) << 3) | mant, fp8_q) + + return fp8_q.contiguous(), scales.to(torch.float32).contiguous() + + +def dequantize_fp8_perchannel(fp8_weights: torch.Tensor, scales: torch.Tensor): + """ + Dequantize FP8 weights back to BF16 for reference computation. + + Args: + fp8_weights: [expert_num, n, k] uint8 tensor + scales: [expert_num, n] FP32 tensor + + Returns: + dequantized: [expert_num, n, k] BF16 tensor + """ + # Build lookup table for FP8 E4M3 -> float + fp8_lut = torch.tensor([fp8_e4m3_to_float(i) for i in range(256)], dtype=torch.float32) + + # Use lookup table + fp8_float = fp8_lut[fp8_weights.to(torch.int64)] + + # Apply per-channel scales + scales_expanded = scales.unsqueeze(-1) # [e, n, 1] + dequantized = fp8_float * scales_expanded + + return dequantized.to(torch.bfloat16).contiguous() + + +def build_random_fp8_perchannel_weights(): + """ + Generate random BF16 weights and quantize to FP8 with per-channel scales. + + Returns: + dict with fp8 weights, scales, and original bf16 for reference + """ + torch.manual_seed(42) + + # Generate random BF16 weights with small values + gate_proj = (torch.randn((expert_num, intermediate_size, hidden_size), dtype=torch.float32) / 100.0).to( + torch.bfloat16 + ) + up_proj = (torch.randn((expert_num, intermediate_size, hidden_size), dtype=torch.float32) / 100.0).to( + torch.bfloat16 + ) + down_proj = (torch.randn((expert_num, hidden_size, intermediate_size), dtype=torch.float32) / 100.0).to( + torch.bfloat16 + ) + + # Quantize to FP8 with per-channel scales + gate_fp8, gate_scales = quantize_to_fp8_perchannel(gate_proj) + up_fp8, up_scales = quantize_to_fp8_perchannel(up_proj) + down_fp8, down_scales = quantize_to_fp8_perchannel(down_proj) + + # Dequantize for reference computation + gate_deq = dequantize_fp8_perchannel(gate_fp8, gate_scales) + up_deq = dequantize_fp8_perchannel(up_fp8, up_scales) + down_deq = dequantize_fp8_perchannel(down_fp8, down_scales) + + print(f"FP8 Per-Channel weights shape: gate={gate_fp8.shape}, up={up_fp8.shape}, down={down_fp8.shape}") + print(f"Per-Channel scales shape: gate={gate_scales.shape}, up={up_scales.shape}, down={down_scales.shape}") + + # Debug: Print FP8 weight and scale info for expert 0 + print("\n=== DEBUG: FP8 Per-Channel Weight and Scale Info (Expert 0) ===") + print(f"gate_fp8[0] first 8x8 block:") + for i in range(8): + print(f" row {i}: {gate_fp8[0, i, :8].numpy().tobytes().hex(' ')}") + print(f"gate_fp8[0] stats: min={gate_fp8[0].min()}, max={gate_fp8[0].max()}") + print(f"gate_scales[0] first 8 channels: {gate_scales[0, :8]}") + print(f"gate_scales[0] stats: min={gate_scales[0].min():.6f}, max={gate_scales[0].max():.6f}") + + return { + "gate_fp8": gate_fp8.contiguous(), + "up_fp8": up_fp8.contiguous(), + "down_fp8": down_fp8.contiguous(), + "gate_scales": gate_scales.contiguous(), + "up_scales": up_scales.contiguous(), + "down_scales": down_scales.contiguous(), + "gate_deq": gate_deq.contiguous(), + "up_deq": up_deq.contiguous(), + "down_deq": down_deq.contiguous(), + } + + +def build_moes_from_fp8_perchannel_data(fp8_data: dict): + """ + Build FP8 Per-Channel MoE modules from quantized data. + """ + moes = [] + with torch.inference_mode(mode=True): + for _ in range(layer_num): + config = kt_kernel_ext.moe.MOEConfig(expert_num, num_experts_per_tok, hidden_size, intermediate_size, 0) + config.max_len = max_len + config.quant_config.bits = 8 + config.quant_config.group_size = 0 # Not used for per-channel + config.quant_config.zero_point = False + config.quant_config.per_channel = True # Enable per-channel mode + + # Set FP8 weight pointers + config.gate_proj = fp8_data["gate_fp8"].data_ptr() + config.up_proj = fp8_data["up_fp8"].data_ptr() + config.down_proj = fp8_data["down_fp8"].data_ptr() + + # Set per-channel scale pointers + config.gate_scale = fp8_data["gate_scales"].data_ptr() + config.up_scale = fp8_data["up_scales"].data_ptr() + config.down_scale = fp8_data["down_scales"].data_ptr() + config.pool = CPUInfer.backend_ + + moe = kt_kernel_ext.moe.AMXFP8PerChannel_MOE(config) + CPUInfer.submit(moe.load_weights_task(physical_to_logical_map.data_ptr())) + CPUInfer.sync() + moes.append(moe) + return moes + + +def run_fp8_perchannel_moe_test(): + """ + Run FP8 Per-Channel MoE validation test. + """ + print("\n" + "=" * 70) + print("FP8 Per-Channel MoE Kernel Validation Test") + print("=" * 70) + + # Build FP8 per-channel weights + print("\nGenerating and quantizing weights with per-channel scales...") + fp8_data = build_random_fp8_perchannel_weights() + + # Build MoE modules + print("\nBuilding FP8 Per-Channel MoE modules...") + moes = build_moes_from_fp8_perchannel_data(fp8_data) + + # Get dequantized weights for reference + gate_deq = fp8_data["gate_deq"] + up_deq = fp8_data["up_deq"] + down_deq = fp8_data["down_deq"] + + diffs = [] + with torch.inference_mode(mode=True): + for i in range(validation_iter): + torch.manual_seed(100 + i) + bsz_tensor = torch.tensor([qlen], device="cpu") + expert_ids = torch.stack( + [torch.randperm(expert_num)[:num_experts_per_tok] for _ in range(qlen)] + ).contiguous() + weights = torch.randn((qlen, num_experts_per_tok), dtype=torch.float32).contiguous() / 100 + input_tensor = torch.randn((qlen, hidden_size), dtype=torch.bfloat16).contiguous() * 1.5 + output = torch.empty((qlen, hidden_size), dtype=torch.bfloat16).contiguous() + + moe = moes[i % layer_num] + CPUInfer.submit( + moe.forward_task( + bsz_tensor.data_ptr(), + num_experts_per_tok, + expert_ids.data_ptr(), + weights.data_ptr(), + input_tensor.data_ptr(), + output.data_ptr(), + False, + ) + ) + CPUInfer.sync() + + assert not torch.isnan(output).any(), "NaN values detected in CPU expert output." + assert not torch.isinf(output).any(), "Inf values detected in CPU expert output." + + # Reference computation using dequantized weights + t_output = moe_torch(input_tensor, expert_ids, weights, gate_deq, up_deq, down_deq) + + t_output_flat = t_output.flatten() + output_flat = output.flatten() + + diff = torch.mean(torch.abs(output_flat - t_output_flat)) / (torch.mean(torch.abs(t_output_flat)) + 1e-12) + diffs.append(diff.item()) + print(f"Iteration {i}: relative L1 diff = {diff:.6f}") + + if i < 3: # Print detailed output for first few iterations + print(f" kernel output: {output_flat[:debug_print_count]}") + print(f" torch output: {t_output_flat[:debug_print_count]}") + + mean_diff = float(sum(diffs) / len(diffs)) + max_diff = float(max(diffs)) + min_diff = float(min(diffs)) + + print("\n" + "=" * 70) + print("FP8 Per-Channel MoE Test Results") + print("=" * 70) + print(f"Mean relative L1 diff: {mean_diff*100:.4f}%") + print(f"Max relative L1 diff: {max_diff*100:.4f}%") + print(f"Min relative L1 diff: {min_diff*100:.4f}%") + + # Pass/Fail criteria + threshold = 15.0 # 15% relative error threshold for FP8 + if mean_diff * 100 < threshold: + print(f"\nPASS: Mean error {mean_diff*100:.4f}% < {threshold}% threshold") + else: + print(f"\nFAIL: Mean error {mean_diff*100:.4f}% >= {threshold}% threshold") + + return {"mean": mean_diff, "max": max_diff, "min": min_diff} + + +if __name__ == "__main__": + run_fp8_perchannel_moe_test() diff --git a/kt-kernel/examples/test_write_buffer.py b/kt-kernel/examples/test_write_buffer.py index c0582b5..86097d9 100644 --- a/kt-kernel/examples/test_write_buffer.py +++ b/kt-kernel/examples/test_write_buffer.py @@ -2,12 +2,14 @@ Test write_weight_scale_to_buffer for AMX MOE operators. Supports: -- FP8: FP8 weights (1 byte) + float32 scales +- FP8: FP8 weights (1 byte) + float32 scales (block-wise) +- FP8_PERCHANNEL: FP8 weights (1 byte) + float32 per-channel scales - BF16: Native BF16 weights (2 bytes), no scales Usage: python test_write_buffer.py # Run all modes python test_write_buffer.py fp8 # Run FP8 only + python test_write_buffer.py fp8_perchannel # Run FP8 per-channel only python test_write_buffer.py bf16 # Run BF16 only """ @@ -41,6 +43,17 @@ def build_config_fp8(cpuinfer, expert_num, num_experts_per_tok, hidden_size, int return cfg +def build_config_fp8_perchannel(cpuinfer, expert_num, num_experts_per_tok, hidden_size, intermediate_size): + cfg = kt_kernel_ext.moe.MOEConfig(expert_num, num_experts_per_tok, hidden_size, intermediate_size) + cfg.max_len = 1 + cfg.quant_config.bits = 8 # FP8 + cfg.quant_config.group_size = 0 # Not used for per-channel + cfg.quant_config.zero_point = False + cfg.quant_config.per_channel = True + cfg.pool = cpuinfer.backend_ + return cfg + + def build_config_bf16(cpuinfer, expert_num, num_experts_per_tok, hidden_size, intermediate_size): cfg = kt_kernel_ext.moe.MOEConfig(expert_num, num_experts_per_tok, hidden_size, intermediate_size) cfg.max_len = 1 @@ -83,6 +96,33 @@ def allocate_weights_fp8(expert_num, hidden_size, intermediate_size, group_size) } +def allocate_weights_fp8_perchannel(expert_num, hidden_size, intermediate_size): + """Allocate FP8 per-channel weights and scales for testing""" + per_mat_weight_bytes = hidden_size * intermediate_size + per_mat_scale_elems_gate_up = intermediate_size # one scale per output channel + per_mat_scale_elems_down = hidden_size + + gate_q = torch.randint(0, 256, (expert_num * per_mat_weight_bytes,), dtype=torch.uint8) + up_q = torch.randint(0, 256, (expert_num * per_mat_weight_bytes,), dtype=torch.uint8) + down_q = torch.randint(0, 256, (expert_num * per_mat_weight_bytes,), dtype=torch.uint8) + + gate_scale = torch.randn(expert_num * per_mat_scale_elems_gate_up, dtype=torch.float32) + up_scale = torch.randn(expert_num * per_mat_scale_elems_gate_up, dtype=torch.float32) + down_scale = torch.randn(expert_num * per_mat_scale_elems_down, dtype=torch.float32) + + return { + "gate_q": gate_q, + "up_q": up_q, + "down_q": down_q, + "gate_scale": gate_scale, + "up_scale": up_scale, + "down_scale": down_scale, + "per_mat_weight_bytes": per_mat_weight_bytes, + "per_mat_scale_elems_gate_up": per_mat_scale_elems_gate_up, + "per_mat_scale_elems_down": per_mat_scale_elems_down, + } + + def allocate_weights_bf16(expert_num, hidden_size, intermediate_size): """Allocate BF16 weights for testing (no scales)""" # BF16 weights: 2 bytes per element @@ -312,6 +352,195 @@ def test_fp8_write_buffer(gpu_tp_count): return True +def test_fp8_perchannel_write_buffer(gpu_tp_count): + """Test write_weight_scale_to_buffer with FP8 per-channel weights""" + torch.manual_seed(123) + + expert_num = 256 + gpu_experts = expert_num + num_experts_per_tok = 8 + hidden_size = 3072 + intermediate_size = 1536 + + cpuinfer = make_cpu_infer() + cfg = build_config_fp8_perchannel(cpuinfer, expert_num, num_experts_per_tok, hidden_size, intermediate_size) + weights = allocate_weights_fp8_perchannel(expert_num, hidden_size, intermediate_size) + + cfg.gate_proj = weights["gate_q"].data_ptr() + cfg.up_proj = weights["up_q"].data_ptr() + cfg.down_proj = weights["down_q"].data_ptr() + cfg.gate_scale = weights["gate_scale"].data_ptr() + cfg.up_scale = weights["up_scale"].data_ptr() + cfg.down_scale = weights["down_scale"].data_ptr() + + moe = kt_kernel_ext.moe.AMXFP8PerChannel_MOE(cfg) + + physical_to_logical_map = torch.arange(expert_num, dtype=torch.int64, device="cpu").contiguous() + cpuinfer.submit(moe.load_weights_task(physical_to_logical_map.data_ptr())) + cpuinfer.sync() + + per_mat_weight_bytes = weights["per_mat_weight_bytes"] + per_mat_scale_elems_gate_up = weights["per_mat_scale_elems_gate_up"] + per_mat_scale_elems_down = weights["per_mat_scale_elems_down"] + + weight_bytes_per_expert_per_tp = per_mat_weight_bytes // gpu_tp_count + gpu_n_w13 = intermediate_size // gpu_tp_count + scale_elems_per_expert_per_tp_gate_up = gpu_n_w13 + scale_elems_per_expert_per_tp_down = per_mat_scale_elems_down + + total_weight_bytes_per_tp = gpu_experts * weight_bytes_per_expert_per_tp + total_scale_elems_per_tp_gate_up = gpu_experts * scale_elems_per_expert_per_tp_gate_up + total_scale_elems_per_tp_down = gpu_experts * scale_elems_per_expert_per_tp_down + + w13_weight_bufs = [torch.empty(2 * total_weight_bytes_per_tp, dtype=torch.uint8) for _ in range(gpu_tp_count)] + w13_scale_bufs = [ + torch.empty(2 * total_scale_elems_per_tp_gate_up, dtype=torch.float32) for _ in range(gpu_tp_count) + ] + w2_weight_bufs = [torch.empty(total_weight_bytes_per_tp, dtype=torch.uint8) for _ in range(gpu_tp_count)] + w2_scale_bufs = [torch.empty(total_scale_elems_per_tp_down, dtype=torch.float32) for _ in range(gpu_tp_count)] + + print(f"[FP8_PERCHANNEL] GPU TP count: {gpu_tp_count}, Experts: {expert_num}") + print(f"[FP8_PERCHANNEL] Weight bytes per expert per TP: {weight_bytes_per_expert_per_tp}") + print(f"[FP8_PERCHANNEL] Scale elements per expert per TP (gate/up): {scale_elems_per_expert_per_tp_gate_up}") + + def get_expert_ptrs(expert_id): + w13_weight_ptrs = [] + w13_scale_ptrs = [] + w2_weight_ptrs = [] + w2_scale_ptrs = [] + for tp_idx in range(gpu_tp_count): + w13_weight_expert_offset = expert_id * 2 * weight_bytes_per_expert_per_tp + w13_scale_expert_offset = expert_id * 2 * scale_elems_per_expert_per_tp_gate_up + w2_weight_expert_offset = expert_id * weight_bytes_per_expert_per_tp + w2_scale_expert_offset = expert_id * scale_elems_per_expert_per_tp_down + + w13_weight_ptrs.append(w13_weight_bufs[tp_idx].data_ptr() + w13_weight_expert_offset) + w13_scale_ptrs.append(w13_scale_bufs[tp_idx].data_ptr() + w13_scale_expert_offset * 4) + w2_weight_ptrs.append(w2_weight_bufs[tp_idx].data_ptr() + w2_weight_expert_offset) + w2_scale_ptrs.append(w2_scale_bufs[tp_idx].data_ptr() + w2_scale_expert_offset * 4) + return w13_weight_ptrs, w13_scale_ptrs, w2_weight_ptrs, w2_scale_ptrs + + for _ in range(2): + for expert_id in range(gpu_experts): + w13_weight_ptrs, w13_scale_ptrs, w2_weight_ptrs, w2_scale_ptrs = get_expert_ptrs(expert_id) + cpuinfer.submit( + moe.write_weight_scale_to_buffer_task( + gpu_tp_count=gpu_tp_count, + expert_id=expert_id, + w13_weight_ptrs=w13_weight_ptrs, + w13_scale_ptrs=w13_scale_ptrs, + w2_weight_ptrs=w2_weight_ptrs, + w2_scale_ptrs=w2_scale_ptrs, + ) + ) + cpuinfer.sync() + + begin_time = time.perf_counter_ns() + for expert_id in range(gpu_experts): + w13_weight_ptrs, w13_scale_ptrs, w2_weight_ptrs, w2_scale_ptrs = get_expert_ptrs(expert_id) + cpuinfer.submit( + moe.write_weight_scale_to_buffer_task( + gpu_tp_count=gpu_tp_count, + expert_id=expert_id, + w13_weight_ptrs=w13_weight_ptrs, + w13_scale_ptrs=w13_scale_ptrs, + w2_weight_ptrs=w2_weight_ptrs, + w2_scale_ptrs=w2_scale_ptrs, + ) + ) + cpuinfer.sync() + end_time = time.perf_counter_ns() + elapsed_ms = (end_time - begin_time) / 1e6 + + total_bytes = ( + hidden_size * intermediate_size * gpu_experts * 3 + + (per_mat_scale_elems_gate_up * 2 + per_mat_scale_elems_down) * gpu_experts * 4 + ) + print(f"[FP8_PERCHANNEL] write_weight_scale_to_buffer time: {elapsed_ms:.2f} ms") + print(f"[FP8_PERCHANNEL] Throughput: {total_bytes / (elapsed_ms * 1e6):.2f} GB/s") + + def split_expert_tensor(tensor, chunk): + return [tensor[i * chunk : (i + 1) * chunk] for i in range(expert_num)] + + gate_q = weights["gate_q"] + up_q = weights["up_q"] + down_q = weights["down_q"] + gate_scale = weights["gate_scale"] + up_scale = weights["up_scale"] + down_scale = weights["down_scale"] + + gate_q_experts = split_expert_tensor(gate_q, per_mat_weight_bytes) + up_q_experts = split_expert_tensor(up_q, per_mat_weight_bytes) + down_q_experts = split_expert_tensor(down_q, per_mat_weight_bytes) + gate_scale_experts = split_expert_tensor(gate_scale, per_mat_scale_elems_gate_up) + up_scale_experts = split_expert_tensor(up_scale, per_mat_scale_elems_gate_up) + down_scale_experts = split_expert_tensor(down_scale, per_mat_scale_elems_down) + + for tp_idx in range(gpu_tp_count): + expected_w13_weights = [] + expected_w13_scales = [] + expected_w2_weights = [] + expected_w2_scales = [] + + weight13_per_tp = per_mat_weight_bytes // gpu_tp_count + scale13_per_tp = per_mat_scale_elems_gate_up // gpu_tp_count + + for expert_id in range(gpu_experts): + start_weight = tp_idx * weight13_per_tp + end_weight = (tp_idx + 1) * weight13_per_tp + start_scale = tp_idx * scale13_per_tp + end_scale = (tp_idx + 1) * scale13_per_tp + + gate_weight_tp = gate_q_experts[expert_id][start_weight:end_weight] + gate_scale_tp = gate_scale_experts[expert_id][start_scale:end_scale] + up_weight_tp = up_q_experts[expert_id][start_weight:end_weight] + up_scale_tp = up_scale_experts[expert_id][start_scale:end_scale] + + down_weight_tp_parts = [] + tp_slice_weight_size = intermediate_size // gpu_tp_count + + for row_idx in range(hidden_size): + row_weight_start = row_idx * intermediate_size + tp_weight_offset = row_weight_start + tp_idx * tp_slice_weight_size + down_weight_tp_parts.append( + down_q_experts[expert_id][tp_weight_offset : tp_weight_offset + tp_slice_weight_size] + ) + + down_weight_tp = torch.cat(down_weight_tp_parts) + down_scale_tp = down_scale_experts[expert_id] + + expected_w13_weights.append(gate_weight_tp) + expected_w13_weights.append(up_weight_tp) + expected_w13_scales.append(gate_scale_tp) + expected_w13_scales.append(up_scale_tp) + expected_w2_weights.append(down_weight_tp) + expected_w2_scales.append(down_scale_tp) + + expected_w13_weight = torch.cat(expected_w13_weights) + expected_w13_scale = torch.cat(expected_w13_scales) + expected_w2_weight = torch.cat(expected_w2_weights) + expected_w2_scale = torch.cat(expected_w2_scales) + + if not torch.equal(w13_weight_bufs[tp_idx], expected_w13_weight): + diff_mask = w13_weight_bufs[tp_idx] != expected_w13_weight + first_diff_idx = diff_mask.nonzero()[0].item() if diff_mask.any() else -1 + raise AssertionError(f"[FP8_PERCHANNEL] w13 weight mismatch for TP {tp_idx} at index {first_diff_idx}") + + if not torch.allclose(w13_scale_bufs[tp_idx], expected_w13_scale): + raise AssertionError(f"[FP8_PERCHANNEL] w13 scale mismatch for TP {tp_idx}") + + if not torch.equal(w2_weight_bufs[tp_idx], expected_w2_weight): + diff_mask = w2_weight_bufs[tp_idx] != expected_w2_weight + first_diff_idx = diff_mask.nonzero()[0].item() if diff_mask.any() else -1 + raise AssertionError(f"[FP8_PERCHANNEL] w2 weight mismatch for TP {tp_idx} at index {first_diff_idx}") + + if not torch.allclose(w2_scale_bufs[tp_idx], expected_w2_scale): + raise AssertionError(f"[FP8_PERCHANNEL] w2 scale mismatch for TP {tp_idx}") + + print(f"[FP8_PERCHANNEL] TP={gpu_tp_count} PASSED (verified {gpu_experts} experts across {gpu_tp_count} TP parts)") + return True + + def test_bf16_write_buffer(gpu_tp_count): """Test write_weight_scale_to_buffer with BF16 weights (no scales)""" torch.manual_seed(123) @@ -478,6 +707,8 @@ def test_with_tp(quant_mode: str, gpu_tp_count: int): """Test write_weight_scale_to_buffer with specified mode and TP count""" if quant_mode == "fp8": return test_fp8_write_buffer(gpu_tp_count) + elif quant_mode == "fp8_perchannel": + return test_fp8_perchannel_write_buffer(gpu_tp_count) elif quant_mode == "bf16": return test_bf16_write_buffer(gpu_tp_count) else: @@ -487,7 +718,7 @@ def test_with_tp(quant_mode: str, gpu_tp_count: int): def main(quant_modes=None): """Run tests for specified quant modes""" if quant_modes is None: - quant_modes = ["fp8", "bf16"] + quant_modes = ["fp8", "fp8_perchannel", "bf16"] tp_values = [1, 2, 4] all_passed = True @@ -525,10 +756,10 @@ def main(quant_modes=None): if __name__ == "__main__": if len(sys.argv) > 1: mode = sys.argv[1].lower() - if mode in ["fp8", "bf16"]: + if mode in ["fp8", "fp8_perchannel", "bf16"]: main([mode]) else: - print(f"Unknown mode: {mode}. Use 'fp8' or 'bf16'") + print(f"Unknown mode: {mode}. Use 'fp8', 'fp8_perchannel' or 'bf16'") sys.exit(1) else: main() diff --git a/kt-kernel/ext_bindings.cpp b/kt-kernel/ext_bindings.cpp index ee1cf07..7344b2d 100644 --- a/kt-kernel/ext_bindings.cpp +++ b/kt-kernel/ext_bindings.cpp @@ -37,8 +37,9 @@ static const bool _is_plain_ = false; #if defined(__x86_64__) && defined(USE_AMX_AVX_KERNEL) #include "operators/amx/awq-moe.hpp" #if defined(__AVX512BF16__) -#include "operators/amx/bf16-moe.hpp" // Native BF16 MoE using CRTP pattern -#include "operators/amx/fp8-moe.hpp" // FP8 MoE requires AVX512 BF16 support +#include "operators/amx/bf16-moe.hpp" // Native BF16 MoE using CRTP pattern +#include "operators/amx/fp8-moe.hpp" // FP8 MoE requires AVX512 BF16 support +#include "operators/amx/fp8-perchannel-moe.hpp" // FP8 Per-Channel MoE for GLM-4.7-FP8 #endif #include "operators/amx/k2-moe.hpp" #include "operators/amx/la/amx_kernels.hpp" @@ -252,8 +253,9 @@ void bind_moe_module(py::module_& moe_module, const char* name) { .def("load_weights", &MoeClass::load_weights) .def("forward", &MoeClass::forward_binding); -#if defined(__x86_64__) && defined(USE_AMX_AVX_KERNEL) - if constexpr (std::is_same_v>) { + // Bind write_weight_scale_to_buffer_task for MoE types that support it + // Uses SFINAE to detect if MoeClass has write_weight_scale_to_buffer method + if constexpr (requires { &MoeClass::write_weight_scale_to_buffer; }) { struct WriteWeightScaleToBufferBindings { struct Args { CPUInfer* cpuinfer; @@ -295,99 +297,6 @@ void bind_moe_module(py::module_& moe_module, const char* name) { py::arg("gpu_tp_count"), py::arg("expert_id"), py::arg("w13_weight_ptrs"), py::arg("w13_scale_ptrs"), py::arg("w2_weight_ptrs"), py::arg("w2_scale_ptrs")); } - -#if defined(__AVX512BF16__) - // FP8 MoE: processes one expert at a time (expert_id instead of gpu_experts_num) - // Only available on CPUs with AVX512 BF16 support - if constexpr (std::is_same_v>) { - struct WriteWeightScaleToBufferBindings { - struct Args { - CPUInfer* cpuinfer; - MoeClass* moe; - int gpu_tp_count; - int expert_id; - std::vector w13_weight_ptrs; - std::vector w13_scale_ptrs; - std::vector w2_weight_ptrs; - std::vector w2_scale_ptrs; - }; - - static void inner(void* args) { - Args* args_ = (Args*)args; - args_->cpuinfer->enqueue(&MoeClass::write_weight_scale_to_buffer, args_->moe, args_->gpu_tp_count, - args_->expert_id, args_->w13_weight_ptrs, args_->w13_scale_ptrs, args_->w2_weight_ptrs, - args_->w2_scale_ptrs); - } - - static std::pair cpuinfer_interface(std::shared_ptr moe, int gpu_tp_count, - int expert_id, py::list w13_weight_ptrs, - py::list w13_scale_ptrs, py::list w2_weight_ptrs, - py::list w2_scale_ptrs) { - // Convert Python lists to std::vector - std::vector w13_weight_vec, w13_scale_vec, w2_weight_vec, w2_scale_vec; - - for (auto item : w13_weight_ptrs) w13_weight_vec.push_back(py::cast(item)); - for (auto item : w13_scale_ptrs) w13_scale_vec.push_back(py::cast(item)); - for (auto item : w2_weight_ptrs) w2_weight_vec.push_back(py::cast(item)); - for (auto item : w2_scale_ptrs) w2_scale_vec.push_back(py::cast(item)); - - Args* args = new Args{nullptr, moe.get(), gpu_tp_count, expert_id, - w13_weight_vec, w13_scale_vec, w2_weight_vec, w2_scale_vec}; - return std::make_pair((intptr_t)&inner, (intptr_t)args); - } - }; - - moe_cls.def("write_weight_scale_to_buffer_task", &WriteWeightScaleToBufferBindings::cpuinfer_interface, - py::arg("gpu_tp_count"), py::arg("expert_id"), py::arg("w13_weight_ptrs"), py::arg("w13_scale_ptrs"), - py::arg("w2_weight_ptrs"), py::arg("w2_scale_ptrs")); - } - - // BF16 MoE: processes one expert at a time (expert_id instead of gpu_experts_num) - // Only available on CPUs with AVX512 BF16 support - if constexpr (std::is_same_v>) { - struct WriteWeightScaleToBufferBindings { - struct Args { - CPUInfer* cpuinfer; - MoeClass* moe; - int gpu_tp_count; - int expert_id; - std::vector w13_weight_ptrs; - std::vector w13_scale_ptrs; - std::vector w2_weight_ptrs; - std::vector w2_scale_ptrs; - }; - - static void inner(void* args) { - Args* args_ = (Args*)args; - args_->cpuinfer->enqueue(&MoeClass::write_weight_scale_to_buffer, args_->moe, args_->gpu_tp_count, - args_->expert_id, args_->w13_weight_ptrs, args_->w13_scale_ptrs, args_->w2_weight_ptrs, - args_->w2_scale_ptrs); - } - - static std::pair cpuinfer_interface(std::shared_ptr moe, int gpu_tp_count, - int expert_id, py::list w13_weight_ptrs, - py::list w13_scale_ptrs, py::list w2_weight_ptrs, - py::list w2_scale_ptrs) { - // Convert Python lists to std::vector - std::vector w13_weight_vec, w13_scale_vec, w2_weight_vec, w2_scale_vec; - - for (auto item : w13_weight_ptrs) w13_weight_vec.push_back(py::cast(item)); - for (auto item : w13_scale_ptrs) w13_scale_vec.push_back(py::cast(item)); - for (auto item : w2_weight_ptrs) w2_weight_vec.push_back(py::cast(item)); - for (auto item : w2_scale_ptrs) w2_scale_vec.push_back(py::cast(item)); - - Args* args = new Args{nullptr, moe.get(), gpu_tp_count, expert_id, - w13_weight_vec, w13_scale_vec, w2_weight_vec, w2_scale_vec}; - return std::make_pair((intptr_t)&inner, (intptr_t)args); - } - }; - - moe_cls.def("write_weight_scale_to_buffer_task", &WriteWeightScaleToBufferBindings::cpuinfer_interface, - py::arg("gpu_tp_count"), py::arg("expert_id"), py::arg("w13_weight_ptrs"), py::arg("w13_scale_ptrs"), - py::arg("w2_weight_ptrs"), py::arg("w2_scale_ptrs")); - } -#endif // __AVX512BF16__ -#endif } PYBIND11_MODULE(kt_kernel_ext, m) { @@ -585,7 +494,8 @@ PYBIND11_MODULE(kt_kernel_ext, m) { .def_readwrite("quant_method", &QuantConfig::quant_method) .def_readwrite("bits", &QuantConfig::bits) .def_readwrite("group_size", &QuantConfig::group_size) - .def_readwrite("zero_point", &QuantConfig::zero_point); + .def_readwrite("zero_point", &QuantConfig::zero_point) + .def_readwrite("per_channel", &QuantConfig::per_channel); auto moe_module = m.def_submodule("moe"); @@ -660,6 +570,7 @@ PYBIND11_MODULE(kt_kernel_ext, m) { #if defined(__AVX512BF16__) bind_moe_module>(moe_module, "AMXBF16_MOE"); bind_moe_module>(moe_module, "AMXFP8_MOE"); + bind_moe_module>(moe_module, "AMXFP8PerChannel_MOE"); #endif #endif #if defined(USE_MOE_KERNEL) diff --git a/kt-kernel/operators/amx/bf16-moe.hpp b/kt-kernel/operators/amx/bf16-moe.hpp index 29d38ff..d8733de 100644 --- a/kt-kernel/operators/amx/bf16-moe.hpp +++ b/kt-kernel/operators/amx/bf16-moe.hpp @@ -213,6 +213,12 @@ class AMX_BF16_MOE_TP : public AMX_MOE_BASE> { _mm512_storeu_si512(dst, data); } + // Fast 64-byte non-temporal store (bypass cache for write-only patterns) + static inline void fast_stream_64(void* __restrict dst, const void* __restrict src) { + __m512i data = _mm512_loadu_si512(src); + _mm512_stream_si512((__m512i*)dst, data); + } + // Fast memcpy for arbitrary sizes using AVX512 static inline void fast_memcpy(void* __restrict dst, const void* __restrict src, size_t bytes) { uint8_t* d = (uint8_t*)dst; @@ -262,19 +268,19 @@ class AMX_BF16_MOE_TP : public AMX_MOE_BASE> { amx::transpose_16x16_32bit(temp_block1); amx::transpose_16x16_32bit(temp_block2); - // Copy transposed data to destination in n-major layout - const ggml_bf16_t* temp1_bf16 = reinterpret_cast(temp_block1); - const ggml_bf16_t* temp2_bf16 = reinterpret_cast(temp_block2); - + // Copy transposed data to destination in n-major layout using non-temporal stores // First 16 rows (block 1) for (int i = 0; i < TILE_N; i++) { - std::memcpy(dst + i * dst_row_stride, temp1_bf16 + i * K_STEP, K_STEP * sizeof(ggml_bf16_t)); + fast_stream_64(dst + i * dst_row_stride, &temp_block1[i]); } // Next 16 rows (block 2) for (int i = 0; i < TILE_N; i++) { - std::memcpy(dst + (TILE_N + i) * dst_row_stride, temp2_bf16 + i * K_STEP, K_STEP * sizeof(ggml_bf16_t)); + fast_stream_64(dst + (TILE_N + i) * dst_row_stride, &temp_block2[i]); } + + // Ensure all stores complete before returning + _mm_sfence(); } /** diff --git a/kt-kernel/operators/amx/fp8-perchannel-moe.hpp b/kt-kernel/operators/amx/fp8-perchannel-moe.hpp new file mode 100644 index 0000000..e08eb4f --- /dev/null +++ b/kt-kernel/operators/amx/fp8-perchannel-moe.hpp @@ -0,0 +1,716 @@ +/** + * @Description : FP8 Per-Channel AMX MoE operator for GLM-4.7-FP8 native inference + * @Author : Claude + * @Date : 2025-01-12 + * @Version : 1.0.0 + * @Copyright (c) 2024 by KVCache.AI, All Rights Reserved. + * + * This file implements FP8 MoE with per-channel quantization using CRTP pattern. + * Per-channel quantization: each output channel (row) has one scale factor. + * This is different from block-wise quantization where each 128x128 block has one scale. + **/ +#ifndef CPUINFER_OPERATOR_AMX_FP8_PERCHANNEL_MOE_H +#define CPUINFER_OPERATOR_AMX_FP8_PERCHANNEL_MOE_H + +#include "la/amx_raw_buffers.hpp" +#include "la/amx_raw_kernels.hpp" +#include "moe_base.hpp" + +/** + * @brief FP8 Per-Channel MoE operator using CRTP pattern + * @tparam T Kernel type, defaults to GemmKernel224FP8PerChannel + * + * This class provides FP8 per-channel specific implementations: + * - do_gate_up_gemm, do_down_gemm : FP8 weight -> BF16 conversion mat mul with per-channel scale + * - load_weights: Load FP8 weights with per-channel scales (shape: [n]) + */ +template +class AMX_FP8_PERCHANNEL_MOE_TP : public AMX_MOE_BASE> { + using Base = AMX_MOE_BASE>; + using Base::config_; + using Base::down_ba_; + using Base::down_bb_; + using Base::down_bc_; + using Base::gate_bb_; + using Base::gate_bc_; + using Base::gate_up_ba_; + using Base::m_local_num_; + using Base::tp_part_idx; + using Base::up_bb_; + using Base::up_bc_; + + public: + using typename Base::input_t; + using typename Base::output_t; + + AMX_FP8_PERCHANNEL_MOE_TP() = default; + + AMX_FP8_PERCHANNEL_MOE_TP(GeneralMOEConfig config, int tp_part_idx_ = 0) : Base(config, tp_part_idx_) { + // Initialization now happens in derived_init() which is called by base constructor + } + + void derived_init() { + auto& quant_config = config_.quant_config; + if (!quant_config.per_channel) { + throw std::runtime_error("KT-Kernel FP8 Per-Channel MoE requires per_channel=true"); + } + printf("Created AMX_FP8_PERCHANNEL_MOE_TP %d at numa %d\n", tp_part_idx, numa_node_of_cpu(sched_getcpu())); + } + + ~AMX_FP8_PERCHANNEL_MOE_TP() = default; + + // ============================================================================ + // CRTP buffer creation - per-channel (no group_size needed) + // ============================================================================ + + size_t buffer_a_required_size_impl(size_t m, size_t k) const { return T::BufferA::required_size(m, k); } + size_t buffer_b_required_size_impl(size_t n, size_t k) const { + // Per-channel: weight size + n scales (no group_size) + return T::BufferB::required_size(n, k); + } + size_t buffer_c_required_size_impl(size_t m, size_t n) const { return T::BufferC::required_size(m, n); } + + std::shared_ptr make_buffer_a_impl(size_t m, size_t k, void* data) const { + return std::make_shared(m, k, data); + } + std::shared_ptr make_buffer_b_impl(size_t n, size_t k, void* data) const { + // Per-channel BufferB doesn't need group_size + return std::make_shared(n, k, data); + } + std::shared_ptr make_buffer_c_impl(size_t m, size_t n, void* data) const { + return std::make_shared(m, n, data); + } + + // ============================================================================ + // CRTP virtual points - GEMM dispatch (per-channel) + // ============================================================================ + + void do_gate_up_gemm(bool do_up, int expert_idx, int ith, int nth, int qlen) { + int m = m_local_num_[expert_idx]; + auto& ba = gate_up_ba_[expert_idx]; + auto& bb = do_up ? up_bb_[expert_idx] : gate_bb_[expert_idx]; + auto& bc = do_up ? up_bc_[expert_idx] : gate_bc_[expert_idx]; + + // Per-channel: use vec_mul_perchannel instead of vec_mul_kgroup + amx::float_mat_vec_perchannel(m, config_.intermediate_size, config_.hidden_size, ba.get(), bb.get(), bc.get(), + ith, nth); + } + + void do_down_gemm(int expert_idx, int ith, int nth, int qlen) { + int m = m_local_num_[expert_idx]; + + amx::float_mat_vec_perchannel(m, config_.hidden_size, config_.intermediate_size, down_ba_[expert_idx].get(), + down_bb_[expert_idx].get(), down_bc_[expert_idx].get(), ith, nth); + } + + // Fast 64-byte (512-bit) memcpy using AVX512 + static inline void fast_memcpy_64(void* __restrict dst, const void* __restrict src) { + __m512i data = _mm512_loadu_si512(src); + _mm512_storeu_si512(dst, data); + } + + // Fast memcpy for arbitrary sizes using AVX512 + static inline void fast_memcpy(void* __restrict dst, const void* __restrict src, size_t bytes) { + uint8_t* d = (uint8_t*)dst; + const uint8_t* s = (const uint8_t*)src; + size_t chunks = bytes / 64; + for (size_t i = 0; i < chunks; i++) { + fast_memcpy_64(d, s); + d += 64; + s += 64; + } + bytes -= chunks * 64; + if (bytes > 0) { + std::memcpy(d, s, bytes); + } + } + + /** + * @brief Unpack a single N_STEP x K_STEP block from packed BufferB format to n-major format + * + * This is the inverse of the packing done in BufferBFP8PerChannelImpl::from_mat. + * Optimized with AVX512 gather for efficient non-contiguous reads. + * + * @param src Pointer to packed data (N_STEP * K_STEP bytes in packed layout) + * @param dst Pointer to destination in n-major layout + * @param dst_row_stride Row stride in destination buffer (number of columns in full matrix) + */ + static inline void unpack_nk_block(const uint8_t* src, uint8_t* dst, size_t dst_row_stride) { + // row_map[packed_i] gives the base row for packed index packed_i + static constexpr int row_map[8] = {0, 16, 4, 20, 8, 24, 12, 28}; + const uint64_t* src64 = reinterpret_cast(src); + + // Gather indices: src64[8*j + packed_i] for j = 0..7 + // Offsets in uint64 units: 0, 8, 16, 24, 32, 40, 48, 56 (+ packed_i for each group) + const __m512i gather_offsets = _mm512_set_epi64(56, 48, 40, 32, 24, 16, 8, 0); + + // Process each packed group (8 groups of 4 rows each = 32 rows total) + for (int packed_i = 0; packed_i < 8; packed_i++) { + const int base_row = row_map[packed_i]; + const uint64_t* base_src = src64 + packed_i; + + // Gather 8 values for j=0..7 and j=8..15 + __m512i vals_0_7 = _mm512_i64gather_epi64(gather_offsets, base_src, 8); + __m512i vals_8_15 = _mm512_i64gather_epi64(gather_offsets, base_src + 64, 8); + + // Extract 4 rows from each set of 8 values + // Row 0: bits 0-15 + __m128i row0_lo = _mm512_cvtepi64_epi16(_mm512_and_si512(vals_0_7, _mm512_set1_epi64(0xFFFF))); + __m128i row0_hi = _mm512_cvtepi64_epi16(_mm512_and_si512(vals_8_15, _mm512_set1_epi64(0xFFFF))); + // Row 1: bits 16-31 + __m128i row1_lo = + _mm512_cvtepi64_epi16(_mm512_and_si512(_mm512_srli_epi64(vals_0_7, 16), _mm512_set1_epi64(0xFFFF))); + __m128i row1_hi = + _mm512_cvtepi64_epi16(_mm512_and_si512(_mm512_srli_epi64(vals_8_15, 16), _mm512_set1_epi64(0xFFFF))); + // Row 2: bits 32-47 + __m128i row2_lo = + _mm512_cvtepi64_epi16(_mm512_and_si512(_mm512_srli_epi64(vals_0_7, 32), _mm512_set1_epi64(0xFFFF))); + __m128i row2_hi = + _mm512_cvtepi64_epi16(_mm512_and_si512(_mm512_srli_epi64(vals_8_15, 32), _mm512_set1_epi64(0xFFFF))); + // Row 3: bits 48-63 + __m128i row3_lo = _mm512_cvtepi64_epi16(_mm512_srli_epi64(vals_0_7, 48)); + __m128i row3_hi = _mm512_cvtepi64_epi16(_mm512_srli_epi64(vals_8_15, 48)); + + // Store 32 bytes (16 x uint16) to each row + // Combine two 128-bit values into 256-bit for more efficient stores + uint8_t* row0_dst = dst + (size_t)base_row * dst_row_stride; + uint8_t* row1_dst = dst + (size_t)(base_row + 1) * dst_row_stride; + uint8_t* row2_dst = dst + (size_t)(base_row + 2) * dst_row_stride; + uint8_t* row3_dst = dst + (size_t)(base_row + 3) * dst_row_stride; + + // Combine lo and hi into 256-bit and store + __m256i row0_256 = _mm256_set_m128i(row0_hi, row0_lo); + __m256i row1_256 = _mm256_set_m128i(row1_hi, row1_lo); + __m256i row2_256 = _mm256_set_m128i(row2_hi, row2_lo); + __m256i row3_256 = _mm256_set_m128i(row3_hi, row3_lo); + + _mm256_storeu_si256((__m256i*)row0_dst, row0_256); + _mm256_storeu_si256((__m256i*)row1_dst, row1_256); + _mm256_storeu_si256((__m256i*)row2_dst, row2_256); + _mm256_storeu_si256((__m256i*)row3_dst, row3_256); + } + } + + /** + * @brief Unpack 4 consecutive N_STEP x K_STEP blocks to maximize cache line utilization + * + * Processing 4 blocks together means each row write is 128 bytes = 2 cache lines, + * which greatly improves write efficiency compared to 32 bytes per row. + * + * @param src Array of 4 source pointers (each pointing to a 32x32 packed block) + * @param dst Destination pointer in n-major layout + * @param dst_row_stride Row stride in destination buffer + */ + static inline void unpack_4nk_blocks(const uint8_t* src[4], uint8_t* dst, size_t dst_row_stride) { + static constexpr int row_map[8] = {0, 16, 4, 20, 8, 24, 12, 28}; + constexpr int K_STEP = T::K_STEP; // 32 + + // Reinterpret as uint64 arrays for efficient access + const uint64_t* src0 = reinterpret_cast(src[0]); + const uint64_t* src1 = reinterpret_cast(src[1]); + const uint64_t* src2 = reinterpret_cast(src[2]); + const uint64_t* src3 = reinterpret_cast(src[3]); + + // Process all 32 rows, writing 128 bytes (4 x 32) per row + for (int packed_i = 0; packed_i < 8; packed_i++) { + const int base_row = row_map[packed_i]; + + // Process 4 rows at a time + for (int r = 0; r < 4; r++) { + uint16_t* row_dst = reinterpret_cast(dst + (size_t)(base_row + r) * dst_row_stride); + const int shift = r * 16; + + // Unroll: process all 4 blocks x 16 columns = 64 uint16 values + // Block 0: columns 0-15 + for (int j = 0; j < 16; j++) { + row_dst[j] = static_cast(src0[8 * j + packed_i] >> shift); + } + // Block 1: columns 16-31 + for (int j = 0; j < 16; j++) { + row_dst[16 + j] = static_cast(src1[8 * j + packed_i] >> shift); + } + // Block 2: columns 32-47 + for (int j = 0; j < 16; j++) { + row_dst[32 + j] = static_cast(src2[8 * j + packed_i] >> shift); + } + // Block 3: columns 48-63 + for (int j = 0; j < 16; j++) { + row_dst[48 + j] = static_cast(src3[8 * j + packed_i] >> shift); + } + } + } + } + + /** + * @brief Reconstruct weights for a single expert to the output buffers (per-channel version) + * + * Directly unpacks from packed BufferB format to n-major GPU buffers without intermediate storage. + * Scale handling is simplified for per-channel quantization (linear copy instead of block-wise). + * + * @param gpu_tp_count Number of GPU TP parts (1, 2, 4, or 8) + * @param cpu_tp_count Number of CPU TP parts + * @param expert_id Expert index to process + * @param full_config Full configuration (before CPU TP split) + * @param w13_weight_ptrs Pointers to gate+up weight buffers (one per GPU TP) + * @param w13_scale_ptrs Pointers to gate+up scale buffers (one per GPU TP) + * @param w2_weight_ptrs Pointers to down weight buffers (one per GPU TP) + * @param w2_scale_ptrs Pointers to down scale buffers (one per GPU TP) + */ + void write_weights_to_buffer(int gpu_tp_count, [[maybe_unused]] int cpu_tp_count, int expert_id, + const GeneralMOEConfig& full_config, const std::vector& w13_weight_ptrs, + const std::vector& w13_scale_ptrs, + const std::vector& w2_weight_ptrs, + const std::vector& w2_scale_ptrs) const { + auto& config = config_; + auto pool = config.pool->get_subpool(tp_part_idx); + + constexpr int N_STEP = T::N_STEP; + constexpr int K_STEP = T::K_STEP; + constexpr int N_BLOCK = T::N_BLOCK; + constexpr int K_BLOCK = T::K_BLOCK; + + // ========= W13 (gate+up): Shape [intermediate, hidden], split by N only ========= + const int cpu_n_w13 = config.intermediate_size; + const int cpu_k_w13 = config.hidden_size; + const int gpu_n_w13 = full_config.intermediate_size / gpu_tp_count; + const int gpu_k_w13 = full_config.hidden_size; + const int global_n_offset_w13 = tp_part_idx * cpu_n_w13; + + const size_t gpu_w13_weight_per_mat = (size_t)gpu_n_w13 * gpu_k_w13; + // Per-channel scale: shape [n] for each matrix + const size_t gpu_w13_scale_per_mat = (size_t)gpu_n_w13; + + // ========= W2 (down): Shape [hidden, intermediate], split by K ========= + const int cpu_n_w2 = config.hidden_size; + const int cpu_k_w2 = config.intermediate_size; + const int gpu_n_w2 = full_config.hidden_size; + const int gpu_k_w2 = full_config.intermediate_size / gpu_tp_count; + const int global_k_offset_w2 = tp_part_idx * cpu_k_w2; + + const size_t gpu_w2_weight_per_mat = (size_t)gpu_n_w2 * gpu_k_w2; + // Per-channel scale for down: shape [hidden_size] - not split by K + const size_t gpu_w2_scale_per_mat = (size_t)gpu_n_w2; + + // ========= Optimized job layout ========= + constexpr int NUM_W13_TASKS = 32; // Per matrix (gate or up), total 64 for w13 + constexpr int NUM_W2_TASKS = 32; // For down matrix + constexpr int SCALE_TASKS = 3; // gate_scale, up_scale, down_scale + + const int total_tasks = NUM_W13_TASKS * 2 + NUM_W2_TASKS + SCALE_TASKS; + + // Calculate N_STEP blocks per task (must be N_STEP aligned for correct BufferB addressing) + const int w13_n_steps = div_up(cpu_n_w13, N_STEP); + const int w13_steps_per_task = div_up(w13_n_steps, NUM_W13_TASKS); + const int w2_n_steps = div_up(cpu_n_w2, N_STEP); + const int w2_steps_per_task = div_up(w2_n_steps, NUM_W2_TASKS); + + pool->do_work_stealing_job( + total_tasks, nullptr, + [=, &w13_weight_ptrs, &w13_scale_ptrs, &w2_weight_ptrs, &w2_scale_ptrs, this](int task_id) { + if (task_id < NUM_W13_TASKS * 2) { + // ========= W13 weight task: process chunk of rows x full K ========= + const bool is_up = task_id >= NUM_W13_TASKS; + const int chunk_idx = task_id % NUM_W13_TASKS; + const auto& bb = is_up ? up_bb_[expert_id] : gate_bb_[expert_id]; + + // Calculate row range for this task (N_STEP aligned) + const int step_start = chunk_idx * w13_steps_per_task; + const int step_end = std::min(step_start + w13_steps_per_task, w13_n_steps); + if (step_start >= w13_n_steps) return; + const int chunk_n_start = step_start * N_STEP; + const int chunk_n_end = std::min(step_end * N_STEP, cpu_n_w13); + + // Process each N_STEP within this chunk + for (int local_n_start = chunk_n_start; local_n_start < chunk_n_end; local_n_start += N_STEP) { + // Calculate GPU target and offset for each N_STEP (may cross GPU TP boundaries) + const int global_n = global_n_offset_w13 + local_n_start; + const int target_gpu = global_n / gpu_n_w13; + const int n_in_gpu = global_n % gpu_n_w13; + + uint8_t* weight_base = (uint8_t*)w13_weight_ptrs[target_gpu]; + // Pointer already points to current expert's location, only add offset for up matrix + const size_t expert_weight_off = is_up ? gpu_w13_weight_per_mat : 0; + + // Calculate N_BLOCK info for source addressing + const int n_block_idx = local_n_start / N_BLOCK; + const int n_block_begin = n_block_idx * N_BLOCK; + const int n_block_size = std::min(N_BLOCK, cpu_n_w13 - n_block_begin); + const int n_in_block = local_n_start - n_block_begin; + + // Process all K in groups of 4 K_STEPs when possible for cache efficiency + for (int k_block_begin = 0; k_block_begin < cpu_k_w13; k_block_begin += K_BLOCK) { + const int k_block_size = std::min(K_BLOCK, cpu_k_w13 - k_block_begin); + + // Try to process 4 K_STEPs at once (128 columns = 2 cache lines per row) + int k_begin = 0; + for (; k_begin + 4 * K_STEP <= k_block_size; k_begin += 4 * K_STEP) { + const uint8_t* src_ptrs[4]; + for (int i = 0; i < 4; i++) { + src_ptrs[i] = bb->b + (size_t)n_block_begin * cpu_k_w13 + (size_t)k_block_begin * n_block_size + + (size_t)n_in_block * k_block_size + (size_t)(k_begin + i * K_STEP) * N_STEP; + } + uint8_t* dst = + weight_base + expert_weight_off + (size_t)n_in_gpu * gpu_k_w13 + k_block_begin + k_begin; + unpack_4nk_blocks(src_ptrs, dst, gpu_k_w13); + } + + // Handle remaining K_STEPs one by one + for (; k_begin < k_block_size; k_begin += K_STEP) { + const uint8_t* src = bb->b + (size_t)n_block_begin * cpu_k_w13 + + (size_t)k_block_begin * n_block_size + (size_t)n_in_block * k_block_size + + (size_t)k_begin * N_STEP; + uint8_t* dst = + weight_base + expert_weight_off + (size_t)n_in_gpu * gpu_k_w13 + k_block_begin + k_begin; + unpack_nk_block(src, dst, gpu_k_w13); + } + } + } + + } else if (task_id < NUM_W13_TASKS * 2 + NUM_W2_TASKS) { + // ========= W2 weight task: process chunk of rows x all K slices ========= + const int chunk_idx = task_id - NUM_W13_TASKS * 2; + const auto& bb = down_bb_[expert_id]; + + // Calculate row range for this task (N_STEP aligned) + const int step_start = chunk_idx * w2_steps_per_task; + const int step_end = std::min(step_start + w2_steps_per_task, w2_n_steps); + if (step_start >= w2_n_steps) return; + const int chunk_n_start = step_start * N_STEP; + const int chunk_n_end = std::min(step_end * N_STEP, cpu_n_w2); + + // Process each N_STEP within this chunk + for (int local_n_start = chunk_n_start; local_n_start < chunk_n_end; local_n_start += N_STEP) { + // Calculate N_BLOCK info for source addressing + const int n_block_idx = local_n_start / N_BLOCK; + const int n_block_begin = n_block_idx * N_BLOCK; + const int n_block_size = std::min(N_BLOCK, cpu_n_w2 - n_block_begin); + const int n_in_block = local_n_start - n_block_begin; + + // Process all K slices (each slice goes to a different GPU TP) + for (int k_slice_start = 0; k_slice_start < cpu_k_w2; k_slice_start += gpu_k_w2) { + const int k_slice_end = std::min(k_slice_start + gpu_k_w2, cpu_k_w2); + + const int global_k_start = global_k_offset_w2 + k_slice_start; + const int target_gpu = global_k_start / gpu_k_w2; + const int k_in_gpu_base = global_k_start % gpu_k_w2; + + uint8_t* weight_base = (uint8_t*)w2_weight_ptrs[target_gpu]; + // Pointer already points to current expert's location + const size_t expert_weight_off = 0; + + // Process K within this slice, trying 4 K_STEPs at once when aligned + for (int k_abs = k_slice_start; k_abs < k_slice_end;) { + const int k_block_idx = k_abs / K_BLOCK; + const int k_block_begin = k_block_idx * K_BLOCK; + const int k_block_size = std::min(K_BLOCK, cpu_k_w2 - k_block_begin); + const int k_in_block = k_abs - k_block_begin; + const int k_in_gpu = k_in_gpu_base + (k_abs - k_slice_start); + + // Check if we can process 4 K_STEPs at once + const int remaining_in_block = k_block_size - k_in_block; + const int remaining_in_slice = k_slice_end - k_abs; + + if (remaining_in_block >= 4 * K_STEP && remaining_in_slice >= 4 * K_STEP) { + const uint8_t* src_ptrs[4]; + for (int i = 0; i < 4; i++) { + src_ptrs[i] = bb->b + (size_t)n_block_begin * cpu_k_w2 + (size_t)k_block_begin * n_block_size + + (size_t)n_in_block * k_block_size + (size_t)(k_in_block + i * K_STEP) * N_STEP; + } + uint8_t* dst = weight_base + expert_weight_off + (size_t)local_n_start * gpu_k_w2 + k_in_gpu; + unpack_4nk_blocks(src_ptrs, dst, gpu_k_w2); + k_abs += 4 * K_STEP; + } else { + const uint8_t* src = bb->b + (size_t)n_block_begin * cpu_k_w2 + + (size_t)k_block_begin * n_block_size + (size_t)n_in_block * k_block_size + + (size_t)k_in_block * N_STEP; + uint8_t* dst = weight_base + expert_weight_off + (size_t)local_n_start * gpu_k_w2 + k_in_gpu; + unpack_nk_block(src, dst, gpu_k_w2); + k_abs += K_STEP; + } + } + } + } + + } else { + // ========= Scale copy task: per-channel (simple linear copy) ========= + const int scale_task_id = task_id - NUM_W13_TASKS * 2 - NUM_W2_TASKS; + + if (scale_task_id < 2) { + // Gate (0) or Up (1) scale copy - per-channel: [intermediate_size] + const bool is_up = scale_task_id == 1; + const auto& bb = is_up ? up_bb_[expert_id] : gate_bb_[expert_id]; + + // W13 per-channel scales: copy N range corresponding to this CPU TP + // Each GPU TP gets [gpu_n_w13] scales + const int n_start_global = global_n_offset_w13; + + for (int local_n = 0; local_n < cpu_n_w13;) { + const int global_n = n_start_global + local_n; + const int target_gpu = global_n / gpu_n_w13; + const int n_in_gpu = global_n % gpu_n_w13; + + // Calculate how many scales to copy to this GPU TP + const int remaining_in_gpu = gpu_n_w13 - n_in_gpu; + const int remaining_local = cpu_n_w13 - local_n; + const int copy_count = std::min(remaining_in_gpu, remaining_local); + + float* scale_dst = (float*)w13_scale_ptrs[target_gpu]; + // Pointer already points to current expert's location, only add offset for up matrix + const size_t expert_scale_off = is_up ? gpu_w13_scale_per_mat : 0; + + fast_memcpy(scale_dst + expert_scale_off + n_in_gpu, bb->d + local_n, copy_count * sizeof(float)); + + local_n += copy_count; + } + } else { + // Down scale copy (scale_task_id == 2) - per-channel: [hidden_size] + const auto& bb = down_bb_[expert_id]; + + // W2 per-channel scales: shape [hidden_size], not split by K + // All GPU TPs get the same scales (full hidden_size) + // However, since K is split, we need to write to each GPU TP + for (int gpu_idx = 0; gpu_idx < gpu_tp_count; gpu_idx++) { + // Check if this CPU TP contributes to this GPU TP's K range + const int gpu_k_start = gpu_idx * gpu_k_w2; + const int gpu_k_end = gpu_k_start + gpu_k_w2; + const int cpu_k_start = global_k_offset_w2; + const int cpu_k_end = cpu_k_start + cpu_k_w2; + + // Check for overlap + if (cpu_k_start < gpu_k_end && cpu_k_end > gpu_k_start) { + // This CPU TP contributes to this GPU TP + // Only the first CPU TP for this GPU should write scales + if (cpu_k_start == gpu_k_start || cpu_k_start % gpu_k_w2 == 0) { + float* scale_dst = (float*)w2_scale_ptrs[gpu_idx]; + // Pointer already points to current expert's location + fast_memcpy(scale_dst, bb->d, cpu_n_w2 * sizeof(float)); + } + } + } + } + } + }, + nullptr); + } + + /** + * @brief Load FP8 weights from contiguous memory layout with per-channel scales + * + * Loads weights from config_.gate_proj, up_proj, down_proj with scales + * from config_.gate_scale, up_scale, down_scale. + * + * Per-channel scale shape: [n] (one scale per output channel) + */ + void load_weights() { + const uint64_t* physical_to_logical_map = (const uint64_t*)config_.physical_to_logical_map; + auto pool = config_.pool->get_subpool(tp_part_idx); + + if (config_.gate_scale == nullptr) { + throw std::runtime_error("FP8 Per-Channel MoE requires scale pointers."); + } + + // load gate and up weights + int nth = T::recommended_nth(config_.intermediate_size); + pool->do_work_stealing_job( + nth * config_.expert_num, nullptr, + [this, nth, physical_to_logical_map](int task_id) { + uint64_t expert_idx = task_id / nth; + uint64_t logical_expert_id = expert_map(physical_to_logical_map, expert_idx); + int ith = task_id % nth; + + // Per-channel scale: shape [intermediate_size] for gate/up + const size_t weight_offset = logical_expert_id * config_.intermediate_size * config_.hidden_size; + const size_t scale_offset = logical_expert_id * config_.intermediate_size; + + // gate part + gate_bb_[expert_idx]->from_mat((uint8_t*)config_.gate_proj + weight_offset, + (float*)config_.gate_scale + scale_offset, ith, nth); + // up part + up_bb_[expert_idx]->from_mat((uint8_t*)config_.up_proj + weight_offset, + (float*)config_.up_scale + scale_offset, ith, nth); + }, + nullptr); + + // load down weights + nth = T::recommended_nth(config_.hidden_size); + pool->do_work_stealing_job( + nth * config_.expert_num, nullptr, + [this, nth, physical_to_logical_map](int task_id) { + uint64_t expert_idx = task_id / nth; + uint64_t logical_expert_id = expert_map(physical_to_logical_map, expert_idx); + int ith = task_id % nth; + + // Per-channel scale: shape [hidden_size] for down + const size_t weight_offset = logical_expert_id * config_.intermediate_size * config_.hidden_size; + const size_t scale_offset = logical_expert_id * config_.hidden_size; + + // down part + down_bb_[expert_idx]->from_mat((uint8_t*)config_.down_proj + weight_offset, + (float*)config_.down_scale + scale_offset, ith, nth); + }, + nullptr); + } +}; + +/** + * @brief TP_MOE specialization for FP8 Per-Channel MoE + */ +template +class TP_MOE> : public TP_MOE>> { + public: + using Base = TP_MOE>>; + using Base::Base; + + /** + * @brief Write weights and scales to GPU buffer for a single expert + * + * This method coordinates all CPU TP parts to write their portions + * of weights and scales to the GPU buffers. + * + * @param gpu_tp_count Number of GPU TP parts + * @param expert_id Expert index to write + * @param w13_weight_ptrs Pointers to gate+up weight buffers (one per GPU TP) + * @param w13_scale_ptrs Pointers to gate+up scale buffers (one per GPU TP) + * @param w2_weight_ptrs Pointers to down weight buffers (one per GPU TP) + * @param w2_scale_ptrs Pointers to down scale buffers (one per GPU TP) + */ + void write_weight_scale_to_buffer(int gpu_tp_count, int expert_id, const std::vector& w13_weight_ptrs, + const std::vector& w13_scale_ptrs, + const std::vector& w2_weight_ptrs, + const std::vector& w2_scale_ptrs) { + if (this->weights_loaded == false) { + throw std::runtime_error("Not Loaded"); + } + if (this->tps.empty()) { + throw std::runtime_error("No TP parts initialized"); + } + if ((int)w13_weight_ptrs.size() != gpu_tp_count || (int)w13_scale_ptrs.size() != gpu_tp_count || + (int)w2_weight_ptrs.size() != gpu_tp_count || (int)w2_scale_ptrs.size() != gpu_tp_count) { + throw std::runtime_error("Pointer arrays size must match gpu_tp_count"); + } + + this->config.pool->dispense_backend()->do_numa_job([&, this](int i) { + this->tps[i]->write_weights_to_buffer(gpu_tp_count, this->tp_count, expert_id, this->config, w13_weight_ptrs, + w13_scale_ptrs, w2_weight_ptrs, w2_scale_ptrs); + }); + } + + void load_weights() override { + auto& config = this->config; + auto& tps = this->tps; + auto& tp_count = this->tp_count; + auto pool = config.pool; + const uint64_t* physical_to_logical_map = (const uint64_t*)config.physical_to_logical_map; + + if (!config.quant_config.per_channel) { + throw std::runtime_error("FP8 Per-Channel MoE requires per_channel=true"); + } + + if (config.gate_projs.empty() && config.gate_proj == nullptr) { + throw std::runtime_error("no weight source"); + } + const bool use_per_expert_ptrs = !config.gate_projs.empty(); + + const size_t full_weight_elems = (size_t)config.intermediate_size * config.hidden_size; + // Per-channel: scale count = output dimension + const size_t gate_up_scale_elems = (size_t)config.intermediate_size; + const size_t down_scale_elems = (size_t)config.hidden_size; + + pool->dispense_backend()->do_numa_job([&, this](int i) { + auto& tpc = tps[i]->config_; + const size_t tp_weight_elems = (size_t)tpc.intermediate_size * tpc.hidden_size; + // Per-channel scales for TP part + const size_t tp_gate_up_scale_elems = (size_t)tpc.intermediate_size; + const size_t tp_down_scale_elems = (size_t)tpc.hidden_size; + + tpc.gate_proj = new uint8_t[tpc.expert_num * tp_weight_elems]; + tpc.up_proj = new uint8_t[tpc.expert_num * tp_weight_elems]; + tpc.down_proj = new uint8_t[tpc.expert_num * tp_weight_elems]; + + tpc.gate_scale = new float[tpc.expert_num * tp_gate_up_scale_elems]; + tpc.up_scale = new float[tpc.expert_num * tp_gate_up_scale_elems]; + tpc.down_scale = new float[tpc.expert_num * tp_down_scale_elems]; + + const size_t tp_idx = (size_t)i; + // gate/up: split by N (intermediate_size) + const size_t gate_up_weight_src_offset = i * tp_weight_elems; + const size_t gate_up_scale_src_offset = i * tp_gate_up_scale_elems; + + // down: split by K (intermediate_size) + const size_t down_weight_src_col_offset = i * (size_t)tpc.intermediate_size; + + pool->get_subpool(i)->do_work_stealing_job( + tpc.expert_num, nullptr, + [&, &tpc](int expert_id_) { + const size_t expert_id = expert_map(physical_to_logical_map, expert_id_); + + uint8_t* gate_dst = (uint8_t*)tpc.gate_proj + expert_id * tp_weight_elems; + uint8_t* up_dst = (uint8_t*)tpc.up_proj + expert_id * tp_weight_elems; + uint8_t* down_dst = (uint8_t*)tpc.down_proj + expert_id * tp_weight_elems; + + float* gate_scale_dst = (float*)tpc.gate_scale + expert_id * tp_gate_up_scale_elems; + float* up_scale_dst = (float*)tpc.up_scale + expert_id * tp_gate_up_scale_elems; + float* down_scale_dst = (float*)tpc.down_scale + expert_id * tp_down_scale_elems; + + const uint8_t* gate_src; + const uint8_t* up_src; + const uint8_t* down_src; + const float* gate_scale_src; + const float* up_scale_src; + const float* down_scale_src; + + if (use_per_expert_ptrs) { + gate_src = (const uint8_t*)config.gate_projs[0][expert_id] + gate_up_weight_src_offset; + up_src = (const uint8_t*)config.up_projs[0][expert_id] + gate_up_weight_src_offset; + down_src = (const uint8_t*)config.down_projs[0][expert_id]; + + gate_scale_src = (const float*)config.gate_scales[0][expert_id] + gate_up_scale_src_offset; + up_scale_src = (const float*)config.up_scales[0][expert_id] + gate_up_scale_src_offset; + down_scale_src = (const float*)config.down_scales[0][expert_id]; + } else { + gate_src = (const uint8_t*)config.gate_proj + expert_id * full_weight_elems + gate_up_weight_src_offset; + up_src = (const uint8_t*)config.up_proj + expert_id * full_weight_elems + gate_up_weight_src_offset; + down_src = (const uint8_t*)config.down_proj + expert_id * full_weight_elems; + + gate_scale_src = + (const float*)config.gate_scale + expert_id * gate_up_scale_elems + gate_up_scale_src_offset; + up_scale_src = (const float*)config.up_scale + expert_id * gate_up_scale_elems + gate_up_scale_src_offset; + down_scale_src = (const float*)config.down_scale + expert_id * down_scale_elems; + } + + // Copy gate/up weights and scales (N dimension split) + std::memcpy(gate_dst, gate_src, tp_weight_elems); + std::memcpy(up_dst, up_src, tp_weight_elems); + std::memcpy(gate_scale_dst, gate_scale_src, sizeof(float) * tp_gate_up_scale_elems); + std::memcpy(up_scale_dst, up_scale_src, sizeof(float) * tp_gate_up_scale_elems); + + // Copy down weights (K dimension split) - row by row + for (int row = 0; row < config.hidden_size; row++) { + const size_t src_row_offset = (size_t)row * (size_t)config.intermediate_size + down_weight_src_col_offset; + const size_t dst_row_offset = (size_t)row * (size_t)tpc.intermediate_size; + std::memcpy(down_dst + dst_row_offset, down_src + src_row_offset, (size_t)tpc.intermediate_size); + } + + // Copy down scales (N dimension = hidden_size, full copy for each TP) + std::memcpy(down_scale_dst, down_scale_src, sizeof(float) * tp_down_scale_elems); + }, + nullptr); + }); + + DO_TPS_LOAD_WEIGHTS(pool); + + pool->dispense_backend()->do_numa_job([&, this](int i) { + auto& tpc = tps[i]->config_; + delete[] (uint8_t*)tpc.gate_proj; + delete[] (uint8_t*)tpc.up_proj; + delete[] (uint8_t*)tpc.down_proj; + delete[] (float*)tpc.gate_scale; + delete[] (float*)tpc.up_scale; + delete[] (float*)tpc.down_scale; + }); + + this->weights_loaded = true; + } +}; + +#endif // CPUINFER_OPERATOR_AMX_FP8_PERCHANNEL_MOE_H diff --git a/kt-kernel/operators/amx/la/amx_raw_buffers.hpp b/kt-kernel/operators/amx/la/amx_raw_buffers.hpp index b485966..86fee93 100644 --- a/kt-kernel/operators/amx/la/amx_raw_buffers.hpp +++ b/kt-kernel/operators/amx/la/amx_raw_buffers.hpp @@ -483,6 +483,119 @@ struct BufferCFP32ReduceImpl { } }; +// ============================================================================ +// BufferBFP8PerChannelImpl: FP8 权重缓冲区(Per Channel 量化) +// ============================================================================ + +/** + * @brief FP8 Per-Channel 权重缓冲区 + * + * 存储 FP8 格式的权重矩阵,每个输出通道(行)有一个缩放因子。 + * 这与 GLM-4.7-FP8 的 per-channel 量化格式匹配。 + * + * 与 BufferBFP8Impl (block-wise) 的区别: + * - Block-wise: scale shape = [n/128, k/128], 每 128x128 块一个 scale + * - Per-channel: scale shape = [n], 每行一个 scale + * + * @tparam K Kernel 类型 + */ +template +struct BufferBFP8PerChannelImpl { + uint8_t* b; // FP8 weight [n, k] + float* d; // per-channel scale [n] + int n, k; + + static constexpr int N_STEP = K::N_STEP; + static constexpr int K_STEP = K::K_STEP; + static constexpr int N_BLOCK = K::N_BLOCK; + static constexpr int K_BLOCK = K::K_BLOCK; + static constexpr bool SCALE = true; + static constexpr bool PER_CHANNEL = true; + + /** + * @brief 计算所需内存大小 + * weight: n * k bytes (FP8) + * scale: n * sizeof(float) bytes + */ + static size_t required_size(int n, int k) { return sizeof(uint8_t) * n * k + sizeof(float) * n; } + + /** + * @brief 构造函数 + */ + BufferBFP8PerChannelImpl(int n, int k, void* ptr) : n(n), k(k) { set_data(ptr); } + + void set_data(void* ptr) { + assert(reinterpret_cast(ptr) % 64 == 0); + b = reinterpret_cast(ptr); + d = reinterpret_cast(b + (size_t)n * k); + } + + static constexpr int mat_offset[8] = {0, 2, 4, 6, 1, 3, 5, 7}; // fp8 matrix offset for reordering + + /** + * @brief 从原始 FP8 权重加载(per-channel 量化格式) + * + * @param b_src FP8 权重源数据 (n-major, n×k) + * @param d_src FP32 per-channel scale 源数据 (shape: [n] or [n, 1]) + */ + void from_mat(const uint8_t* b_src, const float* d_src, int ith, int nth) { + assert(b != nullptr && d != nullptr); + assert(N_STEP == 32 && K_STEP == 32); + + // Copy per-channel scales. Each thread copies its own n-block range. + if (d_src != nullptr) { + auto [n_start, n_end] = K::split_range_n(n, ith, nth); + memcpy(d + n_start, d_src + n_start, sizeof(float) * (n_end - n_start)); + } + + // Reorder FP8 weights into KT block-major layout (same as BufferBFP8Impl) + auto [n_start, n_end] = K::split_range_n(n, ith, nth); + int n_block_begin = n_start; + int n_block_size = n_end - n_block_begin; + for (int n_begin = 0; n_begin < n_block_size; n_begin += N_STEP) { + int n_step_size = std::min(N_STEP, n_block_size - n_begin); + for (int k_block_begin = 0; k_block_begin < k; k_block_begin += K_BLOCK) { + int k_block_size = std::min(K_BLOCK, k - k_block_begin); + for (int k_begin = 0; k_begin < k_block_size; k_begin += K_STEP) { + int k_step_size = std::min(K_STEP, k_block_size - k_begin); + // [k_step_size, n_step_size] block copy + const uint8_t* block_b_src = b_src + (size_t)(n_block_begin + n_begin) * k + k_block_begin + k_begin; + uint64_t* block_b_dst = + reinterpret_cast(b + (size_t)n_block_begin * k + (size_t)k_block_begin * n_block_size + + (size_t)n_begin * k_block_size + (size_t)k_begin * N_STEP); + for (int i = 0; i < 8; i++) { + const uint16_t* s = reinterpret_cast(block_b_src + (size_t)i * k * 4); + for (int j = 0; j < 16; j++) { + uint64_t val = (((uint64_t)s[j])) | (((uint64_t)s[j + (k / 2) * 1]) << 16) | + (((uint64_t)s[j + (k / 2) * 2]) << 32) | (((uint64_t)s[j + (k / 2) * 3]) << 48); + block_b_dst[8 * j + mat_offset[i]] = val; + } + } + } + } + } + } + + /** + * @brief 获取行 n_begin 开始的 per-channel scale 指针 + */ + float* get_scale(int n_begin) { return d + n_begin; } + + /** + * @brief 获取子矩阵指针 + */ + uint8_t* get_submat(int n, int k, int n_begin, int k_begin) { + int n_block_begin = n_begin / N_BLOCK * N_BLOCK; + n_begin -= n_block_begin; + int n_block_size = std::min(N_BLOCK, n - n_block_begin); + int k_block_begin = k_begin / K_BLOCK * K_BLOCK; + k_begin -= k_block_begin; + int k_block_size = std::min(K_BLOCK, k - k_block_begin); + return b + (size_t)n_block_begin * k + (size_t)k_block_begin * n_block_size + (size_t)n_begin * k_block_size + + (size_t)k_begin * N_STEP; + } +}; + } // namespace amx #endif // AMX_RAW_BUFFERS_HPP diff --git a/kt-kernel/operators/amx/la/amx_raw_kernels.hpp b/kt-kernel/operators/amx/la/amx_raw_kernels.hpp index df2faae..4170286 100644 --- a/kt-kernel/operators/amx/la/amx_raw_kernels.hpp +++ b/kt-kernel/operators/amx/la/amx_raw_kernels.hpp @@ -285,7 +285,7 @@ struct GemmKernel224FP8 { static void config() {} - private: + // FP8->BF16 conversion lookup tables (public for reuse by GemmKernel224FP8PerChannel) alignas(64) static constexpr uint8_t bf16_hi_0_val[64] = { 0x00, 0x3b, 0x3b, 0x3b, 0x3c, 0x3c, 0x3c, 0x3c, 0x3c, 0x3c, 0x3c, 0x3c, 0x3c, 0x3c, 0x3c, 0x3c, 0x3d, 0x3d, 0x3d, 0x3d, 0x3d, 0x3d, 0x3d, 0x3d, 0x3d, 0x3d, 0x3d, 0x3d, 0x3d, 0x3d, 0x3d, 0x3d, @@ -317,8 +317,6 @@ struct GemmKernel224FP8 { static inline __m512i bf16_lo_0_mask() { return _mm512_load_si512((__m512i const*)bf16_lo_0_val); } static inline __m512i bf16_lo_1_mask() { return _mm512_load_si512((__m512i const*)bf16_lo_1_val); } static inline __m512i sign_mask() { return _mm512_set1_epi8(0x80); } - - public: using BufferA = BufferABF16Impl; using BufferB = BufferBFP8Impl; using BufferC = BufferCFP32ReduceImpl; @@ -618,6 +616,231 @@ inline void mat_mul_kgroup(int m, int n, int k, int k_group_size, std::shared_pt float_mat_vec_kgroup(m, n, k, k_group_size, ba.get(), bb.get(), bc.get(), ith, nth); } +// ============================================================================ +// Per-Channel FP8 GEMM (for GLM-4.7-FP8 style quantization) +// ============================================================================ + +/** + * @brief FP8 Per-Channel Kernel + * + * Similar to GemmKernel224FP8 but with per-channel scaling instead of block-wise scaling. + * - Block-wise: scale shape = [n/128, k/128], one scale per 128x128 block + * - Per-channel: scale shape = [n], one scale per output row + */ +struct GemmKernel224FP8PerChannel { + using fp8_t = uint8_t; + using output_t = float; + + static constexpr double ELEMENT_SIZE = 1.0; + static const int TILE_M = 16; + static const int TILE_K = 32; + static const int TILE_N = 16; + static const int VNNI_BLK = 2; + + static const int M_STEP = TILE_M * 2; + static const int N_STEP = TILE_N * 2; + static const int K_STEP = TILE_K; + + // Use smaller N_BLOCK for per-channel to allow efficient scale application + static inline const int N_BLOCK = 128; + static inline const int K_BLOCK = 7168; + + static std::string name() { return "FP8PerChannel"; } + + static int recommended_nth(int n) { return (n + N_BLOCK - 1) / N_BLOCK; } + + static std::pair split_range_n(int n, int ith, int nth) { + int n_start = N_BLOCK * ith; + int n_end = std::min(n, N_BLOCK * (ith + 1)); + return {n_start, n_end}; + } + + static void config() {} + + using BufferA = BufferABF16Impl; + using BufferB = BufferBFP8PerChannelImpl; + using BufferC = BufferCFP32Impl; + + // Reuse FP8->BF16 conversion from GemmKernel224FP8 + static inline std::pair<__m512i, __m512i> fp8x64_to_bf16x64(__m512i bfp8_512) { + return GemmKernel224FP8::fp8x64_to_bf16x64(bfp8_512); + } + + /** + * @brief Apply per-channel scale to result + * + * Unlike block-wise scaling, per-channel scaling applies a different scale to each column + * of the result (each output channel). + * + * @param m Total rows + * @param n Total columns + * @param m_begin Starting row + * @param n_begin Starting column + * @param c Output buffer (M_STEP x N_STEP) + * @param bb BufferB containing per-channel scales + */ + static void apply_scale_perchannel(int m, [[maybe_unused]] int n, int m_begin, int n_begin, float* c, BufferB* bb) { + int to = std::min(m - m_begin, M_STEP); + + // Load N_STEP per-channel scales (32 floats) + __m512 bs_lo = _mm512_loadu_ps(bb->get_scale(n_begin)); // scale[n_begin..n_begin+15] + __m512 bs_hi = _mm512_loadu_ps(bb->get_scale(n_begin + TILE_N)); // scale[n_begin+16..n_begin+31] + + for (int i = 0; i < to; i++) { + // Each row gets multiplied by the same set of per-channel scales + __m512 c_lo = _mm512_load_ps(c + i * N_STEP); + __m512 c_hi = _mm512_load_ps(c + i * N_STEP + TILE_N); + _mm512_store_ps(c + i * N_STEP, _mm512_mul_ps(c_lo, bs_lo)); + _mm512_store_ps(c + i * N_STEP + TILE_N, _mm512_mul_ps(c_hi, bs_hi)); + } + } + + // AVX kernel for per-channel FP8 GEMM - processes entire K dimension + static void avx_kernel_4(int m, int n, int k, int m_begin, int n_begin, int k_block_begin, float* c, BufferA* ba, + BufferB* bb) { + const __m512i bf16_hi_0 = GemmKernel224FP8::bf16_hi_0_mask(); + const __m512i bf16_hi_1 = GemmKernel224FP8::bf16_hi_1_mask(); + const __m512i bf16_lo_0 = GemmKernel224FP8::bf16_lo_0_mask(); + const __m512i bf16_lo_1 = GemmKernel224FP8::bf16_lo_1_mask(); + const __m512i sign_mask_v = GemmKernel224FP8::sign_mask(); + + __m512* c512 = (__m512*)c; + int m_block_end = std::min(m - m_begin, M_STEP); + + // Zero out accumulator at start of K_BLOCK + if (k_block_begin == 0) { + for (int m_i = 0; m_i < m_block_end; m_i++) { + c512[m_i * 2] = _mm512_setzero_ps(); + c512[m_i * 2 + 1] = _mm512_setzero_ps(); + } + } + + // Process K_BLOCK + for (int k_begin = 0; k_begin < K_BLOCK && k_block_begin + k_begin < k; k_begin += K_STEP) { + ggml_bf16_t* abf16 = (ggml_bf16_t*)ba->get_submat(m, k, m_begin, k_block_begin + k_begin); + __m512i* bfp8_512 = (__m512i*)bb->get_submat(n, k, n_begin, k_block_begin + k_begin); + + // Process 4 k_i at once + for (int k_i = 0; k_i < 16; k_i += 4) { + // Load 4 B vectors + __m512i bfp8_0 = bfp8_512[k_i]; + __m512i bfp8_1 = bfp8_512[k_i + 1]; + __m512i bfp8_2 = bfp8_512[k_i + 2]; + __m512i bfp8_3 = bfp8_512[k_i + 3]; + + // Convert all 4 FP8 -> BF16 + __m512i b_hi, b_lo; + + b_hi = _mm512_or_si512(_mm512_and_si512(sign_mask_v, bfp8_0), + _mm512_permutex2var_epi8(bf16_hi_0, bfp8_0, bf16_hi_1)); + b_lo = _mm512_permutex2var_epi8(bf16_lo_0, bfp8_0, bf16_lo_1); + __m512bh bbf16_0_lo = (__m512bh)_mm512_unpacklo_epi8(b_lo, b_hi); + __m512bh bbf16_0_hi = (__m512bh)_mm512_unpackhi_epi8(b_lo, b_hi); + + b_hi = _mm512_or_si512(_mm512_and_si512(sign_mask_v, bfp8_1), + _mm512_permutex2var_epi8(bf16_hi_0, bfp8_1, bf16_hi_1)); + b_lo = _mm512_permutex2var_epi8(bf16_lo_0, bfp8_1, bf16_lo_1); + __m512bh bbf16_1_lo = (__m512bh)_mm512_unpacklo_epi8(b_lo, b_hi); + __m512bh bbf16_1_hi = (__m512bh)_mm512_unpackhi_epi8(b_lo, b_hi); + + b_hi = _mm512_or_si512(_mm512_and_si512(sign_mask_v, bfp8_2), + _mm512_permutex2var_epi8(bf16_hi_0, bfp8_2, bf16_hi_1)); + b_lo = _mm512_permutex2var_epi8(bf16_lo_0, bfp8_2, bf16_lo_1); + __m512bh bbf16_2_lo = (__m512bh)_mm512_unpacklo_epi8(b_lo, b_hi); + __m512bh bbf16_2_hi = (__m512bh)_mm512_unpackhi_epi8(b_lo, b_hi); + + b_hi = _mm512_or_si512(_mm512_and_si512(sign_mask_v, bfp8_3), + _mm512_permutex2var_epi8(bf16_hi_0, bfp8_3, bf16_hi_1)); + b_lo = _mm512_permutex2var_epi8(bf16_lo_0, bfp8_3, bf16_lo_1); + __m512bh bbf16_3_lo = (__m512bh)_mm512_unpacklo_epi8(b_lo, b_hi); + __m512bh bbf16_3_hi = (__m512bh)_mm512_unpackhi_epi8(b_lo, b_hi); + + // Process m rows + int m_i = 0; + for (; m_i + 1 < m_block_end; m_i += 2) { + __m512bh ma0_0 = (__m512bh)_mm512_set1_epi32(*(int32_t*)&abf16[m_i * K_STEP + k_i * 2]); + __m512bh ma1_0 = (__m512bh)_mm512_set1_epi32(*(int32_t*)&abf16[m_i * K_STEP + (k_i + 1) * 2]); + __m512bh ma2_0 = (__m512bh)_mm512_set1_epi32(*(int32_t*)&abf16[m_i * K_STEP + (k_i + 2) * 2]); + __m512bh ma3_0 = (__m512bh)_mm512_set1_epi32(*(int32_t*)&abf16[m_i * K_STEP + (k_i + 3) * 2]); + __m512bh ma0_1 = (__m512bh)_mm512_set1_epi32(*(int32_t*)&abf16[(m_i + 1) * K_STEP + k_i * 2]); + __m512bh ma1_1 = (__m512bh)_mm512_set1_epi32(*(int32_t*)&abf16[(m_i + 1) * K_STEP + (k_i + 1) * 2]); + __m512bh ma2_1 = (__m512bh)_mm512_set1_epi32(*(int32_t*)&abf16[(m_i + 1) * K_STEP + (k_i + 2) * 2]); + __m512bh ma3_1 = (__m512bh)_mm512_set1_epi32(*(int32_t*)&abf16[(m_i + 1) * K_STEP + (k_i + 3) * 2]); + + c512[m_i * 2] = _mm512_dpbf16_ps(c512[m_i * 2], ma0_0, bbf16_0_lo); + c512[m_i * 2 + 1] = _mm512_dpbf16_ps(c512[m_i * 2 + 1], ma0_0, bbf16_0_hi); + c512[m_i * 2] = _mm512_dpbf16_ps(c512[m_i * 2], ma1_0, bbf16_1_lo); + c512[m_i * 2 + 1] = _mm512_dpbf16_ps(c512[m_i * 2 + 1], ma1_0, bbf16_1_hi); + c512[m_i * 2] = _mm512_dpbf16_ps(c512[m_i * 2], ma2_0, bbf16_2_lo); + c512[m_i * 2 + 1] = _mm512_dpbf16_ps(c512[m_i * 2 + 1], ma2_0, bbf16_2_hi); + c512[m_i * 2] = _mm512_dpbf16_ps(c512[m_i * 2], ma3_0, bbf16_3_lo); + c512[m_i * 2 + 1] = _mm512_dpbf16_ps(c512[m_i * 2 + 1], ma3_0, bbf16_3_hi); + + c512[(m_i + 1) * 2] = _mm512_dpbf16_ps(c512[(m_i + 1) * 2], ma0_1, bbf16_0_lo); + c512[(m_i + 1) * 2 + 1] = _mm512_dpbf16_ps(c512[(m_i + 1) * 2 + 1], ma0_1, bbf16_0_hi); + c512[(m_i + 1) * 2] = _mm512_dpbf16_ps(c512[(m_i + 1) * 2], ma1_1, bbf16_1_lo); + c512[(m_i + 1) * 2 + 1] = _mm512_dpbf16_ps(c512[(m_i + 1) * 2 + 1], ma1_1, bbf16_1_hi); + c512[(m_i + 1) * 2] = _mm512_dpbf16_ps(c512[(m_i + 1) * 2], ma2_1, bbf16_2_lo); + c512[(m_i + 1) * 2 + 1] = _mm512_dpbf16_ps(c512[(m_i + 1) * 2 + 1], ma2_1, bbf16_2_hi); + c512[(m_i + 1) * 2] = _mm512_dpbf16_ps(c512[(m_i + 1) * 2], ma3_1, bbf16_3_lo); + c512[(m_i + 1) * 2 + 1] = _mm512_dpbf16_ps(c512[(m_i + 1) * 2 + 1], ma3_1, bbf16_3_hi); + } + // Handle remaining row + for (; m_i < m_block_end; m_i++) { + __m512bh ma0 = (__m512bh)_mm512_set1_epi32(*(int32_t*)&abf16[m_i * K_STEP + k_i * 2]); + __m512bh ma1 = (__m512bh)_mm512_set1_epi32(*(int32_t*)&abf16[m_i * K_STEP + (k_i + 1) * 2]); + __m512bh ma2 = (__m512bh)_mm512_set1_epi32(*(int32_t*)&abf16[m_i * K_STEP + (k_i + 2) * 2]); + __m512bh ma3 = (__m512bh)_mm512_set1_epi32(*(int32_t*)&abf16[m_i * K_STEP + (k_i + 3) * 2]); + + c512[m_i * 2] = _mm512_dpbf16_ps(c512[m_i * 2], ma0, bbf16_0_lo); + c512[m_i * 2 + 1] = _mm512_dpbf16_ps(c512[m_i * 2 + 1], ma0, bbf16_0_hi); + c512[m_i * 2] = _mm512_dpbf16_ps(c512[m_i * 2], ma1, bbf16_1_lo); + c512[m_i * 2 + 1] = _mm512_dpbf16_ps(c512[m_i * 2 + 1], ma1, bbf16_1_hi); + c512[m_i * 2] = _mm512_dpbf16_ps(c512[m_i * 2], ma2, bbf16_2_lo); + c512[m_i * 2 + 1] = _mm512_dpbf16_ps(c512[m_i * 2 + 1], ma2, bbf16_2_hi); + c512[m_i * 2] = _mm512_dpbf16_ps(c512[m_i * 2], ma3, bbf16_3_lo); + c512[m_i * 2 + 1] = _mm512_dpbf16_ps(c512[m_i * 2 + 1], ma3, bbf16_3_hi); + } + } + } + } +}; + +/** + * @brief Per-channel FP8 GEMM function + * + * Unlike block-wise FP8 which applies scale per 128x128 block during computation, + * per-channel FP8 processes entire K dimension first, then applies per-channel scale at the end. + */ +template +void float_mat_vec_perchannel(int m, int n, int k, typename K::BufferA* ba, typename K::BufferB* bb, + typename K::BufferC* bc, int ith, int nth) { + assert(n % K::N_STEP == 0); + assert(k % K::K_STEP == 0); + + auto [n_start, n_end] = K::split_range_n(n, ith, nth); + + for (int m_begin = 0; m_begin < m; m_begin += K::M_STEP) { + for (int n_begin = n_start; n_begin < n_end; n_begin += K::N_STEP) { + float* c = bc->get_submat(m, n, m_begin, n_begin); + + // Process entire K dimension with K_BLOCKs + for (int k_block_begin = 0; k_block_begin < k; k_block_begin += K::K_BLOCK) { + K::avx_kernel_4(m, n, k, m_begin, n_begin, k_block_begin, c, ba, bb); + } + + // Apply per-channel scale once after all K is processed + K::apply_scale_perchannel(m, n, m_begin, n_begin, c, bb); + } + } +} + +inline void vec_mul_perchannel(int m, int n, int k, std::shared_ptr ba, + std::shared_ptr bb, + std::shared_ptr bc, int ith, int nth) { + float_mat_vec_perchannel(m, n, k, ba.get(), bb.get(), bc.get(), ith, nth); +} + } // namespace amx #endif // AMX_RAW_KERNELS_HPP diff --git a/kt-kernel/operators/common.hpp b/kt-kernel/operators/common.hpp index 3fa39a1..cfb8488 100644 --- a/kt-kernel/operators/common.hpp +++ b/kt-kernel/operators/common.hpp @@ -223,7 +223,8 @@ struct QuantConfig { std::string quant_method = ""; int bits = 0; int group_size = 0; - bool zero_point; + bool zero_point = false; + bool per_channel = false; // Per-channel quantization (GLM-4.7-FP8 style) }; struct GeneralMOEConfig { diff --git a/kt-kernel/python/utils/amx.py b/kt-kernel/python/utils/amx.py index 46a6282..73e03df 100644 --- a/kt-kernel/python/utils/amx.py +++ b/kt-kernel/python/utils/amx.py @@ -15,6 +15,14 @@ except (ImportError, AttributeError): _HAS_AMX_SUPPORT = False AMXInt4_MOE, AMXInt8_MOE, AMXInt4_KGroup_MOE, AMXFP8_MOE, AMXBF16_MOE = None, None, None, None, None +try: + from kt_kernel_ext.moe import AMXFP8PerChannel_MOE + + _HAS_FP8_PERCHANNEL_SUPPORT = True +except (ImportError, AttributeError): + _HAS_FP8_PERCHANNEL_SUPPORT = False + AMXFP8PerChannel_MOE = None + from typing import Optional @@ -304,7 +312,7 @@ class AMXMoEWrapper(BaseMoEWrapper): class NativeMoEWrapper(BaseMoEWrapper): - """Wrapper for RAWINT4/FP8/BF16 experts stored in compressed SafeTensor format.""" + """Wrapper for RAWINT4/FP8/FP8_PERCHANNEL/BF16 experts stored in compressed SafeTensor format.""" _native_loader_instance = None @@ -330,6 +338,8 @@ class NativeMoEWrapper(BaseMoEWrapper): raise RuntimeError("AMX backend with RAWINT4 support is not available.") if method == "FP8" and AMXFP8_MOE is None: raise RuntimeError("AMX backend with FP8 support is not available.") + if method == "FP8_PERCHANNEL" and not _HAS_FP8_PERCHANNEL_SUPPORT: + raise RuntimeError("AMX backend with FP8 per-channel support is not available.") if method == "BF16" and AMXBF16_MOE is None: raise RuntimeError("AMX backend with BF16 support is not available.") @@ -354,6 +364,9 @@ class NativeMoEWrapper(BaseMoEWrapper): NativeMoEWrapper._native_loader_instance = CompressedSafeTensorLoader(weight_path) elif method == "FP8": NativeMoEWrapper._native_loader_instance = FP8SafeTensorLoader(weight_path) + elif method == "FP8_PERCHANNEL": + # Use FP8SafeTensorLoader with per-channel scale format + NativeMoEWrapper._native_loader_instance = FP8SafeTensorLoader(weight_path, scale_suffix="weight_scale") elif method == "BF16": NativeMoEWrapper._native_loader_instance = BF16SafeTensorLoader(weight_path) else: @@ -408,6 +421,8 @@ class NativeMoEWrapper(BaseMoEWrapper): assert self.gate_scales[0].dtype == torch.bfloat16, "Expected bf16 scales for RAWINT4" elif self.method == "FP8": assert self.gate_scales[0].dtype == torch.float32, "Expected float32 scales for FP8" + elif self.method == "FP8_PERCHANNEL": + assert self.gate_scales[0].dtype == torch.float32, "Expected float32 scales for FP8_PERCHANNEL" t2 = time.time() @@ -462,6 +477,11 @@ class NativeMoEWrapper(BaseMoEWrapper): moe_config.quant_config.group_size = 128 moe_config.quant_config.zero_point = False self.moe = AMXFP8_MOE(moe_config) + elif self.method == "FP8_PERCHANNEL": + moe_config.quant_config.bits = 8 + moe_config.quant_config.per_channel = True + moe_config.quant_config.zero_point = False + self.moe = AMXFP8PerChannel_MOE(moe_config) elif self.method == "BF16": # BF16 has no quantization config needed self.moe = AMXBF16_MOE(moe_config) diff --git a/kt-kernel/python/utils/loader.py b/kt-kernel/python/utils/loader.py index 360780c..a9e6875 100644 --- a/kt-kernel/python/utils/loader.py +++ b/kt-kernel/python/utils/loader.py @@ -244,6 +244,10 @@ class FP8SafeTensorLoader(SafeTensorLoader): - DeepSeek style: {base}.mlp.experts.{id}.{gate,up,down}_proj.weight - Mixtral/MiniMax style: {base}.block_sparse_moe.experts.{id}.{w1,w3,w2}.weight + Supported scale formats (auto-detected): + - Block-wise: weight_scale_inv (DeepSeek FP8) + - Per-channel: weight_scale (GLM-4.7-FP8) + The format is auto-detected during initialization. """ @@ -253,13 +257,28 @@ class FP8SafeTensorLoader(SafeTensorLoader): "mixtral": ("{base}.block_sparse_moe.experts", "w1", "w3", "w2"), } - def __init__(self, file_path: str): + def __init__(self, file_path: str, scale_suffix: str = None): + """Initialize FP8 loader with optional scale suffix override. + + Args: + file_path: Path to safetensor files + scale_suffix: Optional scale key suffix. If None, auto-detect between + 'weight_scale_inv' (block-wise) and 'weight_scale' (per-channel). + """ super().__init__(file_path) self._detected_format = None + self._scale_suffix = scale_suffix # None means auto-detect + # Set per_channel based on explicit scale_suffix if provided + if scale_suffix == "weight_scale": + self._is_per_channel = True + elif scale_suffix == "weight_scale_inv": + self._is_per_channel = False + else: + self._is_per_channel = False # Will be updated in _detect_format if auto-detect self._detect_format() def _detect_format(self): - """Auto-detect the MoE naming format by checking tensor keys.""" + """Auto-detect the MoE naming format and scale format by checking tensor keys.""" # Sample some tensor names to detect format sample_keys = list(self.tensor_file_map.keys())[:1000] @@ -272,15 +291,42 @@ class FP8SafeTensorLoader(SafeTensorLoader): if "block_sparse_moe.experts" in key and fmt_name == "mixtral": self._detected_format = fmt_name print(f"[FP8SafeTensorLoader] Detected format: {fmt_name}") - return + break elif "mlp.experts" in key and "block_sparse_moe" not in key and fmt_name == "deepseek": self._detected_format = fmt_name print(f"[FP8SafeTensorLoader] Detected format: {fmt_name}") - return + break + if self._detected_format: + break # Default to deepseek if no format detected - self._detected_format = "deepseek" - print("[FP8SafeTensorLoader] No MoE format detected, defaulting to: deepseek") + if not self._detected_format: + self._detected_format = "deepseek" + print("[FP8SafeTensorLoader] No MoE format detected, defaulting to: deepseek") + + # Auto-detect scale suffix if not specified + if self._scale_suffix is None: + _, gate, _, _ = self.MOE_FORMATS[self._detected_format] + # Check for per-channel scale (weight_scale) vs block-wise (weight_scale_inv) + for key in sample_keys: + if f".{gate}.weight_scale_inv" in key: + self._scale_suffix = "weight_scale_inv" + self._is_per_channel = False + print("[FP8SafeTensorLoader] Detected scale format: block-wise (weight_scale_inv)") + return + elif f".{gate}.weight_scale" in key and "weight_scale_inv" not in key: + self._scale_suffix = "weight_scale" + self._is_per_channel = True + print("[FP8SafeTensorLoader] Detected scale format: per-channel (weight_scale)") + return + # Default to weight_scale_inv + self._scale_suffix = "weight_scale_inv" + self._is_per_channel = False + print("[FP8SafeTensorLoader] No scale format detected, defaulting to: weight_scale_inv") + else: + # Scale suffix was explicitly provided + scale_type = "per-channel" if self._is_per_channel else "block-wise" + print(f"[FP8SafeTensorLoader] Using explicit scale format: {scale_type} ({self._scale_suffix})") def _get_experts_prefix(self, base_key: str) -> str: """Get the experts prefix based on detected format.""" @@ -305,7 +351,11 @@ class FP8SafeTensorLoader(SafeTensorLoader): return tensor.to(device) def load_experts(self, base_key: str, device: str = "cpu"): - """Load FP8 expert weights and their block-wise scale_inv tensors.""" + """Load FP8 expert weights and their scale tensors. + + Supports both block-wise (weight_scale_inv) and per-channel (weight_scale) formats. + Per-channel scales are squeezed from [N, 1] to [N] if needed. + """ experts_prefix = self._get_experts_prefix(base_key) gate_name, up_name, down_name = self._get_proj_names() @@ -327,16 +377,30 @@ class FP8SafeTensorLoader(SafeTensorLoader): gate_w_key = f"{experts_prefix}.{exp_id}.{gate_name}.weight" up_w_key = f"{experts_prefix}.{exp_id}.{up_name}.weight" down_w_key = f"{experts_prefix}.{exp_id}.{down_name}.weight" - gate_s_key = f"{experts_prefix}.{exp_id}.{gate_name}.weight_scale_inv" - up_s_key = f"{experts_prefix}.{exp_id}.{up_name}.weight_scale_inv" - down_s_key = f"{experts_prefix}.{exp_id}.{down_name}.weight_scale_inv" + gate_s_key = f"{experts_prefix}.{exp_id}.{gate_name}.{self._scale_suffix}" + up_s_key = f"{experts_prefix}.{exp_id}.{up_name}.{self._scale_suffix}" + down_s_key = f"{experts_prefix}.{exp_id}.{down_name}.{self._scale_suffix}" gate_weights[exp_id] = self.load_tensor(gate_w_key, device).contiguous() up_weights[exp_id] = self.load_tensor(up_w_key, device).contiguous() down_weights[exp_id] = self.load_tensor(down_w_key, device).contiguous() - gate_scales[exp_id] = self.load_tensor(gate_s_key, device).contiguous() - up_scales[exp_id] = self.load_tensor(up_s_key, device).contiguous() - down_scales[exp_id] = self.load_tensor(down_s_key, device).contiguous() + + gate_scale = self.load_tensor(gate_s_key, device) + up_scale = self.load_tensor(up_s_key, device) + down_scale = self.load_tensor(down_s_key, device) + + # For per-channel scales, squeeze [N, 1] -> [N] if needed + if self._is_per_channel: + if gate_scale.dim() == 2 and gate_scale.shape[1] == 1: + gate_scale = gate_scale.squeeze(1) + if up_scale.dim() == 2 and up_scale.shape[1] == 1: + up_scale = up_scale.squeeze(1) + if down_scale.dim() == 2 and down_scale.shape[1] == 1: + down_scale = down_scale.squeeze(1) + + gate_scales[exp_id] = gate_scale.contiguous() + up_scales[exp_id] = up_scale.contiguous() + down_scales[exp_id] = down_scale.contiguous() return { "gate": gate_weights, @@ -347,6 +411,103 @@ class FP8SafeTensorLoader(SafeTensorLoader): "down_scale": down_scales, } + def is_per_channel(self) -> bool: + """Return True if using per-channel quantization, False for block-wise.""" + return self._is_per_channel + + +class BF16SafeTensorLoader(SafeTensorLoader): + """Loader for native BF16 expert weights (no quantization, no scales). + + Supported formats: + - DeepSeek style: {base}.mlp.experts.{id}.{gate,up,down}_proj.weight + - Mixtral/MiniMax style: {base}.block_sparse_moe.experts.{id}.{w1,w3,w2}.weight + + The format is auto-detected during initialization. + """ + + MOE_FORMATS = { + "deepseek": ("{base}.mlp.experts", "gate_proj", "up_proj", "down_proj"), + "mixtral": ("{base}.block_sparse_moe.experts", "w1", "w3", "w2"), + } + + def __init__(self, file_path: str): + super().__init__(file_path) + self._detected_format = None + self._detect_format() + + def _detect_format(self): + """Auto-detect the MoE naming format by checking tensor keys.""" + sample_keys = list(self.tensor_file_map.keys())[:1000] + + for fmt_name, (path_tpl, gate, up, down) in self.MOE_FORMATS.items(): + for key in sample_keys: + if ".experts." in key and f".{gate}.weight" in key: + if "block_sparse_moe.experts" in key and fmt_name == "mixtral": + self._detected_format = fmt_name + print(f"[BF16SafeTensorLoader] Detected format: {fmt_name}") + return + elif "mlp.experts" in key and "block_sparse_moe" not in key and fmt_name == "deepseek": + self._detected_format = fmt_name + print(f"[BF16SafeTensorLoader] Detected format: {fmt_name}") + return + + self._detected_format = "deepseek" + print("[BF16SafeTensorLoader] No MoE format detected, defaulting to: deepseek") + + def _get_experts_prefix(self, base_key: str) -> str: + """Get the experts prefix based on detected format.""" + path_tpl, _, _, _ = self.MOE_FORMATS[self._detected_format] + return path_tpl.format(base=base_key) + + def _get_proj_names(self): + """Get projection names (gate, up, down) based on detected format.""" + _, gate, up, down = self.MOE_FORMATS[self._detected_format] + return gate, up, down + + def load_tensor(self, key: str, device: str = "cpu"): + if key not in self.tensor_file_map: + raise KeyError(f"Key {key} not found in Safetensor files") + file = self.tensor_file_map[key] + f = self.file_handle_map.get(file) + if f is None: + raise FileNotFoundError(f"File {file} not found in Safetensor files") + tensor = f.get_tensor(key) + if device == "cpu": + return tensor + return tensor.to(device) + + def load_experts(self, base_key: str, device: str = "cpu"): + """Load BF16 expert weights (no scales needed).""" + experts_prefix = self._get_experts_prefix(base_key) + gate_name, up_name, down_name = self._get_proj_names() + + expert_count = 0 + while self.has_tensor(f"{experts_prefix}.{expert_count}.{gate_name}.weight"): + expert_count += 1 + + if expert_count == 0: + raise ValueError(f"No experts found for key {experts_prefix}") + + gate_weights = [None] * expert_count + up_weights = [None] * expert_count + down_weights = [None] * expert_count + + for exp_id in range(expert_count): + gate_w_key = f"{experts_prefix}.{exp_id}.{gate_name}.weight" + up_w_key = f"{experts_prefix}.{exp_id}.{up_name}.weight" + down_w_key = f"{experts_prefix}.{exp_id}.{down_name}.weight" + + gate_weights[exp_id] = self.load_tensor(gate_w_key, device).contiguous() + up_weights[exp_id] = self.load_tensor(up_w_key, device).contiguous() + down_weights[exp_id] = self.load_tensor(down_w_key, device).contiguous() + + return { + "gate": gate_weights, + "up": up_weights, + "down": down_weights, + } + class BF16SafeTensorLoader(SafeTensorLoader): """Loader for native BF16 expert weights (no quantization, no scales).