mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2026-03-14 18:37:23 +00:00
1371 lines
56 KiB
Python
1371 lines
56 KiB
Python
#!/usr/bin/env python
|
|
# coding=utf-8
|
|
"""
|
|
MoE Performance Comparison Script
|
|
Compares performance between KTransformers AMX MoE and SGL CPU MoE implementations
|
|
"""
|
|
import os
|
|
import sys
|
|
import time
|
|
import json
|
|
import platform
|
|
import subprocess
|
|
import argparse
|
|
import logging
|
|
import signal
|
|
from datetime import datetime
|
|
from typing import Dict, List, Optional, Tuple
|
|
from dataclasses import dataclass, asdict
|
|
from pathlib import Path
|
|
|
|
# Configure logging
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format='%(asctime)s - %(levelname)s - %(message)s'
|
|
)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Environment configuration
|
|
@dataclass
|
|
class EnvironmentConfig:
|
|
malloc_conf: str = "oversize_threshold:1,background_thread:true,metadata_thp:auto,dirty_decay_ms:-1,muzzy_decay_ms:-1"
|
|
jemalloc_path: str = "/home/xwy/Projects/jemalloc/lib/libjemalloc.so"
|
|
|
|
def apply(self):
|
|
os.environ['MALLOC_CONF'] = self.malloc_conf
|
|
if os.path.exists(self.jemalloc_path):
|
|
os.environ['LD_PRELOAD'] = self.jemalloc_path
|
|
else:
|
|
logger.warning(f"jemalloc not found at {self.jemalloc_path}")
|
|
|
|
# Apply environment configuration
|
|
env_config = EnvironmentConfig()
|
|
env_config.apply()
|
|
|
|
# Add paths for both implementations
|
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'build'))
|
|
sys.path.insert(0, '/home/xwy/Projects/sgl-cpu-tests')
|
|
|
|
import torch
|
|
|
|
# Try importing both implementations
|
|
try:
|
|
import kt_kernel_ext
|
|
KTRANSFORMERS_AVAILABLE = True
|
|
logger.info("KTransformers kt_kernel_ext loaded successfully")
|
|
except ImportError as e:
|
|
KTRANSFORMERS_AVAILABLE = False
|
|
logger.warning(f"KTransformers kt_kernel_ext not available: {e}")
|
|
|
|
try:
|
|
from sgl_kernel.common_ops import fused_experts_cpu
|
|
from sgl_kernel.common_ops import convert_weight_packed
|
|
SGL_AVAILABLE = True
|
|
logger.info("SGL kernel loaded successfully")
|
|
except ImportError as e:
|
|
SGL_AVAILABLE = False
|
|
logger.warning(f"SGL kernel not available: {e}")
|
|
|
|
# Try importing int4 support
|
|
try:
|
|
# For SGL INT4, we'll check if the sglang-jianan directory exists
|
|
import os
|
|
sglang_path = "/home/xwy/Projects/sglang-jianan"
|
|
if os.path.exists(sglang_path) and os.path.exists(os.path.join(sglang_path, "benchmark/kernels/int4_moe/benchmark_int4_moe.py")):
|
|
SGL_INT4_AVAILABLE = True
|
|
logger.info("SGL INT4 support available (via sglang-jianan)")
|
|
else:
|
|
SGL_INT4_AVAILABLE = False
|
|
logger.warning("SGL INT4 support not available: sglang-jianan directory not found")
|
|
except Exception as e:
|
|
SGL_INT4_AVAILABLE = False
|
|
logger.warning(f"SGL INT4 support not available: {e}")
|
|
|
|
def get_cpu_count() -> int:
|
|
"""Get logical CPU core count (including hyperthreading)"""
|
|
cpu_count = None
|
|
|
|
# Method 1: os.cpu_count()
|
|
try:
|
|
cpu_count = os.cpu_count()
|
|
if cpu_count and cpu_count > 0:
|
|
logger.info(f"Detected {cpu_count} logical CPU cores via os.cpu_count()")
|
|
return cpu_count
|
|
except Exception as e:
|
|
logger.debug(f"os.cpu_count() failed: {e}")
|
|
|
|
# Method 2: Check /proc/cpuinfo
|
|
try:
|
|
with open('/proc/cpuinfo', 'r') as f:
|
|
cpu_count = sum(1 for line in f if line.strip().startswith('processor'))
|
|
if cpu_count > 0:
|
|
logger.info(f"Detected {cpu_count} logical CPU cores via /proc/cpuinfo")
|
|
return cpu_count
|
|
except Exception as e:
|
|
logger.debug(f"Failed to read /proc/cpuinfo: {e}")
|
|
|
|
# Default fallback
|
|
logger.warning("Could not detect CPU count, defaulting to 32")
|
|
return 32
|
|
|
|
def get_physical_cpu_count() -> int:
|
|
"""Get physical CPU core count (excluding hyperthreading)"""
|
|
|
|
# Method 1: Try lscpu command
|
|
try:
|
|
result = subprocess.run(['lscpu'], capture_output=True, text=True, timeout=5)
|
|
if result.returncode == 0:
|
|
cores_per_socket = None
|
|
sockets = None
|
|
for line in result.stdout.split('\n'):
|
|
if 'Core(s) per socket:' in line:
|
|
cores_per_socket = int(line.split(':')[1].strip())
|
|
elif 'Socket(s):' in line:
|
|
sockets = int(line.split(':')[1].strip())
|
|
|
|
if cores_per_socket and sockets:
|
|
physical_cores = cores_per_socket * sockets
|
|
logger.info(f"Detected {physical_cores} physical CPU cores via lscpu")
|
|
return physical_cores
|
|
except Exception as e:
|
|
logger.debug(f"lscpu failed: {e}")
|
|
|
|
# Method 2: Check /sys/devices/system/cpu/
|
|
try:
|
|
cpu_path = '/sys/devices/system/cpu/'
|
|
if os.path.exists(cpu_path):
|
|
# Count unique physical core IDs
|
|
physical_cores = set()
|
|
for cpu_dir in os.listdir(cpu_path):
|
|
if cpu_dir.startswith('cpu') and cpu_dir[3:].isdigit():
|
|
core_id_path = os.path.join(cpu_path, cpu_dir, 'topology/core_id')
|
|
if os.path.exists(core_id_path):
|
|
with open(core_id_path, 'r') as f:
|
|
core_id = f.read().strip()
|
|
physical_cores.add(core_id)
|
|
|
|
if physical_cores:
|
|
count = len(physical_cores)
|
|
logger.info(f"Detected {count} physical CPU cores via sysfs")
|
|
return count
|
|
except Exception as e:
|
|
logger.debug(f"Failed to check sysfs: {e}")
|
|
|
|
# Method 3: Parse /proc/cpuinfo for unique core ids
|
|
try:
|
|
with open('/proc/cpuinfo', 'r') as f:
|
|
content = f.read()
|
|
cores = set()
|
|
current_physical_id = None
|
|
|
|
for line in content.split('\n'):
|
|
if line.startswith('physical id'):
|
|
current_physical_id = line.split(':')[1].strip()
|
|
elif line.startswith('core id') and current_physical_id is not None:
|
|
core_id = line.split(':')[1].strip()
|
|
cores.add(f"{current_physical_id}:{core_id}")
|
|
|
|
if cores:
|
|
count = len(cores)
|
|
logger.info(f"Detected {count} physical CPU cores via /proc/cpuinfo")
|
|
return count
|
|
except Exception as e:
|
|
logger.debug(f"Failed to parse /proc/cpuinfo: {e}")
|
|
|
|
# Fallback: assume hyperthreading is enabled and divide logical cores by 2
|
|
try:
|
|
logical_count = get_cpu_count()
|
|
if logical_count > 0:
|
|
# Assume hyperthreading, so physical cores = logical cores / 2
|
|
physical_count = logical_count // 2
|
|
logger.warning(f"Could not detect physical cores directly. Assuming hyperthreading enabled: {logical_count} logical cores -> {physical_count} physical cores")
|
|
return physical_count
|
|
except:
|
|
pass
|
|
|
|
# Default fallback
|
|
logger.warning("Could not detect physical CPU count, defaulting to 32")
|
|
return 32
|
|
|
|
# Test configuration dataclass
|
|
@dataclass
|
|
class TestConfig:
|
|
expert_num: int = 256
|
|
hidden_size: int = 7168
|
|
intermediate_size: int = 2048
|
|
max_len: int = 25600
|
|
num_experts_per_tok: int = 8
|
|
layer_num: int = 5
|
|
warm_up_iter: int = 100
|
|
test_iter: int = 10000
|
|
qlen_values: List[int] = None
|
|
thread_count_values: List[int] = None
|
|
|
|
def __post_init__(self):
|
|
if self.qlen_values is None:
|
|
self.qlen_values = [1, 4, 16, 64, 256, 1024, 2048]
|
|
if self.thread_count_values is None:
|
|
# Default to physical CPU core count
|
|
physical_cores = get_physical_cpu_count()
|
|
self.thread_count_values = [physical_cores]
|
|
|
|
@property
|
|
def total_configurations(self) -> int:
|
|
return len(self.qlen_values) * len(self.thread_count_values)
|
|
|
|
def get_numa_count() -> int:
|
|
"""Get NUMA node count from system with multiple fallback methods"""
|
|
# Method 1: Try numactl
|
|
try:
|
|
result = subprocess.run(['numactl', '--hardware'],
|
|
capture_output=True, text=True, timeout=5)
|
|
if result.returncode == 0:
|
|
for line in result.stdout.split('\n'):
|
|
if 'available:' in line and 'nodes' in line:
|
|
parts = line.split()
|
|
if len(parts) >= 2 and parts[1].isdigit():
|
|
numa_count = int(parts[1])
|
|
logger.info(f"Detected {numa_count} NUMA nodes via numactl")
|
|
return numa_count
|
|
except (subprocess.TimeoutExpired, FileNotFoundError) as e:
|
|
logger.debug(f"numactl not available: {e}")
|
|
|
|
# Method 2: Check /sys/devices/system/node/
|
|
try:
|
|
node_path = '/sys/devices/system/node/'
|
|
if os.path.exists(node_path):
|
|
numa_dirs = [d for d in os.listdir(node_path) if d.startswith('node')]
|
|
if numa_dirs:
|
|
numa_count = len(numa_dirs)
|
|
logger.info(f"Detected {numa_count} NUMA nodes via sysfs")
|
|
return numa_count
|
|
except Exception as e:
|
|
logger.debug(f"Failed to check sysfs: {e}")
|
|
|
|
# Default fallback
|
|
logger.warning("Could not detect NUMA configuration, defaulting to 2 nodes")
|
|
return 2
|
|
|
|
# System configuration
|
|
@dataclass
|
|
class SystemConfig:
|
|
numa_count: int = 0
|
|
cpu_cores: int = 0
|
|
|
|
def __post_init__(self):
|
|
if self.numa_count == 0:
|
|
self.numa_count = get_numa_count()
|
|
if self.cpu_cores == 0:
|
|
self.cpu_cores = get_cpu_count()
|
|
|
|
sys_config = SystemConfig()
|
|
|
|
@dataclass
|
|
class ThreadConfig:
|
|
thread_count: int
|
|
threads_per_numa: int
|
|
sgl_thread_count: int
|
|
numa_prefix: str
|
|
|
|
@classmethod
|
|
def from_thread_count(cls, thread_count: int, numa_count: int, cpu_cores: int) -> 'ThreadConfig':
|
|
"""Create thread configuration for a specific thread count"""
|
|
# Validate thread count
|
|
if thread_count > cpu_cores:
|
|
logger.warning(f"thread_count ({thread_count}) > cpu_cores ({cpu_cores}), using all cores")
|
|
thread_count = cpu_cores
|
|
|
|
threads_per_numa = thread_count // numa_count
|
|
sgl_thread_count = threads_per_numa
|
|
last_core = sgl_thread_count - 1
|
|
numa_prefix = f"numactl --physcpubind=0-{last_core} --membind=0"
|
|
|
|
return cls(
|
|
thread_count=thread_count,
|
|
threads_per_numa=threads_per_numa,
|
|
sgl_thread_count=sgl_thread_count,
|
|
numa_prefix=numa_prefix
|
|
)
|
|
|
|
def get_system_info() -> Dict[str, any]:
|
|
"""Get comprehensive system information"""
|
|
info = {}
|
|
|
|
# Basic system info
|
|
uname = platform.uname()
|
|
info["system_name"] = uname.system
|
|
info["node_name"] = uname.node
|
|
info["release"] = uname.release
|
|
info["machine"] = uname.machine
|
|
info["cpu_count"] = sys_config.cpu_cores
|
|
info["numa_nodes"] = sys_config.numa_count
|
|
|
|
# CPU model information
|
|
if os.path.exists('/proc/cpuinfo'):
|
|
try:
|
|
with open('/proc/cpuinfo', 'r') as f:
|
|
cpu_info = f.read()
|
|
for line in cpu_info.split('\n'):
|
|
if "model name" in line:
|
|
info["cpu_model"] = line.split(":", 1)[1].strip()
|
|
break
|
|
# Check for CPU features
|
|
if "flags" in cpu_info:
|
|
flags_line = next(line for line in cpu_info.split('\n') if "flags" in line)
|
|
flags = flags_line.split(":", 1)[1].strip().split()
|
|
info["cpu_features"] = {
|
|
"avx2": "avx2" in flags,
|
|
"avx512": any(f.startswith("avx512") for f in flags),
|
|
"amx": any("amx" in f for f in flags)
|
|
}
|
|
except Exception as e:
|
|
logger.debug(f"Failed to read CPU info: {e}")
|
|
|
|
# Memory information
|
|
try:
|
|
import psutil
|
|
mem = psutil.virtual_memory()
|
|
info["total_memory_gb"] = round(mem.total / (1024**3), 2)
|
|
info["available_memory_gb"] = round(mem.available / (1024**3), 2)
|
|
except ImportError:
|
|
pass
|
|
|
|
# Python and PyTorch versions
|
|
info["python_version"] = sys.version.split()[0]
|
|
info["torch_version"] = torch.__version__
|
|
info["cuda_available"] = torch.cuda.is_available()
|
|
if torch.cuda.is_available():
|
|
info["cuda_version"] = torch.version.cuda
|
|
|
|
return info
|
|
|
|
@dataclass
|
|
class BenchmarkResult:
|
|
implementation: str
|
|
quant_mode: str
|
|
qlen: int
|
|
thread_count: int
|
|
total_time: float
|
|
time_per_iter_us: float
|
|
bandwidth_gbs: float
|
|
tflops: float
|
|
iterations: int
|
|
|
|
def to_dict(self) -> Dict:
|
|
return asdict(self)
|
|
|
|
@dataclass
|
|
class CheckpointState:
|
|
"""State information for checkpoint/resume functionality"""
|
|
test_config: TestConfig
|
|
completed_configs: List[Tuple[int, int, str, str]] # (thread_count, qlen, implementation, quant_mode)
|
|
results: List[BenchmarkResult]
|
|
start_time: str
|
|
last_update: str
|
|
|
|
def to_dict(self) -> Dict:
|
|
return {
|
|
'test_config': asdict(self.test_config),
|
|
'completed_configs': self.completed_configs,
|
|
'results': [r.to_dict() for r in self.results],
|
|
'start_time': self.start_time,
|
|
'last_update': self.last_update
|
|
}
|
|
|
|
@classmethod
|
|
def from_dict(cls, data: Dict) -> 'CheckpointState':
|
|
test_config = TestConfig(**data['test_config'])
|
|
results = [BenchmarkResult(**r) for r in data['results']]
|
|
return cls(
|
|
test_config=test_config,
|
|
completed_configs=data['completed_configs'],
|
|
results=results,
|
|
start_time=data['start_time'],
|
|
last_update=data['last_update']
|
|
)
|
|
|
|
class CheckpointManager:
|
|
"""Manages checkpoint saving and loading"""
|
|
def __init__(self, checkpoint_dir: str = None):
|
|
self.checkpoint_dir = Path(checkpoint_dir) if checkpoint_dir else Path.cwd() / "checkpoints"
|
|
self.checkpoint_dir.mkdir(exist_ok=True)
|
|
self.checkpoint_file = self.checkpoint_dir / "moe_benchmark_checkpoint.json"
|
|
self.interrupted = False
|
|
|
|
# Set up signal handler for graceful shutdown
|
|
signal.signal(signal.SIGINT, self._signal_handler)
|
|
signal.signal(signal.SIGTERM, self._signal_handler)
|
|
|
|
def _signal_handler(self, signum, frame):
|
|
logger.warning(f"Received signal {signum}, will save checkpoint after current test...")
|
|
self.interrupted = True
|
|
|
|
def save_checkpoint(self, state: CheckpointState):
|
|
"""Save checkpoint to file"""
|
|
state.last_update = datetime.now().isoformat()
|
|
|
|
# Save to temporary file first for atomicity
|
|
temp_file = self.checkpoint_file.with_suffix('.tmp')
|
|
try:
|
|
with open(temp_file, 'w') as f:
|
|
json.dump(state.to_dict(), f, indent=2)
|
|
|
|
# Atomically rename
|
|
temp_file.replace(self.checkpoint_file)
|
|
logger.info(f"Checkpoint saved: {len(state.results)} results, {len(state.completed_configs)} configs completed")
|
|
except Exception as e:
|
|
logger.error(f"Failed to save checkpoint: {e}")
|
|
if temp_file.exists():
|
|
temp_file.unlink()
|
|
|
|
def load_checkpoint(self) -> Optional[CheckpointState]:
|
|
"""Load checkpoint from file if exists"""
|
|
if not self.checkpoint_file.exists():
|
|
return None
|
|
|
|
try:
|
|
with open(self.checkpoint_file, 'r') as f:
|
|
data = json.load(f)
|
|
state = CheckpointState.from_dict(data)
|
|
logger.info(f"Loaded checkpoint: {len(state.results)} results, {len(state.completed_configs)} configs completed")
|
|
logger.info(f"Checkpoint started at {state.start_time}, last updated {state.last_update}")
|
|
return state
|
|
except Exception as e:
|
|
logger.error(f"Failed to load checkpoint: {e}")
|
|
return None
|
|
|
|
def clear_checkpoint(self):
|
|
"""Remove checkpoint file"""
|
|
if self.checkpoint_file.exists():
|
|
self.checkpoint_file.unlink()
|
|
logger.info("Checkpoint cleared")
|
|
|
|
def bench_ktransformers_moe(test_config: TestConfig, quant_mode: str, qlen: int,
|
|
thread_config: ThreadConfig) -> Optional[BenchmarkResult]:
|
|
"""Benchmark KTransformers AMX MoE implementation"""
|
|
if not KTRANSFORMERS_AVAILABLE:
|
|
logger.error("KTransformers not available, skipping benchmark")
|
|
return None
|
|
|
|
# Adjust iterations based on qlen to maintain reasonable runtime
|
|
adjusted_iterations = test_config.test_iter
|
|
adjusted_warmup = test_config.warm_up_iter
|
|
if qlen >= 1024:
|
|
adjusted_iterations = max(10, test_config.test_iter // 100)
|
|
adjusted_warmup = max(5, test_config.warm_up_iter // 20)
|
|
elif qlen >= 256:
|
|
adjusted_iterations = max(50, test_config.test_iter // 20)
|
|
adjusted_warmup = max(10, test_config.warm_up_iter // 10)
|
|
elif qlen >= 64:
|
|
adjusted_iterations = max(100, test_config.test_iter // 10)
|
|
adjusted_warmup = max(20, test_config.warm_up_iter // 5)
|
|
elif qlen >= 16:
|
|
adjusted_iterations = max(200, test_config.test_iter // 5)
|
|
adjusted_warmup = max(40, test_config.warm_up_iter // 2)
|
|
|
|
logger.info(f"Testing KTransformers MoE: quant={quant_mode}, qlen={qlen}, threads={thread_config.thread_count}, "
|
|
f"iterations={adjusted_iterations} (warmup={adjusted_warmup})")
|
|
|
|
# Set thread count for this test
|
|
os.environ['OMP_NUM_THREADS'] = str(thread_config.thread_count)
|
|
|
|
try:
|
|
with torch.inference_mode():
|
|
# Setup worker config with consistent threads per NUMA
|
|
worker_config = kt_kernel_ext.WorkerPoolConfig()
|
|
worker_config.subpool_count = sys_config.numa_count
|
|
worker_config.subpool_numa_map = list(range(sys_config.numa_count))
|
|
worker_config.subpool_thread_count = [thread_config.threads_per_numa] * sys_config.numa_count
|
|
CPUInfer = kt_kernel_ext.CPUInfer(worker_config)
|
|
|
|
# Create MoE layers
|
|
moes = []
|
|
gate_projs = []
|
|
up_projs = []
|
|
down_projs = []
|
|
|
|
logger.debug(f"Creating {test_config.layer_num} MoE layers...")
|
|
for i in range(test_config.layer_num):
|
|
gate_proj = torch.randn((test_config.expert_num, test_config.intermediate_size, test_config.hidden_size),
|
|
dtype=torch.float32).contiguous()
|
|
up_proj = torch.randn((test_config.expert_num, test_config.intermediate_size, test_config.hidden_size),
|
|
dtype=torch.float32).contiguous()
|
|
down_proj = torch.randn((test_config.expert_num, test_config.hidden_size, test_config.intermediate_size),
|
|
dtype=torch.float32).contiguous()
|
|
|
|
config = kt_kernel_ext.moe.MOEConfig(
|
|
test_config.expert_num, test_config.num_experts_per_tok,
|
|
test_config.hidden_size, test_config.intermediate_size)
|
|
config.max_len = test_config.max_len
|
|
config.gate_proj = gate_proj.data_ptr()
|
|
config.up_proj = up_proj.data_ptr()
|
|
config.down_proj = down_proj.data_ptr()
|
|
config.pool = CPUInfer.backend_
|
|
|
|
if quant_mode == "bf16":
|
|
moe = kt_kernel_ext.moe.AMXBF16_MOE(config)
|
|
elif quant_mode == "int8":
|
|
moe = kt_kernel_ext.moe.AMXInt8_MOE(config)
|
|
elif quant_mode == "int4":
|
|
moe = kt_kernel_ext.moe.AMXInt4_MOE(config)
|
|
else:
|
|
raise ValueError(f"Unsupported quantization mode: {quant_mode}")
|
|
|
|
CPUInfer.submit(moe.load_weights_task())
|
|
CPUInfer.sync()
|
|
gate_projs.append(gate_proj)
|
|
up_projs.append(up_proj)
|
|
down_projs.append(down_proj)
|
|
moes.append(moe)
|
|
|
|
# Prepare test data
|
|
logger.debug("Preparing test data...")
|
|
gen_iter = 1000
|
|
expert_ids = torch.rand(gen_iter * qlen, test_config.expert_num).argsort(dim=-1)[
|
|
:, :test_config.num_experts_per_tok
|
|
].reshape(gen_iter, qlen * test_config.num_experts_per_tok).contiguous()
|
|
|
|
weights = torch.rand((gen_iter, qlen, test_config.num_experts_per_tok),
|
|
dtype=torch.float32).contiguous()
|
|
input_tensor = torch.randn((test_config.layer_num, qlen, test_config.hidden_size),
|
|
dtype=torch.bfloat16).contiguous()
|
|
output_tensor = torch.empty((test_config.layer_num, qlen, test_config.hidden_size),
|
|
dtype=torch.bfloat16).contiguous()
|
|
bsz_tensor = torch.tensor([qlen], dtype=torch.int32)
|
|
|
|
# Warmup
|
|
logger.debug(f"Running {adjusted_warmup} warmup iterations...")
|
|
for i in range(adjusted_warmup):
|
|
layer_idx = i % test_config.layer_num
|
|
gen_idx = i % gen_iter
|
|
CPUInfer.submit(
|
|
moes[layer_idx].forward_task(
|
|
bsz_tensor.data_ptr(),
|
|
test_config.num_experts_per_tok,
|
|
expert_ids[gen_idx].data_ptr(),
|
|
weights[gen_idx].data_ptr(),
|
|
input_tensor[layer_idx].data_ptr(),
|
|
output_tensor[layer_idx].data_ptr(),
|
|
False,
|
|
)
|
|
)
|
|
CPUInfer.sync()
|
|
|
|
# Benchmark
|
|
logger.debug(f"Running {adjusted_iterations} benchmark iterations...")
|
|
start = time.perf_counter()
|
|
for i in range(adjusted_iterations):
|
|
layer_idx = i % test_config.layer_num
|
|
gen_idx = i % gen_iter
|
|
CPUInfer.submit(
|
|
moes[layer_idx].forward_task(
|
|
bsz_tensor.data_ptr(),
|
|
test_config.num_experts_per_tok,
|
|
expert_ids[gen_idx].data_ptr(),
|
|
weights[gen_idx].data_ptr(),
|
|
input_tensor[layer_idx].data_ptr(),
|
|
output_tensor[layer_idx].data_ptr(),
|
|
False,
|
|
)
|
|
)
|
|
CPUInfer.sync()
|
|
end = time.perf_counter()
|
|
|
|
# Calculate metrics
|
|
total_time = end - start
|
|
time_per_iter_us = total_time / adjusted_iterations * 1e6
|
|
|
|
# Bytes per element based on quantization
|
|
bytes_per_elem = {
|
|
"bf16": 2.0,
|
|
"int8": 1.0,
|
|
"int4": 0.5
|
|
}.get(quant_mode, 2.0)
|
|
|
|
# Memory bandwidth calculation (GB/s)
|
|
memory_per_iter = (
|
|
test_config.hidden_size * test_config.intermediate_size * 3 *
|
|
test_config.num_experts_per_tok *
|
|
(1/8 * test_config.expert_num * (1-(31/32)**qlen)) * bytes_per_elem
|
|
)
|
|
bandwidth_gbs = memory_per_iter * adjusted_iterations / total_time / 1e9
|
|
|
|
# FLOPS calculation (TFLOPS)
|
|
flops_per_iter = (
|
|
test_config.hidden_size * test_config.intermediate_size * qlen * 3 *
|
|
test_config.num_experts_per_tok * 2
|
|
)
|
|
tflops = flops_per_iter * adjusted_iterations / total_time / 1e12
|
|
|
|
logger.info(f"Results - Time: {total_time:.4f}s, Per-iter: {time_per_iter_us:.2f}μs, "
|
|
f"BW: {bandwidth_gbs:.2f} GB/s, TFLOPS: {tflops:.2f}")
|
|
|
|
return BenchmarkResult(
|
|
implementation="KTransformers",
|
|
quant_mode=quant_mode,
|
|
qlen=qlen,
|
|
thread_count=thread_config.thread_count,
|
|
total_time=total_time,
|
|
time_per_iter_us=time_per_iter_us,
|
|
bandwidth_gbs=bandwidth_gbs,
|
|
tflops=tflops,
|
|
iterations=adjusted_iterations
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.error(f"KTransformers benchmark failed: {e}", exc_info=True)
|
|
return None
|
|
|
|
def run_sgl_int4_with_numactl(test_config: TestConfig, qlen: int,
|
|
thread_config: ThreadConfig) -> Optional[BenchmarkResult]:
|
|
"""Run SGL INT4 benchmark with numactl in subprocess"""
|
|
if not SGL_INT4_AVAILABLE:
|
|
logger.error("SGL INT4 not available, skipping benchmark")
|
|
return None
|
|
|
|
# Calculate SGL intermediate size (divided by NUMA nodes)
|
|
sgl_intermediate_size = test_config.intermediate_size // sys_config.numa_count
|
|
|
|
# Adjust iterations based on qlen to maintain reasonable runtime
|
|
adjusted_iterations = test_config.test_iter
|
|
adjusted_warmup = test_config.warm_up_iter
|
|
if qlen >= 1024:
|
|
adjusted_iterations = max(10, test_config.test_iter // 100)
|
|
adjusted_warmup = max(5, test_config.warm_up_iter // 20)
|
|
elif qlen >= 256:
|
|
adjusted_iterations = max(50, test_config.test_iter // 20)
|
|
adjusted_warmup = max(10, test_config.warm_up_iter // 10)
|
|
elif qlen >= 64:
|
|
adjusted_iterations = max(100, test_config.test_iter // 10)
|
|
adjusted_warmup = max(20, test_config.warm_up_iter // 5)
|
|
elif qlen >= 16:
|
|
adjusted_iterations = max(200, test_config.test_iter // 5)
|
|
adjusted_warmup = max(40, test_config.warm_up_iter // 2)
|
|
|
|
logger.info(f"Testing SGL INT4: qlen={qlen}, iterations={adjusted_iterations} (warmup={adjusted_warmup}), "
|
|
f"threads per NUMA: {thread_config.sgl_thread_count}")
|
|
|
|
script_content = f'''
|
|
import sys
|
|
sys.path.insert(0, '/home/xwy/Projects/sglang-jianan')
|
|
sys.path.insert(0, '/home/xwy/Projects/sglang-jianan/test')
|
|
|
|
import os
|
|
import torch
|
|
import numpy as np
|
|
import sgl_kernel
|
|
from srt.cpu.utils import autoawq_to_int4pack
|
|
import time
|
|
|
|
torch.manual_seed(1111)
|
|
M, N, K, E, topk = {qlen}, {sgl_intermediate_size}, {test_config.hidden_size}, {test_config.expert_num}, {test_config.num_experts_per_tok}
|
|
layer_num = {test_config.layer_num}
|
|
group_size = 128
|
|
kernel = torch.ops.sgl_kernel
|
|
|
|
# Prepare int4 data
|
|
dtype = torch.bfloat16
|
|
device = "cpu"
|
|
|
|
# Generate input activations for all layers
|
|
input_tensors = [torch.rand(M, K, dtype=dtype, device=device) / np.sqrt(K) for _ in range(layer_num)]
|
|
|
|
# Generate weights and pack for each layer
|
|
all_awq_w13_weight_pack = []
|
|
all_awq_w13_zero_pack = []
|
|
all_awq_w13_scales_pack = []
|
|
all_awq_w2_weight_pack = []
|
|
all_awq_w2_zero_pack = []
|
|
all_awq_w2_scales_pack = []
|
|
|
|
# Generate expert routing scores (different for each iteration)
|
|
gen_iter = 1000
|
|
all_topk_weights = []
|
|
all_topk_ids = []
|
|
|
|
for gen_idx in range(gen_iter):
|
|
score = torch.rand(M, E, dtype=dtype, device=device)
|
|
score = torch.softmax(score, dim=-1, dtype=torch.float32)
|
|
topk_weight, topk_ids = torch.topk(score, topk)
|
|
all_topk_weights.append(topk_weight)
|
|
all_topk_ids.append(topk_ids.to(torch.int32))
|
|
|
|
print("Creating " + str(layer_num) + " MoE layers...")
|
|
for layer_idx in range(layer_num):
|
|
# Generate INT4 quantized weights for each expert
|
|
# w1: gate and up projection (K -> 2*N)
|
|
awq_w13_weight = torch.randint(-127, 128, (E, K, 2 * N // 8), device=device).to(torch.int)
|
|
awq_w13_zero = torch.randint(0, 10, (E, K // group_size, 2 * N // 8), device=device).to(torch.int)
|
|
awq_w13_scales = torch.rand(E, K // group_size, 2 * N, dtype=dtype, device=device)
|
|
|
|
# w2: down projection (N -> K)
|
|
awq_w2_weight = torch.randint(-127, 128, (E, N, K // 8), device=device).to(torch.int)
|
|
awq_w2_zero = torch.randint(0, 10, (E, N // group_size, K // 8), device=device).to(torch.int)
|
|
awq_w2_scales = torch.rand(E, N // group_size, K, dtype=dtype, device=device)
|
|
|
|
# Pack weights for optimized kernel
|
|
awq_w13_weight_pack = []
|
|
awq_w13_zero_pack = []
|
|
awq_w13_scales_pack = []
|
|
awq_w2_weight_pack = []
|
|
awq_w2_zero_pack = []
|
|
awq_w2_scales_pack = []
|
|
|
|
for i in range(E):
|
|
packed_weight_13, packed_zero_13, packed_scales_13 = autoawq_to_int4pack(
|
|
awq_w13_weight[i], awq_w13_zero[i], awq_w13_scales[i], False
|
|
)
|
|
awq_w13_weight_pack.append(packed_weight_13)
|
|
awq_w13_zero_pack.append(packed_zero_13)
|
|
awq_w13_scales_pack.append(packed_scales_13)
|
|
|
|
packed_weight_2, packed_zero_2, packed_scales_2 = autoawq_to_int4pack(
|
|
awq_w2_weight[i], awq_w2_zero[i], awq_w2_scales[i], False
|
|
)
|
|
awq_w2_weight_pack.append(packed_weight_2)
|
|
awq_w2_zero_pack.append(packed_zero_2)
|
|
awq_w2_scales_pack.append(packed_scales_2)
|
|
|
|
all_awq_w13_weight_pack.append(torch.stack(awq_w13_weight_pack).detach())
|
|
all_awq_w13_zero_pack.append(torch.stack(awq_w13_zero_pack).detach())
|
|
all_awq_w13_scales_pack.append(torch.stack(awq_w13_scales_pack).detach())
|
|
all_awq_w2_weight_pack.append(torch.stack(awq_w2_weight_pack).detach())
|
|
all_awq_w2_zero_pack.append(torch.stack(awq_w2_zero_pack).detach())
|
|
all_awq_w2_scales_pack.append(torch.stack(awq_w2_scales_pack).detach())
|
|
|
|
# Warmup
|
|
print("Running " + str({adjusted_warmup}) + " warmup iterations...")
|
|
for i in range({adjusted_warmup}):
|
|
layer_idx = i % layer_num
|
|
gen_idx = i % gen_iter
|
|
out = kernel.fused_experts_cpu(
|
|
input_tensors[layer_idx],
|
|
all_awq_w13_weight_pack[layer_idx],
|
|
all_awq_w2_weight_pack[layer_idx],
|
|
all_topk_weights[gen_idx],
|
|
all_topk_ids[gen_idx],
|
|
False, # inplace
|
|
False, # use_int8_w8a8
|
|
False, # use_fp8_w8a16
|
|
True, # use_int4_w4a16
|
|
all_awq_w13_scales_pack[layer_idx],
|
|
all_awq_w2_scales_pack[layer_idx],
|
|
all_awq_w13_zero_pack[layer_idx],
|
|
all_awq_w2_zero_pack[layer_idx],
|
|
None, # block_size
|
|
None, # a1_scale
|
|
None, # a2_scale
|
|
True, # is_vnni
|
|
)
|
|
|
|
# Benchmark
|
|
print("Running " + str({adjusted_iterations}) + " benchmark iterations...")
|
|
start = time.perf_counter()
|
|
for i in range({adjusted_iterations}):
|
|
layer_idx = i % layer_num
|
|
gen_idx = i % gen_iter
|
|
out = kernel.fused_experts_cpu(
|
|
input_tensors[layer_idx],
|
|
all_awq_w13_weight_pack[layer_idx],
|
|
all_awq_w2_weight_pack[layer_idx],
|
|
all_topk_weights[gen_idx],
|
|
all_topk_ids[gen_idx],
|
|
False,
|
|
False,
|
|
False,
|
|
True,
|
|
all_awq_w13_scales_pack[layer_idx],
|
|
all_awq_w2_scales_pack[layer_idx],
|
|
all_awq_w13_zero_pack[layer_idx],
|
|
all_awq_w2_zero_pack[layer_idx],
|
|
None,
|
|
None,
|
|
None,
|
|
True,
|
|
)
|
|
end = time.perf_counter()
|
|
|
|
total_time = end - start
|
|
time_per_iter_us = total_time / {adjusted_iterations} * 1e6
|
|
|
|
# Calculate performance metrics for int4
|
|
bytes_per_elem = 0.5 # int4
|
|
memory_per_iter = (
|
|
{test_config.hidden_size} * {sgl_intermediate_size} * 3 * {test_config.num_experts_per_tok} *
|
|
(1/8 * {test_config.expert_num} * (1-(31/32)**{qlen})) * bytes_per_elem
|
|
)
|
|
bandwidth_gbs = memory_per_iter * {adjusted_iterations} / total_time / 1e9
|
|
|
|
# FLOPS calculation
|
|
flops_per_iter = {test_config.hidden_size} * {sgl_intermediate_size} * {qlen} * 3 * {test_config.num_experts_per_tok} * 2
|
|
tflops = flops_per_iter * {adjusted_iterations} / total_time / 1e12
|
|
|
|
print(f"SGL_RESULT:{{total_time}},{{time_per_iter_us}},{{bandwidth_gbs}},{{tflops}}")
|
|
'''
|
|
|
|
# Create temporary script in sglang-jianan directory
|
|
sglang_path = "/home/xwy/Projects/sglang-jianan"
|
|
temp_script = f"{sglang_path}/temp_sgl_int4_bench_{os.getpid()}_{qlen}.py"
|
|
|
|
try:
|
|
with open(temp_script, 'w') as f:
|
|
f.write(script_content)
|
|
|
|
# Setup environment
|
|
env = os.environ.copy()
|
|
env['MALLOC_CONF'] = env_config.malloc_conf
|
|
if os.path.exists(env_config.jemalloc_path):
|
|
env['LD_PRELOAD'] = env_config.jemalloc_path
|
|
env['OMP_NUM_THREADS'] = str(thread_config.sgl_thread_count)
|
|
|
|
# Run with numactl from the sglang-jianan directory
|
|
cmd = f"cd {sglang_path} && {thread_config.numa_prefix} python3 {temp_script}"
|
|
logger.debug(f"Running SGL INT4 command: {cmd}")
|
|
|
|
result = subprocess.run(cmd, shell=True, capture_output=True, text=True, env=env, timeout=300)
|
|
|
|
if result.returncode == 0:
|
|
# Parse result
|
|
for line in result.stdout.split('\n'):
|
|
if line.startswith('SGL_RESULT:'):
|
|
parts = line.replace('SGL_RESULT:', '').split(',')
|
|
if len(parts) >= 4:
|
|
try:
|
|
total_time = float(parts[0])
|
|
time_per_iter_us = float(parts[1])
|
|
bandwidth_gbs = float(parts[2])
|
|
tflops = float(parts[3])
|
|
|
|
logger.info(f"SGL INT4 Results - Time: {total_time:.4f}s, Per-iter: {time_per_iter_us:.2f}μs, "
|
|
f"BW: {bandwidth_gbs:.2f} GB/s, TFLOPS: {tflops:.2f}")
|
|
|
|
return BenchmarkResult(
|
|
implementation="SGL",
|
|
quant_mode="int4",
|
|
qlen=qlen,
|
|
thread_count=thread_config.thread_count,
|
|
total_time=total_time,
|
|
time_per_iter_us=time_per_iter_us,
|
|
bandwidth_gbs=bandwidth_gbs,
|
|
tflops=tflops,
|
|
iterations=adjusted_iterations
|
|
)
|
|
except ValueError as e:
|
|
logger.error(f"Failed to parse SGL INT4 results: {e}")
|
|
else:
|
|
logger.error(f"SGL INT4 subprocess failed with code {result.returncode}")
|
|
logger.error(f"STDOUT: {result.stdout}")
|
|
logger.error(f"STDERR: {result.stderr}")
|
|
|
|
except subprocess.TimeoutExpired:
|
|
logger.error("SGL INT4 benchmark timed out")
|
|
except Exception as e:
|
|
logger.error(f"SGL INT4 benchmark error: {e}", exc_info=True)
|
|
finally:
|
|
# Clean up
|
|
if os.path.exists(temp_script):
|
|
try:
|
|
os.remove(temp_script)
|
|
except:
|
|
pass
|
|
|
|
return None
|
|
|
|
def run_sgl_with_numactl(test_config: TestConfig, qlen: int,
|
|
thread_config: ThreadConfig) -> Optional[BenchmarkResult]:
|
|
"""Run SGL benchmark with numactl in subprocess"""
|
|
if not SGL_AVAILABLE:
|
|
logger.error("SGL not available, skipping benchmark")
|
|
return None
|
|
|
|
# Calculate SGL intermediate size (divided by NUMA nodes)
|
|
sgl_intermediate_size = test_config.intermediate_size // sys_config.numa_count
|
|
|
|
# Adjust iterations based on qlen to maintain reasonable runtime
|
|
adjusted_iterations = test_config.test_iter
|
|
adjusted_warmup = test_config.warm_up_iter
|
|
if qlen >= 1024:
|
|
adjusted_iterations = max(10, test_config.test_iter // 100)
|
|
adjusted_warmup = max(5, test_config.warm_up_iter // 20)
|
|
elif qlen >= 256:
|
|
adjusted_iterations = max(50, test_config.test_iter // 20)
|
|
adjusted_warmup = max(10, test_config.warm_up_iter // 10)
|
|
elif qlen >= 64:
|
|
adjusted_iterations = max(100, test_config.test_iter // 10)
|
|
adjusted_warmup = max(20, test_config.warm_up_iter // 5)
|
|
elif qlen >= 16:
|
|
adjusted_iterations = max(200, test_config.test_iter // 5)
|
|
adjusted_warmup = max(40, test_config.warm_up_iter // 2)
|
|
|
|
logger.info(f"Testing SGL INT8: qlen={qlen}, iterations={adjusted_iterations} (warmup={adjusted_warmup}), "
|
|
f"threads per NUMA: {thread_config.sgl_thread_count}")
|
|
|
|
script_content = f'''
|
|
import sys
|
|
sys.path.insert(0, "/home/xwy/Projects/sgl-cpu-tests")
|
|
|
|
import os
|
|
import torch
|
|
from sgl_kernel.common_ops import fused_experts_cpu as fused_experts
|
|
from sgl_kernel.common_ops import convert_weight_packed
|
|
import time
|
|
|
|
torch.manual_seed(1111)
|
|
M, N, K, E, topk = {qlen}, {sgl_intermediate_size}, {test_config.hidden_size}, {test_config.expert_num}, {test_config.num_experts_per_tok}
|
|
layer_num = {test_config.layer_num}
|
|
|
|
# Generate expert routing scores (different for each iteration)
|
|
gen_iter = 1000
|
|
all_topk_weights = []
|
|
all_topk_ids = []
|
|
|
|
for gen_idx in range(gen_iter):
|
|
score = torch.randn(M, E).to(dtype=torch.bfloat16)
|
|
score = torch.softmax(score, dim=-1, dtype=torch.float32)
|
|
topk_weight, topk_ids = torch.topk(score, topk)
|
|
all_topk_weights.append(topk_weight)
|
|
all_topk_ids.append(topk_ids.to(torch.int32))
|
|
|
|
prepack = True
|
|
inplace = True
|
|
use_int4_w4a16 = False
|
|
|
|
# Create multiple layers
|
|
print("Creating " + str(layer_num) + " MoE layers...")
|
|
inputs = []
|
|
packed_w1s_int8 = []
|
|
packed_w2s_int8 = []
|
|
w1_s_list = []
|
|
w2_s_list = []
|
|
|
|
for layer_idx in range(layer_num):
|
|
input_tensor = torch.randn(M, K).to(dtype=torch.bfloat16)
|
|
|
|
# int8 weights
|
|
w1_int8 = torch.randn(E, 2 * N, K).to(dtype=torch.int8)
|
|
w2_int8 = torch.randn(E, K, N).to(dtype=torch.int8)
|
|
packed_w1_int8 = convert_weight_packed(w1_int8)
|
|
packed_w2_int8 = convert_weight_packed(w2_int8)
|
|
w1_s = torch.rand(E, 2 * N)
|
|
w2_s = torch.rand(E, K)
|
|
|
|
inputs.append(input_tensor)
|
|
packed_w1s_int8.append(packed_w1_int8)
|
|
packed_w2s_int8.append(packed_w2_int8)
|
|
w1_s_list.append(w1_s)
|
|
w2_s_list.append(w2_s)
|
|
|
|
# Warmup
|
|
print("Running " + str({adjusted_warmup}) + " warmup iterations...")
|
|
for i in range({adjusted_warmup}):
|
|
layer_idx = i % layer_num
|
|
gen_idx = i % gen_iter
|
|
fused_experts(inputs[layer_idx], packed_w1s_int8[layer_idx], packed_w2s_int8[layer_idx],
|
|
all_topk_weights[gen_idx], all_topk_ids[gen_idx],
|
|
inplace, True, False, use_int4_w4a16, w1_s_list[layer_idx], w2_s_list[layer_idx],
|
|
None, None, None, None, None, prepack)
|
|
|
|
# Benchmark
|
|
print("Running " + str({adjusted_iterations}) + " benchmark iterations...")
|
|
start = time.perf_counter()
|
|
for i in range({adjusted_iterations}):
|
|
layer_idx = i % layer_num
|
|
gen_idx = i % gen_iter
|
|
fused_experts(inputs[layer_idx], packed_w1s_int8[layer_idx], packed_w2s_int8[layer_idx],
|
|
all_topk_weights[gen_idx], all_topk_ids[gen_idx],
|
|
inplace, True, False, use_int4_w4a16, w1_s_list[layer_idx], w2_s_list[layer_idx],
|
|
None, None, None, None, None, prepack)
|
|
end = time.perf_counter()
|
|
|
|
total_time = end - start
|
|
time_per_iter_us = total_time / {adjusted_iterations} * 1e6
|
|
|
|
# Calculate performance metrics for int8
|
|
bytes_per_elem = 1.0 # int8
|
|
memory_per_iter = (
|
|
{test_config.hidden_size} * {sgl_intermediate_size} * 3 * {test_config.num_experts_per_tok} *
|
|
(1/8 * {test_config.expert_num} * (1-(31/32)**{qlen})) * bytes_per_elem
|
|
)
|
|
bandwidth_gbs = memory_per_iter * {adjusted_iterations} / total_time / 1e9
|
|
|
|
# FLOPS calculation
|
|
flops_per_iter = {test_config.hidden_size} * {sgl_intermediate_size} * {qlen} * 3 * {test_config.num_experts_per_tok} * 2
|
|
tflops = flops_per_iter * {adjusted_iterations} / total_time / 1e12
|
|
|
|
print(f"SGL_RESULT:{{total_time}},{{time_per_iter_us}},{{bandwidth_gbs}},{{tflops}}")
|
|
'''
|
|
|
|
# Create temporary script
|
|
temp_script = f"/tmp/sgl_bench_{os.getpid()}_{qlen}.py"
|
|
|
|
try:
|
|
with open(temp_script, 'w') as f:
|
|
f.write(script_content)
|
|
|
|
# Setup environment
|
|
env = os.environ.copy()
|
|
env['MALLOC_CONF'] = env_config.malloc_conf
|
|
if os.path.exists(env_config.jemalloc_path):
|
|
env['LD_PRELOAD'] = env_config.jemalloc_path
|
|
env['OMP_NUM_THREADS'] = str(thread_config.sgl_thread_count)
|
|
|
|
# Run with numactl
|
|
cmd = f"{thread_config.numa_prefix} python3 {temp_script}"
|
|
logger.debug(f"Running SGL command: {cmd}")
|
|
|
|
result = subprocess.run(cmd, shell=True, capture_output=True, text=True, env=env, timeout=300)
|
|
|
|
if result.returncode == 0:
|
|
# Parse result
|
|
for line in result.stdout.split('\n'):
|
|
if line.startswith('SGL_RESULT:'):
|
|
parts = line.replace('SGL_RESULT:', '').split(',')
|
|
if len(parts) >= 4:
|
|
try:
|
|
total_time = float(parts[0])
|
|
time_per_iter_us = float(parts[1])
|
|
bandwidth_gbs = float(parts[2])
|
|
tflops = float(parts[3])
|
|
|
|
logger.info(f"SGL Results - Time: {total_time:.4f}s, Per-iter: {time_per_iter_us:.2f}μs, "
|
|
f"BW: {bandwidth_gbs:.2f} GB/s, TFLOPS: {tflops:.2f}")
|
|
|
|
return BenchmarkResult(
|
|
implementation="SGL",
|
|
quant_mode="int8",
|
|
qlen=qlen,
|
|
thread_count=thread_config.thread_count,
|
|
total_time=total_time,
|
|
time_per_iter_us=time_per_iter_us,
|
|
bandwidth_gbs=bandwidth_gbs,
|
|
tflops=tflops,
|
|
iterations=adjusted_iterations
|
|
)
|
|
except ValueError as e:
|
|
logger.error(f"Failed to parse SGL results: {e}")
|
|
else:
|
|
logger.error(f"SGL subprocess failed with code {result.returncode}: {result.stderr}")
|
|
|
|
except subprocess.TimeoutExpired:
|
|
logger.error("SGL benchmark timed out")
|
|
except Exception as e:
|
|
logger.error(f"SGL benchmark error: {e}", exc_info=True)
|
|
finally:
|
|
# Clean up
|
|
if os.path.exists(temp_script):
|
|
try:
|
|
os.remove(temp_script)
|
|
except:
|
|
pass
|
|
|
|
return None
|
|
|
|
def save_results(results: List[BenchmarkResult], test_config: TestConfig, filename: str = None) -> str:
|
|
"""Save benchmark results to JSON file"""
|
|
if not filename:
|
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
filename = f"moe_comparison_{timestamp}.json"
|
|
|
|
output_data = {
|
|
"timestamp": datetime.now().isoformat(),
|
|
"test_configuration": asdict(test_config),
|
|
"system_info": get_system_info(),
|
|
"results": [r.to_dict() for r in results],
|
|
"summary": {
|
|
"total_benchmarks": len(results),
|
|
"implementations_tested": list(set(r.implementation for r in results)),
|
|
"quantization_modes": list(set(r.quant_mode for r in results)),
|
|
"qlen_values_tested": sorted(set(r.qlen for r in results)),
|
|
"thread_counts_tested": sorted(set(r.thread_count for r in results))
|
|
}
|
|
}
|
|
|
|
with open(filename, 'w') as f:
|
|
json.dump(output_data, f, indent=2)
|
|
|
|
logger.info(f"Results saved to: {filename}")
|
|
return filename
|
|
|
|
def print_summary_table(results: List[BenchmarkResult]):
|
|
"""Print formatted summary table of results"""
|
|
if not results:
|
|
return
|
|
|
|
print("\n" + "=" * 100)
|
|
print("PERFORMANCE SUMMARY")
|
|
print("=" * 100)
|
|
print(f"{'Implementation':<15} {'Quant':<6} {'Threads':<8} {'QLen':<8} {'Time(μs)':<12} {'BW(GB/s)':<12} {'TFLOPS':<10} {'Speedup':<10}")
|
|
print("-" * 100)
|
|
|
|
# Group by configuration for better comparison
|
|
baseline_times = {}
|
|
|
|
for result in sorted(results, key=lambda r: (r.thread_count, r.qlen, r.implementation, r.quant_mode)):
|
|
key = (result.thread_count, result.qlen)
|
|
|
|
if key not in baseline_times:
|
|
baseline_times[key] = result.time_per_iter_us
|
|
speedup = "1.00x"
|
|
else:
|
|
speedup = f"{baseline_times[key]/result.time_per_iter_us:.2f}x"
|
|
|
|
print(f"{result.implementation:<15} {result.quant_mode:<6} {result.thread_count:<8} "
|
|
f"{result.qlen:<8} {result.time_per_iter_us:<12.2f} {result.bandwidth_gbs:<12.2f} "
|
|
f"{result.tflops:<10.2f} {speedup:<10}")
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description="Compare MoE performance between KTransformers and SGL")
|
|
parser.add_argument("--qlen", type=int, nargs="+", help="Sequence lengths to test")
|
|
parser.add_argument("--threads", type=int, nargs="+", help="Thread counts to test")
|
|
parser.add_argument("--iterations", type=int, help="Number of test iterations")
|
|
parser.add_argument("--warmup", type=int, help="Number of warmup iterations")
|
|
parser.add_argument("--output", type=str, help="Output filename for results")
|
|
parser.add_argument("--verbose", "-v", action="store_true", help="Enable verbose logging")
|
|
parser.add_argument("--resume", action="store_true", help="Resume from checkpoint if available")
|
|
parser.add_argument("--checkpoint-dir", type=str, help="Directory for checkpoint files")
|
|
parser.add_argument("--no-checkpoint", action="store_true", help="Disable checkpoint saving")
|
|
parser.add_argument("--framework", choices=["all", "ktransformers", "sgl"], default="all",
|
|
help="Framework to test (default: all)")
|
|
parser.add_argument("--precision", choices=["all", "int8", "int4"], default="all",
|
|
help="Precision to test (default: all)")
|
|
|
|
args = parser.parse_args()
|
|
|
|
# Configure logging level
|
|
if args.verbose:
|
|
logging.getLogger().setLevel(logging.DEBUG)
|
|
|
|
# Create test configuration
|
|
test_config = TestConfig()
|
|
if args.qlen:
|
|
test_config.qlen_values = args.qlen
|
|
if args.threads:
|
|
test_config.thread_count_values = args.threads
|
|
if args.iterations:
|
|
test_config.test_iter = args.iterations
|
|
if args.warmup:
|
|
test_config.warm_up_iter = args.warmup
|
|
|
|
# Determine which frameworks and precisions to test
|
|
test_ktransformers = args.framework in ["all", "ktransformers"] and KTRANSFORMERS_AVAILABLE
|
|
test_sgl = args.framework in ["all", "sgl"] and (SGL_AVAILABLE or SGL_INT4_AVAILABLE)
|
|
|
|
# Determine which precisions to test
|
|
test_precisions = []
|
|
if args.precision == "all":
|
|
test_precisions = ["int8", "int4"]
|
|
else:
|
|
test_precisions = [args.precision]
|
|
|
|
# Print configuration
|
|
logger.info("MoE Performance Comparison")
|
|
logger.info("=" * 60)
|
|
logger.info(f"System configuration:")
|
|
logger.info(f" CPU cores: {sys_config.cpu_cores}")
|
|
logger.info(f" NUMA nodes: {sys_config.numa_count}")
|
|
logger.info(f"Test parameters:")
|
|
logger.info(f" Expert count: {test_config.expert_num}")
|
|
logger.info(f" Hidden size: {test_config.hidden_size}")
|
|
logger.info(f" Intermediate size: {test_config.intermediate_size}")
|
|
logger.info(f" Experts per token: {test_config.num_experts_per_tok}")
|
|
logger.info(f" Test iterations: {test_config.test_iter}")
|
|
logger.info(f" Warmup iterations: {test_config.warm_up_iter}")
|
|
logger.info(f"Testing configurations:")
|
|
logger.info(f" QLEN values: {test_config.qlen_values}")
|
|
logger.info(f" Thread counts: {test_config.thread_count_values}")
|
|
logger.info(f" Frameworks: {args.framework}")
|
|
logger.info(f" Precisions: {args.precision}")
|
|
logger.info(f" Total configs: {test_config.total_configurations}")
|
|
print()
|
|
|
|
# Check availability
|
|
if not KTRANSFORMERS_AVAILABLE and not SGL_AVAILABLE:
|
|
logger.error("Neither KTransformers nor SGL is available. Cannot run benchmarks.")
|
|
return 1
|
|
|
|
# Initialize checkpoint manager
|
|
checkpoint_mgr = CheckpointManager(args.checkpoint_dir) if not args.no_checkpoint else None
|
|
|
|
# Load checkpoint if resuming
|
|
checkpoint_state = None
|
|
completed_configs = set()
|
|
all_results = []
|
|
start_time = datetime.now().isoformat()
|
|
|
|
if args.resume and checkpoint_mgr:
|
|
checkpoint_state = checkpoint_mgr.load_checkpoint()
|
|
if checkpoint_state:
|
|
# Verify configuration matches
|
|
if (checkpoint_state.test_config.qlen_values != test_config.qlen_values or
|
|
checkpoint_state.test_config.thread_count_values != test_config.thread_count_values):
|
|
logger.warning("Checkpoint configuration doesn't match current configuration")
|
|
response = input("Continue with checkpoint anyway? (y/n): ")
|
|
if response.lower() != 'y':
|
|
logger.info("Starting fresh run")
|
|
checkpoint_state = None
|
|
|
|
if checkpoint_state:
|
|
all_results = checkpoint_state.results
|
|
completed_configs = set(checkpoint_state.completed_configs)
|
|
start_time = checkpoint_state.start_time
|
|
logger.info(f"Resuming from checkpoint with {len(all_results)} results")
|
|
|
|
# Create checkpoint state if not loaded
|
|
if not checkpoint_state and checkpoint_mgr:
|
|
checkpoint_state = CheckpointState(
|
|
test_config=test_config,
|
|
completed_configs=[],
|
|
results=[],
|
|
start_time=start_time,
|
|
last_update=start_time
|
|
)
|
|
|
|
config_count = 0
|
|
total_configs_to_run = 0
|
|
|
|
# Calculate total configs to run
|
|
for thread_count in test_config.thread_count_values:
|
|
for qlen in test_config.qlen_values:
|
|
if test_ktransformers:
|
|
for quant_mode in test_precisions:
|
|
if (thread_count, qlen, "KTransformers", quant_mode) not in completed_configs:
|
|
total_configs_to_run += 1
|
|
if test_sgl:
|
|
if "int8" in test_precisions and SGL_AVAILABLE:
|
|
if (thread_count, qlen, "SGL", "int8") not in completed_configs:
|
|
total_configs_to_run += 1
|
|
if "int4" in test_precisions and SGL_INT4_AVAILABLE:
|
|
if (thread_count, qlen, "SGL", "int4") not in completed_configs:
|
|
total_configs_to_run += 1
|
|
|
|
logger.info(f"Total configurations to run: {total_configs_to_run}")
|
|
|
|
# Test all combinations
|
|
for thread_count in test_config.thread_count_values:
|
|
thread_config = ThreadConfig.from_thread_count(thread_count, sys_config.numa_count, sys_config.cpu_cores)
|
|
logger.info(f"\nThread Configuration: {thread_count} total ({thread_config.threads_per_numa} per NUMA)")
|
|
|
|
for qlen in test_config.qlen_values:
|
|
# Check for interrupt
|
|
if checkpoint_mgr and checkpoint_mgr.interrupted:
|
|
logger.warning("Interrupt detected, saving checkpoint and exiting...")
|
|
if checkpoint_state:
|
|
checkpoint_state.results = all_results
|
|
checkpoint_state.completed_configs = list(completed_configs)
|
|
checkpoint_mgr.save_checkpoint(checkpoint_state)
|
|
return 2
|
|
|
|
logger.info(f"\n--- Configuration: threads={thread_count}, qlen={qlen} ---")
|
|
|
|
# Test KTransformers
|
|
if test_ktransformers:
|
|
for quant_mode in test_precisions:
|
|
config_key = (thread_count, qlen, "KTransformers", quant_mode)
|
|
if config_key in completed_configs:
|
|
logger.info(f"Skipping already completed: KTransformers-{quant_mode}")
|
|
continue
|
|
|
|
config_count += 1
|
|
logger.info(f"Progress: {config_count}/{total_configs_to_run}")
|
|
|
|
result = bench_ktransformers_moe(test_config, quant_mode, qlen, thread_config)
|
|
if result:
|
|
all_results.append(result)
|
|
completed_configs.add(config_key)
|
|
|
|
# Save checkpoint after each successful test
|
|
if checkpoint_mgr and checkpoint_state:
|
|
checkpoint_state.results = all_results
|
|
checkpoint_state.completed_configs = list(completed_configs)
|
|
checkpoint_mgr.save_checkpoint(checkpoint_state)
|
|
|
|
# Test SGL int8
|
|
if test_sgl and "int8" in test_precisions and SGL_AVAILABLE:
|
|
config_key = (thread_count, qlen, "SGL", "int8")
|
|
if config_key in completed_configs:
|
|
logger.info("Skipping already completed: SGL-int8")
|
|
continue
|
|
|
|
config_count += 1
|
|
logger.info(f"Progress: {config_count}/{total_configs_to_run}")
|
|
|
|
logger.info(f"Testing SGL MoE (int8): qlen={qlen}, threads={thread_count}")
|
|
sgl_intermediate = test_config.intermediate_size // sys_config.numa_count
|
|
sgl_threads_per_numa = thread_config.sgl_thread_count
|
|
logger.info(f"Using NUMA TP: intermediate_size {test_config.intermediate_size} -> "
|
|
f"{sgl_intermediate} (/{sys_config.numa_count}), threads per NUMA: {sgl_threads_per_numa}")
|
|
|
|
result = run_sgl_with_numactl(test_config, qlen, thread_config)
|
|
if result:
|
|
all_results.append(result)
|
|
completed_configs.add(config_key)
|
|
|
|
# Save checkpoint after each successful test
|
|
if checkpoint_mgr and checkpoint_state:
|
|
checkpoint_state.results = all_results
|
|
checkpoint_state.completed_configs = list(completed_configs)
|
|
checkpoint_mgr.save_checkpoint(checkpoint_state)
|
|
|
|
# Test SGL int4
|
|
if test_sgl and "int4" in test_precisions and SGL_INT4_AVAILABLE:
|
|
config_key = (thread_count, qlen, "SGL", "int4")
|
|
if config_key in completed_configs:
|
|
logger.info("Skipping already completed: SGL-int4")
|
|
continue
|
|
|
|
config_count += 1
|
|
logger.info(f"Progress: {config_count}/{total_configs_to_run}")
|
|
|
|
logger.info(f"Testing SGL MoE (int4): qlen={qlen}, threads={thread_count}")
|
|
sgl_intermediate = test_config.intermediate_size // sys_config.numa_count
|
|
sgl_threads_per_numa = thread_config.sgl_thread_count
|
|
logger.info(f"Using NUMA TP: intermediate_size {test_config.intermediate_size} -> "
|
|
f"{sgl_intermediate} (/{sys_config.numa_count}), threads per NUMA: {sgl_threads_per_numa}")
|
|
|
|
result = run_sgl_int4_with_numactl(test_config, qlen, thread_config)
|
|
if result:
|
|
all_results.append(result)
|
|
completed_configs.add(config_key)
|
|
|
|
# Save checkpoint after each successful test
|
|
if checkpoint_mgr and checkpoint_state:
|
|
checkpoint_state.results = all_results
|
|
checkpoint_state.completed_configs = list(completed_configs)
|
|
checkpoint_mgr.save_checkpoint(checkpoint_state)
|
|
|
|
# Final summary
|
|
if all_results:
|
|
print_summary_table(all_results)
|
|
|
|
# Save results
|
|
output_file = save_results(all_results, test_config, args.output)
|
|
|
|
print(f"\nTotal benchmarks completed: {len(all_results)}")
|
|
print(f"Results saved to: {output_file}")
|
|
|
|
# Clear checkpoint on successful completion
|
|
if checkpoint_mgr and config_count == total_configs_to_run:
|
|
checkpoint_mgr.clear_checkpoint()
|
|
logger.info("All tests completed successfully, checkpoint cleared")
|
|
elif checkpoint_mgr and config_count < total_configs_to_run:
|
|
logger.warning(f"Only {config_count}/{total_configs_to_run} configurations completed")
|
|
logger.info("Checkpoint preserved for resuming")
|
|
|
|
# Print best performers per configuration
|
|
print("\nBest performers by configuration:")
|
|
from itertools import groupby
|
|
|
|
sorted_results = sorted(all_results, key=lambda r: (r.qlen, r.thread_count, r.time_per_iter_us))
|
|
for key, group in groupby(sorted_results, key=lambda r: (r.qlen, r.thread_count)):
|
|
qlen, threads = key
|
|
best = next(group)
|
|
print(f" QLen={qlen}, Threads={threads}: {best.implementation}-{best.quant_mode} "
|
|
f"({best.time_per_iter_us:.2f}μs, {best.tflops:.2f} TFLOPS)")
|
|
else:
|
|
logger.error("No successful benchmarks completed.")
|
|
return 1
|
|
|
|
return 0
|
|
|
|
if __name__ == "__main__":
|
|
sys.exit(main()) |