mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2026-04-20 14:29:22 +00:00
287 lines
10 KiB
Python
287 lines
10 KiB
Python
"""
|
||
Performance benchmark for FP8 MoE kernel (AVX implementation).
|
||
|
||
This benchmark measures the performance of the FP8 MoE operator with:
|
||
- FP8 (E4M3) weights with 128x128 block-wise scaling
|
||
- BF16 activations
|
||
- AVX-512 DPBF16 compute path
|
||
"""
|
||
|
||
import os
|
||
import sys
|
||
import time
|
||
import json
|
||
import subprocess
|
||
import platform
|
||
|
||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "build"))
|
||
|
||
import torch
|
||
from kt_kernel import kt_kernel_ext
|
||
from tqdm import tqdm
|
||
|
||
# Test parameters
|
||
expert_num = 256
|
||
hidden_size = 7168
|
||
intermediate_size = 2048
|
||
num_experts_per_tok = 8
|
||
fp8_group_size = 128
|
||
max_len = 25600
|
||
|
||
layer_num = 2
|
||
qlen = 1
|
||
warm_up_iter = 1000
|
||
test_iter = 3000
|
||
CPUINFER_PARAM = 80
|
||
|
||
CPUInfer = kt_kernel_ext.CPUInfer(CPUINFER_PARAM)
|
||
|
||
# Result file path
|
||
script_path = os.path.abspath(__file__)
|
||
script_dir = os.path.dirname(script_path)
|
||
json_path = os.path.join(script_dir, "bench_results.jsonl")
|
||
|
||
|
||
def get_git_commit():
|
||
"""Get current git commit info"""
|
||
result = {}
|
||
try:
|
||
commit = subprocess.check_output(["git", "rev-parse", "HEAD"]).decode("utf-8").strip()
|
||
commit_msg = subprocess.check_output(["git", "log", "-1", "--pretty=%B"]).decode("utf-8").strip()
|
||
result["commit"] = commit
|
||
result["commit_message"] = commit_msg
|
||
dirty_output = subprocess.check_output(["git", "status", "--porcelain"]).decode("utf-8").strip()
|
||
result["dirty"] = bool(dirty_output)
|
||
if dirty_output:
|
||
result["dirty_files"] = dirty_output.splitlines()
|
||
except Exception as e:
|
||
result["commit"] = None
|
||
result["error"] = str(e)
|
||
return result
|
||
|
||
|
||
def get_system_info():
|
||
"""Get system information"""
|
||
info = {}
|
||
uname = platform.uname()
|
||
info["system_name"] = uname.system
|
||
info["node_name"] = uname.node
|
||
|
||
cpu_model = None
|
||
if os.path.exists("/proc/cpuinfo"):
|
||
try:
|
||
with open("/proc/cpuinfo", "r") as f:
|
||
for line in f:
|
||
if "model name" in line:
|
||
cpu_model = line.split(":", 1)[1].strip()
|
||
break
|
||
except Exception:
|
||
pass
|
||
info["cpu_model"] = cpu_model
|
||
info["cpu_core_count"] = os.cpu_count()
|
||
return info
|
||
|
||
|
||
def record_results(result, filename=json_path):
|
||
"""Append result to JSON file"""
|
||
with open(filename, "a") as f:
|
||
f.write(json.dumps(result) + "\n")
|
||
|
||
|
||
def generate_fp8_weights_direct(shape: tuple, group_size: int = 128):
|
||
"""
|
||
Directly generate random FP8 weights and e8m0 format scale_inv.
|
||
|
||
Args:
|
||
shape: (expert_num, n, k) - weight tensor shape
|
||
group_size: block size for scaling (128x128 blocks)
|
||
|
||
Returns:
|
||
fp8_weights: uint8 tensor with random FP8 E4M3 values
|
||
scale_inv: fp32 tensor with e8m0 format (powers of 2)
|
||
"""
|
||
e, n, k = shape
|
||
n_blocks = n // group_size
|
||
k_blocks = k // group_size
|
||
|
||
# Directly generate random FP8 weights as uint8
|
||
# FP8 E4M3 format: 1 sign + 4 exp + 3 mantissa
|
||
# Valid range for normal numbers: exp 1-14 (0 is subnormal, 15 is special)
|
||
fp8_weights = torch.randint(0, 256, (e, n, k), dtype=torch.uint8, device="cuda").to("cpu").contiguous()
|
||
|
||
# Generate e8m0 format scale_inv (powers of 2)
|
||
# e8m0: 8-bit exponent only, no mantissa, bias = 127
|
||
# Generate random exponents in a reasonable range (e.g., -8 to 8)
|
||
exponents = torch.randint(-8, 9, (e, n_blocks, k_blocks), dtype=torch.int32, device="cuda").to("cpu").contiguous()
|
||
scale_inv = (2.0 ** exponents.float()).to(torch.float32).contiguous()
|
||
|
||
return fp8_weights, scale_inv
|
||
|
||
|
||
def bench_fp8_moe():
|
||
"""Benchmark FP8 MoE performance"""
|
||
with torch.inference_mode():
|
||
print("=" * 70)
|
||
print("FP8 MoE Kernel Performance Benchmark")
|
||
print("=" * 70)
|
||
|
||
# Generate FP8 weights directly (no quantization from fp32)
|
||
print("\nGenerating FP8 weights directly...")
|
||
torch.manual_seed(42)
|
||
gate_fp8, gate_scales = generate_fp8_weights_direct(
|
||
(expert_num, intermediate_size, hidden_size), fp8_group_size
|
||
)
|
||
up_fp8, up_scales = generate_fp8_weights_direct((expert_num, intermediate_size, hidden_size), fp8_group_size)
|
||
down_fp8, down_scales = generate_fp8_weights_direct(
|
||
(expert_num, hidden_size, intermediate_size), fp8_group_size
|
||
)
|
||
|
||
physical_to_logical_map = torch.tensor(range(expert_num), device="cpu", dtype=torch.int64).contiguous()
|
||
|
||
# Build MoE layers
|
||
print("Building FP8 MoE layers...")
|
||
moes = []
|
||
for _ in tqdm(range(layer_num), desc="Initializing MOEs"):
|
||
config = kt_kernel_ext.moe.MOEConfig(expert_num, num_experts_per_tok, hidden_size, intermediate_size, 0)
|
||
config.max_len = max_len
|
||
config.quant_config.bits = 8
|
||
config.quant_config.group_size = fp8_group_size
|
||
config.quant_config.zero_point = False
|
||
|
||
config.gate_proj = gate_fp8.data_ptr()
|
||
config.up_proj = up_fp8.data_ptr()
|
||
config.down_proj = down_fp8.data_ptr()
|
||
config.gate_scale = gate_scales.data_ptr()
|
||
config.up_scale = up_scales.data_ptr()
|
||
config.down_scale = down_scales.data_ptr()
|
||
config.pool = CPUInfer.backend_
|
||
|
||
moe = kt_kernel_ext.moe.AMXFP8_MOE(config)
|
||
CPUInfer.submit(moe.load_weights_task(physical_to_logical_map.data_ptr()))
|
||
CPUInfer.sync()
|
||
moes.append(moe)
|
||
|
||
# Generate input data
|
||
print("Generating input data...")
|
||
gen_iter = 1000
|
||
expert_ids = (
|
||
torch.rand(gen_iter * qlen, expert_num, device="cpu")
|
||
.argsort(dim=-1)[:, :num_experts_per_tok]
|
||
.reshape(gen_iter, qlen * num_experts_per_tok)
|
||
.contiguous()
|
||
)
|
||
weights = torch.rand((gen_iter, qlen, num_experts_per_tok), dtype=torch.float32, device="cpu").contiguous()
|
||
input_tensor = torch.randn((layer_num, qlen, hidden_size), dtype=torch.bfloat16, device="cpu").contiguous()
|
||
output_tensor = torch.empty((layer_num, qlen, hidden_size), dtype=torch.bfloat16, device="cpu").contiguous()
|
||
qlen_tensor = torch.tensor([qlen], dtype=torch.int32)
|
||
|
||
# Warmup
|
||
print(f"Warming up ({warm_up_iter} iterations)...")
|
||
for i in tqdm(range(warm_up_iter), desc="Warm-up"):
|
||
CPUInfer.submit(
|
||
moes[i % layer_num].forward_task(
|
||
qlen_tensor.data_ptr(),
|
||
num_experts_per_tok,
|
||
expert_ids[i % gen_iter].data_ptr(),
|
||
weights[i % gen_iter].data_ptr(),
|
||
input_tensor[i % layer_num].data_ptr(),
|
||
output_tensor[i % layer_num].data_ptr(),
|
||
False,
|
||
)
|
||
)
|
||
CPUInfer.sync()
|
||
|
||
# Benchmark
|
||
print(f"Running benchmark ({test_iter} iterations)...")
|
||
start = time.perf_counter()
|
||
for i in tqdm(range(test_iter), desc="Testing"):
|
||
CPUInfer.submit(
|
||
moes[i % layer_num].forward_task(
|
||
qlen_tensor.data_ptr(),
|
||
num_experts_per_tok,
|
||
expert_ids[i % gen_iter].data_ptr(),
|
||
weights[i % gen_iter].data_ptr(),
|
||
input_tensor[i % layer_num].data_ptr(),
|
||
output_tensor[i % layer_num].data_ptr(),
|
||
False,
|
||
)
|
||
)
|
||
CPUInfer.sync()
|
||
end = time.perf_counter()
|
||
total_time = end - start
|
||
|
||
# Calculate metrics
|
||
time_per_iter_us = total_time / test_iter * 1e6
|
||
|
||
# FLOPS calculation:
|
||
# Each expert performs: gate(intermediate x hidden) + up(intermediate x hidden) + down(hidden x intermediate)
|
||
# GEMM/GEMV: 2 * m * n * k flops (multiply + accumulate = 2 ops per element)
|
||
# For vector-matrix multiply (qlen=1): 2 * n * k per matrix
|
||
flops_per_expert = (
|
||
2 * intermediate_size * hidden_size # gate
|
||
+ 2 * intermediate_size * hidden_size # up
|
||
+ 2 * hidden_size * intermediate_size # down
|
||
)
|
||
total_flops = qlen * num_experts_per_tok * flops_per_expert * test_iter
|
||
tflops = total_flops / total_time / 1e12
|
||
|
||
# Bandwidth calculation (FP8 = 1 byte per element)
|
||
bytes_per_elem = 1.0
|
||
# Weight memory: gate + up + down per expert
|
||
bandwidth = (
|
||
hidden_size
|
||
* intermediate_size
|
||
* 3
|
||
* num_experts_per_tok
|
||
* (1 / num_experts_per_tok * expert_num * (1 - (1 - num_experts_per_tok / expert_num) ** qlen))
|
||
* bytes_per_elem
|
||
* test_iter
|
||
/ total_time
|
||
/ 1e9
|
||
) # 单位:GB/s
|
||
|
||
# Print results
|
||
print("\n" + "=" * 70)
|
||
print("Benchmark Results")
|
||
print("=" * 70)
|
||
print(f"Quant mode: FP8 (E4M3) with {fp8_group_size}x{fp8_group_size} block scaling")
|
||
print(f"Total time: {total_time:.4f} s")
|
||
print(f"Iterations: {test_iter}")
|
||
print(f"Time per iteration: {time_per_iter_us:.2f} us")
|
||
print(f"Bandwidth: {bandwidth:.2f} GB/s")
|
||
print(f"TFLOPS: {tflops:.4f}")
|
||
print("")
|
||
|
||
# Record results
|
||
result = {
|
||
"test_name": os.path.basename(__file__),
|
||
"quant_mode": "fp8_e4m3",
|
||
"total_time_seconds": total_time,
|
||
"iterations": test_iter,
|
||
"time_per_iteration_us": time_per_iter_us,
|
||
"bandwidth_GBs": bandwidth,
|
||
"flops_TFLOPS": tflops,
|
||
"timestamp": time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()),
|
||
"test_parameters": {
|
||
"expert_num": expert_num,
|
||
"hidden_size": hidden_size,
|
||
"intermediate_size": intermediate_size,
|
||
"num_experts_per_tok": num_experts_per_tok,
|
||
"fp8_group_size": fp8_group_size,
|
||
"layer_num": layer_num,
|
||
"qlen": qlen,
|
||
"warm_up_iter": warm_up_iter,
|
||
"test_iter": test_iter,
|
||
"CPUInfer_parameter": CPUINFER_PARAM,
|
||
},
|
||
}
|
||
result.update(get_git_commit())
|
||
result.update(get_system_info())
|
||
record_results(result)
|
||
|
||
return tflops, bandwidth
|
||
|
||
|
||
if __name__ == "__main__":
|
||
bench_fp8_moe()
|