mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2026-03-14 18:37:23 +00:00
233 lines
9.8 KiB
Python
233 lines
9.8 KiB
Python
#!/usr/bin/env python
|
|
# coding=utf-8
|
|
"""
|
|
Bench MOE kernel with runtime tiling params (N_BLOCK_UP_GATE, N_BLOCK_DOWN, N_BLOCK, M_BLOCK, K_BLOCK)
|
|
- Demonstrates how to get/set tiling params from Python via kt_kernel_ext.moe.tiling
|
|
- Runs a small benchmark similar to bench_moe_kernel.py
|
|
|
|
Usage examples:
|
|
# 1) Just run with defaults (int8)
|
|
python bench_moe_kernel_tiling.py --quant int8
|
|
|
|
# 2) Override tiling params for INT8
|
|
python bench_moe_kernel_tiling.py --quant int8 \
|
|
--n_block_up_gate 32 --n_block_down 64 --n_block 64 --m_block 320 --k_block 7168
|
|
|
|
# 3) Set both INT8 and INT4 tiling params (if INT4 kernel is available on your platform)
|
|
python bench_moe_kernel_tiling.py --quant int4 --set_all \
|
|
--n_block_up_gate 256 --n_block_down 1024 --n_block 64 --m_block 320 --k_block 7168
|
|
"""
|
|
import os
|
|
import sys
|
|
import time
|
|
import argparse
|
|
|
|
os.environ.setdefault("BLAS_NUM_THREADS", "1")
|
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "build"))
|
|
|
|
import torch # noqa: E402
|
|
from kt_kernel import kt_kernel_ext as ce # noqa: E402
|
|
from tqdm import tqdm # noqa: E402
|
|
|
|
|
|
def maybe_get_class(module, name):
|
|
return getattr(module, name) if hasattr(module, name) else None
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--quant", choices=["int8", "int4"], default="int8")
|
|
parser.add_argument("--expert_num", type=int, default=256)
|
|
parser.add_argument("--hidden_size", type=int, default=7168)
|
|
parser.add_argument("--intermediate_size", type=int, default=2048)
|
|
parser.add_argument("--num_experts_per_tok", type=int, default=8)
|
|
parser.add_argument("--max_len", type=int, default=25600)
|
|
parser.add_argument("--layer_num", type=int, default=1)
|
|
parser.add_argument("--qlen", type=int, default=1024)
|
|
parser.add_argument("--warm_up_iter", type=int, default=200)
|
|
parser.add_argument("--test_iter", type=int, default=500)
|
|
parser.add_argument("--threads", type=int, default=160, help="CPUInfer initialization param")
|
|
|
|
# Tiling params
|
|
parser.add_argument("--set_all", action="store_true", help="Apply tiling to both INT8 and INT4 kernels")
|
|
parser.add_argument("--n_block_up_gate", type=int, default=None)
|
|
parser.add_argument("--n_block_down", type=int, default=None)
|
|
parser.add_argument("--n_block", type=int, default=None)
|
|
parser.add_argument("--m_block", type=int, default=None)
|
|
parser.add_argument("--k_block", type=int, default=None)
|
|
parser.add_argument("--n_block_up_gate_prefi", type=int, default=None)
|
|
parser.add_argument("--n_block_down_prefi", type=int, default=None)
|
|
|
|
args = parser.parse_args()
|
|
|
|
# Show current tiling defaults
|
|
if args.quant == "int8":
|
|
print("[tiling] default int8:", ce.moe.tiling.get_int8())
|
|
if hasattr(ce.moe.tiling, "get_int4") and args.quant == "int4":
|
|
print("[tiling] default int4:", ce.moe.tiling.get_int4())
|
|
|
|
# Apply overrides if provided
|
|
if any(v is not None for v in [args.n_block_up_gate, args.n_block_down, args.n_block, args.m_block, args.k_block]):
|
|
# Fill missing values with current defaults to avoid overwriting unrelated params
|
|
def fill_defaults(getter):
|
|
cur = getter()
|
|
return (
|
|
args.n_block_up_gate if args.n_block_up_gate is not None else int(cur["n_block_up_gate"]),
|
|
args.n_block_down if args.n_block_down is not None else int(cur["n_block_down"]),
|
|
args.n_block if args.n_block is not None else int(cur["n_block"]),
|
|
args.m_block if args.m_block is not None else int(cur["m_block"]),
|
|
args.k_block if args.k_block is not None else int(cur["k_block"]),
|
|
(
|
|
args.n_block_up_gate_prefi
|
|
if args.n_block_up_gate_prefi is not None
|
|
else int(cur["n_block_up_gate_prefi"])
|
|
),
|
|
args.n_block_down_prefi if args.n_block_down_prefi is not None else int(cur["n_block_down_prefi"]),
|
|
)
|
|
|
|
if args.set_all and hasattr(ce.moe.tiling, "set_all"):
|
|
nbug, nbd, nb, mb, kb, nbug_prefi, nbd_prefi = fill_defaults(ce.moe.tiling.get_int8)
|
|
ce.moe.tiling.set_all(nbug, nbd, nb, mb, kb, nbug_prefi, nbd_prefi)
|
|
print("[tiling] set_all ->", ce.moe.tiling.get_int8())
|
|
if hasattr(ce.moe.tiling, "get_int4"):
|
|
print("[tiling] set_all -> int4:", ce.moe.tiling.get_int4())
|
|
else:
|
|
if args.quant == "int8":
|
|
nbug, nbd, nb, mb, kb, nbug_prefi, nbd_prefi = fill_defaults(ce.moe.tiling.get_int8)
|
|
ce.moe.tiling.set_int8(nbug, nbd, nb, mb, kb, nbug_prefi, nbd_prefi)
|
|
print("[tiling] set_int8 ->", ce.moe.tiling.get_int8())
|
|
elif args.quant == "int4" and hasattr(ce.moe.tiling, "set_int4"):
|
|
nbug, nbd, nb, mb, kb, nbug_prefi, nbd_prefi = fill_defaults(ce.moe.tiling.get_int4)
|
|
ce.moe.tiling.set_int4(nbug, nbd, nb, mb, kb, nbug_prefi, nbd_prefi)
|
|
print("[tiling] set_int4 ->", ce.moe.tiling.get_int4())
|
|
|
|
# Warn about divisibility expectations; kernels assume specific blocking
|
|
# - Some helpers assert n % N_BLOCK == 0, etc. Ensure your dims/tiles align.
|
|
print("[note] Ensure your selected tiling parameters are compatible with hidden/intermediate sizes and blocking.")
|
|
|
|
# Initialize CPUInfer
|
|
CPUInfer = ce.CPUInfer(args.threads)
|
|
|
|
# Select MOE kernel
|
|
moe_cls = None
|
|
if args.quant == "int8":
|
|
moe_cls = maybe_get_class(ce.moe, "Int8_KERNEL_MOE")
|
|
if moe_cls is None:
|
|
raise RuntimeError("Int8 kernel binding 'Int8_KERNEL_MOE' not found.")
|
|
bytes_per_elem = 1.0
|
|
else:
|
|
moe_cls = maybe_get_class(ce.moe, "Int4_KERNEL_MOE")
|
|
if moe_cls is None:
|
|
raise RuntimeError("Int4 kernel binding 'Int4_KERNEL_MOE' not available on this platform.")
|
|
bytes_per_elem = 0.5
|
|
|
|
# Prepare config/weights
|
|
expert_num = args.expert_num
|
|
hidden_size = args.hidden_size
|
|
intermediate_size = args.intermediate_size
|
|
num_experts_per_tok = args.num_experts_per_tok
|
|
layer_num = args.layer_num
|
|
max_len = args.max_len
|
|
|
|
physical_to_logical_map = torch.arange(expert_num, dtype=torch.int64, device="cpu").contiguous()
|
|
|
|
moes = []
|
|
gate_projs, up_projs, down_projs = [], [], []
|
|
|
|
for layer_idx in range(layer_num):
|
|
gate_proj = torch.randn(
|
|
(expert_num, intermediate_size, hidden_size), dtype=torch.float32, device="cpu"
|
|
).contiguous()
|
|
up_proj = torch.randn(
|
|
(expert_num, intermediate_size, hidden_size), dtype=torch.float32, device="cpu"
|
|
).contiguous()
|
|
down_proj = torch.randn(
|
|
(expert_num, hidden_size, intermediate_size), dtype=torch.float32, device="cpu"
|
|
).contiguous()
|
|
|
|
cfg = ce.moe.MOEConfig(expert_num, num_experts_per_tok, hidden_size, intermediate_size, 0)
|
|
cfg.max_len = max_len
|
|
cfg.gate_proj = gate_proj.data_ptr()
|
|
cfg.up_proj = up_proj.data_ptr()
|
|
cfg.down_proj = down_proj.data_ptr()
|
|
cfg.pool = CPUInfer.backend_
|
|
|
|
moe = moe_cls(cfg)
|
|
CPUInfer.submit(moe.load_weights_task(physical_to_logical_map.data_ptr()))
|
|
CPUInfer.sync()
|
|
|
|
gate_projs.append(gate_proj)
|
|
up_projs.append(up_proj)
|
|
down_projs.append(down_proj)
|
|
moes.append(moe)
|
|
|
|
qlen = args.qlen
|
|
warm_up_iter = args.warm_up_iter
|
|
test_iter = args.test_iter
|
|
|
|
expert_ids = (
|
|
torch.rand(test_iter * qlen, expert_num)
|
|
.argsort(dim=-1)[:, :num_experts_per_tok]
|
|
.reshape(test_iter, qlen * num_experts_per_tok)
|
|
.to("cpu")
|
|
.contiguous()
|
|
)
|
|
weights = torch.rand((test_iter, qlen, num_experts_per_tok), dtype=torch.float32).to("cpu").contiguous()
|
|
input_tensor = torch.randn((layer_num, qlen, hidden_size), dtype=torch.bfloat16).to("cpu").contiguous()
|
|
output_tensor = torch.empty((layer_num, qlen, hidden_size), dtype=torch.bfloat16).to("cpu").contiguous()
|
|
bsz_tensor = torch.tensor([qlen], dtype=torch.int32).to("cpu").contiguous()
|
|
|
|
# Warmup
|
|
for i in tqdm(range(warm_up_iter), desc="Warm-up"):
|
|
CPUInfer.submit(
|
|
moes[i % layer_num].forward_task(
|
|
bsz_tensor.data_ptr(),
|
|
num_experts_per_tok,
|
|
expert_ids[i].data_ptr(),
|
|
weights[i].data_ptr(),
|
|
input_tensor[i % layer_num].data_ptr(),
|
|
output_tensor[i % layer_num].data_ptr(),
|
|
)
|
|
)
|
|
CPUInfer.sync()
|
|
|
|
# Measure
|
|
start = time.perf_counter()
|
|
for i in tqdm(range(test_iter), desc="Testing"):
|
|
CPUInfer.submit(
|
|
moes[i % layer_num].forward_task(
|
|
bsz_tensor.data_ptr(),
|
|
num_experts_per_tok,
|
|
expert_ids[i].data_ptr(),
|
|
weights[i].data_ptr(),
|
|
input_tensor[i % layer_num].data_ptr(),
|
|
output_tensor[i % layer_num].data_ptr(),
|
|
False,
|
|
)
|
|
)
|
|
CPUInfer.sync()
|
|
end = time.perf_counter()
|
|
|
|
total_time = end - start
|
|
time_per_iter_us = total_time / test_iter * 1e6
|
|
bandwidth_gbs = (
|
|
hidden_size * intermediate_size * 3 * num_experts_per_tok * qlen * bytes_per_elem * test_iter / total_time / 1e9
|
|
)
|
|
flops_tflops = hidden_size * intermediate_size * qlen * 3 * num_experts_per_tok * 2 * test_iter / total_time / 1e12
|
|
|
|
print("\n=== Results ===")
|
|
print("quant:", args.quant)
|
|
if hasattr(ce.moe.tiling, "get_int8") and args.quant == "int8":
|
|
print("tiling int8:", ce.moe.tiling.get_int8())
|
|
if hasattr(ce.moe.tiling, "get_int4") and args.quant == "int4":
|
|
print("tiling int4:", ce.moe.tiling.get_int4())
|
|
print("time (s):", total_time)
|
|
print("iter:", test_iter)
|
|
print("time per iter (us):", time_per_iter_us)
|
|
print("bandwidth (GB/s):", bandwidth_gbs)
|
|
print("TFLOPS:", flops_tflops)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|