mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2026-03-14 18:37:23 +00:00
766 lines
33 KiB
Python
766 lines
33 KiB
Python
"""
|
|
Test write_weight_scale_to_buffer for AMX MOE operators.
|
|
|
|
Supports:
|
|
- FP8: FP8 weights (1 byte) + float32 scales (block-wise)
|
|
- FP8_PERCHANNEL: FP8 weights (1 byte) + float32 per-channel scales
|
|
- BF16: Native BF16 weights (2 bytes), no scales
|
|
|
|
Usage:
|
|
python test_write_buffer.py # Run all modes
|
|
python test_write_buffer.py fp8 # Run FP8 only
|
|
python test_write_buffer.py fp8_perchannel # Run FP8 per-channel only
|
|
python test_write_buffer.py bf16 # Run BF16 only
|
|
"""
|
|
|
|
import os
|
|
import sys
|
|
import time
|
|
|
|
import torch
|
|
|
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "build"))
|
|
|
|
from kt_kernel import kt_kernel_ext
|
|
from kt_kernel_ext import CPUInfer
|
|
|
|
|
|
def make_cpu_infer(thread_num=80):
|
|
return CPUInfer(thread_num)
|
|
|
|
|
|
def div_up(a, b):
|
|
return (a + b - 1) // b
|
|
|
|
|
|
def build_config_fp8(cpuinfer, expert_num, num_experts_per_tok, hidden_size, intermediate_size, group_size):
|
|
cfg = kt_kernel_ext.moe.MOEConfig(expert_num, num_experts_per_tok, hidden_size, intermediate_size)
|
|
cfg.max_len = 1
|
|
cfg.quant_config.bits = 8 # FP8
|
|
cfg.quant_config.group_size = group_size
|
|
cfg.quant_config.zero_point = False
|
|
cfg.pool = cpuinfer.backend_
|
|
return cfg
|
|
|
|
|
|
def build_config_fp8_perchannel(cpuinfer, expert_num, num_experts_per_tok, hidden_size, intermediate_size):
|
|
cfg = kt_kernel_ext.moe.MOEConfig(expert_num, num_experts_per_tok, hidden_size, intermediate_size)
|
|
cfg.max_len = 1
|
|
cfg.quant_config.bits = 8 # FP8
|
|
cfg.quant_config.group_size = 0 # Not used for per-channel
|
|
cfg.quant_config.zero_point = False
|
|
cfg.quant_config.per_channel = True
|
|
cfg.pool = cpuinfer.backend_
|
|
return cfg
|
|
|
|
|
|
def build_config_bf16(cpuinfer, expert_num, num_experts_per_tok, hidden_size, intermediate_size):
|
|
cfg = kt_kernel_ext.moe.MOEConfig(expert_num, num_experts_per_tok, hidden_size, intermediate_size)
|
|
cfg.max_len = 1
|
|
cfg.pool = cpuinfer.backend_
|
|
return cfg
|
|
|
|
|
|
def allocate_weights_fp8(expert_num, hidden_size, intermediate_size, group_size):
|
|
"""Allocate FP8 weights and scales for testing"""
|
|
# FP8 weights: 1 byte per element
|
|
per_mat_weight_bytes = hidden_size * intermediate_size
|
|
# FP8 scales: block-wise (group_size x group_size blocks), stored as float32
|
|
n_blocks_n_gate_up = div_up(intermediate_size, group_size)
|
|
n_blocks_k = div_up(hidden_size, group_size)
|
|
per_mat_scale_elems_gate_up = n_blocks_n_gate_up * n_blocks_k
|
|
|
|
# For down: n=hidden_size, k=intermediate_size
|
|
n_blocks_n_down = n_blocks_k
|
|
n_blocks_k_down = n_blocks_n_gate_up
|
|
per_mat_scale_elems_down = n_blocks_n_down * n_blocks_k_down
|
|
|
|
gate_q = torch.randint(0, 256, (expert_num * per_mat_weight_bytes,), dtype=torch.uint8)
|
|
up_q = torch.randint(0, 256, (expert_num * per_mat_weight_bytes,), dtype=torch.uint8)
|
|
down_q = torch.randint(0, 256, (expert_num * per_mat_weight_bytes,), dtype=torch.uint8)
|
|
|
|
gate_scale = torch.randn(expert_num * per_mat_scale_elems_gate_up, dtype=torch.float32)
|
|
up_scale = torch.randn(expert_num * per_mat_scale_elems_gate_up, dtype=torch.float32)
|
|
down_scale = torch.randn(expert_num * per_mat_scale_elems_down, dtype=torch.float32)
|
|
|
|
return {
|
|
"gate_q": gate_q,
|
|
"up_q": up_q,
|
|
"down_q": down_q,
|
|
"gate_scale": gate_scale,
|
|
"up_scale": up_scale,
|
|
"down_scale": down_scale,
|
|
"per_mat_weight_bytes": per_mat_weight_bytes,
|
|
"per_mat_scale_elems_gate_up": per_mat_scale_elems_gate_up,
|
|
"per_mat_scale_elems_down": per_mat_scale_elems_down,
|
|
}
|
|
|
|
|
|
def allocate_weights_fp8_perchannel(expert_num, hidden_size, intermediate_size):
|
|
"""Allocate FP8 per-channel weights and scales for testing"""
|
|
per_mat_weight_bytes = hidden_size * intermediate_size
|
|
per_mat_scale_elems_gate_up = intermediate_size # one scale per output channel
|
|
per_mat_scale_elems_down = hidden_size
|
|
|
|
gate_q = torch.randint(0, 256, (expert_num * per_mat_weight_bytes,), dtype=torch.uint8)
|
|
up_q = torch.randint(0, 256, (expert_num * per_mat_weight_bytes,), dtype=torch.uint8)
|
|
down_q = torch.randint(0, 256, (expert_num * per_mat_weight_bytes,), dtype=torch.uint8)
|
|
|
|
gate_scale = torch.randn(expert_num * per_mat_scale_elems_gate_up, dtype=torch.float32)
|
|
up_scale = torch.randn(expert_num * per_mat_scale_elems_gate_up, dtype=torch.float32)
|
|
down_scale = torch.randn(expert_num * per_mat_scale_elems_down, dtype=torch.float32)
|
|
|
|
return {
|
|
"gate_q": gate_q,
|
|
"up_q": up_q,
|
|
"down_q": down_q,
|
|
"gate_scale": gate_scale,
|
|
"up_scale": up_scale,
|
|
"down_scale": down_scale,
|
|
"per_mat_weight_bytes": per_mat_weight_bytes,
|
|
"per_mat_scale_elems_gate_up": per_mat_scale_elems_gate_up,
|
|
"per_mat_scale_elems_down": per_mat_scale_elems_down,
|
|
}
|
|
|
|
|
|
def allocate_weights_bf16(expert_num, hidden_size, intermediate_size):
|
|
"""Allocate BF16 weights for testing (no scales)"""
|
|
# BF16 weights: 2 bytes per element
|
|
per_mat_weight_elems = hidden_size * intermediate_size
|
|
per_mat_weight_bytes = per_mat_weight_elems * 2 # BF16 = 2 bytes
|
|
|
|
gate_proj = torch.randn(expert_num * per_mat_weight_elems, dtype=torch.bfloat16)
|
|
up_proj = torch.randn(expert_num * per_mat_weight_elems, dtype=torch.bfloat16)
|
|
down_proj = torch.randn(expert_num * per_mat_weight_elems, dtype=torch.bfloat16)
|
|
|
|
return {
|
|
"gate_proj": gate_proj,
|
|
"up_proj": up_proj,
|
|
"down_proj": down_proj,
|
|
"per_mat_weight_bytes": per_mat_weight_bytes,
|
|
"per_mat_weight_elems": per_mat_weight_elems,
|
|
}
|
|
|
|
|
|
def test_fp8_write_buffer(gpu_tp_count):
|
|
"""Test write_weight_scale_to_buffer with FP8 weights"""
|
|
torch.manual_seed(123)
|
|
|
|
expert_num = 256
|
|
gpu_experts = expert_num
|
|
num_experts_per_tok = 8
|
|
hidden_size = 3072
|
|
intermediate_size = 1536
|
|
group_size = 128
|
|
|
|
cpuinfer = make_cpu_infer()
|
|
cfg = build_config_fp8(cpuinfer, expert_num, num_experts_per_tok, hidden_size, intermediate_size, group_size)
|
|
weights = allocate_weights_fp8(expert_num, hidden_size, intermediate_size, group_size)
|
|
|
|
cfg.gate_proj = weights["gate_q"].data_ptr()
|
|
cfg.up_proj = weights["up_q"].data_ptr()
|
|
cfg.down_proj = weights["down_q"].data_ptr()
|
|
cfg.gate_scale = weights["gate_scale"].data_ptr()
|
|
cfg.up_scale = weights["up_scale"].data_ptr()
|
|
cfg.down_scale = weights["down_scale"].data_ptr()
|
|
|
|
moe = kt_kernel_ext.moe.AMXFP8_MOE(cfg)
|
|
|
|
physical_to_logical_map = torch.arange(expert_num, dtype=torch.int64, device="cpu").contiguous()
|
|
cpuinfer.submit(moe.load_weights_task(physical_to_logical_map.data_ptr()))
|
|
cpuinfer.sync()
|
|
|
|
per_mat_weight_bytes = weights["per_mat_weight_bytes"]
|
|
per_mat_scale_elems_gate_up = weights["per_mat_scale_elems_gate_up"]
|
|
per_mat_scale_elems_down = weights["per_mat_scale_elems_down"]
|
|
|
|
# Calculate sizes per TP part
|
|
weight_bytes_per_expert_per_tp = per_mat_weight_bytes // gpu_tp_count
|
|
gpu_n_w13 = intermediate_size // gpu_tp_count
|
|
gpu_k_w13 = hidden_size
|
|
scale_elems_per_expert_per_tp_gate_up = div_up(gpu_n_w13, group_size) * div_up(gpu_k_w13, group_size)
|
|
gpu_n_w2 = hidden_size
|
|
gpu_k_w2 = intermediate_size // gpu_tp_count
|
|
scale_elems_per_expert_per_tp_down = div_up(gpu_n_w2, group_size) * div_up(gpu_k_w2, group_size)
|
|
|
|
total_weight_bytes_per_tp = gpu_experts * weight_bytes_per_expert_per_tp
|
|
total_scale_elems_per_tp_gate_up = gpu_experts * scale_elems_per_expert_per_tp_gate_up
|
|
total_scale_elems_per_tp_down = gpu_experts * scale_elems_per_expert_per_tp_down
|
|
|
|
# Create buffer lists
|
|
w13_weight_bufs = [torch.empty(2 * total_weight_bytes_per_tp, dtype=torch.uint8) for _ in range(gpu_tp_count)]
|
|
w13_scale_bufs = [
|
|
torch.empty(2 * total_scale_elems_per_tp_gate_up, dtype=torch.float32) for _ in range(gpu_tp_count)
|
|
]
|
|
w2_weight_bufs = [torch.empty(total_weight_bytes_per_tp, dtype=torch.uint8) for _ in range(gpu_tp_count)]
|
|
w2_scale_bufs = [torch.empty(total_scale_elems_per_tp_down, dtype=torch.float32) for _ in range(gpu_tp_count)]
|
|
|
|
print(f"[FP8] GPU TP count: {gpu_tp_count}, Experts: {expert_num}")
|
|
print(f"[FP8] Weight bytes per expert per TP: {weight_bytes_per_expert_per_tp}")
|
|
print(f"[FP8] Scale elements per expert per TP (gate/up): {scale_elems_per_expert_per_tp_gate_up}")
|
|
|
|
def get_expert_ptrs(expert_id):
|
|
w13_weight_ptrs = []
|
|
w13_scale_ptrs = []
|
|
w2_weight_ptrs = []
|
|
w2_scale_ptrs = []
|
|
for tp_idx in range(gpu_tp_count):
|
|
w13_weight_expert_offset = expert_id * 2 * weight_bytes_per_expert_per_tp
|
|
w13_scale_expert_offset = expert_id * 2 * scale_elems_per_expert_per_tp_gate_up
|
|
w2_weight_expert_offset = expert_id * weight_bytes_per_expert_per_tp
|
|
w2_scale_expert_offset = expert_id * scale_elems_per_expert_per_tp_down
|
|
|
|
w13_weight_ptrs.append(w13_weight_bufs[tp_idx].data_ptr() + w13_weight_expert_offset)
|
|
w13_scale_ptrs.append(w13_scale_bufs[tp_idx].data_ptr() + w13_scale_expert_offset * 4)
|
|
w2_weight_ptrs.append(w2_weight_bufs[tp_idx].data_ptr() + w2_weight_expert_offset)
|
|
w2_scale_ptrs.append(w2_scale_bufs[tp_idx].data_ptr() + w2_scale_expert_offset * 4)
|
|
return w13_weight_ptrs, w13_scale_ptrs, w2_weight_ptrs, w2_scale_ptrs
|
|
|
|
# Warm up
|
|
for _ in range(2):
|
|
for expert_id in range(gpu_experts):
|
|
w13_weight_ptrs, w13_scale_ptrs, w2_weight_ptrs, w2_scale_ptrs = get_expert_ptrs(expert_id)
|
|
cpuinfer.submit(
|
|
moe.write_weight_scale_to_buffer_task(
|
|
gpu_tp_count=gpu_tp_count,
|
|
expert_id=expert_id,
|
|
w13_weight_ptrs=w13_weight_ptrs,
|
|
w13_scale_ptrs=w13_scale_ptrs,
|
|
w2_weight_ptrs=w2_weight_ptrs,
|
|
w2_scale_ptrs=w2_scale_ptrs,
|
|
)
|
|
)
|
|
cpuinfer.sync()
|
|
|
|
# Timing
|
|
begin_time = time.perf_counter_ns()
|
|
for expert_id in range(gpu_experts):
|
|
w13_weight_ptrs, w13_scale_ptrs, w2_weight_ptrs, w2_scale_ptrs = get_expert_ptrs(expert_id)
|
|
cpuinfer.submit(
|
|
moe.write_weight_scale_to_buffer_task(
|
|
gpu_tp_count=gpu_tp_count,
|
|
expert_id=expert_id,
|
|
w13_weight_ptrs=w13_weight_ptrs,
|
|
w13_scale_ptrs=w13_scale_ptrs,
|
|
w2_weight_ptrs=w2_weight_ptrs,
|
|
w2_scale_ptrs=w2_scale_ptrs,
|
|
)
|
|
)
|
|
cpuinfer.sync()
|
|
end_time = time.perf_counter_ns()
|
|
elapsed_ms = (end_time - begin_time) / 1e6
|
|
|
|
total_bytes = (
|
|
hidden_size * intermediate_size * gpu_experts * 3
|
|
+ (per_mat_scale_elems_gate_up * 2 + per_mat_scale_elems_down) * gpu_experts * 4
|
|
)
|
|
print(f"[FP8] write_weight_scale_to_buffer time: {elapsed_ms:.2f} ms")
|
|
print(f"[FP8] Throughput: {total_bytes / (elapsed_ms * 1e6):.2f} GB/s")
|
|
|
|
# Verify correctness
|
|
def split_expert_tensor(tensor, chunk):
|
|
return [tensor[i * chunk : (i + 1) * chunk] for i in range(expert_num)]
|
|
|
|
gate_q = weights["gate_q"]
|
|
up_q = weights["up_q"]
|
|
down_q = weights["down_q"]
|
|
gate_scale = weights["gate_scale"]
|
|
up_scale = weights["up_scale"]
|
|
down_scale = weights["down_scale"]
|
|
|
|
gate_q_experts = split_expert_tensor(gate_q, per_mat_weight_bytes)
|
|
up_q_experts = split_expert_tensor(up_q, per_mat_weight_bytes)
|
|
down_q_experts = split_expert_tensor(down_q, per_mat_weight_bytes)
|
|
gate_scale_experts = split_expert_tensor(gate_scale, per_mat_scale_elems_gate_up)
|
|
up_scale_experts = split_expert_tensor(up_scale, per_mat_scale_elems_gate_up)
|
|
down_scale_experts = split_expert_tensor(down_scale, per_mat_scale_elems_down)
|
|
|
|
n_blocks_n = div_up(hidden_size, group_size)
|
|
n_blocks_k = div_up(intermediate_size, group_size)
|
|
n_blocks_k_per_tp = n_blocks_k // gpu_tp_count
|
|
|
|
for tp_idx in range(gpu_tp_count):
|
|
expected_w13_weights = []
|
|
expected_w13_scales = []
|
|
expected_w2_weights = []
|
|
expected_w2_scales = []
|
|
|
|
weight13_per_tp = per_mat_weight_bytes // gpu_tp_count
|
|
scale13_per_tp = per_mat_scale_elems_gate_up // gpu_tp_count
|
|
|
|
for expert_id in range(gpu_experts):
|
|
start_weight = tp_idx * weight13_per_tp
|
|
end_weight = (tp_idx + 1) * weight13_per_tp
|
|
start_scale = tp_idx * scale13_per_tp
|
|
end_scale = (tp_idx + 1) * scale13_per_tp
|
|
|
|
gate_weight_tp = gate_q_experts[expert_id][start_weight:end_weight]
|
|
gate_scale_tp = gate_scale_experts[expert_id][start_scale:end_scale]
|
|
up_weight_tp = up_q_experts[expert_id][start_weight:end_weight]
|
|
up_scale_tp = up_scale_experts[expert_id][start_scale:end_scale]
|
|
|
|
down_weight_tp_parts = []
|
|
down_scale_tp_parts = []
|
|
tp_slice_weight_size = intermediate_size // gpu_tp_count
|
|
|
|
for row_idx in range(hidden_size):
|
|
row_weight_start = row_idx * intermediate_size
|
|
tp_weight_offset = row_weight_start + tp_idx * tp_slice_weight_size
|
|
down_weight_tp_parts.append(
|
|
down_q_experts[expert_id][tp_weight_offset : tp_weight_offset + tp_slice_weight_size]
|
|
)
|
|
|
|
for bn in range(n_blocks_n):
|
|
row_scale_start = bn * n_blocks_k
|
|
tp_scale_offset = row_scale_start + tp_idx * n_blocks_k_per_tp
|
|
down_scale_tp_parts.append(
|
|
down_scale_experts[expert_id][tp_scale_offset : tp_scale_offset + n_blocks_k_per_tp]
|
|
)
|
|
|
|
down_weight_tp = torch.cat(down_weight_tp_parts)
|
|
down_scale_tp = torch.cat(down_scale_tp_parts)
|
|
|
|
expected_w13_weights.append(gate_weight_tp)
|
|
expected_w13_weights.append(up_weight_tp)
|
|
expected_w13_scales.append(gate_scale_tp)
|
|
expected_w13_scales.append(up_scale_tp)
|
|
expected_w2_weights.append(down_weight_tp)
|
|
expected_w2_scales.append(down_scale_tp)
|
|
|
|
expected_w13_weight = torch.cat(expected_w13_weights)
|
|
expected_w13_scale = torch.cat(expected_w13_scales)
|
|
expected_w2_weight = torch.cat(expected_w2_weights)
|
|
expected_w2_scale = torch.cat(expected_w2_scales)
|
|
|
|
if not torch.equal(w13_weight_bufs[tp_idx], expected_w13_weight):
|
|
diff_mask = w13_weight_bufs[tp_idx] != expected_w13_weight
|
|
first_diff_idx = diff_mask.nonzero()[0].item() if diff_mask.any() else -1
|
|
raise AssertionError(f"[FP8] w13 weight mismatch for TP {tp_idx} at index {first_diff_idx}")
|
|
|
|
if not torch.allclose(w13_scale_bufs[tp_idx], expected_w13_scale):
|
|
raise AssertionError(f"[FP8] w13 scale mismatch for TP {tp_idx}")
|
|
|
|
if not torch.equal(w2_weight_bufs[tp_idx], expected_w2_weight):
|
|
diff_mask = w2_weight_bufs[tp_idx] != expected_w2_weight
|
|
first_diff_idx = diff_mask.nonzero()[0].item() if diff_mask.any() else -1
|
|
raise AssertionError(f"[FP8] w2 weight mismatch for TP {tp_idx} at index {first_diff_idx}")
|
|
|
|
if not torch.allclose(w2_scale_bufs[tp_idx], expected_w2_scale):
|
|
raise AssertionError(f"[FP8] w2 scale mismatch for TP {tp_idx}")
|
|
|
|
print(f"[FP8] TP={gpu_tp_count} PASSED (verified {gpu_experts} experts across {gpu_tp_count} TP parts)")
|
|
return True
|
|
|
|
|
|
def test_fp8_perchannel_write_buffer(gpu_tp_count):
|
|
"""Test write_weight_scale_to_buffer with FP8 per-channel weights"""
|
|
torch.manual_seed(123)
|
|
|
|
expert_num = 256
|
|
gpu_experts = expert_num
|
|
num_experts_per_tok = 8
|
|
hidden_size = 3072
|
|
intermediate_size = 1536
|
|
|
|
cpuinfer = make_cpu_infer()
|
|
cfg = build_config_fp8_perchannel(cpuinfer, expert_num, num_experts_per_tok, hidden_size, intermediate_size)
|
|
weights = allocate_weights_fp8_perchannel(expert_num, hidden_size, intermediate_size)
|
|
|
|
cfg.gate_proj = weights["gate_q"].data_ptr()
|
|
cfg.up_proj = weights["up_q"].data_ptr()
|
|
cfg.down_proj = weights["down_q"].data_ptr()
|
|
cfg.gate_scale = weights["gate_scale"].data_ptr()
|
|
cfg.up_scale = weights["up_scale"].data_ptr()
|
|
cfg.down_scale = weights["down_scale"].data_ptr()
|
|
|
|
moe = kt_kernel_ext.moe.AMXFP8PerChannel_MOE(cfg)
|
|
|
|
physical_to_logical_map = torch.arange(expert_num, dtype=torch.int64, device="cpu").contiguous()
|
|
cpuinfer.submit(moe.load_weights_task(physical_to_logical_map.data_ptr()))
|
|
cpuinfer.sync()
|
|
|
|
per_mat_weight_bytes = weights["per_mat_weight_bytes"]
|
|
per_mat_scale_elems_gate_up = weights["per_mat_scale_elems_gate_up"]
|
|
per_mat_scale_elems_down = weights["per_mat_scale_elems_down"]
|
|
|
|
weight_bytes_per_expert_per_tp = per_mat_weight_bytes // gpu_tp_count
|
|
gpu_n_w13 = intermediate_size // gpu_tp_count
|
|
scale_elems_per_expert_per_tp_gate_up = gpu_n_w13
|
|
scale_elems_per_expert_per_tp_down = per_mat_scale_elems_down
|
|
|
|
total_weight_bytes_per_tp = gpu_experts * weight_bytes_per_expert_per_tp
|
|
total_scale_elems_per_tp_gate_up = gpu_experts * scale_elems_per_expert_per_tp_gate_up
|
|
total_scale_elems_per_tp_down = gpu_experts * scale_elems_per_expert_per_tp_down
|
|
|
|
w13_weight_bufs = [torch.empty(2 * total_weight_bytes_per_tp, dtype=torch.uint8) for _ in range(gpu_tp_count)]
|
|
w13_scale_bufs = [
|
|
torch.empty(2 * total_scale_elems_per_tp_gate_up, dtype=torch.float32) for _ in range(gpu_tp_count)
|
|
]
|
|
w2_weight_bufs = [torch.empty(total_weight_bytes_per_tp, dtype=torch.uint8) for _ in range(gpu_tp_count)]
|
|
w2_scale_bufs = [torch.empty(total_scale_elems_per_tp_down, dtype=torch.float32) for _ in range(gpu_tp_count)]
|
|
|
|
print(f"[FP8_PERCHANNEL] GPU TP count: {gpu_tp_count}, Experts: {expert_num}")
|
|
print(f"[FP8_PERCHANNEL] Weight bytes per expert per TP: {weight_bytes_per_expert_per_tp}")
|
|
print(f"[FP8_PERCHANNEL] Scale elements per expert per TP (gate/up): {scale_elems_per_expert_per_tp_gate_up}")
|
|
|
|
def get_expert_ptrs(expert_id):
|
|
w13_weight_ptrs = []
|
|
w13_scale_ptrs = []
|
|
w2_weight_ptrs = []
|
|
w2_scale_ptrs = []
|
|
for tp_idx in range(gpu_tp_count):
|
|
w13_weight_expert_offset = expert_id * 2 * weight_bytes_per_expert_per_tp
|
|
w13_scale_expert_offset = expert_id * 2 * scale_elems_per_expert_per_tp_gate_up
|
|
w2_weight_expert_offset = expert_id * weight_bytes_per_expert_per_tp
|
|
w2_scale_expert_offset = expert_id * scale_elems_per_expert_per_tp_down
|
|
|
|
w13_weight_ptrs.append(w13_weight_bufs[tp_idx].data_ptr() + w13_weight_expert_offset)
|
|
w13_scale_ptrs.append(w13_scale_bufs[tp_idx].data_ptr() + w13_scale_expert_offset * 4)
|
|
w2_weight_ptrs.append(w2_weight_bufs[tp_idx].data_ptr() + w2_weight_expert_offset)
|
|
w2_scale_ptrs.append(w2_scale_bufs[tp_idx].data_ptr() + w2_scale_expert_offset * 4)
|
|
return w13_weight_ptrs, w13_scale_ptrs, w2_weight_ptrs, w2_scale_ptrs
|
|
|
|
for _ in range(2):
|
|
for expert_id in range(gpu_experts):
|
|
w13_weight_ptrs, w13_scale_ptrs, w2_weight_ptrs, w2_scale_ptrs = get_expert_ptrs(expert_id)
|
|
cpuinfer.submit(
|
|
moe.write_weight_scale_to_buffer_task(
|
|
gpu_tp_count=gpu_tp_count,
|
|
expert_id=expert_id,
|
|
w13_weight_ptrs=w13_weight_ptrs,
|
|
w13_scale_ptrs=w13_scale_ptrs,
|
|
w2_weight_ptrs=w2_weight_ptrs,
|
|
w2_scale_ptrs=w2_scale_ptrs,
|
|
)
|
|
)
|
|
cpuinfer.sync()
|
|
|
|
begin_time = time.perf_counter_ns()
|
|
for expert_id in range(gpu_experts):
|
|
w13_weight_ptrs, w13_scale_ptrs, w2_weight_ptrs, w2_scale_ptrs = get_expert_ptrs(expert_id)
|
|
cpuinfer.submit(
|
|
moe.write_weight_scale_to_buffer_task(
|
|
gpu_tp_count=gpu_tp_count,
|
|
expert_id=expert_id,
|
|
w13_weight_ptrs=w13_weight_ptrs,
|
|
w13_scale_ptrs=w13_scale_ptrs,
|
|
w2_weight_ptrs=w2_weight_ptrs,
|
|
w2_scale_ptrs=w2_scale_ptrs,
|
|
)
|
|
)
|
|
cpuinfer.sync()
|
|
end_time = time.perf_counter_ns()
|
|
elapsed_ms = (end_time - begin_time) / 1e6
|
|
|
|
total_bytes = (
|
|
hidden_size * intermediate_size * gpu_experts * 3
|
|
+ (per_mat_scale_elems_gate_up * 2 + per_mat_scale_elems_down) * gpu_experts * 4
|
|
)
|
|
print(f"[FP8_PERCHANNEL] write_weight_scale_to_buffer time: {elapsed_ms:.2f} ms")
|
|
print(f"[FP8_PERCHANNEL] Throughput: {total_bytes / (elapsed_ms * 1e6):.2f} GB/s")
|
|
|
|
def split_expert_tensor(tensor, chunk):
|
|
return [tensor[i * chunk : (i + 1) * chunk] for i in range(expert_num)]
|
|
|
|
gate_q = weights["gate_q"]
|
|
up_q = weights["up_q"]
|
|
down_q = weights["down_q"]
|
|
gate_scale = weights["gate_scale"]
|
|
up_scale = weights["up_scale"]
|
|
down_scale = weights["down_scale"]
|
|
|
|
gate_q_experts = split_expert_tensor(gate_q, per_mat_weight_bytes)
|
|
up_q_experts = split_expert_tensor(up_q, per_mat_weight_bytes)
|
|
down_q_experts = split_expert_tensor(down_q, per_mat_weight_bytes)
|
|
gate_scale_experts = split_expert_tensor(gate_scale, per_mat_scale_elems_gate_up)
|
|
up_scale_experts = split_expert_tensor(up_scale, per_mat_scale_elems_gate_up)
|
|
down_scale_experts = split_expert_tensor(down_scale, per_mat_scale_elems_down)
|
|
|
|
for tp_idx in range(gpu_tp_count):
|
|
expected_w13_weights = []
|
|
expected_w13_scales = []
|
|
expected_w2_weights = []
|
|
expected_w2_scales = []
|
|
|
|
weight13_per_tp = per_mat_weight_bytes // gpu_tp_count
|
|
scale13_per_tp = per_mat_scale_elems_gate_up // gpu_tp_count
|
|
|
|
for expert_id in range(gpu_experts):
|
|
start_weight = tp_idx * weight13_per_tp
|
|
end_weight = (tp_idx + 1) * weight13_per_tp
|
|
start_scale = tp_idx * scale13_per_tp
|
|
end_scale = (tp_idx + 1) * scale13_per_tp
|
|
|
|
gate_weight_tp = gate_q_experts[expert_id][start_weight:end_weight]
|
|
gate_scale_tp = gate_scale_experts[expert_id][start_scale:end_scale]
|
|
up_weight_tp = up_q_experts[expert_id][start_weight:end_weight]
|
|
up_scale_tp = up_scale_experts[expert_id][start_scale:end_scale]
|
|
|
|
down_weight_tp_parts = []
|
|
tp_slice_weight_size = intermediate_size // gpu_tp_count
|
|
|
|
for row_idx in range(hidden_size):
|
|
row_weight_start = row_idx * intermediate_size
|
|
tp_weight_offset = row_weight_start + tp_idx * tp_slice_weight_size
|
|
down_weight_tp_parts.append(
|
|
down_q_experts[expert_id][tp_weight_offset : tp_weight_offset + tp_slice_weight_size]
|
|
)
|
|
|
|
down_weight_tp = torch.cat(down_weight_tp_parts)
|
|
down_scale_tp = down_scale_experts[expert_id]
|
|
|
|
expected_w13_weights.append(gate_weight_tp)
|
|
expected_w13_weights.append(up_weight_tp)
|
|
expected_w13_scales.append(gate_scale_tp)
|
|
expected_w13_scales.append(up_scale_tp)
|
|
expected_w2_weights.append(down_weight_tp)
|
|
expected_w2_scales.append(down_scale_tp)
|
|
|
|
expected_w13_weight = torch.cat(expected_w13_weights)
|
|
expected_w13_scale = torch.cat(expected_w13_scales)
|
|
expected_w2_weight = torch.cat(expected_w2_weights)
|
|
expected_w2_scale = torch.cat(expected_w2_scales)
|
|
|
|
if not torch.equal(w13_weight_bufs[tp_idx], expected_w13_weight):
|
|
diff_mask = w13_weight_bufs[tp_idx] != expected_w13_weight
|
|
first_diff_idx = diff_mask.nonzero()[0].item() if diff_mask.any() else -1
|
|
raise AssertionError(f"[FP8_PERCHANNEL] w13 weight mismatch for TP {tp_idx} at index {first_diff_idx}")
|
|
|
|
if not torch.allclose(w13_scale_bufs[tp_idx], expected_w13_scale):
|
|
raise AssertionError(f"[FP8_PERCHANNEL] w13 scale mismatch for TP {tp_idx}")
|
|
|
|
if not torch.equal(w2_weight_bufs[tp_idx], expected_w2_weight):
|
|
diff_mask = w2_weight_bufs[tp_idx] != expected_w2_weight
|
|
first_diff_idx = diff_mask.nonzero()[0].item() if diff_mask.any() else -1
|
|
raise AssertionError(f"[FP8_PERCHANNEL] w2 weight mismatch for TP {tp_idx} at index {first_diff_idx}")
|
|
|
|
if not torch.allclose(w2_scale_bufs[tp_idx], expected_w2_scale):
|
|
raise AssertionError(f"[FP8_PERCHANNEL] w2 scale mismatch for TP {tp_idx}")
|
|
|
|
print(f"[FP8_PERCHANNEL] TP={gpu_tp_count} PASSED (verified {gpu_experts} experts across {gpu_tp_count} TP parts)")
|
|
return True
|
|
|
|
|
|
def test_bf16_write_buffer(gpu_tp_count):
|
|
"""Test write_weight_scale_to_buffer with BF16 weights (no scales)"""
|
|
torch.manual_seed(123)
|
|
|
|
expert_num = 16
|
|
gpu_experts = expert_num
|
|
num_experts_per_tok = 8
|
|
hidden_size = 3072
|
|
intermediate_size = 1536
|
|
|
|
cpuinfer = make_cpu_infer()
|
|
cfg = build_config_bf16(cpuinfer, expert_num, num_experts_per_tok, hidden_size, intermediate_size)
|
|
weights = allocate_weights_bf16(expert_num, hidden_size, intermediate_size)
|
|
|
|
cfg.gate_proj = weights["gate_proj"].data_ptr()
|
|
cfg.up_proj = weights["up_proj"].data_ptr()
|
|
cfg.down_proj = weights["down_proj"].data_ptr()
|
|
cfg.gate_scale = 0
|
|
cfg.up_scale = 0
|
|
cfg.down_scale = 0
|
|
|
|
moe = kt_kernel_ext.moe.AMXBF16_MOE(cfg)
|
|
|
|
physical_to_logical_map = torch.arange(expert_num, dtype=torch.int64, device="cpu").contiguous()
|
|
cpuinfer.submit(moe.load_weights_task(physical_to_logical_map.data_ptr()))
|
|
cpuinfer.sync()
|
|
|
|
per_mat_weight_elems = weights["per_mat_weight_elems"]
|
|
|
|
# Calculate sizes per TP part (BF16 = 2 bytes per element)
|
|
weight_elems_per_expert_per_tp = per_mat_weight_elems // gpu_tp_count
|
|
weight_bytes_per_expert_per_tp = weight_elems_per_expert_per_tp * 2
|
|
|
|
total_weight_bytes_per_tp = gpu_experts * weight_bytes_per_expert_per_tp
|
|
|
|
# Create buffer lists (BF16: weights only, no scales)
|
|
w13_weight_bufs = [torch.empty(2 * total_weight_bytes_per_tp, dtype=torch.uint8) for _ in range(gpu_tp_count)]
|
|
w2_weight_bufs = [torch.empty(total_weight_bytes_per_tp, dtype=torch.uint8) for _ in range(gpu_tp_count)]
|
|
# Empty scale buffers (not used for BF16 but needed for interface)
|
|
w13_scale_bufs = [torch.empty(1, dtype=torch.float32) for _ in range(gpu_tp_count)]
|
|
w2_scale_bufs = [torch.empty(1, dtype=torch.float32) for _ in range(gpu_tp_count)]
|
|
|
|
print(f"[BF16] GPU TP count: {gpu_tp_count}, Experts: {expert_num}")
|
|
print(f"[BF16] Weight bytes per expert per TP: {weight_bytes_per_expert_per_tp}")
|
|
|
|
def get_expert_ptrs(expert_id):
|
|
w13_weight_ptrs = []
|
|
w13_scale_ptrs = []
|
|
w2_weight_ptrs = []
|
|
w2_scale_ptrs = []
|
|
for tp_idx in range(gpu_tp_count):
|
|
w13_weight_expert_offset = expert_id * 2 * weight_bytes_per_expert_per_tp
|
|
w2_weight_expert_offset = expert_id * weight_bytes_per_expert_per_tp
|
|
|
|
w13_weight_ptrs.append(w13_weight_bufs[tp_idx].data_ptr() + w13_weight_expert_offset)
|
|
w13_scale_ptrs.append(w13_scale_bufs[tp_idx].data_ptr()) # Not used
|
|
w2_weight_ptrs.append(w2_weight_bufs[tp_idx].data_ptr() + w2_weight_expert_offset)
|
|
w2_scale_ptrs.append(w2_scale_bufs[tp_idx].data_ptr()) # Not used
|
|
return w13_weight_ptrs, w13_scale_ptrs, w2_weight_ptrs, w2_scale_ptrs
|
|
|
|
# Warm up
|
|
for _ in range(2):
|
|
for expert_id in range(gpu_experts):
|
|
w13_weight_ptrs, w13_scale_ptrs, w2_weight_ptrs, w2_scale_ptrs = get_expert_ptrs(expert_id)
|
|
cpuinfer.submit(
|
|
moe.write_weight_scale_to_buffer_task(
|
|
gpu_tp_count=gpu_tp_count,
|
|
expert_id=expert_id,
|
|
w13_weight_ptrs=w13_weight_ptrs,
|
|
w13_scale_ptrs=w13_scale_ptrs,
|
|
w2_weight_ptrs=w2_weight_ptrs,
|
|
w2_scale_ptrs=w2_scale_ptrs,
|
|
)
|
|
)
|
|
cpuinfer.sync()
|
|
|
|
# Timing
|
|
begin_time = time.perf_counter_ns()
|
|
for expert_id in range(gpu_experts):
|
|
w13_weight_ptrs, w13_scale_ptrs, w2_weight_ptrs, w2_scale_ptrs = get_expert_ptrs(expert_id)
|
|
cpuinfer.submit(
|
|
moe.write_weight_scale_to_buffer_task(
|
|
gpu_tp_count=gpu_tp_count,
|
|
expert_id=expert_id,
|
|
w13_weight_ptrs=w13_weight_ptrs,
|
|
w13_scale_ptrs=w13_scale_ptrs,
|
|
w2_weight_ptrs=w2_weight_ptrs,
|
|
w2_scale_ptrs=w2_scale_ptrs,
|
|
)
|
|
)
|
|
cpuinfer.sync()
|
|
end_time = time.perf_counter_ns()
|
|
elapsed_ms = (end_time - begin_time) / 1e6
|
|
|
|
total_bytes = hidden_size * intermediate_size * gpu_experts * 3 * 2 # BF16 = 2 bytes
|
|
print(f"[BF16] write_weight_scale_to_buffer time: {elapsed_ms:.2f} ms")
|
|
print(f"[BF16] Throughput: {total_bytes / (elapsed_ms * 1e6):.2f} GB/s")
|
|
|
|
# Verify correctness (BF16: weights only, no scales)
|
|
def split_expert_tensor(tensor, chunk):
|
|
return [tensor[i * chunk : (i + 1) * chunk] for i in range(expert_num)]
|
|
|
|
gate_proj = weights["gate_proj"]
|
|
up_proj = weights["up_proj"]
|
|
down_proj = weights["down_proj"]
|
|
|
|
# View BF16 as uint8 for byte-level comparison
|
|
gate_bytes = gate_proj.view(torch.uint8)
|
|
up_bytes = up_proj.view(torch.uint8)
|
|
down_bytes = down_proj.view(torch.uint8)
|
|
|
|
per_mat_bytes = per_mat_weight_elems * 2 # BF16 = 2 bytes
|
|
gate_experts = split_expert_tensor(gate_bytes, per_mat_bytes)
|
|
up_experts = split_expert_tensor(up_bytes, per_mat_bytes)
|
|
down_experts = split_expert_tensor(down_bytes, per_mat_bytes)
|
|
|
|
for tp_idx in range(gpu_tp_count):
|
|
expected_w13_weights = []
|
|
expected_w2_weights = []
|
|
|
|
weight_bytes_per_tp = per_mat_bytes // gpu_tp_count
|
|
|
|
for expert_id in range(gpu_experts):
|
|
start_weight = tp_idx * weight_bytes_per_tp
|
|
end_weight = (tp_idx + 1) * weight_bytes_per_tp
|
|
|
|
gate_weight_tp = gate_experts[expert_id][start_weight:end_weight]
|
|
up_weight_tp = up_experts[expert_id][start_weight:end_weight]
|
|
|
|
# Down matrix: sliced column-wise (BF16 = 2 bytes per element)
|
|
down_weight_tp_parts = []
|
|
tp_slice_elems = intermediate_size // gpu_tp_count
|
|
tp_slice_bytes = tp_slice_elems * 2
|
|
|
|
for row_idx in range(hidden_size):
|
|
row_byte_start = row_idx * intermediate_size * 2
|
|
tp_byte_offset = row_byte_start + tp_idx * tp_slice_bytes
|
|
down_weight_tp_parts.append(down_experts[expert_id][tp_byte_offset : tp_byte_offset + tp_slice_bytes])
|
|
|
|
down_weight_tp = torch.cat(down_weight_tp_parts)
|
|
|
|
expected_w13_weights.append(gate_weight_tp)
|
|
expected_w13_weights.append(up_weight_tp)
|
|
expected_w2_weights.append(down_weight_tp)
|
|
|
|
expected_w13_weight = torch.cat(expected_w13_weights)
|
|
expected_w2_weight = torch.cat(expected_w2_weights)
|
|
|
|
if not torch.equal(w13_weight_bufs[tp_idx], expected_w13_weight):
|
|
diff_mask = w13_weight_bufs[tp_idx] != expected_w13_weight
|
|
first_diff_idx = diff_mask.nonzero()[0].item() if diff_mask.any() else -1
|
|
raise AssertionError(f"[BF16] w13 weight mismatch for TP {tp_idx} at index {first_diff_idx}")
|
|
|
|
if not torch.equal(w2_weight_bufs[tp_idx], expected_w2_weight):
|
|
diff_mask = w2_weight_bufs[tp_idx] != expected_w2_weight
|
|
first_diff_idx = diff_mask.nonzero()[0].item() if diff_mask.any() else -1
|
|
raise AssertionError(f"[BF16] w2 weight mismatch for TP {tp_idx} at index {first_diff_idx}")
|
|
|
|
print(f"[BF16] TP={gpu_tp_count} PASSED (verified {gpu_experts} experts across {gpu_tp_count} TP parts)")
|
|
return True
|
|
|
|
|
|
def test_with_tp(quant_mode: str, gpu_tp_count: int):
|
|
"""Test write_weight_scale_to_buffer with specified mode and TP count"""
|
|
if quant_mode == "fp8":
|
|
return test_fp8_write_buffer(gpu_tp_count)
|
|
elif quant_mode == "fp8_perchannel":
|
|
return test_fp8_perchannel_write_buffer(gpu_tp_count)
|
|
elif quant_mode == "bf16":
|
|
return test_bf16_write_buffer(gpu_tp_count)
|
|
else:
|
|
raise ValueError(f"Unsupported quant_mode: {quant_mode}")
|
|
|
|
|
|
def main(quant_modes=None):
|
|
"""Run tests for specified quant modes"""
|
|
if quant_modes is None:
|
|
quant_modes = ["fp8", "fp8_perchannel", "bf16"]
|
|
|
|
tp_values = [1, 2, 4]
|
|
all_passed = True
|
|
results = {}
|
|
|
|
for quant_mode in quant_modes:
|
|
print("\n" + "=" * 60)
|
|
print(f"Testing {quant_mode.upper()} write_weight_scale_to_buffer")
|
|
print("=" * 60)
|
|
|
|
for tp in tp_values:
|
|
print(f"\n--- Testing {quant_mode.upper()} with gpu_tp_count = {tp} ---")
|
|
try:
|
|
test_with_tp(quant_mode, tp)
|
|
results[(quant_mode, tp)] = "PASSED"
|
|
except Exception as e:
|
|
results[(quant_mode, tp)] = f"FAILED: {e}"
|
|
all_passed = False
|
|
print(f"[{quant_mode.upper()}] TP={tp} FAILED: {e}")
|
|
|
|
print("\n" + "=" * 60)
|
|
print("SUMMARY")
|
|
print("=" * 60)
|
|
for (mode, tp), result in results.items():
|
|
status = "PASS" if "PASSED" in result else "FAIL"
|
|
print(f" [{status}] {mode.upper()} TP={tp}: {result}")
|
|
|
|
if all_passed:
|
|
print("\nALL TESTS PASSED")
|
|
else:
|
|
print("\nSOME TESTS FAILED")
|
|
sys.exit(1)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
if len(sys.argv) > 1:
|
|
mode = sys.argv[1].lower()
|
|
if mode in ["fp8", "fp8_perchannel", "bf16"]:
|
|
main([mode])
|
|
else:
|
|
print(f"Unknown mode: {mode}. Use 'fp8', 'fp8_perchannel' or 'bf16'")
|
|
sys.exit(1)
|
|
else:
|
|
main()
|