mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2026-04-20 14:29:22 +00:00
264
kt-kernel/bench/bench_bf16_moe.py
Normal file
264
kt-kernel/bench/bench_bf16_moe.py
Normal file
@@ -0,0 +1,264 @@
|
||||
"""
|
||||
Performance benchmark for native BF16 MoE kernel (AMX implementation).
|
||||
|
||||
This benchmark measures the performance of the BF16 MoE operator with:
|
||||
- Native BF16 weights (no quantization)
|
||||
- BF16 activations
|
||||
- AMX BF16 DPBF16PS compute path
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import json
|
||||
import subprocess
|
||||
import platform
|
||||
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "build"))
|
||||
|
||||
import torch
|
||||
from kt_kernel import kt_kernel_ext
|
||||
from tqdm import tqdm
|
||||
|
||||
# Test parameters
|
||||
expert_num = 256
|
||||
hidden_size = 7168
|
||||
intermediate_size = 2048
|
||||
num_experts_per_tok = 8
|
||||
max_len = 25600
|
||||
|
||||
layer_num = 5
|
||||
qlen = 1
|
||||
warm_up_iter = 100
|
||||
test_iter = 3000
|
||||
CPUINFER_PARAM = 80
|
||||
|
||||
CPUInfer = kt_kernel_ext.CPUInfer(CPUINFER_PARAM)
|
||||
|
||||
# Result file path
|
||||
script_path = os.path.abspath(__file__)
|
||||
script_dir = os.path.dirname(script_path)
|
||||
json_path = os.path.join(script_dir, "bench_bf16_moe.jsonl")
|
||||
|
||||
|
||||
def get_git_commit():
|
||||
"""Get current git commit info"""
|
||||
result = {}
|
||||
try:
|
||||
commit = subprocess.check_output(["git", "rev-parse", "HEAD"]).decode("utf-8").strip()
|
||||
commit_msg = subprocess.check_output(["git", "log", "-1", "--pretty=%B"]).decode("utf-8").strip()
|
||||
result["commit"] = commit
|
||||
result["commit_message"] = commit_msg
|
||||
dirty_output = subprocess.check_output(["git", "status", "--porcelain"]).decode("utf-8").strip()
|
||||
result["dirty"] = bool(dirty_output)
|
||||
if dirty_output:
|
||||
result["dirty_files"] = dirty_output.splitlines()
|
||||
except Exception as e:
|
||||
result["commit"] = None
|
||||
result["error"] = str(e)
|
||||
return result
|
||||
|
||||
|
||||
def get_system_info():
|
||||
"""Get system information"""
|
||||
info = {}
|
||||
uname = platform.uname()
|
||||
info["system_name"] = uname.system
|
||||
info["node_name"] = uname.node
|
||||
|
||||
cpu_model = None
|
||||
if os.path.exists("/proc/cpuinfo"):
|
||||
try:
|
||||
with open("/proc/cpuinfo", "r") as f:
|
||||
for line in f:
|
||||
if "model name" in line:
|
||||
cpu_model = line.split(":", 1)[1].strip()
|
||||
break
|
||||
except Exception:
|
||||
pass
|
||||
info["cpu_model"] = cpu_model
|
||||
info["cpu_core_count"] = os.cpu_count()
|
||||
return info
|
||||
|
||||
|
||||
def record_results(result, filename=json_path):
|
||||
"""Append result to JSON file"""
|
||||
with open(filename, "a") as f:
|
||||
f.write(json.dumps(result) + "\n")
|
||||
|
||||
|
||||
def generate_bf16_weights(shape: tuple):
|
||||
"""
|
||||
Generate random BF16 weights.
|
||||
|
||||
Args:
|
||||
shape: (expert_num, n, k) - weight tensor shape
|
||||
|
||||
Returns:
|
||||
bf16_weights: bfloat16 tensor with random values
|
||||
"""
|
||||
# Generate random BF16 weights with small values to avoid overflow
|
||||
weights = (torch.randn(shape, dtype=torch.float32, device="cuda") / 100.0).to(torch.bfloat16).to("cpu").contiguous()
|
||||
return weights
|
||||
|
||||
|
||||
def bench_bf16_moe():
|
||||
"""Benchmark native BF16 MoE performance"""
|
||||
with torch.inference_mode():
|
||||
print("=" * 70)
|
||||
print("Native BF16 MoE Kernel Performance Benchmark")
|
||||
print("=" * 70)
|
||||
|
||||
# Generate BF16 weights
|
||||
print("\nGenerating BF16 weights...")
|
||||
torch.manual_seed(42)
|
||||
gate_proj = generate_bf16_weights((expert_num, intermediate_size, hidden_size))
|
||||
up_proj = generate_bf16_weights((expert_num, intermediate_size, hidden_size))
|
||||
down_proj = generate_bf16_weights((expert_num, hidden_size, intermediate_size))
|
||||
|
||||
physical_to_logical_map = torch.tensor(range(expert_num), device="cpu", dtype=torch.int64).contiguous()
|
||||
|
||||
# Build MoE layers
|
||||
print("Building BF16 MoE layers...")
|
||||
moes = []
|
||||
for _ in tqdm(range(layer_num), desc="Initializing MOEs"):
|
||||
config = kt_kernel_ext.moe.MOEConfig(expert_num, num_experts_per_tok, hidden_size, intermediate_size, 0)
|
||||
config.max_len = max_len
|
||||
|
||||
# Set BF16 weight pointers (no scales needed)
|
||||
config.gate_proj = gate_proj.data_ptr()
|
||||
config.up_proj = up_proj.data_ptr()
|
||||
config.down_proj = down_proj.data_ptr()
|
||||
|
||||
# No scales for BF16
|
||||
config.gate_scale = 0
|
||||
config.up_scale = 0
|
||||
config.down_scale = 0
|
||||
config.pool = CPUInfer.backend_
|
||||
|
||||
moe = kt_kernel_ext.moe.AMXBF16_MOE(config)
|
||||
CPUInfer.submit(moe.load_weights_task(physical_to_logical_map.data_ptr()))
|
||||
CPUInfer.sync()
|
||||
moes.append(moe)
|
||||
|
||||
# Generate input data
|
||||
print("Generating input data...")
|
||||
gen_iter = 1000
|
||||
expert_ids = (
|
||||
torch.rand(gen_iter * qlen, expert_num, device="cpu")
|
||||
.argsort(dim=-1)[:, :num_experts_per_tok]
|
||||
.reshape(gen_iter, qlen * num_experts_per_tok)
|
||||
.contiguous()
|
||||
)
|
||||
weights = torch.rand((gen_iter, qlen, num_experts_per_tok), dtype=torch.float32, device="cpu").contiguous()
|
||||
input_tensor = torch.randn((layer_num, qlen, hidden_size), dtype=torch.bfloat16, device="cpu").contiguous()
|
||||
output_tensor = torch.empty((layer_num, qlen, hidden_size), dtype=torch.bfloat16, device="cpu").contiguous()
|
||||
qlen_tensor = torch.tensor([qlen], dtype=torch.int32)
|
||||
|
||||
# Warmup
|
||||
print(f"Warming up ({warm_up_iter} iterations)...")
|
||||
for i in tqdm(range(warm_up_iter), desc="Warm-up"):
|
||||
CPUInfer.submit(
|
||||
moes[i % layer_num].forward_task(
|
||||
qlen_tensor.data_ptr(),
|
||||
num_experts_per_tok,
|
||||
expert_ids[i % gen_iter].data_ptr(),
|
||||
weights[i % gen_iter].data_ptr(),
|
||||
input_tensor[i % layer_num].data_ptr(),
|
||||
output_tensor[i % layer_num].data_ptr(),
|
||||
False,
|
||||
)
|
||||
)
|
||||
CPUInfer.sync()
|
||||
|
||||
# Benchmark
|
||||
print(f"Running benchmark ({test_iter} iterations)...")
|
||||
start = time.perf_counter()
|
||||
for i in tqdm(range(test_iter), desc="Testing"):
|
||||
CPUInfer.submit(
|
||||
moes[i % layer_num].forward_task(
|
||||
qlen_tensor.data_ptr(),
|
||||
num_experts_per_tok,
|
||||
expert_ids[i % gen_iter].data_ptr(),
|
||||
weights[i % gen_iter].data_ptr(),
|
||||
input_tensor[i % layer_num].data_ptr(),
|
||||
output_tensor[i % layer_num].data_ptr(),
|
||||
False,
|
||||
)
|
||||
)
|
||||
CPUInfer.sync()
|
||||
end = time.perf_counter()
|
||||
total_time = end - start
|
||||
|
||||
# Calculate metrics
|
||||
time_per_iter_us = total_time / test_iter * 1e6
|
||||
|
||||
# FLOPS calculation:
|
||||
# Each expert performs: gate(intermediate x hidden) + up(intermediate x hidden) + down(hidden x intermediate)
|
||||
# GEMM/GEMV: 2 * m * n * k flops (multiply + accumulate = 2 ops per element)
|
||||
flops_per_expert = (
|
||||
2 * intermediate_size * hidden_size # gate
|
||||
+ 2 * intermediate_size * hidden_size # up
|
||||
+ 2 * hidden_size * intermediate_size # down
|
||||
)
|
||||
total_flops = qlen * num_experts_per_tok * flops_per_expert * test_iter
|
||||
tflops = total_flops / total_time / 1e12
|
||||
|
||||
# Bandwidth calculation (BF16 = 2 bytes per element)
|
||||
bytes_per_elem = 2.0
|
||||
# Weight memory: gate + up + down per expert
|
||||
bandwidth = (
|
||||
hidden_size
|
||||
* intermediate_size
|
||||
* 3
|
||||
* num_experts_per_tok
|
||||
* (1 / num_experts_per_tok * expert_num * (1 - (1 - num_experts_per_tok / expert_num) ** qlen))
|
||||
* bytes_per_elem
|
||||
* test_iter
|
||||
/ total_time
|
||||
/ 1e9
|
||||
) # GB/s
|
||||
|
||||
# Print results
|
||||
print("\n" + "=" * 70)
|
||||
print("Benchmark Results")
|
||||
print("=" * 70)
|
||||
print(f"Quant mode: Native BF16 (no quantization)")
|
||||
print(f"Total time: {total_time:.4f} s")
|
||||
print(f"Iterations: {test_iter}")
|
||||
print(f"Time per iteration: {time_per_iter_us:.2f} us")
|
||||
print(f"Bandwidth: {bandwidth:.2f} GB/s")
|
||||
print(f"TFLOPS: {tflops:.4f}")
|
||||
print("")
|
||||
|
||||
# Record results
|
||||
result = {
|
||||
"test_name": os.path.basename(__file__),
|
||||
"quant_mode": "bf16_native",
|
||||
"total_time_seconds": total_time,
|
||||
"iterations": test_iter,
|
||||
"time_per_iteration_us": time_per_iter_us,
|
||||
"bandwidth_GBs": bandwidth,
|
||||
"flops_TFLOPS": tflops,
|
||||
"timestamp": time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()),
|
||||
"test_parameters": {
|
||||
"expert_num": expert_num,
|
||||
"hidden_size": hidden_size,
|
||||
"intermediate_size": intermediate_size,
|
||||
"num_experts_per_tok": num_experts_per_tok,
|
||||
"layer_num": layer_num,
|
||||
"qlen": qlen,
|
||||
"warm_up_iter": warm_up_iter,
|
||||
"test_iter": test_iter,
|
||||
"CPUInfer_parameter": CPUINFER_PARAM,
|
||||
},
|
||||
}
|
||||
result.update(get_git_commit())
|
||||
result.update(get_system_info())
|
||||
record_results(result)
|
||||
|
||||
return tflops, bandwidth
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
bench_bf16_moe()
|
||||
@@ -1,9 +1,16 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
"""
|
||||
Benchmark write_weight_scale_to_buffer for AMX_FP8_MOE_TP (FP8 weights + float32 scales).
|
||||
Benchmark write_weight_scale_to_buffer for AMX MOE operators.
|
||||
|
||||
Uses two MOE instances that alternate writing to simulate realistic multi-layer scenarios.
|
||||
Supports:
|
||||
- FP8: FP8 weights (1 byte) + float32 scales
|
||||
- BF16: Native BF16 weights (2 bytes), no scales
|
||||
|
||||
Usage:
|
||||
python bench_write_buffer.py # Run all modes
|
||||
python bench_write_buffer.py fp8 # Run FP8 only
|
||||
python bench_write_buffer.py bf16 # Run BF16 only
|
||||
"""
|
||||
import json
|
||||
import os
|
||||
@@ -17,7 +24,6 @@ from tqdm import tqdm
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "build"))
|
||||
|
||||
from kt_kernel import kt_kernel_ext
|
||||
from kt_kernel_ext.moe import AMXFP8_MOE
|
||||
import torch
|
||||
|
||||
# Benchmark parameters
|
||||
@@ -87,10 +93,19 @@ def record_results(result, filename=json_path):
|
||||
f.write(json.dumps(result) + "\n")
|
||||
|
||||
|
||||
def allocate_weights():
|
||||
def div_up(a, b):
|
||||
return (a + b - 1) // b
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
# FP8 Functions
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
def allocate_weights_fp8():
|
||||
per_mat_weight_bytes = hidden_size * intermediate_size
|
||||
n_blocks_n_gate_up = (intermediate_size + group_size - 1) // group_size
|
||||
n_blocks_k = (hidden_size + group_size - 1) // group_size
|
||||
n_blocks_n_gate_up = div_up(intermediate_size, group_size)
|
||||
n_blocks_k = div_up(hidden_size, group_size)
|
||||
per_mat_scale_elems_gate_up = n_blocks_n_gate_up * n_blocks_k
|
||||
per_mat_scale_elems_down = n_blocks_k * n_blocks_n_gate_up
|
||||
|
||||
@@ -119,32 +134,22 @@ def allocate_weights():
|
||||
torch.randn(expert_num * per_mat_scale_elems_down, dtype=torch.float32, device="cuda").to("cpu").contiguous()
|
||||
)
|
||||
|
||||
return (
|
||||
gate_q,
|
||||
up_q,
|
||||
down_q,
|
||||
gate_scale,
|
||||
up_scale,
|
||||
down_scale,
|
||||
per_mat_weight_bytes,
|
||||
per_mat_scale_elems_gate_up,
|
||||
per_mat_scale_elems_down,
|
||||
)
|
||||
return {
|
||||
"gate_q": gate_q,
|
||||
"up_q": up_q,
|
||||
"down_q": down_q,
|
||||
"gate_scale": gate_scale,
|
||||
"up_scale": up_scale,
|
||||
"down_scale": down_scale,
|
||||
"per_mat_weight_bytes": per_mat_weight_bytes,
|
||||
"per_mat_scale_elems_gate_up": per_mat_scale_elems_gate_up,
|
||||
"per_mat_scale_elems_down": per_mat_scale_elems_down,
|
||||
}
|
||||
|
||||
|
||||
def build_moe(layer_idx=0):
|
||||
"""Build a single MOE instance with the given layer_idx."""
|
||||
(
|
||||
gate_q,
|
||||
up_q,
|
||||
down_q,
|
||||
gate_scale,
|
||||
up_scale,
|
||||
down_scale,
|
||||
per_mat_weight_bytes,
|
||||
per_mat_scale_elems_gate_up,
|
||||
per_mat_scale_elems_down,
|
||||
) = allocate_weights()
|
||||
def build_moe_fp8(layer_idx=0):
|
||||
"""Build a single FP8 MOE instance."""
|
||||
weights = allocate_weights_fp8()
|
||||
|
||||
config = kt_kernel_ext.moe.MOEConfig(expert_num, num_experts_per_tok, hidden_size, intermediate_size)
|
||||
config.max_len = max_len
|
||||
@@ -153,37 +158,28 @@ def build_moe(layer_idx=0):
|
||||
config.quant_config.group_size = group_size
|
||||
config.quant_config.zero_point = False
|
||||
config.pool = CPUInfer.backend_
|
||||
config.gate_proj = gate_q.data_ptr()
|
||||
config.up_proj = up_q.data_ptr()
|
||||
config.down_proj = down_q.data_ptr()
|
||||
config.gate_scale = gate_scale.data_ptr()
|
||||
config.up_scale = up_scale.data_ptr()
|
||||
config.down_scale = down_scale.data_ptr()
|
||||
config.gate_proj = weights["gate_q"].data_ptr()
|
||||
config.up_proj = weights["up_q"].data_ptr()
|
||||
config.down_proj = weights["down_q"].data_ptr()
|
||||
config.gate_scale = weights["gate_scale"].data_ptr()
|
||||
config.up_scale = weights["up_scale"].data_ptr()
|
||||
config.down_scale = weights["down_scale"].data_ptr()
|
||||
|
||||
moe = AMXFP8_MOE(config)
|
||||
moe = kt_kernel_ext.moe.AMXFP8_MOE(config)
|
||||
CPUInfer.submit(moe.load_weights_task(physical_to_logical_map.data_ptr()))
|
||||
CPUInfer.sync()
|
||||
|
||||
keep_tensors = {
|
||||
"gate_q": gate_q,
|
||||
"up_q": up_q,
|
||||
"down_q": down_q,
|
||||
"gate_scale": gate_scale,
|
||||
"up_scale": up_scale,
|
||||
"down_scale": down_scale,
|
||||
}
|
||||
|
||||
buffer_shapes = {
|
||||
"per_mat_weight_bytes": per_mat_weight_bytes,
|
||||
"per_mat_scale_elems_gate_up": per_mat_scale_elems_gate_up,
|
||||
"per_mat_scale_elems_down": per_mat_scale_elems_down,
|
||||
"per_mat_weight_bytes": weights["per_mat_weight_bytes"],
|
||||
"per_mat_scale_elems_gate_up": weights["per_mat_scale_elems_gate_up"],
|
||||
"per_mat_scale_elems_down": weights["per_mat_scale_elems_down"],
|
||||
}
|
||||
|
||||
return moe, buffer_shapes, keep_tensors
|
||||
return moe, buffer_shapes, weights
|
||||
|
||||
|
||||
def allocate_buffers(buffer_shapes):
|
||||
"""Allocate shared output buffers for single expert."""
|
||||
def allocate_buffers_fp8(buffer_shapes):
|
||||
"""Allocate output buffers for FP8 single expert."""
|
||||
per_mat_weight_bytes = buffer_shapes["per_mat_weight_bytes"]
|
||||
per_mat_scale_elems_gate_up = buffer_shapes["per_mat_scale_elems_gate_up"]
|
||||
per_mat_scale_elems_down = buffer_shapes["per_mat_scale_elems_down"]
|
||||
@@ -192,7 +188,6 @@ def allocate_buffers(buffer_shapes):
|
||||
scale_elems_per_expert_per_tp_gate_up = per_mat_scale_elems_gate_up // gpu_tp_count
|
||||
scale_elems_per_expert_per_tp_down = per_mat_scale_elems_down // gpu_tp_count
|
||||
|
||||
# Each buffer stores data for a single expert
|
||||
w13_weight_bufs = [torch.empty(2 * weight_bytes_per_expert_per_tp, dtype=torch.uint8) for _ in range(gpu_tp_count)]
|
||||
w13_scale_bufs = [
|
||||
torch.empty(2 * scale_elems_per_expert_per_tp_gate_up, dtype=torch.float32) for _ in range(gpu_tp_count)
|
||||
@@ -217,23 +212,130 @@ def allocate_buffers(buffer_shapes):
|
||||
return buffer_ptrs, keep_tensors
|
||||
|
||||
|
||||
def bench_write_buffer():
|
||||
# Build two MOE instances with different layer_idx
|
||||
moe_0, buffer_shapes, keep_tensors_0 = build_moe(layer_idx=0)
|
||||
moe_1, _, keep_tensors_1 = build_moe(layer_idx=1)
|
||||
# ==============================================================================
|
||||
# BF16 Functions
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
def allocate_weights_bf16():
|
||||
per_mat_weight_elems = hidden_size * intermediate_size
|
||||
per_mat_weight_bytes = per_mat_weight_elems * 2 # BF16 = 2 bytes
|
||||
|
||||
gate_proj = (
|
||||
torch.randn(expert_num * per_mat_weight_elems, dtype=torch.bfloat16, device="cuda").to("cpu").contiguous()
|
||||
)
|
||||
up_proj = torch.randn(expert_num * per_mat_weight_elems, dtype=torch.bfloat16, device="cuda").to("cpu").contiguous()
|
||||
down_proj = (
|
||||
torch.randn(expert_num * per_mat_weight_elems, dtype=torch.bfloat16, device="cuda").to("cpu").contiguous()
|
||||
)
|
||||
|
||||
return {
|
||||
"gate_proj": gate_proj,
|
||||
"up_proj": up_proj,
|
||||
"down_proj": down_proj,
|
||||
"per_mat_weight_bytes": per_mat_weight_bytes,
|
||||
"per_mat_weight_elems": per_mat_weight_elems,
|
||||
}
|
||||
|
||||
|
||||
def build_moe_bf16(layer_idx=0):
|
||||
"""Build a single BF16 MOE instance."""
|
||||
weights = allocate_weights_bf16()
|
||||
|
||||
config = kt_kernel_ext.moe.MOEConfig(expert_num, num_experts_per_tok, hidden_size, intermediate_size)
|
||||
config.max_len = max_len
|
||||
config.layer_idx = layer_idx
|
||||
config.pool = CPUInfer.backend_
|
||||
config.gate_proj = weights["gate_proj"].data_ptr()
|
||||
config.up_proj = weights["up_proj"].data_ptr()
|
||||
config.down_proj = weights["down_proj"].data_ptr()
|
||||
config.gate_scale = 0
|
||||
config.up_scale = 0
|
||||
config.down_scale = 0
|
||||
|
||||
moe = kt_kernel_ext.moe.AMXBF16_MOE(config)
|
||||
CPUInfer.submit(moe.load_weights_task(physical_to_logical_map.data_ptr()))
|
||||
CPUInfer.sync()
|
||||
|
||||
buffer_shapes = {
|
||||
"per_mat_weight_bytes": weights["per_mat_weight_bytes"],
|
||||
"per_mat_weight_elems": weights["per_mat_weight_elems"],
|
||||
}
|
||||
|
||||
return moe, buffer_shapes, weights
|
||||
|
||||
|
||||
def allocate_buffers_bf16(buffer_shapes):
|
||||
"""Allocate output buffers for BF16 single expert (no scales)."""
|
||||
per_mat_weight_bytes = buffer_shapes["per_mat_weight_bytes"]
|
||||
|
||||
weight_bytes_per_expert_per_tp = per_mat_weight_bytes // gpu_tp_count
|
||||
|
||||
w13_weight_bufs = [torch.empty(2 * weight_bytes_per_expert_per_tp, dtype=torch.uint8) for _ in range(gpu_tp_count)]
|
||||
w2_weight_bufs = [torch.empty(weight_bytes_per_expert_per_tp, dtype=torch.uint8) for _ in range(gpu_tp_count)]
|
||||
# Dummy scale buffers (not used for BF16 but needed for interface)
|
||||
w13_scale_bufs = [torch.empty(1, dtype=torch.float32) for _ in range(gpu_tp_count)]
|
||||
w2_scale_bufs = [torch.empty(1, dtype=torch.float32) for _ in range(gpu_tp_count)]
|
||||
|
||||
buffer_ptrs = {
|
||||
"w13_weight_ptrs": [buf.data_ptr() for buf in w13_weight_bufs],
|
||||
"w13_scale_ptrs": [buf.data_ptr() for buf in w13_scale_bufs],
|
||||
"w2_weight_ptrs": [buf.data_ptr() for buf in w2_weight_bufs],
|
||||
"w2_scale_ptrs": [buf.data_ptr() for buf in w2_scale_bufs],
|
||||
}
|
||||
|
||||
keep_tensors = {
|
||||
"w13_weight_bufs": w13_weight_bufs,
|
||||
"w13_scale_bufs": w13_scale_bufs,
|
||||
"w2_weight_bufs": w2_weight_bufs,
|
||||
"w2_scale_bufs": w2_scale_bufs,
|
||||
}
|
||||
|
||||
return buffer_ptrs, keep_tensors
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
# Benchmark Functions
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
def bench_write_buffer(quant_mode: str):
|
||||
"""Benchmark write_weight_scale_to_buffer for specified quant mode."""
|
||||
print(f"\n{'='*60}")
|
||||
print(f"{quant_mode.upper()} write_weight_scale_to_buffer benchmark")
|
||||
print(f"{'='*60}")
|
||||
|
||||
if quant_mode == "fp8":
|
||||
bytes_per_elem = 1.0
|
||||
moe_0, buffer_shapes, keep_tensors_0 = build_moe_fp8(layer_idx=0)
|
||||
moe_1, _, keep_tensors_1 = build_moe_fp8(layer_idx=1)
|
||||
buffer_ptrs, buffer_keep = allocate_buffers_fp8(buffer_shapes)
|
||||
|
||||
# Calculate total bytes including scales
|
||||
total_weights = hidden_size * intermediate_size * expert_num * 3
|
||||
total_scale_bytes = (
|
||||
(buffer_shapes["per_mat_scale_elems_gate_up"] * 2 + buffer_shapes["per_mat_scale_elems_down"])
|
||||
* expert_num
|
||||
* 4
|
||||
)
|
||||
bytes_per_call = total_weights + total_scale_bytes
|
||||
|
||||
elif quant_mode == "bf16":
|
||||
bytes_per_elem = 2.0
|
||||
moe_0, buffer_shapes, keep_tensors_0 = build_moe_bf16(layer_idx=0)
|
||||
moe_1, _, keep_tensors_1 = build_moe_bf16(layer_idx=1)
|
||||
buffer_ptrs, buffer_keep = allocate_buffers_bf16(buffer_shapes)
|
||||
|
||||
# BF16: only weights, no scales
|
||||
bytes_per_call = hidden_size * intermediate_size * expert_num * 3 * 2 # BF16 = 2 bytes
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unsupported quant_mode: {quant_mode}")
|
||||
|
||||
moes = [moe_0, moe_1]
|
||||
|
||||
# Allocate shared buffers
|
||||
buffer_ptrs, buffer_keep_tensors = allocate_buffers(buffer_shapes)
|
||||
|
||||
total_weights = hidden_size * intermediate_size * expert_num * 3
|
||||
total_scale_bytes = (
|
||||
(buffer_shapes["per_mat_scale_elems_gate_up"] * 2 + buffer_shapes["per_mat_scale_elems_down"]) * expert_num * 4
|
||||
)
|
||||
bytes_per_call = total_weights + total_scale_bytes
|
||||
|
||||
# Warm-up: alternate between two MOEs
|
||||
for _ in tqdm(range(warm_up_iter), desc="Warm-up"):
|
||||
# Warm-up
|
||||
for _ in tqdm(range(warm_up_iter), desc=f"[{quant_mode.upper()}] Warm-up"):
|
||||
for moe_idx, moe in enumerate(moes):
|
||||
for expert_id in range(gpu_experts_num):
|
||||
CPUInfer.submit(
|
||||
@@ -241,10 +343,10 @@ def bench_write_buffer():
|
||||
)
|
||||
CPUInfer.sync()
|
||||
|
||||
# Benchmark
|
||||
total_time = 0
|
||||
for iter_idx in tqdm(range(test_iter), desc="Testing"):
|
||||
for iter_idx in tqdm(range(test_iter), desc=f"[{quant_mode.upper()}] Testing"):
|
||||
start = time.perf_counter()
|
||||
# Alternate between two MOEs
|
||||
for moe_idx, moe in enumerate(moes):
|
||||
for expert_id in range(gpu_experts_num):
|
||||
CPUInfer.submit(
|
||||
@@ -254,7 +356,7 @@ def bench_write_buffer():
|
||||
end = time.perf_counter()
|
||||
iter_time = end - start
|
||||
total_time += iter_time
|
||||
print(f"Iter {iter_idx}: {iter_time*1000:.2f} ms")
|
||||
print(f" Iter {iter_idx}: {iter_time*1000:.2f} ms")
|
||||
time.sleep(0.3)
|
||||
|
||||
# bytes_per_call is for one MOE, we have 2 MOEs
|
||||
@@ -263,7 +365,7 @@ def bench_write_buffer():
|
||||
bandwidth_gbs = bytes_per_iter * test_iter / total_time / 1e9
|
||||
|
||||
print(f"\n{'='*60}")
|
||||
print("FP8 write_weight_scale_to_buffer benchmark (2 MOEs alternating)")
|
||||
print(f"{quant_mode.upper()} write_weight_scale_to_buffer Results (2 MOEs alternating)")
|
||||
print(f"{'='*60}")
|
||||
print(f"Time per iteration: {time_per_iter_ms:.2f} ms")
|
||||
print(f"Bandwidth: {bandwidth_gbs:.2f} GB/s")
|
||||
@@ -271,7 +373,8 @@ def bench_write_buffer():
|
||||
print(f"Time per expert: {time_per_iter_ms/(gpu_experts_num*2)*1000:.2f} us")
|
||||
|
||||
result = {
|
||||
"op": "write_weight_scale_to_buffer_fp8",
|
||||
"op": f"write_weight_scale_to_buffer_{quant_mode}",
|
||||
"quant_mode": quant_mode,
|
||||
"time_per_iteration_ms": time_per_iter_ms,
|
||||
"bandwidth_GBs": bandwidth_gbs,
|
||||
"timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
|
||||
@@ -279,16 +382,51 @@ def bench_write_buffer():
|
||||
"expert_num": expert_num,
|
||||
"hidden_size": hidden_size,
|
||||
"intermediate_size": intermediate_size,
|
||||
"group_size": group_size,
|
||||
"gpu_tp_count": gpu_tp_count,
|
||||
"bytes_per_iter": bytes_per_iter,
|
||||
"num_moes": 2,
|
||||
},
|
||||
}
|
||||
if quant_mode == "fp8":
|
||||
result["test_parameters"]["group_size"] = group_size
|
||||
|
||||
result.update(get_git_commit())
|
||||
result.update(get_system_info())
|
||||
record_results(result)
|
||||
|
||||
return bandwidth_gbs
|
||||
|
||||
|
||||
def main(quant_modes=None):
|
||||
"""Run benchmarks for specified quant modes."""
|
||||
if quant_modes is None:
|
||||
quant_modes = ["fp8", "bf16"]
|
||||
|
||||
results = {}
|
||||
for mode in quant_modes:
|
||||
try:
|
||||
bandwidth = bench_write_buffer(mode)
|
||||
results[mode] = f"PASSED ({bandwidth:.2f} GB/s)"
|
||||
except Exception as e:
|
||||
results[mode] = f"FAILED: {e}"
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("SUMMARY")
|
||||
print("=" * 60)
|
||||
for mode, result in results.items():
|
||||
print(f" {mode.upper()}: {result}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
bench_write_buffer()
|
||||
if len(sys.argv) > 1:
|
||||
mode = sys.argv[1].lower()
|
||||
if mode in ["fp8", "bf16"]:
|
||||
main([mode])
|
||||
else:
|
||||
print(f"Unknown mode: {mode}. Use 'fp8' or 'bf16'")
|
||||
sys.exit(1)
|
||||
else:
|
||||
main()
|
||||
245
kt-kernel/examples/test_bf16_moe.py
Normal file
245
kt-kernel/examples/test_bf16_moe.py
Normal file
@@ -0,0 +1,245 @@
|
||||
"""
|
||||
Test script for AMX_BF16_MOE_TP (native BF16 MoE) kernel validation.
|
||||
|
||||
This script:
|
||||
1. Generates random BF16 weights
|
||||
2. Runs the BF16 MoE kernel
|
||||
3. Compares results with PyTorch reference
|
||||
|
||||
BF16 format notes:
|
||||
- Weight: BF16 stored as ggml_bf16_t, shape [expert_num, n, k]
|
||||
- No scales needed (native BF16 precision)
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
sys.path.insert(0, os.path.dirname(__file__) + "/../build")
|
||||
|
||||
import torch
|
||||
from kt_kernel import kt_kernel_ext
|
||||
|
||||
torch.manual_seed(42)
|
||||
|
||||
# Model config
|
||||
hidden_size = 3072
|
||||
intermediate_size = 1536
|
||||
max_len = 25600
|
||||
|
||||
expert_num = 16
|
||||
num_experts_per_tok = 16
|
||||
|
||||
qlen = 1
|
||||
layer_num = 5
|
||||
CPUInfer = kt_kernel_ext.CPUInfer(3)
|
||||
validation_iter = 5
|
||||
debug_print_count = 16
|
||||
|
||||
physical_to_logical_map = torch.tensor(data=range(expert_num), device="cpu", dtype=torch.int64).contiguous()
|
||||
|
||||
|
||||
def act_fn(x):
|
||||
"""SiLU activation function"""
|
||||
return x / (1.0 + torch.exp(-x))
|
||||
|
||||
|
||||
def mlp_torch(input, gate_proj, up_proj, down_proj):
|
||||
"""Reference MLP computation in PyTorch"""
|
||||
gate_buf = torch.mm(input, gate_proj.t())
|
||||
up_buf = torch.mm(input, up_proj.t())
|
||||
intermediate = act_fn(gate_buf) * up_buf
|
||||
ret = torch.mm(intermediate, down_proj.t())
|
||||
return ret
|
||||
|
||||
|
||||
def moe_torch(input, expert_ids, weights, gate_proj, up_proj, down_proj):
|
||||
"""Reference MoE computation in PyTorch"""
|
||||
cnts = expert_ids.new_zeros((expert_ids.shape[0], expert_num))
|
||||
cnts.scatter_(1, expert_ids, 1)
|
||||
tokens_per_expert = cnts.sum(dim=0)
|
||||
idxs = expert_ids.view(-1).argsort()
|
||||
sorted_tokens = input[idxs // expert_ids.shape[1]]
|
||||
|
||||
outputs = []
|
||||
start_idx = 0
|
||||
|
||||
for i, num_tokens in enumerate(tokens_per_expert):
|
||||
end_idx = start_idx + num_tokens
|
||||
if num_tokens == 0:
|
||||
continue
|
||||
tokens_for_this_expert = sorted_tokens[start_idx:end_idx]
|
||||
expert_out = mlp_torch(tokens_for_this_expert, gate_proj[i], up_proj[i], down_proj[i])
|
||||
outputs.append(expert_out)
|
||||
start_idx = end_idx
|
||||
|
||||
outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0)
|
||||
|
||||
new_x = torch.empty_like(outs)
|
||||
new_x[idxs] = outs
|
||||
t_output = (
|
||||
new_x.view(*expert_ids.shape, -1)
|
||||
.type(weights.dtype)
|
||||
.mul_(weights.unsqueeze(dim=-1))
|
||||
.sum(dim=1)
|
||||
.type(new_x.dtype)
|
||||
)
|
||||
return t_output
|
||||
|
||||
|
||||
def build_bf16_weights():
|
||||
"""
|
||||
Generate random BF16 weights.
|
||||
|
||||
Returns:
|
||||
dict with BF16 weights for gate, up, down projections
|
||||
"""
|
||||
torch.manual_seed(42)
|
||||
|
||||
# Generate random BF16 weights with small values
|
||||
gate_proj = (
|
||||
(torch.randn((expert_num, intermediate_size, hidden_size), dtype=torch.float32) / 100.0)
|
||||
.to(torch.bfloat16)
|
||||
.contiguous()
|
||||
)
|
||||
up_proj = (
|
||||
(torch.randn((expert_num, intermediate_size, hidden_size), dtype=torch.float32) / 100.0)
|
||||
.to(torch.bfloat16)
|
||||
.contiguous()
|
||||
)
|
||||
down_proj = (
|
||||
(torch.randn((expert_num, hidden_size, intermediate_size), dtype=torch.float32) / 100.0)
|
||||
.to(torch.bfloat16)
|
||||
.contiguous()
|
||||
)
|
||||
|
||||
print(f"BF16 weights shape: gate={gate_proj.shape}, up={up_proj.shape}, down={down_proj.shape}")
|
||||
|
||||
# Debug: Print BF16 weight info for expert 0
|
||||
print("\n=== DEBUG: BF16 Weight Info (Expert 0) ===")
|
||||
print(f"gate_proj[0] first 8 values: {gate_proj[0, 0, :8]}")
|
||||
print(f"gate_proj[0] stats: min={gate_proj[0].min()}, max={gate_proj[0].max()}")
|
||||
print(f"up_proj[0] first 8 values: {up_proj[0, 0, :8]}")
|
||||
print(f"down_proj[0] first 8 values: {down_proj[0, 0, :8]}")
|
||||
|
||||
return {
|
||||
"gate_proj": gate_proj,
|
||||
"up_proj": up_proj,
|
||||
"down_proj": down_proj,
|
||||
}
|
||||
|
||||
|
||||
def build_moes_from_bf16_data(bf16_data: dict):
|
||||
"""
|
||||
Build BF16 MoE modules from BF16 weight data.
|
||||
"""
|
||||
moes = []
|
||||
with torch.inference_mode(mode=True):
|
||||
for _ in range(layer_num):
|
||||
config = kt_kernel_ext.moe.MOEConfig(expert_num, num_experts_per_tok, hidden_size, intermediate_size, 0)
|
||||
config.max_len = max_len
|
||||
|
||||
# Set BF16 weight pointers (no scales needed)
|
||||
config.gate_proj = bf16_data["gate_proj"].data_ptr()
|
||||
config.up_proj = bf16_data["up_proj"].data_ptr()
|
||||
config.down_proj = bf16_data["down_proj"].data_ptr()
|
||||
|
||||
# No scales for BF16
|
||||
config.gate_scale = 0
|
||||
config.up_scale = 0
|
||||
config.down_scale = 0
|
||||
config.pool = CPUInfer.backend_
|
||||
|
||||
moe = kt_kernel_ext.moe.AMXBF16_MOE(config)
|
||||
CPUInfer.submit(moe.load_weights_task(physical_to_logical_map.data_ptr()))
|
||||
CPUInfer.sync()
|
||||
moes.append(moe)
|
||||
return moes
|
||||
|
||||
|
||||
def run_bf16_moe_test():
|
||||
"""
|
||||
Run BF16 MoE validation test.
|
||||
"""
|
||||
print("\n" + "=" * 70)
|
||||
print("BF16 MoE Kernel Validation Test")
|
||||
print("=" * 70)
|
||||
|
||||
# Build BF16 weights
|
||||
print("\nGenerating BF16 weights...")
|
||||
bf16_data = build_bf16_weights()
|
||||
|
||||
# Build MoE modules
|
||||
print("\nBuilding BF16 MoE modules...")
|
||||
moes = build_moes_from_bf16_data(bf16_data)
|
||||
|
||||
# Get weights for reference computation
|
||||
gate_proj = bf16_data["gate_proj"]
|
||||
up_proj = bf16_data["up_proj"]
|
||||
down_proj = bf16_data["down_proj"]
|
||||
|
||||
diffs = []
|
||||
with torch.inference_mode(mode=True):
|
||||
for i in range(validation_iter):
|
||||
torch.manual_seed(100 + i)
|
||||
bsz_tensor = torch.tensor([qlen], device="cpu")
|
||||
expert_ids = torch.stack(
|
||||
[torch.randperm(expert_num)[:num_experts_per_tok] for _ in range(qlen)]
|
||||
).contiguous()
|
||||
weights = torch.randn((qlen, num_experts_per_tok), dtype=torch.float32).contiguous() / 100
|
||||
input_tensor = torch.randn((qlen, hidden_size), dtype=torch.bfloat16).contiguous() * 1.5
|
||||
output = torch.empty((qlen, hidden_size), dtype=torch.bfloat16).contiguous()
|
||||
|
||||
moe = moes[i % layer_num]
|
||||
CPUInfer.submit(
|
||||
moe.forward_task(
|
||||
bsz_tensor.data_ptr(),
|
||||
num_experts_per_tok,
|
||||
expert_ids.data_ptr(),
|
||||
weights.data_ptr(),
|
||||
input_tensor.data_ptr(),
|
||||
output.data_ptr(),
|
||||
False,
|
||||
)
|
||||
)
|
||||
CPUInfer.sync()
|
||||
|
||||
assert not torch.isnan(output).any(), "NaN values detected in CPU expert output."
|
||||
assert not torch.isinf(output).any(), "Inf values detected in CPU expert output."
|
||||
|
||||
# Reference computation using BF16 weights
|
||||
t_output = moe_torch(input_tensor, expert_ids, weights, gate_proj, up_proj, down_proj)
|
||||
|
||||
t_output_flat = t_output.flatten()
|
||||
output_flat = output.flatten()
|
||||
|
||||
diff = torch.mean(torch.abs(output_flat - t_output_flat)) / (torch.mean(torch.abs(t_output_flat)) + 1e-12)
|
||||
diffs.append(diff.item())
|
||||
print(f"Iteration {i}: relative L1 diff = {diff:.6f}")
|
||||
|
||||
if i < 3: # Print detailed output for first few iterations
|
||||
print(f" kernel output: {output_flat[:debug_print_count]}")
|
||||
print(f" torch output: {t_output_flat[:debug_print_count]}")
|
||||
|
||||
mean_diff = float(sum(diffs) / len(diffs))
|
||||
max_diff = float(max(diffs))
|
||||
min_diff = float(min(diffs))
|
||||
|
||||
print("\n" + "=" * 70)
|
||||
print("BF16 MoE Test Results")
|
||||
print("=" * 70)
|
||||
print(f"Mean relative L1 diff: {mean_diff*100:.4f}%")
|
||||
print(f"Max relative L1 diff: {max_diff*100:.4f}%")
|
||||
print(f"Min relative L1 diff: {min_diff*100:.4f}%")
|
||||
|
||||
# Pass/Fail criteria (BF16 should be very accurate, <5% error)
|
||||
threshold = 5.0
|
||||
if mean_diff * 100 < threshold:
|
||||
print(f"\nPASS: Mean error {mean_diff*100:.4f}% < {threshold}% threshold")
|
||||
else:
|
||||
print(f"\nFAIL: Mean error {mean_diff*100:.4f}% >= {threshold}% threshold")
|
||||
|
||||
return {"mean": mean_diff, "max": max_diff, "min": min_diff}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_bf16_moe_test()
|
||||
@@ -1,389 +0,0 @@
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
|
||||
from kt_kernel import kt_kernel_ext
|
||||
from kt_kernel_ext import CPUInfer
|
||||
from kt_kernel_ext.moe import AMXFP8_MOE
|
||||
|
||||
|
||||
def make_cpu_infer(thread_num=80):
|
||||
return CPUInfer(thread_num)
|
||||
|
||||
|
||||
def build_config(cpuinfer, expert_num, num_experts_per_tok, hidden_size, intermediate_size, group_size):
|
||||
cfg = kt_kernel_ext.moe.MOEConfig(expert_num, num_experts_per_tok, hidden_size, intermediate_size)
|
||||
cfg.max_len = 1
|
||||
cfg.quant_config.bits = 8 # FP8
|
||||
cfg.quant_config.group_size = group_size
|
||||
cfg.quant_config.zero_point = False
|
||||
cfg.pool = cpuinfer.backend_
|
||||
return cfg
|
||||
|
||||
|
||||
def allocate_weights(expert_num, hidden_size, intermediate_size, group_size):
|
||||
"""Allocate FP8 weights and scales for testing"""
|
||||
# FP8 weights: 1 byte per element
|
||||
per_mat_weight_bytes = hidden_size * intermediate_size
|
||||
# FP8 scales: block-wise (group_size x group_size blocks), stored as float32
|
||||
n_blocks_n_gate_up = (intermediate_size + group_size - 1) // group_size
|
||||
n_blocks_k = (hidden_size + group_size - 1) // group_size
|
||||
per_mat_scale_elems_gate_up = n_blocks_n_gate_up * n_blocks_k
|
||||
|
||||
# For down: n=hidden_size, k=intermediate_size
|
||||
n_blocks_n_down = n_blocks_k
|
||||
n_blocks_k_down = n_blocks_n_gate_up
|
||||
per_mat_scale_elems_down = n_blocks_n_down * n_blocks_k_down
|
||||
|
||||
gate_q = torch.randint(0, 256, (expert_num * per_mat_weight_bytes,), dtype=torch.uint8)
|
||||
up_q = torch.randint(0, 256, (expert_num * per_mat_weight_bytes,), dtype=torch.uint8)
|
||||
down_q = torch.randint(0, 256, (expert_num * per_mat_weight_bytes,), dtype=torch.uint8)
|
||||
|
||||
# FP8 scales are float32
|
||||
gate_scale = torch.randn(expert_num * per_mat_scale_elems_gate_up, dtype=torch.float32)
|
||||
up_scale = torch.randn(expert_num * per_mat_scale_elems_gate_up, dtype=torch.float32)
|
||||
down_scale = torch.randn(expert_num * per_mat_scale_elems_down, dtype=torch.float32)
|
||||
|
||||
return (
|
||||
gate_q,
|
||||
up_q,
|
||||
down_q,
|
||||
gate_scale,
|
||||
up_scale,
|
||||
down_scale,
|
||||
per_mat_weight_bytes,
|
||||
per_mat_scale_elems_gate_up,
|
||||
per_mat_scale_elems_down,
|
||||
)
|
||||
|
||||
|
||||
def test_with_tp(gpu_tp_count):
|
||||
"""Test write_weight_scale_to_buffer with a specific gpu_tp_count"""
|
||||
torch.manual_seed(123)
|
||||
|
||||
expert_num = 256 # Reduced for debugging
|
||||
gpu_experts = expert_num # Number of experts on GPU
|
||||
|
||||
num_experts_per_tok = 8
|
||||
hidden_size = 3072
|
||||
intermediate_size = 1536 # Changed from 2048 to test non-aligned case
|
||||
group_size = 128 # FP8 uses 128x128 block-wise scales
|
||||
|
||||
cpuinfer = make_cpu_infer()
|
||||
cfg = build_config(cpuinfer, expert_num, num_experts_per_tok, hidden_size, intermediate_size, group_size)
|
||||
|
||||
(
|
||||
gate_q,
|
||||
up_q,
|
||||
down_q,
|
||||
gate_scale,
|
||||
up_scale,
|
||||
down_scale,
|
||||
per_mat_weight_bytes,
|
||||
per_mat_scale_elems_gate_up,
|
||||
per_mat_scale_elems_down,
|
||||
) = allocate_weights(expert_num, hidden_size, intermediate_size, group_size)
|
||||
|
||||
cfg.gate_proj = gate_q.data_ptr()
|
||||
cfg.up_proj = up_q.data_ptr()
|
||||
cfg.down_proj = down_q.data_ptr()
|
||||
cfg.gate_scale = gate_scale.data_ptr()
|
||||
cfg.up_scale = up_scale.data_ptr()
|
||||
cfg.down_scale = down_scale.data_ptr()
|
||||
|
||||
moe = AMXFP8_MOE(cfg)
|
||||
|
||||
physical_to_logical_map = torch.arange(expert_num, dtype=torch.int64, device="cpu").contiguous()
|
||||
cpuinfer.submit(moe.load_weights_task(physical_to_logical_map.data_ptr()))
|
||||
cpuinfer.sync()
|
||||
|
||||
# TP configuration
|
||||
# Calculate sizes per TP part (per expert) - must match C++ code which uses div_up
|
||||
def div_up(a, b):
|
||||
return (a + b - 1) // b
|
||||
|
||||
weight_bytes_per_expert_per_tp = per_mat_weight_bytes // gpu_tp_count
|
||||
|
||||
# For W13 (gate/up): n=intermediate_size/gpu_tp, k=hidden_size
|
||||
gpu_n_w13 = intermediate_size // gpu_tp_count
|
||||
gpu_k_w13 = hidden_size
|
||||
scale_elems_per_expert_per_tp_gate_up = div_up(gpu_n_w13, group_size) * div_up(gpu_k_w13, group_size)
|
||||
|
||||
# For W2 (down): n=hidden_size, k=intermediate_size/gpu_tp
|
||||
gpu_n_w2 = hidden_size
|
||||
gpu_k_w2 = intermediate_size // gpu_tp_count
|
||||
scale_elems_per_expert_per_tp_down = div_up(gpu_n_w2, group_size) * div_up(gpu_k_w2, group_size)
|
||||
|
||||
# Total sizes for all gpu_experts
|
||||
total_weight_bytes_per_tp = gpu_experts * weight_bytes_per_expert_per_tp
|
||||
total_scale_elems_per_tp_gate_up = gpu_experts * scale_elems_per_expert_per_tp_gate_up
|
||||
total_scale_elems_per_tp_down = gpu_experts * scale_elems_per_expert_per_tp_down
|
||||
|
||||
# Create buffer lists for w13 (gate+up) and w2 (down)
|
||||
# These hold all experts' data for each GPU TP
|
||||
w13_weight_bufs = []
|
||||
w13_scale_bufs = []
|
||||
w2_weight_bufs = []
|
||||
w2_scale_bufs = []
|
||||
|
||||
for tp_idx in range(gpu_tp_count):
|
||||
# w13 combines gate and up, so needs 2x the size
|
||||
w13_weight_bufs.append(torch.empty(2 * total_weight_bytes_per_tp, dtype=torch.uint8))
|
||||
w13_scale_bufs.append(torch.empty(2 * total_scale_elems_per_tp_gate_up, dtype=torch.float32))
|
||||
w2_weight_bufs.append(torch.empty(total_weight_bytes_per_tp, dtype=torch.uint8))
|
||||
w2_scale_bufs.append(torch.empty(total_scale_elems_per_tp_down, dtype=torch.float32))
|
||||
|
||||
print(f"Total experts: {expert_num}, GPU experts: {gpu_experts}")
|
||||
print(f"GPU TP count: {gpu_tp_count}")
|
||||
print(f"Original per matrix weight bytes: {per_mat_weight_bytes}")
|
||||
print(f"Original per matrix scale elements (gate/up): {per_mat_scale_elems_gate_up}")
|
||||
print(f"Original per matrix scale elements (down): {per_mat_scale_elems_down}")
|
||||
print(f"Weight bytes per expert per TP: {weight_bytes_per_expert_per_tp}")
|
||||
print(f"Scale elements per expert per TP (gate/up): {scale_elems_per_expert_per_tp_gate_up}")
|
||||
print(f"Scale elements per expert per TP (down): {scale_elems_per_expert_per_tp_down}")
|
||||
print(f"Total weight bytes per TP (w13): {2 * total_weight_bytes_per_tp}")
|
||||
print(f"Total weight bytes per TP (w2): {total_weight_bytes_per_tp}")
|
||||
|
||||
# Helper function to get pointers with expert offset
|
||||
# write_weights_to_buffer writes one expert at a time, so we need to pass
|
||||
# pointers that already point to the correct location for each expert
|
||||
def get_expert_ptrs(expert_id):
|
||||
w13_weight_ptrs = []
|
||||
w13_scale_ptrs = []
|
||||
w2_weight_ptrs = []
|
||||
w2_scale_ptrs = []
|
||||
|
||||
for tp_idx in range(gpu_tp_count):
|
||||
# Calculate byte offsets for this expert
|
||||
# w13: gate_weight + up_weight interleaved by expert
|
||||
# Layout: [expert0_gate, expert0_up, expert1_gate, expert1_up, ...]
|
||||
w13_weight_expert_offset = expert_id * 2 * weight_bytes_per_expert_per_tp
|
||||
w13_scale_expert_offset = expert_id * 2 * scale_elems_per_expert_per_tp_gate_up
|
||||
w2_weight_expert_offset = expert_id * weight_bytes_per_expert_per_tp
|
||||
w2_scale_expert_offset = expert_id * scale_elems_per_expert_per_tp_down
|
||||
|
||||
w13_weight_ptrs.append(w13_weight_bufs[tp_idx].data_ptr() + w13_weight_expert_offset)
|
||||
w13_scale_ptrs.append(w13_scale_bufs[tp_idx].data_ptr() + w13_scale_expert_offset * 4) # float32 = 4 bytes
|
||||
w2_weight_ptrs.append(w2_weight_bufs[tp_idx].data_ptr() + w2_weight_expert_offset)
|
||||
w2_scale_ptrs.append(w2_scale_bufs[tp_idx].data_ptr() + w2_scale_expert_offset * 4) # float32 = 4 bytes
|
||||
|
||||
return w13_weight_ptrs, w13_scale_ptrs, w2_weight_ptrs, w2_scale_ptrs
|
||||
|
||||
# Warm up
|
||||
for i in range(2):
|
||||
for expert_id in range(gpu_experts):
|
||||
w13_weight_ptrs, w13_scale_ptrs, w2_weight_ptrs, w2_scale_ptrs = get_expert_ptrs(expert_id)
|
||||
cpuinfer.submit(
|
||||
moe.write_weight_scale_to_buffer_task(
|
||||
gpu_tp_count=gpu_tp_count,
|
||||
expert_id=expert_id,
|
||||
w13_weight_ptrs=w13_weight_ptrs,
|
||||
w13_scale_ptrs=w13_scale_ptrs,
|
||||
w2_weight_ptrs=w2_weight_ptrs,
|
||||
w2_scale_ptrs=w2_scale_ptrs,
|
||||
)
|
||||
)
|
||||
cpuinfer.sync()
|
||||
|
||||
# Timing
|
||||
begin_time = time.perf_counter_ns()
|
||||
for expert_id in range(gpu_experts):
|
||||
w13_weight_ptrs, w13_scale_ptrs, w2_weight_ptrs, w2_scale_ptrs = get_expert_ptrs(expert_id)
|
||||
cpuinfer.submit(
|
||||
moe.write_weight_scale_to_buffer_task(
|
||||
gpu_tp_count=gpu_tp_count,
|
||||
expert_id=expert_id,
|
||||
w13_weight_ptrs=w13_weight_ptrs,
|
||||
w13_scale_ptrs=w13_scale_ptrs,
|
||||
w2_weight_ptrs=w2_weight_ptrs,
|
||||
w2_scale_ptrs=w2_scale_ptrs,
|
||||
)
|
||||
)
|
||||
cpuinfer.sync()
|
||||
end_time = time.perf_counter_ns()
|
||||
elapsed_ms = (end_time - begin_time) / 1000000
|
||||
|
||||
# Calculate throughput
|
||||
total_weights = hidden_size * intermediate_size * gpu_experts * 3
|
||||
total_scale_bytes = (per_mat_scale_elems_gate_up * 2 + per_mat_scale_elems_down) * gpu_experts * 4 # float32
|
||||
total_bytes = total_weights + total_scale_bytes
|
||||
print(f"write_weight_scale_to_buffer time: {elapsed_ms:.2f} ms")
|
||||
print(f"Throughput: {total_bytes / (elapsed_ms * 1e6):.2f} GB/s")
|
||||
|
||||
def split_expert_tensor(tensor, chunk):
|
||||
"""Split tensor by experts"""
|
||||
return [tensor[i * chunk : (i + 1) * chunk] for i in range(expert_num)]
|
||||
|
||||
# Split by experts first
|
||||
gate_q_experts = split_expert_tensor(gate_q, per_mat_weight_bytes)
|
||||
up_q_experts = split_expert_tensor(up_q, per_mat_weight_bytes)
|
||||
down_q_experts = split_expert_tensor(down_q, per_mat_weight_bytes)
|
||||
|
||||
gate_scale_experts = split_expert_tensor(gate_scale, per_mat_scale_elems_gate_up)
|
||||
up_scale_experts = split_expert_tensor(up_scale, per_mat_scale_elems_gate_up)
|
||||
down_scale_experts = split_expert_tensor(down_scale, per_mat_scale_elems_down)
|
||||
|
||||
# For down matrix
|
||||
n_blocks_n = (hidden_size + group_size - 1) // group_size
|
||||
n_blocks_k = (intermediate_size + group_size - 1) // group_size
|
||||
n_blocks_k_per_tp = n_blocks_k // gpu_tp_count
|
||||
|
||||
# Verify buffers for each TP part
|
||||
for tp_idx in range(gpu_tp_count):
|
||||
expected_w13_weights = []
|
||||
expected_w13_scales = []
|
||||
expected_w2_weights = []
|
||||
expected_w2_scales = []
|
||||
|
||||
weight13_per_tp = per_mat_weight_bytes // gpu_tp_count
|
||||
scale13_per_tp = per_mat_scale_elems_gate_up // gpu_tp_count
|
||||
|
||||
# Process each GPU expert
|
||||
for expert_id in range(gpu_experts):
|
||||
# For w13 (gate and up), the slicing is along intermediate_size (n direction)
|
||||
start_weight = tp_idx * weight13_per_tp
|
||||
end_weight = (tp_idx + 1) * weight13_per_tp
|
||||
start_scale = tp_idx * scale13_per_tp
|
||||
end_scale = (tp_idx + 1) * scale13_per_tp
|
||||
|
||||
# Gate
|
||||
gate_weight_tp = gate_q_experts[expert_id][start_weight:end_weight]
|
||||
gate_scale_tp = gate_scale_experts[expert_id][start_scale:end_scale]
|
||||
|
||||
# Up
|
||||
up_weight_tp = up_q_experts[expert_id][start_weight:end_weight]
|
||||
up_scale_tp = up_scale_experts[expert_id][start_scale:end_scale]
|
||||
|
||||
# Down matrix needs special handling because it's sliced column-wise
|
||||
# down is (hidden_size, intermediate_size) in n-major format
|
||||
down_weight_tp_parts = []
|
||||
down_scale_tp_parts = []
|
||||
|
||||
# Iterate through each row to extract the corresponding parts
|
||||
for row_idx in range(hidden_size):
|
||||
row_weight_start = row_idx * intermediate_size
|
||||
|
||||
# Direct mapping: each CPU TP corresponds to a GPU TP
|
||||
tp_slice_weight_size = intermediate_size // gpu_tp_count
|
||||
|
||||
tp_weight_offset = row_weight_start + tp_idx * tp_slice_weight_size
|
||||
|
||||
down_weight_tp_parts.append(
|
||||
down_q_experts[expert_id][tp_weight_offset : tp_weight_offset + tp_slice_weight_size]
|
||||
)
|
||||
|
||||
# For scale: only process at block boundaries
|
||||
for bn in range(n_blocks_n):
|
||||
row_scale_start = bn * n_blocks_k
|
||||
tp_scale_offset = row_scale_start + tp_idx * n_blocks_k_per_tp
|
||||
down_scale_tp_parts.append(
|
||||
down_scale_experts[expert_id][tp_scale_offset : tp_scale_offset + n_blocks_k_per_tp]
|
||||
)
|
||||
|
||||
# Concatenate all slices for this TP
|
||||
down_weight_tp = torch.cat(down_weight_tp_parts)
|
||||
down_scale_tp = torch.cat(down_scale_tp_parts)
|
||||
|
||||
# Append to expected lists - interleaved by expert: [gate0, up0, gate1, up1, ...]
|
||||
expected_w13_weights.append(gate_weight_tp)
|
||||
expected_w13_weights.append(up_weight_tp)
|
||||
expected_w13_scales.append(gate_scale_tp)
|
||||
expected_w13_scales.append(up_scale_tp)
|
||||
expected_w2_weights.append(down_weight_tp)
|
||||
expected_w2_scales.append(down_scale_tp)
|
||||
|
||||
# Concatenate all experts for this TP part
|
||||
expected_w13_weight = torch.cat(expected_w13_weights)
|
||||
expected_w13_scale = torch.cat(expected_w13_scales)
|
||||
expected_w2_weight = torch.cat(expected_w2_weights)
|
||||
expected_w2_scale = torch.cat(expected_w2_scales)
|
||||
|
||||
print(f"=== Checking TP part {tp_idx} ===")
|
||||
print(f" w13 weight shape: actual={w13_weight_bufs[tp_idx].shape}, expected={expected_w13_weight.shape}")
|
||||
print(f" w13 scale shape: actual={w13_scale_bufs[tp_idx].shape}, expected={expected_w13_scale.shape}")
|
||||
print(f" w2 weight shape: actual={w2_weight_bufs[tp_idx].shape}, expected={expected_w2_weight.shape}")
|
||||
print(f" w2 scale shape: actual={w2_scale_bufs[tp_idx].shape}, expected={expected_w2_scale.shape}")
|
||||
|
||||
# Assert all checks pass
|
||||
if not torch.equal(w13_weight_bufs[tp_idx], expected_w13_weight):
|
||||
# Find first mismatch
|
||||
diff_mask = w13_weight_bufs[tp_idx] != expected_w13_weight
|
||||
first_diff_idx = diff_mask.nonzero()[0].item() if diff_mask.any() else -1
|
||||
print(f" w13 weight mismatch at index {first_diff_idx}")
|
||||
print(f" actual: {w13_weight_bufs[tp_idx][first_diff_idx:first_diff_idx+10]}")
|
||||
print(f" expected: {expected_w13_weight[first_diff_idx:first_diff_idx+10]}")
|
||||
raise AssertionError(f"w13 weight bytes mismatch for TP {tp_idx}")
|
||||
|
||||
if not torch.allclose(w13_scale_bufs[tp_idx], expected_w13_scale):
|
||||
diff = torch.abs(w13_scale_bufs[tp_idx] - expected_w13_scale)
|
||||
max_diff_idx = diff.argmax().item()
|
||||
print(f" w13 scale mismatch, max diff at index {max_diff_idx}")
|
||||
print(f" actual: {w13_scale_bufs[tp_idx][max_diff_idx]}")
|
||||
print(f" expected: {expected_w13_scale[max_diff_idx]}")
|
||||
raise AssertionError(f"w13 scale values mismatch for TP {tp_idx}")
|
||||
|
||||
if not torch.equal(w2_weight_bufs[tp_idx], expected_w2_weight):
|
||||
diff_mask = w2_weight_bufs[tp_idx] != expected_w2_weight
|
||||
first_diff_idx = diff_mask.nonzero()[0].item() if diff_mask.any() else -1
|
||||
print(f" w2 weight mismatch at index {first_diff_idx}")
|
||||
print(f" actual: {w2_weight_bufs[tp_idx][first_diff_idx:first_diff_idx+10]}")
|
||||
print(f" expected: {expected_w2_weight[first_diff_idx:first_diff_idx+10]}")
|
||||
raise AssertionError(f"w2 weight bytes mismatch for TP {tp_idx}")
|
||||
|
||||
if not torch.allclose(w2_scale_bufs[tp_idx], expected_w2_scale):
|
||||
diff = torch.abs(w2_scale_bufs[tp_idx] - expected_w2_scale)
|
||||
max_diff_idx = diff.argmax().item()
|
||||
print(f" w2 scale mismatch, max diff at index {max_diff_idx}")
|
||||
print(f" actual: {w2_scale_bufs[tp_idx][max_diff_idx]}")
|
||||
print(f" expected: {expected_w2_scale[max_diff_idx]}")
|
||||
raise AssertionError(f"w2 scale values mismatch for TP {tp_idx}")
|
||||
|
||||
print(
|
||||
f"\n✓ write_weight_scale_to_buffer passed: extracted {gpu_experts} GPU experts across {gpu_tp_count} TP parts"
|
||||
)
|
||||
return True
|
||||
|
||||
|
||||
def main():
|
||||
"""Run tests for all gpu_tp_count values: 1, 2, 4, 8"""
|
||||
tp_values = [1, 2, 4] # Test TP=8
|
||||
all_passed = True
|
||||
results = {}
|
||||
|
||||
print("=" * 60)
|
||||
print("Testing FP8 write_weight_scale_to_buffer for TP = ", tp_values)
|
||||
print("=" * 60)
|
||||
|
||||
for tp in tp_values:
|
||||
print(f"\n{'='*60}")
|
||||
print(f"Testing with gpu_tp_count = {tp}")
|
||||
print(f"{'='*60}")
|
||||
try:
|
||||
test_with_tp(tp)
|
||||
results[tp] = "PASSED"
|
||||
print(f"✓ TP={tp} PASSED")
|
||||
except Exception as e:
|
||||
results[tp] = f"FAILED: {e}"
|
||||
all_passed = False
|
||||
print(f"✗ TP={tp} FAILED: {e}")
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("SUMMARY")
|
||||
print("=" * 60)
|
||||
for tp, result in results.items():
|
||||
status = "✓" if "PASSED" in result else "✗"
|
||||
print(f" {status} TP={tp}: {result}")
|
||||
|
||||
if all_passed:
|
||||
print("\n✓ ALL TESTS PASSED")
|
||||
else:
|
||||
print("\n✗ SOME TESTS FAILED")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
534
kt-kernel/examples/test_write_buffer.py
Normal file
534
kt-kernel/examples/test_write_buffer.py
Normal file
@@ -0,0 +1,534 @@
|
||||
"""
|
||||
Test write_weight_scale_to_buffer for AMX MOE operators.
|
||||
|
||||
Supports:
|
||||
- FP8: FP8 weights (1 byte) + float32 scales
|
||||
- BF16: Native BF16 weights (2 bytes), no scales
|
||||
|
||||
Usage:
|
||||
python test_write_buffer.py # Run all modes
|
||||
python test_write_buffer.py fp8 # Run FP8 only
|
||||
python test_write_buffer.py bf16 # Run BF16 only
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
|
||||
import torch
|
||||
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "build"))
|
||||
|
||||
from kt_kernel import kt_kernel_ext
|
||||
from kt_kernel_ext import CPUInfer
|
||||
|
||||
|
||||
def make_cpu_infer(thread_num=80):
|
||||
return CPUInfer(thread_num)
|
||||
|
||||
|
||||
def div_up(a, b):
|
||||
return (a + b - 1) // b
|
||||
|
||||
|
||||
def build_config_fp8(cpuinfer, expert_num, num_experts_per_tok, hidden_size, intermediate_size, group_size):
|
||||
cfg = kt_kernel_ext.moe.MOEConfig(expert_num, num_experts_per_tok, hidden_size, intermediate_size)
|
||||
cfg.max_len = 1
|
||||
cfg.quant_config.bits = 8 # FP8
|
||||
cfg.quant_config.group_size = group_size
|
||||
cfg.quant_config.zero_point = False
|
||||
cfg.pool = cpuinfer.backend_
|
||||
return cfg
|
||||
|
||||
|
||||
def build_config_bf16(cpuinfer, expert_num, num_experts_per_tok, hidden_size, intermediate_size):
|
||||
cfg = kt_kernel_ext.moe.MOEConfig(expert_num, num_experts_per_tok, hidden_size, intermediate_size)
|
||||
cfg.max_len = 1
|
||||
cfg.pool = cpuinfer.backend_
|
||||
return cfg
|
||||
|
||||
|
||||
def allocate_weights_fp8(expert_num, hidden_size, intermediate_size, group_size):
|
||||
"""Allocate FP8 weights and scales for testing"""
|
||||
# FP8 weights: 1 byte per element
|
||||
per_mat_weight_bytes = hidden_size * intermediate_size
|
||||
# FP8 scales: block-wise (group_size x group_size blocks), stored as float32
|
||||
n_blocks_n_gate_up = div_up(intermediate_size, group_size)
|
||||
n_blocks_k = div_up(hidden_size, group_size)
|
||||
per_mat_scale_elems_gate_up = n_blocks_n_gate_up * n_blocks_k
|
||||
|
||||
# For down: n=hidden_size, k=intermediate_size
|
||||
n_blocks_n_down = n_blocks_k
|
||||
n_blocks_k_down = n_blocks_n_gate_up
|
||||
per_mat_scale_elems_down = n_blocks_n_down * n_blocks_k_down
|
||||
|
||||
gate_q = torch.randint(0, 256, (expert_num * per_mat_weight_bytes,), dtype=torch.uint8)
|
||||
up_q = torch.randint(0, 256, (expert_num * per_mat_weight_bytes,), dtype=torch.uint8)
|
||||
down_q = torch.randint(0, 256, (expert_num * per_mat_weight_bytes,), dtype=torch.uint8)
|
||||
|
||||
gate_scale = torch.randn(expert_num * per_mat_scale_elems_gate_up, dtype=torch.float32)
|
||||
up_scale = torch.randn(expert_num * per_mat_scale_elems_gate_up, dtype=torch.float32)
|
||||
down_scale = torch.randn(expert_num * per_mat_scale_elems_down, dtype=torch.float32)
|
||||
|
||||
return {
|
||||
"gate_q": gate_q,
|
||||
"up_q": up_q,
|
||||
"down_q": down_q,
|
||||
"gate_scale": gate_scale,
|
||||
"up_scale": up_scale,
|
||||
"down_scale": down_scale,
|
||||
"per_mat_weight_bytes": per_mat_weight_bytes,
|
||||
"per_mat_scale_elems_gate_up": per_mat_scale_elems_gate_up,
|
||||
"per_mat_scale_elems_down": per_mat_scale_elems_down,
|
||||
}
|
||||
|
||||
|
||||
def allocate_weights_bf16(expert_num, hidden_size, intermediate_size):
|
||||
"""Allocate BF16 weights for testing (no scales)"""
|
||||
# BF16 weights: 2 bytes per element
|
||||
per_mat_weight_elems = hidden_size * intermediate_size
|
||||
per_mat_weight_bytes = per_mat_weight_elems * 2 # BF16 = 2 bytes
|
||||
|
||||
gate_proj = torch.randn(expert_num * per_mat_weight_elems, dtype=torch.bfloat16)
|
||||
up_proj = torch.randn(expert_num * per_mat_weight_elems, dtype=torch.bfloat16)
|
||||
down_proj = torch.randn(expert_num * per_mat_weight_elems, dtype=torch.bfloat16)
|
||||
|
||||
return {
|
||||
"gate_proj": gate_proj,
|
||||
"up_proj": up_proj,
|
||||
"down_proj": down_proj,
|
||||
"per_mat_weight_bytes": per_mat_weight_bytes,
|
||||
"per_mat_weight_elems": per_mat_weight_elems,
|
||||
}
|
||||
|
||||
|
||||
def test_fp8_write_buffer(gpu_tp_count):
|
||||
"""Test write_weight_scale_to_buffer with FP8 weights"""
|
||||
torch.manual_seed(123)
|
||||
|
||||
expert_num = 256
|
||||
gpu_experts = expert_num
|
||||
num_experts_per_tok = 8
|
||||
hidden_size = 3072
|
||||
intermediate_size = 1536
|
||||
group_size = 128
|
||||
|
||||
cpuinfer = make_cpu_infer()
|
||||
cfg = build_config_fp8(cpuinfer, expert_num, num_experts_per_tok, hidden_size, intermediate_size, group_size)
|
||||
weights = allocate_weights_fp8(expert_num, hidden_size, intermediate_size, group_size)
|
||||
|
||||
cfg.gate_proj = weights["gate_q"].data_ptr()
|
||||
cfg.up_proj = weights["up_q"].data_ptr()
|
||||
cfg.down_proj = weights["down_q"].data_ptr()
|
||||
cfg.gate_scale = weights["gate_scale"].data_ptr()
|
||||
cfg.up_scale = weights["up_scale"].data_ptr()
|
||||
cfg.down_scale = weights["down_scale"].data_ptr()
|
||||
|
||||
moe = kt_kernel_ext.moe.AMXFP8_MOE(cfg)
|
||||
|
||||
physical_to_logical_map = torch.arange(expert_num, dtype=torch.int64, device="cpu").contiguous()
|
||||
cpuinfer.submit(moe.load_weights_task(physical_to_logical_map.data_ptr()))
|
||||
cpuinfer.sync()
|
||||
|
||||
per_mat_weight_bytes = weights["per_mat_weight_bytes"]
|
||||
per_mat_scale_elems_gate_up = weights["per_mat_scale_elems_gate_up"]
|
||||
per_mat_scale_elems_down = weights["per_mat_scale_elems_down"]
|
||||
|
||||
# Calculate sizes per TP part
|
||||
weight_bytes_per_expert_per_tp = per_mat_weight_bytes // gpu_tp_count
|
||||
gpu_n_w13 = intermediate_size // gpu_tp_count
|
||||
gpu_k_w13 = hidden_size
|
||||
scale_elems_per_expert_per_tp_gate_up = div_up(gpu_n_w13, group_size) * div_up(gpu_k_w13, group_size)
|
||||
gpu_n_w2 = hidden_size
|
||||
gpu_k_w2 = intermediate_size // gpu_tp_count
|
||||
scale_elems_per_expert_per_tp_down = div_up(gpu_n_w2, group_size) * div_up(gpu_k_w2, group_size)
|
||||
|
||||
total_weight_bytes_per_tp = gpu_experts * weight_bytes_per_expert_per_tp
|
||||
total_scale_elems_per_tp_gate_up = gpu_experts * scale_elems_per_expert_per_tp_gate_up
|
||||
total_scale_elems_per_tp_down = gpu_experts * scale_elems_per_expert_per_tp_down
|
||||
|
||||
# Create buffer lists
|
||||
w13_weight_bufs = [torch.empty(2 * total_weight_bytes_per_tp, dtype=torch.uint8) for _ in range(gpu_tp_count)]
|
||||
w13_scale_bufs = [
|
||||
torch.empty(2 * total_scale_elems_per_tp_gate_up, dtype=torch.float32) for _ in range(gpu_tp_count)
|
||||
]
|
||||
w2_weight_bufs = [torch.empty(total_weight_bytes_per_tp, dtype=torch.uint8) for _ in range(gpu_tp_count)]
|
||||
w2_scale_bufs = [torch.empty(total_scale_elems_per_tp_down, dtype=torch.float32) for _ in range(gpu_tp_count)]
|
||||
|
||||
print(f"[FP8] GPU TP count: {gpu_tp_count}, Experts: {expert_num}")
|
||||
print(f"[FP8] Weight bytes per expert per TP: {weight_bytes_per_expert_per_tp}")
|
||||
print(f"[FP8] Scale elements per expert per TP (gate/up): {scale_elems_per_expert_per_tp_gate_up}")
|
||||
|
||||
def get_expert_ptrs(expert_id):
|
||||
w13_weight_ptrs = []
|
||||
w13_scale_ptrs = []
|
||||
w2_weight_ptrs = []
|
||||
w2_scale_ptrs = []
|
||||
for tp_idx in range(gpu_tp_count):
|
||||
w13_weight_expert_offset = expert_id * 2 * weight_bytes_per_expert_per_tp
|
||||
w13_scale_expert_offset = expert_id * 2 * scale_elems_per_expert_per_tp_gate_up
|
||||
w2_weight_expert_offset = expert_id * weight_bytes_per_expert_per_tp
|
||||
w2_scale_expert_offset = expert_id * scale_elems_per_expert_per_tp_down
|
||||
|
||||
w13_weight_ptrs.append(w13_weight_bufs[tp_idx].data_ptr() + w13_weight_expert_offset)
|
||||
w13_scale_ptrs.append(w13_scale_bufs[tp_idx].data_ptr() + w13_scale_expert_offset * 4)
|
||||
w2_weight_ptrs.append(w2_weight_bufs[tp_idx].data_ptr() + w2_weight_expert_offset)
|
||||
w2_scale_ptrs.append(w2_scale_bufs[tp_idx].data_ptr() + w2_scale_expert_offset * 4)
|
||||
return w13_weight_ptrs, w13_scale_ptrs, w2_weight_ptrs, w2_scale_ptrs
|
||||
|
||||
# Warm up
|
||||
for _ in range(2):
|
||||
for expert_id in range(gpu_experts):
|
||||
w13_weight_ptrs, w13_scale_ptrs, w2_weight_ptrs, w2_scale_ptrs = get_expert_ptrs(expert_id)
|
||||
cpuinfer.submit(
|
||||
moe.write_weight_scale_to_buffer_task(
|
||||
gpu_tp_count=gpu_tp_count,
|
||||
expert_id=expert_id,
|
||||
w13_weight_ptrs=w13_weight_ptrs,
|
||||
w13_scale_ptrs=w13_scale_ptrs,
|
||||
w2_weight_ptrs=w2_weight_ptrs,
|
||||
w2_scale_ptrs=w2_scale_ptrs,
|
||||
)
|
||||
)
|
||||
cpuinfer.sync()
|
||||
|
||||
# Timing
|
||||
begin_time = time.perf_counter_ns()
|
||||
for expert_id in range(gpu_experts):
|
||||
w13_weight_ptrs, w13_scale_ptrs, w2_weight_ptrs, w2_scale_ptrs = get_expert_ptrs(expert_id)
|
||||
cpuinfer.submit(
|
||||
moe.write_weight_scale_to_buffer_task(
|
||||
gpu_tp_count=gpu_tp_count,
|
||||
expert_id=expert_id,
|
||||
w13_weight_ptrs=w13_weight_ptrs,
|
||||
w13_scale_ptrs=w13_scale_ptrs,
|
||||
w2_weight_ptrs=w2_weight_ptrs,
|
||||
w2_scale_ptrs=w2_scale_ptrs,
|
||||
)
|
||||
)
|
||||
cpuinfer.sync()
|
||||
end_time = time.perf_counter_ns()
|
||||
elapsed_ms = (end_time - begin_time) / 1e6
|
||||
|
||||
total_bytes = (
|
||||
hidden_size * intermediate_size * gpu_experts * 3
|
||||
+ (per_mat_scale_elems_gate_up * 2 + per_mat_scale_elems_down) * gpu_experts * 4
|
||||
)
|
||||
print(f"[FP8] write_weight_scale_to_buffer time: {elapsed_ms:.2f} ms")
|
||||
print(f"[FP8] Throughput: {total_bytes / (elapsed_ms * 1e6):.2f} GB/s")
|
||||
|
||||
# Verify correctness
|
||||
def split_expert_tensor(tensor, chunk):
|
||||
return [tensor[i * chunk : (i + 1) * chunk] for i in range(expert_num)]
|
||||
|
||||
gate_q = weights["gate_q"]
|
||||
up_q = weights["up_q"]
|
||||
down_q = weights["down_q"]
|
||||
gate_scale = weights["gate_scale"]
|
||||
up_scale = weights["up_scale"]
|
||||
down_scale = weights["down_scale"]
|
||||
|
||||
gate_q_experts = split_expert_tensor(gate_q, per_mat_weight_bytes)
|
||||
up_q_experts = split_expert_tensor(up_q, per_mat_weight_bytes)
|
||||
down_q_experts = split_expert_tensor(down_q, per_mat_weight_bytes)
|
||||
gate_scale_experts = split_expert_tensor(gate_scale, per_mat_scale_elems_gate_up)
|
||||
up_scale_experts = split_expert_tensor(up_scale, per_mat_scale_elems_gate_up)
|
||||
down_scale_experts = split_expert_tensor(down_scale, per_mat_scale_elems_down)
|
||||
|
||||
n_blocks_n = div_up(hidden_size, group_size)
|
||||
n_blocks_k = div_up(intermediate_size, group_size)
|
||||
n_blocks_k_per_tp = n_blocks_k // gpu_tp_count
|
||||
|
||||
for tp_idx in range(gpu_tp_count):
|
||||
expected_w13_weights = []
|
||||
expected_w13_scales = []
|
||||
expected_w2_weights = []
|
||||
expected_w2_scales = []
|
||||
|
||||
weight13_per_tp = per_mat_weight_bytes // gpu_tp_count
|
||||
scale13_per_tp = per_mat_scale_elems_gate_up // gpu_tp_count
|
||||
|
||||
for expert_id in range(gpu_experts):
|
||||
start_weight = tp_idx * weight13_per_tp
|
||||
end_weight = (tp_idx + 1) * weight13_per_tp
|
||||
start_scale = tp_idx * scale13_per_tp
|
||||
end_scale = (tp_idx + 1) * scale13_per_tp
|
||||
|
||||
gate_weight_tp = gate_q_experts[expert_id][start_weight:end_weight]
|
||||
gate_scale_tp = gate_scale_experts[expert_id][start_scale:end_scale]
|
||||
up_weight_tp = up_q_experts[expert_id][start_weight:end_weight]
|
||||
up_scale_tp = up_scale_experts[expert_id][start_scale:end_scale]
|
||||
|
||||
down_weight_tp_parts = []
|
||||
down_scale_tp_parts = []
|
||||
tp_slice_weight_size = intermediate_size // gpu_tp_count
|
||||
|
||||
for row_idx in range(hidden_size):
|
||||
row_weight_start = row_idx * intermediate_size
|
||||
tp_weight_offset = row_weight_start + tp_idx * tp_slice_weight_size
|
||||
down_weight_tp_parts.append(
|
||||
down_q_experts[expert_id][tp_weight_offset : tp_weight_offset + tp_slice_weight_size]
|
||||
)
|
||||
|
||||
for bn in range(n_blocks_n):
|
||||
row_scale_start = bn * n_blocks_k
|
||||
tp_scale_offset = row_scale_start + tp_idx * n_blocks_k_per_tp
|
||||
down_scale_tp_parts.append(
|
||||
down_scale_experts[expert_id][tp_scale_offset : tp_scale_offset + n_blocks_k_per_tp]
|
||||
)
|
||||
|
||||
down_weight_tp = torch.cat(down_weight_tp_parts)
|
||||
down_scale_tp = torch.cat(down_scale_tp_parts)
|
||||
|
||||
expected_w13_weights.append(gate_weight_tp)
|
||||
expected_w13_weights.append(up_weight_tp)
|
||||
expected_w13_scales.append(gate_scale_tp)
|
||||
expected_w13_scales.append(up_scale_tp)
|
||||
expected_w2_weights.append(down_weight_tp)
|
||||
expected_w2_scales.append(down_scale_tp)
|
||||
|
||||
expected_w13_weight = torch.cat(expected_w13_weights)
|
||||
expected_w13_scale = torch.cat(expected_w13_scales)
|
||||
expected_w2_weight = torch.cat(expected_w2_weights)
|
||||
expected_w2_scale = torch.cat(expected_w2_scales)
|
||||
|
||||
if not torch.equal(w13_weight_bufs[tp_idx], expected_w13_weight):
|
||||
diff_mask = w13_weight_bufs[tp_idx] != expected_w13_weight
|
||||
first_diff_idx = diff_mask.nonzero()[0].item() if diff_mask.any() else -1
|
||||
raise AssertionError(f"[FP8] w13 weight mismatch for TP {tp_idx} at index {first_diff_idx}")
|
||||
|
||||
if not torch.allclose(w13_scale_bufs[tp_idx], expected_w13_scale):
|
||||
raise AssertionError(f"[FP8] w13 scale mismatch for TP {tp_idx}")
|
||||
|
||||
if not torch.equal(w2_weight_bufs[tp_idx], expected_w2_weight):
|
||||
diff_mask = w2_weight_bufs[tp_idx] != expected_w2_weight
|
||||
first_diff_idx = diff_mask.nonzero()[0].item() if diff_mask.any() else -1
|
||||
raise AssertionError(f"[FP8] w2 weight mismatch for TP {tp_idx} at index {first_diff_idx}")
|
||||
|
||||
if not torch.allclose(w2_scale_bufs[tp_idx], expected_w2_scale):
|
||||
raise AssertionError(f"[FP8] w2 scale mismatch for TP {tp_idx}")
|
||||
|
||||
print(f"[FP8] TP={gpu_tp_count} PASSED (verified {gpu_experts} experts across {gpu_tp_count} TP parts)")
|
||||
return True
|
||||
|
||||
|
||||
def test_bf16_write_buffer(gpu_tp_count):
|
||||
"""Test write_weight_scale_to_buffer with BF16 weights (no scales)"""
|
||||
torch.manual_seed(123)
|
||||
|
||||
expert_num = 16
|
||||
gpu_experts = expert_num
|
||||
num_experts_per_tok = 8
|
||||
hidden_size = 3072
|
||||
intermediate_size = 1536
|
||||
|
||||
cpuinfer = make_cpu_infer()
|
||||
cfg = build_config_bf16(cpuinfer, expert_num, num_experts_per_tok, hidden_size, intermediate_size)
|
||||
weights = allocate_weights_bf16(expert_num, hidden_size, intermediate_size)
|
||||
|
||||
cfg.gate_proj = weights["gate_proj"].data_ptr()
|
||||
cfg.up_proj = weights["up_proj"].data_ptr()
|
||||
cfg.down_proj = weights["down_proj"].data_ptr()
|
||||
cfg.gate_scale = 0
|
||||
cfg.up_scale = 0
|
||||
cfg.down_scale = 0
|
||||
|
||||
moe = kt_kernel_ext.moe.AMXBF16_MOE(cfg)
|
||||
|
||||
physical_to_logical_map = torch.arange(expert_num, dtype=torch.int64, device="cpu").contiguous()
|
||||
cpuinfer.submit(moe.load_weights_task(physical_to_logical_map.data_ptr()))
|
||||
cpuinfer.sync()
|
||||
|
||||
per_mat_weight_elems = weights["per_mat_weight_elems"]
|
||||
|
||||
# Calculate sizes per TP part (BF16 = 2 bytes per element)
|
||||
weight_elems_per_expert_per_tp = per_mat_weight_elems // gpu_tp_count
|
||||
weight_bytes_per_expert_per_tp = weight_elems_per_expert_per_tp * 2
|
||||
|
||||
total_weight_bytes_per_tp = gpu_experts * weight_bytes_per_expert_per_tp
|
||||
|
||||
# Create buffer lists (BF16: weights only, no scales)
|
||||
w13_weight_bufs = [torch.empty(2 * total_weight_bytes_per_tp, dtype=torch.uint8) for _ in range(gpu_tp_count)]
|
||||
w2_weight_bufs = [torch.empty(total_weight_bytes_per_tp, dtype=torch.uint8) for _ in range(gpu_tp_count)]
|
||||
# Empty scale buffers (not used for BF16 but needed for interface)
|
||||
w13_scale_bufs = [torch.empty(1, dtype=torch.float32) for _ in range(gpu_tp_count)]
|
||||
w2_scale_bufs = [torch.empty(1, dtype=torch.float32) for _ in range(gpu_tp_count)]
|
||||
|
||||
print(f"[BF16] GPU TP count: {gpu_tp_count}, Experts: {expert_num}")
|
||||
print(f"[BF16] Weight bytes per expert per TP: {weight_bytes_per_expert_per_tp}")
|
||||
|
||||
def get_expert_ptrs(expert_id):
|
||||
w13_weight_ptrs = []
|
||||
w13_scale_ptrs = []
|
||||
w2_weight_ptrs = []
|
||||
w2_scale_ptrs = []
|
||||
for tp_idx in range(gpu_tp_count):
|
||||
w13_weight_expert_offset = expert_id * 2 * weight_bytes_per_expert_per_tp
|
||||
w2_weight_expert_offset = expert_id * weight_bytes_per_expert_per_tp
|
||||
|
||||
w13_weight_ptrs.append(w13_weight_bufs[tp_idx].data_ptr() + w13_weight_expert_offset)
|
||||
w13_scale_ptrs.append(w13_scale_bufs[tp_idx].data_ptr()) # Not used
|
||||
w2_weight_ptrs.append(w2_weight_bufs[tp_idx].data_ptr() + w2_weight_expert_offset)
|
||||
w2_scale_ptrs.append(w2_scale_bufs[tp_idx].data_ptr()) # Not used
|
||||
return w13_weight_ptrs, w13_scale_ptrs, w2_weight_ptrs, w2_scale_ptrs
|
||||
|
||||
# Warm up
|
||||
for _ in range(2):
|
||||
for expert_id in range(gpu_experts):
|
||||
w13_weight_ptrs, w13_scale_ptrs, w2_weight_ptrs, w2_scale_ptrs = get_expert_ptrs(expert_id)
|
||||
cpuinfer.submit(
|
||||
moe.write_weight_scale_to_buffer_task(
|
||||
gpu_tp_count=gpu_tp_count,
|
||||
expert_id=expert_id,
|
||||
w13_weight_ptrs=w13_weight_ptrs,
|
||||
w13_scale_ptrs=w13_scale_ptrs,
|
||||
w2_weight_ptrs=w2_weight_ptrs,
|
||||
w2_scale_ptrs=w2_scale_ptrs,
|
||||
)
|
||||
)
|
||||
cpuinfer.sync()
|
||||
|
||||
# Timing
|
||||
begin_time = time.perf_counter_ns()
|
||||
for expert_id in range(gpu_experts):
|
||||
w13_weight_ptrs, w13_scale_ptrs, w2_weight_ptrs, w2_scale_ptrs = get_expert_ptrs(expert_id)
|
||||
cpuinfer.submit(
|
||||
moe.write_weight_scale_to_buffer_task(
|
||||
gpu_tp_count=gpu_tp_count,
|
||||
expert_id=expert_id,
|
||||
w13_weight_ptrs=w13_weight_ptrs,
|
||||
w13_scale_ptrs=w13_scale_ptrs,
|
||||
w2_weight_ptrs=w2_weight_ptrs,
|
||||
w2_scale_ptrs=w2_scale_ptrs,
|
||||
)
|
||||
)
|
||||
cpuinfer.sync()
|
||||
end_time = time.perf_counter_ns()
|
||||
elapsed_ms = (end_time - begin_time) / 1e6
|
||||
|
||||
total_bytes = hidden_size * intermediate_size * gpu_experts * 3 * 2 # BF16 = 2 bytes
|
||||
print(f"[BF16] write_weight_scale_to_buffer time: {elapsed_ms:.2f} ms")
|
||||
print(f"[BF16] Throughput: {total_bytes / (elapsed_ms * 1e6):.2f} GB/s")
|
||||
|
||||
# Verify correctness (BF16: weights only, no scales)
|
||||
def split_expert_tensor(tensor, chunk):
|
||||
return [tensor[i * chunk : (i + 1) * chunk] for i in range(expert_num)]
|
||||
|
||||
gate_proj = weights["gate_proj"]
|
||||
up_proj = weights["up_proj"]
|
||||
down_proj = weights["down_proj"]
|
||||
|
||||
# View BF16 as uint8 for byte-level comparison
|
||||
gate_bytes = gate_proj.view(torch.uint8)
|
||||
up_bytes = up_proj.view(torch.uint8)
|
||||
down_bytes = down_proj.view(torch.uint8)
|
||||
|
||||
per_mat_bytes = per_mat_weight_elems * 2 # BF16 = 2 bytes
|
||||
gate_experts = split_expert_tensor(gate_bytes, per_mat_bytes)
|
||||
up_experts = split_expert_tensor(up_bytes, per_mat_bytes)
|
||||
down_experts = split_expert_tensor(down_bytes, per_mat_bytes)
|
||||
|
||||
for tp_idx in range(gpu_tp_count):
|
||||
expected_w13_weights = []
|
||||
expected_w2_weights = []
|
||||
|
||||
weight_bytes_per_tp = per_mat_bytes // gpu_tp_count
|
||||
|
||||
for expert_id in range(gpu_experts):
|
||||
start_weight = tp_idx * weight_bytes_per_tp
|
||||
end_weight = (tp_idx + 1) * weight_bytes_per_tp
|
||||
|
||||
gate_weight_tp = gate_experts[expert_id][start_weight:end_weight]
|
||||
up_weight_tp = up_experts[expert_id][start_weight:end_weight]
|
||||
|
||||
# Down matrix: sliced column-wise (BF16 = 2 bytes per element)
|
||||
down_weight_tp_parts = []
|
||||
tp_slice_elems = intermediate_size // gpu_tp_count
|
||||
tp_slice_bytes = tp_slice_elems * 2
|
||||
|
||||
for row_idx in range(hidden_size):
|
||||
row_byte_start = row_idx * intermediate_size * 2
|
||||
tp_byte_offset = row_byte_start + tp_idx * tp_slice_bytes
|
||||
down_weight_tp_parts.append(down_experts[expert_id][tp_byte_offset : tp_byte_offset + tp_slice_bytes])
|
||||
|
||||
down_weight_tp = torch.cat(down_weight_tp_parts)
|
||||
|
||||
expected_w13_weights.append(gate_weight_tp)
|
||||
expected_w13_weights.append(up_weight_tp)
|
||||
expected_w2_weights.append(down_weight_tp)
|
||||
|
||||
expected_w13_weight = torch.cat(expected_w13_weights)
|
||||
expected_w2_weight = torch.cat(expected_w2_weights)
|
||||
|
||||
if not torch.equal(w13_weight_bufs[tp_idx], expected_w13_weight):
|
||||
diff_mask = w13_weight_bufs[tp_idx] != expected_w13_weight
|
||||
first_diff_idx = diff_mask.nonzero()[0].item() if diff_mask.any() else -1
|
||||
raise AssertionError(f"[BF16] w13 weight mismatch for TP {tp_idx} at index {first_diff_idx}")
|
||||
|
||||
if not torch.equal(w2_weight_bufs[tp_idx], expected_w2_weight):
|
||||
diff_mask = w2_weight_bufs[tp_idx] != expected_w2_weight
|
||||
first_diff_idx = diff_mask.nonzero()[0].item() if diff_mask.any() else -1
|
||||
raise AssertionError(f"[BF16] w2 weight mismatch for TP {tp_idx} at index {first_diff_idx}")
|
||||
|
||||
print(f"[BF16] TP={gpu_tp_count} PASSED (verified {gpu_experts} experts across {gpu_tp_count} TP parts)")
|
||||
return True
|
||||
|
||||
|
||||
def test_with_tp(quant_mode: str, gpu_tp_count: int):
|
||||
"""Test write_weight_scale_to_buffer with specified mode and TP count"""
|
||||
if quant_mode == "fp8":
|
||||
return test_fp8_write_buffer(gpu_tp_count)
|
||||
elif quant_mode == "bf16":
|
||||
return test_bf16_write_buffer(gpu_tp_count)
|
||||
else:
|
||||
raise ValueError(f"Unsupported quant_mode: {quant_mode}")
|
||||
|
||||
|
||||
def main(quant_modes=None):
|
||||
"""Run tests for specified quant modes"""
|
||||
if quant_modes is None:
|
||||
quant_modes = ["fp8", "bf16"]
|
||||
|
||||
tp_values = [1, 2, 4]
|
||||
all_passed = True
|
||||
results = {}
|
||||
|
||||
for quant_mode in quant_modes:
|
||||
print("\n" + "=" * 60)
|
||||
print(f"Testing {quant_mode.upper()} write_weight_scale_to_buffer")
|
||||
print("=" * 60)
|
||||
|
||||
for tp in tp_values:
|
||||
print(f"\n--- Testing {quant_mode.upper()} with gpu_tp_count = {tp} ---")
|
||||
try:
|
||||
test_with_tp(quant_mode, tp)
|
||||
results[(quant_mode, tp)] = "PASSED"
|
||||
except Exception as e:
|
||||
results[(quant_mode, tp)] = f"FAILED: {e}"
|
||||
all_passed = False
|
||||
print(f"[{quant_mode.upper()}] TP={tp} FAILED: {e}")
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("SUMMARY")
|
||||
print("=" * 60)
|
||||
for (mode, tp), result in results.items():
|
||||
status = "PASS" if "PASSED" in result else "FAIL"
|
||||
print(f" [{status}] {mode.upper()} TP={tp}: {result}")
|
||||
|
||||
if all_passed:
|
||||
print("\nALL TESTS PASSED")
|
||||
else:
|
||||
print("\nSOME TESTS FAILED")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if len(sys.argv) > 1:
|
||||
mode = sys.argv[1].lower()
|
||||
if mode in ["fp8", "bf16"]:
|
||||
main([mode])
|
||||
else:
|
||||
print(f"Unknown mode: {mode}. Use 'fp8' or 'bf16'")
|
||||
sys.exit(1)
|
||||
else:
|
||||
main()
|
||||
@@ -37,7 +37,8 @@ static const bool _is_plain_ = false;
|
||||
#if defined(__x86_64__) && defined(USE_AMX_AVX_KERNEL)
|
||||
#include "operators/amx/awq-moe.hpp"
|
||||
#if defined(__AVX512BF16__)
|
||||
#include "operators/amx/fp8-moe.hpp" // FP8 MoE requires AVX512 BF16 support
|
||||
#include "operators/amx/bf16-moe.hpp" // Native BF16 MoE using CRTP pattern
|
||||
#include "operators/amx/fp8-moe.hpp" // FP8 MoE requires AVX512 BF16 support
|
||||
#endif
|
||||
#include "operators/amx/k2-moe.hpp"
|
||||
#include "operators/amx/la/amx_kernels.hpp"
|
||||
@@ -340,6 +341,51 @@ void bind_moe_module(py::module_& moe_module, const char* name) {
|
||||
py::arg("gpu_tp_count"), py::arg("expert_id"), py::arg("w13_weight_ptrs"), py::arg("w13_scale_ptrs"),
|
||||
py::arg("w2_weight_ptrs"), py::arg("w2_scale_ptrs"));
|
||||
}
|
||||
|
||||
// BF16 MoE: processes one expert at a time (expert_id instead of gpu_experts_num)
|
||||
// Only available on CPUs with AVX512 BF16 support
|
||||
if constexpr (std::is_same_v<MoeTP, AMX_BF16_MOE_TP<amx::GemmKernel224BF16>>) {
|
||||
struct WriteWeightScaleToBufferBindings {
|
||||
struct Args {
|
||||
CPUInfer* cpuinfer;
|
||||
MoeClass* moe;
|
||||
int gpu_tp_count;
|
||||
int expert_id;
|
||||
std::vector<uintptr_t> w13_weight_ptrs;
|
||||
std::vector<uintptr_t> w13_scale_ptrs;
|
||||
std::vector<uintptr_t> w2_weight_ptrs;
|
||||
std::vector<uintptr_t> w2_scale_ptrs;
|
||||
};
|
||||
|
||||
static void inner(void* args) {
|
||||
Args* args_ = (Args*)args;
|
||||
args_->cpuinfer->enqueue(&MoeClass::write_weight_scale_to_buffer, args_->moe, args_->gpu_tp_count,
|
||||
args_->expert_id, args_->w13_weight_ptrs, args_->w13_scale_ptrs, args_->w2_weight_ptrs,
|
||||
args_->w2_scale_ptrs);
|
||||
}
|
||||
|
||||
static std::pair<intptr_t, intptr_t> cpuinfer_interface(std::shared_ptr<MoeClass> moe, int gpu_tp_count,
|
||||
int expert_id, py::list w13_weight_ptrs,
|
||||
py::list w13_scale_ptrs, py::list w2_weight_ptrs,
|
||||
py::list w2_scale_ptrs) {
|
||||
// Convert Python lists to std::vector<uintptr_t>
|
||||
std::vector<uintptr_t> w13_weight_vec, w13_scale_vec, w2_weight_vec, w2_scale_vec;
|
||||
|
||||
for (auto item : w13_weight_ptrs) w13_weight_vec.push_back(py::cast<uintptr_t>(item));
|
||||
for (auto item : w13_scale_ptrs) w13_scale_vec.push_back(py::cast<uintptr_t>(item));
|
||||
for (auto item : w2_weight_ptrs) w2_weight_vec.push_back(py::cast<uintptr_t>(item));
|
||||
for (auto item : w2_scale_ptrs) w2_scale_vec.push_back(py::cast<uintptr_t>(item));
|
||||
|
||||
Args* args = new Args{nullptr, moe.get(), gpu_tp_count, expert_id,
|
||||
w13_weight_vec, w13_scale_vec, w2_weight_vec, w2_scale_vec};
|
||||
return std::make_pair((intptr_t)&inner, (intptr_t)args);
|
||||
}
|
||||
};
|
||||
|
||||
moe_cls.def("write_weight_scale_to_buffer_task", &WriteWeightScaleToBufferBindings::cpuinfer_interface,
|
||||
py::arg("gpu_tp_count"), py::arg("expert_id"), py::arg("w13_weight_ptrs"), py::arg("w13_scale_ptrs"),
|
||||
py::arg("w2_weight_ptrs"), py::arg("w2_scale_ptrs"));
|
||||
}
|
||||
#endif // __AVX512BF16__
|
||||
#endif
|
||||
}
|
||||
@@ -606,13 +652,13 @@ PYBIND11_MODULE(kt_kernel_ext, m) {
|
||||
bind_moe_module<LLAMA_MOE_TP>(moe_module, "MOE");
|
||||
|
||||
#if defined(__x86_64__) && defined(USE_AMX_AVX_KERNEL)
|
||||
bind_moe_module<AMX_MOE_TP<amx::GemmKernel224BF>>(moe_module, "AMXBF16_MOE");
|
||||
bind_moe_module<AMX_MOE_TP<amx::GemmKernel224Int8>>(moe_module, "AMXInt8_MOE");
|
||||
bind_moe_module<AMX_MOE_TP<amx::GemmKernel224Int4>>(moe_module, "AMXInt4_MOE");
|
||||
bind_moe_module<AMX_MOE_TP<amx::GemmKernel224Int4_1>>(moe_module, "AMXInt4_1_MOE");
|
||||
bind_moe_module<AMX_AWQ_MOE_TP<amx::GemmKernel224Int4_1_LowKGroup>>(moe_module, "AMXInt4_1KGroup_MOE");
|
||||
bind_moe_module<AMX_K2_MOE_TP<amx::GemmKernel224Int4SmallKGroup>>(moe_module, "AMXInt4_KGroup_MOE");
|
||||
#if defined(__AVX512BF16__)
|
||||
bind_moe_module<AMX_BF16_MOE_TP<amx::GemmKernel224BF16>>(moe_module, "AMXBF16_MOE");
|
||||
bind_moe_module<AMX_FP8_MOE_TP<amx::GemmKernel224FP8>>(moe_module, "AMXFP8_MOE");
|
||||
#endif
|
||||
#endif
|
||||
|
||||
536
kt-kernel/operators/amx/bf16-moe.hpp
Normal file
536
kt-kernel/operators/amx/bf16-moe.hpp
Normal file
@@ -0,0 +1,536 @@
|
||||
/**
|
||||
* @Description : BF16 AMX MoE operator for native BF16 inference
|
||||
* @Author : oql, Codex and Claude
|
||||
* @Date : 2026-01-06
|
||||
* @Version : 1.0.0
|
||||
* @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
|
||||
*
|
||||
* This file implements BF16 MoE using CRTP pattern, inheriting from moe_base.hpp.
|
||||
* BF16 weights are stored without quantization (no scales).
|
||||
**/
|
||||
#ifndef CPUINFER_OPERATOR_AMX_BF16_MOE_H
|
||||
#define CPUINFER_OPERATOR_AMX_BF16_MOE_H
|
||||
|
||||
// #define DEBUG_BF16_MOE
|
||||
|
||||
#include "la/amx_kernels.hpp" // For vec_mul/mat_mul
|
||||
#include "la/amx_raw_buffers.hpp"
|
||||
#include "la/amx_raw_kernels.hpp"
|
||||
#include "la/amx_utils.hpp" // For transpose_16x16_32bit
|
||||
#include "moe_base.hpp"
|
||||
|
||||
/**
|
||||
* @brief BF16 MoE operator using CRTP pattern
|
||||
* @tparam T Kernel type, defaults to GemmKernel224BF16
|
||||
*
|
||||
* This class provides BF16-specific implementations:
|
||||
* - do_gate_up_gemm, do_down_gemm: BF16 weight mat mul (no quantization)
|
||||
* - load_weights: Load native BF16 weights (no scales)
|
||||
*/
|
||||
template <class T = amx::GemmKernel224BF16>
|
||||
class AMX_BF16_MOE_TP : public AMX_MOE_BASE<T, AMX_BF16_MOE_TP<T>> {
|
||||
using Base = AMX_MOE_BASE<T, AMX_BF16_MOE_TP<T>>;
|
||||
using Base::config_;
|
||||
using Base::down_ba_;
|
||||
using Base::down_bb_;
|
||||
using Base::down_bc_;
|
||||
using Base::gate_bb_;
|
||||
using Base::gate_bc_;
|
||||
using Base::gate_up_ba_;
|
||||
using Base::m_local_num_;
|
||||
using Base::tp_part_idx;
|
||||
using Base::up_bb_;
|
||||
using Base::up_bc_;
|
||||
|
||||
public:
|
||||
using typename Base::input_t;
|
||||
using typename Base::output_t;
|
||||
|
||||
AMX_BF16_MOE_TP() = default;
|
||||
|
||||
AMX_BF16_MOE_TP(GeneralMOEConfig config, int tp_part_idx_ = 0) : Base(config, tp_part_idx_) {
|
||||
// Initialization now happens in derived_init() which is called by base constructor
|
||||
}
|
||||
|
||||
void derived_init() {
|
||||
// BF16 has no quantization, no need to check quant_config
|
||||
printf("Created AMX_BF16_MOE_TP %d at numa %d\n", tp_part_idx, numa_node_of_cpu(sched_getcpu()));
|
||||
}
|
||||
|
||||
~AMX_BF16_MOE_TP() = default;
|
||||
|
||||
// ============================================================================
|
||||
// CRTP buffer creation - without group_size
|
||||
// ============================================================================
|
||||
|
||||
size_t buffer_a_required_size_impl(size_t m, size_t k) const { return T::BufferA::required_size(m, k); }
|
||||
|
||||
size_t buffer_b_required_size_impl(size_t n, size_t k) const {
|
||||
return T::BufferB::required_size(n, k); // 2 parameters - no group_size
|
||||
}
|
||||
|
||||
size_t buffer_c_required_size_impl(size_t m, size_t n) const { return T::BufferC::required_size(m, n); }
|
||||
|
||||
std::shared_ptr<typename T::BufferA> make_buffer_a_impl(size_t m, size_t k, void* data) const {
|
||||
return std::make_shared<typename T::BufferA>(m, k, data);
|
||||
}
|
||||
|
||||
std::shared_ptr<typename T::BufferB> make_buffer_b_impl(size_t n, size_t k, void* data) const {
|
||||
return std::make_shared<typename T::BufferB>(n, k, data); // 2 parameters - no group_size
|
||||
}
|
||||
|
||||
std::shared_ptr<typename T::BufferC> make_buffer_c_impl(size_t m, size_t n, void* data) const {
|
||||
return std::make_shared<typename T::BufferC>(m, n, data);
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// CRTP virtual points - GEMM dispatch
|
||||
// ============================================================================
|
||||
|
||||
void do_gate_up_gemm(bool do_up, int expert_idx, int ith, int nth, int qlen) {
|
||||
int m = m_local_num_[expert_idx];
|
||||
auto& ba = gate_up_ba_[expert_idx];
|
||||
auto& bb = do_up ? up_bb_[expert_idx] : gate_bb_[expert_idx];
|
||||
auto& bc = do_up ? up_bc_[expert_idx] : gate_bc_[expert_idx];
|
||||
|
||||
// Use vec_mul/mat_mul (no group_size)
|
||||
if (qlen > 4 * config_.expert_num / config_.num_experts_per_tok) {
|
||||
amx::mat_mul(m, config_.intermediate_size, config_.hidden_size, ba, bb, bc, ith, nth);
|
||||
} else {
|
||||
amx::vec_mul(m, config_.intermediate_size, config_.hidden_size, ba, bb, bc, ith, nth);
|
||||
}
|
||||
}
|
||||
|
||||
void do_down_gemm(int expert_idx, int ith, int nth, int qlen) {
|
||||
int m = m_local_num_[expert_idx];
|
||||
|
||||
if (qlen > 4 * config_.expert_num / config_.num_experts_per_tok) {
|
||||
amx::mat_mul(m, config_.hidden_size, config_.intermediate_size, down_ba_[expert_idx], down_bb_[expert_idx],
|
||||
down_bc_[expert_idx], ith, nth);
|
||||
} else {
|
||||
amx::vec_mul(m, config_.hidden_size, config_.intermediate_size, down_ba_[expert_idx], down_bb_[expert_idx],
|
||||
down_bc_[expert_idx], ith, nth);
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef DEBUG_BF16_MOE
|
||||
// Function to dump Buffer B data for debugging
|
||||
inline void dump_buffer_b(int expert_idx, const std::string& matrix_type, typename T::BufferB* buffer) {
|
||||
printf("[DUMP_BUFFER_B] TP%d BF16 Expert%d %s:\n", tp_part_idx, expert_idx, matrix_type.c_str());
|
||||
|
||||
// Calculate dimensions based on matrix type
|
||||
int rows, cols;
|
||||
if (matrix_type == "gate" || matrix_type == "up") {
|
||||
rows = config_.intermediate_size;
|
||||
cols = config_.hidden_size;
|
||||
} else { // down
|
||||
rows = config_.hidden_size;
|
||||
cols = config_.intermediate_size;
|
||||
}
|
||||
|
||||
// Dump BF16 weights
|
||||
size_t weight_size = (size_t)rows * cols;
|
||||
ggml_bf16_t* weight_ptr = buffer->b;
|
||||
|
||||
printf(" BF16 Weights[first 16]: ");
|
||||
for (int i = 0; i < std::min(16, (int)weight_size); i++) {
|
||||
printf("%.6f ", ggml_bf16_to_fp32(weight_ptr[i]));
|
||||
}
|
||||
printf("\n");
|
||||
|
||||
if (weight_size > 16) {
|
||||
printf(" BF16 Weights[last 16]: ");
|
||||
int start_idx = std::max(0, (int)weight_size - 16);
|
||||
for (int i = start_idx; i < (int)weight_size; i++) {
|
||||
printf("%.6f ", ggml_bf16_to_fp32(weight_ptr[i]));
|
||||
}
|
||||
printf("\n");
|
||||
}
|
||||
|
||||
printf(" Matrix dimensions: %dx%d (n x k)\n", rows, cols);
|
||||
}
|
||||
#endif
|
||||
|
||||
/**
|
||||
* @brief Load BF16 weights from contiguous memory layout
|
||||
*
|
||||
* Loads weights from config_.gate_proj, up_proj, down_proj (no scales).
|
||||
*/
|
||||
void load_weights() {
|
||||
const uint64_t* physical_to_logical_map = (const uint64_t*)config_.physical_to_logical_map;
|
||||
auto pool = config_.pool->get_subpool(tp_part_idx);
|
||||
|
||||
if (config_.gate_proj == nullptr) {
|
||||
throw std::runtime_error("BF16 MOE requires native BF16 weight.");
|
||||
}
|
||||
|
||||
// Load gate + up weights
|
||||
int nth = T::recommended_nth(config_.intermediate_size);
|
||||
pool->do_work_stealing_job(
|
||||
nth * config_.expert_num, nullptr,
|
||||
[this, nth, physical_to_logical_map](int task_id) {
|
||||
uint64_t expert_idx = task_id / nth;
|
||||
uint64_t logical_expert_id = expert_map(physical_to_logical_map, expert_idx);
|
||||
int ith = task_id % nth;
|
||||
|
||||
// Gate: from BF16 data (no scale)
|
||||
gate_bb_[expert_idx]->from_mat(
|
||||
(ggml_bf16_t*)config_.gate_proj + (logical_expert_id * config_.intermediate_size * config_.hidden_size),
|
||||
ith, nth); // 3 parameters: (bf16*, ith, nth)
|
||||
|
||||
// Up: same
|
||||
up_bb_[expert_idx]->from_mat(
|
||||
(ggml_bf16_t*)config_.up_proj + (logical_expert_id * config_.intermediate_size * config_.hidden_size),
|
||||
ith, nth);
|
||||
},
|
||||
nullptr);
|
||||
|
||||
// Load down weights
|
||||
nth = T::recommended_nth(config_.hidden_size);
|
||||
pool->do_work_stealing_job(
|
||||
nth * config_.expert_num, nullptr,
|
||||
[this, nth, physical_to_logical_map](int task_id) {
|
||||
uint64_t expert_idx = task_id / nth;
|
||||
uint64_t logical_expert_id = expert_map(physical_to_logical_map, expert_idx);
|
||||
int ith = task_id % nth;
|
||||
|
||||
// Down
|
||||
down_bb_[expert_idx]->from_mat(
|
||||
(ggml_bf16_t*)config_.down_proj + (logical_expert_id * config_.intermediate_size * config_.hidden_size),
|
||||
ith, nth);
|
||||
},
|
||||
nullptr);
|
||||
|
||||
#ifdef DEBUG_BF16_MOE
|
||||
dump_buffer_b(0, "gate", gate_bb_[0].get());
|
||||
dump_buffer_b(0, "down", down_bb_[0].get());
|
||||
#endif
|
||||
}
|
||||
|
||||
// Fast 64-byte (512-bit) memcpy using AVX512
|
||||
static inline void fast_memcpy_64(void* __restrict dst, const void* __restrict src) {
|
||||
__m512i data = _mm512_loadu_si512(src);
|
||||
_mm512_storeu_si512(dst, data);
|
||||
}
|
||||
|
||||
// Fast memcpy for arbitrary sizes using AVX512
|
||||
static inline void fast_memcpy(void* __restrict dst, const void* __restrict src, size_t bytes) {
|
||||
uint8_t* d = (uint8_t*)dst;
|
||||
const uint8_t* s = (const uint8_t*)src;
|
||||
size_t chunks = bytes / 64;
|
||||
for (size_t i = 0; i < chunks; i++) {
|
||||
fast_memcpy_64(d, s);
|
||||
d += 64;
|
||||
s += 64;
|
||||
}
|
||||
bytes -= chunks * 64;
|
||||
if (bytes > 0) {
|
||||
std::memcpy(d, s, bytes);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Unpack a single N_STEP x K_STEP block from packed BufferB format to n-major format (BF16 version)
|
||||
*
|
||||
* This is the inverse of the packing done in BufferBBF16Impl::from_mat.
|
||||
* BF16 elements are 2 bytes, and the packed format includes 16x16 32-bit transpose.
|
||||
*
|
||||
* @param src Pointer to packed data (N_STEP * K_STEP * 2 bytes in packed layout)
|
||||
* @param dst Pointer to destination in n-major layout
|
||||
* @param dst_row_stride Row stride in destination buffer (number of BF16 elements per row)
|
||||
*/
|
||||
static inline void unpack_nk_block_bf16(const ggml_bf16_t* src, ggml_bf16_t* dst, size_t dst_row_stride) {
|
||||
constexpr int N_STEP = T::N_STEP; // 32
|
||||
constexpr int K_STEP = T::K_STEP; // 32
|
||||
constexpr int TILE_N = T::TILE_N; // 16
|
||||
|
||||
// The packed format has two 16x16 blocks (32-bit view) that were transposed
|
||||
// We need to reverse the transpose first, then copy to n-major layout
|
||||
|
||||
// Create aligned temporary buffers for transpose
|
||||
alignas(64) __m512i temp_block1[TILE_N];
|
||||
alignas(64) __m512i temp_block2[TILE_N];
|
||||
|
||||
// Copy source data to temporary buffers
|
||||
const __m512i* src_vec = reinterpret_cast<const __m512i*>(src);
|
||||
for (int i = 0; i < TILE_N; i++) {
|
||||
temp_block1[i] = src_vec[i];
|
||||
temp_block2[i] = src_vec[TILE_N + i];
|
||||
}
|
||||
|
||||
// Reverse transpose (transpose is self-inverse)
|
||||
amx::transpose_16x16_32bit(temp_block1);
|
||||
amx::transpose_16x16_32bit(temp_block2);
|
||||
|
||||
// Copy transposed data to destination in n-major layout
|
||||
const ggml_bf16_t* temp1_bf16 = reinterpret_cast<const ggml_bf16_t*>(temp_block1);
|
||||
const ggml_bf16_t* temp2_bf16 = reinterpret_cast<const ggml_bf16_t*>(temp_block2);
|
||||
|
||||
// First 16 rows (block 1)
|
||||
for (int i = 0; i < TILE_N; i++) {
|
||||
std::memcpy(dst + i * dst_row_stride, temp1_bf16 + i * K_STEP, K_STEP * sizeof(ggml_bf16_t));
|
||||
}
|
||||
|
||||
// Next 16 rows (block 2)
|
||||
for (int i = 0; i < TILE_N; i++) {
|
||||
std::memcpy(dst + (TILE_N + i) * dst_row_stride, temp2_bf16 + i * K_STEP, K_STEP * sizeof(ggml_bf16_t));
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Reconstruct weights for a single expert to the output buffers
|
||||
*
|
||||
* Directly unpacks from packed BufferB format to n-major GPU buffers without intermediate storage.
|
||||
* BF16 version - no scales needed.
|
||||
*
|
||||
* @param gpu_tp_count Number of GPU TP parts (1, 2, 4, or 8)
|
||||
* @param cpu_tp_count Number of CPU TP parts
|
||||
* @param expert_id Expert index to process
|
||||
* @param full_config Full configuration (before CPU TP split)
|
||||
* @param w13_weight_ptrs Pointers to gate+up weight buffers (one per GPU TP)
|
||||
* @param w13_scale_ptrs Pointers to gate+up scale buffers (unused for BF16, kept for interface compatibility)
|
||||
* @param w2_weight_ptrs Pointers to down weight buffers (one per GPU TP)
|
||||
* @param w2_scale_ptrs Pointers to down scale buffers (unused for BF16, kept for interface compatibility)
|
||||
*/
|
||||
void write_weights_to_buffer(int gpu_tp_count, [[maybe_unused]] int cpu_tp_count, int expert_id,
|
||||
const GeneralMOEConfig& full_config, const std::vector<uintptr_t>& w13_weight_ptrs,
|
||||
[[maybe_unused]] const std::vector<uintptr_t>& w13_scale_ptrs,
|
||||
const std::vector<uintptr_t>& w2_weight_ptrs,
|
||||
[[maybe_unused]] const std::vector<uintptr_t>& w2_scale_ptrs) const {
|
||||
auto& config = config_;
|
||||
auto pool = config.pool->get_subpool(tp_part_idx);
|
||||
|
||||
constexpr int N_STEP = T::N_STEP;
|
||||
constexpr int K_STEP = T::K_STEP;
|
||||
constexpr int N_BLOCK = T::N_BLOCK;
|
||||
constexpr int K_BLOCK = T::K_BLOCK;
|
||||
|
||||
// ========= W13 (gate+up): Shape [intermediate, hidden], split by N only =========
|
||||
const int cpu_n_w13 = config.intermediate_size;
|
||||
const int cpu_k_w13 = config.hidden_size;
|
||||
const int gpu_n_w13 = full_config.intermediate_size / gpu_tp_count;
|
||||
const int gpu_k_w13 = full_config.hidden_size;
|
||||
const int global_n_offset_w13 = tp_part_idx * cpu_n_w13;
|
||||
|
||||
const size_t gpu_w13_weight_per_mat = (size_t)gpu_n_w13 * gpu_k_w13;
|
||||
|
||||
// ========= W2 (down): Shape [hidden, intermediate], split by K =========
|
||||
const int cpu_n_w2 = config.hidden_size;
|
||||
const int cpu_k_w2 = config.intermediate_size;
|
||||
const int gpu_n_w2 = full_config.hidden_size;
|
||||
const int gpu_k_w2 = full_config.intermediate_size / gpu_tp_count;
|
||||
const int global_k_offset_w2 = tp_part_idx * cpu_k_w2;
|
||||
|
||||
// ========= Optimized job layout =========
|
||||
constexpr int NUM_W13_TASKS = 32; // Per matrix (gate or up), total 64 for w13
|
||||
constexpr int NUM_W2_TASKS = 32; // For down matrix
|
||||
|
||||
const int total_tasks = NUM_W13_TASKS * 2 + NUM_W2_TASKS;
|
||||
|
||||
// Calculate N_STEP blocks per task
|
||||
const int w13_n_steps = div_up(cpu_n_w13, N_STEP);
|
||||
const int w13_steps_per_task = div_up(w13_n_steps, NUM_W13_TASKS);
|
||||
const int w2_n_steps = div_up(cpu_n_w2, N_STEP);
|
||||
const int w2_steps_per_task = div_up(w2_n_steps, NUM_W2_TASKS);
|
||||
|
||||
pool->do_work_stealing_job(
|
||||
total_tasks, nullptr,
|
||||
[=, &w13_weight_ptrs, &w2_weight_ptrs, this](int task_id) {
|
||||
if (task_id < NUM_W13_TASKS * 2) {
|
||||
// ========= W13 weight task: process chunk of rows x full K =========
|
||||
const bool is_up = task_id >= NUM_W13_TASKS;
|
||||
const int chunk_idx = task_id % NUM_W13_TASKS;
|
||||
const auto& bb = is_up ? up_bb_[expert_id] : gate_bb_[expert_id];
|
||||
|
||||
const int step_start = chunk_idx * w13_steps_per_task;
|
||||
const int step_end = std::min(step_start + w13_steps_per_task, w13_n_steps);
|
||||
if (step_start >= w13_n_steps) return;
|
||||
const int chunk_n_start = step_start * N_STEP;
|
||||
const int chunk_n_end = std::min(step_end * N_STEP, cpu_n_w13);
|
||||
|
||||
for (int local_n_start = chunk_n_start; local_n_start < chunk_n_end; local_n_start += N_STEP) {
|
||||
const int global_n = global_n_offset_w13 + local_n_start;
|
||||
const int target_gpu = global_n / gpu_n_w13;
|
||||
const int n_in_gpu = global_n % gpu_n_w13;
|
||||
|
||||
ggml_bf16_t* weight_base = (ggml_bf16_t*)w13_weight_ptrs[target_gpu];
|
||||
const size_t expert_weight_off = is_up ? gpu_w13_weight_per_mat : 0;
|
||||
|
||||
const int n_block_idx = local_n_start / N_BLOCK;
|
||||
const int n_block_begin = n_block_idx * N_BLOCK;
|
||||
const int n_block_size = std::min(N_BLOCK, cpu_n_w13 - n_block_begin);
|
||||
const int n_in_block = local_n_start - n_block_begin;
|
||||
|
||||
for (int k_block_begin = 0; k_block_begin < cpu_k_w13; k_block_begin += K_BLOCK) {
|
||||
const int k_block_size = std::min(K_BLOCK, cpu_k_w13 - k_block_begin);
|
||||
|
||||
for (int k_begin = 0; k_begin < k_block_size; k_begin += K_STEP) {
|
||||
const ggml_bf16_t* src = bb->b + (size_t)n_block_begin * cpu_k_w13 +
|
||||
(size_t)k_block_begin * n_block_size + (size_t)n_in_block * k_block_size +
|
||||
(size_t)k_begin * N_STEP;
|
||||
ggml_bf16_t* dst =
|
||||
weight_base + expert_weight_off + (size_t)n_in_gpu * gpu_k_w13 + k_block_begin + k_begin;
|
||||
unpack_nk_block_bf16(src, dst, gpu_k_w13);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} else {
|
||||
// ========= W2 weight task: process chunk of rows x all K slices =========
|
||||
const int chunk_idx = task_id - NUM_W13_TASKS * 2;
|
||||
const auto& bb = down_bb_[expert_id];
|
||||
|
||||
const int step_start = chunk_idx * w2_steps_per_task;
|
||||
const int step_end = std::min(step_start + w2_steps_per_task, w2_n_steps);
|
||||
if (step_start >= w2_n_steps) return;
|
||||
const int chunk_n_start = step_start * N_STEP;
|
||||
const int chunk_n_end = std::min(step_end * N_STEP, cpu_n_w2);
|
||||
|
||||
for (int local_n_start = chunk_n_start; local_n_start < chunk_n_end; local_n_start += N_STEP) {
|
||||
const int n_block_idx = local_n_start / N_BLOCK;
|
||||
const int n_block_begin = n_block_idx * N_BLOCK;
|
||||
const int n_block_size = std::min(N_BLOCK, cpu_n_w2 - n_block_begin);
|
||||
const int n_in_block = local_n_start - n_block_begin;
|
||||
|
||||
for (int k_slice_start = 0; k_slice_start < cpu_k_w2; k_slice_start += gpu_k_w2) {
|
||||
const int k_slice_end = std::min(k_slice_start + gpu_k_w2, cpu_k_w2);
|
||||
|
||||
const int global_k_start = global_k_offset_w2 + k_slice_start;
|
||||
const int target_gpu = global_k_start / gpu_k_w2;
|
||||
const int k_in_gpu_base = global_k_start % gpu_k_w2;
|
||||
|
||||
ggml_bf16_t* weight_base = (ggml_bf16_t*)w2_weight_ptrs[target_gpu];
|
||||
|
||||
for (int k_abs = k_slice_start; k_abs < k_slice_end; k_abs += K_STEP) {
|
||||
const int k_block_idx = k_abs / K_BLOCK;
|
||||
const int k_block_begin = k_block_idx * K_BLOCK;
|
||||
const int k_block_size = std::min(K_BLOCK, cpu_k_w2 - k_block_begin);
|
||||
const int k_in_block = k_abs - k_block_begin;
|
||||
const int k_in_gpu = k_in_gpu_base + (k_abs - k_slice_start);
|
||||
|
||||
const ggml_bf16_t* src = bb->b + (size_t)n_block_begin * cpu_k_w2 +
|
||||
(size_t)k_block_begin * n_block_size + (size_t)n_in_block * k_block_size +
|
||||
(size_t)k_in_block * N_STEP;
|
||||
ggml_bf16_t* dst = weight_base + (size_t)local_n_start * gpu_k_w2 + k_in_gpu;
|
||||
unpack_nk_block_bf16(src, dst, gpu_k_w2);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
nullptr);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename K>
|
||||
class TP_MOE<AMX_BF16_MOE_TP<K>> : public TP_MOE<AMX_MOE_BASE<K, AMX_BF16_MOE_TP<K>>> {
|
||||
public:
|
||||
using Base = TP_MOE<AMX_MOE_BASE<K, AMX_BF16_MOE_TP<K>>>;
|
||||
using Base::Base;
|
||||
|
||||
void load_weights() override {
|
||||
auto& config = this->config;
|
||||
auto& tps = this->tps;
|
||||
auto& tp_count = this->tp_count;
|
||||
auto pool = config.pool;
|
||||
const uint64_t* physical_to_logical_map = (const uint64_t*)config.physical_to_logical_map;
|
||||
|
||||
// BF16 has no quantization check needed
|
||||
if (config.gate_projs.empty() && config.gate_proj == nullptr) {
|
||||
throw std::runtime_error("no weight source");
|
||||
}
|
||||
|
||||
const bool use_per_expert_ptrs = !config.gate_projs.empty();
|
||||
const size_t full_weight_elems = (size_t)config.intermediate_size * config.hidden_size;
|
||||
|
||||
pool->dispense_backend()->do_numa_job([&, this](int i) {
|
||||
auto& tpc = tps[i]->config_;
|
||||
const size_t tp_weight_elems = (size_t)tpc.intermediate_size * tpc.hidden_size;
|
||||
|
||||
// Allocate BF16 weights (2 bytes/element)
|
||||
tpc.gate_proj = new ggml_bf16_t[tpc.expert_num * tp_weight_elems];
|
||||
tpc.up_proj = new ggml_bf16_t[tpc.expert_num * tp_weight_elems];
|
||||
tpc.down_proj = new ggml_bf16_t[tpc.expert_num * tp_weight_elems];
|
||||
|
||||
const size_t tp_idx = (size_t)i;
|
||||
const size_t gate_up_weight_src_offset = i * tp_weight_elems;
|
||||
const size_t down_weight_src_col_offset = i * (size_t)tpc.intermediate_size;
|
||||
|
||||
pool->get_subpool(i)->do_work_stealing_job(
|
||||
tpc.expert_num, nullptr,
|
||||
[&, &tpc](int expert_id_) {
|
||||
const size_t expert_id = expert_map(physical_to_logical_map, expert_id_);
|
||||
|
||||
ggml_bf16_t* gate_dst = (ggml_bf16_t*)tpc.gate_proj + expert_id * tp_weight_elems;
|
||||
ggml_bf16_t* up_dst = (ggml_bf16_t*)tpc.up_proj + expert_id * tp_weight_elems;
|
||||
ggml_bf16_t* down_dst = (ggml_bf16_t*)tpc.down_proj + expert_id * tp_weight_elems;
|
||||
|
||||
const ggml_bf16_t* gate_src;
|
||||
const ggml_bf16_t* up_src;
|
||||
const ggml_bf16_t* down_src;
|
||||
|
||||
if (use_per_expert_ptrs) {
|
||||
gate_src = (const ggml_bf16_t*)config.gate_projs[0][expert_id] + gate_up_weight_src_offset;
|
||||
up_src = (const ggml_bf16_t*)config.up_projs[0][expert_id] + gate_up_weight_src_offset;
|
||||
down_src = (const ggml_bf16_t*)config.down_projs[0][expert_id];
|
||||
} else {
|
||||
gate_src =
|
||||
(const ggml_bf16_t*)config.gate_proj + expert_id * full_weight_elems + gate_up_weight_src_offset;
|
||||
up_src = (const ggml_bf16_t*)config.up_proj + expert_id * full_weight_elems + gate_up_weight_src_offset;
|
||||
down_src = (const ggml_bf16_t*)config.down_proj + expert_id * full_weight_elems;
|
||||
}
|
||||
|
||||
// Copy gate and up weights
|
||||
std::memcpy(gate_dst, gate_src, tp_weight_elems * sizeof(ggml_bf16_t));
|
||||
std::memcpy(up_dst, up_src, tp_weight_elems * sizeof(ggml_bf16_t));
|
||||
|
||||
// Copy down weights (row-wise split)
|
||||
for (int row = 0; row < config.hidden_size; row++) {
|
||||
const size_t src_row_offset = (size_t)row * (size_t)config.intermediate_size + down_weight_src_col_offset;
|
||||
const size_t dst_row_offset = (size_t)row * (size_t)tpc.intermediate_size;
|
||||
std::memcpy(down_dst + dst_row_offset, down_src + src_row_offset,
|
||||
(size_t)tpc.intermediate_size * sizeof(ggml_bf16_t));
|
||||
}
|
||||
},
|
||||
nullptr);
|
||||
});
|
||||
|
||||
DO_TPS_LOAD_WEIGHTS(pool);
|
||||
|
||||
pool->dispense_backend()->do_numa_job([&, this](int i) {
|
||||
auto& tpc = tps[i]->config_;
|
||||
delete[] (ggml_bf16_t*)tpc.gate_proj;
|
||||
delete[] (ggml_bf16_t*)tpc.up_proj;
|
||||
delete[] (ggml_bf16_t*)tpc.down_proj;
|
||||
});
|
||||
|
||||
this->weights_loaded = true;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Write weights to GPU buffer for all TP parts
|
||||
*
|
||||
* BF16 version - no scales needed, scale_ptrs parameters are kept for interface compatibility.
|
||||
*/
|
||||
void write_weight_scale_to_buffer(int gpu_tp_count, int expert_id, const std::vector<uintptr_t>& w13_weight_ptrs,
|
||||
const std::vector<uintptr_t>& w13_scale_ptrs,
|
||||
const std::vector<uintptr_t>& w2_weight_ptrs,
|
||||
const std::vector<uintptr_t>& w2_scale_ptrs) {
|
||||
if (this->weights_loaded == false) {
|
||||
throw std::runtime_error("Not Loaded");
|
||||
}
|
||||
if (this->tps.empty()) {
|
||||
throw std::runtime_error("No TP parts initialized");
|
||||
}
|
||||
if ((int)w13_weight_ptrs.size() != gpu_tp_count || (int)w2_weight_ptrs.size() != gpu_tp_count) {
|
||||
throw std::runtime_error("Weight pointer arrays size must match gpu_tp_count");
|
||||
}
|
||||
|
||||
this->config.pool->dispense_backend()->do_numa_job([&, this](int i) {
|
||||
this->tps[i]->write_weights_to_buffer(gpu_tp_count, this->tp_count, expert_id, this->config, w13_weight_ptrs,
|
||||
w13_scale_ptrs, w2_weight_ptrs, w2_scale_ptrs);
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
#endif // CPUINFER_OPERATOR_AMX_BF16_MOE_H
|
||||
@@ -121,6 +121,137 @@ struct GemmKernel224BF16 {
|
||||
using BufferA = BufferABF16Impl<GemmKernel224BF16>;
|
||||
using BufferB = BufferBBF16Impl<GemmKernel224BF16>;
|
||||
using BufferC = BufferCFP32Impl<GemmKernel224BF16>;
|
||||
|
||||
// Basic AVX kernel for BF16: process entire K_BLOCK
|
||||
static void avx_kernel(int m, int n, int k, int m_begin, int n_begin, int k_block_begin, float* c, BufferA* ba,
|
||||
BufferB* bb) {
|
||||
__m512* c512 = (__m512*)c;
|
||||
int m_block_end = std::min(m - m_begin, M_STEP);
|
||||
|
||||
// Zero out accumulator at the start of k_block
|
||||
if (k_block_begin == 0) {
|
||||
for (int m_i = 0; m_i < m_block_end; m_i++) {
|
||||
c512[m_i * 2] = _mm512_setzero_ps();
|
||||
c512[m_i * 2 + 1] = _mm512_setzero_ps();
|
||||
}
|
||||
}
|
||||
|
||||
// Process entire K_BLOCK
|
||||
for (int k_begin = 0; k_begin < K_BLOCK && k_block_begin + k_begin < k; k_begin += K_STEP) {
|
||||
int32_t* a32 = (int32_t*)ba->get_submat(m, k, m_begin, k_block_begin + k_begin);
|
||||
__m512bh* b512 = (__m512bh*)bb->get_submat(n, k, n_begin, k_block_begin + k_begin);
|
||||
|
||||
for (int m_i = 0; m_i < m_block_end; m_i++) {
|
||||
for (int k_i = 0; k_i < 16; k_i++) {
|
||||
__m512bh ma = (__m512bh)_mm512_set1_epi32(a32[m_i * 16 + k_i]);
|
||||
c512[m_i * 2] = _mm512_dpbf16_ps(c512[m_i * 2], ma, b512[k_i]);
|
||||
c512[m_i * 2 + 1] = _mm512_dpbf16_ps(c512[m_i * 2 + 1], ma, b512[16 + k_i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Optimized AVX kernel: process 4 k_i at once, unroll m rows by 2
|
||||
static void avx_kernel_4(int m, int n, int k, int m_begin, int n_begin, int k_block_begin, float* c, BufferA* ba,
|
||||
BufferB* bb) {
|
||||
__m512* c512 = (__m512*)c;
|
||||
int m_block_end = std::min(m - m_begin, M_STEP);
|
||||
|
||||
// Zero out accumulator at the start of k_block
|
||||
if (k_block_begin == 0) {
|
||||
for (int m_i = 0; m_i < m_block_end; m_i++) {
|
||||
c512[m_i * 2] = _mm512_setzero_ps();
|
||||
c512[m_i * 2 + 1] = _mm512_setzero_ps();
|
||||
}
|
||||
}
|
||||
|
||||
// Process entire K_BLOCK
|
||||
for (int k_begin = 0; k_begin < K_BLOCK && k_block_begin + k_begin < k; k_begin += K_STEP) {
|
||||
int32_t* a32 = (int32_t*)ba->get_submat(m, k, m_begin, k_block_begin + k_begin);
|
||||
__m512bh* b512 = (__m512bh*)bb->get_submat(n, k, n_begin, k_block_begin + k_begin);
|
||||
|
||||
// Process 4 k_i at once - load B vectors and reuse across all m rows
|
||||
for (int k_i = 0; k_i < 16; k_i += 4) {
|
||||
// Load 4 B vector pairs (lo and hi for each k_i)
|
||||
__m512bh b0_lo = b512[k_i];
|
||||
__m512bh b0_hi = b512[16 + k_i];
|
||||
__m512bh b1_lo = b512[k_i + 1];
|
||||
__m512bh b1_hi = b512[16 + k_i + 1];
|
||||
__m512bh b2_lo = b512[k_i + 2];
|
||||
__m512bh b2_hi = b512[16 + k_i + 2];
|
||||
__m512bh b3_lo = b512[k_i + 3];
|
||||
__m512bh b3_hi = b512[16 + k_i + 3];
|
||||
|
||||
// Process m rows - unroll by 2 for better ILP
|
||||
int m_i = 0;
|
||||
for (; m_i + 1 < m_block_end; m_i += 2) {
|
||||
// Load A values for 2 rows, 4 k_i each
|
||||
__m512bh ma0_0 = (__m512bh)_mm512_set1_epi32(a32[m_i * 16 + k_i]);
|
||||
__m512bh ma1_0 = (__m512bh)_mm512_set1_epi32(a32[m_i * 16 + k_i + 1]);
|
||||
__m512bh ma2_0 = (__m512bh)_mm512_set1_epi32(a32[m_i * 16 + k_i + 2]);
|
||||
__m512bh ma3_0 = (__m512bh)_mm512_set1_epi32(a32[m_i * 16 + k_i + 3]);
|
||||
__m512bh ma0_1 = (__m512bh)_mm512_set1_epi32(a32[(m_i + 1) * 16 + k_i]);
|
||||
__m512bh ma1_1 = (__m512bh)_mm512_set1_epi32(a32[(m_i + 1) * 16 + k_i + 1]);
|
||||
__m512bh ma2_1 = (__m512bh)_mm512_set1_epi32(a32[(m_i + 1) * 16 + k_i + 2]);
|
||||
__m512bh ma3_1 = (__m512bh)_mm512_set1_epi32(a32[(m_i + 1) * 16 + k_i + 3]);
|
||||
|
||||
// Process row 0
|
||||
c512[m_i * 2] = _mm512_dpbf16_ps(c512[m_i * 2], ma0_0, b0_lo);
|
||||
c512[m_i * 2 + 1] = _mm512_dpbf16_ps(c512[m_i * 2 + 1], ma0_0, b0_hi);
|
||||
c512[m_i * 2] = _mm512_dpbf16_ps(c512[m_i * 2], ma1_0, b1_lo);
|
||||
c512[m_i * 2 + 1] = _mm512_dpbf16_ps(c512[m_i * 2 + 1], ma1_0, b1_hi);
|
||||
c512[m_i * 2] = _mm512_dpbf16_ps(c512[m_i * 2], ma2_0, b2_lo);
|
||||
c512[m_i * 2 + 1] = _mm512_dpbf16_ps(c512[m_i * 2 + 1], ma2_0, b2_hi);
|
||||
c512[m_i * 2] = _mm512_dpbf16_ps(c512[m_i * 2], ma3_0, b3_lo);
|
||||
c512[m_i * 2 + 1] = _mm512_dpbf16_ps(c512[m_i * 2 + 1], ma3_0, b3_hi);
|
||||
|
||||
// Process row 1
|
||||
c512[(m_i + 1) * 2] = _mm512_dpbf16_ps(c512[(m_i + 1) * 2], ma0_1, b0_lo);
|
||||
c512[(m_i + 1) * 2 + 1] = _mm512_dpbf16_ps(c512[(m_i + 1) * 2 + 1], ma0_1, b0_hi);
|
||||
c512[(m_i + 1) * 2] = _mm512_dpbf16_ps(c512[(m_i + 1) * 2], ma1_1, b1_lo);
|
||||
c512[(m_i + 1) * 2 + 1] = _mm512_dpbf16_ps(c512[(m_i + 1) * 2 + 1], ma1_1, b1_hi);
|
||||
c512[(m_i + 1) * 2] = _mm512_dpbf16_ps(c512[(m_i + 1) * 2], ma2_1, b2_lo);
|
||||
c512[(m_i + 1) * 2 + 1] = _mm512_dpbf16_ps(c512[(m_i + 1) * 2 + 1], ma2_1, b2_hi);
|
||||
c512[(m_i + 1) * 2] = _mm512_dpbf16_ps(c512[(m_i + 1) * 2], ma3_1, b3_lo);
|
||||
c512[(m_i + 1) * 2 + 1] = _mm512_dpbf16_ps(c512[(m_i + 1) * 2 + 1], ma3_1, b3_hi);
|
||||
}
|
||||
// Handle remaining row
|
||||
for (; m_i < m_block_end; m_i++) {
|
||||
__m512bh ma0 = (__m512bh)_mm512_set1_epi32(a32[m_i * 16 + k_i]);
|
||||
__m512bh ma1 = (__m512bh)_mm512_set1_epi32(a32[m_i * 16 + k_i + 1]);
|
||||
__m512bh ma2 = (__m512bh)_mm512_set1_epi32(a32[m_i * 16 + k_i + 2]);
|
||||
__m512bh ma3 = (__m512bh)_mm512_set1_epi32(a32[m_i * 16 + k_i + 3]);
|
||||
|
||||
c512[m_i * 2] = _mm512_dpbf16_ps(c512[m_i * 2], ma0, b0_lo);
|
||||
c512[m_i * 2 + 1] = _mm512_dpbf16_ps(c512[m_i * 2 + 1], ma0, b0_hi);
|
||||
c512[m_i * 2] = _mm512_dpbf16_ps(c512[m_i * 2], ma1, b1_lo);
|
||||
c512[m_i * 2 + 1] = _mm512_dpbf16_ps(c512[m_i * 2 + 1], ma1, b1_hi);
|
||||
c512[m_i * 2] = _mm512_dpbf16_ps(c512[m_i * 2], ma2, b2_lo);
|
||||
c512[m_i * 2 + 1] = _mm512_dpbf16_ps(c512[m_i * 2 + 1], ma2, b2_hi);
|
||||
c512[m_i * 2] = _mm512_dpbf16_ps(c512[m_i * 2], ma3, b3_lo);
|
||||
c512[m_i * 2 + 1] = _mm512_dpbf16_ps(c512[m_i * 2 + 1], ma3, b3_hi);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// AMX kernel for BF16: process entire K_BLOCK using AMX tiles
|
||||
static void amx_kernel(int m, int n, int k, int m_begin, int n_begin, int k_block_begin, float* c, BufferA* ba,
|
||||
BufferB* bb) {
|
||||
if (k_block_begin == 0) {
|
||||
clean_c();
|
||||
} else {
|
||||
load_c(c, N_STEP * sizeof(float));
|
||||
}
|
||||
|
||||
for (int k_begin = 0; k_begin < K_BLOCK && k_block_begin + k_begin < k; k_begin += K_STEP) {
|
||||
load_a(ba->get_submat(m, k, m_begin, k_block_begin + k_begin), K_STEP * sizeof(ggml_bf16_t));
|
||||
load_b(bb->get_submat(n, k, n_begin, k_block_begin + k_begin), K_STEP * sizeof(ggml_bf16_t));
|
||||
run_tile();
|
||||
}
|
||||
|
||||
store_c(c, N_STEP * sizeof(float));
|
||||
}
|
||||
};
|
||||
|
||||
// FP8 (e4m3) AMX kernel that mirrors the GemmKernel224BF16 interface.
|
||||
@@ -427,7 +558,7 @@ void float_mat_vec_kgroup(int m, int n, int k, int k_group_size, typename K::Buf
|
||||
}
|
||||
} else {
|
||||
// Single call processes entire k_group
|
||||
K::avx_kernel_4(m, n, k, m_begin, n_begin, k_group_begin, reduce_c, ba, bb, k_group_size);
|
||||
K::avx_kernel(m, n, k, m_begin, n_begin, k_group_begin, reduce_c, ba, bb, k_group_size);
|
||||
}
|
||||
K::apply_scale_kgroup(m, n, m_begin, n_begin, k_group_begin, c, reduce_c, ba, bb, k, k_group_size);
|
||||
}
|
||||
@@ -435,17 +566,45 @@ void float_mat_vec_kgroup(int m, int n, int k, int k_group_size, typename K::Buf
|
||||
}
|
||||
}
|
||||
|
||||
// inline void vec_mul_kgroup(int m, int n, int k, int k_group_size, std::shared_ptr<GemmKernel224BF16::BufferA> ba,
|
||||
// std::shared_ptr<GemmKernel224BF16::BufferB> bb,
|
||||
// std::shared_ptr<GemmKernel224BF16::BufferC> bc, int ith, int nth) {
|
||||
// float_mat_mul_kgroup<GemmKernel224BF16, false>(m, n, k, k_group_size, ba.get(), bb.get(), bc.get(), ith, nth);
|
||||
// }
|
||||
// ============================================================================
|
||||
// GemmKernel224BF16 vec_mul/mat_mul
|
||||
// ============================================================================
|
||||
|
||||
// inline void mat_mul_kgroup(int m, int n, int k, int k_group_size, std::shared_ptr<GemmKernel224BF16::BufferA> ba,
|
||||
// std::shared_ptr<GemmKernel224BF16::BufferB> bb,
|
||||
// std::shared_ptr<GemmKernel224BF16::BufferC> bc, int ith, int nth) {
|
||||
// float_mat_mul_kgroup<GemmKernel224BF16, true>(m, n, k, k_group_size, ba.get(), bb.get(), bc.get(), ith, nth);
|
||||
// }
|
||||
// Template function for BF16 mat_mul/vec_mul with AMX or AVX backend
|
||||
template <typename K, bool amx_or_avx = true>
|
||||
void float_mat_vec(int m, int n, int k, typename K::BufferA* ba, typename K::BufferB* bb, typename K::BufferC* bc,
|
||||
int ith, int nth) {
|
||||
assert(n % K::N_STEP == 0);
|
||||
assert(k % K::K_STEP == 0);
|
||||
|
||||
auto [n_start, n_end] = K::split_range_n(n, ith, nth);
|
||||
|
||||
for (int k_block_begin = 0; k_block_begin < k; k_block_begin += K::K_BLOCK) {
|
||||
for (int m_begin = 0; m_begin < m; m_begin += K::M_STEP) {
|
||||
for (int n_begin = n_start; n_begin < n_end; n_begin += K::N_STEP) {
|
||||
float* c = bc->get_submat(m, n, m_begin, n_begin);
|
||||
|
||||
if constexpr (amx_or_avx && AMX_AVAILABLE) {
|
||||
K::amx_kernel(m, n, k, m_begin, n_begin, k_block_begin, c, ba, bb);
|
||||
} else {
|
||||
K::avx_kernel_4(m, n, k, m_begin, n_begin, k_block_begin, c, ba, bb);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
inline void mat_mul(int m, int n, int k, std::shared_ptr<GemmKernel224BF16::BufferA> ba,
|
||||
std::shared_ptr<GemmKernel224BF16::BufferB> bb, std::shared_ptr<GemmKernel224BF16::BufferC> bc,
|
||||
int ith, int nth) {
|
||||
float_mat_vec<GemmKernel224BF16, true>(m, n, k, ba.get(), bb.get(), bc.get(), ith, nth);
|
||||
}
|
||||
|
||||
inline void vec_mul(int m, int n, int k, std::shared_ptr<GemmKernel224BF16::BufferA> ba,
|
||||
std::shared_ptr<GemmKernel224BF16::BufferB> bb, std::shared_ptr<GemmKernel224BF16::BufferC> bc,
|
||||
int ith, int nth) {
|
||||
float_mat_vec<GemmKernel224BF16, false>(m, n, k, ba.get(), bb.get(), bc.get(), ith, nth);
|
||||
}
|
||||
|
||||
inline void vec_mul_kgroup(int m, int n, int k, int k_group_size, std::shared_ptr<GemmKernel224FP8::BufferA> ba,
|
||||
std::shared_ptr<GemmKernel224FP8::BufferB> bb, std::shared_ptr<GemmKernel224FP8::BufferC> bc,
|
||||
|
||||
@@ -77,7 +77,7 @@ class KTMoEWrapper:
|
||||
chunked_prefill_size: Maximum prefill chunk size
|
||||
cpu_save: Whether to save weights to CPU memory
|
||||
max_deferred_experts_per_token: Number of experts per token to defer. Defaults to 0.
|
||||
method: Backend method ("AMXINT4", "AMXINT8", "RAWINT4", "FP8", "LLAMAFILE", "MOE_INT4", "MOE_INT8")
|
||||
method: Backend method ("AMXINT4", "AMXINT8", "RAWINT4", "FP8", "BF16", "LLAMAFILE", "MOE_INT4", "MOE_INT8")
|
||||
|
||||
Returns:
|
||||
An instance of the appropriate backend implementation (e.g., AMXMoEWrapper)
|
||||
@@ -85,7 +85,7 @@ class KTMoEWrapper:
|
||||
# Select backend based on method
|
||||
if method in ["AMXINT4", "AMXINT8"]:
|
||||
backend_cls = AMXMoEWrapper
|
||||
elif method in ["RAWINT4", "FP8"]:
|
||||
elif method in ["RAWINT4", "FP8", "BF16"]:
|
||||
backend_cls = NativeMoEWrapper
|
||||
elif method == "LLAMAFILE":
|
||||
backend_cls = LlamafileMoEWrapper
|
||||
|
||||
@@ -4,16 +4,16 @@ import ctypes
|
||||
|
||||
# Use relative imports for package structure
|
||||
from ..experts_base import BaseMoEWrapper
|
||||
from .loader import SafeTensorLoader, CompressedSafeTensorLoader, FP8SafeTensorLoader
|
||||
from .loader import SafeTensorLoader, CompressedSafeTensorLoader, FP8SafeTensorLoader, BF16SafeTensorLoader
|
||||
from kt_kernel_ext.moe import MOEConfig
|
||||
|
||||
try:
|
||||
from kt_kernel_ext.moe import AMXInt4_MOE, AMXInt8_MOE, AMXInt4_KGroup_MOE, AMXFP8_MOE
|
||||
from kt_kernel_ext.moe import AMXInt4_MOE, AMXInt8_MOE, AMXInt4_KGroup_MOE, AMXFP8_MOE, AMXBF16_MOE
|
||||
|
||||
_HAS_AMX_SUPPORT = True
|
||||
except (ImportError, AttributeError):
|
||||
_HAS_AMX_SUPPORT = False
|
||||
AMXInt4_MOE, AMXInt8_MOE, AMXInt4_KGroup_MOE, AMXFP8_MOE = None, None, None, None
|
||||
AMXInt4_MOE, AMXInt8_MOE, AMXInt4_KGroup_MOE, AMXFP8_MOE, AMXBF16_MOE = None, None, None, None, None
|
||||
|
||||
from typing import Optional
|
||||
|
||||
@@ -304,7 +304,7 @@ class AMXMoEWrapper(BaseMoEWrapper):
|
||||
|
||||
|
||||
class NativeMoEWrapper(BaseMoEWrapper):
|
||||
"""Wrapper for RAWINT4/FP8 experts stored in compressed SafeTensor format."""
|
||||
"""Wrapper for RAWINT4/FP8/BF16 experts stored in compressed SafeTensor format."""
|
||||
|
||||
_native_loader_instance = None
|
||||
|
||||
@@ -330,6 +330,8 @@ class NativeMoEWrapper(BaseMoEWrapper):
|
||||
raise RuntimeError("AMX backend with RAWINT4 support is not available.")
|
||||
if method == "FP8" and AMXFP8_MOE is None:
|
||||
raise RuntimeError("AMX backend with FP8 support is not available.")
|
||||
if method == "BF16" and AMXBF16_MOE is None:
|
||||
raise RuntimeError("AMX backend with BF16 support is not available.")
|
||||
|
||||
super().__init__(
|
||||
layer_idx=layer_idx,
|
||||
@@ -352,6 +354,8 @@ class NativeMoEWrapper(BaseMoEWrapper):
|
||||
NativeMoEWrapper._native_loader_instance = CompressedSafeTensorLoader(weight_path)
|
||||
elif method == "FP8":
|
||||
NativeMoEWrapper._native_loader_instance = FP8SafeTensorLoader(weight_path)
|
||||
elif method == "BF16":
|
||||
NativeMoEWrapper._native_loader_instance = BF16SafeTensorLoader(weight_path)
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported method for NativeMoEWrapper: {method}")
|
||||
self.loader = NativeMoEWrapper._native_loader_instance
|
||||
@@ -386,28 +390,42 @@ class NativeMoEWrapper(BaseMoEWrapper):
|
||||
self.up_weights = weights["up"]
|
||||
self.down_weights = weights["down"]
|
||||
|
||||
# Convert scales to bf16 individually
|
||||
# self.gate_scales = [t.to(torch.bfloat16).contiguous() for t in weights["gate_scale"]]
|
||||
# self.up_scales = [t.to(torch.bfloat16).contiguous() for t in weights["up_scale"]]
|
||||
# self.down_scales = [t.to(torch.bfloat16).contiguous() for t in weights["down_scale"]]
|
||||
self.gate_scales = weights["gate_scale"]
|
||||
self.up_scales = weights["up_scale"]
|
||||
self.down_scales = weights["down_scale"]
|
||||
if self.method == "RAWINT4":
|
||||
assert self.gate_scales[0].dtype == torch.bfloat16, "Expected bf16 scales for RAWINT4"
|
||||
elif self.method == "FP8":
|
||||
assert self.gate_scales[0].dtype == torch.float32, "Expected float32 scales for FP8"
|
||||
# BF16 has no scales, others have scales
|
||||
if self.method == "BF16":
|
||||
# BF16 doesn't have scales
|
||||
self.gate_scales = None
|
||||
self.up_scales = None
|
||||
self.down_scales = None
|
||||
else:
|
||||
# Convert scales to bf16 individually
|
||||
# self.gate_scales = [t.to(torch.bfloat16).contiguous() for t in weights["gate_scale"]]
|
||||
# self.up_scales = [t.to(torch.bfloat16).contiguous() for t in weights["up_scale"]]
|
||||
# self.down_scales = [t.to(torch.bfloat16).contiguous() for t in weights["down_scale"]]
|
||||
self.gate_scales = weights["gate_scale"]
|
||||
self.up_scales = weights["up_scale"]
|
||||
self.down_scales = weights["down_scale"]
|
||||
if self.method == "RAWINT4":
|
||||
assert self.gate_scales[0].dtype == torch.bfloat16, "Expected bf16 scales for RAWINT4"
|
||||
elif self.method == "FP8":
|
||||
assert self.gate_scales[0].dtype == torch.float32, "Expected float32 scales for FP8"
|
||||
|
||||
t2 = time.time()
|
||||
|
||||
# Build pointer lists: [numa_id][expert_id] -> pointer
|
||||
# Since RAWINT4 has no numa sharding, numa dimension is 1
|
||||
# Since RAWINT4/FP8/BF16 has no numa sharding, numa dimension is 1
|
||||
gate_ptrs = [[t.data_ptr() for t in self.gate_weights]]
|
||||
up_ptrs = [[t.data_ptr() for t in self.up_weights]]
|
||||
down_ptrs = [[t.data_ptr() for t in self.down_weights]]
|
||||
gate_scale_ptrs = [[t.data_ptr() for t in self.gate_scales]]
|
||||
up_scale_ptrs = [[t.data_ptr() for t in self.up_scales]]
|
||||
down_scale_ptrs = [[t.data_ptr() for t in self.down_scales]]
|
||||
|
||||
# BF16 has no scales, pass empty lists (will use 0/nullptr for consistency)
|
||||
if self.method == "BF16":
|
||||
gate_scale_ptrs = [[0 for _ in self.gate_weights]]
|
||||
up_scale_ptrs = [[0 for _ in self.up_weights]]
|
||||
down_scale_ptrs = [[0 for _ in self.down_weights]]
|
||||
else:
|
||||
gate_scale_ptrs = [[t.data_ptr() for t in self.gate_scales]]
|
||||
up_scale_ptrs = [[t.data_ptr() for t in self.up_scales]]
|
||||
down_scale_ptrs = [[t.data_ptr() for t in self.down_scales]]
|
||||
t3 = time.time()
|
||||
|
||||
moe_config = MOEConfig(
|
||||
@@ -444,6 +462,9 @@ class NativeMoEWrapper(BaseMoEWrapper):
|
||||
moe_config.quant_config.group_size = 128
|
||||
moe_config.quant_config.zero_point = False
|
||||
self.moe = AMXFP8_MOE(moe_config)
|
||||
elif self.method == "BF16":
|
||||
# BF16 has no quantization config needed
|
||||
self.moe = AMXBF16_MOE(moe_config)
|
||||
t4 = time.time()
|
||||
|
||||
self.cpu_infer.submit(self.moe.load_weights_task(physical_to_logical_map_cpu.data_ptr()))
|
||||
@@ -453,9 +474,10 @@ class NativeMoEWrapper(BaseMoEWrapper):
|
||||
del self.gate_weights
|
||||
del self.up_weights
|
||||
del self.down_weights
|
||||
del self.gate_scales
|
||||
del self.up_scales
|
||||
del self.down_scales
|
||||
if self.gate_scales is not None:
|
||||
del self.gate_scales
|
||||
del self.up_scales
|
||||
del self.down_scales
|
||||
t6 = time.time()
|
||||
|
||||
print(
|
||||
|
||||
@@ -348,6 +348,99 @@ class FP8SafeTensorLoader(SafeTensorLoader):
|
||||
}
|
||||
|
||||
|
||||
class BF16SafeTensorLoader(SafeTensorLoader):
|
||||
"""Loader for native BF16 expert weights (no quantization, no scales).
|
||||
|
||||
Supported formats:
|
||||
- DeepSeek style: {base}.mlp.experts.{id}.{gate,up,down}_proj.weight
|
||||
- Mixtral/MiniMax style: {base}.block_sparse_moe.experts.{id}.{w1,w3,w2}.weight
|
||||
|
||||
The format is auto-detected during initialization.
|
||||
"""
|
||||
|
||||
MOE_FORMATS = {
|
||||
"deepseek": ("{base}.mlp.experts", "gate_proj", "up_proj", "down_proj"),
|
||||
"mixtral": ("{base}.block_sparse_moe.experts", "w1", "w3", "w2"),
|
||||
}
|
||||
|
||||
def __init__(self, file_path: str):
|
||||
super().__init__(file_path)
|
||||
self._detected_format = None
|
||||
self._detect_format()
|
||||
|
||||
def _detect_format(self):
|
||||
"""Auto-detect the MoE naming format by checking tensor keys."""
|
||||
sample_keys = list(self.tensor_file_map.keys())[:1000]
|
||||
|
||||
for fmt_name, (path_tpl, gate, up, down) in self.MOE_FORMATS.items():
|
||||
for key in sample_keys:
|
||||
if ".experts." in key and f".{gate}.weight" in key:
|
||||
if "block_sparse_moe.experts" in key and fmt_name == "mixtral":
|
||||
self._detected_format = fmt_name
|
||||
print(f"[BF16SafeTensorLoader] Detected format: {fmt_name}")
|
||||
return
|
||||
elif "mlp.experts" in key and "block_sparse_moe" not in key and fmt_name == "deepseek":
|
||||
self._detected_format = fmt_name
|
||||
print(f"[BF16SafeTensorLoader] Detected format: {fmt_name}")
|
||||
return
|
||||
|
||||
self._detected_format = "deepseek"
|
||||
print("[BF16SafeTensorLoader] No MoE format detected, defaulting to: deepseek")
|
||||
|
||||
def _get_experts_prefix(self, base_key: str) -> str:
|
||||
"""Get the experts prefix based on detected format."""
|
||||
path_tpl, _, _, _ = self.MOE_FORMATS[self._detected_format]
|
||||
return path_tpl.format(base=base_key)
|
||||
|
||||
def _get_proj_names(self):
|
||||
"""Get projection names (gate, up, down) based on detected format."""
|
||||
_, gate, up, down = self.MOE_FORMATS[self._detected_format]
|
||||
return gate, up, down
|
||||
|
||||
def load_tensor(self, key: str, device: str = "cpu"):
|
||||
if key not in self.tensor_file_map:
|
||||
raise KeyError(f"Key {key} not found in Safetensor files")
|
||||
file = self.tensor_file_map[key]
|
||||
f = self.file_handle_map.get(file)
|
||||
if f is None:
|
||||
raise FileNotFoundError(f"File {file} not found in Safetensor files")
|
||||
tensor = f.get_tensor(key)
|
||||
if device == "cpu":
|
||||
return tensor
|
||||
return tensor.to(device)
|
||||
|
||||
def load_experts(self, base_key: str, device: str = "cpu"):
|
||||
"""Load BF16 expert weights (no scales needed)."""
|
||||
experts_prefix = self._get_experts_prefix(base_key)
|
||||
gate_name, up_name, down_name = self._get_proj_names()
|
||||
|
||||
expert_count = 0
|
||||
while self.has_tensor(f"{experts_prefix}.{expert_count}.{gate_name}.weight"):
|
||||
expert_count += 1
|
||||
|
||||
if expert_count == 0:
|
||||
raise ValueError(f"No experts found for key {experts_prefix}")
|
||||
|
||||
gate_weights = [None] * expert_count
|
||||
up_weights = [None] * expert_count
|
||||
down_weights = [None] * expert_count
|
||||
|
||||
for exp_id in range(expert_count):
|
||||
gate_w_key = f"{experts_prefix}.{exp_id}.{gate_name}.weight"
|
||||
up_w_key = f"{experts_prefix}.{exp_id}.{up_name}.weight"
|
||||
down_w_key = f"{experts_prefix}.{exp_id}.{down_name}.weight"
|
||||
|
||||
gate_weights[exp_id] = self.load_tensor(gate_w_key, device).contiguous()
|
||||
up_weights[exp_id] = self.load_tensor(up_w_key, device).contiguous()
|
||||
down_weights[exp_id] = self.load_tensor(down_w_key, device).contiguous()
|
||||
|
||||
return {
|
||||
"gate": gate_weights,
|
||||
"up": up_weights,
|
||||
"down": down_weights,
|
||||
}
|
||||
|
||||
|
||||
class CompressedSafeTensorLoader(SafeTensorLoader):
|
||||
"""Loader for compressed SafeTensor layouts (RAWINT4 weights)."""
|
||||
|
||||
|
||||
Reference in New Issue
Block a user