Files
ktransformers/kt-kernel/bench/compare_moe_performance.py
2025-11-03 15:19:52 +08:00

1371 lines
56 KiB
Python

#!/usr/bin/env python
# coding=utf-8
"""
MoE Performance Comparison Script
Compares performance between KTransformers AMX MoE and SGL CPU MoE implementations
"""
import os
import sys
import time
import json
import platform
import subprocess
import argparse
import logging
import signal
from datetime import datetime
from typing import Dict, List, Optional, Tuple
from dataclasses import dataclass, asdict
from pathlib import Path
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
# Environment configuration
@dataclass
class EnvironmentConfig:
malloc_conf: str = "oversize_threshold:1,background_thread:true,metadata_thp:auto,dirty_decay_ms:-1,muzzy_decay_ms:-1"
jemalloc_path: str = "/home/xwy/Projects/jemalloc/lib/libjemalloc.so"
def apply(self):
os.environ['MALLOC_CONF'] = self.malloc_conf
if os.path.exists(self.jemalloc_path):
os.environ['LD_PRELOAD'] = self.jemalloc_path
else:
logger.warning(f"jemalloc not found at {self.jemalloc_path}")
# Apply environment configuration
env_config = EnvironmentConfig()
env_config.apply()
# Add paths for both implementations
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'build'))
sys.path.insert(0, '/home/xwy/Projects/sgl-cpu-tests')
import torch
# Try importing both implementations
try:
import kt_kernel_ext
KTRANSFORMERS_AVAILABLE = True
logger.info("KTransformers kt_kernel_ext loaded successfully")
except ImportError as e:
KTRANSFORMERS_AVAILABLE = False
logger.warning(f"KTransformers kt_kernel_ext not available: {e}")
try:
from sgl_kernel.common_ops import fused_experts_cpu
from sgl_kernel.common_ops import convert_weight_packed
SGL_AVAILABLE = True
logger.info("SGL kernel loaded successfully")
except ImportError as e:
SGL_AVAILABLE = False
logger.warning(f"SGL kernel not available: {e}")
# Try importing int4 support
try:
# For SGL INT4, we'll check if the sglang-jianan directory exists
import os
sglang_path = "/home/xwy/Projects/sglang-jianan"
if os.path.exists(sglang_path) and os.path.exists(os.path.join(sglang_path, "benchmark/kernels/int4_moe/benchmark_int4_moe.py")):
SGL_INT4_AVAILABLE = True
logger.info("SGL INT4 support available (via sglang-jianan)")
else:
SGL_INT4_AVAILABLE = False
logger.warning("SGL INT4 support not available: sglang-jianan directory not found")
except Exception as e:
SGL_INT4_AVAILABLE = False
logger.warning(f"SGL INT4 support not available: {e}")
def get_cpu_count() -> int:
"""Get logical CPU core count (including hyperthreading)"""
cpu_count = None
# Method 1: os.cpu_count()
try:
cpu_count = os.cpu_count()
if cpu_count and cpu_count > 0:
logger.info(f"Detected {cpu_count} logical CPU cores via os.cpu_count()")
return cpu_count
except Exception as e:
logger.debug(f"os.cpu_count() failed: {e}")
# Method 2: Check /proc/cpuinfo
try:
with open('/proc/cpuinfo', 'r') as f:
cpu_count = sum(1 for line in f if line.strip().startswith('processor'))
if cpu_count > 0:
logger.info(f"Detected {cpu_count} logical CPU cores via /proc/cpuinfo")
return cpu_count
except Exception as e:
logger.debug(f"Failed to read /proc/cpuinfo: {e}")
# Default fallback
logger.warning("Could not detect CPU count, defaulting to 32")
return 32
def get_physical_cpu_count() -> int:
"""Get physical CPU core count (excluding hyperthreading)"""
# Method 1: Try lscpu command
try:
result = subprocess.run(['lscpu'], capture_output=True, text=True, timeout=5)
if result.returncode == 0:
cores_per_socket = None
sockets = None
for line in result.stdout.split('\n'):
if 'Core(s) per socket:' in line:
cores_per_socket = int(line.split(':')[1].strip())
elif 'Socket(s):' in line:
sockets = int(line.split(':')[1].strip())
if cores_per_socket and sockets:
physical_cores = cores_per_socket * sockets
logger.info(f"Detected {physical_cores} physical CPU cores via lscpu")
return physical_cores
except Exception as e:
logger.debug(f"lscpu failed: {e}")
# Method 2: Check /sys/devices/system/cpu/
try:
cpu_path = '/sys/devices/system/cpu/'
if os.path.exists(cpu_path):
# Count unique physical core IDs
physical_cores = set()
for cpu_dir in os.listdir(cpu_path):
if cpu_dir.startswith('cpu') and cpu_dir[3:].isdigit():
core_id_path = os.path.join(cpu_path, cpu_dir, 'topology/core_id')
if os.path.exists(core_id_path):
with open(core_id_path, 'r') as f:
core_id = f.read().strip()
physical_cores.add(core_id)
if physical_cores:
count = len(physical_cores)
logger.info(f"Detected {count} physical CPU cores via sysfs")
return count
except Exception as e:
logger.debug(f"Failed to check sysfs: {e}")
# Method 3: Parse /proc/cpuinfo for unique core ids
try:
with open('/proc/cpuinfo', 'r') as f:
content = f.read()
cores = set()
current_physical_id = None
for line in content.split('\n'):
if line.startswith('physical id'):
current_physical_id = line.split(':')[1].strip()
elif line.startswith('core id') and current_physical_id is not None:
core_id = line.split(':')[1].strip()
cores.add(f"{current_physical_id}:{core_id}")
if cores:
count = len(cores)
logger.info(f"Detected {count} physical CPU cores via /proc/cpuinfo")
return count
except Exception as e:
logger.debug(f"Failed to parse /proc/cpuinfo: {e}")
# Fallback: assume hyperthreading is enabled and divide logical cores by 2
try:
logical_count = get_cpu_count()
if logical_count > 0:
# Assume hyperthreading, so physical cores = logical cores / 2
physical_count = logical_count // 2
logger.warning(f"Could not detect physical cores directly. Assuming hyperthreading enabled: {logical_count} logical cores -> {physical_count} physical cores")
return physical_count
except:
pass
# Default fallback
logger.warning("Could not detect physical CPU count, defaulting to 32")
return 32
# Test configuration dataclass
@dataclass
class TestConfig:
expert_num: int = 256
hidden_size: int = 7168
intermediate_size: int = 2048
max_len: int = 25600
num_experts_per_tok: int = 8
layer_num: int = 5
warm_up_iter: int = 100
test_iter: int = 10000
qlen_values: List[int] = None
thread_count_values: List[int] = None
def __post_init__(self):
if self.qlen_values is None:
self.qlen_values = [1, 4, 16, 64, 256, 1024, 2048]
if self.thread_count_values is None:
# Default to physical CPU core count
physical_cores = get_physical_cpu_count()
self.thread_count_values = [physical_cores]
@property
def total_configurations(self) -> int:
return len(self.qlen_values) * len(self.thread_count_values)
def get_numa_count() -> int:
"""Get NUMA node count from system with multiple fallback methods"""
# Method 1: Try numactl
try:
result = subprocess.run(['numactl', '--hardware'],
capture_output=True, text=True, timeout=5)
if result.returncode == 0:
for line in result.stdout.split('\n'):
if 'available:' in line and 'nodes' in line:
parts = line.split()
if len(parts) >= 2 and parts[1].isdigit():
numa_count = int(parts[1])
logger.info(f"Detected {numa_count} NUMA nodes via numactl")
return numa_count
except (subprocess.TimeoutExpired, FileNotFoundError) as e:
logger.debug(f"numactl not available: {e}")
# Method 2: Check /sys/devices/system/node/
try:
node_path = '/sys/devices/system/node/'
if os.path.exists(node_path):
numa_dirs = [d for d in os.listdir(node_path) if d.startswith('node')]
if numa_dirs:
numa_count = len(numa_dirs)
logger.info(f"Detected {numa_count} NUMA nodes via sysfs")
return numa_count
except Exception as e:
logger.debug(f"Failed to check sysfs: {e}")
# Default fallback
logger.warning("Could not detect NUMA configuration, defaulting to 2 nodes")
return 2
# System configuration
@dataclass
class SystemConfig:
numa_count: int = 0
cpu_cores: int = 0
def __post_init__(self):
if self.numa_count == 0:
self.numa_count = get_numa_count()
if self.cpu_cores == 0:
self.cpu_cores = get_cpu_count()
sys_config = SystemConfig()
@dataclass
class ThreadConfig:
thread_count: int
threads_per_numa: int
sgl_thread_count: int
numa_prefix: str
@classmethod
def from_thread_count(cls, thread_count: int, numa_count: int, cpu_cores: int) -> 'ThreadConfig':
"""Create thread configuration for a specific thread count"""
# Validate thread count
if thread_count > cpu_cores:
logger.warning(f"thread_count ({thread_count}) > cpu_cores ({cpu_cores}), using all cores")
thread_count = cpu_cores
threads_per_numa = thread_count // numa_count
sgl_thread_count = threads_per_numa
last_core = sgl_thread_count - 1
numa_prefix = f"numactl --physcpubind=0-{last_core} --membind=0"
return cls(
thread_count=thread_count,
threads_per_numa=threads_per_numa,
sgl_thread_count=sgl_thread_count,
numa_prefix=numa_prefix
)
def get_system_info() -> Dict[str, any]:
"""Get comprehensive system information"""
info = {}
# Basic system info
uname = platform.uname()
info["system_name"] = uname.system
info["node_name"] = uname.node
info["release"] = uname.release
info["machine"] = uname.machine
info["cpu_count"] = sys_config.cpu_cores
info["numa_nodes"] = sys_config.numa_count
# CPU model information
if os.path.exists('/proc/cpuinfo'):
try:
with open('/proc/cpuinfo', 'r') as f:
cpu_info = f.read()
for line in cpu_info.split('\n'):
if "model name" in line:
info["cpu_model"] = line.split(":", 1)[1].strip()
break
# Check for CPU features
if "flags" in cpu_info:
flags_line = next(line for line in cpu_info.split('\n') if "flags" in line)
flags = flags_line.split(":", 1)[1].strip().split()
info["cpu_features"] = {
"avx2": "avx2" in flags,
"avx512": any(f.startswith("avx512") for f in flags),
"amx": any("amx" in f for f in flags)
}
except Exception as e:
logger.debug(f"Failed to read CPU info: {e}")
# Memory information
try:
import psutil
mem = psutil.virtual_memory()
info["total_memory_gb"] = round(mem.total / (1024**3), 2)
info["available_memory_gb"] = round(mem.available / (1024**3), 2)
except ImportError:
pass
# Python and PyTorch versions
info["python_version"] = sys.version.split()[0]
info["torch_version"] = torch.__version__
info["cuda_available"] = torch.cuda.is_available()
if torch.cuda.is_available():
info["cuda_version"] = torch.version.cuda
return info
@dataclass
class BenchmarkResult:
implementation: str
quant_mode: str
qlen: int
thread_count: int
total_time: float
time_per_iter_us: float
bandwidth_gbs: float
tflops: float
iterations: int
def to_dict(self) -> Dict:
return asdict(self)
@dataclass
class CheckpointState:
"""State information for checkpoint/resume functionality"""
test_config: TestConfig
completed_configs: List[Tuple[int, int, str, str]] # (thread_count, qlen, implementation, quant_mode)
results: List[BenchmarkResult]
start_time: str
last_update: str
def to_dict(self) -> Dict:
return {
'test_config': asdict(self.test_config),
'completed_configs': self.completed_configs,
'results': [r.to_dict() for r in self.results],
'start_time': self.start_time,
'last_update': self.last_update
}
@classmethod
def from_dict(cls, data: Dict) -> 'CheckpointState':
test_config = TestConfig(**data['test_config'])
results = [BenchmarkResult(**r) for r in data['results']]
return cls(
test_config=test_config,
completed_configs=data['completed_configs'],
results=results,
start_time=data['start_time'],
last_update=data['last_update']
)
class CheckpointManager:
"""Manages checkpoint saving and loading"""
def __init__(self, checkpoint_dir: str = None):
self.checkpoint_dir = Path(checkpoint_dir) if checkpoint_dir else Path.cwd() / "checkpoints"
self.checkpoint_dir.mkdir(exist_ok=True)
self.checkpoint_file = self.checkpoint_dir / "moe_benchmark_checkpoint.json"
self.interrupted = False
# Set up signal handler for graceful shutdown
signal.signal(signal.SIGINT, self._signal_handler)
signal.signal(signal.SIGTERM, self._signal_handler)
def _signal_handler(self, signum, frame):
logger.warning(f"Received signal {signum}, will save checkpoint after current test...")
self.interrupted = True
def save_checkpoint(self, state: CheckpointState):
"""Save checkpoint to file"""
state.last_update = datetime.now().isoformat()
# Save to temporary file first for atomicity
temp_file = self.checkpoint_file.with_suffix('.tmp')
try:
with open(temp_file, 'w') as f:
json.dump(state.to_dict(), f, indent=2)
# Atomically rename
temp_file.replace(self.checkpoint_file)
logger.info(f"Checkpoint saved: {len(state.results)} results, {len(state.completed_configs)} configs completed")
except Exception as e:
logger.error(f"Failed to save checkpoint: {e}")
if temp_file.exists():
temp_file.unlink()
def load_checkpoint(self) -> Optional[CheckpointState]:
"""Load checkpoint from file if exists"""
if not self.checkpoint_file.exists():
return None
try:
with open(self.checkpoint_file, 'r') as f:
data = json.load(f)
state = CheckpointState.from_dict(data)
logger.info(f"Loaded checkpoint: {len(state.results)} results, {len(state.completed_configs)} configs completed")
logger.info(f"Checkpoint started at {state.start_time}, last updated {state.last_update}")
return state
except Exception as e:
logger.error(f"Failed to load checkpoint: {e}")
return None
def clear_checkpoint(self):
"""Remove checkpoint file"""
if self.checkpoint_file.exists():
self.checkpoint_file.unlink()
logger.info("Checkpoint cleared")
def bench_ktransformers_moe(test_config: TestConfig, quant_mode: str, qlen: int,
thread_config: ThreadConfig) -> Optional[BenchmarkResult]:
"""Benchmark KTransformers AMX MoE implementation"""
if not KTRANSFORMERS_AVAILABLE:
logger.error("KTransformers not available, skipping benchmark")
return None
# Adjust iterations based on qlen to maintain reasonable runtime
adjusted_iterations = test_config.test_iter
adjusted_warmup = test_config.warm_up_iter
if qlen >= 1024:
adjusted_iterations = max(10, test_config.test_iter // 100)
adjusted_warmup = max(5, test_config.warm_up_iter // 20)
elif qlen >= 256:
adjusted_iterations = max(50, test_config.test_iter // 20)
adjusted_warmup = max(10, test_config.warm_up_iter // 10)
elif qlen >= 64:
adjusted_iterations = max(100, test_config.test_iter // 10)
adjusted_warmup = max(20, test_config.warm_up_iter // 5)
elif qlen >= 16:
adjusted_iterations = max(200, test_config.test_iter // 5)
adjusted_warmup = max(40, test_config.warm_up_iter // 2)
logger.info(f"Testing KTransformers MoE: quant={quant_mode}, qlen={qlen}, threads={thread_config.thread_count}, "
f"iterations={adjusted_iterations} (warmup={adjusted_warmup})")
# Set thread count for this test
os.environ['OMP_NUM_THREADS'] = str(thread_config.thread_count)
try:
with torch.inference_mode():
# Setup worker config with consistent threads per NUMA
worker_config = kt_kernel_ext.WorkerPoolConfig()
worker_config.subpool_count = sys_config.numa_count
worker_config.subpool_numa_map = list(range(sys_config.numa_count))
worker_config.subpool_thread_count = [thread_config.threads_per_numa] * sys_config.numa_count
CPUInfer = kt_kernel_ext.CPUInfer(worker_config)
# Create MoE layers
moes = []
gate_projs = []
up_projs = []
down_projs = []
logger.debug(f"Creating {test_config.layer_num} MoE layers...")
for i in range(test_config.layer_num):
gate_proj = torch.randn((test_config.expert_num, test_config.intermediate_size, test_config.hidden_size),
dtype=torch.float32).contiguous()
up_proj = torch.randn((test_config.expert_num, test_config.intermediate_size, test_config.hidden_size),
dtype=torch.float32).contiguous()
down_proj = torch.randn((test_config.expert_num, test_config.hidden_size, test_config.intermediate_size),
dtype=torch.float32).contiguous()
config = kt_kernel_ext.moe.MOEConfig(
test_config.expert_num, test_config.num_experts_per_tok,
test_config.hidden_size, test_config.intermediate_size)
config.max_len = test_config.max_len
config.gate_proj = gate_proj.data_ptr()
config.up_proj = up_proj.data_ptr()
config.down_proj = down_proj.data_ptr()
config.pool = CPUInfer.backend_
if quant_mode == "bf16":
moe = kt_kernel_ext.moe.AMXBF16_MOE(config)
elif quant_mode == "int8":
moe = kt_kernel_ext.moe.AMXInt8_MOE(config)
elif quant_mode == "int4":
moe = kt_kernel_ext.moe.AMXInt4_MOE(config)
else:
raise ValueError(f"Unsupported quantization mode: {quant_mode}")
CPUInfer.submit(moe.load_weights_task())
CPUInfer.sync()
gate_projs.append(gate_proj)
up_projs.append(up_proj)
down_projs.append(down_proj)
moes.append(moe)
# Prepare test data
logger.debug("Preparing test data...")
gen_iter = 1000
expert_ids = torch.rand(gen_iter * qlen, test_config.expert_num).argsort(dim=-1)[
:, :test_config.num_experts_per_tok
].reshape(gen_iter, qlen * test_config.num_experts_per_tok).contiguous()
weights = torch.rand((gen_iter, qlen, test_config.num_experts_per_tok),
dtype=torch.float32).contiguous()
input_tensor = torch.randn((test_config.layer_num, qlen, test_config.hidden_size),
dtype=torch.bfloat16).contiguous()
output_tensor = torch.empty((test_config.layer_num, qlen, test_config.hidden_size),
dtype=torch.bfloat16).contiguous()
bsz_tensor = torch.tensor([qlen], dtype=torch.int32)
# Warmup
logger.debug(f"Running {adjusted_warmup} warmup iterations...")
for i in range(adjusted_warmup):
layer_idx = i % test_config.layer_num
gen_idx = i % gen_iter
CPUInfer.submit(
moes[layer_idx].forward_task(
bsz_tensor.data_ptr(),
test_config.num_experts_per_tok,
expert_ids[gen_idx].data_ptr(),
weights[gen_idx].data_ptr(),
input_tensor[layer_idx].data_ptr(),
output_tensor[layer_idx].data_ptr(),
False,
)
)
CPUInfer.sync()
# Benchmark
logger.debug(f"Running {adjusted_iterations} benchmark iterations...")
start = time.perf_counter()
for i in range(adjusted_iterations):
layer_idx = i % test_config.layer_num
gen_idx = i % gen_iter
CPUInfer.submit(
moes[layer_idx].forward_task(
bsz_tensor.data_ptr(),
test_config.num_experts_per_tok,
expert_ids[gen_idx].data_ptr(),
weights[gen_idx].data_ptr(),
input_tensor[layer_idx].data_ptr(),
output_tensor[layer_idx].data_ptr(),
False,
)
)
CPUInfer.sync()
end = time.perf_counter()
# Calculate metrics
total_time = end - start
time_per_iter_us = total_time / adjusted_iterations * 1e6
# Bytes per element based on quantization
bytes_per_elem = {
"bf16": 2.0,
"int8": 1.0,
"int4": 0.5
}.get(quant_mode, 2.0)
# Memory bandwidth calculation (GB/s)
memory_per_iter = (
test_config.hidden_size * test_config.intermediate_size * 3 *
test_config.num_experts_per_tok *
(1/8 * test_config.expert_num * (1-(31/32)**qlen)) * bytes_per_elem
)
bandwidth_gbs = memory_per_iter * adjusted_iterations / total_time / 1e9
# FLOPS calculation (TFLOPS)
flops_per_iter = (
test_config.hidden_size * test_config.intermediate_size * qlen * 3 *
test_config.num_experts_per_tok * 2
)
tflops = flops_per_iter * adjusted_iterations / total_time / 1e12
logger.info(f"Results - Time: {total_time:.4f}s, Per-iter: {time_per_iter_us:.2f}μs, "
f"BW: {bandwidth_gbs:.2f} GB/s, TFLOPS: {tflops:.2f}")
return BenchmarkResult(
implementation="KTransformers",
quant_mode=quant_mode,
qlen=qlen,
thread_count=thread_config.thread_count,
total_time=total_time,
time_per_iter_us=time_per_iter_us,
bandwidth_gbs=bandwidth_gbs,
tflops=tflops,
iterations=adjusted_iterations
)
except Exception as e:
logger.error(f"KTransformers benchmark failed: {e}", exc_info=True)
return None
def run_sgl_int4_with_numactl(test_config: TestConfig, qlen: int,
thread_config: ThreadConfig) -> Optional[BenchmarkResult]:
"""Run SGL INT4 benchmark with numactl in subprocess"""
if not SGL_INT4_AVAILABLE:
logger.error("SGL INT4 not available, skipping benchmark")
return None
# Calculate SGL intermediate size (divided by NUMA nodes)
sgl_intermediate_size = test_config.intermediate_size // sys_config.numa_count
# Adjust iterations based on qlen to maintain reasonable runtime
adjusted_iterations = test_config.test_iter
adjusted_warmup = test_config.warm_up_iter
if qlen >= 1024:
adjusted_iterations = max(10, test_config.test_iter // 100)
adjusted_warmup = max(5, test_config.warm_up_iter // 20)
elif qlen >= 256:
adjusted_iterations = max(50, test_config.test_iter // 20)
adjusted_warmup = max(10, test_config.warm_up_iter // 10)
elif qlen >= 64:
adjusted_iterations = max(100, test_config.test_iter // 10)
adjusted_warmup = max(20, test_config.warm_up_iter // 5)
elif qlen >= 16:
adjusted_iterations = max(200, test_config.test_iter // 5)
adjusted_warmup = max(40, test_config.warm_up_iter // 2)
logger.info(f"Testing SGL INT4: qlen={qlen}, iterations={adjusted_iterations} (warmup={adjusted_warmup}), "
f"threads per NUMA: {thread_config.sgl_thread_count}")
script_content = f'''
import sys
sys.path.insert(0, '/home/xwy/Projects/sglang-jianan')
sys.path.insert(0, '/home/xwy/Projects/sglang-jianan/test')
import os
import torch
import numpy as np
import sgl_kernel
from srt.cpu.utils import autoawq_to_int4pack
import time
torch.manual_seed(1111)
M, N, K, E, topk = {qlen}, {sgl_intermediate_size}, {test_config.hidden_size}, {test_config.expert_num}, {test_config.num_experts_per_tok}
layer_num = {test_config.layer_num}
group_size = 128
kernel = torch.ops.sgl_kernel
# Prepare int4 data
dtype = torch.bfloat16
device = "cpu"
# Generate input activations for all layers
input_tensors = [torch.rand(M, K, dtype=dtype, device=device) / np.sqrt(K) for _ in range(layer_num)]
# Generate weights and pack for each layer
all_awq_w13_weight_pack = []
all_awq_w13_zero_pack = []
all_awq_w13_scales_pack = []
all_awq_w2_weight_pack = []
all_awq_w2_zero_pack = []
all_awq_w2_scales_pack = []
# Generate expert routing scores (different for each iteration)
gen_iter = 1000
all_topk_weights = []
all_topk_ids = []
for gen_idx in range(gen_iter):
score = torch.rand(M, E, dtype=dtype, device=device)
score = torch.softmax(score, dim=-1, dtype=torch.float32)
topk_weight, topk_ids = torch.topk(score, topk)
all_topk_weights.append(topk_weight)
all_topk_ids.append(topk_ids.to(torch.int32))
print("Creating " + str(layer_num) + " MoE layers...")
for layer_idx in range(layer_num):
# Generate INT4 quantized weights for each expert
# w1: gate and up projection (K -> 2*N)
awq_w13_weight = torch.randint(-127, 128, (E, K, 2 * N // 8), device=device).to(torch.int)
awq_w13_zero = torch.randint(0, 10, (E, K // group_size, 2 * N // 8), device=device).to(torch.int)
awq_w13_scales = torch.rand(E, K // group_size, 2 * N, dtype=dtype, device=device)
# w2: down projection (N -> K)
awq_w2_weight = torch.randint(-127, 128, (E, N, K // 8), device=device).to(torch.int)
awq_w2_zero = torch.randint(0, 10, (E, N // group_size, K // 8), device=device).to(torch.int)
awq_w2_scales = torch.rand(E, N // group_size, K, dtype=dtype, device=device)
# Pack weights for optimized kernel
awq_w13_weight_pack = []
awq_w13_zero_pack = []
awq_w13_scales_pack = []
awq_w2_weight_pack = []
awq_w2_zero_pack = []
awq_w2_scales_pack = []
for i in range(E):
packed_weight_13, packed_zero_13, packed_scales_13 = autoawq_to_int4pack(
awq_w13_weight[i], awq_w13_zero[i], awq_w13_scales[i], False
)
awq_w13_weight_pack.append(packed_weight_13)
awq_w13_zero_pack.append(packed_zero_13)
awq_w13_scales_pack.append(packed_scales_13)
packed_weight_2, packed_zero_2, packed_scales_2 = autoawq_to_int4pack(
awq_w2_weight[i], awq_w2_zero[i], awq_w2_scales[i], False
)
awq_w2_weight_pack.append(packed_weight_2)
awq_w2_zero_pack.append(packed_zero_2)
awq_w2_scales_pack.append(packed_scales_2)
all_awq_w13_weight_pack.append(torch.stack(awq_w13_weight_pack).detach())
all_awq_w13_zero_pack.append(torch.stack(awq_w13_zero_pack).detach())
all_awq_w13_scales_pack.append(torch.stack(awq_w13_scales_pack).detach())
all_awq_w2_weight_pack.append(torch.stack(awq_w2_weight_pack).detach())
all_awq_w2_zero_pack.append(torch.stack(awq_w2_zero_pack).detach())
all_awq_w2_scales_pack.append(torch.stack(awq_w2_scales_pack).detach())
# Warmup
print("Running " + str({adjusted_warmup}) + " warmup iterations...")
for i in range({adjusted_warmup}):
layer_idx = i % layer_num
gen_idx = i % gen_iter
out = kernel.fused_experts_cpu(
input_tensors[layer_idx],
all_awq_w13_weight_pack[layer_idx],
all_awq_w2_weight_pack[layer_idx],
all_topk_weights[gen_idx],
all_topk_ids[gen_idx],
False, # inplace
False, # use_int8_w8a8
False, # use_fp8_w8a16
True, # use_int4_w4a16
all_awq_w13_scales_pack[layer_idx],
all_awq_w2_scales_pack[layer_idx],
all_awq_w13_zero_pack[layer_idx],
all_awq_w2_zero_pack[layer_idx],
None, # block_size
None, # a1_scale
None, # a2_scale
True, # is_vnni
)
# Benchmark
print("Running " + str({adjusted_iterations}) + " benchmark iterations...")
start = time.perf_counter()
for i in range({adjusted_iterations}):
layer_idx = i % layer_num
gen_idx = i % gen_iter
out = kernel.fused_experts_cpu(
input_tensors[layer_idx],
all_awq_w13_weight_pack[layer_idx],
all_awq_w2_weight_pack[layer_idx],
all_topk_weights[gen_idx],
all_topk_ids[gen_idx],
False,
False,
False,
True,
all_awq_w13_scales_pack[layer_idx],
all_awq_w2_scales_pack[layer_idx],
all_awq_w13_zero_pack[layer_idx],
all_awq_w2_zero_pack[layer_idx],
None,
None,
None,
True,
)
end = time.perf_counter()
total_time = end - start
time_per_iter_us = total_time / {adjusted_iterations} * 1e6
# Calculate performance metrics for int4
bytes_per_elem = 0.5 # int4
memory_per_iter = (
{test_config.hidden_size} * {sgl_intermediate_size} * 3 * {test_config.num_experts_per_tok} *
(1/8 * {test_config.expert_num} * (1-(31/32)**{qlen})) * bytes_per_elem
)
bandwidth_gbs = memory_per_iter * {adjusted_iterations} / total_time / 1e9
# FLOPS calculation
flops_per_iter = {test_config.hidden_size} * {sgl_intermediate_size} * {qlen} * 3 * {test_config.num_experts_per_tok} * 2
tflops = flops_per_iter * {adjusted_iterations} / total_time / 1e12
print(f"SGL_RESULT:{{total_time}},{{time_per_iter_us}},{{bandwidth_gbs}},{{tflops}}")
'''
# Create temporary script in sglang-jianan directory
sglang_path = "/home/xwy/Projects/sglang-jianan"
temp_script = f"{sglang_path}/temp_sgl_int4_bench_{os.getpid()}_{qlen}.py"
try:
with open(temp_script, 'w') as f:
f.write(script_content)
# Setup environment
env = os.environ.copy()
env['MALLOC_CONF'] = env_config.malloc_conf
if os.path.exists(env_config.jemalloc_path):
env['LD_PRELOAD'] = env_config.jemalloc_path
env['OMP_NUM_THREADS'] = str(thread_config.sgl_thread_count)
# Run with numactl from the sglang-jianan directory
cmd = f"cd {sglang_path} && {thread_config.numa_prefix} python3 {temp_script}"
logger.debug(f"Running SGL INT4 command: {cmd}")
result = subprocess.run(cmd, shell=True, capture_output=True, text=True, env=env, timeout=300)
if result.returncode == 0:
# Parse result
for line in result.stdout.split('\n'):
if line.startswith('SGL_RESULT:'):
parts = line.replace('SGL_RESULT:', '').split(',')
if len(parts) >= 4:
try:
total_time = float(parts[0])
time_per_iter_us = float(parts[1])
bandwidth_gbs = float(parts[2])
tflops = float(parts[3])
logger.info(f"SGL INT4 Results - Time: {total_time:.4f}s, Per-iter: {time_per_iter_us:.2f}μs, "
f"BW: {bandwidth_gbs:.2f} GB/s, TFLOPS: {tflops:.2f}")
return BenchmarkResult(
implementation="SGL",
quant_mode="int4",
qlen=qlen,
thread_count=thread_config.thread_count,
total_time=total_time,
time_per_iter_us=time_per_iter_us,
bandwidth_gbs=bandwidth_gbs,
tflops=tflops,
iterations=adjusted_iterations
)
except ValueError as e:
logger.error(f"Failed to parse SGL INT4 results: {e}")
else:
logger.error(f"SGL INT4 subprocess failed with code {result.returncode}")
logger.error(f"STDOUT: {result.stdout}")
logger.error(f"STDERR: {result.stderr}")
except subprocess.TimeoutExpired:
logger.error("SGL INT4 benchmark timed out")
except Exception as e:
logger.error(f"SGL INT4 benchmark error: {e}", exc_info=True)
finally:
# Clean up
if os.path.exists(temp_script):
try:
os.remove(temp_script)
except:
pass
return None
def run_sgl_with_numactl(test_config: TestConfig, qlen: int,
thread_config: ThreadConfig) -> Optional[BenchmarkResult]:
"""Run SGL benchmark with numactl in subprocess"""
if not SGL_AVAILABLE:
logger.error("SGL not available, skipping benchmark")
return None
# Calculate SGL intermediate size (divided by NUMA nodes)
sgl_intermediate_size = test_config.intermediate_size // sys_config.numa_count
# Adjust iterations based on qlen to maintain reasonable runtime
adjusted_iterations = test_config.test_iter
adjusted_warmup = test_config.warm_up_iter
if qlen >= 1024:
adjusted_iterations = max(10, test_config.test_iter // 100)
adjusted_warmup = max(5, test_config.warm_up_iter // 20)
elif qlen >= 256:
adjusted_iterations = max(50, test_config.test_iter // 20)
adjusted_warmup = max(10, test_config.warm_up_iter // 10)
elif qlen >= 64:
adjusted_iterations = max(100, test_config.test_iter // 10)
adjusted_warmup = max(20, test_config.warm_up_iter // 5)
elif qlen >= 16:
adjusted_iterations = max(200, test_config.test_iter // 5)
adjusted_warmup = max(40, test_config.warm_up_iter // 2)
logger.info(f"Testing SGL INT8: qlen={qlen}, iterations={adjusted_iterations} (warmup={adjusted_warmup}), "
f"threads per NUMA: {thread_config.sgl_thread_count}")
script_content = f'''
import sys
sys.path.insert(0, "/home/xwy/Projects/sgl-cpu-tests")
import os
import torch
from sgl_kernel.common_ops import fused_experts_cpu as fused_experts
from sgl_kernel.common_ops import convert_weight_packed
import time
torch.manual_seed(1111)
M, N, K, E, topk = {qlen}, {sgl_intermediate_size}, {test_config.hidden_size}, {test_config.expert_num}, {test_config.num_experts_per_tok}
layer_num = {test_config.layer_num}
# Generate expert routing scores (different for each iteration)
gen_iter = 1000
all_topk_weights = []
all_topk_ids = []
for gen_idx in range(gen_iter):
score = torch.randn(M, E).to(dtype=torch.bfloat16)
score = torch.softmax(score, dim=-1, dtype=torch.float32)
topk_weight, topk_ids = torch.topk(score, topk)
all_topk_weights.append(topk_weight)
all_topk_ids.append(topk_ids.to(torch.int32))
prepack = True
inplace = True
use_int4_w4a16 = False
# Create multiple layers
print("Creating " + str(layer_num) + " MoE layers...")
inputs = []
packed_w1s_int8 = []
packed_w2s_int8 = []
w1_s_list = []
w2_s_list = []
for layer_idx in range(layer_num):
input_tensor = torch.randn(M, K).to(dtype=torch.bfloat16)
# int8 weights
w1_int8 = torch.randn(E, 2 * N, K).to(dtype=torch.int8)
w2_int8 = torch.randn(E, K, N).to(dtype=torch.int8)
packed_w1_int8 = convert_weight_packed(w1_int8)
packed_w2_int8 = convert_weight_packed(w2_int8)
w1_s = torch.rand(E, 2 * N)
w2_s = torch.rand(E, K)
inputs.append(input_tensor)
packed_w1s_int8.append(packed_w1_int8)
packed_w2s_int8.append(packed_w2_int8)
w1_s_list.append(w1_s)
w2_s_list.append(w2_s)
# Warmup
print("Running " + str({adjusted_warmup}) + " warmup iterations...")
for i in range({adjusted_warmup}):
layer_idx = i % layer_num
gen_idx = i % gen_iter
fused_experts(inputs[layer_idx], packed_w1s_int8[layer_idx], packed_w2s_int8[layer_idx],
all_topk_weights[gen_idx], all_topk_ids[gen_idx],
inplace, True, False, use_int4_w4a16, w1_s_list[layer_idx], w2_s_list[layer_idx],
None, None, None, None, None, prepack)
# Benchmark
print("Running " + str({adjusted_iterations}) + " benchmark iterations...")
start = time.perf_counter()
for i in range({adjusted_iterations}):
layer_idx = i % layer_num
gen_idx = i % gen_iter
fused_experts(inputs[layer_idx], packed_w1s_int8[layer_idx], packed_w2s_int8[layer_idx],
all_topk_weights[gen_idx], all_topk_ids[gen_idx],
inplace, True, False, use_int4_w4a16, w1_s_list[layer_idx], w2_s_list[layer_idx],
None, None, None, None, None, prepack)
end = time.perf_counter()
total_time = end - start
time_per_iter_us = total_time / {adjusted_iterations} * 1e6
# Calculate performance metrics for int8
bytes_per_elem = 1.0 # int8
memory_per_iter = (
{test_config.hidden_size} * {sgl_intermediate_size} * 3 * {test_config.num_experts_per_tok} *
(1/8 * {test_config.expert_num} * (1-(31/32)**{qlen})) * bytes_per_elem
)
bandwidth_gbs = memory_per_iter * {adjusted_iterations} / total_time / 1e9
# FLOPS calculation
flops_per_iter = {test_config.hidden_size} * {sgl_intermediate_size} * {qlen} * 3 * {test_config.num_experts_per_tok} * 2
tflops = flops_per_iter * {adjusted_iterations} / total_time / 1e12
print(f"SGL_RESULT:{{total_time}},{{time_per_iter_us}},{{bandwidth_gbs}},{{tflops}}")
'''
# Create temporary script
temp_script = f"/tmp/sgl_bench_{os.getpid()}_{qlen}.py"
try:
with open(temp_script, 'w') as f:
f.write(script_content)
# Setup environment
env = os.environ.copy()
env['MALLOC_CONF'] = env_config.malloc_conf
if os.path.exists(env_config.jemalloc_path):
env['LD_PRELOAD'] = env_config.jemalloc_path
env['OMP_NUM_THREADS'] = str(thread_config.sgl_thread_count)
# Run with numactl
cmd = f"{thread_config.numa_prefix} python3 {temp_script}"
logger.debug(f"Running SGL command: {cmd}")
result = subprocess.run(cmd, shell=True, capture_output=True, text=True, env=env, timeout=300)
if result.returncode == 0:
# Parse result
for line in result.stdout.split('\n'):
if line.startswith('SGL_RESULT:'):
parts = line.replace('SGL_RESULT:', '').split(',')
if len(parts) >= 4:
try:
total_time = float(parts[0])
time_per_iter_us = float(parts[1])
bandwidth_gbs = float(parts[2])
tflops = float(parts[3])
logger.info(f"SGL Results - Time: {total_time:.4f}s, Per-iter: {time_per_iter_us:.2f}μs, "
f"BW: {bandwidth_gbs:.2f} GB/s, TFLOPS: {tflops:.2f}")
return BenchmarkResult(
implementation="SGL",
quant_mode="int8",
qlen=qlen,
thread_count=thread_config.thread_count,
total_time=total_time,
time_per_iter_us=time_per_iter_us,
bandwidth_gbs=bandwidth_gbs,
tflops=tflops,
iterations=adjusted_iterations
)
except ValueError as e:
logger.error(f"Failed to parse SGL results: {e}")
else:
logger.error(f"SGL subprocess failed with code {result.returncode}: {result.stderr}")
except subprocess.TimeoutExpired:
logger.error("SGL benchmark timed out")
except Exception as e:
logger.error(f"SGL benchmark error: {e}", exc_info=True)
finally:
# Clean up
if os.path.exists(temp_script):
try:
os.remove(temp_script)
except:
pass
return None
def save_results(results: List[BenchmarkResult], test_config: TestConfig, filename: str = None) -> str:
"""Save benchmark results to JSON file"""
if not filename:
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
filename = f"moe_comparison_{timestamp}.json"
output_data = {
"timestamp": datetime.now().isoformat(),
"test_configuration": asdict(test_config),
"system_info": get_system_info(),
"results": [r.to_dict() for r in results],
"summary": {
"total_benchmarks": len(results),
"implementations_tested": list(set(r.implementation for r in results)),
"quantization_modes": list(set(r.quant_mode for r in results)),
"qlen_values_tested": sorted(set(r.qlen for r in results)),
"thread_counts_tested": sorted(set(r.thread_count for r in results))
}
}
with open(filename, 'w') as f:
json.dump(output_data, f, indent=2)
logger.info(f"Results saved to: {filename}")
return filename
def print_summary_table(results: List[BenchmarkResult]):
"""Print formatted summary table of results"""
if not results:
return
print("\n" + "=" * 100)
print("PERFORMANCE SUMMARY")
print("=" * 100)
print(f"{'Implementation':<15} {'Quant':<6} {'Threads':<8} {'QLen':<8} {'Time(μs)':<12} {'BW(GB/s)':<12} {'TFLOPS':<10} {'Speedup':<10}")
print("-" * 100)
# Group by configuration for better comparison
baseline_times = {}
for result in sorted(results, key=lambda r: (r.thread_count, r.qlen, r.implementation, r.quant_mode)):
key = (result.thread_count, result.qlen)
if key not in baseline_times:
baseline_times[key] = result.time_per_iter_us
speedup = "1.00x"
else:
speedup = f"{baseline_times[key]/result.time_per_iter_us:.2f}x"
print(f"{result.implementation:<15} {result.quant_mode:<6} {result.thread_count:<8} "
f"{result.qlen:<8} {result.time_per_iter_us:<12.2f} {result.bandwidth_gbs:<12.2f} "
f"{result.tflops:<10.2f} {speedup:<10}")
def main():
parser = argparse.ArgumentParser(description="Compare MoE performance between KTransformers and SGL")
parser.add_argument("--qlen", type=int, nargs="+", help="Sequence lengths to test")
parser.add_argument("--threads", type=int, nargs="+", help="Thread counts to test")
parser.add_argument("--iterations", type=int, help="Number of test iterations")
parser.add_argument("--warmup", type=int, help="Number of warmup iterations")
parser.add_argument("--output", type=str, help="Output filename for results")
parser.add_argument("--verbose", "-v", action="store_true", help="Enable verbose logging")
parser.add_argument("--resume", action="store_true", help="Resume from checkpoint if available")
parser.add_argument("--checkpoint-dir", type=str, help="Directory for checkpoint files")
parser.add_argument("--no-checkpoint", action="store_true", help="Disable checkpoint saving")
parser.add_argument("--framework", choices=["all", "ktransformers", "sgl"], default="all",
help="Framework to test (default: all)")
parser.add_argument("--precision", choices=["all", "int8", "int4"], default="all",
help="Precision to test (default: all)")
args = parser.parse_args()
# Configure logging level
if args.verbose:
logging.getLogger().setLevel(logging.DEBUG)
# Create test configuration
test_config = TestConfig()
if args.qlen:
test_config.qlen_values = args.qlen
if args.threads:
test_config.thread_count_values = args.threads
if args.iterations:
test_config.test_iter = args.iterations
if args.warmup:
test_config.warm_up_iter = args.warmup
# Determine which frameworks and precisions to test
test_ktransformers = args.framework in ["all", "ktransformers"] and KTRANSFORMERS_AVAILABLE
test_sgl = args.framework in ["all", "sgl"] and (SGL_AVAILABLE or SGL_INT4_AVAILABLE)
# Determine which precisions to test
test_precisions = []
if args.precision == "all":
test_precisions = ["int8", "int4"]
else:
test_precisions = [args.precision]
# Print configuration
logger.info("MoE Performance Comparison")
logger.info("=" * 60)
logger.info(f"System configuration:")
logger.info(f" CPU cores: {sys_config.cpu_cores}")
logger.info(f" NUMA nodes: {sys_config.numa_count}")
logger.info(f"Test parameters:")
logger.info(f" Expert count: {test_config.expert_num}")
logger.info(f" Hidden size: {test_config.hidden_size}")
logger.info(f" Intermediate size: {test_config.intermediate_size}")
logger.info(f" Experts per token: {test_config.num_experts_per_tok}")
logger.info(f" Test iterations: {test_config.test_iter}")
logger.info(f" Warmup iterations: {test_config.warm_up_iter}")
logger.info(f"Testing configurations:")
logger.info(f" QLEN values: {test_config.qlen_values}")
logger.info(f" Thread counts: {test_config.thread_count_values}")
logger.info(f" Frameworks: {args.framework}")
logger.info(f" Precisions: {args.precision}")
logger.info(f" Total configs: {test_config.total_configurations}")
print()
# Check availability
if not KTRANSFORMERS_AVAILABLE and not SGL_AVAILABLE:
logger.error("Neither KTransformers nor SGL is available. Cannot run benchmarks.")
return 1
# Initialize checkpoint manager
checkpoint_mgr = CheckpointManager(args.checkpoint_dir) if not args.no_checkpoint else None
# Load checkpoint if resuming
checkpoint_state = None
completed_configs = set()
all_results = []
start_time = datetime.now().isoformat()
if args.resume and checkpoint_mgr:
checkpoint_state = checkpoint_mgr.load_checkpoint()
if checkpoint_state:
# Verify configuration matches
if (checkpoint_state.test_config.qlen_values != test_config.qlen_values or
checkpoint_state.test_config.thread_count_values != test_config.thread_count_values):
logger.warning("Checkpoint configuration doesn't match current configuration")
response = input("Continue with checkpoint anyway? (y/n): ")
if response.lower() != 'y':
logger.info("Starting fresh run")
checkpoint_state = None
if checkpoint_state:
all_results = checkpoint_state.results
completed_configs = set(checkpoint_state.completed_configs)
start_time = checkpoint_state.start_time
logger.info(f"Resuming from checkpoint with {len(all_results)} results")
# Create checkpoint state if not loaded
if not checkpoint_state and checkpoint_mgr:
checkpoint_state = CheckpointState(
test_config=test_config,
completed_configs=[],
results=[],
start_time=start_time,
last_update=start_time
)
config_count = 0
total_configs_to_run = 0
# Calculate total configs to run
for thread_count in test_config.thread_count_values:
for qlen in test_config.qlen_values:
if test_ktransformers:
for quant_mode in test_precisions:
if (thread_count, qlen, "KTransformers", quant_mode) not in completed_configs:
total_configs_to_run += 1
if test_sgl:
if "int8" in test_precisions and SGL_AVAILABLE:
if (thread_count, qlen, "SGL", "int8") not in completed_configs:
total_configs_to_run += 1
if "int4" in test_precisions and SGL_INT4_AVAILABLE:
if (thread_count, qlen, "SGL", "int4") not in completed_configs:
total_configs_to_run += 1
logger.info(f"Total configurations to run: {total_configs_to_run}")
# Test all combinations
for thread_count in test_config.thread_count_values:
thread_config = ThreadConfig.from_thread_count(thread_count, sys_config.numa_count, sys_config.cpu_cores)
logger.info(f"\nThread Configuration: {thread_count} total ({thread_config.threads_per_numa} per NUMA)")
for qlen in test_config.qlen_values:
# Check for interrupt
if checkpoint_mgr and checkpoint_mgr.interrupted:
logger.warning("Interrupt detected, saving checkpoint and exiting...")
if checkpoint_state:
checkpoint_state.results = all_results
checkpoint_state.completed_configs = list(completed_configs)
checkpoint_mgr.save_checkpoint(checkpoint_state)
return 2
logger.info(f"\n--- Configuration: threads={thread_count}, qlen={qlen} ---")
# Test KTransformers
if test_ktransformers:
for quant_mode in test_precisions:
config_key = (thread_count, qlen, "KTransformers", quant_mode)
if config_key in completed_configs:
logger.info(f"Skipping already completed: KTransformers-{quant_mode}")
continue
config_count += 1
logger.info(f"Progress: {config_count}/{total_configs_to_run}")
result = bench_ktransformers_moe(test_config, quant_mode, qlen, thread_config)
if result:
all_results.append(result)
completed_configs.add(config_key)
# Save checkpoint after each successful test
if checkpoint_mgr and checkpoint_state:
checkpoint_state.results = all_results
checkpoint_state.completed_configs = list(completed_configs)
checkpoint_mgr.save_checkpoint(checkpoint_state)
# Test SGL int8
if test_sgl and "int8" in test_precisions and SGL_AVAILABLE:
config_key = (thread_count, qlen, "SGL", "int8")
if config_key in completed_configs:
logger.info("Skipping already completed: SGL-int8")
continue
config_count += 1
logger.info(f"Progress: {config_count}/{total_configs_to_run}")
logger.info(f"Testing SGL MoE (int8): qlen={qlen}, threads={thread_count}")
sgl_intermediate = test_config.intermediate_size // sys_config.numa_count
sgl_threads_per_numa = thread_config.sgl_thread_count
logger.info(f"Using NUMA TP: intermediate_size {test_config.intermediate_size} -> "
f"{sgl_intermediate} (/{sys_config.numa_count}), threads per NUMA: {sgl_threads_per_numa}")
result = run_sgl_with_numactl(test_config, qlen, thread_config)
if result:
all_results.append(result)
completed_configs.add(config_key)
# Save checkpoint after each successful test
if checkpoint_mgr and checkpoint_state:
checkpoint_state.results = all_results
checkpoint_state.completed_configs = list(completed_configs)
checkpoint_mgr.save_checkpoint(checkpoint_state)
# Test SGL int4
if test_sgl and "int4" in test_precisions and SGL_INT4_AVAILABLE:
config_key = (thread_count, qlen, "SGL", "int4")
if config_key in completed_configs:
logger.info("Skipping already completed: SGL-int4")
continue
config_count += 1
logger.info(f"Progress: {config_count}/{total_configs_to_run}")
logger.info(f"Testing SGL MoE (int4): qlen={qlen}, threads={thread_count}")
sgl_intermediate = test_config.intermediate_size // sys_config.numa_count
sgl_threads_per_numa = thread_config.sgl_thread_count
logger.info(f"Using NUMA TP: intermediate_size {test_config.intermediate_size} -> "
f"{sgl_intermediate} (/{sys_config.numa_count}), threads per NUMA: {sgl_threads_per_numa}")
result = run_sgl_int4_with_numactl(test_config, qlen, thread_config)
if result:
all_results.append(result)
completed_configs.add(config_key)
# Save checkpoint after each successful test
if checkpoint_mgr and checkpoint_state:
checkpoint_state.results = all_results
checkpoint_state.completed_configs = list(completed_configs)
checkpoint_mgr.save_checkpoint(checkpoint_state)
# Final summary
if all_results:
print_summary_table(all_results)
# Save results
output_file = save_results(all_results, test_config, args.output)
print(f"\nTotal benchmarks completed: {len(all_results)}")
print(f"Results saved to: {output_file}")
# Clear checkpoint on successful completion
if checkpoint_mgr and config_count == total_configs_to_run:
checkpoint_mgr.clear_checkpoint()
logger.info("All tests completed successfully, checkpoint cleared")
elif checkpoint_mgr and config_count < total_configs_to_run:
logger.warning(f"Only {config_count}/{total_configs_to_run} configurations completed")
logger.info("Checkpoint preserved for resuming")
# Print best performers per configuration
print("\nBest performers by configuration:")
from itertools import groupby
sorted_results = sorted(all_results, key=lambda r: (r.qlen, r.thread_count, r.time_per_iter_us))
for key, group in groupby(sorted_results, key=lambda r: (r.qlen, r.thread_count)):
qlen, threads = key
best = next(group)
print(f" QLen={qlen}, Threads={threads}: {best.implementation}-{best.quant_mode} "
f"({best.time_per_iter_us:.2f}μs, {best.tflops:.2f} TFLOPS)")
else:
logger.error("No successful benchmarks completed.")
return 1
return 0
if __name__ == "__main__":
sys.exit(main())