support GLM 4.7 (#1791)

support GLM 4.7
This commit is contained in:
Oql
2026-01-13 17:36:25 +08:00
committed by GitHub
parent 667030d6e6
commit 6277da4c2b
14 changed files with 2336 additions and 144 deletions

View File

@@ -17,7 +17,7 @@ import platform
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "build"))
import torch
import kt_kernel_ext
from kt_kernel import kt_kernel_ext
from tqdm import tqdm
# Test parameters
@@ -29,9 +29,9 @@ fp8_group_size = 128
max_len = 25600
layer_num = 2
qlen = 1024
warm_up_iter = 10
test_iter = 30
qlen = 1
warm_up_iter = 1000
test_iter = 3000
CPUINFER_PARAM = 80
CPUInfer = kt_kernel_ext.CPUInfer(CPUINFER_PARAM)

View File

@@ -0,0 +1,277 @@
"""
Performance benchmark for FP8 Per-Channel MoE kernel (GLM-4.7-FP8 style).
This benchmark measures the performance of the FP8 Per-Channel MoE operator with:
- FP8 (E4M3) weights with per-channel scaling (one scale per output row)
- BF16 activations
- AVX-512 DPBF16 compute path
"""
import os
import sys
import time
import json
import subprocess
import platform
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "build"))
import torch
from kt_kernel import kt_kernel_ext
from tqdm import tqdm
# Test parameters
expert_num = 256
hidden_size = 7168
intermediate_size = 2048
num_experts_per_tok = 8
max_len = 25600
layer_num = 2
qlen = 1
warm_up_iter = 1000
test_iter = 3000
CPUINFER_PARAM = 80
CPUInfer = kt_kernel_ext.CPUInfer(CPUINFER_PARAM)
# Result file path
script_path = os.path.abspath(__file__)
script_dir = os.path.dirname(script_path)
json_path = os.path.join(script_dir, "bench_results.jsonl")
def get_git_commit():
"""Get current git commit info"""
result = {}
try:
commit = subprocess.check_output(["git", "rev-parse", "HEAD"]).decode("utf-8").strip()
commit_msg = subprocess.check_output(["git", "log", "-1", "--pretty=%B"]).decode("utf-8").strip()
result["commit"] = commit
result["commit_message"] = commit_msg
dirty_output = subprocess.check_output(["git", "status", "--porcelain"]).decode("utf-8").strip()
result["dirty"] = bool(dirty_output)
if dirty_output:
result["dirty_files"] = dirty_output.splitlines()
except Exception as e:
result["commit"] = None
result["error"] = str(e)
return result
def get_system_info():
"""Get system information"""
info = {}
uname = platform.uname()
info["system_name"] = uname.system
info["node_name"] = uname.node
cpu_model = None
if os.path.exists("/proc/cpuinfo"):
try:
with open("/proc/cpuinfo", "r") as f:
for line in f:
if "model name" in line:
cpu_model = line.split(":", 1)[1].strip()
break
except Exception:
pass
info["cpu_model"] = cpu_model
info["cpu_core_count"] = os.cpu_count()
return info
def record_results(result, filename=json_path):
"""Append result to JSON file"""
with open(filename, "a") as f:
f.write(json.dumps(result) + "\n")
def generate_fp8_perchannel_weights_direct(shape: tuple):
"""
Directly generate random FP8 weights and per-channel scales.
Args:
shape: (expert_num, n, k) - weight tensor shape
Returns:
fp8_weights: uint8 tensor with random FP8 E4M3 values
scales: fp32 tensor with per-channel scales, shape [expert_num, n]
"""
e, n, k = shape
# Directly generate random FP8 weights as uint8
# FP8 E4M3 format: 1 sign + 4 exp + 3 mantissa
fp8_weights = torch.randint(0, 256, (e, n, k), dtype=torch.uint8, device="cuda").to("cpu").contiguous()
# Generate random per-channel scales (one per output row)
# Use reasonable scale range (e.g., 2^-8 to 2^8)
exponents = torch.randint(-8, 9, (e, n), dtype=torch.int32, device="cuda").to("cpu").contiguous()
scales = (2.0 ** exponents.float()).to(torch.float32).contiguous()
return fp8_weights, scales
def bench_fp8_perchannel_moe():
"""Benchmark FP8 Per-Channel MoE performance"""
with torch.inference_mode():
print("=" * 70)
print("FP8 Per-Channel MoE Kernel Performance Benchmark")
print("=" * 70)
# Generate FP8 weights with per-channel scales
print("\nGenerating FP8 weights with per-channel scales...")
torch.manual_seed(42)
gate_fp8, gate_scales = generate_fp8_perchannel_weights_direct((expert_num, intermediate_size, hidden_size))
up_fp8, up_scales = generate_fp8_perchannel_weights_direct((expert_num, intermediate_size, hidden_size))
down_fp8, down_scales = generate_fp8_perchannel_weights_direct((expert_num, hidden_size, intermediate_size))
physical_to_logical_map = torch.tensor(range(expert_num), device="cpu", dtype=torch.int64).contiguous()
# Build MoE layers
print("Building FP8 Per-Channel MoE layers...")
moes = []
for _ in tqdm(range(layer_num), desc="Initializing MOEs"):
config = kt_kernel_ext.moe.MOEConfig(expert_num, num_experts_per_tok, hidden_size, intermediate_size, 0)
config.max_len = max_len
config.quant_config.bits = 8
config.quant_config.group_size = 0 # Not used for per-channel
config.quant_config.zero_point = False
config.quant_config.per_channel = True # Enable per-channel mode
config.gate_proj = gate_fp8.data_ptr()
config.up_proj = up_fp8.data_ptr()
config.down_proj = down_fp8.data_ptr()
config.gate_scale = gate_scales.data_ptr()
config.up_scale = up_scales.data_ptr()
config.down_scale = down_scales.data_ptr()
config.pool = CPUInfer.backend_
moe = kt_kernel_ext.moe.AMXFP8PerChannel_MOE(config)
CPUInfer.submit(moe.load_weights_task(physical_to_logical_map.data_ptr()))
CPUInfer.sync()
moes.append(moe)
# Generate input data
print("Generating input data...")
gen_iter = 1000
expert_ids = (
torch.rand(gen_iter * qlen, expert_num, device="cpu")
.argsort(dim=-1)[:, :num_experts_per_tok]
.reshape(gen_iter, qlen * num_experts_per_tok)
.contiguous()
)
weights = torch.rand((gen_iter, qlen, num_experts_per_tok), dtype=torch.float32, device="cpu").contiguous()
input_tensor = torch.randn((layer_num, qlen, hidden_size), dtype=torch.bfloat16, device="cpu").contiguous()
output_tensor = torch.empty((layer_num, qlen, hidden_size), dtype=torch.bfloat16, device="cpu").contiguous()
qlen_tensor = torch.tensor([qlen], dtype=torch.int32)
# Warmup
print(f"Warming up ({warm_up_iter} iterations)...")
for i in tqdm(range(warm_up_iter), desc="Warm-up"):
CPUInfer.submit(
moes[i % layer_num].forward_task(
qlen_tensor.data_ptr(),
num_experts_per_tok,
expert_ids[i % gen_iter].data_ptr(),
weights[i % gen_iter].data_ptr(),
input_tensor[i % layer_num].data_ptr(),
output_tensor[i % layer_num].data_ptr(),
False,
)
)
CPUInfer.sync()
# Benchmark
print(f"Running benchmark ({test_iter} iterations)...")
start = time.perf_counter()
for i in tqdm(range(test_iter), desc="Testing"):
CPUInfer.submit(
moes[i % layer_num].forward_task(
qlen_tensor.data_ptr(),
num_experts_per_tok,
expert_ids[i % gen_iter].data_ptr(),
weights[i % gen_iter].data_ptr(),
input_tensor[i % layer_num].data_ptr(),
output_tensor[i % layer_num].data_ptr(),
False,
)
)
CPUInfer.sync()
end = time.perf_counter()
total_time = end - start
# Calculate metrics
time_per_iter_us = total_time / test_iter * 1e6
# FLOPS calculation:
# Each expert performs: gate(intermediate x hidden) + up(intermediate x hidden) + down(hidden x intermediate)
# GEMM/GEMV: 2 * m * n * k flops (multiply + accumulate = 2 ops per element)
# For vector-matrix multiply (qlen=1): 2 * n * k per matrix
flops_per_expert = (
2 * intermediate_size * hidden_size # gate
+ 2 * intermediate_size * hidden_size # up
+ 2 * hidden_size * intermediate_size # down
)
total_flops = qlen * num_experts_per_tok * flops_per_expert * test_iter
tflops = total_flops / total_time / 1e12
# Bandwidth calculation (FP8 = 1 byte per element)
bytes_per_elem = 1.0
# Weight memory: gate + up + down per expert
bandwidth = (
hidden_size
* intermediate_size
* 3
* num_experts_per_tok
* (1 / num_experts_per_tok * expert_num * (1 - (1 - num_experts_per_tok / expert_num) ** qlen))
* bytes_per_elem
* test_iter
/ total_time
/ 1e9
)
# Print results
print("\n" + "=" * 70)
print("Benchmark Results")
print("=" * 70)
print(f"Quant mode: FP8 (E4M3) with per-channel scaling")
print(f"Total time: {total_time:.4f} s")
print(f"Iterations: {test_iter}")
print(f"Time per iteration: {time_per_iter_us:.2f} us")
print(f"Bandwidth: {bandwidth:.2f} GB/s")
print(f"TFLOPS: {tflops:.4f}")
print("")
# Record results
result = {
"test_name": os.path.basename(__file__),
"quant_mode": "fp8_e4m3_perchannel",
"total_time_seconds": total_time,
"iterations": test_iter,
"time_per_iteration_us": time_per_iter_us,
"bandwidth_GBs": bandwidth,
"flops_TFLOPS": tflops,
"timestamp": time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()),
"test_parameters": {
"expert_num": expert_num,
"hidden_size": hidden_size,
"intermediate_size": intermediate_size,
"num_experts_per_tok": num_experts_per_tok,
"quant_type": "per_channel",
"layer_num": layer_num,
"qlen": qlen,
"warm_up_iter": warm_up_iter,
"test_iter": test_iter,
"CPUInfer_parameter": CPUINFER_PARAM,
},
}
result.update(get_git_commit())
result.update(get_system_info())
record_results(result)
return tflops, bandwidth
if __name__ == "__main__":
bench_fp8_perchannel_moe()

View File

@@ -4,12 +4,14 @@
Benchmark write_weight_scale_to_buffer for AMX MOE operators.
Supports:
- FP8: FP8 weights (1 byte) + float32 scales
- FP8: FP8 weights (1 byte) + float32 scales (block-wise)
- FP8_PERCHANNEL: FP8 weights (1 byte) + float32 per-channel scales
- BF16: Native BF16 weights (2 bytes), no scales
Usage:
python bench_write_buffer.py # Run all modes
python bench_write_buffer.py fp8 # Run FP8 only
python bench_write_buffer.py fp8_perchannel # Run FP8 per-channel only
python bench_write_buffer.py bf16 # Run BF16 only
"""
import json
@@ -31,8 +33,8 @@ expert_num = 256
num_experts_per_tok = 8
gpu_tp_count = 2
warm_up_iter = 3
test_iter = 7
warm_up_iter = 30
test_iter = 70
gpu_experts_num = expert_num
@@ -147,6 +149,49 @@ def allocate_weights_fp8():
}
def allocate_weights_fp8_perchannel():
per_mat_weight_bytes = hidden_size * intermediate_size
per_mat_scale_elems_gate_up = intermediate_size
per_mat_scale_elems_down = hidden_size
gate_q = (
torch.randint(0, 256, (expert_num * per_mat_weight_bytes,), dtype=torch.uint8, device="cuda")
.to("cpu")
.contiguous()
)
up_q = (
torch.randint(0, 256, (expert_num * per_mat_weight_bytes,), dtype=torch.uint8, device="cuda")
.to("cpu")
.contiguous()
)
down_q = (
torch.randint(0, 256, (expert_num * per_mat_weight_bytes,), dtype=torch.uint8, device="cuda")
.to("cpu")
.contiguous()
)
gate_scale = (
torch.randn(expert_num * per_mat_scale_elems_gate_up, dtype=torch.float32, device="cuda").to("cpu").contiguous()
)
up_scale = (
torch.randn(expert_num * per_mat_scale_elems_gate_up, dtype=torch.float32, device="cuda").to("cpu").contiguous()
)
down_scale = (
torch.randn(expert_num * per_mat_scale_elems_down, dtype=torch.float32, device="cuda").to("cpu").contiguous()
)
return {
"gate_q": gate_q,
"up_q": up_q,
"down_q": down_q,
"gate_scale": gate_scale,
"up_scale": up_scale,
"down_scale": down_scale,
"per_mat_weight_bytes": per_mat_weight_bytes,
"per_mat_scale_elems_gate_up": per_mat_scale_elems_gate_up,
"per_mat_scale_elems_down": per_mat_scale_elems_down,
}
def build_moe_fp8(layer_idx=0):
"""Build a single FP8 MOE instance."""
weights = allocate_weights_fp8()
@@ -178,6 +223,38 @@ def build_moe_fp8(layer_idx=0):
return moe, buffer_shapes, weights
def build_moe_fp8_perchannel(layer_idx=0):
"""Build a single FP8 per-channel MOE instance."""
weights = allocate_weights_fp8_perchannel()
config = kt_kernel_ext.moe.MOEConfig(expert_num, num_experts_per_tok, hidden_size, intermediate_size)
config.max_len = max_len
config.layer_idx = layer_idx
config.quant_config.bits = 8
config.quant_config.group_size = 0
config.quant_config.zero_point = False
config.quant_config.per_channel = True
config.pool = CPUInfer.backend_
config.gate_proj = weights["gate_q"].data_ptr()
config.up_proj = weights["up_q"].data_ptr()
config.down_proj = weights["down_q"].data_ptr()
config.gate_scale = weights["gate_scale"].data_ptr()
config.up_scale = weights["up_scale"].data_ptr()
config.down_scale = weights["down_scale"].data_ptr()
moe = kt_kernel_ext.moe.AMXFP8PerChannel_MOE(config)
CPUInfer.submit(moe.load_weights_task(physical_to_logical_map.data_ptr()))
CPUInfer.sync()
buffer_shapes = {
"per_mat_weight_bytes": weights["per_mat_weight_bytes"],
"per_mat_scale_elems_gate_up": weights["per_mat_scale_elems_gate_up"],
"per_mat_scale_elems_down": weights["per_mat_scale_elems_down"],
}
return moe, buffer_shapes, weights
def allocate_buffers_fp8(buffer_shapes):
"""Allocate output buffers for FP8 single expert."""
per_mat_weight_bytes = buffer_shapes["per_mat_weight_bytes"]
@@ -212,6 +289,40 @@ def allocate_buffers_fp8(buffer_shapes):
return buffer_ptrs, keep_tensors
def allocate_buffers_fp8_perchannel(buffer_shapes):
"""Allocate output buffers for FP8 per-channel single expert."""
per_mat_weight_bytes = buffer_shapes["per_mat_weight_bytes"]
per_mat_scale_elems_gate_up = buffer_shapes["per_mat_scale_elems_gate_up"]
per_mat_scale_elems_down = buffer_shapes["per_mat_scale_elems_down"]
weight_bytes_per_expert_per_tp = per_mat_weight_bytes // gpu_tp_count
scale_elems_per_expert_per_tp_gate_up = per_mat_scale_elems_gate_up // gpu_tp_count
scale_elems_per_expert_per_tp_down = per_mat_scale_elems_down
w13_weight_bufs = [torch.empty(2 * weight_bytes_per_expert_per_tp, dtype=torch.uint8) for _ in range(gpu_tp_count)]
w13_scale_bufs = [
torch.empty(2 * scale_elems_per_expert_per_tp_gate_up, dtype=torch.float32) for _ in range(gpu_tp_count)
]
w2_weight_bufs = [torch.empty(weight_bytes_per_expert_per_tp, dtype=torch.uint8) for _ in range(gpu_tp_count)]
w2_scale_bufs = [torch.empty(scale_elems_per_expert_per_tp_down, dtype=torch.float32) for _ in range(gpu_tp_count)]
buffer_ptrs = {
"w13_weight_ptrs": [buf.data_ptr() for buf in w13_weight_bufs],
"w13_scale_ptrs": [buf.data_ptr() for buf in w13_scale_bufs],
"w2_weight_ptrs": [buf.data_ptr() for buf in w2_weight_bufs],
"w2_scale_ptrs": [buf.data_ptr() for buf in w2_scale_bufs],
}
keep_tensors = {
"w13_weight_bufs": w13_weight_bufs,
"w13_scale_bufs": w13_scale_bufs,
"w2_weight_bufs": w2_weight_bufs,
"w2_scale_bufs": w2_scale_bufs,
}
return buffer_ptrs, keep_tensors
# ==============================================================================
# BF16 Functions
# ==============================================================================
@@ -320,6 +431,20 @@ def bench_write_buffer(quant_mode: str):
)
bytes_per_call = total_weights + total_scale_bytes
elif quant_mode == "fp8_perchannel":
bytes_per_elem = 1.0
moe_0, buffer_shapes, keep_tensors_0 = build_moe_fp8_perchannel(layer_idx=0)
moe_1, _, keep_tensors_1 = build_moe_fp8_perchannel(layer_idx=1)
buffer_ptrs, buffer_keep = allocate_buffers_fp8_perchannel(buffer_shapes)
total_weights = hidden_size * intermediate_size * expert_num * 3
total_scale_bytes = (
(buffer_shapes["per_mat_scale_elems_gate_up"] * 2 + buffer_shapes["per_mat_scale_elems_down"])
* expert_num
* 4
)
bytes_per_call = total_weights + total_scale_bytes
elif quant_mode == "bf16":
bytes_per_elem = 2.0
moe_0, buffer_shapes, keep_tensors_0 = build_moe_bf16(layer_idx=0)
@@ -356,7 +481,7 @@ def bench_write_buffer(quant_mode: str):
end = time.perf_counter()
iter_time = end - start
total_time += iter_time
print(f" Iter {iter_idx}: {iter_time*1000:.2f} ms")
# print(f" Iter {iter_idx}: {iter_time*1000:.2f} ms")
time.sleep(0.3)
# bytes_per_call is for one MOE, we have 2 MOEs
@@ -400,7 +525,7 @@ def bench_write_buffer(quant_mode: str):
def main(quant_modes=None):
"""Run benchmarks for specified quant modes."""
if quant_modes is None:
quant_modes = ["fp8", "bf16"]
quant_modes = ["fp8", "fp8_perchannel", "bf16"]
results = {}
for mode in quant_modes:
@@ -423,10 +548,10 @@ def main(quant_modes=None):
if __name__ == "__main__":
if len(sys.argv) > 1:
mode = sys.argv[1].lower()
if mode in ["fp8", "bf16"]:
if mode in ["fp8", "fp8_perchannel", "bf16"]:
main([mode])
else:
print(f"Unknown mode: {mode}. Use 'fp8' or 'bf16'")
print(f"Unknown mode: {mode}. Use 'fp8', 'fp8_perchannel' or 'bf16'")
sys.exit(1)
else:
main()

View File

@@ -22,12 +22,12 @@ from kt_kernel import kt_kernel_ext
torch.manual_seed(42)
# Model config
hidden_size = 3072
intermediate_size = 1536
hidden_size = 2048
intermediate_size = 768
max_len = 25600
expert_num = 16
num_experts_per_tok = 16
expert_num = 128
num_experts_per_tok = 8
qlen = 1
layer_num = 5
@@ -180,13 +180,13 @@ def run_bf16_moe_test():
diffs = []
with torch.inference_mode(mode=True):
for i in range(validation_iter):
torch.manual_seed(100 + i)
torch.manual_seed(114514 + i)
bsz_tensor = torch.tensor([qlen], device="cpu")
expert_ids = torch.stack(
[torch.randperm(expert_num)[:num_experts_per_tok] for _ in range(qlen)]
).contiguous()
weights = torch.randn((qlen, num_experts_per_tok), dtype=torch.float32).contiguous() / 100
input_tensor = torch.randn((qlen, hidden_size), dtype=torch.bfloat16).contiguous() * 1.5
weights = torch.randn((qlen, num_experts_per_tok), dtype=torch.float32).contiguous() / 10
input_tensor = torch.randn((qlen, hidden_size), dtype=torch.bfloat16).contiguous() * 3
output = torch.empty((qlen, hidden_size), dtype=torch.bfloat16).contiguous()
moe = moes[i % layer_num]

View File

@@ -0,0 +1,408 @@
"""
Test script for FP8 Per-Channel MoE kernel validation (GLM-4.7-FP8 style).
This script:
1. Generates random BF16 weights
2. Quantizes them to FP8 format with per-channel scales (one scale per output channel)
3. Runs the FP8 Per-Channel MoE kernel
4. Compares results with PyTorch reference using dequantized BF16 weights
FP8 Per-Channel format notes:
- Weight: FP8 (E4M3) stored as uint8, shape [expert_num, n, k]
- Scale: FP32, shape [expert_num, n] (one scale per output row)
"""
import os
import sys
sys.path.insert(0, os.path.dirname(__file__) + "/../build")
import torch
from kt_kernel import kt_kernel_ext
torch.manual_seed(42)
# Model config
hidden_size = 3072
intermediate_size = 1536
max_len = 25600
expert_num = 16
num_experts_per_tok = 8
qlen = 100
layer_num = 1
CPUInfer = kt_kernel_ext.CPUInfer(40)
validation_iter = 1
debug_print_count = 16
physical_to_logical_map = torch.tensor(data=range(expert_num), device="cpu", dtype=torch.int64).contiguous()
def act_fn(x):
"""SiLU activation function"""
return x / (1.0 + torch.exp(-x))
def mlp_torch(input, gate_proj, up_proj, down_proj):
"""Reference MLP computation in PyTorch"""
gate_buf = torch.mm(input, gate_proj.t())
up_buf = torch.mm(input, up_proj.t())
intermediate = act_fn(gate_buf) * up_buf
ret = torch.mm(intermediate, down_proj.t())
return ret
def moe_torch(input, expert_ids, weights, gate_proj, up_proj, down_proj):
"""Reference MoE computation in PyTorch"""
cnts = expert_ids.new_zeros((expert_ids.shape[0], expert_num))
cnts.scatter_(1, expert_ids, 1)
tokens_per_expert = cnts.sum(dim=0)
idxs = expert_ids.view(-1).argsort()
sorted_tokens = input[idxs // expert_ids.shape[1]]
outputs = []
start_idx = 0
for i, num_tokens in enumerate(tokens_per_expert):
end_idx = start_idx + num_tokens
if num_tokens == 0:
continue
tokens_for_this_expert = sorted_tokens[start_idx:end_idx]
expert_out = mlp_torch(tokens_for_this_expert, gate_proj[i], up_proj[i], down_proj[i])
outputs.append(expert_out)
start_idx = end_idx
outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0)
new_x = torch.empty_like(outs)
new_x[idxs] = outs
t_output = (
new_x.view(*expert_ids.shape, -1)
.type(weights.dtype)
.mul_(weights.unsqueeze(dim=-1))
.sum(dim=1)
.type(new_x.dtype)
)
return t_output
# FP8 E4M3 constants
FP8_E4M3_MAX = 448.0 # Maximum representable value in FP8 E4M3
def fp8_e4m3_to_float(fp8_val: int) -> float:
"""
Convert FP8 E4M3 value to float.
FP8 E4M3 format: 1 sign bit, 4 exponent bits, 3 mantissa bits
"""
sign = (fp8_val >> 7) & 1
exp = (fp8_val >> 3) & 0xF
mant = fp8_val & 0x7
if exp == 0:
# Subnormal or zero
if mant == 0:
return -0.0 if sign else 0.0
# Subnormal: value = (-1)^sign * 2^(-6) * (0.mant)
return ((-1) ** sign) * (2**-6) * (mant / 8.0)
elif exp == 15:
# NaN (FP8 E4M3 doesn't have Inf, all exp=15 are NaN)
return float("nan")
else:
# Normal: value = (-1)^sign * 2^(exp-7) * (1.mant)
return ((-1) ** sign) * (2 ** (exp - 7)) * (1.0 + mant / 8.0)
def float_to_fp8_e4m3(val: float) -> int:
"""
Convert float to FP8 E4M3 value.
"""
if val != val: # NaN
return 0x7F # NaN representation
sign = 1 if val < 0 else 0
val = abs(val)
if val == 0:
return sign << 7
# Clamp to max representable value
val = min(val, FP8_E4M3_MAX)
# Find exponent
import math
if val < 2**-9: # Subnormal threshold
# Subnormal
mant = int(round(val / (2**-9)))
mant = min(mant, 7)
return (sign << 7) | mant
exp = int(math.floor(math.log2(val))) + 7
exp = max(1, min(exp, 14)) # Clamp exponent to valid range
# Calculate mantissa
mant = int(round((val / (2 ** (exp - 7)) - 1.0) * 8))
mant = max(0, min(mant, 7))
# Handle overflow to next exponent
if mant > 7:
mant = 0
exp += 1
if exp > 14:
exp = 14
mant = 7
return (sign << 7) | (exp << 3) | mant
def quantize_to_fp8_perchannel(weights: torch.Tensor):
"""
Quantize BF16/FP32 weights to FP8 with per-channel scaling.
Args:
weights: [expert_num, n, k] tensor in BF16/FP32
Returns:
fp8_weights: [expert_num, n, k] uint8 tensor
scales: [expert_num, n] FP32 tensor (one scale per output row)
"""
weights_f32 = weights.to(torch.float32)
e, n, k = weights_f32.shape
# Calculate max abs per row (per output channel)
max_abs = weights_f32.abs().amax(dim=-1, keepdim=True) # [e, n, 1]
max_abs = torch.clamp(max_abs, min=1e-12)
# Scale to FP8 range: scale = max_abs / FP8_MAX
scales = (max_abs / FP8_E4M3_MAX).squeeze(-1) # [e, n]
# Quantize: q = round(val / scale)
scaled = weights_f32 / (scales.unsqueeze(-1) + 1e-12)
# Clamp to FP8 representable range
scaled = scaled.clamp(-FP8_E4M3_MAX, FP8_E4M3_MAX)
# Vectorized FP8 quantization
fp8_q = torch.zeros_like(scaled, dtype=torch.uint8)
sign_mask = (scaled < 0).to(torch.uint8) << 7
abs_scaled = scaled.abs()
# Handle different ranges
# Subnormal: 0 < |x| < 2^-6
subnormal_mask = (abs_scaled > 0) & (abs_scaled < 2**-6)
subnormal_mant = (abs_scaled / (2**-9)).round().clamp(0, 7).to(torch.uint8)
# Normal values
normal_mask = abs_scaled >= 2**-6
log2_val = torch.log2(abs_scaled.clamp(min=2**-9))
exp = (log2_val.floor() + 7).clamp(1, 14).to(torch.int32)
mant = ((abs_scaled / (2.0 ** (exp.float() - 7)) - 1.0) * 8).round().clamp(0, 7).to(torch.uint8)
# Combine
fp8_q = torch.where(subnormal_mask, sign_mask | subnormal_mant, fp8_q)
fp8_q = torch.where(normal_mask, sign_mask | (exp.to(torch.uint8) << 3) | mant, fp8_q)
return fp8_q.contiguous(), scales.to(torch.float32).contiguous()
def dequantize_fp8_perchannel(fp8_weights: torch.Tensor, scales: torch.Tensor):
"""
Dequantize FP8 weights back to BF16 for reference computation.
Args:
fp8_weights: [expert_num, n, k] uint8 tensor
scales: [expert_num, n] FP32 tensor
Returns:
dequantized: [expert_num, n, k] BF16 tensor
"""
# Build lookup table for FP8 E4M3 -> float
fp8_lut = torch.tensor([fp8_e4m3_to_float(i) for i in range(256)], dtype=torch.float32)
# Use lookup table
fp8_float = fp8_lut[fp8_weights.to(torch.int64)]
# Apply per-channel scales
scales_expanded = scales.unsqueeze(-1) # [e, n, 1]
dequantized = fp8_float * scales_expanded
return dequantized.to(torch.bfloat16).contiguous()
def build_random_fp8_perchannel_weights():
"""
Generate random BF16 weights and quantize to FP8 with per-channel scales.
Returns:
dict with fp8 weights, scales, and original bf16 for reference
"""
torch.manual_seed(42)
# Generate random BF16 weights with small values
gate_proj = (torch.randn((expert_num, intermediate_size, hidden_size), dtype=torch.float32) / 100.0).to(
torch.bfloat16
)
up_proj = (torch.randn((expert_num, intermediate_size, hidden_size), dtype=torch.float32) / 100.0).to(
torch.bfloat16
)
down_proj = (torch.randn((expert_num, hidden_size, intermediate_size), dtype=torch.float32) / 100.0).to(
torch.bfloat16
)
# Quantize to FP8 with per-channel scales
gate_fp8, gate_scales = quantize_to_fp8_perchannel(gate_proj)
up_fp8, up_scales = quantize_to_fp8_perchannel(up_proj)
down_fp8, down_scales = quantize_to_fp8_perchannel(down_proj)
# Dequantize for reference computation
gate_deq = dequantize_fp8_perchannel(gate_fp8, gate_scales)
up_deq = dequantize_fp8_perchannel(up_fp8, up_scales)
down_deq = dequantize_fp8_perchannel(down_fp8, down_scales)
print(f"FP8 Per-Channel weights shape: gate={gate_fp8.shape}, up={up_fp8.shape}, down={down_fp8.shape}")
print(f"Per-Channel scales shape: gate={gate_scales.shape}, up={up_scales.shape}, down={down_scales.shape}")
# Debug: Print FP8 weight and scale info for expert 0
print("\n=== DEBUG: FP8 Per-Channel Weight and Scale Info (Expert 0) ===")
print(f"gate_fp8[0] first 8x8 block:")
for i in range(8):
print(f" row {i}: {gate_fp8[0, i, :8].numpy().tobytes().hex(' ')}")
print(f"gate_fp8[0] stats: min={gate_fp8[0].min()}, max={gate_fp8[0].max()}")
print(f"gate_scales[0] first 8 channels: {gate_scales[0, :8]}")
print(f"gate_scales[0] stats: min={gate_scales[0].min():.6f}, max={gate_scales[0].max():.6f}")
return {
"gate_fp8": gate_fp8.contiguous(),
"up_fp8": up_fp8.contiguous(),
"down_fp8": down_fp8.contiguous(),
"gate_scales": gate_scales.contiguous(),
"up_scales": up_scales.contiguous(),
"down_scales": down_scales.contiguous(),
"gate_deq": gate_deq.contiguous(),
"up_deq": up_deq.contiguous(),
"down_deq": down_deq.contiguous(),
}
def build_moes_from_fp8_perchannel_data(fp8_data: dict):
"""
Build FP8 Per-Channel MoE modules from quantized data.
"""
moes = []
with torch.inference_mode(mode=True):
for _ in range(layer_num):
config = kt_kernel_ext.moe.MOEConfig(expert_num, num_experts_per_tok, hidden_size, intermediate_size, 0)
config.max_len = max_len
config.quant_config.bits = 8
config.quant_config.group_size = 0 # Not used for per-channel
config.quant_config.zero_point = False
config.quant_config.per_channel = True # Enable per-channel mode
# Set FP8 weight pointers
config.gate_proj = fp8_data["gate_fp8"].data_ptr()
config.up_proj = fp8_data["up_fp8"].data_ptr()
config.down_proj = fp8_data["down_fp8"].data_ptr()
# Set per-channel scale pointers
config.gate_scale = fp8_data["gate_scales"].data_ptr()
config.up_scale = fp8_data["up_scales"].data_ptr()
config.down_scale = fp8_data["down_scales"].data_ptr()
config.pool = CPUInfer.backend_
moe = kt_kernel_ext.moe.AMXFP8PerChannel_MOE(config)
CPUInfer.submit(moe.load_weights_task(physical_to_logical_map.data_ptr()))
CPUInfer.sync()
moes.append(moe)
return moes
def run_fp8_perchannel_moe_test():
"""
Run FP8 Per-Channel MoE validation test.
"""
print("\n" + "=" * 70)
print("FP8 Per-Channel MoE Kernel Validation Test")
print("=" * 70)
# Build FP8 per-channel weights
print("\nGenerating and quantizing weights with per-channel scales...")
fp8_data = build_random_fp8_perchannel_weights()
# Build MoE modules
print("\nBuilding FP8 Per-Channel MoE modules...")
moes = build_moes_from_fp8_perchannel_data(fp8_data)
# Get dequantized weights for reference
gate_deq = fp8_data["gate_deq"]
up_deq = fp8_data["up_deq"]
down_deq = fp8_data["down_deq"]
diffs = []
with torch.inference_mode(mode=True):
for i in range(validation_iter):
torch.manual_seed(100 + i)
bsz_tensor = torch.tensor([qlen], device="cpu")
expert_ids = torch.stack(
[torch.randperm(expert_num)[:num_experts_per_tok] for _ in range(qlen)]
).contiguous()
weights = torch.randn((qlen, num_experts_per_tok), dtype=torch.float32).contiguous() / 100
input_tensor = torch.randn((qlen, hidden_size), dtype=torch.bfloat16).contiguous() * 1.5
output = torch.empty((qlen, hidden_size), dtype=torch.bfloat16).contiguous()
moe = moes[i % layer_num]
CPUInfer.submit(
moe.forward_task(
bsz_tensor.data_ptr(),
num_experts_per_tok,
expert_ids.data_ptr(),
weights.data_ptr(),
input_tensor.data_ptr(),
output.data_ptr(),
False,
)
)
CPUInfer.sync()
assert not torch.isnan(output).any(), "NaN values detected in CPU expert output."
assert not torch.isinf(output).any(), "Inf values detected in CPU expert output."
# Reference computation using dequantized weights
t_output = moe_torch(input_tensor, expert_ids, weights, gate_deq, up_deq, down_deq)
t_output_flat = t_output.flatten()
output_flat = output.flatten()
diff = torch.mean(torch.abs(output_flat - t_output_flat)) / (torch.mean(torch.abs(t_output_flat)) + 1e-12)
diffs.append(diff.item())
print(f"Iteration {i}: relative L1 diff = {diff:.6f}")
if i < 3: # Print detailed output for first few iterations
print(f" kernel output: {output_flat[:debug_print_count]}")
print(f" torch output: {t_output_flat[:debug_print_count]}")
mean_diff = float(sum(diffs) / len(diffs))
max_diff = float(max(diffs))
min_diff = float(min(diffs))
print("\n" + "=" * 70)
print("FP8 Per-Channel MoE Test Results")
print("=" * 70)
print(f"Mean relative L1 diff: {mean_diff*100:.4f}%")
print(f"Max relative L1 diff: {max_diff*100:.4f}%")
print(f"Min relative L1 diff: {min_diff*100:.4f}%")
# Pass/Fail criteria
threshold = 15.0 # 15% relative error threshold for FP8
if mean_diff * 100 < threshold:
print(f"\nPASS: Mean error {mean_diff*100:.4f}% < {threshold}% threshold")
else:
print(f"\nFAIL: Mean error {mean_diff*100:.4f}% >= {threshold}% threshold")
return {"mean": mean_diff, "max": max_diff, "min": min_diff}
if __name__ == "__main__":
run_fp8_perchannel_moe_test()

View File

@@ -2,12 +2,14 @@
Test write_weight_scale_to_buffer for AMX MOE operators.
Supports:
- FP8: FP8 weights (1 byte) + float32 scales
- FP8: FP8 weights (1 byte) + float32 scales (block-wise)
- FP8_PERCHANNEL: FP8 weights (1 byte) + float32 per-channel scales
- BF16: Native BF16 weights (2 bytes), no scales
Usage:
python test_write_buffer.py # Run all modes
python test_write_buffer.py fp8 # Run FP8 only
python test_write_buffer.py fp8_perchannel # Run FP8 per-channel only
python test_write_buffer.py bf16 # Run BF16 only
"""
@@ -41,6 +43,17 @@ def build_config_fp8(cpuinfer, expert_num, num_experts_per_tok, hidden_size, int
return cfg
def build_config_fp8_perchannel(cpuinfer, expert_num, num_experts_per_tok, hidden_size, intermediate_size):
cfg = kt_kernel_ext.moe.MOEConfig(expert_num, num_experts_per_tok, hidden_size, intermediate_size)
cfg.max_len = 1
cfg.quant_config.bits = 8 # FP8
cfg.quant_config.group_size = 0 # Not used for per-channel
cfg.quant_config.zero_point = False
cfg.quant_config.per_channel = True
cfg.pool = cpuinfer.backend_
return cfg
def build_config_bf16(cpuinfer, expert_num, num_experts_per_tok, hidden_size, intermediate_size):
cfg = kt_kernel_ext.moe.MOEConfig(expert_num, num_experts_per_tok, hidden_size, intermediate_size)
cfg.max_len = 1
@@ -83,6 +96,33 @@ def allocate_weights_fp8(expert_num, hidden_size, intermediate_size, group_size)
}
def allocate_weights_fp8_perchannel(expert_num, hidden_size, intermediate_size):
"""Allocate FP8 per-channel weights and scales for testing"""
per_mat_weight_bytes = hidden_size * intermediate_size
per_mat_scale_elems_gate_up = intermediate_size # one scale per output channel
per_mat_scale_elems_down = hidden_size
gate_q = torch.randint(0, 256, (expert_num * per_mat_weight_bytes,), dtype=torch.uint8)
up_q = torch.randint(0, 256, (expert_num * per_mat_weight_bytes,), dtype=torch.uint8)
down_q = torch.randint(0, 256, (expert_num * per_mat_weight_bytes,), dtype=torch.uint8)
gate_scale = torch.randn(expert_num * per_mat_scale_elems_gate_up, dtype=torch.float32)
up_scale = torch.randn(expert_num * per_mat_scale_elems_gate_up, dtype=torch.float32)
down_scale = torch.randn(expert_num * per_mat_scale_elems_down, dtype=torch.float32)
return {
"gate_q": gate_q,
"up_q": up_q,
"down_q": down_q,
"gate_scale": gate_scale,
"up_scale": up_scale,
"down_scale": down_scale,
"per_mat_weight_bytes": per_mat_weight_bytes,
"per_mat_scale_elems_gate_up": per_mat_scale_elems_gate_up,
"per_mat_scale_elems_down": per_mat_scale_elems_down,
}
def allocate_weights_bf16(expert_num, hidden_size, intermediate_size):
"""Allocate BF16 weights for testing (no scales)"""
# BF16 weights: 2 bytes per element
@@ -312,6 +352,195 @@ def test_fp8_write_buffer(gpu_tp_count):
return True
def test_fp8_perchannel_write_buffer(gpu_tp_count):
"""Test write_weight_scale_to_buffer with FP8 per-channel weights"""
torch.manual_seed(123)
expert_num = 256
gpu_experts = expert_num
num_experts_per_tok = 8
hidden_size = 3072
intermediate_size = 1536
cpuinfer = make_cpu_infer()
cfg = build_config_fp8_perchannel(cpuinfer, expert_num, num_experts_per_tok, hidden_size, intermediate_size)
weights = allocate_weights_fp8_perchannel(expert_num, hidden_size, intermediate_size)
cfg.gate_proj = weights["gate_q"].data_ptr()
cfg.up_proj = weights["up_q"].data_ptr()
cfg.down_proj = weights["down_q"].data_ptr()
cfg.gate_scale = weights["gate_scale"].data_ptr()
cfg.up_scale = weights["up_scale"].data_ptr()
cfg.down_scale = weights["down_scale"].data_ptr()
moe = kt_kernel_ext.moe.AMXFP8PerChannel_MOE(cfg)
physical_to_logical_map = torch.arange(expert_num, dtype=torch.int64, device="cpu").contiguous()
cpuinfer.submit(moe.load_weights_task(physical_to_logical_map.data_ptr()))
cpuinfer.sync()
per_mat_weight_bytes = weights["per_mat_weight_bytes"]
per_mat_scale_elems_gate_up = weights["per_mat_scale_elems_gate_up"]
per_mat_scale_elems_down = weights["per_mat_scale_elems_down"]
weight_bytes_per_expert_per_tp = per_mat_weight_bytes // gpu_tp_count
gpu_n_w13 = intermediate_size // gpu_tp_count
scale_elems_per_expert_per_tp_gate_up = gpu_n_w13
scale_elems_per_expert_per_tp_down = per_mat_scale_elems_down
total_weight_bytes_per_tp = gpu_experts * weight_bytes_per_expert_per_tp
total_scale_elems_per_tp_gate_up = gpu_experts * scale_elems_per_expert_per_tp_gate_up
total_scale_elems_per_tp_down = gpu_experts * scale_elems_per_expert_per_tp_down
w13_weight_bufs = [torch.empty(2 * total_weight_bytes_per_tp, dtype=torch.uint8) for _ in range(gpu_tp_count)]
w13_scale_bufs = [
torch.empty(2 * total_scale_elems_per_tp_gate_up, dtype=torch.float32) for _ in range(gpu_tp_count)
]
w2_weight_bufs = [torch.empty(total_weight_bytes_per_tp, dtype=torch.uint8) for _ in range(gpu_tp_count)]
w2_scale_bufs = [torch.empty(total_scale_elems_per_tp_down, dtype=torch.float32) for _ in range(gpu_tp_count)]
print(f"[FP8_PERCHANNEL] GPU TP count: {gpu_tp_count}, Experts: {expert_num}")
print(f"[FP8_PERCHANNEL] Weight bytes per expert per TP: {weight_bytes_per_expert_per_tp}")
print(f"[FP8_PERCHANNEL] Scale elements per expert per TP (gate/up): {scale_elems_per_expert_per_tp_gate_up}")
def get_expert_ptrs(expert_id):
w13_weight_ptrs = []
w13_scale_ptrs = []
w2_weight_ptrs = []
w2_scale_ptrs = []
for tp_idx in range(gpu_tp_count):
w13_weight_expert_offset = expert_id * 2 * weight_bytes_per_expert_per_tp
w13_scale_expert_offset = expert_id * 2 * scale_elems_per_expert_per_tp_gate_up
w2_weight_expert_offset = expert_id * weight_bytes_per_expert_per_tp
w2_scale_expert_offset = expert_id * scale_elems_per_expert_per_tp_down
w13_weight_ptrs.append(w13_weight_bufs[tp_idx].data_ptr() + w13_weight_expert_offset)
w13_scale_ptrs.append(w13_scale_bufs[tp_idx].data_ptr() + w13_scale_expert_offset * 4)
w2_weight_ptrs.append(w2_weight_bufs[tp_idx].data_ptr() + w2_weight_expert_offset)
w2_scale_ptrs.append(w2_scale_bufs[tp_idx].data_ptr() + w2_scale_expert_offset * 4)
return w13_weight_ptrs, w13_scale_ptrs, w2_weight_ptrs, w2_scale_ptrs
for _ in range(2):
for expert_id in range(gpu_experts):
w13_weight_ptrs, w13_scale_ptrs, w2_weight_ptrs, w2_scale_ptrs = get_expert_ptrs(expert_id)
cpuinfer.submit(
moe.write_weight_scale_to_buffer_task(
gpu_tp_count=gpu_tp_count,
expert_id=expert_id,
w13_weight_ptrs=w13_weight_ptrs,
w13_scale_ptrs=w13_scale_ptrs,
w2_weight_ptrs=w2_weight_ptrs,
w2_scale_ptrs=w2_scale_ptrs,
)
)
cpuinfer.sync()
begin_time = time.perf_counter_ns()
for expert_id in range(gpu_experts):
w13_weight_ptrs, w13_scale_ptrs, w2_weight_ptrs, w2_scale_ptrs = get_expert_ptrs(expert_id)
cpuinfer.submit(
moe.write_weight_scale_to_buffer_task(
gpu_tp_count=gpu_tp_count,
expert_id=expert_id,
w13_weight_ptrs=w13_weight_ptrs,
w13_scale_ptrs=w13_scale_ptrs,
w2_weight_ptrs=w2_weight_ptrs,
w2_scale_ptrs=w2_scale_ptrs,
)
)
cpuinfer.sync()
end_time = time.perf_counter_ns()
elapsed_ms = (end_time - begin_time) / 1e6
total_bytes = (
hidden_size * intermediate_size * gpu_experts * 3
+ (per_mat_scale_elems_gate_up * 2 + per_mat_scale_elems_down) * gpu_experts * 4
)
print(f"[FP8_PERCHANNEL] write_weight_scale_to_buffer time: {elapsed_ms:.2f} ms")
print(f"[FP8_PERCHANNEL] Throughput: {total_bytes / (elapsed_ms * 1e6):.2f} GB/s")
def split_expert_tensor(tensor, chunk):
return [tensor[i * chunk : (i + 1) * chunk] for i in range(expert_num)]
gate_q = weights["gate_q"]
up_q = weights["up_q"]
down_q = weights["down_q"]
gate_scale = weights["gate_scale"]
up_scale = weights["up_scale"]
down_scale = weights["down_scale"]
gate_q_experts = split_expert_tensor(gate_q, per_mat_weight_bytes)
up_q_experts = split_expert_tensor(up_q, per_mat_weight_bytes)
down_q_experts = split_expert_tensor(down_q, per_mat_weight_bytes)
gate_scale_experts = split_expert_tensor(gate_scale, per_mat_scale_elems_gate_up)
up_scale_experts = split_expert_tensor(up_scale, per_mat_scale_elems_gate_up)
down_scale_experts = split_expert_tensor(down_scale, per_mat_scale_elems_down)
for tp_idx in range(gpu_tp_count):
expected_w13_weights = []
expected_w13_scales = []
expected_w2_weights = []
expected_w2_scales = []
weight13_per_tp = per_mat_weight_bytes // gpu_tp_count
scale13_per_tp = per_mat_scale_elems_gate_up // gpu_tp_count
for expert_id in range(gpu_experts):
start_weight = tp_idx * weight13_per_tp
end_weight = (tp_idx + 1) * weight13_per_tp
start_scale = tp_idx * scale13_per_tp
end_scale = (tp_idx + 1) * scale13_per_tp
gate_weight_tp = gate_q_experts[expert_id][start_weight:end_weight]
gate_scale_tp = gate_scale_experts[expert_id][start_scale:end_scale]
up_weight_tp = up_q_experts[expert_id][start_weight:end_weight]
up_scale_tp = up_scale_experts[expert_id][start_scale:end_scale]
down_weight_tp_parts = []
tp_slice_weight_size = intermediate_size // gpu_tp_count
for row_idx in range(hidden_size):
row_weight_start = row_idx * intermediate_size
tp_weight_offset = row_weight_start + tp_idx * tp_slice_weight_size
down_weight_tp_parts.append(
down_q_experts[expert_id][tp_weight_offset : tp_weight_offset + tp_slice_weight_size]
)
down_weight_tp = torch.cat(down_weight_tp_parts)
down_scale_tp = down_scale_experts[expert_id]
expected_w13_weights.append(gate_weight_tp)
expected_w13_weights.append(up_weight_tp)
expected_w13_scales.append(gate_scale_tp)
expected_w13_scales.append(up_scale_tp)
expected_w2_weights.append(down_weight_tp)
expected_w2_scales.append(down_scale_tp)
expected_w13_weight = torch.cat(expected_w13_weights)
expected_w13_scale = torch.cat(expected_w13_scales)
expected_w2_weight = torch.cat(expected_w2_weights)
expected_w2_scale = torch.cat(expected_w2_scales)
if not torch.equal(w13_weight_bufs[tp_idx], expected_w13_weight):
diff_mask = w13_weight_bufs[tp_idx] != expected_w13_weight
first_diff_idx = diff_mask.nonzero()[0].item() if diff_mask.any() else -1
raise AssertionError(f"[FP8_PERCHANNEL] w13 weight mismatch for TP {tp_idx} at index {first_diff_idx}")
if not torch.allclose(w13_scale_bufs[tp_idx], expected_w13_scale):
raise AssertionError(f"[FP8_PERCHANNEL] w13 scale mismatch for TP {tp_idx}")
if not torch.equal(w2_weight_bufs[tp_idx], expected_w2_weight):
diff_mask = w2_weight_bufs[tp_idx] != expected_w2_weight
first_diff_idx = diff_mask.nonzero()[0].item() if diff_mask.any() else -1
raise AssertionError(f"[FP8_PERCHANNEL] w2 weight mismatch for TP {tp_idx} at index {first_diff_idx}")
if not torch.allclose(w2_scale_bufs[tp_idx], expected_w2_scale):
raise AssertionError(f"[FP8_PERCHANNEL] w2 scale mismatch for TP {tp_idx}")
print(f"[FP8_PERCHANNEL] TP={gpu_tp_count} PASSED (verified {gpu_experts} experts across {gpu_tp_count} TP parts)")
return True
def test_bf16_write_buffer(gpu_tp_count):
"""Test write_weight_scale_to_buffer with BF16 weights (no scales)"""
torch.manual_seed(123)
@@ -478,6 +707,8 @@ def test_with_tp(quant_mode: str, gpu_tp_count: int):
"""Test write_weight_scale_to_buffer with specified mode and TP count"""
if quant_mode == "fp8":
return test_fp8_write_buffer(gpu_tp_count)
elif quant_mode == "fp8_perchannel":
return test_fp8_perchannel_write_buffer(gpu_tp_count)
elif quant_mode == "bf16":
return test_bf16_write_buffer(gpu_tp_count)
else:
@@ -487,7 +718,7 @@ def test_with_tp(quant_mode: str, gpu_tp_count: int):
def main(quant_modes=None):
"""Run tests for specified quant modes"""
if quant_modes is None:
quant_modes = ["fp8", "bf16"]
quant_modes = ["fp8", "fp8_perchannel", "bf16"]
tp_values = [1, 2, 4]
all_passed = True
@@ -525,10 +756,10 @@ def main(quant_modes=None):
if __name__ == "__main__":
if len(sys.argv) > 1:
mode = sys.argv[1].lower()
if mode in ["fp8", "bf16"]:
if mode in ["fp8", "fp8_perchannel", "bf16"]:
main([mode])
else:
print(f"Unknown mode: {mode}. Use 'fp8' or 'bf16'")
print(f"Unknown mode: {mode}. Use 'fp8', 'fp8_perchannel' or 'bf16'")
sys.exit(1)
else:
main()

View File

@@ -37,8 +37,9 @@ static const bool _is_plain_ = false;
#if defined(__x86_64__) && defined(USE_AMX_AVX_KERNEL)
#include "operators/amx/awq-moe.hpp"
#if defined(__AVX512BF16__)
#include "operators/amx/bf16-moe.hpp" // Native BF16 MoE using CRTP pattern
#include "operators/amx/fp8-moe.hpp" // FP8 MoE requires AVX512 BF16 support
#include "operators/amx/bf16-moe.hpp" // Native BF16 MoE using CRTP pattern
#include "operators/amx/fp8-moe.hpp" // FP8 MoE requires AVX512 BF16 support
#include "operators/amx/fp8-perchannel-moe.hpp" // FP8 Per-Channel MoE for GLM-4.7-FP8
#endif
#include "operators/amx/k2-moe.hpp"
#include "operators/amx/la/amx_kernels.hpp"
@@ -252,8 +253,9 @@ void bind_moe_module(py::module_& moe_module, const char* name) {
.def("load_weights", &MoeClass::load_weights)
.def("forward", &MoeClass::forward_binding);
#if defined(__x86_64__) && defined(USE_AMX_AVX_KERNEL)
if constexpr (std::is_same_v<MoeTP, AMX_K2_MOE_TP<amx::GemmKernel224Int4SmallKGroup>>) {
// Bind write_weight_scale_to_buffer_task for MoE types that support it
// Uses SFINAE to detect if MoeClass has write_weight_scale_to_buffer method
if constexpr (requires { &MoeClass::write_weight_scale_to_buffer; }) {
struct WriteWeightScaleToBufferBindings {
struct Args {
CPUInfer* cpuinfer;
@@ -295,99 +297,6 @@ void bind_moe_module(py::module_& moe_module, const char* name) {
py::arg("gpu_tp_count"), py::arg("expert_id"), py::arg("w13_weight_ptrs"), py::arg("w13_scale_ptrs"),
py::arg("w2_weight_ptrs"), py::arg("w2_scale_ptrs"));
}
#if defined(__AVX512BF16__)
// FP8 MoE: processes one expert at a time (expert_id instead of gpu_experts_num)
// Only available on CPUs with AVX512 BF16 support
if constexpr (std::is_same_v<MoeTP, AMX_FP8_MOE_TP<amx::GemmKernel224FP8>>) {
struct WriteWeightScaleToBufferBindings {
struct Args {
CPUInfer* cpuinfer;
MoeClass* moe;
int gpu_tp_count;
int expert_id;
std::vector<uintptr_t> w13_weight_ptrs;
std::vector<uintptr_t> w13_scale_ptrs;
std::vector<uintptr_t> w2_weight_ptrs;
std::vector<uintptr_t> w2_scale_ptrs;
};
static void inner(void* args) {
Args* args_ = (Args*)args;
args_->cpuinfer->enqueue(&MoeClass::write_weight_scale_to_buffer, args_->moe, args_->gpu_tp_count,
args_->expert_id, args_->w13_weight_ptrs, args_->w13_scale_ptrs, args_->w2_weight_ptrs,
args_->w2_scale_ptrs);
}
static std::pair<intptr_t, intptr_t> cpuinfer_interface(std::shared_ptr<MoeClass> moe, int gpu_tp_count,
int expert_id, py::list w13_weight_ptrs,
py::list w13_scale_ptrs, py::list w2_weight_ptrs,
py::list w2_scale_ptrs) {
// Convert Python lists to std::vector<uintptr_t>
std::vector<uintptr_t> w13_weight_vec, w13_scale_vec, w2_weight_vec, w2_scale_vec;
for (auto item : w13_weight_ptrs) w13_weight_vec.push_back(py::cast<uintptr_t>(item));
for (auto item : w13_scale_ptrs) w13_scale_vec.push_back(py::cast<uintptr_t>(item));
for (auto item : w2_weight_ptrs) w2_weight_vec.push_back(py::cast<uintptr_t>(item));
for (auto item : w2_scale_ptrs) w2_scale_vec.push_back(py::cast<uintptr_t>(item));
Args* args = new Args{nullptr, moe.get(), gpu_tp_count, expert_id,
w13_weight_vec, w13_scale_vec, w2_weight_vec, w2_scale_vec};
return std::make_pair((intptr_t)&inner, (intptr_t)args);
}
};
moe_cls.def("write_weight_scale_to_buffer_task", &WriteWeightScaleToBufferBindings::cpuinfer_interface,
py::arg("gpu_tp_count"), py::arg("expert_id"), py::arg("w13_weight_ptrs"), py::arg("w13_scale_ptrs"),
py::arg("w2_weight_ptrs"), py::arg("w2_scale_ptrs"));
}
// BF16 MoE: processes one expert at a time (expert_id instead of gpu_experts_num)
// Only available on CPUs with AVX512 BF16 support
if constexpr (std::is_same_v<MoeTP, AMX_BF16_MOE_TP<amx::GemmKernel224BF16>>) {
struct WriteWeightScaleToBufferBindings {
struct Args {
CPUInfer* cpuinfer;
MoeClass* moe;
int gpu_tp_count;
int expert_id;
std::vector<uintptr_t> w13_weight_ptrs;
std::vector<uintptr_t> w13_scale_ptrs;
std::vector<uintptr_t> w2_weight_ptrs;
std::vector<uintptr_t> w2_scale_ptrs;
};
static void inner(void* args) {
Args* args_ = (Args*)args;
args_->cpuinfer->enqueue(&MoeClass::write_weight_scale_to_buffer, args_->moe, args_->gpu_tp_count,
args_->expert_id, args_->w13_weight_ptrs, args_->w13_scale_ptrs, args_->w2_weight_ptrs,
args_->w2_scale_ptrs);
}
static std::pair<intptr_t, intptr_t> cpuinfer_interface(std::shared_ptr<MoeClass> moe, int gpu_tp_count,
int expert_id, py::list w13_weight_ptrs,
py::list w13_scale_ptrs, py::list w2_weight_ptrs,
py::list w2_scale_ptrs) {
// Convert Python lists to std::vector<uintptr_t>
std::vector<uintptr_t> w13_weight_vec, w13_scale_vec, w2_weight_vec, w2_scale_vec;
for (auto item : w13_weight_ptrs) w13_weight_vec.push_back(py::cast<uintptr_t>(item));
for (auto item : w13_scale_ptrs) w13_scale_vec.push_back(py::cast<uintptr_t>(item));
for (auto item : w2_weight_ptrs) w2_weight_vec.push_back(py::cast<uintptr_t>(item));
for (auto item : w2_scale_ptrs) w2_scale_vec.push_back(py::cast<uintptr_t>(item));
Args* args = new Args{nullptr, moe.get(), gpu_tp_count, expert_id,
w13_weight_vec, w13_scale_vec, w2_weight_vec, w2_scale_vec};
return std::make_pair((intptr_t)&inner, (intptr_t)args);
}
};
moe_cls.def("write_weight_scale_to_buffer_task", &WriteWeightScaleToBufferBindings::cpuinfer_interface,
py::arg("gpu_tp_count"), py::arg("expert_id"), py::arg("w13_weight_ptrs"), py::arg("w13_scale_ptrs"),
py::arg("w2_weight_ptrs"), py::arg("w2_scale_ptrs"));
}
#endif // __AVX512BF16__
#endif
}
PYBIND11_MODULE(kt_kernel_ext, m) {
@@ -585,7 +494,8 @@ PYBIND11_MODULE(kt_kernel_ext, m) {
.def_readwrite("quant_method", &QuantConfig::quant_method)
.def_readwrite("bits", &QuantConfig::bits)
.def_readwrite("group_size", &QuantConfig::group_size)
.def_readwrite("zero_point", &QuantConfig::zero_point);
.def_readwrite("zero_point", &QuantConfig::zero_point)
.def_readwrite("per_channel", &QuantConfig::per_channel);
auto moe_module = m.def_submodule("moe");
@@ -660,6 +570,7 @@ PYBIND11_MODULE(kt_kernel_ext, m) {
#if defined(__AVX512BF16__)
bind_moe_module<AMX_BF16_MOE_TP<amx::GemmKernel224BF16>>(moe_module, "AMXBF16_MOE");
bind_moe_module<AMX_FP8_MOE_TP<amx::GemmKernel224FP8>>(moe_module, "AMXFP8_MOE");
bind_moe_module<AMX_FP8_PERCHANNEL_MOE_TP<amx::GemmKernel224FP8PerChannel>>(moe_module, "AMXFP8PerChannel_MOE");
#endif
#endif
#if defined(USE_MOE_KERNEL)

View File

@@ -213,6 +213,12 @@ class AMX_BF16_MOE_TP : public AMX_MOE_BASE<T, AMX_BF16_MOE_TP<T>> {
_mm512_storeu_si512(dst, data);
}
// Fast 64-byte non-temporal store (bypass cache for write-only patterns)
static inline void fast_stream_64(void* __restrict dst, const void* __restrict src) {
__m512i data = _mm512_loadu_si512(src);
_mm512_stream_si512((__m512i*)dst, data);
}
// Fast memcpy for arbitrary sizes using AVX512
static inline void fast_memcpy(void* __restrict dst, const void* __restrict src, size_t bytes) {
uint8_t* d = (uint8_t*)dst;
@@ -262,19 +268,19 @@ class AMX_BF16_MOE_TP : public AMX_MOE_BASE<T, AMX_BF16_MOE_TP<T>> {
amx::transpose_16x16_32bit(temp_block1);
amx::transpose_16x16_32bit(temp_block2);
// Copy transposed data to destination in n-major layout
const ggml_bf16_t* temp1_bf16 = reinterpret_cast<const ggml_bf16_t*>(temp_block1);
const ggml_bf16_t* temp2_bf16 = reinterpret_cast<const ggml_bf16_t*>(temp_block2);
// Copy transposed data to destination in n-major layout using non-temporal stores
// First 16 rows (block 1)
for (int i = 0; i < TILE_N; i++) {
std::memcpy(dst + i * dst_row_stride, temp1_bf16 + i * K_STEP, K_STEP * sizeof(ggml_bf16_t));
fast_stream_64(dst + i * dst_row_stride, &temp_block1[i]);
}
// Next 16 rows (block 2)
for (int i = 0; i < TILE_N; i++) {
std::memcpy(dst + (TILE_N + i) * dst_row_stride, temp2_bf16 + i * K_STEP, K_STEP * sizeof(ggml_bf16_t));
fast_stream_64(dst + (TILE_N + i) * dst_row_stride, &temp_block2[i]);
}
// Ensure all stores complete before returning
_mm_sfence();
}
/**

View File

@@ -0,0 +1,716 @@
/**
* @Description : FP8 Per-Channel AMX MoE operator for GLM-4.7-FP8 native inference
* @Author : Claude
* @Date : 2025-01-12
* @Version : 1.0.0
* @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
*
* This file implements FP8 MoE with per-channel quantization using CRTP pattern.
* Per-channel quantization: each output channel (row) has one scale factor.
* This is different from block-wise quantization where each 128x128 block has one scale.
**/
#ifndef CPUINFER_OPERATOR_AMX_FP8_PERCHANNEL_MOE_H
#define CPUINFER_OPERATOR_AMX_FP8_PERCHANNEL_MOE_H
#include "la/amx_raw_buffers.hpp"
#include "la/amx_raw_kernels.hpp"
#include "moe_base.hpp"
/**
* @brief FP8 Per-Channel MoE operator using CRTP pattern
* @tparam T Kernel type, defaults to GemmKernel224FP8PerChannel
*
* This class provides FP8 per-channel specific implementations:
* - do_gate_up_gemm, do_down_gemm : FP8 weight -> BF16 conversion mat mul with per-channel scale
* - load_weights: Load FP8 weights with per-channel scales (shape: [n])
*/
template <class T = amx::GemmKernel224FP8PerChannel>
class AMX_FP8_PERCHANNEL_MOE_TP : public AMX_MOE_BASE<T, AMX_FP8_PERCHANNEL_MOE_TP<T>> {
using Base = AMX_MOE_BASE<T, AMX_FP8_PERCHANNEL_MOE_TP<T>>;
using Base::config_;
using Base::down_ba_;
using Base::down_bb_;
using Base::down_bc_;
using Base::gate_bb_;
using Base::gate_bc_;
using Base::gate_up_ba_;
using Base::m_local_num_;
using Base::tp_part_idx;
using Base::up_bb_;
using Base::up_bc_;
public:
using typename Base::input_t;
using typename Base::output_t;
AMX_FP8_PERCHANNEL_MOE_TP() = default;
AMX_FP8_PERCHANNEL_MOE_TP(GeneralMOEConfig config, int tp_part_idx_ = 0) : Base(config, tp_part_idx_) {
// Initialization now happens in derived_init() which is called by base constructor
}
void derived_init() {
auto& quant_config = config_.quant_config;
if (!quant_config.per_channel) {
throw std::runtime_error("KT-Kernel FP8 Per-Channel MoE requires per_channel=true");
}
printf("Created AMX_FP8_PERCHANNEL_MOE_TP %d at numa %d\n", tp_part_idx, numa_node_of_cpu(sched_getcpu()));
}
~AMX_FP8_PERCHANNEL_MOE_TP() = default;
// ============================================================================
// CRTP buffer creation - per-channel (no group_size needed)
// ============================================================================
size_t buffer_a_required_size_impl(size_t m, size_t k) const { return T::BufferA::required_size(m, k); }
size_t buffer_b_required_size_impl(size_t n, size_t k) const {
// Per-channel: weight size + n scales (no group_size)
return T::BufferB::required_size(n, k);
}
size_t buffer_c_required_size_impl(size_t m, size_t n) const { return T::BufferC::required_size(m, n); }
std::shared_ptr<typename T::BufferA> make_buffer_a_impl(size_t m, size_t k, void* data) const {
return std::make_shared<typename T::BufferA>(m, k, data);
}
std::shared_ptr<typename T::BufferB> make_buffer_b_impl(size_t n, size_t k, void* data) const {
// Per-channel BufferB doesn't need group_size
return std::make_shared<typename T::BufferB>(n, k, data);
}
std::shared_ptr<typename T::BufferC> make_buffer_c_impl(size_t m, size_t n, void* data) const {
return std::make_shared<typename T::BufferC>(m, n, data);
}
// ============================================================================
// CRTP virtual points - GEMM dispatch (per-channel)
// ============================================================================
void do_gate_up_gemm(bool do_up, int expert_idx, int ith, int nth, int qlen) {
int m = m_local_num_[expert_idx];
auto& ba = gate_up_ba_[expert_idx];
auto& bb = do_up ? up_bb_[expert_idx] : gate_bb_[expert_idx];
auto& bc = do_up ? up_bc_[expert_idx] : gate_bc_[expert_idx];
// Per-channel: use vec_mul_perchannel instead of vec_mul_kgroup
amx::float_mat_vec_perchannel<T>(m, config_.intermediate_size, config_.hidden_size, ba.get(), bb.get(), bc.get(),
ith, nth);
}
void do_down_gemm(int expert_idx, int ith, int nth, int qlen) {
int m = m_local_num_[expert_idx];
amx::float_mat_vec_perchannel<T>(m, config_.hidden_size, config_.intermediate_size, down_ba_[expert_idx].get(),
down_bb_[expert_idx].get(), down_bc_[expert_idx].get(), ith, nth);
}
// Fast 64-byte (512-bit) memcpy using AVX512
static inline void fast_memcpy_64(void* __restrict dst, const void* __restrict src) {
__m512i data = _mm512_loadu_si512(src);
_mm512_storeu_si512(dst, data);
}
// Fast memcpy for arbitrary sizes using AVX512
static inline void fast_memcpy(void* __restrict dst, const void* __restrict src, size_t bytes) {
uint8_t* d = (uint8_t*)dst;
const uint8_t* s = (const uint8_t*)src;
size_t chunks = bytes / 64;
for (size_t i = 0; i < chunks; i++) {
fast_memcpy_64(d, s);
d += 64;
s += 64;
}
bytes -= chunks * 64;
if (bytes > 0) {
std::memcpy(d, s, bytes);
}
}
/**
* @brief Unpack a single N_STEP x K_STEP block from packed BufferB format to n-major format
*
* This is the inverse of the packing done in BufferBFP8PerChannelImpl::from_mat.
* Optimized with AVX512 gather for efficient non-contiguous reads.
*
* @param src Pointer to packed data (N_STEP * K_STEP bytes in packed layout)
* @param dst Pointer to destination in n-major layout
* @param dst_row_stride Row stride in destination buffer (number of columns in full matrix)
*/
static inline void unpack_nk_block(const uint8_t* src, uint8_t* dst, size_t dst_row_stride) {
// row_map[packed_i] gives the base row for packed index packed_i
static constexpr int row_map[8] = {0, 16, 4, 20, 8, 24, 12, 28};
const uint64_t* src64 = reinterpret_cast<const uint64_t*>(src);
// Gather indices: src64[8*j + packed_i] for j = 0..7
// Offsets in uint64 units: 0, 8, 16, 24, 32, 40, 48, 56 (+ packed_i for each group)
const __m512i gather_offsets = _mm512_set_epi64(56, 48, 40, 32, 24, 16, 8, 0);
// Process each packed group (8 groups of 4 rows each = 32 rows total)
for (int packed_i = 0; packed_i < 8; packed_i++) {
const int base_row = row_map[packed_i];
const uint64_t* base_src = src64 + packed_i;
// Gather 8 values for j=0..7 and j=8..15
__m512i vals_0_7 = _mm512_i64gather_epi64(gather_offsets, base_src, 8);
__m512i vals_8_15 = _mm512_i64gather_epi64(gather_offsets, base_src + 64, 8);
// Extract 4 rows from each set of 8 values
// Row 0: bits 0-15
__m128i row0_lo = _mm512_cvtepi64_epi16(_mm512_and_si512(vals_0_7, _mm512_set1_epi64(0xFFFF)));
__m128i row0_hi = _mm512_cvtepi64_epi16(_mm512_and_si512(vals_8_15, _mm512_set1_epi64(0xFFFF)));
// Row 1: bits 16-31
__m128i row1_lo =
_mm512_cvtepi64_epi16(_mm512_and_si512(_mm512_srli_epi64(vals_0_7, 16), _mm512_set1_epi64(0xFFFF)));
__m128i row1_hi =
_mm512_cvtepi64_epi16(_mm512_and_si512(_mm512_srli_epi64(vals_8_15, 16), _mm512_set1_epi64(0xFFFF)));
// Row 2: bits 32-47
__m128i row2_lo =
_mm512_cvtepi64_epi16(_mm512_and_si512(_mm512_srli_epi64(vals_0_7, 32), _mm512_set1_epi64(0xFFFF)));
__m128i row2_hi =
_mm512_cvtepi64_epi16(_mm512_and_si512(_mm512_srli_epi64(vals_8_15, 32), _mm512_set1_epi64(0xFFFF)));
// Row 3: bits 48-63
__m128i row3_lo = _mm512_cvtepi64_epi16(_mm512_srli_epi64(vals_0_7, 48));
__m128i row3_hi = _mm512_cvtepi64_epi16(_mm512_srli_epi64(vals_8_15, 48));
// Store 32 bytes (16 x uint16) to each row
// Combine two 128-bit values into 256-bit for more efficient stores
uint8_t* row0_dst = dst + (size_t)base_row * dst_row_stride;
uint8_t* row1_dst = dst + (size_t)(base_row + 1) * dst_row_stride;
uint8_t* row2_dst = dst + (size_t)(base_row + 2) * dst_row_stride;
uint8_t* row3_dst = dst + (size_t)(base_row + 3) * dst_row_stride;
// Combine lo and hi into 256-bit and store
__m256i row0_256 = _mm256_set_m128i(row0_hi, row0_lo);
__m256i row1_256 = _mm256_set_m128i(row1_hi, row1_lo);
__m256i row2_256 = _mm256_set_m128i(row2_hi, row2_lo);
__m256i row3_256 = _mm256_set_m128i(row3_hi, row3_lo);
_mm256_storeu_si256((__m256i*)row0_dst, row0_256);
_mm256_storeu_si256((__m256i*)row1_dst, row1_256);
_mm256_storeu_si256((__m256i*)row2_dst, row2_256);
_mm256_storeu_si256((__m256i*)row3_dst, row3_256);
}
}
/**
* @brief Unpack 4 consecutive N_STEP x K_STEP blocks to maximize cache line utilization
*
* Processing 4 blocks together means each row write is 128 bytes = 2 cache lines,
* which greatly improves write efficiency compared to 32 bytes per row.
*
* @param src Array of 4 source pointers (each pointing to a 32x32 packed block)
* @param dst Destination pointer in n-major layout
* @param dst_row_stride Row stride in destination buffer
*/
static inline void unpack_4nk_blocks(const uint8_t* src[4], uint8_t* dst, size_t dst_row_stride) {
static constexpr int row_map[8] = {0, 16, 4, 20, 8, 24, 12, 28};
constexpr int K_STEP = T::K_STEP; // 32
// Reinterpret as uint64 arrays for efficient access
const uint64_t* src0 = reinterpret_cast<const uint64_t*>(src[0]);
const uint64_t* src1 = reinterpret_cast<const uint64_t*>(src[1]);
const uint64_t* src2 = reinterpret_cast<const uint64_t*>(src[2]);
const uint64_t* src3 = reinterpret_cast<const uint64_t*>(src[3]);
// Process all 32 rows, writing 128 bytes (4 x 32) per row
for (int packed_i = 0; packed_i < 8; packed_i++) {
const int base_row = row_map[packed_i];
// Process 4 rows at a time
for (int r = 0; r < 4; r++) {
uint16_t* row_dst = reinterpret_cast<uint16_t*>(dst + (size_t)(base_row + r) * dst_row_stride);
const int shift = r * 16;
// Unroll: process all 4 blocks x 16 columns = 64 uint16 values
// Block 0: columns 0-15
for (int j = 0; j < 16; j++) {
row_dst[j] = static_cast<uint16_t>(src0[8 * j + packed_i] >> shift);
}
// Block 1: columns 16-31
for (int j = 0; j < 16; j++) {
row_dst[16 + j] = static_cast<uint16_t>(src1[8 * j + packed_i] >> shift);
}
// Block 2: columns 32-47
for (int j = 0; j < 16; j++) {
row_dst[32 + j] = static_cast<uint16_t>(src2[8 * j + packed_i] >> shift);
}
// Block 3: columns 48-63
for (int j = 0; j < 16; j++) {
row_dst[48 + j] = static_cast<uint16_t>(src3[8 * j + packed_i] >> shift);
}
}
}
}
/**
* @brief Reconstruct weights for a single expert to the output buffers (per-channel version)
*
* Directly unpacks from packed BufferB format to n-major GPU buffers without intermediate storage.
* Scale handling is simplified for per-channel quantization (linear copy instead of block-wise).
*
* @param gpu_tp_count Number of GPU TP parts (1, 2, 4, or 8)
* @param cpu_tp_count Number of CPU TP parts
* @param expert_id Expert index to process
* @param full_config Full configuration (before CPU TP split)
* @param w13_weight_ptrs Pointers to gate+up weight buffers (one per GPU TP)
* @param w13_scale_ptrs Pointers to gate+up scale buffers (one per GPU TP)
* @param w2_weight_ptrs Pointers to down weight buffers (one per GPU TP)
* @param w2_scale_ptrs Pointers to down scale buffers (one per GPU TP)
*/
void write_weights_to_buffer(int gpu_tp_count, [[maybe_unused]] int cpu_tp_count, int expert_id,
const GeneralMOEConfig& full_config, const std::vector<uintptr_t>& w13_weight_ptrs,
const std::vector<uintptr_t>& w13_scale_ptrs,
const std::vector<uintptr_t>& w2_weight_ptrs,
const std::vector<uintptr_t>& w2_scale_ptrs) const {
auto& config = config_;
auto pool = config.pool->get_subpool(tp_part_idx);
constexpr int N_STEP = T::N_STEP;
constexpr int K_STEP = T::K_STEP;
constexpr int N_BLOCK = T::N_BLOCK;
constexpr int K_BLOCK = T::K_BLOCK;
// ========= W13 (gate+up): Shape [intermediate, hidden], split by N only =========
const int cpu_n_w13 = config.intermediate_size;
const int cpu_k_w13 = config.hidden_size;
const int gpu_n_w13 = full_config.intermediate_size / gpu_tp_count;
const int gpu_k_w13 = full_config.hidden_size;
const int global_n_offset_w13 = tp_part_idx * cpu_n_w13;
const size_t gpu_w13_weight_per_mat = (size_t)gpu_n_w13 * gpu_k_w13;
// Per-channel scale: shape [n] for each matrix
const size_t gpu_w13_scale_per_mat = (size_t)gpu_n_w13;
// ========= W2 (down): Shape [hidden, intermediate], split by K =========
const int cpu_n_w2 = config.hidden_size;
const int cpu_k_w2 = config.intermediate_size;
const int gpu_n_w2 = full_config.hidden_size;
const int gpu_k_w2 = full_config.intermediate_size / gpu_tp_count;
const int global_k_offset_w2 = tp_part_idx * cpu_k_w2;
const size_t gpu_w2_weight_per_mat = (size_t)gpu_n_w2 * gpu_k_w2;
// Per-channel scale for down: shape [hidden_size] - not split by K
const size_t gpu_w2_scale_per_mat = (size_t)gpu_n_w2;
// ========= Optimized job layout =========
constexpr int NUM_W13_TASKS = 32; // Per matrix (gate or up), total 64 for w13
constexpr int NUM_W2_TASKS = 32; // For down matrix
constexpr int SCALE_TASKS = 3; // gate_scale, up_scale, down_scale
const int total_tasks = NUM_W13_TASKS * 2 + NUM_W2_TASKS + SCALE_TASKS;
// Calculate N_STEP blocks per task (must be N_STEP aligned for correct BufferB addressing)
const int w13_n_steps = div_up(cpu_n_w13, N_STEP);
const int w13_steps_per_task = div_up(w13_n_steps, NUM_W13_TASKS);
const int w2_n_steps = div_up(cpu_n_w2, N_STEP);
const int w2_steps_per_task = div_up(w2_n_steps, NUM_W2_TASKS);
pool->do_work_stealing_job(
total_tasks, nullptr,
[=, &w13_weight_ptrs, &w13_scale_ptrs, &w2_weight_ptrs, &w2_scale_ptrs, this](int task_id) {
if (task_id < NUM_W13_TASKS * 2) {
// ========= W13 weight task: process chunk of rows x full K =========
const bool is_up = task_id >= NUM_W13_TASKS;
const int chunk_idx = task_id % NUM_W13_TASKS;
const auto& bb = is_up ? up_bb_[expert_id] : gate_bb_[expert_id];
// Calculate row range for this task (N_STEP aligned)
const int step_start = chunk_idx * w13_steps_per_task;
const int step_end = std::min(step_start + w13_steps_per_task, w13_n_steps);
if (step_start >= w13_n_steps) return;
const int chunk_n_start = step_start * N_STEP;
const int chunk_n_end = std::min(step_end * N_STEP, cpu_n_w13);
// Process each N_STEP within this chunk
for (int local_n_start = chunk_n_start; local_n_start < chunk_n_end; local_n_start += N_STEP) {
// Calculate GPU target and offset for each N_STEP (may cross GPU TP boundaries)
const int global_n = global_n_offset_w13 + local_n_start;
const int target_gpu = global_n / gpu_n_w13;
const int n_in_gpu = global_n % gpu_n_w13;
uint8_t* weight_base = (uint8_t*)w13_weight_ptrs[target_gpu];
// Pointer already points to current expert's location, only add offset for up matrix
const size_t expert_weight_off = is_up ? gpu_w13_weight_per_mat : 0;
// Calculate N_BLOCK info for source addressing
const int n_block_idx = local_n_start / N_BLOCK;
const int n_block_begin = n_block_idx * N_BLOCK;
const int n_block_size = std::min(N_BLOCK, cpu_n_w13 - n_block_begin);
const int n_in_block = local_n_start - n_block_begin;
// Process all K in groups of 4 K_STEPs when possible for cache efficiency
for (int k_block_begin = 0; k_block_begin < cpu_k_w13; k_block_begin += K_BLOCK) {
const int k_block_size = std::min(K_BLOCK, cpu_k_w13 - k_block_begin);
// Try to process 4 K_STEPs at once (128 columns = 2 cache lines per row)
int k_begin = 0;
for (; k_begin + 4 * K_STEP <= k_block_size; k_begin += 4 * K_STEP) {
const uint8_t* src_ptrs[4];
for (int i = 0; i < 4; i++) {
src_ptrs[i] = bb->b + (size_t)n_block_begin * cpu_k_w13 + (size_t)k_block_begin * n_block_size +
(size_t)n_in_block * k_block_size + (size_t)(k_begin + i * K_STEP) * N_STEP;
}
uint8_t* dst =
weight_base + expert_weight_off + (size_t)n_in_gpu * gpu_k_w13 + k_block_begin + k_begin;
unpack_4nk_blocks(src_ptrs, dst, gpu_k_w13);
}
// Handle remaining K_STEPs one by one
for (; k_begin < k_block_size; k_begin += K_STEP) {
const uint8_t* src = bb->b + (size_t)n_block_begin * cpu_k_w13 +
(size_t)k_block_begin * n_block_size + (size_t)n_in_block * k_block_size +
(size_t)k_begin * N_STEP;
uint8_t* dst =
weight_base + expert_weight_off + (size_t)n_in_gpu * gpu_k_w13 + k_block_begin + k_begin;
unpack_nk_block(src, dst, gpu_k_w13);
}
}
}
} else if (task_id < NUM_W13_TASKS * 2 + NUM_W2_TASKS) {
// ========= W2 weight task: process chunk of rows x all K slices =========
const int chunk_idx = task_id - NUM_W13_TASKS * 2;
const auto& bb = down_bb_[expert_id];
// Calculate row range for this task (N_STEP aligned)
const int step_start = chunk_idx * w2_steps_per_task;
const int step_end = std::min(step_start + w2_steps_per_task, w2_n_steps);
if (step_start >= w2_n_steps) return;
const int chunk_n_start = step_start * N_STEP;
const int chunk_n_end = std::min(step_end * N_STEP, cpu_n_w2);
// Process each N_STEP within this chunk
for (int local_n_start = chunk_n_start; local_n_start < chunk_n_end; local_n_start += N_STEP) {
// Calculate N_BLOCK info for source addressing
const int n_block_idx = local_n_start / N_BLOCK;
const int n_block_begin = n_block_idx * N_BLOCK;
const int n_block_size = std::min(N_BLOCK, cpu_n_w2 - n_block_begin);
const int n_in_block = local_n_start - n_block_begin;
// Process all K slices (each slice goes to a different GPU TP)
for (int k_slice_start = 0; k_slice_start < cpu_k_w2; k_slice_start += gpu_k_w2) {
const int k_slice_end = std::min(k_slice_start + gpu_k_w2, cpu_k_w2);
const int global_k_start = global_k_offset_w2 + k_slice_start;
const int target_gpu = global_k_start / gpu_k_w2;
const int k_in_gpu_base = global_k_start % gpu_k_w2;
uint8_t* weight_base = (uint8_t*)w2_weight_ptrs[target_gpu];
// Pointer already points to current expert's location
const size_t expert_weight_off = 0;
// Process K within this slice, trying 4 K_STEPs at once when aligned
for (int k_abs = k_slice_start; k_abs < k_slice_end;) {
const int k_block_idx = k_abs / K_BLOCK;
const int k_block_begin = k_block_idx * K_BLOCK;
const int k_block_size = std::min(K_BLOCK, cpu_k_w2 - k_block_begin);
const int k_in_block = k_abs - k_block_begin;
const int k_in_gpu = k_in_gpu_base + (k_abs - k_slice_start);
// Check if we can process 4 K_STEPs at once
const int remaining_in_block = k_block_size - k_in_block;
const int remaining_in_slice = k_slice_end - k_abs;
if (remaining_in_block >= 4 * K_STEP && remaining_in_slice >= 4 * K_STEP) {
const uint8_t* src_ptrs[4];
for (int i = 0; i < 4; i++) {
src_ptrs[i] = bb->b + (size_t)n_block_begin * cpu_k_w2 + (size_t)k_block_begin * n_block_size +
(size_t)n_in_block * k_block_size + (size_t)(k_in_block + i * K_STEP) * N_STEP;
}
uint8_t* dst = weight_base + expert_weight_off + (size_t)local_n_start * gpu_k_w2 + k_in_gpu;
unpack_4nk_blocks(src_ptrs, dst, gpu_k_w2);
k_abs += 4 * K_STEP;
} else {
const uint8_t* src = bb->b + (size_t)n_block_begin * cpu_k_w2 +
(size_t)k_block_begin * n_block_size + (size_t)n_in_block * k_block_size +
(size_t)k_in_block * N_STEP;
uint8_t* dst = weight_base + expert_weight_off + (size_t)local_n_start * gpu_k_w2 + k_in_gpu;
unpack_nk_block(src, dst, gpu_k_w2);
k_abs += K_STEP;
}
}
}
}
} else {
// ========= Scale copy task: per-channel (simple linear copy) =========
const int scale_task_id = task_id - NUM_W13_TASKS * 2 - NUM_W2_TASKS;
if (scale_task_id < 2) {
// Gate (0) or Up (1) scale copy - per-channel: [intermediate_size]
const bool is_up = scale_task_id == 1;
const auto& bb = is_up ? up_bb_[expert_id] : gate_bb_[expert_id];
// W13 per-channel scales: copy N range corresponding to this CPU TP
// Each GPU TP gets [gpu_n_w13] scales
const int n_start_global = global_n_offset_w13;
for (int local_n = 0; local_n < cpu_n_w13;) {
const int global_n = n_start_global + local_n;
const int target_gpu = global_n / gpu_n_w13;
const int n_in_gpu = global_n % gpu_n_w13;
// Calculate how many scales to copy to this GPU TP
const int remaining_in_gpu = gpu_n_w13 - n_in_gpu;
const int remaining_local = cpu_n_w13 - local_n;
const int copy_count = std::min(remaining_in_gpu, remaining_local);
float* scale_dst = (float*)w13_scale_ptrs[target_gpu];
// Pointer already points to current expert's location, only add offset for up matrix
const size_t expert_scale_off = is_up ? gpu_w13_scale_per_mat : 0;
fast_memcpy(scale_dst + expert_scale_off + n_in_gpu, bb->d + local_n, copy_count * sizeof(float));
local_n += copy_count;
}
} else {
// Down scale copy (scale_task_id == 2) - per-channel: [hidden_size]
const auto& bb = down_bb_[expert_id];
// W2 per-channel scales: shape [hidden_size], not split by K
// All GPU TPs get the same scales (full hidden_size)
// However, since K is split, we need to write to each GPU TP
for (int gpu_idx = 0; gpu_idx < gpu_tp_count; gpu_idx++) {
// Check if this CPU TP contributes to this GPU TP's K range
const int gpu_k_start = gpu_idx * gpu_k_w2;
const int gpu_k_end = gpu_k_start + gpu_k_w2;
const int cpu_k_start = global_k_offset_w2;
const int cpu_k_end = cpu_k_start + cpu_k_w2;
// Check for overlap
if (cpu_k_start < gpu_k_end && cpu_k_end > gpu_k_start) {
// This CPU TP contributes to this GPU TP
// Only the first CPU TP for this GPU should write scales
if (cpu_k_start == gpu_k_start || cpu_k_start % gpu_k_w2 == 0) {
float* scale_dst = (float*)w2_scale_ptrs[gpu_idx];
// Pointer already points to current expert's location
fast_memcpy(scale_dst, bb->d, cpu_n_w2 * sizeof(float));
}
}
}
}
}
},
nullptr);
}
/**
* @brief Load FP8 weights from contiguous memory layout with per-channel scales
*
* Loads weights from config_.gate_proj, up_proj, down_proj with scales
* from config_.gate_scale, up_scale, down_scale.
*
* Per-channel scale shape: [n] (one scale per output channel)
*/
void load_weights() {
const uint64_t* physical_to_logical_map = (const uint64_t*)config_.physical_to_logical_map;
auto pool = config_.pool->get_subpool(tp_part_idx);
if (config_.gate_scale == nullptr) {
throw std::runtime_error("FP8 Per-Channel MoE requires scale pointers.");
}
// load gate and up weights
int nth = T::recommended_nth(config_.intermediate_size);
pool->do_work_stealing_job(
nth * config_.expert_num, nullptr,
[this, nth, physical_to_logical_map](int task_id) {
uint64_t expert_idx = task_id / nth;
uint64_t logical_expert_id = expert_map(physical_to_logical_map, expert_idx);
int ith = task_id % nth;
// Per-channel scale: shape [intermediate_size] for gate/up
const size_t weight_offset = logical_expert_id * config_.intermediate_size * config_.hidden_size;
const size_t scale_offset = logical_expert_id * config_.intermediate_size;
// gate part
gate_bb_[expert_idx]->from_mat((uint8_t*)config_.gate_proj + weight_offset,
(float*)config_.gate_scale + scale_offset, ith, nth);
// up part
up_bb_[expert_idx]->from_mat((uint8_t*)config_.up_proj + weight_offset,
(float*)config_.up_scale + scale_offset, ith, nth);
},
nullptr);
// load down weights
nth = T::recommended_nth(config_.hidden_size);
pool->do_work_stealing_job(
nth * config_.expert_num, nullptr,
[this, nth, physical_to_logical_map](int task_id) {
uint64_t expert_idx = task_id / nth;
uint64_t logical_expert_id = expert_map(physical_to_logical_map, expert_idx);
int ith = task_id % nth;
// Per-channel scale: shape [hidden_size] for down
const size_t weight_offset = logical_expert_id * config_.intermediate_size * config_.hidden_size;
const size_t scale_offset = logical_expert_id * config_.hidden_size;
// down part
down_bb_[expert_idx]->from_mat((uint8_t*)config_.down_proj + weight_offset,
(float*)config_.down_scale + scale_offset, ith, nth);
},
nullptr);
}
};
/**
* @brief TP_MOE specialization for FP8 Per-Channel MoE
*/
template <typename K>
class TP_MOE<AMX_FP8_PERCHANNEL_MOE_TP<K>> : public TP_MOE<AMX_MOE_BASE<K, AMX_FP8_PERCHANNEL_MOE_TP<K>>> {
public:
using Base = TP_MOE<AMX_MOE_BASE<K, AMX_FP8_PERCHANNEL_MOE_TP<K>>>;
using Base::Base;
/**
* @brief Write weights and scales to GPU buffer for a single expert
*
* This method coordinates all CPU TP parts to write their portions
* of weights and scales to the GPU buffers.
*
* @param gpu_tp_count Number of GPU TP parts
* @param expert_id Expert index to write
* @param w13_weight_ptrs Pointers to gate+up weight buffers (one per GPU TP)
* @param w13_scale_ptrs Pointers to gate+up scale buffers (one per GPU TP)
* @param w2_weight_ptrs Pointers to down weight buffers (one per GPU TP)
* @param w2_scale_ptrs Pointers to down scale buffers (one per GPU TP)
*/
void write_weight_scale_to_buffer(int gpu_tp_count, int expert_id, const std::vector<uintptr_t>& w13_weight_ptrs,
const std::vector<uintptr_t>& w13_scale_ptrs,
const std::vector<uintptr_t>& w2_weight_ptrs,
const std::vector<uintptr_t>& w2_scale_ptrs) {
if (this->weights_loaded == false) {
throw std::runtime_error("Not Loaded");
}
if (this->tps.empty()) {
throw std::runtime_error("No TP parts initialized");
}
if ((int)w13_weight_ptrs.size() != gpu_tp_count || (int)w13_scale_ptrs.size() != gpu_tp_count ||
(int)w2_weight_ptrs.size() != gpu_tp_count || (int)w2_scale_ptrs.size() != gpu_tp_count) {
throw std::runtime_error("Pointer arrays size must match gpu_tp_count");
}
this->config.pool->dispense_backend()->do_numa_job([&, this](int i) {
this->tps[i]->write_weights_to_buffer(gpu_tp_count, this->tp_count, expert_id, this->config, w13_weight_ptrs,
w13_scale_ptrs, w2_weight_ptrs, w2_scale_ptrs);
});
}
void load_weights() override {
auto& config = this->config;
auto& tps = this->tps;
auto& tp_count = this->tp_count;
auto pool = config.pool;
const uint64_t* physical_to_logical_map = (const uint64_t*)config.physical_to_logical_map;
if (!config.quant_config.per_channel) {
throw std::runtime_error("FP8 Per-Channel MoE requires per_channel=true");
}
if (config.gate_projs.empty() && config.gate_proj == nullptr) {
throw std::runtime_error("no weight source");
}
const bool use_per_expert_ptrs = !config.gate_projs.empty();
const size_t full_weight_elems = (size_t)config.intermediate_size * config.hidden_size;
// Per-channel: scale count = output dimension
const size_t gate_up_scale_elems = (size_t)config.intermediate_size;
const size_t down_scale_elems = (size_t)config.hidden_size;
pool->dispense_backend()->do_numa_job([&, this](int i) {
auto& tpc = tps[i]->config_;
const size_t tp_weight_elems = (size_t)tpc.intermediate_size * tpc.hidden_size;
// Per-channel scales for TP part
const size_t tp_gate_up_scale_elems = (size_t)tpc.intermediate_size;
const size_t tp_down_scale_elems = (size_t)tpc.hidden_size;
tpc.gate_proj = new uint8_t[tpc.expert_num * tp_weight_elems];
tpc.up_proj = new uint8_t[tpc.expert_num * tp_weight_elems];
tpc.down_proj = new uint8_t[tpc.expert_num * tp_weight_elems];
tpc.gate_scale = new float[tpc.expert_num * tp_gate_up_scale_elems];
tpc.up_scale = new float[tpc.expert_num * tp_gate_up_scale_elems];
tpc.down_scale = new float[tpc.expert_num * tp_down_scale_elems];
const size_t tp_idx = (size_t)i;
// gate/up: split by N (intermediate_size)
const size_t gate_up_weight_src_offset = i * tp_weight_elems;
const size_t gate_up_scale_src_offset = i * tp_gate_up_scale_elems;
// down: split by K (intermediate_size)
const size_t down_weight_src_col_offset = i * (size_t)tpc.intermediate_size;
pool->get_subpool(i)->do_work_stealing_job(
tpc.expert_num, nullptr,
[&, &tpc](int expert_id_) {
const size_t expert_id = expert_map(physical_to_logical_map, expert_id_);
uint8_t* gate_dst = (uint8_t*)tpc.gate_proj + expert_id * tp_weight_elems;
uint8_t* up_dst = (uint8_t*)tpc.up_proj + expert_id * tp_weight_elems;
uint8_t* down_dst = (uint8_t*)tpc.down_proj + expert_id * tp_weight_elems;
float* gate_scale_dst = (float*)tpc.gate_scale + expert_id * tp_gate_up_scale_elems;
float* up_scale_dst = (float*)tpc.up_scale + expert_id * tp_gate_up_scale_elems;
float* down_scale_dst = (float*)tpc.down_scale + expert_id * tp_down_scale_elems;
const uint8_t* gate_src;
const uint8_t* up_src;
const uint8_t* down_src;
const float* gate_scale_src;
const float* up_scale_src;
const float* down_scale_src;
if (use_per_expert_ptrs) {
gate_src = (const uint8_t*)config.gate_projs[0][expert_id] + gate_up_weight_src_offset;
up_src = (const uint8_t*)config.up_projs[0][expert_id] + gate_up_weight_src_offset;
down_src = (const uint8_t*)config.down_projs[0][expert_id];
gate_scale_src = (const float*)config.gate_scales[0][expert_id] + gate_up_scale_src_offset;
up_scale_src = (const float*)config.up_scales[0][expert_id] + gate_up_scale_src_offset;
down_scale_src = (const float*)config.down_scales[0][expert_id];
} else {
gate_src = (const uint8_t*)config.gate_proj + expert_id * full_weight_elems + gate_up_weight_src_offset;
up_src = (const uint8_t*)config.up_proj + expert_id * full_weight_elems + gate_up_weight_src_offset;
down_src = (const uint8_t*)config.down_proj + expert_id * full_weight_elems;
gate_scale_src =
(const float*)config.gate_scale + expert_id * gate_up_scale_elems + gate_up_scale_src_offset;
up_scale_src = (const float*)config.up_scale + expert_id * gate_up_scale_elems + gate_up_scale_src_offset;
down_scale_src = (const float*)config.down_scale + expert_id * down_scale_elems;
}
// Copy gate/up weights and scales (N dimension split)
std::memcpy(gate_dst, gate_src, tp_weight_elems);
std::memcpy(up_dst, up_src, tp_weight_elems);
std::memcpy(gate_scale_dst, gate_scale_src, sizeof(float) * tp_gate_up_scale_elems);
std::memcpy(up_scale_dst, up_scale_src, sizeof(float) * tp_gate_up_scale_elems);
// Copy down weights (K dimension split) - row by row
for (int row = 0; row < config.hidden_size; row++) {
const size_t src_row_offset = (size_t)row * (size_t)config.intermediate_size + down_weight_src_col_offset;
const size_t dst_row_offset = (size_t)row * (size_t)tpc.intermediate_size;
std::memcpy(down_dst + dst_row_offset, down_src + src_row_offset, (size_t)tpc.intermediate_size);
}
// Copy down scales (N dimension = hidden_size, full copy for each TP)
std::memcpy(down_scale_dst, down_scale_src, sizeof(float) * tp_down_scale_elems);
},
nullptr);
});
DO_TPS_LOAD_WEIGHTS(pool);
pool->dispense_backend()->do_numa_job([&, this](int i) {
auto& tpc = tps[i]->config_;
delete[] (uint8_t*)tpc.gate_proj;
delete[] (uint8_t*)tpc.up_proj;
delete[] (uint8_t*)tpc.down_proj;
delete[] (float*)tpc.gate_scale;
delete[] (float*)tpc.up_scale;
delete[] (float*)tpc.down_scale;
});
this->weights_loaded = true;
}
};
#endif // CPUINFER_OPERATOR_AMX_FP8_PERCHANNEL_MOE_H

View File

@@ -483,6 +483,119 @@ struct BufferCFP32ReduceImpl {
}
};
// ============================================================================
// BufferBFP8PerChannelImpl: FP8 权重缓冲区Per Channel 量化)
// ============================================================================
/**
* @brief FP8 Per-Channel 权重缓冲区
*
* 存储 FP8 格式的权重矩阵,每个输出通道(行)有一个缩放因子。
* 这与 GLM-4.7-FP8 的 per-channel 量化格式匹配。
*
* 与 BufferBFP8Impl (block-wise) 的区别:
* - Block-wise: scale shape = [n/128, k/128], 每 128x128 块一个 scale
* - Per-channel: scale shape = [n], 每行一个 scale
*
* @tparam K Kernel 类型
*/
template <typename K>
struct BufferBFP8PerChannelImpl {
uint8_t* b; // FP8 weight [n, k]
float* d; // per-channel scale [n]
int n, k;
static constexpr int N_STEP = K::N_STEP;
static constexpr int K_STEP = K::K_STEP;
static constexpr int N_BLOCK = K::N_BLOCK;
static constexpr int K_BLOCK = K::K_BLOCK;
static constexpr bool SCALE = true;
static constexpr bool PER_CHANNEL = true;
/**
* @brief 计算所需内存大小
* weight: n * k bytes (FP8)
* scale: n * sizeof(float) bytes
*/
static size_t required_size(int n, int k) { return sizeof(uint8_t) * n * k + sizeof(float) * n; }
/**
* @brief 构造函数
*/
BufferBFP8PerChannelImpl(int n, int k, void* ptr) : n(n), k(k) { set_data(ptr); }
void set_data(void* ptr) {
assert(reinterpret_cast<intptr_t>(ptr) % 64 == 0);
b = reinterpret_cast<uint8_t*>(ptr);
d = reinterpret_cast<float*>(b + (size_t)n * k);
}
static constexpr int mat_offset[8] = {0, 2, 4, 6, 1, 3, 5, 7}; // fp8 matrix offset for reordering
/**
* @brief 从原始 FP8 权重加载per-channel 量化格式)
*
* @param b_src FP8 权重源数据 (n-major, n×k)
* @param d_src FP32 per-channel scale 源数据 (shape: [n] or [n, 1])
*/
void from_mat(const uint8_t* b_src, const float* d_src, int ith, int nth) {
assert(b != nullptr && d != nullptr);
assert(N_STEP == 32 && K_STEP == 32);
// Copy per-channel scales. Each thread copies its own n-block range.
if (d_src != nullptr) {
auto [n_start, n_end] = K::split_range_n(n, ith, nth);
memcpy(d + n_start, d_src + n_start, sizeof(float) * (n_end - n_start));
}
// Reorder FP8 weights into KT block-major layout (same as BufferBFP8Impl)
auto [n_start, n_end] = K::split_range_n(n, ith, nth);
int n_block_begin = n_start;
int n_block_size = n_end - n_block_begin;
for (int n_begin = 0; n_begin < n_block_size; n_begin += N_STEP) {
int n_step_size = std::min(N_STEP, n_block_size - n_begin);
for (int k_block_begin = 0; k_block_begin < k; k_block_begin += K_BLOCK) {
int k_block_size = std::min(K_BLOCK, k - k_block_begin);
for (int k_begin = 0; k_begin < k_block_size; k_begin += K_STEP) {
int k_step_size = std::min(K_STEP, k_block_size - k_begin);
// [k_step_size, n_step_size] block copy
const uint8_t* block_b_src = b_src + (size_t)(n_block_begin + n_begin) * k + k_block_begin + k_begin;
uint64_t* block_b_dst =
reinterpret_cast<uint64_t*>(b + (size_t)n_block_begin * k + (size_t)k_block_begin * n_block_size +
(size_t)n_begin * k_block_size + (size_t)k_begin * N_STEP);
for (int i = 0; i < 8; i++) {
const uint16_t* s = reinterpret_cast<const uint16_t*>(block_b_src + (size_t)i * k * 4);
for (int j = 0; j < 16; j++) {
uint64_t val = (((uint64_t)s[j])) | (((uint64_t)s[j + (k / 2) * 1]) << 16) |
(((uint64_t)s[j + (k / 2) * 2]) << 32) | (((uint64_t)s[j + (k / 2) * 3]) << 48);
block_b_dst[8 * j + mat_offset[i]] = val;
}
}
}
}
}
}
/**
* @brief 获取行 n_begin 开始的 per-channel scale 指针
*/
float* get_scale(int n_begin) { return d + n_begin; }
/**
* @brief 获取子矩阵指针
*/
uint8_t* get_submat(int n, int k, int n_begin, int k_begin) {
int n_block_begin = n_begin / N_BLOCK * N_BLOCK;
n_begin -= n_block_begin;
int n_block_size = std::min(N_BLOCK, n - n_block_begin);
int k_block_begin = k_begin / K_BLOCK * K_BLOCK;
k_begin -= k_block_begin;
int k_block_size = std::min(K_BLOCK, k - k_block_begin);
return b + (size_t)n_block_begin * k + (size_t)k_block_begin * n_block_size + (size_t)n_begin * k_block_size +
(size_t)k_begin * N_STEP;
}
};
} // namespace amx
#endif // AMX_RAW_BUFFERS_HPP

View File

@@ -285,7 +285,7 @@ struct GemmKernel224FP8 {
static void config() {}
private:
// FP8->BF16 conversion lookup tables (public for reuse by GemmKernel224FP8PerChannel)
alignas(64) static constexpr uint8_t bf16_hi_0_val[64] = {
0x00, 0x3b, 0x3b, 0x3b, 0x3c, 0x3c, 0x3c, 0x3c, 0x3c, 0x3c, 0x3c, 0x3c, 0x3c, 0x3c, 0x3c, 0x3c,
0x3d, 0x3d, 0x3d, 0x3d, 0x3d, 0x3d, 0x3d, 0x3d, 0x3d, 0x3d, 0x3d, 0x3d, 0x3d, 0x3d, 0x3d, 0x3d,
@@ -317,8 +317,6 @@ struct GemmKernel224FP8 {
static inline __m512i bf16_lo_0_mask() { return _mm512_load_si512((__m512i const*)bf16_lo_0_val); }
static inline __m512i bf16_lo_1_mask() { return _mm512_load_si512((__m512i const*)bf16_lo_1_val); }
static inline __m512i sign_mask() { return _mm512_set1_epi8(0x80); }
public:
using BufferA = BufferABF16Impl<GemmKernel224FP8>;
using BufferB = BufferBFP8Impl<GemmKernel224FP8>;
using BufferC = BufferCFP32ReduceImpl<GemmKernel224FP8>;
@@ -618,6 +616,231 @@ inline void mat_mul_kgroup(int m, int n, int k, int k_group_size, std::shared_pt
float_mat_vec_kgroup<GemmKernel224FP8, false>(m, n, k, k_group_size, ba.get(), bb.get(), bc.get(), ith, nth);
}
// ============================================================================
// Per-Channel FP8 GEMM (for GLM-4.7-FP8 style quantization)
// ============================================================================
/**
* @brief FP8 Per-Channel Kernel
*
* Similar to GemmKernel224FP8 but with per-channel scaling instead of block-wise scaling.
* - Block-wise: scale shape = [n/128, k/128], one scale per 128x128 block
* - Per-channel: scale shape = [n], one scale per output row
*/
struct GemmKernel224FP8PerChannel {
using fp8_t = uint8_t;
using output_t = float;
static constexpr double ELEMENT_SIZE = 1.0;
static const int TILE_M = 16;
static const int TILE_K = 32;
static const int TILE_N = 16;
static const int VNNI_BLK = 2;
static const int M_STEP = TILE_M * 2;
static const int N_STEP = TILE_N * 2;
static const int K_STEP = TILE_K;
// Use smaller N_BLOCK for per-channel to allow efficient scale application
static inline const int N_BLOCK = 128;
static inline const int K_BLOCK = 7168;
static std::string name() { return "FP8PerChannel"; }
static int recommended_nth(int n) { return (n + N_BLOCK - 1) / N_BLOCK; }
static std::pair<int, int> split_range_n(int n, int ith, int nth) {
int n_start = N_BLOCK * ith;
int n_end = std::min(n, N_BLOCK * (ith + 1));
return {n_start, n_end};
}
static void config() {}
using BufferA = BufferABF16Impl<GemmKernel224FP8PerChannel>;
using BufferB = BufferBFP8PerChannelImpl<GemmKernel224FP8PerChannel>;
using BufferC = BufferCFP32Impl<GemmKernel224FP8PerChannel>;
// Reuse FP8->BF16 conversion from GemmKernel224FP8
static inline std::pair<__m512i, __m512i> fp8x64_to_bf16x64(__m512i bfp8_512) {
return GemmKernel224FP8::fp8x64_to_bf16x64(bfp8_512);
}
/**
* @brief Apply per-channel scale to result
*
* Unlike block-wise scaling, per-channel scaling applies a different scale to each column
* of the result (each output channel).
*
* @param m Total rows
* @param n Total columns
* @param m_begin Starting row
* @param n_begin Starting column
* @param c Output buffer (M_STEP x N_STEP)
* @param bb BufferB containing per-channel scales
*/
static void apply_scale_perchannel(int m, [[maybe_unused]] int n, int m_begin, int n_begin, float* c, BufferB* bb) {
int to = std::min(m - m_begin, M_STEP);
// Load N_STEP per-channel scales (32 floats)
__m512 bs_lo = _mm512_loadu_ps(bb->get_scale(n_begin)); // scale[n_begin..n_begin+15]
__m512 bs_hi = _mm512_loadu_ps(bb->get_scale(n_begin + TILE_N)); // scale[n_begin+16..n_begin+31]
for (int i = 0; i < to; i++) {
// Each row gets multiplied by the same set of per-channel scales
__m512 c_lo = _mm512_load_ps(c + i * N_STEP);
__m512 c_hi = _mm512_load_ps(c + i * N_STEP + TILE_N);
_mm512_store_ps(c + i * N_STEP, _mm512_mul_ps(c_lo, bs_lo));
_mm512_store_ps(c + i * N_STEP + TILE_N, _mm512_mul_ps(c_hi, bs_hi));
}
}
// AVX kernel for per-channel FP8 GEMM - processes entire K dimension
static void avx_kernel_4(int m, int n, int k, int m_begin, int n_begin, int k_block_begin, float* c, BufferA* ba,
BufferB* bb) {
const __m512i bf16_hi_0 = GemmKernel224FP8::bf16_hi_0_mask();
const __m512i bf16_hi_1 = GemmKernel224FP8::bf16_hi_1_mask();
const __m512i bf16_lo_0 = GemmKernel224FP8::bf16_lo_0_mask();
const __m512i bf16_lo_1 = GemmKernel224FP8::bf16_lo_1_mask();
const __m512i sign_mask_v = GemmKernel224FP8::sign_mask();
__m512* c512 = (__m512*)c;
int m_block_end = std::min(m - m_begin, M_STEP);
// Zero out accumulator at start of K_BLOCK
if (k_block_begin == 0) {
for (int m_i = 0; m_i < m_block_end; m_i++) {
c512[m_i * 2] = _mm512_setzero_ps();
c512[m_i * 2 + 1] = _mm512_setzero_ps();
}
}
// Process K_BLOCK
for (int k_begin = 0; k_begin < K_BLOCK && k_block_begin + k_begin < k; k_begin += K_STEP) {
ggml_bf16_t* abf16 = (ggml_bf16_t*)ba->get_submat(m, k, m_begin, k_block_begin + k_begin);
__m512i* bfp8_512 = (__m512i*)bb->get_submat(n, k, n_begin, k_block_begin + k_begin);
// Process 4 k_i at once
for (int k_i = 0; k_i < 16; k_i += 4) {
// Load 4 B vectors
__m512i bfp8_0 = bfp8_512[k_i];
__m512i bfp8_1 = bfp8_512[k_i + 1];
__m512i bfp8_2 = bfp8_512[k_i + 2];
__m512i bfp8_3 = bfp8_512[k_i + 3];
// Convert all 4 FP8 -> BF16
__m512i b_hi, b_lo;
b_hi = _mm512_or_si512(_mm512_and_si512(sign_mask_v, bfp8_0),
_mm512_permutex2var_epi8(bf16_hi_0, bfp8_0, bf16_hi_1));
b_lo = _mm512_permutex2var_epi8(bf16_lo_0, bfp8_0, bf16_lo_1);
__m512bh bbf16_0_lo = (__m512bh)_mm512_unpacklo_epi8(b_lo, b_hi);
__m512bh bbf16_0_hi = (__m512bh)_mm512_unpackhi_epi8(b_lo, b_hi);
b_hi = _mm512_or_si512(_mm512_and_si512(sign_mask_v, bfp8_1),
_mm512_permutex2var_epi8(bf16_hi_0, bfp8_1, bf16_hi_1));
b_lo = _mm512_permutex2var_epi8(bf16_lo_0, bfp8_1, bf16_lo_1);
__m512bh bbf16_1_lo = (__m512bh)_mm512_unpacklo_epi8(b_lo, b_hi);
__m512bh bbf16_1_hi = (__m512bh)_mm512_unpackhi_epi8(b_lo, b_hi);
b_hi = _mm512_or_si512(_mm512_and_si512(sign_mask_v, bfp8_2),
_mm512_permutex2var_epi8(bf16_hi_0, bfp8_2, bf16_hi_1));
b_lo = _mm512_permutex2var_epi8(bf16_lo_0, bfp8_2, bf16_lo_1);
__m512bh bbf16_2_lo = (__m512bh)_mm512_unpacklo_epi8(b_lo, b_hi);
__m512bh bbf16_2_hi = (__m512bh)_mm512_unpackhi_epi8(b_lo, b_hi);
b_hi = _mm512_or_si512(_mm512_and_si512(sign_mask_v, bfp8_3),
_mm512_permutex2var_epi8(bf16_hi_0, bfp8_3, bf16_hi_1));
b_lo = _mm512_permutex2var_epi8(bf16_lo_0, bfp8_3, bf16_lo_1);
__m512bh bbf16_3_lo = (__m512bh)_mm512_unpacklo_epi8(b_lo, b_hi);
__m512bh bbf16_3_hi = (__m512bh)_mm512_unpackhi_epi8(b_lo, b_hi);
// Process m rows
int m_i = 0;
for (; m_i + 1 < m_block_end; m_i += 2) {
__m512bh ma0_0 = (__m512bh)_mm512_set1_epi32(*(int32_t*)&abf16[m_i * K_STEP + k_i * 2]);
__m512bh ma1_0 = (__m512bh)_mm512_set1_epi32(*(int32_t*)&abf16[m_i * K_STEP + (k_i + 1) * 2]);
__m512bh ma2_0 = (__m512bh)_mm512_set1_epi32(*(int32_t*)&abf16[m_i * K_STEP + (k_i + 2) * 2]);
__m512bh ma3_0 = (__m512bh)_mm512_set1_epi32(*(int32_t*)&abf16[m_i * K_STEP + (k_i + 3) * 2]);
__m512bh ma0_1 = (__m512bh)_mm512_set1_epi32(*(int32_t*)&abf16[(m_i + 1) * K_STEP + k_i * 2]);
__m512bh ma1_1 = (__m512bh)_mm512_set1_epi32(*(int32_t*)&abf16[(m_i + 1) * K_STEP + (k_i + 1) * 2]);
__m512bh ma2_1 = (__m512bh)_mm512_set1_epi32(*(int32_t*)&abf16[(m_i + 1) * K_STEP + (k_i + 2) * 2]);
__m512bh ma3_1 = (__m512bh)_mm512_set1_epi32(*(int32_t*)&abf16[(m_i + 1) * K_STEP + (k_i + 3) * 2]);
c512[m_i * 2] = _mm512_dpbf16_ps(c512[m_i * 2], ma0_0, bbf16_0_lo);
c512[m_i * 2 + 1] = _mm512_dpbf16_ps(c512[m_i * 2 + 1], ma0_0, bbf16_0_hi);
c512[m_i * 2] = _mm512_dpbf16_ps(c512[m_i * 2], ma1_0, bbf16_1_lo);
c512[m_i * 2 + 1] = _mm512_dpbf16_ps(c512[m_i * 2 + 1], ma1_0, bbf16_1_hi);
c512[m_i * 2] = _mm512_dpbf16_ps(c512[m_i * 2], ma2_0, bbf16_2_lo);
c512[m_i * 2 + 1] = _mm512_dpbf16_ps(c512[m_i * 2 + 1], ma2_0, bbf16_2_hi);
c512[m_i * 2] = _mm512_dpbf16_ps(c512[m_i * 2], ma3_0, bbf16_3_lo);
c512[m_i * 2 + 1] = _mm512_dpbf16_ps(c512[m_i * 2 + 1], ma3_0, bbf16_3_hi);
c512[(m_i + 1) * 2] = _mm512_dpbf16_ps(c512[(m_i + 1) * 2], ma0_1, bbf16_0_lo);
c512[(m_i + 1) * 2 + 1] = _mm512_dpbf16_ps(c512[(m_i + 1) * 2 + 1], ma0_1, bbf16_0_hi);
c512[(m_i + 1) * 2] = _mm512_dpbf16_ps(c512[(m_i + 1) * 2], ma1_1, bbf16_1_lo);
c512[(m_i + 1) * 2 + 1] = _mm512_dpbf16_ps(c512[(m_i + 1) * 2 + 1], ma1_1, bbf16_1_hi);
c512[(m_i + 1) * 2] = _mm512_dpbf16_ps(c512[(m_i + 1) * 2], ma2_1, bbf16_2_lo);
c512[(m_i + 1) * 2 + 1] = _mm512_dpbf16_ps(c512[(m_i + 1) * 2 + 1], ma2_1, bbf16_2_hi);
c512[(m_i + 1) * 2] = _mm512_dpbf16_ps(c512[(m_i + 1) * 2], ma3_1, bbf16_3_lo);
c512[(m_i + 1) * 2 + 1] = _mm512_dpbf16_ps(c512[(m_i + 1) * 2 + 1], ma3_1, bbf16_3_hi);
}
// Handle remaining row
for (; m_i < m_block_end; m_i++) {
__m512bh ma0 = (__m512bh)_mm512_set1_epi32(*(int32_t*)&abf16[m_i * K_STEP + k_i * 2]);
__m512bh ma1 = (__m512bh)_mm512_set1_epi32(*(int32_t*)&abf16[m_i * K_STEP + (k_i + 1) * 2]);
__m512bh ma2 = (__m512bh)_mm512_set1_epi32(*(int32_t*)&abf16[m_i * K_STEP + (k_i + 2) * 2]);
__m512bh ma3 = (__m512bh)_mm512_set1_epi32(*(int32_t*)&abf16[m_i * K_STEP + (k_i + 3) * 2]);
c512[m_i * 2] = _mm512_dpbf16_ps(c512[m_i * 2], ma0, bbf16_0_lo);
c512[m_i * 2 + 1] = _mm512_dpbf16_ps(c512[m_i * 2 + 1], ma0, bbf16_0_hi);
c512[m_i * 2] = _mm512_dpbf16_ps(c512[m_i * 2], ma1, bbf16_1_lo);
c512[m_i * 2 + 1] = _mm512_dpbf16_ps(c512[m_i * 2 + 1], ma1, bbf16_1_hi);
c512[m_i * 2] = _mm512_dpbf16_ps(c512[m_i * 2], ma2, bbf16_2_lo);
c512[m_i * 2 + 1] = _mm512_dpbf16_ps(c512[m_i * 2 + 1], ma2, bbf16_2_hi);
c512[m_i * 2] = _mm512_dpbf16_ps(c512[m_i * 2], ma3, bbf16_3_lo);
c512[m_i * 2 + 1] = _mm512_dpbf16_ps(c512[m_i * 2 + 1], ma3, bbf16_3_hi);
}
}
}
}
};
/**
* @brief Per-channel FP8 GEMM function
*
* Unlike block-wise FP8 which applies scale per 128x128 block during computation,
* per-channel FP8 processes entire K dimension first, then applies per-channel scale at the end.
*/
template <typename K>
void float_mat_vec_perchannel(int m, int n, int k, typename K::BufferA* ba, typename K::BufferB* bb,
typename K::BufferC* bc, int ith, int nth) {
assert(n % K::N_STEP == 0);
assert(k % K::K_STEP == 0);
auto [n_start, n_end] = K::split_range_n(n, ith, nth);
for (int m_begin = 0; m_begin < m; m_begin += K::M_STEP) {
for (int n_begin = n_start; n_begin < n_end; n_begin += K::N_STEP) {
float* c = bc->get_submat(m, n, m_begin, n_begin);
// Process entire K dimension with K_BLOCKs
for (int k_block_begin = 0; k_block_begin < k; k_block_begin += K::K_BLOCK) {
K::avx_kernel_4(m, n, k, m_begin, n_begin, k_block_begin, c, ba, bb);
}
// Apply per-channel scale once after all K is processed
K::apply_scale_perchannel(m, n, m_begin, n_begin, c, bb);
}
}
}
inline void vec_mul_perchannel(int m, int n, int k, std::shared_ptr<GemmKernel224FP8PerChannel::BufferA> ba,
std::shared_ptr<GemmKernel224FP8PerChannel::BufferB> bb,
std::shared_ptr<GemmKernel224FP8PerChannel::BufferC> bc, int ith, int nth) {
float_mat_vec_perchannel<GemmKernel224FP8PerChannel>(m, n, k, ba.get(), bb.get(), bc.get(), ith, nth);
}
} // namespace amx
#endif // AMX_RAW_KERNELS_HPP

View File

@@ -223,7 +223,8 @@ struct QuantConfig {
std::string quant_method = "";
int bits = 0;
int group_size = 0;
bool zero_point;
bool zero_point = false;
bool per_channel = false; // Per-channel quantization (GLM-4.7-FP8 style)
};
struct GeneralMOEConfig {

View File

@@ -15,6 +15,14 @@ except (ImportError, AttributeError):
_HAS_AMX_SUPPORT = False
AMXInt4_MOE, AMXInt8_MOE, AMXInt4_KGroup_MOE, AMXFP8_MOE, AMXBF16_MOE = None, None, None, None, None
try:
from kt_kernel_ext.moe import AMXFP8PerChannel_MOE
_HAS_FP8_PERCHANNEL_SUPPORT = True
except (ImportError, AttributeError):
_HAS_FP8_PERCHANNEL_SUPPORT = False
AMXFP8PerChannel_MOE = None
from typing import Optional
@@ -304,7 +312,7 @@ class AMXMoEWrapper(BaseMoEWrapper):
class NativeMoEWrapper(BaseMoEWrapper):
"""Wrapper for RAWINT4/FP8/BF16 experts stored in compressed SafeTensor format."""
"""Wrapper for RAWINT4/FP8/FP8_PERCHANNEL/BF16 experts stored in compressed SafeTensor format."""
_native_loader_instance = None
@@ -330,6 +338,8 @@ class NativeMoEWrapper(BaseMoEWrapper):
raise RuntimeError("AMX backend with RAWINT4 support is not available.")
if method == "FP8" and AMXFP8_MOE is None:
raise RuntimeError("AMX backend with FP8 support is not available.")
if method == "FP8_PERCHANNEL" and not _HAS_FP8_PERCHANNEL_SUPPORT:
raise RuntimeError("AMX backend with FP8 per-channel support is not available.")
if method == "BF16" and AMXBF16_MOE is None:
raise RuntimeError("AMX backend with BF16 support is not available.")
@@ -354,6 +364,9 @@ class NativeMoEWrapper(BaseMoEWrapper):
NativeMoEWrapper._native_loader_instance = CompressedSafeTensorLoader(weight_path)
elif method == "FP8":
NativeMoEWrapper._native_loader_instance = FP8SafeTensorLoader(weight_path)
elif method == "FP8_PERCHANNEL":
# Use FP8SafeTensorLoader with per-channel scale format
NativeMoEWrapper._native_loader_instance = FP8SafeTensorLoader(weight_path, scale_suffix="weight_scale")
elif method == "BF16":
NativeMoEWrapper._native_loader_instance = BF16SafeTensorLoader(weight_path)
else:
@@ -408,6 +421,8 @@ class NativeMoEWrapper(BaseMoEWrapper):
assert self.gate_scales[0].dtype == torch.bfloat16, "Expected bf16 scales for RAWINT4"
elif self.method == "FP8":
assert self.gate_scales[0].dtype == torch.float32, "Expected float32 scales for FP8"
elif self.method == "FP8_PERCHANNEL":
assert self.gate_scales[0].dtype == torch.float32, "Expected float32 scales for FP8_PERCHANNEL"
t2 = time.time()
@@ -462,6 +477,11 @@ class NativeMoEWrapper(BaseMoEWrapper):
moe_config.quant_config.group_size = 128
moe_config.quant_config.zero_point = False
self.moe = AMXFP8_MOE(moe_config)
elif self.method == "FP8_PERCHANNEL":
moe_config.quant_config.bits = 8
moe_config.quant_config.per_channel = True
moe_config.quant_config.zero_point = False
self.moe = AMXFP8PerChannel_MOE(moe_config)
elif self.method == "BF16":
# BF16 has no quantization config needed
self.moe = AMXBF16_MOE(moe_config)

View File

@@ -244,6 +244,10 @@ class FP8SafeTensorLoader(SafeTensorLoader):
- DeepSeek style: {base}.mlp.experts.{id}.{gate,up,down}_proj.weight
- Mixtral/MiniMax style: {base}.block_sparse_moe.experts.{id}.{w1,w3,w2}.weight
Supported scale formats (auto-detected):
- Block-wise: weight_scale_inv (DeepSeek FP8)
- Per-channel: weight_scale (GLM-4.7-FP8)
The format is auto-detected during initialization.
"""
@@ -253,13 +257,28 @@ class FP8SafeTensorLoader(SafeTensorLoader):
"mixtral": ("{base}.block_sparse_moe.experts", "w1", "w3", "w2"),
}
def __init__(self, file_path: str):
def __init__(self, file_path: str, scale_suffix: str = None):
"""Initialize FP8 loader with optional scale suffix override.
Args:
file_path: Path to safetensor files
scale_suffix: Optional scale key suffix. If None, auto-detect between
'weight_scale_inv' (block-wise) and 'weight_scale' (per-channel).
"""
super().__init__(file_path)
self._detected_format = None
self._scale_suffix = scale_suffix # None means auto-detect
# Set per_channel based on explicit scale_suffix if provided
if scale_suffix == "weight_scale":
self._is_per_channel = True
elif scale_suffix == "weight_scale_inv":
self._is_per_channel = False
else:
self._is_per_channel = False # Will be updated in _detect_format if auto-detect
self._detect_format()
def _detect_format(self):
"""Auto-detect the MoE naming format by checking tensor keys."""
"""Auto-detect the MoE naming format and scale format by checking tensor keys."""
# Sample some tensor names to detect format
sample_keys = list(self.tensor_file_map.keys())[:1000]
@@ -272,15 +291,42 @@ class FP8SafeTensorLoader(SafeTensorLoader):
if "block_sparse_moe.experts" in key and fmt_name == "mixtral":
self._detected_format = fmt_name
print(f"[FP8SafeTensorLoader] Detected format: {fmt_name}")
return
break
elif "mlp.experts" in key and "block_sparse_moe" not in key and fmt_name == "deepseek":
self._detected_format = fmt_name
print(f"[FP8SafeTensorLoader] Detected format: {fmt_name}")
return
break
if self._detected_format:
break
# Default to deepseek if no format detected
self._detected_format = "deepseek"
print("[FP8SafeTensorLoader] No MoE format detected, defaulting to: deepseek")
if not self._detected_format:
self._detected_format = "deepseek"
print("[FP8SafeTensorLoader] No MoE format detected, defaulting to: deepseek")
# Auto-detect scale suffix if not specified
if self._scale_suffix is None:
_, gate, _, _ = self.MOE_FORMATS[self._detected_format]
# Check for per-channel scale (weight_scale) vs block-wise (weight_scale_inv)
for key in sample_keys:
if f".{gate}.weight_scale_inv" in key:
self._scale_suffix = "weight_scale_inv"
self._is_per_channel = False
print("[FP8SafeTensorLoader] Detected scale format: block-wise (weight_scale_inv)")
return
elif f".{gate}.weight_scale" in key and "weight_scale_inv" not in key:
self._scale_suffix = "weight_scale"
self._is_per_channel = True
print("[FP8SafeTensorLoader] Detected scale format: per-channel (weight_scale)")
return
# Default to weight_scale_inv
self._scale_suffix = "weight_scale_inv"
self._is_per_channel = False
print("[FP8SafeTensorLoader] No scale format detected, defaulting to: weight_scale_inv")
else:
# Scale suffix was explicitly provided
scale_type = "per-channel" if self._is_per_channel else "block-wise"
print(f"[FP8SafeTensorLoader] Using explicit scale format: {scale_type} ({self._scale_suffix})")
def _get_experts_prefix(self, base_key: str) -> str:
"""Get the experts prefix based on detected format."""
@@ -305,7 +351,11 @@ class FP8SafeTensorLoader(SafeTensorLoader):
return tensor.to(device)
def load_experts(self, base_key: str, device: str = "cpu"):
"""Load FP8 expert weights and their block-wise scale_inv tensors."""
"""Load FP8 expert weights and their scale tensors.
Supports both block-wise (weight_scale_inv) and per-channel (weight_scale) formats.
Per-channel scales are squeezed from [N, 1] to [N] if needed.
"""
experts_prefix = self._get_experts_prefix(base_key)
gate_name, up_name, down_name = self._get_proj_names()
@@ -327,16 +377,30 @@ class FP8SafeTensorLoader(SafeTensorLoader):
gate_w_key = f"{experts_prefix}.{exp_id}.{gate_name}.weight"
up_w_key = f"{experts_prefix}.{exp_id}.{up_name}.weight"
down_w_key = f"{experts_prefix}.{exp_id}.{down_name}.weight"
gate_s_key = f"{experts_prefix}.{exp_id}.{gate_name}.weight_scale_inv"
up_s_key = f"{experts_prefix}.{exp_id}.{up_name}.weight_scale_inv"
down_s_key = f"{experts_prefix}.{exp_id}.{down_name}.weight_scale_inv"
gate_s_key = f"{experts_prefix}.{exp_id}.{gate_name}.{self._scale_suffix}"
up_s_key = f"{experts_prefix}.{exp_id}.{up_name}.{self._scale_suffix}"
down_s_key = f"{experts_prefix}.{exp_id}.{down_name}.{self._scale_suffix}"
gate_weights[exp_id] = self.load_tensor(gate_w_key, device).contiguous()
up_weights[exp_id] = self.load_tensor(up_w_key, device).contiguous()
down_weights[exp_id] = self.load_tensor(down_w_key, device).contiguous()
gate_scales[exp_id] = self.load_tensor(gate_s_key, device).contiguous()
up_scales[exp_id] = self.load_tensor(up_s_key, device).contiguous()
down_scales[exp_id] = self.load_tensor(down_s_key, device).contiguous()
gate_scale = self.load_tensor(gate_s_key, device)
up_scale = self.load_tensor(up_s_key, device)
down_scale = self.load_tensor(down_s_key, device)
# For per-channel scales, squeeze [N, 1] -> [N] if needed
if self._is_per_channel:
if gate_scale.dim() == 2 and gate_scale.shape[1] == 1:
gate_scale = gate_scale.squeeze(1)
if up_scale.dim() == 2 and up_scale.shape[1] == 1:
up_scale = up_scale.squeeze(1)
if down_scale.dim() == 2 and down_scale.shape[1] == 1:
down_scale = down_scale.squeeze(1)
gate_scales[exp_id] = gate_scale.contiguous()
up_scales[exp_id] = up_scale.contiguous()
down_scales[exp_id] = down_scale.contiguous()
return {
"gate": gate_weights,
@@ -347,6 +411,103 @@ class FP8SafeTensorLoader(SafeTensorLoader):
"down_scale": down_scales,
}
def is_per_channel(self) -> bool:
"""Return True if using per-channel quantization, False for block-wise."""
return self._is_per_channel
class BF16SafeTensorLoader(SafeTensorLoader):
"""Loader for native BF16 expert weights (no quantization, no scales).
Supported formats:
- DeepSeek style: {base}.mlp.experts.{id}.{gate,up,down}_proj.weight
- Mixtral/MiniMax style: {base}.block_sparse_moe.experts.{id}.{w1,w3,w2}.weight
The format is auto-detected during initialization.
"""
MOE_FORMATS = {
"deepseek": ("{base}.mlp.experts", "gate_proj", "up_proj", "down_proj"),
"mixtral": ("{base}.block_sparse_moe.experts", "w1", "w3", "w2"),
}
def __init__(self, file_path: str):
super().__init__(file_path)
self._detected_format = None
self._detect_format()
def _detect_format(self):
"""Auto-detect the MoE naming format by checking tensor keys."""
sample_keys = list(self.tensor_file_map.keys())[:1000]
for fmt_name, (path_tpl, gate, up, down) in self.MOE_FORMATS.items():
for key in sample_keys:
if ".experts." in key and f".{gate}.weight" in key:
if "block_sparse_moe.experts" in key and fmt_name == "mixtral":
self._detected_format = fmt_name
print(f"[BF16SafeTensorLoader] Detected format: {fmt_name}")
return
elif "mlp.experts" in key and "block_sparse_moe" not in key and fmt_name == "deepseek":
self._detected_format = fmt_name
print(f"[BF16SafeTensorLoader] Detected format: {fmt_name}")
return
self._detected_format = "deepseek"
print("[BF16SafeTensorLoader] No MoE format detected, defaulting to: deepseek")
def _get_experts_prefix(self, base_key: str) -> str:
"""Get the experts prefix based on detected format."""
path_tpl, _, _, _ = self.MOE_FORMATS[self._detected_format]
return path_tpl.format(base=base_key)
def _get_proj_names(self):
"""Get projection names (gate, up, down) based on detected format."""
_, gate, up, down = self.MOE_FORMATS[self._detected_format]
return gate, up, down
def load_tensor(self, key: str, device: str = "cpu"):
if key not in self.tensor_file_map:
raise KeyError(f"Key {key} not found in Safetensor files")
file = self.tensor_file_map[key]
f = self.file_handle_map.get(file)
if f is None:
raise FileNotFoundError(f"File {file} not found in Safetensor files")
tensor = f.get_tensor(key)
if device == "cpu":
return tensor
return tensor.to(device)
def load_experts(self, base_key: str, device: str = "cpu"):
"""Load BF16 expert weights (no scales needed)."""
experts_prefix = self._get_experts_prefix(base_key)
gate_name, up_name, down_name = self._get_proj_names()
expert_count = 0
while self.has_tensor(f"{experts_prefix}.{expert_count}.{gate_name}.weight"):
expert_count += 1
if expert_count == 0:
raise ValueError(f"No experts found for key {experts_prefix}")
gate_weights = [None] * expert_count
up_weights = [None] * expert_count
down_weights = [None] * expert_count
for exp_id in range(expert_count):
gate_w_key = f"{experts_prefix}.{exp_id}.{gate_name}.weight"
up_w_key = f"{experts_prefix}.{exp_id}.{up_name}.weight"
down_w_key = f"{experts_prefix}.{exp_id}.{down_name}.weight"
gate_weights[exp_id] = self.load_tensor(gate_w_key, device).contiguous()
up_weights[exp_id] = self.load_tensor(up_w_key, device).contiguous()
down_weights[exp_id] = self.load_tensor(down_w_key, device).contiguous()
return {
"gate": gate_weights,
"up": up_weights,
"down": down_weights,
}
class BF16SafeTensorLoader(SafeTensorLoader):
"""Loader for native BF16 expert weights (no quantization, no scales).