Files
ktransformers/kt-kernel/examples/test_fp8_perchannel_moe.py
Oql 6277da4c2b support GLM 4.7 (#1791)
support GLM 4.7
2026-01-13 17:36:25 +08:00

409 lines
14 KiB
Python

"""
Test script for FP8 Per-Channel MoE kernel validation (GLM-4.7-FP8 style).
This script:
1. Generates random BF16 weights
2. Quantizes them to FP8 format with per-channel scales (one scale per output channel)
3. Runs the FP8 Per-Channel MoE kernel
4. Compares results with PyTorch reference using dequantized BF16 weights
FP8 Per-Channel format notes:
- Weight: FP8 (E4M3) stored as uint8, shape [expert_num, n, k]
- Scale: FP32, shape [expert_num, n] (one scale per output row)
"""
import os
import sys
sys.path.insert(0, os.path.dirname(__file__) + "/../build")
import torch
from kt_kernel import kt_kernel_ext
torch.manual_seed(42)
# Model config
hidden_size = 3072
intermediate_size = 1536
max_len = 25600
expert_num = 16
num_experts_per_tok = 8
qlen = 100
layer_num = 1
CPUInfer = kt_kernel_ext.CPUInfer(40)
validation_iter = 1
debug_print_count = 16
physical_to_logical_map = torch.tensor(data=range(expert_num), device="cpu", dtype=torch.int64).contiguous()
def act_fn(x):
"""SiLU activation function"""
return x / (1.0 + torch.exp(-x))
def mlp_torch(input, gate_proj, up_proj, down_proj):
"""Reference MLP computation in PyTorch"""
gate_buf = torch.mm(input, gate_proj.t())
up_buf = torch.mm(input, up_proj.t())
intermediate = act_fn(gate_buf) * up_buf
ret = torch.mm(intermediate, down_proj.t())
return ret
def moe_torch(input, expert_ids, weights, gate_proj, up_proj, down_proj):
"""Reference MoE computation in PyTorch"""
cnts = expert_ids.new_zeros((expert_ids.shape[0], expert_num))
cnts.scatter_(1, expert_ids, 1)
tokens_per_expert = cnts.sum(dim=0)
idxs = expert_ids.view(-1).argsort()
sorted_tokens = input[idxs // expert_ids.shape[1]]
outputs = []
start_idx = 0
for i, num_tokens in enumerate(tokens_per_expert):
end_idx = start_idx + num_tokens
if num_tokens == 0:
continue
tokens_for_this_expert = sorted_tokens[start_idx:end_idx]
expert_out = mlp_torch(tokens_for_this_expert, gate_proj[i], up_proj[i], down_proj[i])
outputs.append(expert_out)
start_idx = end_idx
outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0)
new_x = torch.empty_like(outs)
new_x[idxs] = outs
t_output = (
new_x.view(*expert_ids.shape, -1)
.type(weights.dtype)
.mul_(weights.unsqueeze(dim=-1))
.sum(dim=1)
.type(new_x.dtype)
)
return t_output
# FP8 E4M3 constants
FP8_E4M3_MAX = 448.0 # Maximum representable value in FP8 E4M3
def fp8_e4m3_to_float(fp8_val: int) -> float:
"""
Convert FP8 E4M3 value to float.
FP8 E4M3 format: 1 sign bit, 4 exponent bits, 3 mantissa bits
"""
sign = (fp8_val >> 7) & 1
exp = (fp8_val >> 3) & 0xF
mant = fp8_val & 0x7
if exp == 0:
# Subnormal or zero
if mant == 0:
return -0.0 if sign else 0.0
# Subnormal: value = (-1)^sign * 2^(-6) * (0.mant)
return ((-1) ** sign) * (2**-6) * (mant / 8.0)
elif exp == 15:
# NaN (FP8 E4M3 doesn't have Inf, all exp=15 are NaN)
return float("nan")
else:
# Normal: value = (-1)^sign * 2^(exp-7) * (1.mant)
return ((-1) ** sign) * (2 ** (exp - 7)) * (1.0 + mant / 8.0)
def float_to_fp8_e4m3(val: float) -> int:
"""
Convert float to FP8 E4M3 value.
"""
if val != val: # NaN
return 0x7F # NaN representation
sign = 1 if val < 0 else 0
val = abs(val)
if val == 0:
return sign << 7
# Clamp to max representable value
val = min(val, FP8_E4M3_MAX)
# Find exponent
import math
if val < 2**-9: # Subnormal threshold
# Subnormal
mant = int(round(val / (2**-9)))
mant = min(mant, 7)
return (sign << 7) | mant
exp = int(math.floor(math.log2(val))) + 7
exp = max(1, min(exp, 14)) # Clamp exponent to valid range
# Calculate mantissa
mant = int(round((val / (2 ** (exp - 7)) - 1.0) * 8))
mant = max(0, min(mant, 7))
# Handle overflow to next exponent
if mant > 7:
mant = 0
exp += 1
if exp > 14:
exp = 14
mant = 7
return (sign << 7) | (exp << 3) | mant
def quantize_to_fp8_perchannel(weights: torch.Tensor):
"""
Quantize BF16/FP32 weights to FP8 with per-channel scaling.
Args:
weights: [expert_num, n, k] tensor in BF16/FP32
Returns:
fp8_weights: [expert_num, n, k] uint8 tensor
scales: [expert_num, n] FP32 tensor (one scale per output row)
"""
weights_f32 = weights.to(torch.float32)
e, n, k = weights_f32.shape
# Calculate max abs per row (per output channel)
max_abs = weights_f32.abs().amax(dim=-1, keepdim=True) # [e, n, 1]
max_abs = torch.clamp(max_abs, min=1e-12)
# Scale to FP8 range: scale = max_abs / FP8_MAX
scales = (max_abs / FP8_E4M3_MAX).squeeze(-1) # [e, n]
# Quantize: q = round(val / scale)
scaled = weights_f32 / (scales.unsqueeze(-1) + 1e-12)
# Clamp to FP8 representable range
scaled = scaled.clamp(-FP8_E4M3_MAX, FP8_E4M3_MAX)
# Vectorized FP8 quantization
fp8_q = torch.zeros_like(scaled, dtype=torch.uint8)
sign_mask = (scaled < 0).to(torch.uint8) << 7
abs_scaled = scaled.abs()
# Handle different ranges
# Subnormal: 0 < |x| < 2^-6
subnormal_mask = (abs_scaled > 0) & (abs_scaled < 2**-6)
subnormal_mant = (abs_scaled / (2**-9)).round().clamp(0, 7).to(torch.uint8)
# Normal values
normal_mask = abs_scaled >= 2**-6
log2_val = torch.log2(abs_scaled.clamp(min=2**-9))
exp = (log2_val.floor() + 7).clamp(1, 14).to(torch.int32)
mant = ((abs_scaled / (2.0 ** (exp.float() - 7)) - 1.0) * 8).round().clamp(0, 7).to(torch.uint8)
# Combine
fp8_q = torch.where(subnormal_mask, sign_mask | subnormal_mant, fp8_q)
fp8_q = torch.where(normal_mask, sign_mask | (exp.to(torch.uint8) << 3) | mant, fp8_q)
return fp8_q.contiguous(), scales.to(torch.float32).contiguous()
def dequantize_fp8_perchannel(fp8_weights: torch.Tensor, scales: torch.Tensor):
"""
Dequantize FP8 weights back to BF16 for reference computation.
Args:
fp8_weights: [expert_num, n, k] uint8 tensor
scales: [expert_num, n] FP32 tensor
Returns:
dequantized: [expert_num, n, k] BF16 tensor
"""
# Build lookup table for FP8 E4M3 -> float
fp8_lut = torch.tensor([fp8_e4m3_to_float(i) for i in range(256)], dtype=torch.float32)
# Use lookup table
fp8_float = fp8_lut[fp8_weights.to(torch.int64)]
# Apply per-channel scales
scales_expanded = scales.unsqueeze(-1) # [e, n, 1]
dequantized = fp8_float * scales_expanded
return dequantized.to(torch.bfloat16).contiguous()
def build_random_fp8_perchannel_weights():
"""
Generate random BF16 weights and quantize to FP8 with per-channel scales.
Returns:
dict with fp8 weights, scales, and original bf16 for reference
"""
torch.manual_seed(42)
# Generate random BF16 weights with small values
gate_proj = (torch.randn((expert_num, intermediate_size, hidden_size), dtype=torch.float32) / 100.0).to(
torch.bfloat16
)
up_proj = (torch.randn((expert_num, intermediate_size, hidden_size), dtype=torch.float32) / 100.0).to(
torch.bfloat16
)
down_proj = (torch.randn((expert_num, hidden_size, intermediate_size), dtype=torch.float32) / 100.0).to(
torch.bfloat16
)
# Quantize to FP8 with per-channel scales
gate_fp8, gate_scales = quantize_to_fp8_perchannel(gate_proj)
up_fp8, up_scales = quantize_to_fp8_perchannel(up_proj)
down_fp8, down_scales = quantize_to_fp8_perchannel(down_proj)
# Dequantize for reference computation
gate_deq = dequantize_fp8_perchannel(gate_fp8, gate_scales)
up_deq = dequantize_fp8_perchannel(up_fp8, up_scales)
down_deq = dequantize_fp8_perchannel(down_fp8, down_scales)
print(f"FP8 Per-Channel weights shape: gate={gate_fp8.shape}, up={up_fp8.shape}, down={down_fp8.shape}")
print(f"Per-Channel scales shape: gate={gate_scales.shape}, up={up_scales.shape}, down={down_scales.shape}")
# Debug: Print FP8 weight and scale info for expert 0
print("\n=== DEBUG: FP8 Per-Channel Weight and Scale Info (Expert 0) ===")
print(f"gate_fp8[0] first 8x8 block:")
for i in range(8):
print(f" row {i}: {gate_fp8[0, i, :8].numpy().tobytes().hex(' ')}")
print(f"gate_fp8[0] stats: min={gate_fp8[0].min()}, max={gate_fp8[0].max()}")
print(f"gate_scales[0] first 8 channels: {gate_scales[0, :8]}")
print(f"gate_scales[0] stats: min={gate_scales[0].min():.6f}, max={gate_scales[0].max():.6f}")
return {
"gate_fp8": gate_fp8.contiguous(),
"up_fp8": up_fp8.contiguous(),
"down_fp8": down_fp8.contiguous(),
"gate_scales": gate_scales.contiguous(),
"up_scales": up_scales.contiguous(),
"down_scales": down_scales.contiguous(),
"gate_deq": gate_deq.contiguous(),
"up_deq": up_deq.contiguous(),
"down_deq": down_deq.contiguous(),
}
def build_moes_from_fp8_perchannel_data(fp8_data: dict):
"""
Build FP8 Per-Channel MoE modules from quantized data.
"""
moes = []
with torch.inference_mode(mode=True):
for _ in range(layer_num):
config = kt_kernel_ext.moe.MOEConfig(expert_num, num_experts_per_tok, hidden_size, intermediate_size, 0)
config.max_len = max_len
config.quant_config.bits = 8
config.quant_config.group_size = 0 # Not used for per-channel
config.quant_config.zero_point = False
config.quant_config.per_channel = True # Enable per-channel mode
# Set FP8 weight pointers
config.gate_proj = fp8_data["gate_fp8"].data_ptr()
config.up_proj = fp8_data["up_fp8"].data_ptr()
config.down_proj = fp8_data["down_fp8"].data_ptr()
# Set per-channel scale pointers
config.gate_scale = fp8_data["gate_scales"].data_ptr()
config.up_scale = fp8_data["up_scales"].data_ptr()
config.down_scale = fp8_data["down_scales"].data_ptr()
config.pool = CPUInfer.backend_
moe = kt_kernel_ext.moe.AMXFP8PerChannel_MOE(config)
CPUInfer.submit(moe.load_weights_task(physical_to_logical_map.data_ptr()))
CPUInfer.sync()
moes.append(moe)
return moes
def run_fp8_perchannel_moe_test():
"""
Run FP8 Per-Channel MoE validation test.
"""
print("\n" + "=" * 70)
print("FP8 Per-Channel MoE Kernel Validation Test")
print("=" * 70)
# Build FP8 per-channel weights
print("\nGenerating and quantizing weights with per-channel scales...")
fp8_data = build_random_fp8_perchannel_weights()
# Build MoE modules
print("\nBuilding FP8 Per-Channel MoE modules...")
moes = build_moes_from_fp8_perchannel_data(fp8_data)
# Get dequantized weights for reference
gate_deq = fp8_data["gate_deq"]
up_deq = fp8_data["up_deq"]
down_deq = fp8_data["down_deq"]
diffs = []
with torch.inference_mode(mode=True):
for i in range(validation_iter):
torch.manual_seed(100 + i)
bsz_tensor = torch.tensor([qlen], device="cpu")
expert_ids = torch.stack(
[torch.randperm(expert_num)[:num_experts_per_tok] for _ in range(qlen)]
).contiguous()
weights = torch.randn((qlen, num_experts_per_tok), dtype=torch.float32).contiguous() / 100
input_tensor = torch.randn((qlen, hidden_size), dtype=torch.bfloat16).contiguous() * 1.5
output = torch.empty((qlen, hidden_size), dtype=torch.bfloat16).contiguous()
moe = moes[i % layer_num]
CPUInfer.submit(
moe.forward_task(
bsz_tensor.data_ptr(),
num_experts_per_tok,
expert_ids.data_ptr(),
weights.data_ptr(),
input_tensor.data_ptr(),
output.data_ptr(),
False,
)
)
CPUInfer.sync()
assert not torch.isnan(output).any(), "NaN values detected in CPU expert output."
assert not torch.isinf(output).any(), "Inf values detected in CPU expert output."
# Reference computation using dequantized weights
t_output = moe_torch(input_tensor, expert_ids, weights, gate_deq, up_deq, down_deq)
t_output_flat = t_output.flatten()
output_flat = output.flatten()
diff = torch.mean(torch.abs(output_flat - t_output_flat)) / (torch.mean(torch.abs(t_output_flat)) + 1e-12)
diffs.append(diff.item())
print(f"Iteration {i}: relative L1 diff = {diff:.6f}")
if i < 3: # Print detailed output for first few iterations
print(f" kernel output: {output_flat[:debug_print_count]}")
print(f" torch output: {t_output_flat[:debug_print_count]}")
mean_diff = float(sum(diffs) / len(diffs))
max_diff = float(max(diffs))
min_diff = float(min(diffs))
print("\n" + "=" * 70)
print("FP8 Per-Channel MoE Test Results")
print("=" * 70)
print(f"Mean relative L1 diff: {mean_diff*100:.4f}%")
print(f"Max relative L1 diff: {max_diff*100:.4f}%")
print(f"Min relative L1 diff: {min_diff*100:.4f}%")
# Pass/Fail criteria
threshold = 15.0 # 15% relative error threshold for FP8
if mean_diff * 100 < threshold:
print(f"\nPASS: Mean error {mean_diff*100:.4f}% < {threshold}% threshold")
else:
print(f"\nFAIL: Mean error {mean_diff*100:.4f}% >= {threshold}% threshold")
return {"mean": mean_diff, "max": max_diff, "min": min_diff}
if __name__ == "__main__":
run_fp8_perchannel_moe_test()