Files
ktransformers/kt-kernel/bench/bench_moe_kernel_tiling.py
2025-12-17 19:46:32 +08:00

233 lines
9.8 KiB
Python

#!/usr/bin/env python
# coding=utf-8
"""
Bench MOE kernel with runtime tiling params (N_BLOCK_UP_GATE, N_BLOCK_DOWN, N_BLOCK, M_BLOCK, K_BLOCK)
- Demonstrates how to get/set tiling params from Python via kt_kernel_ext.moe.tiling
- Runs a small benchmark similar to bench_moe_kernel.py
Usage examples:
# 1) Just run with defaults (int8)
python bench_moe_kernel_tiling.py --quant int8
# 2) Override tiling params for INT8
python bench_moe_kernel_tiling.py --quant int8 \
--n_block_up_gate 32 --n_block_down 64 --n_block 64 --m_block 320 --k_block 7168
# 3) Set both INT8 and INT4 tiling params (if INT4 kernel is available on your platform)
python bench_moe_kernel_tiling.py --quant int4 --set_all \
--n_block_up_gate 256 --n_block_down 1024 --n_block 64 --m_block 320 --k_block 7168
"""
import os
import sys
import time
import argparse
os.environ.setdefault("BLAS_NUM_THREADS", "1")
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "build"))
import torch # noqa: E402
from kt_kernel import kt_kernel_ext as ce # noqa: E402
from tqdm import tqdm # noqa: E402
def maybe_get_class(module, name):
return getattr(module, name) if hasattr(module, name) else None
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--quant", choices=["int8", "int4"], default="int8")
parser.add_argument("--expert_num", type=int, default=256)
parser.add_argument("--hidden_size", type=int, default=7168)
parser.add_argument("--intermediate_size", type=int, default=2048)
parser.add_argument("--num_experts_per_tok", type=int, default=8)
parser.add_argument("--max_len", type=int, default=25600)
parser.add_argument("--layer_num", type=int, default=1)
parser.add_argument("--qlen", type=int, default=1024)
parser.add_argument("--warm_up_iter", type=int, default=200)
parser.add_argument("--test_iter", type=int, default=500)
parser.add_argument("--threads", type=int, default=160, help="CPUInfer initialization param")
# Tiling params
parser.add_argument("--set_all", action="store_true", help="Apply tiling to both INT8 and INT4 kernels")
parser.add_argument("--n_block_up_gate", type=int, default=None)
parser.add_argument("--n_block_down", type=int, default=None)
parser.add_argument("--n_block", type=int, default=None)
parser.add_argument("--m_block", type=int, default=None)
parser.add_argument("--k_block", type=int, default=None)
parser.add_argument("--n_block_up_gate_prefi", type=int, default=None)
parser.add_argument("--n_block_down_prefi", type=int, default=None)
args = parser.parse_args()
# Show current tiling defaults
if args.quant == "int8":
print("[tiling] default int8:", ce.moe.tiling.get_int8())
if hasattr(ce.moe.tiling, "get_int4") and args.quant == "int4":
print("[tiling] default int4:", ce.moe.tiling.get_int4())
# Apply overrides if provided
if any(v is not None for v in [args.n_block_up_gate, args.n_block_down, args.n_block, args.m_block, args.k_block]):
# Fill missing values with current defaults to avoid overwriting unrelated params
def fill_defaults(getter):
cur = getter()
return (
args.n_block_up_gate if args.n_block_up_gate is not None else int(cur["n_block_up_gate"]),
args.n_block_down if args.n_block_down is not None else int(cur["n_block_down"]),
args.n_block if args.n_block is not None else int(cur["n_block"]),
args.m_block if args.m_block is not None else int(cur["m_block"]),
args.k_block if args.k_block is not None else int(cur["k_block"]),
(
args.n_block_up_gate_prefi
if args.n_block_up_gate_prefi is not None
else int(cur["n_block_up_gate_prefi"])
),
args.n_block_down_prefi if args.n_block_down_prefi is not None else int(cur["n_block_down_prefi"]),
)
if args.set_all and hasattr(ce.moe.tiling, "set_all"):
nbug, nbd, nb, mb, kb, nbug_prefi, nbd_prefi = fill_defaults(ce.moe.tiling.get_int8)
ce.moe.tiling.set_all(nbug, nbd, nb, mb, kb, nbug_prefi, nbd_prefi)
print("[tiling] set_all ->", ce.moe.tiling.get_int8())
if hasattr(ce.moe.tiling, "get_int4"):
print("[tiling] set_all -> int4:", ce.moe.tiling.get_int4())
else:
if args.quant == "int8":
nbug, nbd, nb, mb, kb, nbug_prefi, nbd_prefi = fill_defaults(ce.moe.tiling.get_int8)
ce.moe.tiling.set_int8(nbug, nbd, nb, mb, kb, nbug_prefi, nbd_prefi)
print("[tiling] set_int8 ->", ce.moe.tiling.get_int8())
elif args.quant == "int4" and hasattr(ce.moe.tiling, "set_int4"):
nbug, nbd, nb, mb, kb, nbug_prefi, nbd_prefi = fill_defaults(ce.moe.tiling.get_int4)
ce.moe.tiling.set_int4(nbug, nbd, nb, mb, kb, nbug_prefi, nbd_prefi)
print("[tiling] set_int4 ->", ce.moe.tiling.get_int4())
# Warn about divisibility expectations; kernels assume specific blocking
# - Some helpers assert n % N_BLOCK == 0, etc. Ensure your dims/tiles align.
print("[note] Ensure your selected tiling parameters are compatible with hidden/intermediate sizes and blocking.")
# Initialize CPUInfer
CPUInfer = ce.CPUInfer(args.threads)
# Select MOE kernel
moe_cls = None
if args.quant == "int8":
moe_cls = maybe_get_class(ce.moe, "Int8_KERNEL_MOE")
if moe_cls is None:
raise RuntimeError("Int8 kernel binding 'Int8_KERNEL_MOE' not found.")
bytes_per_elem = 1.0
else:
moe_cls = maybe_get_class(ce.moe, "Int4_KERNEL_MOE")
if moe_cls is None:
raise RuntimeError("Int4 kernel binding 'Int4_KERNEL_MOE' not available on this platform.")
bytes_per_elem = 0.5
# Prepare config/weights
expert_num = args.expert_num
hidden_size = args.hidden_size
intermediate_size = args.intermediate_size
num_experts_per_tok = args.num_experts_per_tok
layer_num = args.layer_num
max_len = args.max_len
physical_to_logical_map = torch.arange(expert_num, dtype=torch.int64, device="cpu").contiguous()
moes = []
gate_projs, up_projs, down_projs = [], [], []
for layer_idx in range(layer_num):
gate_proj = torch.randn(
(expert_num, intermediate_size, hidden_size), dtype=torch.float32, device="cpu"
).contiguous()
up_proj = torch.randn(
(expert_num, intermediate_size, hidden_size), dtype=torch.float32, device="cpu"
).contiguous()
down_proj = torch.randn(
(expert_num, hidden_size, intermediate_size), dtype=torch.float32, device="cpu"
).contiguous()
cfg = ce.moe.MOEConfig(expert_num, num_experts_per_tok, hidden_size, intermediate_size, 0)
cfg.max_len = max_len
cfg.gate_proj = gate_proj.data_ptr()
cfg.up_proj = up_proj.data_ptr()
cfg.down_proj = down_proj.data_ptr()
cfg.pool = CPUInfer.backend_
moe = moe_cls(cfg)
CPUInfer.submit(moe.load_weights_task(physical_to_logical_map.data_ptr()))
CPUInfer.sync()
gate_projs.append(gate_proj)
up_projs.append(up_proj)
down_projs.append(down_proj)
moes.append(moe)
qlen = args.qlen
warm_up_iter = args.warm_up_iter
test_iter = args.test_iter
expert_ids = (
torch.rand(test_iter * qlen, expert_num)
.argsort(dim=-1)[:, :num_experts_per_tok]
.reshape(test_iter, qlen * num_experts_per_tok)
.to("cpu")
.contiguous()
)
weights = torch.rand((test_iter, qlen, num_experts_per_tok), dtype=torch.float32).to("cpu").contiguous()
input_tensor = torch.randn((layer_num, qlen, hidden_size), dtype=torch.bfloat16).to("cpu").contiguous()
output_tensor = torch.empty((layer_num, qlen, hidden_size), dtype=torch.bfloat16).to("cpu").contiguous()
bsz_tensor = torch.tensor([qlen], dtype=torch.int32).to("cpu").contiguous()
# Warmup
for i in tqdm(range(warm_up_iter), desc="Warm-up"):
CPUInfer.submit(
moes[i % layer_num].forward_task(
bsz_tensor.data_ptr(),
num_experts_per_tok,
expert_ids[i].data_ptr(),
weights[i].data_ptr(),
input_tensor[i % layer_num].data_ptr(),
output_tensor[i % layer_num].data_ptr(),
)
)
CPUInfer.sync()
# Measure
start = time.perf_counter()
for i in tqdm(range(test_iter), desc="Testing"):
CPUInfer.submit(
moes[i % layer_num].forward_task(
bsz_tensor.data_ptr(),
num_experts_per_tok,
expert_ids[i].data_ptr(),
weights[i].data_ptr(),
input_tensor[i % layer_num].data_ptr(),
output_tensor[i % layer_num].data_ptr(),
False,
)
)
CPUInfer.sync()
end = time.perf_counter()
total_time = end - start
time_per_iter_us = total_time / test_iter * 1e6
bandwidth_gbs = (
hidden_size * intermediate_size * 3 * num_experts_per_tok * qlen * bytes_per_elem * test_iter / total_time / 1e9
)
flops_tflops = hidden_size * intermediate_size * qlen * 3 * num_experts_per_tok * 2 * test_iter / total_time / 1e12
print("\n=== Results ===")
print("quant:", args.quant)
if hasattr(ce.moe.tiling, "get_int8") and args.quant == "int8":
print("tiling int8:", ce.moe.tiling.get_int8())
if hasattr(ce.moe.tiling, "get_int4") and args.quant == "int4":
print("tiling int4:", ce.moe.tiling.get_int4())
print("time (s):", total_time)
print("iter:", test_iter)
print("time per iter (us):", time_per_iter_us)
print("bandwidth (GB/s):", bandwidth_gbs)
print("TFLOPS:", flops_tflops)
if __name__ == "__main__":
main()