Files
ktransformers/kt-kernel/examples/test_write_buffer.py
Oql 6277da4c2b support GLM 4.7 (#1791)
support GLM 4.7
2026-01-13 17:36:25 +08:00

766 lines
33 KiB
Python

"""
Test write_weight_scale_to_buffer for AMX MOE operators.
Supports:
- FP8: FP8 weights (1 byte) + float32 scales (block-wise)
- FP8_PERCHANNEL: FP8 weights (1 byte) + float32 per-channel scales
- BF16: Native BF16 weights (2 bytes), no scales
Usage:
python test_write_buffer.py # Run all modes
python test_write_buffer.py fp8 # Run FP8 only
python test_write_buffer.py fp8_perchannel # Run FP8 per-channel only
python test_write_buffer.py bf16 # Run BF16 only
"""
import os
import sys
import time
import torch
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "build"))
from kt_kernel import kt_kernel_ext
from kt_kernel_ext import CPUInfer
def make_cpu_infer(thread_num=80):
return CPUInfer(thread_num)
def div_up(a, b):
return (a + b - 1) // b
def build_config_fp8(cpuinfer, expert_num, num_experts_per_tok, hidden_size, intermediate_size, group_size):
cfg = kt_kernel_ext.moe.MOEConfig(expert_num, num_experts_per_tok, hidden_size, intermediate_size)
cfg.max_len = 1
cfg.quant_config.bits = 8 # FP8
cfg.quant_config.group_size = group_size
cfg.quant_config.zero_point = False
cfg.pool = cpuinfer.backend_
return cfg
def build_config_fp8_perchannel(cpuinfer, expert_num, num_experts_per_tok, hidden_size, intermediate_size):
cfg = kt_kernel_ext.moe.MOEConfig(expert_num, num_experts_per_tok, hidden_size, intermediate_size)
cfg.max_len = 1
cfg.quant_config.bits = 8 # FP8
cfg.quant_config.group_size = 0 # Not used for per-channel
cfg.quant_config.zero_point = False
cfg.quant_config.per_channel = True
cfg.pool = cpuinfer.backend_
return cfg
def build_config_bf16(cpuinfer, expert_num, num_experts_per_tok, hidden_size, intermediate_size):
cfg = kt_kernel_ext.moe.MOEConfig(expert_num, num_experts_per_tok, hidden_size, intermediate_size)
cfg.max_len = 1
cfg.pool = cpuinfer.backend_
return cfg
def allocate_weights_fp8(expert_num, hidden_size, intermediate_size, group_size):
"""Allocate FP8 weights and scales for testing"""
# FP8 weights: 1 byte per element
per_mat_weight_bytes = hidden_size * intermediate_size
# FP8 scales: block-wise (group_size x group_size blocks), stored as float32
n_blocks_n_gate_up = div_up(intermediate_size, group_size)
n_blocks_k = div_up(hidden_size, group_size)
per_mat_scale_elems_gate_up = n_blocks_n_gate_up * n_blocks_k
# For down: n=hidden_size, k=intermediate_size
n_blocks_n_down = n_blocks_k
n_blocks_k_down = n_blocks_n_gate_up
per_mat_scale_elems_down = n_blocks_n_down * n_blocks_k_down
gate_q = torch.randint(0, 256, (expert_num * per_mat_weight_bytes,), dtype=torch.uint8)
up_q = torch.randint(0, 256, (expert_num * per_mat_weight_bytes,), dtype=torch.uint8)
down_q = torch.randint(0, 256, (expert_num * per_mat_weight_bytes,), dtype=torch.uint8)
gate_scale = torch.randn(expert_num * per_mat_scale_elems_gate_up, dtype=torch.float32)
up_scale = torch.randn(expert_num * per_mat_scale_elems_gate_up, dtype=torch.float32)
down_scale = torch.randn(expert_num * per_mat_scale_elems_down, dtype=torch.float32)
return {
"gate_q": gate_q,
"up_q": up_q,
"down_q": down_q,
"gate_scale": gate_scale,
"up_scale": up_scale,
"down_scale": down_scale,
"per_mat_weight_bytes": per_mat_weight_bytes,
"per_mat_scale_elems_gate_up": per_mat_scale_elems_gate_up,
"per_mat_scale_elems_down": per_mat_scale_elems_down,
}
def allocate_weights_fp8_perchannel(expert_num, hidden_size, intermediate_size):
"""Allocate FP8 per-channel weights and scales for testing"""
per_mat_weight_bytes = hidden_size * intermediate_size
per_mat_scale_elems_gate_up = intermediate_size # one scale per output channel
per_mat_scale_elems_down = hidden_size
gate_q = torch.randint(0, 256, (expert_num * per_mat_weight_bytes,), dtype=torch.uint8)
up_q = torch.randint(0, 256, (expert_num * per_mat_weight_bytes,), dtype=torch.uint8)
down_q = torch.randint(0, 256, (expert_num * per_mat_weight_bytes,), dtype=torch.uint8)
gate_scale = torch.randn(expert_num * per_mat_scale_elems_gate_up, dtype=torch.float32)
up_scale = torch.randn(expert_num * per_mat_scale_elems_gate_up, dtype=torch.float32)
down_scale = torch.randn(expert_num * per_mat_scale_elems_down, dtype=torch.float32)
return {
"gate_q": gate_q,
"up_q": up_q,
"down_q": down_q,
"gate_scale": gate_scale,
"up_scale": up_scale,
"down_scale": down_scale,
"per_mat_weight_bytes": per_mat_weight_bytes,
"per_mat_scale_elems_gate_up": per_mat_scale_elems_gate_up,
"per_mat_scale_elems_down": per_mat_scale_elems_down,
}
def allocate_weights_bf16(expert_num, hidden_size, intermediate_size):
"""Allocate BF16 weights for testing (no scales)"""
# BF16 weights: 2 bytes per element
per_mat_weight_elems = hidden_size * intermediate_size
per_mat_weight_bytes = per_mat_weight_elems * 2 # BF16 = 2 bytes
gate_proj = torch.randn(expert_num * per_mat_weight_elems, dtype=torch.bfloat16)
up_proj = torch.randn(expert_num * per_mat_weight_elems, dtype=torch.bfloat16)
down_proj = torch.randn(expert_num * per_mat_weight_elems, dtype=torch.bfloat16)
return {
"gate_proj": gate_proj,
"up_proj": up_proj,
"down_proj": down_proj,
"per_mat_weight_bytes": per_mat_weight_bytes,
"per_mat_weight_elems": per_mat_weight_elems,
}
def test_fp8_write_buffer(gpu_tp_count):
"""Test write_weight_scale_to_buffer with FP8 weights"""
torch.manual_seed(123)
expert_num = 256
gpu_experts = expert_num
num_experts_per_tok = 8
hidden_size = 3072
intermediate_size = 1536
group_size = 128
cpuinfer = make_cpu_infer()
cfg = build_config_fp8(cpuinfer, expert_num, num_experts_per_tok, hidden_size, intermediate_size, group_size)
weights = allocate_weights_fp8(expert_num, hidden_size, intermediate_size, group_size)
cfg.gate_proj = weights["gate_q"].data_ptr()
cfg.up_proj = weights["up_q"].data_ptr()
cfg.down_proj = weights["down_q"].data_ptr()
cfg.gate_scale = weights["gate_scale"].data_ptr()
cfg.up_scale = weights["up_scale"].data_ptr()
cfg.down_scale = weights["down_scale"].data_ptr()
moe = kt_kernel_ext.moe.AMXFP8_MOE(cfg)
physical_to_logical_map = torch.arange(expert_num, dtype=torch.int64, device="cpu").contiguous()
cpuinfer.submit(moe.load_weights_task(physical_to_logical_map.data_ptr()))
cpuinfer.sync()
per_mat_weight_bytes = weights["per_mat_weight_bytes"]
per_mat_scale_elems_gate_up = weights["per_mat_scale_elems_gate_up"]
per_mat_scale_elems_down = weights["per_mat_scale_elems_down"]
# Calculate sizes per TP part
weight_bytes_per_expert_per_tp = per_mat_weight_bytes // gpu_tp_count
gpu_n_w13 = intermediate_size // gpu_tp_count
gpu_k_w13 = hidden_size
scale_elems_per_expert_per_tp_gate_up = div_up(gpu_n_w13, group_size) * div_up(gpu_k_w13, group_size)
gpu_n_w2 = hidden_size
gpu_k_w2 = intermediate_size // gpu_tp_count
scale_elems_per_expert_per_tp_down = div_up(gpu_n_w2, group_size) * div_up(gpu_k_w2, group_size)
total_weight_bytes_per_tp = gpu_experts * weight_bytes_per_expert_per_tp
total_scale_elems_per_tp_gate_up = gpu_experts * scale_elems_per_expert_per_tp_gate_up
total_scale_elems_per_tp_down = gpu_experts * scale_elems_per_expert_per_tp_down
# Create buffer lists
w13_weight_bufs = [torch.empty(2 * total_weight_bytes_per_tp, dtype=torch.uint8) for _ in range(gpu_tp_count)]
w13_scale_bufs = [
torch.empty(2 * total_scale_elems_per_tp_gate_up, dtype=torch.float32) for _ in range(gpu_tp_count)
]
w2_weight_bufs = [torch.empty(total_weight_bytes_per_tp, dtype=torch.uint8) for _ in range(gpu_tp_count)]
w2_scale_bufs = [torch.empty(total_scale_elems_per_tp_down, dtype=torch.float32) for _ in range(gpu_tp_count)]
print(f"[FP8] GPU TP count: {gpu_tp_count}, Experts: {expert_num}")
print(f"[FP8] Weight bytes per expert per TP: {weight_bytes_per_expert_per_tp}")
print(f"[FP8] Scale elements per expert per TP (gate/up): {scale_elems_per_expert_per_tp_gate_up}")
def get_expert_ptrs(expert_id):
w13_weight_ptrs = []
w13_scale_ptrs = []
w2_weight_ptrs = []
w2_scale_ptrs = []
for tp_idx in range(gpu_tp_count):
w13_weight_expert_offset = expert_id * 2 * weight_bytes_per_expert_per_tp
w13_scale_expert_offset = expert_id * 2 * scale_elems_per_expert_per_tp_gate_up
w2_weight_expert_offset = expert_id * weight_bytes_per_expert_per_tp
w2_scale_expert_offset = expert_id * scale_elems_per_expert_per_tp_down
w13_weight_ptrs.append(w13_weight_bufs[tp_idx].data_ptr() + w13_weight_expert_offset)
w13_scale_ptrs.append(w13_scale_bufs[tp_idx].data_ptr() + w13_scale_expert_offset * 4)
w2_weight_ptrs.append(w2_weight_bufs[tp_idx].data_ptr() + w2_weight_expert_offset)
w2_scale_ptrs.append(w2_scale_bufs[tp_idx].data_ptr() + w2_scale_expert_offset * 4)
return w13_weight_ptrs, w13_scale_ptrs, w2_weight_ptrs, w2_scale_ptrs
# Warm up
for _ in range(2):
for expert_id in range(gpu_experts):
w13_weight_ptrs, w13_scale_ptrs, w2_weight_ptrs, w2_scale_ptrs = get_expert_ptrs(expert_id)
cpuinfer.submit(
moe.write_weight_scale_to_buffer_task(
gpu_tp_count=gpu_tp_count,
expert_id=expert_id,
w13_weight_ptrs=w13_weight_ptrs,
w13_scale_ptrs=w13_scale_ptrs,
w2_weight_ptrs=w2_weight_ptrs,
w2_scale_ptrs=w2_scale_ptrs,
)
)
cpuinfer.sync()
# Timing
begin_time = time.perf_counter_ns()
for expert_id in range(gpu_experts):
w13_weight_ptrs, w13_scale_ptrs, w2_weight_ptrs, w2_scale_ptrs = get_expert_ptrs(expert_id)
cpuinfer.submit(
moe.write_weight_scale_to_buffer_task(
gpu_tp_count=gpu_tp_count,
expert_id=expert_id,
w13_weight_ptrs=w13_weight_ptrs,
w13_scale_ptrs=w13_scale_ptrs,
w2_weight_ptrs=w2_weight_ptrs,
w2_scale_ptrs=w2_scale_ptrs,
)
)
cpuinfer.sync()
end_time = time.perf_counter_ns()
elapsed_ms = (end_time - begin_time) / 1e6
total_bytes = (
hidden_size * intermediate_size * gpu_experts * 3
+ (per_mat_scale_elems_gate_up * 2 + per_mat_scale_elems_down) * gpu_experts * 4
)
print(f"[FP8] write_weight_scale_to_buffer time: {elapsed_ms:.2f} ms")
print(f"[FP8] Throughput: {total_bytes / (elapsed_ms * 1e6):.2f} GB/s")
# Verify correctness
def split_expert_tensor(tensor, chunk):
return [tensor[i * chunk : (i + 1) * chunk] for i in range(expert_num)]
gate_q = weights["gate_q"]
up_q = weights["up_q"]
down_q = weights["down_q"]
gate_scale = weights["gate_scale"]
up_scale = weights["up_scale"]
down_scale = weights["down_scale"]
gate_q_experts = split_expert_tensor(gate_q, per_mat_weight_bytes)
up_q_experts = split_expert_tensor(up_q, per_mat_weight_bytes)
down_q_experts = split_expert_tensor(down_q, per_mat_weight_bytes)
gate_scale_experts = split_expert_tensor(gate_scale, per_mat_scale_elems_gate_up)
up_scale_experts = split_expert_tensor(up_scale, per_mat_scale_elems_gate_up)
down_scale_experts = split_expert_tensor(down_scale, per_mat_scale_elems_down)
n_blocks_n = div_up(hidden_size, group_size)
n_blocks_k = div_up(intermediate_size, group_size)
n_blocks_k_per_tp = n_blocks_k // gpu_tp_count
for tp_idx in range(gpu_tp_count):
expected_w13_weights = []
expected_w13_scales = []
expected_w2_weights = []
expected_w2_scales = []
weight13_per_tp = per_mat_weight_bytes // gpu_tp_count
scale13_per_tp = per_mat_scale_elems_gate_up // gpu_tp_count
for expert_id in range(gpu_experts):
start_weight = tp_idx * weight13_per_tp
end_weight = (tp_idx + 1) * weight13_per_tp
start_scale = tp_idx * scale13_per_tp
end_scale = (tp_idx + 1) * scale13_per_tp
gate_weight_tp = gate_q_experts[expert_id][start_weight:end_weight]
gate_scale_tp = gate_scale_experts[expert_id][start_scale:end_scale]
up_weight_tp = up_q_experts[expert_id][start_weight:end_weight]
up_scale_tp = up_scale_experts[expert_id][start_scale:end_scale]
down_weight_tp_parts = []
down_scale_tp_parts = []
tp_slice_weight_size = intermediate_size // gpu_tp_count
for row_idx in range(hidden_size):
row_weight_start = row_idx * intermediate_size
tp_weight_offset = row_weight_start + tp_idx * tp_slice_weight_size
down_weight_tp_parts.append(
down_q_experts[expert_id][tp_weight_offset : tp_weight_offset + tp_slice_weight_size]
)
for bn in range(n_blocks_n):
row_scale_start = bn * n_blocks_k
tp_scale_offset = row_scale_start + tp_idx * n_blocks_k_per_tp
down_scale_tp_parts.append(
down_scale_experts[expert_id][tp_scale_offset : tp_scale_offset + n_blocks_k_per_tp]
)
down_weight_tp = torch.cat(down_weight_tp_parts)
down_scale_tp = torch.cat(down_scale_tp_parts)
expected_w13_weights.append(gate_weight_tp)
expected_w13_weights.append(up_weight_tp)
expected_w13_scales.append(gate_scale_tp)
expected_w13_scales.append(up_scale_tp)
expected_w2_weights.append(down_weight_tp)
expected_w2_scales.append(down_scale_tp)
expected_w13_weight = torch.cat(expected_w13_weights)
expected_w13_scale = torch.cat(expected_w13_scales)
expected_w2_weight = torch.cat(expected_w2_weights)
expected_w2_scale = torch.cat(expected_w2_scales)
if not torch.equal(w13_weight_bufs[tp_idx], expected_w13_weight):
diff_mask = w13_weight_bufs[tp_idx] != expected_w13_weight
first_diff_idx = diff_mask.nonzero()[0].item() if diff_mask.any() else -1
raise AssertionError(f"[FP8] w13 weight mismatch for TP {tp_idx} at index {first_diff_idx}")
if not torch.allclose(w13_scale_bufs[tp_idx], expected_w13_scale):
raise AssertionError(f"[FP8] w13 scale mismatch for TP {tp_idx}")
if not torch.equal(w2_weight_bufs[tp_idx], expected_w2_weight):
diff_mask = w2_weight_bufs[tp_idx] != expected_w2_weight
first_diff_idx = diff_mask.nonzero()[0].item() if diff_mask.any() else -1
raise AssertionError(f"[FP8] w2 weight mismatch for TP {tp_idx} at index {first_diff_idx}")
if not torch.allclose(w2_scale_bufs[tp_idx], expected_w2_scale):
raise AssertionError(f"[FP8] w2 scale mismatch for TP {tp_idx}")
print(f"[FP8] TP={gpu_tp_count} PASSED (verified {gpu_experts} experts across {gpu_tp_count} TP parts)")
return True
def test_fp8_perchannel_write_buffer(gpu_tp_count):
"""Test write_weight_scale_to_buffer with FP8 per-channel weights"""
torch.manual_seed(123)
expert_num = 256
gpu_experts = expert_num
num_experts_per_tok = 8
hidden_size = 3072
intermediate_size = 1536
cpuinfer = make_cpu_infer()
cfg = build_config_fp8_perchannel(cpuinfer, expert_num, num_experts_per_tok, hidden_size, intermediate_size)
weights = allocate_weights_fp8_perchannel(expert_num, hidden_size, intermediate_size)
cfg.gate_proj = weights["gate_q"].data_ptr()
cfg.up_proj = weights["up_q"].data_ptr()
cfg.down_proj = weights["down_q"].data_ptr()
cfg.gate_scale = weights["gate_scale"].data_ptr()
cfg.up_scale = weights["up_scale"].data_ptr()
cfg.down_scale = weights["down_scale"].data_ptr()
moe = kt_kernel_ext.moe.AMXFP8PerChannel_MOE(cfg)
physical_to_logical_map = torch.arange(expert_num, dtype=torch.int64, device="cpu").contiguous()
cpuinfer.submit(moe.load_weights_task(physical_to_logical_map.data_ptr()))
cpuinfer.sync()
per_mat_weight_bytes = weights["per_mat_weight_bytes"]
per_mat_scale_elems_gate_up = weights["per_mat_scale_elems_gate_up"]
per_mat_scale_elems_down = weights["per_mat_scale_elems_down"]
weight_bytes_per_expert_per_tp = per_mat_weight_bytes // gpu_tp_count
gpu_n_w13 = intermediate_size // gpu_tp_count
scale_elems_per_expert_per_tp_gate_up = gpu_n_w13
scale_elems_per_expert_per_tp_down = per_mat_scale_elems_down
total_weight_bytes_per_tp = gpu_experts * weight_bytes_per_expert_per_tp
total_scale_elems_per_tp_gate_up = gpu_experts * scale_elems_per_expert_per_tp_gate_up
total_scale_elems_per_tp_down = gpu_experts * scale_elems_per_expert_per_tp_down
w13_weight_bufs = [torch.empty(2 * total_weight_bytes_per_tp, dtype=torch.uint8) for _ in range(gpu_tp_count)]
w13_scale_bufs = [
torch.empty(2 * total_scale_elems_per_tp_gate_up, dtype=torch.float32) for _ in range(gpu_tp_count)
]
w2_weight_bufs = [torch.empty(total_weight_bytes_per_tp, dtype=torch.uint8) for _ in range(gpu_tp_count)]
w2_scale_bufs = [torch.empty(total_scale_elems_per_tp_down, dtype=torch.float32) for _ in range(gpu_tp_count)]
print(f"[FP8_PERCHANNEL] GPU TP count: {gpu_tp_count}, Experts: {expert_num}")
print(f"[FP8_PERCHANNEL] Weight bytes per expert per TP: {weight_bytes_per_expert_per_tp}")
print(f"[FP8_PERCHANNEL] Scale elements per expert per TP (gate/up): {scale_elems_per_expert_per_tp_gate_up}")
def get_expert_ptrs(expert_id):
w13_weight_ptrs = []
w13_scale_ptrs = []
w2_weight_ptrs = []
w2_scale_ptrs = []
for tp_idx in range(gpu_tp_count):
w13_weight_expert_offset = expert_id * 2 * weight_bytes_per_expert_per_tp
w13_scale_expert_offset = expert_id * 2 * scale_elems_per_expert_per_tp_gate_up
w2_weight_expert_offset = expert_id * weight_bytes_per_expert_per_tp
w2_scale_expert_offset = expert_id * scale_elems_per_expert_per_tp_down
w13_weight_ptrs.append(w13_weight_bufs[tp_idx].data_ptr() + w13_weight_expert_offset)
w13_scale_ptrs.append(w13_scale_bufs[tp_idx].data_ptr() + w13_scale_expert_offset * 4)
w2_weight_ptrs.append(w2_weight_bufs[tp_idx].data_ptr() + w2_weight_expert_offset)
w2_scale_ptrs.append(w2_scale_bufs[tp_idx].data_ptr() + w2_scale_expert_offset * 4)
return w13_weight_ptrs, w13_scale_ptrs, w2_weight_ptrs, w2_scale_ptrs
for _ in range(2):
for expert_id in range(gpu_experts):
w13_weight_ptrs, w13_scale_ptrs, w2_weight_ptrs, w2_scale_ptrs = get_expert_ptrs(expert_id)
cpuinfer.submit(
moe.write_weight_scale_to_buffer_task(
gpu_tp_count=gpu_tp_count,
expert_id=expert_id,
w13_weight_ptrs=w13_weight_ptrs,
w13_scale_ptrs=w13_scale_ptrs,
w2_weight_ptrs=w2_weight_ptrs,
w2_scale_ptrs=w2_scale_ptrs,
)
)
cpuinfer.sync()
begin_time = time.perf_counter_ns()
for expert_id in range(gpu_experts):
w13_weight_ptrs, w13_scale_ptrs, w2_weight_ptrs, w2_scale_ptrs = get_expert_ptrs(expert_id)
cpuinfer.submit(
moe.write_weight_scale_to_buffer_task(
gpu_tp_count=gpu_tp_count,
expert_id=expert_id,
w13_weight_ptrs=w13_weight_ptrs,
w13_scale_ptrs=w13_scale_ptrs,
w2_weight_ptrs=w2_weight_ptrs,
w2_scale_ptrs=w2_scale_ptrs,
)
)
cpuinfer.sync()
end_time = time.perf_counter_ns()
elapsed_ms = (end_time - begin_time) / 1e6
total_bytes = (
hidden_size * intermediate_size * gpu_experts * 3
+ (per_mat_scale_elems_gate_up * 2 + per_mat_scale_elems_down) * gpu_experts * 4
)
print(f"[FP8_PERCHANNEL] write_weight_scale_to_buffer time: {elapsed_ms:.2f} ms")
print(f"[FP8_PERCHANNEL] Throughput: {total_bytes / (elapsed_ms * 1e6):.2f} GB/s")
def split_expert_tensor(tensor, chunk):
return [tensor[i * chunk : (i + 1) * chunk] for i in range(expert_num)]
gate_q = weights["gate_q"]
up_q = weights["up_q"]
down_q = weights["down_q"]
gate_scale = weights["gate_scale"]
up_scale = weights["up_scale"]
down_scale = weights["down_scale"]
gate_q_experts = split_expert_tensor(gate_q, per_mat_weight_bytes)
up_q_experts = split_expert_tensor(up_q, per_mat_weight_bytes)
down_q_experts = split_expert_tensor(down_q, per_mat_weight_bytes)
gate_scale_experts = split_expert_tensor(gate_scale, per_mat_scale_elems_gate_up)
up_scale_experts = split_expert_tensor(up_scale, per_mat_scale_elems_gate_up)
down_scale_experts = split_expert_tensor(down_scale, per_mat_scale_elems_down)
for tp_idx in range(gpu_tp_count):
expected_w13_weights = []
expected_w13_scales = []
expected_w2_weights = []
expected_w2_scales = []
weight13_per_tp = per_mat_weight_bytes // gpu_tp_count
scale13_per_tp = per_mat_scale_elems_gate_up // gpu_tp_count
for expert_id in range(gpu_experts):
start_weight = tp_idx * weight13_per_tp
end_weight = (tp_idx + 1) * weight13_per_tp
start_scale = tp_idx * scale13_per_tp
end_scale = (tp_idx + 1) * scale13_per_tp
gate_weight_tp = gate_q_experts[expert_id][start_weight:end_weight]
gate_scale_tp = gate_scale_experts[expert_id][start_scale:end_scale]
up_weight_tp = up_q_experts[expert_id][start_weight:end_weight]
up_scale_tp = up_scale_experts[expert_id][start_scale:end_scale]
down_weight_tp_parts = []
tp_slice_weight_size = intermediate_size // gpu_tp_count
for row_idx in range(hidden_size):
row_weight_start = row_idx * intermediate_size
tp_weight_offset = row_weight_start + tp_idx * tp_slice_weight_size
down_weight_tp_parts.append(
down_q_experts[expert_id][tp_weight_offset : tp_weight_offset + tp_slice_weight_size]
)
down_weight_tp = torch.cat(down_weight_tp_parts)
down_scale_tp = down_scale_experts[expert_id]
expected_w13_weights.append(gate_weight_tp)
expected_w13_weights.append(up_weight_tp)
expected_w13_scales.append(gate_scale_tp)
expected_w13_scales.append(up_scale_tp)
expected_w2_weights.append(down_weight_tp)
expected_w2_scales.append(down_scale_tp)
expected_w13_weight = torch.cat(expected_w13_weights)
expected_w13_scale = torch.cat(expected_w13_scales)
expected_w2_weight = torch.cat(expected_w2_weights)
expected_w2_scale = torch.cat(expected_w2_scales)
if not torch.equal(w13_weight_bufs[tp_idx], expected_w13_weight):
diff_mask = w13_weight_bufs[tp_idx] != expected_w13_weight
first_diff_idx = diff_mask.nonzero()[0].item() if diff_mask.any() else -1
raise AssertionError(f"[FP8_PERCHANNEL] w13 weight mismatch for TP {tp_idx} at index {first_diff_idx}")
if not torch.allclose(w13_scale_bufs[tp_idx], expected_w13_scale):
raise AssertionError(f"[FP8_PERCHANNEL] w13 scale mismatch for TP {tp_idx}")
if not torch.equal(w2_weight_bufs[tp_idx], expected_w2_weight):
diff_mask = w2_weight_bufs[tp_idx] != expected_w2_weight
first_diff_idx = diff_mask.nonzero()[0].item() if diff_mask.any() else -1
raise AssertionError(f"[FP8_PERCHANNEL] w2 weight mismatch for TP {tp_idx} at index {first_diff_idx}")
if not torch.allclose(w2_scale_bufs[tp_idx], expected_w2_scale):
raise AssertionError(f"[FP8_PERCHANNEL] w2 scale mismatch for TP {tp_idx}")
print(f"[FP8_PERCHANNEL] TP={gpu_tp_count} PASSED (verified {gpu_experts} experts across {gpu_tp_count} TP parts)")
return True
def test_bf16_write_buffer(gpu_tp_count):
"""Test write_weight_scale_to_buffer with BF16 weights (no scales)"""
torch.manual_seed(123)
expert_num = 16
gpu_experts = expert_num
num_experts_per_tok = 8
hidden_size = 3072
intermediate_size = 1536
cpuinfer = make_cpu_infer()
cfg = build_config_bf16(cpuinfer, expert_num, num_experts_per_tok, hidden_size, intermediate_size)
weights = allocate_weights_bf16(expert_num, hidden_size, intermediate_size)
cfg.gate_proj = weights["gate_proj"].data_ptr()
cfg.up_proj = weights["up_proj"].data_ptr()
cfg.down_proj = weights["down_proj"].data_ptr()
cfg.gate_scale = 0
cfg.up_scale = 0
cfg.down_scale = 0
moe = kt_kernel_ext.moe.AMXBF16_MOE(cfg)
physical_to_logical_map = torch.arange(expert_num, dtype=torch.int64, device="cpu").contiguous()
cpuinfer.submit(moe.load_weights_task(physical_to_logical_map.data_ptr()))
cpuinfer.sync()
per_mat_weight_elems = weights["per_mat_weight_elems"]
# Calculate sizes per TP part (BF16 = 2 bytes per element)
weight_elems_per_expert_per_tp = per_mat_weight_elems // gpu_tp_count
weight_bytes_per_expert_per_tp = weight_elems_per_expert_per_tp * 2
total_weight_bytes_per_tp = gpu_experts * weight_bytes_per_expert_per_tp
# Create buffer lists (BF16: weights only, no scales)
w13_weight_bufs = [torch.empty(2 * total_weight_bytes_per_tp, dtype=torch.uint8) for _ in range(gpu_tp_count)]
w2_weight_bufs = [torch.empty(total_weight_bytes_per_tp, dtype=torch.uint8) for _ in range(gpu_tp_count)]
# Empty scale buffers (not used for BF16 but needed for interface)
w13_scale_bufs = [torch.empty(1, dtype=torch.float32) for _ in range(gpu_tp_count)]
w2_scale_bufs = [torch.empty(1, dtype=torch.float32) for _ in range(gpu_tp_count)]
print(f"[BF16] GPU TP count: {gpu_tp_count}, Experts: {expert_num}")
print(f"[BF16] Weight bytes per expert per TP: {weight_bytes_per_expert_per_tp}")
def get_expert_ptrs(expert_id):
w13_weight_ptrs = []
w13_scale_ptrs = []
w2_weight_ptrs = []
w2_scale_ptrs = []
for tp_idx in range(gpu_tp_count):
w13_weight_expert_offset = expert_id * 2 * weight_bytes_per_expert_per_tp
w2_weight_expert_offset = expert_id * weight_bytes_per_expert_per_tp
w13_weight_ptrs.append(w13_weight_bufs[tp_idx].data_ptr() + w13_weight_expert_offset)
w13_scale_ptrs.append(w13_scale_bufs[tp_idx].data_ptr()) # Not used
w2_weight_ptrs.append(w2_weight_bufs[tp_idx].data_ptr() + w2_weight_expert_offset)
w2_scale_ptrs.append(w2_scale_bufs[tp_idx].data_ptr()) # Not used
return w13_weight_ptrs, w13_scale_ptrs, w2_weight_ptrs, w2_scale_ptrs
# Warm up
for _ in range(2):
for expert_id in range(gpu_experts):
w13_weight_ptrs, w13_scale_ptrs, w2_weight_ptrs, w2_scale_ptrs = get_expert_ptrs(expert_id)
cpuinfer.submit(
moe.write_weight_scale_to_buffer_task(
gpu_tp_count=gpu_tp_count,
expert_id=expert_id,
w13_weight_ptrs=w13_weight_ptrs,
w13_scale_ptrs=w13_scale_ptrs,
w2_weight_ptrs=w2_weight_ptrs,
w2_scale_ptrs=w2_scale_ptrs,
)
)
cpuinfer.sync()
# Timing
begin_time = time.perf_counter_ns()
for expert_id in range(gpu_experts):
w13_weight_ptrs, w13_scale_ptrs, w2_weight_ptrs, w2_scale_ptrs = get_expert_ptrs(expert_id)
cpuinfer.submit(
moe.write_weight_scale_to_buffer_task(
gpu_tp_count=gpu_tp_count,
expert_id=expert_id,
w13_weight_ptrs=w13_weight_ptrs,
w13_scale_ptrs=w13_scale_ptrs,
w2_weight_ptrs=w2_weight_ptrs,
w2_scale_ptrs=w2_scale_ptrs,
)
)
cpuinfer.sync()
end_time = time.perf_counter_ns()
elapsed_ms = (end_time - begin_time) / 1e6
total_bytes = hidden_size * intermediate_size * gpu_experts * 3 * 2 # BF16 = 2 bytes
print(f"[BF16] write_weight_scale_to_buffer time: {elapsed_ms:.2f} ms")
print(f"[BF16] Throughput: {total_bytes / (elapsed_ms * 1e6):.2f} GB/s")
# Verify correctness (BF16: weights only, no scales)
def split_expert_tensor(tensor, chunk):
return [tensor[i * chunk : (i + 1) * chunk] for i in range(expert_num)]
gate_proj = weights["gate_proj"]
up_proj = weights["up_proj"]
down_proj = weights["down_proj"]
# View BF16 as uint8 for byte-level comparison
gate_bytes = gate_proj.view(torch.uint8)
up_bytes = up_proj.view(torch.uint8)
down_bytes = down_proj.view(torch.uint8)
per_mat_bytes = per_mat_weight_elems * 2 # BF16 = 2 bytes
gate_experts = split_expert_tensor(gate_bytes, per_mat_bytes)
up_experts = split_expert_tensor(up_bytes, per_mat_bytes)
down_experts = split_expert_tensor(down_bytes, per_mat_bytes)
for tp_idx in range(gpu_tp_count):
expected_w13_weights = []
expected_w2_weights = []
weight_bytes_per_tp = per_mat_bytes // gpu_tp_count
for expert_id in range(gpu_experts):
start_weight = tp_idx * weight_bytes_per_tp
end_weight = (tp_idx + 1) * weight_bytes_per_tp
gate_weight_tp = gate_experts[expert_id][start_weight:end_weight]
up_weight_tp = up_experts[expert_id][start_weight:end_weight]
# Down matrix: sliced column-wise (BF16 = 2 bytes per element)
down_weight_tp_parts = []
tp_slice_elems = intermediate_size // gpu_tp_count
tp_slice_bytes = tp_slice_elems * 2
for row_idx in range(hidden_size):
row_byte_start = row_idx * intermediate_size * 2
tp_byte_offset = row_byte_start + tp_idx * tp_slice_bytes
down_weight_tp_parts.append(down_experts[expert_id][tp_byte_offset : tp_byte_offset + tp_slice_bytes])
down_weight_tp = torch.cat(down_weight_tp_parts)
expected_w13_weights.append(gate_weight_tp)
expected_w13_weights.append(up_weight_tp)
expected_w2_weights.append(down_weight_tp)
expected_w13_weight = torch.cat(expected_w13_weights)
expected_w2_weight = torch.cat(expected_w2_weights)
if not torch.equal(w13_weight_bufs[tp_idx], expected_w13_weight):
diff_mask = w13_weight_bufs[tp_idx] != expected_w13_weight
first_diff_idx = diff_mask.nonzero()[0].item() if diff_mask.any() else -1
raise AssertionError(f"[BF16] w13 weight mismatch for TP {tp_idx} at index {first_diff_idx}")
if not torch.equal(w2_weight_bufs[tp_idx], expected_w2_weight):
diff_mask = w2_weight_bufs[tp_idx] != expected_w2_weight
first_diff_idx = diff_mask.nonzero()[0].item() if diff_mask.any() else -1
raise AssertionError(f"[BF16] w2 weight mismatch for TP {tp_idx} at index {first_diff_idx}")
print(f"[BF16] TP={gpu_tp_count} PASSED (verified {gpu_experts} experts across {gpu_tp_count} TP parts)")
return True
def test_with_tp(quant_mode: str, gpu_tp_count: int):
"""Test write_weight_scale_to_buffer with specified mode and TP count"""
if quant_mode == "fp8":
return test_fp8_write_buffer(gpu_tp_count)
elif quant_mode == "fp8_perchannel":
return test_fp8_perchannel_write_buffer(gpu_tp_count)
elif quant_mode == "bf16":
return test_bf16_write_buffer(gpu_tp_count)
else:
raise ValueError(f"Unsupported quant_mode: {quant_mode}")
def main(quant_modes=None):
"""Run tests for specified quant modes"""
if quant_modes is None:
quant_modes = ["fp8", "fp8_perchannel", "bf16"]
tp_values = [1, 2, 4]
all_passed = True
results = {}
for quant_mode in quant_modes:
print("\n" + "=" * 60)
print(f"Testing {quant_mode.upper()} write_weight_scale_to_buffer")
print("=" * 60)
for tp in tp_values:
print(f"\n--- Testing {quant_mode.upper()} with gpu_tp_count = {tp} ---")
try:
test_with_tp(quant_mode, tp)
results[(quant_mode, tp)] = "PASSED"
except Exception as e:
results[(quant_mode, tp)] = f"FAILED: {e}"
all_passed = False
print(f"[{quant_mode.upper()}] TP={tp} FAILED: {e}")
print("\n" + "=" * 60)
print("SUMMARY")
print("=" * 60)
for (mode, tp), result in results.items():
status = "PASS" if "PASSED" in result else "FAIL"
print(f" [{status}] {mode.upper()} TP={tp}: {result}")
if all_passed:
print("\nALL TESTS PASSED")
else:
print("\nSOME TESTS FAILED")
sys.exit(1)
if __name__ == "__main__":
if len(sys.argv) > 1:
mode = sys.argv[1].lower()
if mode in ["fp8", "fp8_perchannel", "bf16"]:
main([mode])
else:
print(f"Unknown mode: {mode}. Use 'fp8', 'fp8_perchannel' or 'bf16'")
sys.exit(1)
else:
main()