Files
ktransformers/kt-kernel/examples/test_fp8_moe.py
ErvinXie d8046e1bb4 Kt minimax (#1742)
[feat]: fp8 kernel and kt-cli support
2025-12-24 15:39:44 +08:00

458 lines
16 KiB
Python

"""
Test script for GemmKernel224FP8 (FP8 MoE) kernel validation.
This script:
1. Generates random BF16 weights
2. Quantizes them to FP8 format with 128x128 block-wise scales
3. Runs the FP8 MoE kernel
4. Compares results with PyTorch reference using dequantized BF16 weights
FP8 format notes:
- Weight: FP8 (E4M3) stored as uint8, shape [expert_num, n, k]
- Scale: FP32, shape [expert_num, n // group_size, k // group_size], group_size=128
"""
import os
import sys
sys.path.insert(0, os.path.dirname(__file__) + "/../build")
import torch
import kt_kernel
torch.manual_seed(42)
# Model config
hidden_size = 3072
intermediate_size = 1536
max_len = 25600
expert_num = 16
num_experts_per_tok = 8
qlen = 100
layer_num = 1
CPUInfer = kt_kernel_ext.CPUInfer(40)
validation_iter = 1
fp8_group_size = 128 # FP8 uses 128x128 block quantization
debug_print_count = 16
physical_to_logical_map = torch.tensor(data=range(expert_num), device="cpu", dtype=torch.int64).contiguous()
def act_fn(x):
"""SiLU activation function"""
return x / (1.0 + torch.exp(-x))
def mlp_torch(input, gate_proj, up_proj, down_proj):
"""Reference MLP computation in PyTorch"""
gate_buf = torch.mm(input, gate_proj.t())
up_buf = torch.mm(input, up_proj.t())
intermediate = act_fn(gate_buf) * up_buf
ret = torch.mm(intermediate, down_proj.t())
return ret
def moe_torch(input, expert_ids, weights, gate_proj, up_proj, down_proj):
"""Reference MoE computation in PyTorch"""
cnts = expert_ids.new_zeros((expert_ids.shape[0], expert_num))
cnts.scatter_(1, expert_ids, 1)
tokens_per_expert = cnts.sum(dim=0)
idxs = expert_ids.view(-1).argsort()
sorted_tokens = input[idxs // expert_ids.shape[1]]
outputs = []
start_idx = 0
for i, num_tokens in enumerate(tokens_per_expert):
end_idx = start_idx + num_tokens
if num_tokens == 0:
continue
tokens_for_this_expert = sorted_tokens[start_idx:end_idx]
expert_out = mlp_torch(tokens_for_this_expert, gate_proj[i], up_proj[i], down_proj[i])
outputs.append(expert_out)
start_idx = end_idx
outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0)
new_x = torch.empty_like(outs)
new_x[idxs] = outs
t_output = (
new_x.view(*expert_ids.shape, -1)
.type(weights.dtype)
.mul_(weights.unsqueeze(dim=-1))
.sum(dim=1)
.type(new_x.dtype)
)
return t_output
# FP8 E4M3 constants
FP8_E4M3_MAX = 448.0 # Maximum representable value in FP8 E4M3
def fp8_e4m3_to_float(fp8_val: int) -> float:
"""
Convert FP8 E4M3 value to float.
FP8 E4M3 format: 1 sign bit, 4 exponent bits, 3 mantissa bits
"""
sign = (fp8_val >> 7) & 1
exp = (fp8_val >> 3) & 0xF
mant = fp8_val & 0x7
if exp == 0:
# Subnormal or zero
if mant == 0:
return -0.0 if sign else 0.0
# Subnormal: value = (-1)^sign * 2^(-6) * (0.mant)
return ((-1) ** sign) * (2**-6) * (mant / 8.0)
elif exp == 15:
# NaN (FP8 E4M3 doesn't have Inf, all exp=15 are NaN)
return float("nan")
else:
# Normal: value = (-1)^sign * 2^(exp-7) * (1.mant)
return ((-1) ** sign) * (2 ** (exp - 7)) * (1.0 + mant / 8.0)
def float_to_fp8_e4m3(val: float) -> int:
"""
Convert float to FP8 E4M3 value.
"""
if val != val: # NaN
return 0x7F # NaN representation
sign = 1 if val < 0 else 0
val = abs(val)
if val == 0:
return sign << 7
# Clamp to max representable value
val = min(val, FP8_E4M3_MAX)
# Find exponent
import math
if val < 2**-9: # Subnormal threshold
# Subnormal
mant = int(round(val / (2**-9)))
mant = min(mant, 7)
return (sign << 7) | mant
exp = int(math.floor(math.log2(val))) + 7
exp = max(1, min(exp, 14)) # Clamp exponent to valid range
# Calculate mantissa
mant = int(round((val / (2 ** (exp - 7)) - 1.0) * 8))
mant = max(0, min(mant, 7))
# Handle overflow to next exponent
if mant > 7:
mant = 0
exp += 1
if exp > 14:
exp = 14
mant = 7
return (sign << 7) | (exp << 3) | mant
def quantize_to_fp8_blockwise(weights: torch.Tensor, group_size: int = 128):
"""
Quantize BF16/FP32 weights to FP8 with block-wise scaling.
Args:
weights: [expert_num, n, k] tensor in BF16/FP32
group_size: Block size for quantization (default 128 for DeepSeek)
Returns:
fp8_weights: [expert_num, n, k] uint8 tensor
scales: [expert_num, n // group_size, k // group_size] BF16 tensor (scale_inv)
"""
weights_f32 = weights.to(torch.float32)
e, n, k = weights_f32.shape
assert n % group_size == 0, f"n ({n}) must be divisible by group_size ({group_size})"
assert k % group_size == 0, f"k ({k}) must be divisible by group_size ({group_size})"
n_blocks = n // group_size
k_blocks = k // group_size
# Reshape to [e, n_blocks, group_size, k_blocks, group_size]
reshaped = weights_f32.view(e, n_blocks, group_size, k_blocks, group_size)
# Move to [e, n_blocks, k_blocks, group_size, group_size] for block processing
reshaped = reshaped.permute(0, 1, 3, 2, 4)
# Calculate max abs per block
max_abs = reshaped.abs().amax(dim=(-2, -1), keepdim=True)
max_abs = torch.clamp(max_abs, min=1e-12)
# Scale to FP8 range: scale = max_abs / FP8_MAX
# We store scale_inv = scale (for dequantization: fp8 * scale)
scales = (max_abs / FP8_E4M3_MAX).squeeze(-1).squeeze(-1) # [e, n_blocks, k_blocks]
# Quantize: q = round(val / scale)
scaled = reshaped / (scales.unsqueeze(-1).unsqueeze(-1) + 1e-12)
# Convert to FP8 E4M3 using vectorized approach
# Clamp to FP8 representable range
scaled = scaled.clamp(-FP8_E4M3_MAX, FP8_E4M3_MAX)
# Simple quantization: round to nearest representable FP8 value
# For simplicity, we use a lookup table approach
fp8_q = torch.zeros_like(scaled, dtype=torch.uint8)
# Vectorized FP8 quantization
sign_mask = (scaled < 0).to(torch.uint8) << 7
abs_scaled = scaled.abs()
# Handle different ranges
# Subnormal: 0 < |x| < 2^-6
subnormal_mask = (abs_scaled > 0) & (abs_scaled < 2**-6)
subnormal_mant = (abs_scaled / (2**-9)).round().clamp(0, 7).to(torch.uint8)
# Normal values
normal_mask = abs_scaled >= 2**-6
log2_val = torch.log2(abs_scaled.clamp(min=2**-9))
exp = (log2_val.floor() + 7).clamp(1, 14).to(torch.int32)
mant = ((abs_scaled / (2.0 ** (exp.float() - 7)) - 1.0) * 8).round().clamp(0, 7).to(torch.uint8)
# Combine
fp8_q = torch.where(subnormal_mask, sign_mask | subnormal_mant, fp8_q)
fp8_q = torch.where(normal_mask, sign_mask | (exp.to(torch.uint8) << 3) | mant, fp8_q)
# Reshape back to [e, n, k]
fp8_q = fp8_q.permute(0, 1, 3, 2, 4).reshape(e, n, k)
# Scales shape: [e, n_blocks, k_blocks] -> store as [e, n_blocks, k_blocks]
scales_fp32 = scales.to(torch.float32).contiguous()
return fp8_q.contiguous(), scales_fp32
def dequantize_fp8_blockwise(fp8_weights: torch.Tensor, scales: torch.Tensor, group_size: int = 128):
"""
Dequantize FP8 weights back to BF16 for reference computation.
Args:
fp8_weights: [expert_num, n, k] uint8 tensor
scales: [expert_num, n // group_size, k // group_size] BF16 tensor
group_size: Block size
Returns:
dequantized: [expert_num, n, k] BF16 tensor
"""
e, n, k = fp8_weights.shape
n_blocks = n // group_size
k_blocks = k // group_size
# Convert FP8 to float
# Build lookup table for FP8 E4M3 -> float
fp8_lut = torch.tensor([fp8_e4m3_to_float(i) for i in range(256)], dtype=torch.float32)
# Use lookup table
fp8_float = fp8_lut[fp8_weights.to(torch.int64)]
# Reshape for block-wise scaling
fp8_reshaped = fp8_float.view(e, n_blocks, group_size, k_blocks, group_size)
fp8_reshaped = fp8_reshaped.permute(0, 1, 3, 2, 4) # [e, n_blocks, k_blocks, group_size, group_size]
# Apply scales
scales_f32 = scales.to(torch.float32).unsqueeze(-1).unsqueeze(-1) # [e, n_blocks, k_blocks, 1, 1]
dequantized = fp8_reshaped * scales_f32
# Reshape back
dequantized = dequantized.permute(0, 1, 3, 2, 4).reshape(e, n, k)
return dequantized.to(torch.bfloat16).contiguous()
def build_random_fp8_weights():
"""
Generate random BF16 weights and quantize to FP8.
Returns:
dict with fp8 weights, scales, and original bf16 for reference
"""
torch.manual_seed(42)
# Generate random BF16 weights with small values
gate_proj = (torch.randn((expert_num, intermediate_size, hidden_size), dtype=torch.float32) / 100.0).to(
torch.bfloat16
)
up_proj = (torch.randn((expert_num, intermediate_size, hidden_size), dtype=torch.float32) / 100.0).to(
torch.bfloat16
)
down_proj = (torch.randn((expert_num, hidden_size, intermediate_size), dtype=torch.float32) / 100.0).to(
torch.bfloat16
)
# Quantize to FP8
gate_fp8, gate_scales = quantize_to_fp8_blockwise(gate_proj, fp8_group_size)
up_fp8, up_scales = quantize_to_fp8_blockwise(up_proj, fp8_group_size)
down_fp8, down_scales = quantize_to_fp8_blockwise(down_proj, fp8_group_size)
# Dequantize for reference computation
gate_deq = dequantize_fp8_blockwise(gate_fp8, gate_scales, fp8_group_size)
up_deq = dequantize_fp8_blockwise(up_fp8, up_scales, fp8_group_size)
down_deq = dequantize_fp8_blockwise(down_fp8, down_scales, fp8_group_size)
print(f"FP8 weights shape: gate={gate_fp8.shape}, up={up_fp8.shape}, down={down_fp8.shape}")
print(f"Scales shape: gate={gate_scales.shape}, up={up_scales.shape}, down={down_scales.shape}")
# Debug: Print FP8 weight and scale info for expert 0
print("\n=== DEBUG: FP8 Weight and Scale Info (Expert 0) ===")
print(f"gate_fp8[0] first 8x8 block:")
for i in range(8):
print(f" row {i}: {gate_fp8[0, i, :8].numpy().tobytes().hex(' ')}")
print(f"gate_fp8[0] stats: min={gate_fp8[0].min()}, max={gate_fp8[0].max()}")
print(f"gate_scales[0] first 4x4 block:\n{gate_scales[0, :4, :4]}")
print(f"gate_scales[0] stats: min={gate_scales[0].min()}, max={gate_scales[0].max()}")
print(f"\nup_fp8[0] first 8x8 block:")
for i in range(8):
print(f" row {i}: {up_fp8[0, i, :8].numpy().tobytes().hex(' ')}")
print(f"up_fp8[0] stats: min={up_fp8[0].min()}, max={up_fp8[0].max()}")
print(f"up_scales[0] first 4x4 block:\n{up_scales[0, :4, :4]}")
print(f"up_scales[0] stats: min={up_scales[0].min()}, max={up_scales[0].max()}")
print(f"\ndown_fp8[0] first 8x8 block:")
for i in range(8):
print(f" row {i}: {down_fp8[0, i, :8].numpy().tobytes().hex(' ')}")
print(f"down_fp8[0] stats: min={down_fp8[0].min()}, max={down_fp8[0].max()}")
print(f"down_scales[0] first 4x4 block:\n{down_scales[0, :4, :4]}")
print(f"down_scales[0] stats: min={down_scales[0].min()}, max={down_scales[0].max()}")
return {
"gate_fp8": gate_fp8.contiguous(),
"up_fp8": up_fp8.contiguous(),
"down_fp8": down_fp8.contiguous(),
"gate_scales": gate_scales.contiguous(),
"up_scales": up_scales.contiguous(),
"down_scales": down_scales.contiguous(),
"gate_deq": gate_deq.contiguous(),
"up_deq": up_deq.contiguous(),
"down_deq": down_deq.contiguous(),
}
def build_moes_from_fp8_data(fp8_data: dict):
"""
Build FP8 MoE modules from quantized data.
"""
moes = []
with torch.inference_mode(mode=True):
for _ in range(layer_num):
config = kt_kernel_ext.moe.MOEConfig(expert_num, num_experts_per_tok, hidden_size, intermediate_size, 0)
config.max_len = max_len
config.quant_config.bits = 8
config.quant_config.group_size = fp8_group_size
config.quant_config.zero_point = False
# Set FP8 weight pointers
config.gate_proj = fp8_data["gate_fp8"].data_ptr()
config.up_proj = fp8_data["up_fp8"].data_ptr()
config.down_proj = fp8_data["down_fp8"].data_ptr()
# Set scale pointers
config.gate_scale = fp8_data["gate_scales"].data_ptr()
config.up_scale = fp8_data["up_scales"].data_ptr()
config.down_scale = fp8_data["down_scales"].data_ptr()
config.pool = CPUInfer.backend_
moe = kt_kernel_ext.moe.AMXFP8_MOE(config)
CPUInfer.submit(moe.load_weights_task(physical_to_logical_map.data_ptr()))
CPUInfer.sync()
moes.append(moe)
return moes
def run_fp8_moe_test():
"""
Run FP8 MoE validation test.
"""
print("\n" + "=" * 70)
print("FP8 MoE Kernel Validation Test")
print("=" * 70)
# Build FP8 weights
print("\nGenerating and quantizing weights...")
fp8_data = build_random_fp8_weights()
# Build MoE modules
print("\nBuilding FP8 MoE modules...")
moes = build_moes_from_fp8_data(fp8_data)
# Get dequantized weights for reference
gate_deq = fp8_data["gate_deq"]
up_deq = fp8_data["up_deq"]
down_deq = fp8_data["down_deq"]
diffs = []
with torch.inference_mode(mode=True):
for i in range(validation_iter):
torch.manual_seed(100 + i)
bsz_tensor = torch.tensor([qlen], device="cpu")
expert_ids = torch.stack(
[torch.randperm(expert_num)[:num_experts_per_tok] for _ in range(qlen)]
).contiguous()
weights = torch.randn((qlen, num_experts_per_tok), dtype=torch.float32).contiguous() / 100
input_tensor = torch.randn((qlen, hidden_size), dtype=torch.bfloat16).contiguous() * 1.5
output = torch.empty((qlen, hidden_size), dtype=torch.bfloat16).contiguous()
moe = moes[i % layer_num]
CPUInfer.submit(
moe.forward_task(
bsz_tensor.data_ptr(),
num_experts_per_tok,
expert_ids.data_ptr(),
weights.data_ptr(),
input_tensor.data_ptr(),
output.data_ptr(),
False,
)
)
CPUInfer.sync()
assert not torch.isnan(output).any(), "NaN values detected in CPU expert output."
assert not torch.isinf(output).any(), "Inf values detected in CPU expert output."
# Reference computation using dequantized weights
t_output = moe_torch(input_tensor, expert_ids, weights, gate_deq, up_deq, down_deq)
t_output_flat = t_output.flatten()
output_flat = output.flatten()
diff = torch.mean(torch.abs(output_flat - t_output_flat)) / (torch.mean(torch.abs(t_output_flat)) + 1e-12)
diffs.append(diff.item())
print(f"Iteration {i}: relative L1 diff = {diff:.6f}")
if i < 3: # Print detailed output for first few iterations
print(f" kernel output: {output_flat[:debug_print_count]}")
print(f" torch output: {t_output_flat[:debug_print_count]}")
mean_diff = float(sum(diffs) / len(diffs))
max_diff = float(max(diffs))
min_diff = float(min(diffs))
print("\n" + "=" * 70)
print("FP8 MoE Test Results")
print("=" * 70)
print(f"Mean relative L1 diff: {mean_diff*100:.4f}%")
print(f"Max relative L1 diff: {max_diff*100:.4f}%")
print(f"Min relative L1 diff: {min_diff*100:.4f}%")
# Pass/Fail criteria
threshold = 15.0 # 15% relative error threshold for FP8
if mean_diff * 100 < threshold:
print(f"\nPASS: Mean error {mean_diff*100:.4f}% < {threshold}% threshold")
else:
print(f"\nFAIL: Mean error {mean_diff*100:.4f}% >= {threshold}% threshold")
return {"mean": mean_diff, "max": max_diff, "min": min_diff}
if __name__ == "__main__":
run_fp8_moe_test()