mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2026-03-15 02:47:22 +00:00
@@ -38,6 +38,7 @@ High-performance kernel operations for KTransformers, featuring CPU-optimized Mo
|
||||
- ✅ **Universal CPU (llamafile backend)**: Supported (using GGUF-format weights)
|
||||
- ✅ **AMD CPUs with BLIS**: Supported (for int8 prefill & decode)
|
||||
- ✅ **Kimi-K2 Native INT4 (RAWINT4)**: Supported on AVX512 CPUs (CPU-GPU shared INT4 weights) - [Guide](../doc/en/Kimi-K2-Thinking-Native.md)
|
||||
- ✅ **FP8 weights (e.g., MiniMax-M2.1)**: Supported on AVX512 CPUs (CPU-GPU shared FP8 weights) - [Guide](../doc/en/MiniMax-M2.1-Tutorial.md)
|
||||
|
||||
## Features
|
||||
|
||||
@@ -167,10 +168,57 @@ Simply run the install script - it will auto-detect your CPU and optimize for be
|
||||
|
||||
## Verification
|
||||
|
||||
After installation, verify that the CLI is working:
|
||||
|
||||
```bash
|
||||
kt version
|
||||
```
|
||||
|
||||
Expected output:
|
||||
```
|
||||
KTransformers CLI v0.x.x
|
||||
|
||||
Python: 3.11.x
|
||||
Platform: Linux 5.15.0-xxx-generic
|
||||
CUDA: 12.x
|
||||
kt-kernel: 0.x.x (amx)
|
||||
sglang: 0.x.x
|
||||
```
|
||||
|
||||
You can also verify the Python module directly:
|
||||
|
||||
```bash
|
||||
python -c "from kt_kernel import KTMoEWrapper; print('✓ kt-kernel installed successfully')"
|
||||
```
|
||||
|
||||
## KT CLI Overview
|
||||
|
||||
The `kt` command-line tool provides a unified interface for running and managing KTransformers models:
|
||||
|
||||
| Command | Description |
|
||||
|---------|-------------|
|
||||
| `kt run <model>` | Start model inference server with auto-optimized parameters |
|
||||
| `kt chat` | Interactive chat with a running model server |
|
||||
| `kt model` | Manage models and storage paths |
|
||||
| `kt doctor` | Diagnose environment issues and check system compatibility |
|
||||
| `kt config` | Manage CLI configuration |
|
||||
| `kt version` | Show version information |
|
||||
|
||||
**Quick Start Example:**
|
||||
|
||||
```bash
|
||||
# Start a model server (auto-detects hardware and applies optimal settings)
|
||||
kt run m2
|
||||
|
||||
# In another terminal, chat with the model
|
||||
kt chat
|
||||
|
||||
# Check system compatibility
|
||||
kt doctor
|
||||
```
|
||||
|
||||
Run `kt --help` for more options, or `kt <command> --help` for command-specific help.
|
||||
|
||||
## Integration with SGLang
|
||||
|
||||
KT-Kernel can be used standalone via [Direct Python API](#direct-python-api-usage) or integrated with SGLang for production deployment. This section describes SGLang integration to enable CPU-GPU heterogeneous inference, where "hot" experts run on GPU and "cold" experts run on CPU for optimal resource utilization.
|
||||
@@ -361,13 +409,13 @@ python -m sglang.launch_server \
|
||||
|
||||
| Parameter | Description | Example Value |
|
||||
|-----------|-------------|---------------|
|
||||
| `--kt-method` | CPU inference backend method | `AMXINT4`, `AMXINT8`, `RAWINT4`, or `LLAMAFILE` |
|
||||
| `--kt-method` | CPU inference backend method | `AMXINT4`, `AMXINT8`, `RAWINT4`, `FP8` or `LLAMAFILE` |
|
||||
| `--kt-weight-path` | Path to quantized CPU weights | `/path/to/cpu-weights` |
|
||||
| `--kt-cpuinfer` | Number of CPU inference threads | `64` (adjust based on CPU cores) |
|
||||
| `--kt-threadpool-count` | Number of thread pools for parallel execution | `2` (typically 1-4) |
|
||||
| `--kt-num-gpu-experts` | Number of experts to keep on GPU | `32` (remaining experts go to CPU) |
|
||||
| `--kt-max-deferred-experts-per-token` | Number of experts per token to defer for pipelined execution | `2` (0 to disable, 1-4 recommended) |
|
||||
| `--kt-gpu-prefill-token-threshold` | Token count threshold for prefill strategy (RAWINT4 only) | ~`400` |
|
||||
| `--kt-gpu-prefill-token-threshold` | Token count threshold for prefill strategy (FP8 and RAWINT4 only) | ~`1024` |
|
||||
|
||||
**Parameter Guidelines:**
|
||||
|
||||
@@ -375,6 +423,7 @@ python -m sglang.launch_server \
|
||||
- `AMXINT4`: Best performance on AMX CPUs with INT4 quantized weights (May cause huge accuracy drop for some models, e.g., Qwen3-30B-A3B)
|
||||
- `AMXINT8`: Higher accuracy with INT8 quantized weights on AMX CPUs
|
||||
- `RAWINT4`: Native INT4 weights shared by CPU and GPU (AMX backend only, currently supports Kimi-K2-Thinking model). See [Kimi-K2-Thinking Native Tutorial](../doc/en/Kimi-K2-Thinking-Native.md) for details.
|
||||
- `FP8`: FP8 weights shared by CPU and GPU
|
||||
- `LLAMAFILE`: GGUF-based backend
|
||||
|
||||
- **`kt-cpuinfer`**: Set to the number of **physical CPU cores** (not hyperthreads).
|
||||
@@ -400,10 +449,10 @@ python -m sglang.launch_server \
|
||||
- `1-4`: Deferred execution (recommended range; good latency/quality balance, requires tuning)
|
||||
- `5-7`: Highest latency reduction but may introduce noticeable accuracy loss; use with care
|
||||
|
||||
- **`kt-gpu-prefill-token-threshold`** (RAWINT4 only): Controls prefill strategy for native INT4 inference:
|
||||
- **`kt-gpu-prefill-token-threshold`** (FP8 and RAWINT4 only): Controls prefill strategy for native FP8 and INT4 inference:
|
||||
- **≤ threshold**: Uses hybrid CPU+GPU prefill. No extra VRAM needed, but performance degrades slowly as token count increases.
|
||||
- **> threshold**: Uses layerwise GPU prefill. Performance scales better with longer sequences, but requires ~9GB+ extra VRAM.
|
||||
- Only applicable when `--kt-method RAWINT4` is used. Currently supports Kimi-K2-Thinking model only.
|
||||
- **> threshold**: Uses layerwise GPU prefill. Performance scales better with longer sequences, but requires one MoE layer extra VRAM (e.g., ~9GB+ for Kimi-K2-Thinking and ~3.6GB for MiniMax-M2.1).
|
||||
- Only applicable when `--kt-method RAWINT4` or `--kt-method FP8` is used.
|
||||
|
||||
## Direct Python API Usage
|
||||
|
||||
|
||||
286
kt-kernel/bench/bench_fp8_moe.py
Normal file
286
kt-kernel/bench/bench_fp8_moe.py
Normal file
@@ -0,0 +1,286 @@
|
||||
"""
|
||||
Performance benchmark for FP8 MoE kernel (AVX implementation).
|
||||
|
||||
This benchmark measures the performance of the FP8 MoE operator with:
|
||||
- FP8 (E4M3) weights with 128x128 block-wise scaling
|
||||
- BF16 activations
|
||||
- AVX-512 DPBF16 compute path
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import json
|
||||
import subprocess
|
||||
import platform
|
||||
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "build"))
|
||||
|
||||
import torch
|
||||
import kt_kernel_ext
|
||||
from tqdm import tqdm
|
||||
|
||||
# Test parameters
|
||||
expert_num = 256
|
||||
hidden_size = 7168
|
||||
intermediate_size = 2048
|
||||
num_experts_per_tok = 8
|
||||
fp8_group_size = 128
|
||||
max_len = 25600
|
||||
|
||||
layer_num = 2
|
||||
qlen = 1024
|
||||
warm_up_iter = 10
|
||||
test_iter = 30
|
||||
CPUINFER_PARAM = 80
|
||||
|
||||
CPUInfer = kt_kernel_ext.CPUInfer(CPUINFER_PARAM)
|
||||
|
||||
# Result file path
|
||||
script_path = os.path.abspath(__file__)
|
||||
script_dir = os.path.dirname(script_path)
|
||||
json_path = os.path.join(script_dir, "bench_results.jsonl")
|
||||
|
||||
|
||||
def get_git_commit():
|
||||
"""Get current git commit info"""
|
||||
result = {}
|
||||
try:
|
||||
commit = subprocess.check_output(["git", "rev-parse", "HEAD"]).decode("utf-8").strip()
|
||||
commit_msg = subprocess.check_output(["git", "log", "-1", "--pretty=%B"]).decode("utf-8").strip()
|
||||
result["commit"] = commit
|
||||
result["commit_message"] = commit_msg
|
||||
dirty_output = subprocess.check_output(["git", "status", "--porcelain"]).decode("utf-8").strip()
|
||||
result["dirty"] = bool(dirty_output)
|
||||
if dirty_output:
|
||||
result["dirty_files"] = dirty_output.splitlines()
|
||||
except Exception as e:
|
||||
result["commit"] = None
|
||||
result["error"] = str(e)
|
||||
return result
|
||||
|
||||
|
||||
def get_system_info():
|
||||
"""Get system information"""
|
||||
info = {}
|
||||
uname = platform.uname()
|
||||
info["system_name"] = uname.system
|
||||
info["node_name"] = uname.node
|
||||
|
||||
cpu_model = None
|
||||
if os.path.exists("/proc/cpuinfo"):
|
||||
try:
|
||||
with open("/proc/cpuinfo", "r") as f:
|
||||
for line in f:
|
||||
if "model name" in line:
|
||||
cpu_model = line.split(":", 1)[1].strip()
|
||||
break
|
||||
except Exception:
|
||||
pass
|
||||
info["cpu_model"] = cpu_model
|
||||
info["cpu_core_count"] = os.cpu_count()
|
||||
return info
|
||||
|
||||
|
||||
def record_results(result, filename=json_path):
|
||||
"""Append result to JSON file"""
|
||||
with open(filename, "a") as f:
|
||||
f.write(json.dumps(result) + "\n")
|
||||
|
||||
|
||||
def generate_fp8_weights_direct(shape: tuple, group_size: int = 128):
|
||||
"""
|
||||
Directly generate random FP8 weights and e8m0 format scale_inv.
|
||||
|
||||
Args:
|
||||
shape: (expert_num, n, k) - weight tensor shape
|
||||
group_size: block size for scaling (128x128 blocks)
|
||||
|
||||
Returns:
|
||||
fp8_weights: uint8 tensor with random FP8 E4M3 values
|
||||
scale_inv: fp32 tensor with e8m0 format (powers of 2)
|
||||
"""
|
||||
e, n, k = shape
|
||||
n_blocks = n // group_size
|
||||
k_blocks = k // group_size
|
||||
|
||||
# Directly generate random FP8 weights as uint8
|
||||
# FP8 E4M3 format: 1 sign + 4 exp + 3 mantissa
|
||||
# Valid range for normal numbers: exp 1-14 (0 is subnormal, 15 is special)
|
||||
fp8_weights = torch.randint(0, 256, (e, n, k), dtype=torch.uint8, device="cuda").to("cpu").contiguous()
|
||||
|
||||
# Generate e8m0 format scale_inv (powers of 2)
|
||||
# e8m0: 8-bit exponent only, no mantissa, bias = 127
|
||||
# Generate random exponents in a reasonable range (e.g., -8 to 8)
|
||||
exponents = torch.randint(-8, 9, (e, n_blocks, k_blocks), dtype=torch.int32, device="cuda").to("cpu").contiguous()
|
||||
scale_inv = (2.0 ** exponents.float()).to(torch.float32).contiguous()
|
||||
|
||||
return fp8_weights, scale_inv
|
||||
|
||||
|
||||
def bench_fp8_moe():
|
||||
"""Benchmark FP8 MoE performance"""
|
||||
with torch.inference_mode():
|
||||
print("=" * 70)
|
||||
print("FP8 MoE Kernel Performance Benchmark")
|
||||
print("=" * 70)
|
||||
|
||||
# Generate FP8 weights directly (no quantization from fp32)
|
||||
print("\nGenerating FP8 weights directly...")
|
||||
torch.manual_seed(42)
|
||||
gate_fp8, gate_scales = generate_fp8_weights_direct(
|
||||
(expert_num, intermediate_size, hidden_size), fp8_group_size
|
||||
)
|
||||
up_fp8, up_scales = generate_fp8_weights_direct((expert_num, intermediate_size, hidden_size), fp8_group_size)
|
||||
down_fp8, down_scales = generate_fp8_weights_direct(
|
||||
(expert_num, hidden_size, intermediate_size), fp8_group_size
|
||||
)
|
||||
|
||||
physical_to_logical_map = torch.tensor(range(expert_num), device="cpu", dtype=torch.int64).contiguous()
|
||||
|
||||
# Build MoE layers
|
||||
print("Building FP8 MoE layers...")
|
||||
moes = []
|
||||
for _ in tqdm(range(layer_num), desc="Initializing MOEs"):
|
||||
config = kt_kernel_ext.moe.MOEConfig(expert_num, num_experts_per_tok, hidden_size, intermediate_size, 0)
|
||||
config.max_len = max_len
|
||||
config.quant_config.bits = 8
|
||||
config.quant_config.group_size = fp8_group_size
|
||||
config.quant_config.zero_point = False
|
||||
|
||||
config.gate_proj = gate_fp8.data_ptr()
|
||||
config.up_proj = up_fp8.data_ptr()
|
||||
config.down_proj = down_fp8.data_ptr()
|
||||
config.gate_scale = gate_scales.data_ptr()
|
||||
config.up_scale = up_scales.data_ptr()
|
||||
config.down_scale = down_scales.data_ptr()
|
||||
config.pool = CPUInfer.backend_
|
||||
|
||||
moe = kt_kernel_ext.moe.AMXFP8_MOE(config)
|
||||
CPUInfer.submit(moe.load_weights_task(physical_to_logical_map.data_ptr()))
|
||||
CPUInfer.sync()
|
||||
moes.append(moe)
|
||||
|
||||
# Generate input data
|
||||
print("Generating input data...")
|
||||
gen_iter = 1000
|
||||
expert_ids = (
|
||||
torch.rand(gen_iter * qlen, expert_num, device="cpu")
|
||||
.argsort(dim=-1)[:, :num_experts_per_tok]
|
||||
.reshape(gen_iter, qlen * num_experts_per_tok)
|
||||
.contiguous()
|
||||
)
|
||||
weights = torch.rand((gen_iter, qlen, num_experts_per_tok), dtype=torch.float32, device="cpu").contiguous()
|
||||
input_tensor = torch.randn((layer_num, qlen, hidden_size), dtype=torch.bfloat16, device="cpu").contiguous()
|
||||
output_tensor = torch.empty((layer_num, qlen, hidden_size), dtype=torch.bfloat16, device="cpu").contiguous()
|
||||
qlen_tensor = torch.tensor([qlen], dtype=torch.int32)
|
||||
|
||||
# Warmup
|
||||
print(f"Warming up ({warm_up_iter} iterations)...")
|
||||
for i in tqdm(range(warm_up_iter), desc="Warm-up"):
|
||||
CPUInfer.submit(
|
||||
moes[i % layer_num].forward_task(
|
||||
qlen_tensor.data_ptr(),
|
||||
num_experts_per_tok,
|
||||
expert_ids[i % gen_iter].data_ptr(),
|
||||
weights[i % gen_iter].data_ptr(),
|
||||
input_tensor[i % layer_num].data_ptr(),
|
||||
output_tensor[i % layer_num].data_ptr(),
|
||||
False,
|
||||
)
|
||||
)
|
||||
CPUInfer.sync()
|
||||
|
||||
# Benchmark
|
||||
print(f"Running benchmark ({test_iter} iterations)...")
|
||||
start = time.perf_counter()
|
||||
for i in tqdm(range(test_iter), desc="Testing"):
|
||||
CPUInfer.submit(
|
||||
moes[i % layer_num].forward_task(
|
||||
qlen_tensor.data_ptr(),
|
||||
num_experts_per_tok,
|
||||
expert_ids[i % gen_iter].data_ptr(),
|
||||
weights[i % gen_iter].data_ptr(),
|
||||
input_tensor[i % layer_num].data_ptr(),
|
||||
output_tensor[i % layer_num].data_ptr(),
|
||||
False,
|
||||
)
|
||||
)
|
||||
CPUInfer.sync()
|
||||
end = time.perf_counter()
|
||||
total_time = end - start
|
||||
|
||||
# Calculate metrics
|
||||
time_per_iter_us = total_time / test_iter * 1e6
|
||||
|
||||
# FLOPS calculation:
|
||||
# Each expert performs: gate(intermediate x hidden) + up(intermediate x hidden) + down(hidden x intermediate)
|
||||
# GEMM/GEMV: 2 * m * n * k flops (multiply + accumulate = 2 ops per element)
|
||||
# For vector-matrix multiply (qlen=1): 2 * n * k per matrix
|
||||
flops_per_expert = (
|
||||
2 * intermediate_size * hidden_size # gate
|
||||
+ 2 * intermediate_size * hidden_size # up
|
||||
+ 2 * hidden_size * intermediate_size # down
|
||||
)
|
||||
total_flops = qlen * num_experts_per_tok * flops_per_expert * test_iter
|
||||
tflops = total_flops / total_time / 1e12
|
||||
|
||||
# Bandwidth calculation (FP8 = 1 byte per element)
|
||||
bytes_per_elem = 1.0
|
||||
# Weight memory: gate + up + down per expert
|
||||
bandwidth = (
|
||||
hidden_size
|
||||
* intermediate_size
|
||||
* 3
|
||||
* num_experts_per_tok
|
||||
* (1 / num_experts_per_tok * expert_num * (1 - (1 - num_experts_per_tok / expert_num) ** qlen))
|
||||
* bytes_per_elem
|
||||
* test_iter
|
||||
/ total_time
|
||||
/ 1e9
|
||||
) # 单位:GB/s
|
||||
|
||||
# Print results
|
||||
print("\n" + "=" * 70)
|
||||
print("Benchmark Results")
|
||||
print("=" * 70)
|
||||
print(f"Quant mode: FP8 (E4M3) with {fp8_group_size}x{fp8_group_size} block scaling")
|
||||
print(f"Total time: {total_time:.4f} s")
|
||||
print(f"Iterations: {test_iter}")
|
||||
print(f"Time per iteration: {time_per_iter_us:.2f} us")
|
||||
print(f"Bandwidth: {bandwidth:.2f} GB/s")
|
||||
print(f"TFLOPS: {tflops:.4f}")
|
||||
print("")
|
||||
|
||||
# Record results
|
||||
result = {
|
||||
"test_name": os.path.basename(__file__),
|
||||
"quant_mode": "fp8_e4m3",
|
||||
"total_time_seconds": total_time,
|
||||
"iterations": test_iter,
|
||||
"time_per_iteration_us": time_per_iter_us,
|
||||
"bandwidth_GBs": bandwidth,
|
||||
"flops_TFLOPS": tflops,
|
||||
"timestamp": time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()),
|
||||
"test_parameters": {
|
||||
"expert_num": expert_num,
|
||||
"hidden_size": hidden_size,
|
||||
"intermediate_size": intermediate_size,
|
||||
"num_experts_per_tok": num_experts_per_tok,
|
||||
"fp8_group_size": fp8_group_size,
|
||||
"layer_num": layer_num,
|
||||
"qlen": qlen,
|
||||
"warm_up_iter": warm_up_iter,
|
||||
"test_iter": test_iter,
|
||||
"CPUInfer_parameter": CPUINFER_PARAM,
|
||||
},
|
||||
}
|
||||
result.update(get_git_commit())
|
||||
result.update(get_system_info())
|
||||
record_results(result)
|
||||
|
||||
return tflops, bandwidth
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
bench_fp8_moe()
|
||||
294
kt-kernel/bench/bench_fp8_write_buffer.py
Normal file
294
kt-kernel/bench/bench_fp8_write_buffer.py
Normal file
@@ -0,0 +1,294 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
"""
|
||||
Benchmark write_weight_scale_to_buffer for AMX_FP8_MOE_TP (FP8 weights + float32 scales).
|
||||
|
||||
Uses two MOE instances that alternate writing to simulate realistic multi-layer scenarios.
|
||||
"""
|
||||
import json
|
||||
import os
|
||||
import platform
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "build"))
|
||||
|
||||
from kt_kernel import kt_kernel_ext
|
||||
from kt_kernel_ext.moe import AMXFP8_MOE
|
||||
import torch
|
||||
|
||||
# Benchmark parameters
|
||||
expert_num = 256
|
||||
num_experts_per_tok = 8
|
||||
gpu_tp_count = 2
|
||||
|
||||
warm_up_iter = 3
|
||||
test_iter = 7
|
||||
|
||||
gpu_experts_num = expert_num
|
||||
|
||||
hidden_size = 7168
|
||||
intermediate_size = 2048
|
||||
group_size = 128 # FP8 uses 128x128 block-wise scales
|
||||
max_len = 1
|
||||
|
||||
physical_to_logical_map = torch.arange(expert_num, dtype=torch.int64, device="cpu").contiguous()
|
||||
CPUInfer = kt_kernel_ext.CPUInfer(80)
|
||||
|
||||
|
||||
def get_git_commit():
|
||||
result = {}
|
||||
try:
|
||||
commit = subprocess.check_output(["git", "rev-parse", "HEAD"]).decode("utf-8").strip()
|
||||
commit_msg = subprocess.check_output(["git", "log", "-1", "--pretty=%B"]).decode("utf-8").strip()
|
||||
result["commit"] = commit
|
||||
result["commit_message"] = commit_msg
|
||||
dirty_output = subprocess.check_output(["git", "status", "--porcelain"]).decode("utf-8").strip()
|
||||
result["dirty"] = bool(dirty_output)
|
||||
if dirty_output:
|
||||
result["dirty_files"] = dirty_output.splitlines()
|
||||
except Exception as e:
|
||||
result["error"] = str(e)
|
||||
return result
|
||||
|
||||
|
||||
def get_system_info():
|
||||
info = {}
|
||||
info["system_name"] = platform.uname().system
|
||||
info["node_name"] = platform.uname().node
|
||||
info["cpu_core_count"] = os.cpu_count()
|
||||
if os.path.exists("/proc/cpuinfo"):
|
||||
with open("/proc/cpuinfo", "r") as f:
|
||||
for line in f:
|
||||
if "model name" in line:
|
||||
info["cpu_model"] = line.split(":", 1)[1].strip()
|
||||
break
|
||||
if os.path.exists("/proc/meminfo"):
|
||||
with open("/proc/meminfo", "r") as f:
|
||||
for line in f:
|
||||
if "MemTotal" in line:
|
||||
mem_kb = float(line.split(":", 1)[1].split()[0])
|
||||
info["memory_size_GB"] = round(mem_kb / (1024 * 1024), 2)
|
||||
break
|
||||
return info
|
||||
|
||||
|
||||
script_path = os.path.abspath(__file__)
|
||||
script_dir = os.path.dirname(script_path)
|
||||
script_name = os.path.splitext(os.path.basename(script_path))[0]
|
||||
json_path = os.path.join(script_dir, script_name + ".jsonl")
|
||||
|
||||
|
||||
def record_results(result, filename=json_path):
|
||||
with open(filename, "a") as f:
|
||||
f.write(json.dumps(result) + "\n")
|
||||
|
||||
|
||||
def allocate_weights():
|
||||
per_mat_weight_bytes = hidden_size * intermediate_size
|
||||
n_blocks_n_gate_up = (intermediate_size + group_size - 1) // group_size
|
||||
n_blocks_k = (hidden_size + group_size - 1) // group_size
|
||||
per_mat_scale_elems_gate_up = n_blocks_n_gate_up * n_blocks_k
|
||||
per_mat_scale_elems_down = n_blocks_k * n_blocks_n_gate_up
|
||||
|
||||
gate_q = (
|
||||
torch.randint(0, 256, (expert_num * per_mat_weight_bytes,), dtype=torch.uint8, device="cuda")
|
||||
.to("cpu")
|
||||
.contiguous()
|
||||
)
|
||||
up_q = (
|
||||
torch.randint(0, 256, (expert_num * per_mat_weight_bytes,), dtype=torch.uint8, device="cuda")
|
||||
.to("cpu")
|
||||
.contiguous()
|
||||
)
|
||||
down_q = (
|
||||
torch.randint(0, 256, (expert_num * per_mat_weight_bytes,), dtype=torch.uint8, device="cuda")
|
||||
.to("cpu")
|
||||
.contiguous()
|
||||
)
|
||||
gate_scale = (
|
||||
torch.randn(expert_num * per_mat_scale_elems_gate_up, dtype=torch.float32, device="cuda").to("cpu").contiguous()
|
||||
)
|
||||
up_scale = (
|
||||
torch.randn(expert_num * per_mat_scale_elems_gate_up, dtype=torch.float32, device="cuda").to("cpu").contiguous()
|
||||
)
|
||||
down_scale = (
|
||||
torch.randn(expert_num * per_mat_scale_elems_down, dtype=torch.float32, device="cuda").to("cpu").contiguous()
|
||||
)
|
||||
|
||||
return (
|
||||
gate_q,
|
||||
up_q,
|
||||
down_q,
|
||||
gate_scale,
|
||||
up_scale,
|
||||
down_scale,
|
||||
per_mat_weight_bytes,
|
||||
per_mat_scale_elems_gate_up,
|
||||
per_mat_scale_elems_down,
|
||||
)
|
||||
|
||||
|
||||
def build_moe(layer_idx=0):
|
||||
"""Build a single MOE instance with the given layer_idx."""
|
||||
(
|
||||
gate_q,
|
||||
up_q,
|
||||
down_q,
|
||||
gate_scale,
|
||||
up_scale,
|
||||
down_scale,
|
||||
per_mat_weight_bytes,
|
||||
per_mat_scale_elems_gate_up,
|
||||
per_mat_scale_elems_down,
|
||||
) = allocate_weights()
|
||||
|
||||
config = kt_kernel_ext.moe.MOEConfig(expert_num, num_experts_per_tok, hidden_size, intermediate_size)
|
||||
config.max_len = max_len
|
||||
config.layer_idx = layer_idx
|
||||
config.quant_config.bits = 8
|
||||
config.quant_config.group_size = group_size
|
||||
config.quant_config.zero_point = False
|
||||
config.pool = CPUInfer.backend_
|
||||
config.gate_proj = gate_q.data_ptr()
|
||||
config.up_proj = up_q.data_ptr()
|
||||
config.down_proj = down_q.data_ptr()
|
||||
config.gate_scale = gate_scale.data_ptr()
|
||||
config.up_scale = up_scale.data_ptr()
|
||||
config.down_scale = down_scale.data_ptr()
|
||||
|
||||
moe = AMXFP8_MOE(config)
|
||||
CPUInfer.submit(moe.load_weights_task(physical_to_logical_map.data_ptr()))
|
||||
CPUInfer.sync()
|
||||
|
||||
keep_tensors = {
|
||||
"gate_q": gate_q,
|
||||
"up_q": up_q,
|
||||
"down_q": down_q,
|
||||
"gate_scale": gate_scale,
|
||||
"up_scale": up_scale,
|
||||
"down_scale": down_scale,
|
||||
}
|
||||
|
||||
buffer_shapes = {
|
||||
"per_mat_weight_bytes": per_mat_weight_bytes,
|
||||
"per_mat_scale_elems_gate_up": per_mat_scale_elems_gate_up,
|
||||
"per_mat_scale_elems_down": per_mat_scale_elems_down,
|
||||
}
|
||||
|
||||
return moe, buffer_shapes, keep_tensors
|
||||
|
||||
|
||||
def allocate_buffers(buffer_shapes):
|
||||
"""Allocate shared output buffers for single expert."""
|
||||
per_mat_weight_bytes = buffer_shapes["per_mat_weight_bytes"]
|
||||
per_mat_scale_elems_gate_up = buffer_shapes["per_mat_scale_elems_gate_up"]
|
||||
per_mat_scale_elems_down = buffer_shapes["per_mat_scale_elems_down"]
|
||||
|
||||
weight_bytes_per_expert_per_tp = per_mat_weight_bytes // gpu_tp_count
|
||||
scale_elems_per_expert_per_tp_gate_up = per_mat_scale_elems_gate_up // gpu_tp_count
|
||||
scale_elems_per_expert_per_tp_down = per_mat_scale_elems_down // gpu_tp_count
|
||||
|
||||
# Each buffer stores data for a single expert
|
||||
w13_weight_bufs = [torch.empty(2 * weight_bytes_per_expert_per_tp, dtype=torch.uint8) for _ in range(gpu_tp_count)]
|
||||
w13_scale_bufs = [
|
||||
torch.empty(2 * scale_elems_per_expert_per_tp_gate_up, dtype=torch.float32) for _ in range(gpu_tp_count)
|
||||
]
|
||||
w2_weight_bufs = [torch.empty(weight_bytes_per_expert_per_tp, dtype=torch.uint8) for _ in range(gpu_tp_count)]
|
||||
w2_scale_bufs = [torch.empty(scale_elems_per_expert_per_tp_down, dtype=torch.float32) for _ in range(gpu_tp_count)]
|
||||
|
||||
buffer_ptrs = {
|
||||
"w13_weight_ptrs": [buf.data_ptr() for buf in w13_weight_bufs],
|
||||
"w13_scale_ptrs": [buf.data_ptr() for buf in w13_scale_bufs],
|
||||
"w2_weight_ptrs": [buf.data_ptr() for buf in w2_weight_bufs],
|
||||
"w2_scale_ptrs": [buf.data_ptr() for buf in w2_scale_bufs],
|
||||
}
|
||||
|
||||
keep_tensors = {
|
||||
"w13_weight_bufs": w13_weight_bufs,
|
||||
"w13_scale_bufs": w13_scale_bufs,
|
||||
"w2_weight_bufs": w2_weight_bufs,
|
||||
"w2_scale_bufs": w2_scale_bufs,
|
||||
}
|
||||
|
||||
return buffer_ptrs, keep_tensors
|
||||
|
||||
|
||||
def bench_write_buffer():
|
||||
# Build two MOE instances with different layer_idx
|
||||
moe_0, buffer_shapes, keep_tensors_0 = build_moe(layer_idx=0)
|
||||
moe_1, _, keep_tensors_1 = build_moe(layer_idx=1)
|
||||
moes = [moe_0, moe_1]
|
||||
|
||||
# Allocate shared buffers
|
||||
buffer_ptrs, buffer_keep_tensors = allocate_buffers(buffer_shapes)
|
||||
|
||||
total_weights = hidden_size * intermediate_size * expert_num * 3
|
||||
total_scale_bytes = (
|
||||
(buffer_shapes["per_mat_scale_elems_gate_up"] * 2 + buffer_shapes["per_mat_scale_elems_down"]) * expert_num * 4
|
||||
)
|
||||
bytes_per_call = total_weights + total_scale_bytes
|
||||
|
||||
# Warm-up: alternate between two MOEs
|
||||
for _ in tqdm(range(warm_up_iter), desc="Warm-up"):
|
||||
for moe_idx, moe in enumerate(moes):
|
||||
for expert_id in range(gpu_experts_num):
|
||||
CPUInfer.submit(
|
||||
moe.write_weight_scale_to_buffer_task(gpu_tp_count=gpu_tp_count, expert_id=expert_id, **buffer_ptrs)
|
||||
)
|
||||
CPUInfer.sync()
|
||||
|
||||
total_time = 0
|
||||
for iter_idx in tqdm(range(test_iter), desc="Testing"):
|
||||
start = time.perf_counter()
|
||||
# Alternate between two MOEs
|
||||
for moe_idx, moe in enumerate(moes):
|
||||
for expert_id in range(gpu_experts_num):
|
||||
CPUInfer.submit(
|
||||
moe.write_weight_scale_to_buffer_task(gpu_tp_count=gpu_tp_count, expert_id=expert_id, **buffer_ptrs)
|
||||
)
|
||||
CPUInfer.sync()
|
||||
end = time.perf_counter()
|
||||
iter_time = end - start
|
||||
total_time += iter_time
|
||||
print(f"Iter {iter_idx}: {iter_time*1000:.2f} ms")
|
||||
time.sleep(0.3)
|
||||
|
||||
# bytes_per_call is for one MOE, we have 2 MOEs
|
||||
bytes_per_iter = bytes_per_call * 2
|
||||
time_per_iter_ms = total_time / test_iter * 1000
|
||||
bandwidth_gbs = bytes_per_iter * test_iter / total_time / 1e9
|
||||
|
||||
print(f"\n{'='*60}")
|
||||
print("FP8 write_weight_scale_to_buffer benchmark (2 MOEs alternating)")
|
||||
print(f"{'='*60}")
|
||||
print(f"Time per iteration: {time_per_iter_ms:.2f} ms")
|
||||
print(f"Bandwidth: {bandwidth_gbs:.2f} GB/s")
|
||||
print(f"Experts per MOE: {gpu_experts_num}, MOEs: 2")
|
||||
print(f"Time per expert: {time_per_iter_ms/(gpu_experts_num*2)*1000:.2f} us")
|
||||
|
||||
result = {
|
||||
"op": "write_weight_scale_to_buffer_fp8",
|
||||
"time_per_iteration_ms": time_per_iter_ms,
|
||||
"bandwidth_GBs": bandwidth_gbs,
|
||||
"timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
|
||||
"test_parameters": {
|
||||
"expert_num": expert_num,
|
||||
"hidden_size": hidden_size,
|
||||
"intermediate_size": intermediate_size,
|
||||
"group_size": group_size,
|
||||
"gpu_tp_count": gpu_tp_count,
|
||||
"bytes_per_iter": bytes_per_iter,
|
||||
"num_moes": 2,
|
||||
},
|
||||
}
|
||||
result.update(get_git_commit())
|
||||
result.update(get_system_info())
|
||||
record_results(result)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
bench_write_buffer()
|
||||
@@ -2,6 +2,8 @@
|
||||
# coding=utf-8
|
||||
"""
|
||||
Benchmark write_weight_scale_to_buffer for AMX_K2_MOE_TP (int4 packed weights + bf16 scales).
|
||||
|
||||
Uses two MOE instances that alternate writing to simulate realistic multi-layer scenarios.
|
||||
"""
|
||||
import json
|
||||
import os
|
||||
@@ -17,7 +19,7 @@ sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "build"))
|
||||
from kt_kernel import kt_kernel_ext
|
||||
import torch
|
||||
|
||||
# Benchmark parameters (single MoE, mirror examples/test_k2_write_buffer.py)
|
||||
# Benchmark parameters
|
||||
expert_num = 384
|
||||
num_experts_per_tok = expert_num
|
||||
gpu_tp_count = 4
|
||||
@@ -33,7 +35,7 @@ group_size = 32
|
||||
max_len = 1
|
||||
|
||||
physical_to_logical_map = torch.arange(expert_num, dtype=torch.int64, device="cpu").contiguous()
|
||||
CPUInfer = kt_kernel_ext.CPUInfer(96)
|
||||
CPUInfer = kt_kernel_ext.CPUInfer(80)
|
||||
|
||||
|
||||
def get_git_commit():
|
||||
@@ -140,7 +142,8 @@ def allocate_weights():
|
||||
)
|
||||
|
||||
|
||||
def build_moe():
|
||||
def build_moe(layer_idx=0):
|
||||
"""Build a single MOE instance with the given layer_idx."""
|
||||
(
|
||||
gate_q,
|
||||
up_q,
|
||||
@@ -154,6 +157,7 @@ def build_moe():
|
||||
|
||||
config = kt_kernel_ext.moe.MOEConfig(expert_num, num_experts_per_tok, hidden_size, intermediate_size)
|
||||
config.max_len = max_len
|
||||
config.layer_idx = layer_idx
|
||||
config.quant_config.bits = 4
|
||||
config.quant_config.group_size = group_size
|
||||
config.quant_config.zero_point = False
|
||||
@@ -170,16 +174,36 @@ def build_moe():
|
||||
CPUInfer.submit(moe.load_weights_task(physical_to_logical_map.data_ptr()))
|
||||
CPUInfer.sync()
|
||||
|
||||
# Buffer sizing per TP
|
||||
keep_tensors = {
|
||||
"gate_q": gate_q,
|
||||
"up_q": up_q,
|
||||
"down_q": down_q,
|
||||
"gate_scale": gate_scale,
|
||||
"up_scale": up_scale,
|
||||
"down_scale": down_scale,
|
||||
}
|
||||
|
||||
buffer_shapes = {
|
||||
"per_mat_weight_bytes": per_mat_weight_bytes,
|
||||
"per_mat_scale_elems": per_mat_scale_elems,
|
||||
}
|
||||
|
||||
return moe, buffer_shapes, keep_tensors
|
||||
|
||||
|
||||
def allocate_buffers(buffer_shapes):
|
||||
"""Allocate shared output buffers for single expert."""
|
||||
per_mat_weight_bytes = buffer_shapes["per_mat_weight_bytes"]
|
||||
per_mat_scale_elems = buffer_shapes["per_mat_scale_elems"]
|
||||
|
||||
weight_bytes_per_expert_per_tp = per_mat_weight_bytes // gpu_tp_count
|
||||
scale_elems_per_expert_per_tp = per_mat_scale_elems // gpu_tp_count
|
||||
total_weight_bytes_per_tp = gpu_experts_num * weight_bytes_per_expert_per_tp
|
||||
total_scale_elems_per_tp = gpu_experts_num * scale_elems_per_expert_per_tp
|
||||
|
||||
w13_weight_bufs = [torch.empty(2 * total_weight_bytes_per_tp, dtype=torch.uint8) for _ in range(gpu_tp_count)]
|
||||
w13_scale_bufs = [torch.empty(2 * total_scale_elems_per_tp, dtype=torch.bfloat16) for _ in range(gpu_tp_count)]
|
||||
w2_weight_bufs = [torch.empty(total_weight_bytes_per_tp, dtype=torch.uint8) for _ in range(gpu_tp_count)]
|
||||
w2_scale_bufs = [torch.empty(total_scale_elems_per_tp, dtype=torch.bfloat16) for _ in range(gpu_tp_count)]
|
||||
# Each buffer stores data for a single expert
|
||||
w13_weight_bufs = [torch.empty(2 * weight_bytes_per_expert_per_tp, dtype=torch.uint8) for _ in range(gpu_tp_count)]
|
||||
w13_scale_bufs = [torch.empty(2 * scale_elems_per_expert_per_tp, dtype=torch.bfloat16) for _ in range(gpu_tp_count)]
|
||||
w2_weight_bufs = [torch.empty(weight_bytes_per_expert_per_tp, dtype=torch.uint8) for _ in range(gpu_tp_count)]
|
||||
w2_scale_bufs = [torch.empty(scale_elems_per_expert_per_tp, dtype=torch.bfloat16) for _ in range(gpu_tp_count)]
|
||||
|
||||
buffer_ptrs = {
|
||||
"w13_weight_ptrs": [buf.data_ptr() for buf in w13_weight_bufs],
|
||||
@@ -188,97 +212,89 @@ def build_moe():
|
||||
"w2_scale_ptrs": [buf.data_ptr() for buf in w2_scale_bufs],
|
||||
}
|
||||
|
||||
buffer_shapes = {
|
||||
"per_mat_weight_bytes": per_mat_weight_bytes,
|
||||
"per_mat_scale_elems": per_mat_scale_elems,
|
||||
"weight_bytes_per_expert_per_tp": weight_bytes_per_expert_per_tp,
|
||||
"scale_elems_per_expert_per_tp": scale_elems_per_expert_per_tp,
|
||||
"total_weight_bytes_per_tp": total_weight_bytes_per_tp,
|
||||
"total_scale_elems_per_tp": total_scale_elems_per_tp,
|
||||
}
|
||||
|
||||
keep_tensors = {
|
||||
"gate_q": gate_q,
|
||||
"up_q": up_q,
|
||||
"down_q": down_q,
|
||||
"gate_scale": gate_scale,
|
||||
"up_scale": up_scale,
|
||||
"down_scale": down_scale,
|
||||
"w13_weight_bufs": w13_weight_bufs,
|
||||
"w13_scale_bufs": w13_scale_bufs,
|
||||
"w2_weight_bufs": w2_weight_bufs,
|
||||
"w2_scale_bufs": w2_scale_bufs,
|
||||
}
|
||||
|
||||
return moe, buffer_ptrs, buffer_shapes, keep_tensors
|
||||
return buffer_ptrs, keep_tensors
|
||||
|
||||
|
||||
def bench_write_buffer():
|
||||
moe, buffer_ptrs, buffer_shapes, keep_tensors = build_moe()
|
||||
# Build two MOE instances with different layer_idx
|
||||
moe_0, buffer_shapes, keep_tensors_0 = build_moe(layer_idx=0)
|
||||
moe_1, _, keep_tensors_1 = build_moe(layer_idx=1)
|
||||
moes = [moe_0, moe_1]
|
||||
|
||||
# Allocate shared buffers
|
||||
buffer_ptrs, buffer_keep_tensors = allocate_buffers(buffer_shapes)
|
||||
|
||||
total_weights = hidden_size * intermediate_size * expert_num * 3
|
||||
# Throughput accounting consistent with examples/test_k2_write_buffer.py
|
||||
bytes_per_call = total_weights // group_size + total_weights // 2
|
||||
# Throughput accounting: scale bytes (bf16) + weight bytes (int4 packed)
|
||||
bytes_per_call = total_weights // group_size * 2 + total_weights // 2
|
||||
|
||||
# Warm-up
|
||||
# Warm-up: alternate between two MOEs
|
||||
for _ in tqdm(range(warm_up_iter), desc="Warm-up"):
|
||||
CPUInfer.submit(
|
||||
moe.write_weight_scale_to_buffer_task(
|
||||
gpu_tp_count=gpu_tp_count,
|
||||
gpu_experts_num=gpu_experts_num,
|
||||
**buffer_ptrs,
|
||||
)
|
||||
)
|
||||
CPUInfer.sync()
|
||||
for moe_idx, moe in enumerate(moes):
|
||||
for expert_id in range(gpu_experts_num):
|
||||
CPUInfer.submit(
|
||||
moe.write_weight_scale_to_buffer_task(
|
||||
gpu_tp_count=gpu_tp_count,
|
||||
expert_id=expert_id,
|
||||
**buffer_ptrs,
|
||||
)
|
||||
)
|
||||
CPUInfer.sync()
|
||||
|
||||
total_time = 0
|
||||
for _ in tqdm(range(test_iter), desc="Testing"):
|
||||
for iter_idx in tqdm(range(test_iter), desc="Testing"):
|
||||
start = time.perf_counter()
|
||||
CPUInfer.submit(
|
||||
moe.write_weight_scale_to_buffer_task(
|
||||
gpu_tp_count=gpu_tp_count,
|
||||
gpu_experts_num=gpu_experts_num,
|
||||
**buffer_ptrs,
|
||||
)
|
||||
)
|
||||
CPUInfer.sync()
|
||||
# Alternate between two MOEs
|
||||
for moe_idx, moe in enumerate(moes):
|
||||
for expert_id in range(gpu_experts_num):
|
||||
CPUInfer.submit(
|
||||
moe.write_weight_scale_to_buffer_task(
|
||||
gpu_tp_count=gpu_tp_count,
|
||||
expert_id=expert_id,
|
||||
**buffer_ptrs,
|
||||
)
|
||||
)
|
||||
CPUInfer.sync()
|
||||
end = time.perf_counter()
|
||||
total_time += end - start
|
||||
time.sleep(0.6)
|
||||
print(end - start)
|
||||
iter_time = end - start
|
||||
total_time += iter_time
|
||||
print(f"Iter {iter_idx}: {iter_time*1000:.2f} ms")
|
||||
time.sleep(0.3)
|
||||
|
||||
time_per_iter_us = total_time / test_iter * 1e6
|
||||
bandwidth_gbs = bytes_per_call * test_iter / total_time / 1e9
|
||||
# bytes_per_call is for one MOE, we have 2 MOEs
|
||||
bytes_per_iter = bytes_per_call * 2
|
||||
time_per_iter_ms = total_time / test_iter * 1000
|
||||
bandwidth_gbs = bytes_per_iter * test_iter / total_time / 1e9
|
||||
|
||||
print("write_weight_scale_to_buffer benchmark")
|
||||
print("Time(s): ", total_time)
|
||||
print("Iteration: ", test_iter)
|
||||
print("Time(us) per iteration: ", time_per_iter_us)
|
||||
print("Bandwidth: ", bandwidth_gbs, "GB/s")
|
||||
print("")
|
||||
print(f"\n{'='*60}")
|
||||
print("K2 write_weight_scale_to_buffer benchmark (2 MOEs alternating)")
|
||||
print(f"{'='*60}")
|
||||
print(f"Time per iteration: {time_per_iter_ms:.2f} ms")
|
||||
print(f"Bandwidth: {bandwidth_gbs:.2f} GB/s")
|
||||
print(f"Experts per MOE: {gpu_experts_num}, MOEs: 2")
|
||||
print(f"Time per expert: {time_per_iter_ms/(gpu_experts_num*2)*1000:.2f} us")
|
||||
|
||||
result = {
|
||||
"op": "write_weight_scale_to_buffer",
|
||||
"total_time_seconds": total_time,
|
||||
"iterations": test_iter,
|
||||
"time_per_iteration_us": time_per_iter_us,
|
||||
"op": "write_weight_scale_to_buffer_k2",
|
||||
"time_per_iteration_ms": time_per_iter_ms,
|
||||
"bandwidth_GBs": bandwidth_gbs,
|
||||
"timestamp": time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()),
|
||||
"timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
|
||||
"test_parameters": {
|
||||
"expert_num": expert_num,
|
||||
"hidden_size": hidden_size,
|
||||
"intermediate_size": intermediate_size,
|
||||
"group_size": group_size,
|
||||
"max_len": max_len,
|
||||
"num_experts_per_tok": num_experts_per_tok,
|
||||
"gpu_tp_count": gpu_tp_count,
|
||||
"gpu_experts_num": gpu_experts_num,
|
||||
"warm_up_iter": warm_up_iter,
|
||||
"test_iter": test_iter,
|
||||
"bytes_per_call": bytes_per_call,
|
||||
"bytes_per_iter": bytes_per_iter,
|
||||
"num_moes": 2,
|
||||
},
|
||||
"buffer_shapes": buffer_shapes,
|
||||
"keep_tensors_alive": list(keep_tensors.keys()),
|
||||
}
|
||||
result.update(get_git_commit())
|
||||
result.update(get_system_info())
|
||||
|
||||
457
kt-kernel/examples/test_fp8_moe.py
Normal file
457
kt-kernel/examples/test_fp8_moe.py
Normal file
@@ -0,0 +1,457 @@
|
||||
"""
|
||||
Test script for GemmKernel224FP8 (FP8 MoE) kernel validation.
|
||||
|
||||
This script:
|
||||
1. Generates random BF16 weights
|
||||
2. Quantizes them to FP8 format with 128x128 block-wise scales
|
||||
3. Runs the FP8 MoE kernel
|
||||
4. Compares results with PyTorch reference using dequantized BF16 weights
|
||||
|
||||
FP8 format notes:
|
||||
- Weight: FP8 (E4M3) stored as uint8, shape [expert_num, n, k]
|
||||
- Scale: FP32, shape [expert_num, n // group_size, k // group_size], group_size=128
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
sys.path.insert(0, os.path.dirname(__file__) + "/../build")
|
||||
|
||||
import torch
|
||||
import kt_kernel
|
||||
|
||||
torch.manual_seed(42)
|
||||
|
||||
# Model config
|
||||
hidden_size = 3072
|
||||
intermediate_size = 1536
|
||||
max_len = 25600
|
||||
|
||||
expert_num = 16
|
||||
num_experts_per_tok = 8
|
||||
|
||||
qlen = 100
|
||||
layer_num = 1
|
||||
CPUInfer = kt_kernel_ext.CPUInfer(40)
|
||||
validation_iter = 1
|
||||
fp8_group_size = 128 # FP8 uses 128x128 block quantization
|
||||
debug_print_count = 16
|
||||
|
||||
physical_to_logical_map = torch.tensor(data=range(expert_num), device="cpu", dtype=torch.int64).contiguous()
|
||||
|
||||
|
||||
def act_fn(x):
|
||||
"""SiLU activation function"""
|
||||
return x / (1.0 + torch.exp(-x))
|
||||
|
||||
|
||||
def mlp_torch(input, gate_proj, up_proj, down_proj):
|
||||
"""Reference MLP computation in PyTorch"""
|
||||
gate_buf = torch.mm(input, gate_proj.t())
|
||||
up_buf = torch.mm(input, up_proj.t())
|
||||
intermediate = act_fn(gate_buf) * up_buf
|
||||
ret = torch.mm(intermediate, down_proj.t())
|
||||
return ret
|
||||
|
||||
|
||||
def moe_torch(input, expert_ids, weights, gate_proj, up_proj, down_proj):
|
||||
"""Reference MoE computation in PyTorch"""
|
||||
cnts = expert_ids.new_zeros((expert_ids.shape[0], expert_num))
|
||||
cnts.scatter_(1, expert_ids, 1)
|
||||
tokens_per_expert = cnts.sum(dim=0)
|
||||
idxs = expert_ids.view(-1).argsort()
|
||||
sorted_tokens = input[idxs // expert_ids.shape[1]]
|
||||
|
||||
outputs = []
|
||||
start_idx = 0
|
||||
|
||||
for i, num_tokens in enumerate(tokens_per_expert):
|
||||
end_idx = start_idx + num_tokens
|
||||
if num_tokens == 0:
|
||||
continue
|
||||
tokens_for_this_expert = sorted_tokens[start_idx:end_idx]
|
||||
expert_out = mlp_torch(tokens_for_this_expert, gate_proj[i], up_proj[i], down_proj[i])
|
||||
outputs.append(expert_out)
|
||||
start_idx = end_idx
|
||||
|
||||
outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0)
|
||||
|
||||
new_x = torch.empty_like(outs)
|
||||
new_x[idxs] = outs
|
||||
t_output = (
|
||||
new_x.view(*expert_ids.shape, -1)
|
||||
.type(weights.dtype)
|
||||
.mul_(weights.unsqueeze(dim=-1))
|
||||
.sum(dim=1)
|
||||
.type(new_x.dtype)
|
||||
)
|
||||
return t_output
|
||||
|
||||
|
||||
# FP8 E4M3 constants
|
||||
FP8_E4M3_MAX = 448.0 # Maximum representable value in FP8 E4M3
|
||||
|
||||
|
||||
def fp8_e4m3_to_float(fp8_val: int) -> float:
|
||||
"""
|
||||
Convert FP8 E4M3 value to float.
|
||||
FP8 E4M3 format: 1 sign bit, 4 exponent bits, 3 mantissa bits
|
||||
"""
|
||||
sign = (fp8_val >> 7) & 1
|
||||
exp = (fp8_val >> 3) & 0xF
|
||||
mant = fp8_val & 0x7
|
||||
|
||||
if exp == 0:
|
||||
# Subnormal or zero
|
||||
if mant == 0:
|
||||
return -0.0 if sign else 0.0
|
||||
# Subnormal: value = (-1)^sign * 2^(-6) * (0.mant)
|
||||
return ((-1) ** sign) * (2**-6) * (mant / 8.0)
|
||||
elif exp == 15:
|
||||
# NaN (FP8 E4M3 doesn't have Inf, all exp=15 are NaN)
|
||||
return float("nan")
|
||||
else:
|
||||
# Normal: value = (-1)^sign * 2^(exp-7) * (1.mant)
|
||||
return ((-1) ** sign) * (2 ** (exp - 7)) * (1.0 + mant / 8.0)
|
||||
|
||||
|
||||
def float_to_fp8_e4m3(val: float) -> int:
|
||||
"""
|
||||
Convert float to FP8 E4M3 value.
|
||||
"""
|
||||
if val != val: # NaN
|
||||
return 0x7F # NaN representation
|
||||
|
||||
sign = 1 if val < 0 else 0
|
||||
val = abs(val)
|
||||
|
||||
if val == 0:
|
||||
return sign << 7
|
||||
|
||||
# Clamp to max representable value
|
||||
val = min(val, FP8_E4M3_MAX)
|
||||
|
||||
# Find exponent
|
||||
import math
|
||||
|
||||
if val < 2**-9: # Subnormal threshold
|
||||
# Subnormal
|
||||
mant = int(round(val / (2**-9)))
|
||||
mant = min(mant, 7)
|
||||
return (sign << 7) | mant
|
||||
|
||||
exp = int(math.floor(math.log2(val))) + 7
|
||||
exp = max(1, min(exp, 14)) # Clamp exponent to valid range
|
||||
|
||||
# Calculate mantissa
|
||||
mant = int(round((val / (2 ** (exp - 7)) - 1.0) * 8))
|
||||
mant = max(0, min(mant, 7))
|
||||
|
||||
# Handle overflow to next exponent
|
||||
if mant > 7:
|
||||
mant = 0
|
||||
exp += 1
|
||||
if exp > 14:
|
||||
exp = 14
|
||||
mant = 7
|
||||
|
||||
return (sign << 7) | (exp << 3) | mant
|
||||
|
||||
|
||||
def quantize_to_fp8_blockwise(weights: torch.Tensor, group_size: int = 128):
|
||||
"""
|
||||
Quantize BF16/FP32 weights to FP8 with block-wise scaling.
|
||||
|
||||
Args:
|
||||
weights: [expert_num, n, k] tensor in BF16/FP32
|
||||
group_size: Block size for quantization (default 128 for DeepSeek)
|
||||
|
||||
Returns:
|
||||
fp8_weights: [expert_num, n, k] uint8 tensor
|
||||
scales: [expert_num, n // group_size, k // group_size] BF16 tensor (scale_inv)
|
||||
"""
|
||||
weights_f32 = weights.to(torch.float32)
|
||||
e, n, k = weights_f32.shape
|
||||
|
||||
assert n % group_size == 0, f"n ({n}) must be divisible by group_size ({group_size})"
|
||||
assert k % group_size == 0, f"k ({k}) must be divisible by group_size ({group_size})"
|
||||
|
||||
n_blocks = n // group_size
|
||||
k_blocks = k // group_size
|
||||
|
||||
# Reshape to [e, n_blocks, group_size, k_blocks, group_size]
|
||||
reshaped = weights_f32.view(e, n_blocks, group_size, k_blocks, group_size)
|
||||
# Move to [e, n_blocks, k_blocks, group_size, group_size] for block processing
|
||||
reshaped = reshaped.permute(0, 1, 3, 2, 4)
|
||||
|
||||
# Calculate max abs per block
|
||||
max_abs = reshaped.abs().amax(dim=(-2, -1), keepdim=True)
|
||||
max_abs = torch.clamp(max_abs, min=1e-12)
|
||||
|
||||
# Scale to FP8 range: scale = max_abs / FP8_MAX
|
||||
# We store scale_inv = scale (for dequantization: fp8 * scale)
|
||||
scales = (max_abs / FP8_E4M3_MAX).squeeze(-1).squeeze(-1) # [e, n_blocks, k_blocks]
|
||||
|
||||
# Quantize: q = round(val / scale)
|
||||
scaled = reshaped / (scales.unsqueeze(-1).unsqueeze(-1) + 1e-12)
|
||||
|
||||
# Convert to FP8 E4M3 using vectorized approach
|
||||
# Clamp to FP8 representable range
|
||||
scaled = scaled.clamp(-FP8_E4M3_MAX, FP8_E4M3_MAX)
|
||||
|
||||
# Simple quantization: round to nearest representable FP8 value
|
||||
# For simplicity, we use a lookup table approach
|
||||
fp8_q = torch.zeros_like(scaled, dtype=torch.uint8)
|
||||
|
||||
# Vectorized FP8 quantization
|
||||
sign_mask = (scaled < 0).to(torch.uint8) << 7
|
||||
abs_scaled = scaled.abs()
|
||||
|
||||
# Handle different ranges
|
||||
# Subnormal: 0 < |x| < 2^-6
|
||||
subnormal_mask = (abs_scaled > 0) & (abs_scaled < 2**-6)
|
||||
subnormal_mant = (abs_scaled / (2**-9)).round().clamp(0, 7).to(torch.uint8)
|
||||
|
||||
# Normal values
|
||||
normal_mask = abs_scaled >= 2**-6
|
||||
log2_val = torch.log2(abs_scaled.clamp(min=2**-9))
|
||||
exp = (log2_val.floor() + 7).clamp(1, 14).to(torch.int32)
|
||||
mant = ((abs_scaled / (2.0 ** (exp.float() - 7)) - 1.0) * 8).round().clamp(0, 7).to(torch.uint8)
|
||||
|
||||
# Combine
|
||||
fp8_q = torch.where(subnormal_mask, sign_mask | subnormal_mant, fp8_q)
|
||||
fp8_q = torch.where(normal_mask, sign_mask | (exp.to(torch.uint8) << 3) | mant, fp8_q)
|
||||
|
||||
# Reshape back to [e, n, k]
|
||||
fp8_q = fp8_q.permute(0, 1, 3, 2, 4).reshape(e, n, k)
|
||||
|
||||
# Scales shape: [e, n_blocks, k_blocks] -> store as [e, n_blocks, k_blocks]
|
||||
scales_fp32 = scales.to(torch.float32).contiguous()
|
||||
|
||||
return fp8_q.contiguous(), scales_fp32
|
||||
|
||||
|
||||
def dequantize_fp8_blockwise(fp8_weights: torch.Tensor, scales: torch.Tensor, group_size: int = 128):
|
||||
"""
|
||||
Dequantize FP8 weights back to BF16 for reference computation.
|
||||
|
||||
Args:
|
||||
fp8_weights: [expert_num, n, k] uint8 tensor
|
||||
scales: [expert_num, n // group_size, k // group_size] BF16 tensor
|
||||
group_size: Block size
|
||||
|
||||
Returns:
|
||||
dequantized: [expert_num, n, k] BF16 tensor
|
||||
"""
|
||||
e, n, k = fp8_weights.shape
|
||||
n_blocks = n // group_size
|
||||
k_blocks = k // group_size
|
||||
|
||||
# Convert FP8 to float
|
||||
# Build lookup table for FP8 E4M3 -> float
|
||||
fp8_lut = torch.tensor([fp8_e4m3_to_float(i) for i in range(256)], dtype=torch.float32)
|
||||
|
||||
# Use lookup table
|
||||
fp8_float = fp8_lut[fp8_weights.to(torch.int64)]
|
||||
|
||||
# Reshape for block-wise scaling
|
||||
fp8_reshaped = fp8_float.view(e, n_blocks, group_size, k_blocks, group_size)
|
||||
fp8_reshaped = fp8_reshaped.permute(0, 1, 3, 2, 4) # [e, n_blocks, k_blocks, group_size, group_size]
|
||||
|
||||
# Apply scales
|
||||
scales_f32 = scales.to(torch.float32).unsqueeze(-1).unsqueeze(-1) # [e, n_blocks, k_blocks, 1, 1]
|
||||
dequantized = fp8_reshaped * scales_f32
|
||||
|
||||
# Reshape back
|
||||
dequantized = dequantized.permute(0, 1, 3, 2, 4).reshape(e, n, k)
|
||||
|
||||
return dequantized.to(torch.bfloat16).contiguous()
|
||||
|
||||
|
||||
def build_random_fp8_weights():
|
||||
"""
|
||||
Generate random BF16 weights and quantize to FP8.
|
||||
|
||||
Returns:
|
||||
dict with fp8 weights, scales, and original bf16 for reference
|
||||
"""
|
||||
torch.manual_seed(42)
|
||||
|
||||
# Generate random BF16 weights with small values
|
||||
gate_proj = (torch.randn((expert_num, intermediate_size, hidden_size), dtype=torch.float32) / 100.0).to(
|
||||
torch.bfloat16
|
||||
)
|
||||
up_proj = (torch.randn((expert_num, intermediate_size, hidden_size), dtype=torch.float32) / 100.0).to(
|
||||
torch.bfloat16
|
||||
)
|
||||
down_proj = (torch.randn((expert_num, hidden_size, intermediate_size), dtype=torch.float32) / 100.0).to(
|
||||
torch.bfloat16
|
||||
)
|
||||
|
||||
# Quantize to FP8
|
||||
gate_fp8, gate_scales = quantize_to_fp8_blockwise(gate_proj, fp8_group_size)
|
||||
up_fp8, up_scales = quantize_to_fp8_blockwise(up_proj, fp8_group_size)
|
||||
down_fp8, down_scales = quantize_to_fp8_blockwise(down_proj, fp8_group_size)
|
||||
|
||||
# Dequantize for reference computation
|
||||
gate_deq = dequantize_fp8_blockwise(gate_fp8, gate_scales, fp8_group_size)
|
||||
up_deq = dequantize_fp8_blockwise(up_fp8, up_scales, fp8_group_size)
|
||||
down_deq = dequantize_fp8_blockwise(down_fp8, down_scales, fp8_group_size)
|
||||
|
||||
print(f"FP8 weights shape: gate={gate_fp8.shape}, up={up_fp8.shape}, down={down_fp8.shape}")
|
||||
print(f"Scales shape: gate={gate_scales.shape}, up={up_scales.shape}, down={down_scales.shape}")
|
||||
|
||||
# Debug: Print FP8 weight and scale info for expert 0
|
||||
print("\n=== DEBUG: FP8 Weight and Scale Info (Expert 0) ===")
|
||||
print(f"gate_fp8[0] first 8x8 block:")
|
||||
for i in range(8):
|
||||
print(f" row {i}: {gate_fp8[0, i, :8].numpy().tobytes().hex(' ')}")
|
||||
print(f"gate_fp8[0] stats: min={gate_fp8[0].min()}, max={gate_fp8[0].max()}")
|
||||
print(f"gate_scales[0] first 4x4 block:\n{gate_scales[0, :4, :4]}")
|
||||
print(f"gate_scales[0] stats: min={gate_scales[0].min()}, max={gate_scales[0].max()}")
|
||||
|
||||
print(f"\nup_fp8[0] first 8x8 block:")
|
||||
for i in range(8):
|
||||
print(f" row {i}: {up_fp8[0, i, :8].numpy().tobytes().hex(' ')}")
|
||||
print(f"up_fp8[0] stats: min={up_fp8[0].min()}, max={up_fp8[0].max()}")
|
||||
print(f"up_scales[0] first 4x4 block:\n{up_scales[0, :4, :4]}")
|
||||
print(f"up_scales[0] stats: min={up_scales[0].min()}, max={up_scales[0].max()}")
|
||||
|
||||
print(f"\ndown_fp8[0] first 8x8 block:")
|
||||
for i in range(8):
|
||||
print(f" row {i}: {down_fp8[0, i, :8].numpy().tobytes().hex(' ')}")
|
||||
print(f"down_fp8[0] stats: min={down_fp8[0].min()}, max={down_fp8[0].max()}")
|
||||
print(f"down_scales[0] first 4x4 block:\n{down_scales[0, :4, :4]}")
|
||||
print(f"down_scales[0] stats: min={down_scales[0].min()}, max={down_scales[0].max()}")
|
||||
|
||||
return {
|
||||
"gate_fp8": gate_fp8.contiguous(),
|
||||
"up_fp8": up_fp8.contiguous(),
|
||||
"down_fp8": down_fp8.contiguous(),
|
||||
"gate_scales": gate_scales.contiguous(),
|
||||
"up_scales": up_scales.contiguous(),
|
||||
"down_scales": down_scales.contiguous(),
|
||||
"gate_deq": gate_deq.contiguous(),
|
||||
"up_deq": up_deq.contiguous(),
|
||||
"down_deq": down_deq.contiguous(),
|
||||
}
|
||||
|
||||
|
||||
def build_moes_from_fp8_data(fp8_data: dict):
|
||||
"""
|
||||
Build FP8 MoE modules from quantized data.
|
||||
"""
|
||||
moes = []
|
||||
with torch.inference_mode(mode=True):
|
||||
for _ in range(layer_num):
|
||||
config = kt_kernel_ext.moe.MOEConfig(expert_num, num_experts_per_tok, hidden_size, intermediate_size, 0)
|
||||
config.max_len = max_len
|
||||
config.quant_config.bits = 8
|
||||
config.quant_config.group_size = fp8_group_size
|
||||
config.quant_config.zero_point = False
|
||||
|
||||
# Set FP8 weight pointers
|
||||
config.gate_proj = fp8_data["gate_fp8"].data_ptr()
|
||||
config.up_proj = fp8_data["up_fp8"].data_ptr()
|
||||
config.down_proj = fp8_data["down_fp8"].data_ptr()
|
||||
|
||||
# Set scale pointers
|
||||
config.gate_scale = fp8_data["gate_scales"].data_ptr()
|
||||
config.up_scale = fp8_data["up_scales"].data_ptr()
|
||||
config.down_scale = fp8_data["down_scales"].data_ptr()
|
||||
config.pool = CPUInfer.backend_
|
||||
|
||||
moe = kt_kernel_ext.moe.AMXFP8_MOE(config)
|
||||
CPUInfer.submit(moe.load_weights_task(physical_to_logical_map.data_ptr()))
|
||||
CPUInfer.sync()
|
||||
moes.append(moe)
|
||||
return moes
|
||||
|
||||
|
||||
def run_fp8_moe_test():
|
||||
"""
|
||||
Run FP8 MoE validation test.
|
||||
"""
|
||||
print("\n" + "=" * 70)
|
||||
print("FP8 MoE Kernel Validation Test")
|
||||
print("=" * 70)
|
||||
|
||||
# Build FP8 weights
|
||||
print("\nGenerating and quantizing weights...")
|
||||
fp8_data = build_random_fp8_weights()
|
||||
|
||||
# Build MoE modules
|
||||
print("\nBuilding FP8 MoE modules...")
|
||||
moes = build_moes_from_fp8_data(fp8_data)
|
||||
|
||||
# Get dequantized weights for reference
|
||||
gate_deq = fp8_data["gate_deq"]
|
||||
up_deq = fp8_data["up_deq"]
|
||||
down_deq = fp8_data["down_deq"]
|
||||
|
||||
diffs = []
|
||||
with torch.inference_mode(mode=True):
|
||||
for i in range(validation_iter):
|
||||
torch.manual_seed(100 + i)
|
||||
bsz_tensor = torch.tensor([qlen], device="cpu")
|
||||
expert_ids = torch.stack(
|
||||
[torch.randperm(expert_num)[:num_experts_per_tok] for _ in range(qlen)]
|
||||
).contiguous()
|
||||
weights = torch.randn((qlen, num_experts_per_tok), dtype=torch.float32).contiguous() / 100
|
||||
input_tensor = torch.randn((qlen, hidden_size), dtype=torch.bfloat16).contiguous() * 1.5
|
||||
output = torch.empty((qlen, hidden_size), dtype=torch.bfloat16).contiguous()
|
||||
|
||||
moe = moes[i % layer_num]
|
||||
CPUInfer.submit(
|
||||
moe.forward_task(
|
||||
bsz_tensor.data_ptr(),
|
||||
num_experts_per_tok,
|
||||
expert_ids.data_ptr(),
|
||||
weights.data_ptr(),
|
||||
input_tensor.data_ptr(),
|
||||
output.data_ptr(),
|
||||
False,
|
||||
)
|
||||
)
|
||||
CPUInfer.sync()
|
||||
|
||||
assert not torch.isnan(output).any(), "NaN values detected in CPU expert output."
|
||||
assert not torch.isinf(output).any(), "Inf values detected in CPU expert output."
|
||||
|
||||
# Reference computation using dequantized weights
|
||||
t_output = moe_torch(input_tensor, expert_ids, weights, gate_deq, up_deq, down_deq)
|
||||
|
||||
t_output_flat = t_output.flatten()
|
||||
output_flat = output.flatten()
|
||||
|
||||
diff = torch.mean(torch.abs(output_flat - t_output_flat)) / (torch.mean(torch.abs(t_output_flat)) + 1e-12)
|
||||
diffs.append(diff.item())
|
||||
print(f"Iteration {i}: relative L1 diff = {diff:.6f}")
|
||||
|
||||
if i < 3: # Print detailed output for first few iterations
|
||||
print(f" kernel output: {output_flat[:debug_print_count]}")
|
||||
print(f" torch output: {t_output_flat[:debug_print_count]}")
|
||||
|
||||
mean_diff = float(sum(diffs) / len(diffs))
|
||||
max_diff = float(max(diffs))
|
||||
min_diff = float(min(diffs))
|
||||
|
||||
print("\n" + "=" * 70)
|
||||
print("FP8 MoE Test Results")
|
||||
print("=" * 70)
|
||||
print(f"Mean relative L1 diff: {mean_diff*100:.4f}%")
|
||||
print(f"Max relative L1 diff: {max_diff*100:.4f}%")
|
||||
print(f"Min relative L1 diff: {min_diff*100:.4f}%")
|
||||
|
||||
# Pass/Fail criteria
|
||||
threshold = 15.0 # 15% relative error threshold for FP8
|
||||
if mean_diff * 100 < threshold:
|
||||
print(f"\nPASS: Mean error {mean_diff*100:.4f}% < {threshold}% threshold")
|
||||
else:
|
||||
print(f"\nFAIL: Mean error {mean_diff*100:.4f}% >= {threshold}% threshold")
|
||||
|
||||
return {"mean": mean_diff, "max": max_diff, "min": min_diff}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_fp8_moe_test()
|
||||
389
kt-kernel/examples/test_fp8_write_buffer.py
Normal file
389
kt-kernel/examples/test_fp8_write_buffer.py
Normal file
@@ -0,0 +1,389 @@
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
|
||||
from kt_kernel import kt_kernel_ext
|
||||
from kt_kernel_ext import CPUInfer
|
||||
from kt_kernel_ext.moe import AMXFP8_MOE
|
||||
|
||||
|
||||
def make_cpu_infer(thread_num=80):
|
||||
return CPUInfer(thread_num)
|
||||
|
||||
|
||||
def build_config(cpuinfer, expert_num, num_experts_per_tok, hidden_size, intermediate_size, group_size):
|
||||
cfg = kt_kernel_ext.moe.MOEConfig(expert_num, num_experts_per_tok, hidden_size, intermediate_size)
|
||||
cfg.max_len = 1
|
||||
cfg.quant_config.bits = 8 # FP8
|
||||
cfg.quant_config.group_size = group_size
|
||||
cfg.quant_config.zero_point = False
|
||||
cfg.pool = cpuinfer.backend_
|
||||
return cfg
|
||||
|
||||
|
||||
def allocate_weights(expert_num, hidden_size, intermediate_size, group_size):
|
||||
"""Allocate FP8 weights and scales for testing"""
|
||||
# FP8 weights: 1 byte per element
|
||||
per_mat_weight_bytes = hidden_size * intermediate_size
|
||||
# FP8 scales: block-wise (group_size x group_size blocks), stored as float32
|
||||
n_blocks_n_gate_up = (intermediate_size + group_size - 1) // group_size
|
||||
n_blocks_k = (hidden_size + group_size - 1) // group_size
|
||||
per_mat_scale_elems_gate_up = n_blocks_n_gate_up * n_blocks_k
|
||||
|
||||
# For down: n=hidden_size, k=intermediate_size
|
||||
n_blocks_n_down = n_blocks_k
|
||||
n_blocks_k_down = n_blocks_n_gate_up
|
||||
per_mat_scale_elems_down = n_blocks_n_down * n_blocks_k_down
|
||||
|
||||
gate_q = torch.randint(0, 256, (expert_num * per_mat_weight_bytes,), dtype=torch.uint8)
|
||||
up_q = torch.randint(0, 256, (expert_num * per_mat_weight_bytes,), dtype=torch.uint8)
|
||||
down_q = torch.randint(0, 256, (expert_num * per_mat_weight_bytes,), dtype=torch.uint8)
|
||||
|
||||
# FP8 scales are float32
|
||||
gate_scale = torch.randn(expert_num * per_mat_scale_elems_gate_up, dtype=torch.float32)
|
||||
up_scale = torch.randn(expert_num * per_mat_scale_elems_gate_up, dtype=torch.float32)
|
||||
down_scale = torch.randn(expert_num * per_mat_scale_elems_down, dtype=torch.float32)
|
||||
|
||||
return (
|
||||
gate_q,
|
||||
up_q,
|
||||
down_q,
|
||||
gate_scale,
|
||||
up_scale,
|
||||
down_scale,
|
||||
per_mat_weight_bytes,
|
||||
per_mat_scale_elems_gate_up,
|
||||
per_mat_scale_elems_down,
|
||||
)
|
||||
|
||||
|
||||
def test_with_tp(gpu_tp_count):
|
||||
"""Test write_weight_scale_to_buffer with a specific gpu_tp_count"""
|
||||
torch.manual_seed(123)
|
||||
|
||||
expert_num = 256 # Reduced for debugging
|
||||
gpu_experts = expert_num # Number of experts on GPU
|
||||
|
||||
num_experts_per_tok = 8
|
||||
hidden_size = 3072
|
||||
intermediate_size = 1536 # Changed from 2048 to test non-aligned case
|
||||
group_size = 128 # FP8 uses 128x128 block-wise scales
|
||||
|
||||
cpuinfer = make_cpu_infer()
|
||||
cfg = build_config(cpuinfer, expert_num, num_experts_per_tok, hidden_size, intermediate_size, group_size)
|
||||
|
||||
(
|
||||
gate_q,
|
||||
up_q,
|
||||
down_q,
|
||||
gate_scale,
|
||||
up_scale,
|
||||
down_scale,
|
||||
per_mat_weight_bytes,
|
||||
per_mat_scale_elems_gate_up,
|
||||
per_mat_scale_elems_down,
|
||||
) = allocate_weights(expert_num, hidden_size, intermediate_size, group_size)
|
||||
|
||||
cfg.gate_proj = gate_q.data_ptr()
|
||||
cfg.up_proj = up_q.data_ptr()
|
||||
cfg.down_proj = down_q.data_ptr()
|
||||
cfg.gate_scale = gate_scale.data_ptr()
|
||||
cfg.up_scale = up_scale.data_ptr()
|
||||
cfg.down_scale = down_scale.data_ptr()
|
||||
|
||||
moe = AMXFP8_MOE(cfg)
|
||||
|
||||
physical_to_logical_map = torch.arange(expert_num, dtype=torch.int64, device="cpu").contiguous()
|
||||
cpuinfer.submit(moe.load_weights_task(physical_to_logical_map.data_ptr()))
|
||||
cpuinfer.sync()
|
||||
|
||||
# TP configuration
|
||||
# Calculate sizes per TP part (per expert) - must match C++ code which uses div_up
|
||||
def div_up(a, b):
|
||||
return (a + b - 1) // b
|
||||
|
||||
weight_bytes_per_expert_per_tp = per_mat_weight_bytes // gpu_tp_count
|
||||
|
||||
# For W13 (gate/up): n=intermediate_size/gpu_tp, k=hidden_size
|
||||
gpu_n_w13 = intermediate_size // gpu_tp_count
|
||||
gpu_k_w13 = hidden_size
|
||||
scale_elems_per_expert_per_tp_gate_up = div_up(gpu_n_w13, group_size) * div_up(gpu_k_w13, group_size)
|
||||
|
||||
# For W2 (down): n=hidden_size, k=intermediate_size/gpu_tp
|
||||
gpu_n_w2 = hidden_size
|
||||
gpu_k_w2 = intermediate_size // gpu_tp_count
|
||||
scale_elems_per_expert_per_tp_down = div_up(gpu_n_w2, group_size) * div_up(gpu_k_w2, group_size)
|
||||
|
||||
# Total sizes for all gpu_experts
|
||||
total_weight_bytes_per_tp = gpu_experts * weight_bytes_per_expert_per_tp
|
||||
total_scale_elems_per_tp_gate_up = gpu_experts * scale_elems_per_expert_per_tp_gate_up
|
||||
total_scale_elems_per_tp_down = gpu_experts * scale_elems_per_expert_per_tp_down
|
||||
|
||||
# Create buffer lists for w13 (gate+up) and w2 (down)
|
||||
# These hold all experts' data for each GPU TP
|
||||
w13_weight_bufs = []
|
||||
w13_scale_bufs = []
|
||||
w2_weight_bufs = []
|
||||
w2_scale_bufs = []
|
||||
|
||||
for tp_idx in range(gpu_tp_count):
|
||||
# w13 combines gate and up, so needs 2x the size
|
||||
w13_weight_bufs.append(torch.empty(2 * total_weight_bytes_per_tp, dtype=torch.uint8))
|
||||
w13_scale_bufs.append(torch.empty(2 * total_scale_elems_per_tp_gate_up, dtype=torch.float32))
|
||||
w2_weight_bufs.append(torch.empty(total_weight_bytes_per_tp, dtype=torch.uint8))
|
||||
w2_scale_bufs.append(torch.empty(total_scale_elems_per_tp_down, dtype=torch.float32))
|
||||
|
||||
print(f"Total experts: {expert_num}, GPU experts: {gpu_experts}")
|
||||
print(f"GPU TP count: {gpu_tp_count}")
|
||||
print(f"Original per matrix weight bytes: {per_mat_weight_bytes}")
|
||||
print(f"Original per matrix scale elements (gate/up): {per_mat_scale_elems_gate_up}")
|
||||
print(f"Original per matrix scale elements (down): {per_mat_scale_elems_down}")
|
||||
print(f"Weight bytes per expert per TP: {weight_bytes_per_expert_per_tp}")
|
||||
print(f"Scale elements per expert per TP (gate/up): {scale_elems_per_expert_per_tp_gate_up}")
|
||||
print(f"Scale elements per expert per TP (down): {scale_elems_per_expert_per_tp_down}")
|
||||
print(f"Total weight bytes per TP (w13): {2 * total_weight_bytes_per_tp}")
|
||||
print(f"Total weight bytes per TP (w2): {total_weight_bytes_per_tp}")
|
||||
|
||||
# Helper function to get pointers with expert offset
|
||||
# write_weights_to_buffer writes one expert at a time, so we need to pass
|
||||
# pointers that already point to the correct location for each expert
|
||||
def get_expert_ptrs(expert_id):
|
||||
w13_weight_ptrs = []
|
||||
w13_scale_ptrs = []
|
||||
w2_weight_ptrs = []
|
||||
w2_scale_ptrs = []
|
||||
|
||||
for tp_idx in range(gpu_tp_count):
|
||||
# Calculate byte offsets for this expert
|
||||
# w13: gate_weight + up_weight interleaved by expert
|
||||
# Layout: [expert0_gate, expert0_up, expert1_gate, expert1_up, ...]
|
||||
w13_weight_expert_offset = expert_id * 2 * weight_bytes_per_expert_per_tp
|
||||
w13_scale_expert_offset = expert_id * 2 * scale_elems_per_expert_per_tp_gate_up
|
||||
w2_weight_expert_offset = expert_id * weight_bytes_per_expert_per_tp
|
||||
w2_scale_expert_offset = expert_id * scale_elems_per_expert_per_tp_down
|
||||
|
||||
w13_weight_ptrs.append(w13_weight_bufs[tp_idx].data_ptr() + w13_weight_expert_offset)
|
||||
w13_scale_ptrs.append(w13_scale_bufs[tp_idx].data_ptr() + w13_scale_expert_offset * 4) # float32 = 4 bytes
|
||||
w2_weight_ptrs.append(w2_weight_bufs[tp_idx].data_ptr() + w2_weight_expert_offset)
|
||||
w2_scale_ptrs.append(w2_scale_bufs[tp_idx].data_ptr() + w2_scale_expert_offset * 4) # float32 = 4 bytes
|
||||
|
||||
return w13_weight_ptrs, w13_scale_ptrs, w2_weight_ptrs, w2_scale_ptrs
|
||||
|
||||
# Warm up
|
||||
for i in range(2):
|
||||
for expert_id in range(gpu_experts):
|
||||
w13_weight_ptrs, w13_scale_ptrs, w2_weight_ptrs, w2_scale_ptrs = get_expert_ptrs(expert_id)
|
||||
cpuinfer.submit(
|
||||
moe.write_weight_scale_to_buffer_task(
|
||||
gpu_tp_count=gpu_tp_count,
|
||||
expert_id=expert_id,
|
||||
w13_weight_ptrs=w13_weight_ptrs,
|
||||
w13_scale_ptrs=w13_scale_ptrs,
|
||||
w2_weight_ptrs=w2_weight_ptrs,
|
||||
w2_scale_ptrs=w2_scale_ptrs,
|
||||
)
|
||||
)
|
||||
cpuinfer.sync()
|
||||
|
||||
# Timing
|
||||
begin_time = time.perf_counter_ns()
|
||||
for expert_id in range(gpu_experts):
|
||||
w13_weight_ptrs, w13_scale_ptrs, w2_weight_ptrs, w2_scale_ptrs = get_expert_ptrs(expert_id)
|
||||
cpuinfer.submit(
|
||||
moe.write_weight_scale_to_buffer_task(
|
||||
gpu_tp_count=gpu_tp_count,
|
||||
expert_id=expert_id,
|
||||
w13_weight_ptrs=w13_weight_ptrs,
|
||||
w13_scale_ptrs=w13_scale_ptrs,
|
||||
w2_weight_ptrs=w2_weight_ptrs,
|
||||
w2_scale_ptrs=w2_scale_ptrs,
|
||||
)
|
||||
)
|
||||
cpuinfer.sync()
|
||||
end_time = time.perf_counter_ns()
|
||||
elapsed_ms = (end_time - begin_time) / 1000000
|
||||
|
||||
# Calculate throughput
|
||||
total_weights = hidden_size * intermediate_size * gpu_experts * 3
|
||||
total_scale_bytes = (per_mat_scale_elems_gate_up * 2 + per_mat_scale_elems_down) * gpu_experts * 4 # float32
|
||||
total_bytes = total_weights + total_scale_bytes
|
||||
print(f"write_weight_scale_to_buffer time: {elapsed_ms:.2f} ms")
|
||||
print(f"Throughput: {total_bytes / (elapsed_ms * 1e6):.2f} GB/s")
|
||||
|
||||
def split_expert_tensor(tensor, chunk):
|
||||
"""Split tensor by experts"""
|
||||
return [tensor[i * chunk : (i + 1) * chunk] for i in range(expert_num)]
|
||||
|
||||
# Split by experts first
|
||||
gate_q_experts = split_expert_tensor(gate_q, per_mat_weight_bytes)
|
||||
up_q_experts = split_expert_tensor(up_q, per_mat_weight_bytes)
|
||||
down_q_experts = split_expert_tensor(down_q, per_mat_weight_bytes)
|
||||
|
||||
gate_scale_experts = split_expert_tensor(gate_scale, per_mat_scale_elems_gate_up)
|
||||
up_scale_experts = split_expert_tensor(up_scale, per_mat_scale_elems_gate_up)
|
||||
down_scale_experts = split_expert_tensor(down_scale, per_mat_scale_elems_down)
|
||||
|
||||
# For down matrix
|
||||
n_blocks_n = (hidden_size + group_size - 1) // group_size
|
||||
n_blocks_k = (intermediate_size + group_size - 1) // group_size
|
||||
n_blocks_k_per_tp = n_blocks_k // gpu_tp_count
|
||||
|
||||
# Verify buffers for each TP part
|
||||
for tp_idx in range(gpu_tp_count):
|
||||
expected_w13_weights = []
|
||||
expected_w13_scales = []
|
||||
expected_w2_weights = []
|
||||
expected_w2_scales = []
|
||||
|
||||
weight13_per_tp = per_mat_weight_bytes // gpu_tp_count
|
||||
scale13_per_tp = per_mat_scale_elems_gate_up // gpu_tp_count
|
||||
|
||||
# Process each GPU expert
|
||||
for expert_id in range(gpu_experts):
|
||||
# For w13 (gate and up), the slicing is along intermediate_size (n direction)
|
||||
start_weight = tp_idx * weight13_per_tp
|
||||
end_weight = (tp_idx + 1) * weight13_per_tp
|
||||
start_scale = tp_idx * scale13_per_tp
|
||||
end_scale = (tp_idx + 1) * scale13_per_tp
|
||||
|
||||
# Gate
|
||||
gate_weight_tp = gate_q_experts[expert_id][start_weight:end_weight]
|
||||
gate_scale_tp = gate_scale_experts[expert_id][start_scale:end_scale]
|
||||
|
||||
# Up
|
||||
up_weight_tp = up_q_experts[expert_id][start_weight:end_weight]
|
||||
up_scale_tp = up_scale_experts[expert_id][start_scale:end_scale]
|
||||
|
||||
# Down matrix needs special handling because it's sliced column-wise
|
||||
# down is (hidden_size, intermediate_size) in n-major format
|
||||
down_weight_tp_parts = []
|
||||
down_scale_tp_parts = []
|
||||
|
||||
# Iterate through each row to extract the corresponding parts
|
||||
for row_idx in range(hidden_size):
|
||||
row_weight_start = row_idx * intermediate_size
|
||||
|
||||
# Direct mapping: each CPU TP corresponds to a GPU TP
|
||||
tp_slice_weight_size = intermediate_size // gpu_tp_count
|
||||
|
||||
tp_weight_offset = row_weight_start + tp_idx * tp_slice_weight_size
|
||||
|
||||
down_weight_tp_parts.append(
|
||||
down_q_experts[expert_id][tp_weight_offset : tp_weight_offset + tp_slice_weight_size]
|
||||
)
|
||||
|
||||
# For scale: only process at block boundaries
|
||||
for bn in range(n_blocks_n):
|
||||
row_scale_start = bn * n_blocks_k
|
||||
tp_scale_offset = row_scale_start + tp_idx * n_blocks_k_per_tp
|
||||
down_scale_tp_parts.append(
|
||||
down_scale_experts[expert_id][tp_scale_offset : tp_scale_offset + n_blocks_k_per_tp]
|
||||
)
|
||||
|
||||
# Concatenate all slices for this TP
|
||||
down_weight_tp = torch.cat(down_weight_tp_parts)
|
||||
down_scale_tp = torch.cat(down_scale_tp_parts)
|
||||
|
||||
# Append to expected lists - interleaved by expert: [gate0, up0, gate1, up1, ...]
|
||||
expected_w13_weights.append(gate_weight_tp)
|
||||
expected_w13_weights.append(up_weight_tp)
|
||||
expected_w13_scales.append(gate_scale_tp)
|
||||
expected_w13_scales.append(up_scale_tp)
|
||||
expected_w2_weights.append(down_weight_tp)
|
||||
expected_w2_scales.append(down_scale_tp)
|
||||
|
||||
# Concatenate all experts for this TP part
|
||||
expected_w13_weight = torch.cat(expected_w13_weights)
|
||||
expected_w13_scale = torch.cat(expected_w13_scales)
|
||||
expected_w2_weight = torch.cat(expected_w2_weights)
|
||||
expected_w2_scale = torch.cat(expected_w2_scales)
|
||||
|
||||
print(f"=== Checking TP part {tp_idx} ===")
|
||||
print(f" w13 weight shape: actual={w13_weight_bufs[tp_idx].shape}, expected={expected_w13_weight.shape}")
|
||||
print(f" w13 scale shape: actual={w13_scale_bufs[tp_idx].shape}, expected={expected_w13_scale.shape}")
|
||||
print(f" w2 weight shape: actual={w2_weight_bufs[tp_idx].shape}, expected={expected_w2_weight.shape}")
|
||||
print(f" w2 scale shape: actual={w2_scale_bufs[tp_idx].shape}, expected={expected_w2_scale.shape}")
|
||||
|
||||
# Assert all checks pass
|
||||
if not torch.equal(w13_weight_bufs[tp_idx], expected_w13_weight):
|
||||
# Find first mismatch
|
||||
diff_mask = w13_weight_bufs[tp_idx] != expected_w13_weight
|
||||
first_diff_idx = diff_mask.nonzero()[0].item() if diff_mask.any() else -1
|
||||
print(f" w13 weight mismatch at index {first_diff_idx}")
|
||||
print(f" actual: {w13_weight_bufs[tp_idx][first_diff_idx:first_diff_idx+10]}")
|
||||
print(f" expected: {expected_w13_weight[first_diff_idx:first_diff_idx+10]}")
|
||||
raise AssertionError(f"w13 weight bytes mismatch for TP {tp_idx}")
|
||||
|
||||
if not torch.allclose(w13_scale_bufs[tp_idx], expected_w13_scale):
|
||||
diff = torch.abs(w13_scale_bufs[tp_idx] - expected_w13_scale)
|
||||
max_diff_idx = diff.argmax().item()
|
||||
print(f" w13 scale mismatch, max diff at index {max_diff_idx}")
|
||||
print(f" actual: {w13_scale_bufs[tp_idx][max_diff_idx]}")
|
||||
print(f" expected: {expected_w13_scale[max_diff_idx]}")
|
||||
raise AssertionError(f"w13 scale values mismatch for TP {tp_idx}")
|
||||
|
||||
if not torch.equal(w2_weight_bufs[tp_idx], expected_w2_weight):
|
||||
diff_mask = w2_weight_bufs[tp_idx] != expected_w2_weight
|
||||
first_diff_idx = diff_mask.nonzero()[0].item() if diff_mask.any() else -1
|
||||
print(f" w2 weight mismatch at index {first_diff_idx}")
|
||||
print(f" actual: {w2_weight_bufs[tp_idx][first_diff_idx:first_diff_idx+10]}")
|
||||
print(f" expected: {expected_w2_weight[first_diff_idx:first_diff_idx+10]}")
|
||||
raise AssertionError(f"w2 weight bytes mismatch for TP {tp_idx}")
|
||||
|
||||
if not torch.allclose(w2_scale_bufs[tp_idx], expected_w2_scale):
|
||||
diff = torch.abs(w2_scale_bufs[tp_idx] - expected_w2_scale)
|
||||
max_diff_idx = diff.argmax().item()
|
||||
print(f" w2 scale mismatch, max diff at index {max_diff_idx}")
|
||||
print(f" actual: {w2_scale_bufs[tp_idx][max_diff_idx]}")
|
||||
print(f" expected: {expected_w2_scale[max_diff_idx]}")
|
||||
raise AssertionError(f"w2 scale values mismatch for TP {tp_idx}")
|
||||
|
||||
print(
|
||||
f"\n✓ write_weight_scale_to_buffer passed: extracted {gpu_experts} GPU experts across {gpu_tp_count} TP parts"
|
||||
)
|
||||
return True
|
||||
|
||||
|
||||
def main():
|
||||
"""Run tests for all gpu_tp_count values: 1, 2, 4, 8"""
|
||||
tp_values = [1, 2, 4] # Test TP=8
|
||||
all_passed = True
|
||||
results = {}
|
||||
|
||||
print("=" * 60)
|
||||
print("Testing FP8 write_weight_scale_to_buffer for TP = ", tp_values)
|
||||
print("=" * 60)
|
||||
|
||||
for tp in tp_values:
|
||||
print(f"\n{'='*60}")
|
||||
print(f"Testing with gpu_tp_count = {tp}")
|
||||
print(f"{'='*60}")
|
||||
try:
|
||||
test_with_tp(tp)
|
||||
results[tp] = "PASSED"
|
||||
print(f"✓ TP={tp} PASSED")
|
||||
except Exception as e:
|
||||
results[tp] = f"FAILED: {e}"
|
||||
all_passed = False
|
||||
print(f"✗ TP={tp} FAILED: {e}")
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("SUMMARY")
|
||||
print("=" * 60)
|
||||
for tp, result in results.items():
|
||||
status = "✓" if "PASSED" in result else "✗"
|
||||
print(f" {status} TP={tp}: {result}")
|
||||
|
||||
if all_passed:
|
||||
print("\n✓ ALL TESTS PASSED")
|
||||
else:
|
||||
print("\n✗ SOME TESTS FAILED")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -6,11 +6,6 @@ import torch
|
||||
import numpy as np
|
||||
|
||||
|
||||
# Ensure we can import the local extension
|
||||
# REPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "../"))
|
||||
# if REPO_ROOT not in sys.path:
|
||||
# sys.path.insert(0, REPO_ROOT)
|
||||
|
||||
from kt_kernel import kt_kernel_ext
|
||||
from kt_kernel_ext import CPUInfer
|
||||
|
||||
@@ -54,12 +49,12 @@ def allocate_weights(expert_num, hidden_size, intermediate_size, group_size):
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
def test_with_tp(gpu_tp_count):
|
||||
"""Test write_weight_scale_to_buffer with a specific gpu_tp_count"""
|
||||
torch.manual_seed(123)
|
||||
|
||||
expert_num = 256 # Total experts
|
||||
expert_num = 8 # Reduced for faster testing
|
||||
gpu_experts = expert_num # Number of experts on GPU
|
||||
gpu_tp_count = 2 # Number of TP parts
|
||||
|
||||
num_experts_per_tok = 8
|
||||
hidden_size = 7168
|
||||
@@ -94,11 +89,7 @@ def main():
|
||||
cpuinfer.sync()
|
||||
|
||||
# TP configuration
|
||||
|
||||
# Since weights are col-major, we can directly divide the total size by tp_count
|
||||
# Each matrix is divided into gpu_tp_count parts in memory order
|
||||
|
||||
# Calculate sizes per TP part (direct division since col-major)
|
||||
# Calculate sizes per TP part (per expert)
|
||||
weight_bytes_per_expert_per_tp = per_mat_weight_bytes // gpu_tp_count
|
||||
scale_elems_per_expert_per_tp = per_mat_scale_elems // gpu_tp_count
|
||||
|
||||
@@ -107,24 +98,19 @@ def main():
|
||||
total_scale_elems_per_tp = gpu_experts * scale_elems_per_expert_per_tp
|
||||
|
||||
# Create buffer lists for w13 (gate+up) and w2 (down)
|
||||
# These hold all experts' data for each GPU TP
|
||||
w13_weight_bufs = []
|
||||
w13_scale_bufs = []
|
||||
w2_weight_bufs = []
|
||||
w2_scale_bufs = []
|
||||
|
||||
for tp_idx in range(gpu_tp_count):
|
||||
# w13 combines gate and up, so needs 2x the size
|
||||
# w13 combines gate and up, so needs 2x the size per expert
|
||||
w13_weight_bufs.append(torch.empty(2 * total_weight_bytes_per_tp, dtype=torch.uint8))
|
||||
w13_scale_bufs.append(torch.empty(2 * total_scale_elems_per_tp, dtype=torch.bfloat16))
|
||||
w2_weight_bufs.append(torch.empty(total_weight_bytes_per_tp, dtype=torch.uint8))
|
||||
w2_scale_bufs.append(torch.empty(total_scale_elems_per_tp, dtype=torch.bfloat16))
|
||||
|
||||
# Get data pointers for all buffers
|
||||
w13_weight_ptrs = [buf.data_ptr() for buf in w13_weight_bufs]
|
||||
w13_scale_ptrs = [buf.data_ptr() for buf in w13_scale_bufs]
|
||||
w2_weight_ptrs = [buf.data_ptr() for buf in w2_weight_bufs]
|
||||
w2_scale_ptrs = [buf.data_ptr() for buf in w2_scale_bufs]
|
||||
|
||||
print(f"Total experts: {expert_num}, GPU experts: {gpu_experts}")
|
||||
print(f"GPU TP count: {gpu_tp_count}")
|
||||
print(f"Original per matrix weight bytes: {per_mat_weight_bytes}")
|
||||
@@ -133,14 +119,56 @@ def main():
|
||||
print(f"Scale elements per expert per TP: {scale_elems_per_expert_per_tp}")
|
||||
print(f"Total weight bytes per TP (w13): {2 * total_weight_bytes_per_tp}")
|
||||
print(f"Total weight bytes per TP (w2): {total_weight_bytes_per_tp}")
|
||||
print(f"Total scale elements per TP (w13): {2 * total_scale_elems_per_tp}")
|
||||
print(f"Total scale elements per TP (w2): {total_scale_elems_per_tp}")
|
||||
|
||||
for i in range(5):
|
||||
# Helper function to get pointers with expert offset
|
||||
# K2 write_weights_to_buffer writes one expert at a time, so we need to pass
|
||||
# pointers that already point to the correct location for each expert
|
||||
def get_expert_ptrs(expert_id):
|
||||
w13_weight_ptrs = []
|
||||
w13_scale_ptrs = []
|
||||
w2_weight_ptrs = []
|
||||
w2_scale_ptrs = []
|
||||
|
||||
for tp_idx in range(gpu_tp_count):
|
||||
# Calculate byte offsets for this expert
|
||||
# w13: gate_weight + up_weight interleaved by expert
|
||||
# Layout: [expert0_gate, expert0_up, expert1_gate, expert1_up, ...]
|
||||
w13_weight_expert_offset = expert_id * 2 * weight_bytes_per_expert_per_tp
|
||||
w13_scale_expert_offset = expert_id * 2 * scale_elems_per_expert_per_tp
|
||||
w2_weight_expert_offset = expert_id * weight_bytes_per_expert_per_tp
|
||||
w2_scale_expert_offset = expert_id * scale_elems_per_expert_per_tp
|
||||
|
||||
w13_weight_ptrs.append(w13_weight_bufs[tp_idx].data_ptr() + w13_weight_expert_offset)
|
||||
w13_scale_ptrs.append(w13_scale_bufs[tp_idx].data_ptr() + w13_scale_expert_offset * 2) # bf16 = 2 bytes
|
||||
w2_weight_ptrs.append(w2_weight_bufs[tp_idx].data_ptr() + w2_weight_expert_offset)
|
||||
w2_scale_ptrs.append(w2_scale_bufs[tp_idx].data_ptr() + w2_scale_expert_offset * 2) # bf16 = 2 bytes
|
||||
|
||||
return w13_weight_ptrs, w13_scale_ptrs, w2_weight_ptrs, w2_scale_ptrs
|
||||
|
||||
# Warm up
|
||||
for i in range(2):
|
||||
for expert_id in range(gpu_experts):
|
||||
w13_weight_ptrs, w13_scale_ptrs, w2_weight_ptrs, w2_scale_ptrs = get_expert_ptrs(expert_id)
|
||||
cpuinfer.submit(
|
||||
moe.write_weight_scale_to_buffer_task(
|
||||
gpu_tp_count=gpu_tp_count,
|
||||
expert_id=expert_id,
|
||||
w13_weight_ptrs=w13_weight_ptrs,
|
||||
w13_scale_ptrs=w13_scale_ptrs,
|
||||
w2_weight_ptrs=w2_weight_ptrs,
|
||||
w2_scale_ptrs=w2_scale_ptrs,
|
||||
)
|
||||
)
|
||||
cpuinfer.sync()
|
||||
|
||||
# Timing
|
||||
begin_time = time.perf_counter_ns()
|
||||
for expert_id in range(gpu_experts):
|
||||
w13_weight_ptrs, w13_scale_ptrs, w2_weight_ptrs, w2_scale_ptrs = get_expert_ptrs(expert_id)
|
||||
cpuinfer.submit(
|
||||
moe.write_weight_scale_to_buffer_task(
|
||||
gpu_tp_count=gpu_tp_count,
|
||||
gpu_experts_num=gpu_experts,
|
||||
expert_id=expert_id,
|
||||
w13_weight_ptrs=w13_weight_ptrs,
|
||||
w13_scale_ptrs=w13_scale_ptrs,
|
||||
w2_weight_ptrs=w2_weight_ptrs,
|
||||
@@ -148,23 +176,10 @@ def main():
|
||||
)
|
||||
)
|
||||
cpuinfer.sync()
|
||||
|
||||
begin_time = time.perf_counter_ns()
|
||||
cpuinfer.submit(
|
||||
moe.write_weight_scale_to_buffer_task(
|
||||
gpu_tp_count=gpu_tp_count,
|
||||
gpu_experts_num=gpu_experts,
|
||||
w13_weight_ptrs=w13_weight_ptrs,
|
||||
w13_scale_ptrs=w13_scale_ptrs,
|
||||
w2_weight_ptrs=w2_weight_ptrs,
|
||||
w2_scale_ptrs=w2_scale_ptrs,
|
||||
)
|
||||
)
|
||||
cpuinfer.sync()
|
||||
end_time = time.perf_counter_ns()
|
||||
elapsed_ms = (end_time - begin_time) / 1000000
|
||||
total_weights = hidden_size * intermediate_size * expert_num * 3
|
||||
total_bytes = total_weights // group_size + total_weights // 2
|
||||
total_weights = hidden_size * intermediate_size * gpu_experts * 3
|
||||
total_bytes = total_weights // group_size * 2 + total_weights // 2 # scale (bf16) + weight (int4)
|
||||
print(f"write_weight_scale_to_buffer time: {elapsed_ms:.2f} ms")
|
||||
print(f"Throughput: {total_bytes / (elapsed_ms * 1e6):.2f} GB/s")
|
||||
|
||||
@@ -181,9 +196,6 @@ def main():
|
||||
up_scale_experts = split_expert_tensor(up_scale, per_mat_scale_elems)
|
||||
down_scale_experts = split_expert_tensor(down_scale, per_mat_scale_elems)
|
||||
|
||||
# CPU TP count is always 2 in this test setup (one TP per NUMA node)
|
||||
cpu_tp_count = 2
|
||||
|
||||
# Verify buffers for each TP part
|
||||
for tp_idx in range(gpu_tp_count):
|
||||
expected_w13_weights = []
|
||||
@@ -193,22 +205,22 @@ def main():
|
||||
|
||||
weight13_per_tp = per_mat_weight_bytes // gpu_tp_count
|
||||
scale13_per_tp = per_mat_scale_elems // gpu_tp_count
|
||||
# Process each GPU expert
|
||||
for expert_idx in range(gpu_experts):
|
||||
# For w13 (gate and up), the slicing is straightforward
|
||||
|
||||
# Process each GPU expert
|
||||
for expert_id in range(gpu_experts):
|
||||
# For w13 (gate and up), the slicing is straightforward
|
||||
start_weight = tp_idx * weight13_per_tp
|
||||
end_weight = (tp_idx + 1) * weight13_per_tp
|
||||
start_scale = tp_idx * scale13_per_tp
|
||||
end_scale = (tp_idx + 1) * scale13_per_tp
|
||||
|
||||
# Gate
|
||||
gate_weight_tp = gate_q_experts[expert_idx][start_weight:end_weight]
|
||||
gate_scale_tp = gate_scale_experts[expert_idx][start_scale:end_scale]
|
||||
gate_weight_tp = gate_q_experts[expert_id][start_weight:end_weight]
|
||||
gate_scale_tp = gate_scale_experts[expert_id][start_scale:end_scale]
|
||||
|
||||
# Up
|
||||
up_weight_tp = up_q_experts[expert_idx][start_weight:end_weight]
|
||||
up_scale_tp = up_scale_experts[expert_idx][start_scale:end_scale]
|
||||
up_weight_tp = up_q_experts[expert_id][start_weight:end_weight]
|
||||
up_scale_tp = up_scale_experts[expert_id][start_scale:end_scale]
|
||||
|
||||
# Down matrix needs special handling because it's sliced column-wise
|
||||
# We need to reconstruct it from column slices
|
||||
@@ -228,16 +240,17 @@ def main():
|
||||
tp_scale_offset = col_scale_start + tp_idx * tp_slice_scale_size
|
||||
|
||||
down_weight_tp_parts.append(
|
||||
down_q_experts[expert_idx][tp_weight_offset : tp_weight_offset + tp_slice_weight_size]
|
||||
down_q_experts[expert_id][tp_weight_offset : tp_weight_offset + tp_slice_weight_size]
|
||||
)
|
||||
down_scale_tp_parts.append(
|
||||
down_scale_experts[expert_idx][tp_scale_offset : tp_scale_offset + tp_slice_scale_size]
|
||||
down_scale_experts[expert_id][tp_scale_offset : tp_scale_offset + tp_slice_scale_size]
|
||||
)
|
||||
|
||||
# Concatenate all column slices for this TP
|
||||
down_weight_tp = torch.cat(down_weight_tp_parts)
|
||||
down_scale_tp = torch.cat(down_scale_tp_parts)
|
||||
|
||||
# Append to expected lists - interleaved by expert: [gate0, up0, gate1, up1, ...]
|
||||
expected_w13_weights.append(gate_weight_tp)
|
||||
expected_w13_weights.append(up_weight_tp)
|
||||
expected_w13_scales.append(gate_scale_tp)
|
||||
@@ -252,16 +265,85 @@ def main():
|
||||
expected_w2_scale = torch.cat(expected_w2_scales)
|
||||
|
||||
print(f"=== Checking TP part {tp_idx} ===")
|
||||
print(f" w13 weight shape: actual={w13_weight_bufs[tp_idx].shape}, expected={expected_w13_weight.shape}")
|
||||
print(f" w13 scale shape: actual={w13_scale_bufs[tp_idx].shape}, expected={expected_w13_scale.shape}")
|
||||
print(f" w2 weight shape: actual={w2_weight_bufs[tp_idx].shape}, expected={expected_w2_weight.shape}")
|
||||
print(f" w2 scale shape: actual={w2_scale_bufs[tp_idx].shape}, expected={expected_w2_scale.shape}")
|
||||
|
||||
# Assert all checks pass
|
||||
assert torch.equal(w13_weight_bufs[tp_idx], expected_w13_weight), f"w13 weight bytes mismatch for TP {tp_idx}"
|
||||
assert torch.allclose(w13_scale_bufs[tp_idx], expected_w13_scale), f"w13 scale values mismatch for TP {tp_idx}"
|
||||
assert torch.equal(w2_weight_bufs[tp_idx], expected_w2_weight), f"w2 weight bytes mismatch for TP {tp_idx}"
|
||||
assert torch.allclose(w2_scale_bufs[tp_idx], expected_w2_scale), f"w2 scale values mismatch for TP {tp_idx}"
|
||||
if not torch.equal(w13_weight_bufs[tp_idx], expected_w13_weight):
|
||||
diff_mask = w13_weight_bufs[tp_idx] != expected_w13_weight
|
||||
first_diff_idx = diff_mask.nonzero()[0].item() if diff_mask.any() else -1
|
||||
print(f" w13 weight mismatch at index {first_diff_idx}")
|
||||
print(f" actual: {w13_weight_bufs[tp_idx][first_diff_idx:first_diff_idx+10]}")
|
||||
print(f" expected: {expected_w13_weight[first_diff_idx:first_diff_idx+10]}")
|
||||
raise AssertionError(f"w13 weight bytes mismatch for TP {tp_idx}")
|
||||
|
||||
if not torch.allclose(w13_scale_bufs[tp_idx], expected_w13_scale):
|
||||
diff = torch.abs(w13_scale_bufs[tp_idx].float() - expected_w13_scale.float())
|
||||
max_diff_idx = diff.argmax().item()
|
||||
print(f" w13 scale mismatch, max diff at index {max_diff_idx}")
|
||||
print(f" actual: {w13_scale_bufs[tp_idx][max_diff_idx]}")
|
||||
print(f" expected: {expected_w13_scale[max_diff_idx]}")
|
||||
raise AssertionError(f"w13 scale values mismatch for TP {tp_idx}")
|
||||
|
||||
if not torch.equal(w2_weight_bufs[tp_idx], expected_w2_weight):
|
||||
diff_mask = w2_weight_bufs[tp_idx] != expected_w2_weight
|
||||
first_diff_idx = diff_mask.nonzero()[0].item() if diff_mask.any() else -1
|
||||
print(f" w2 weight mismatch at index {first_diff_idx}")
|
||||
print(f" actual: {w2_weight_bufs[tp_idx][first_diff_idx:first_diff_idx+10]}")
|
||||
print(f" expected: {expected_w2_weight[first_diff_idx:first_diff_idx+10]}")
|
||||
raise AssertionError(f"w2 weight bytes mismatch for TP {tp_idx}")
|
||||
|
||||
if not torch.allclose(w2_scale_bufs[tp_idx], expected_w2_scale):
|
||||
diff = torch.abs(w2_scale_bufs[tp_idx].float() - expected_w2_scale.float())
|
||||
max_diff_idx = diff.argmax().item()
|
||||
print(f" w2 scale mismatch, max diff at index {max_diff_idx}")
|
||||
print(f" actual: {w2_scale_bufs[tp_idx][max_diff_idx]}")
|
||||
print(f" expected: {expected_w2_scale[max_diff_idx]}")
|
||||
raise AssertionError(f"w2 scale values mismatch for TP {tp_idx}")
|
||||
|
||||
print(
|
||||
f"\n✓ write_weight_scale_to_buffer passed: extracted {gpu_experts} GPU experts across {gpu_tp_count} TP parts from total {expert_num} experts"
|
||||
f"\n✓ write_weight_scale_to_buffer passed: extracted {gpu_experts} GPU experts across {gpu_tp_count} TP parts"
|
||||
)
|
||||
return True
|
||||
|
||||
|
||||
def main():
|
||||
"""Run tests for all gpu_tp_count values: 1, 2, 4, 8"""
|
||||
tp_values = [1, 2, 4, 8]
|
||||
all_passed = True
|
||||
results = {}
|
||||
|
||||
print("=" * 60)
|
||||
print("Testing K2 write_weight_scale_to_buffer for TP = 1, 2, 4, 8")
|
||||
print("=" * 60)
|
||||
|
||||
for tp in tp_values:
|
||||
print(f"\n{'='*60}")
|
||||
print(f"Testing with gpu_tp_count = {tp}")
|
||||
print(f"{'='*60}")
|
||||
try:
|
||||
test_with_tp(tp)
|
||||
results[tp] = "PASSED"
|
||||
print(f"✓ TP={tp} PASSED")
|
||||
except Exception as e:
|
||||
results[tp] = f"FAILED: {e}"
|
||||
all_passed = False
|
||||
print(f"✗ TP={tp} FAILED: {e}")
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("SUMMARY")
|
||||
print("=" * 60)
|
||||
for tp, result in results.items():
|
||||
status = "✓" if "PASSED" in result else "✗"
|
||||
print(f" {status} TP={tp}: {result}")
|
||||
|
||||
if all_passed:
|
||||
print("\n✓ ALL TESTS PASSED")
|
||||
else:
|
||||
print("\n✗ SOME TESTS FAILED")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -36,6 +36,7 @@ static const bool _is_plain_ = false;
|
||||
|
||||
#if defined(__x86_64__) && defined(USE_AMX_AVX_KERNEL)
|
||||
#include "operators/amx/awq-moe.hpp"
|
||||
#include "operators/amx/fp8-moe.hpp"
|
||||
#include "operators/amx/k2-moe.hpp"
|
||||
#include "operators/amx/la/amx_kernels.hpp"
|
||||
#include "operators/amx/moe.hpp"
|
||||
@@ -255,7 +256,7 @@ void bind_moe_module(py::module_& moe_module, const char* name) {
|
||||
CPUInfer* cpuinfer;
|
||||
MoeClass* moe;
|
||||
int gpu_tp_count;
|
||||
int gpu_experts_num;
|
||||
int expert_id;
|
||||
std::vector<uintptr_t> w13_weight_ptrs;
|
||||
std::vector<uintptr_t> w13_scale_ptrs;
|
||||
std::vector<uintptr_t> w2_weight_ptrs;
|
||||
@@ -265,12 +266,12 @@ void bind_moe_module(py::module_& moe_module, const char* name) {
|
||||
static void inner(void* args) {
|
||||
Args* args_ = (Args*)args;
|
||||
args_->cpuinfer->enqueue(&MoeClass::write_weight_scale_to_buffer, args_->moe, args_->gpu_tp_count,
|
||||
args_->gpu_experts_num, args_->w13_weight_ptrs, args_->w13_scale_ptrs,
|
||||
args_->w2_weight_ptrs, args_->w2_scale_ptrs);
|
||||
args_->expert_id, args_->w13_weight_ptrs, args_->w13_scale_ptrs, args_->w2_weight_ptrs,
|
||||
args_->w2_scale_ptrs);
|
||||
}
|
||||
|
||||
static std::pair<intptr_t, intptr_t> cpuinfer_interface(std::shared_ptr<MoeClass> moe, int gpu_tp_count,
|
||||
int gpu_experts_num, py::list w13_weight_ptrs,
|
||||
int expert_id, py::list w13_weight_ptrs,
|
||||
py::list w13_scale_ptrs, py::list w2_weight_ptrs,
|
||||
py::list w2_scale_ptrs) {
|
||||
// Convert Python lists to std::vector<uintptr_t>
|
||||
@@ -281,15 +282,59 @@ void bind_moe_module(py::module_& moe_module, const char* name) {
|
||||
for (auto item : w2_weight_ptrs) w2_weight_vec.push_back(py::cast<uintptr_t>(item));
|
||||
for (auto item : w2_scale_ptrs) w2_scale_vec.push_back(py::cast<uintptr_t>(item));
|
||||
|
||||
Args* args = new Args{nullptr, moe.get(), gpu_tp_count, gpu_experts_num,
|
||||
Args* args = new Args{nullptr, moe.get(), gpu_tp_count, expert_id,
|
||||
w13_weight_vec, w13_scale_vec, w2_weight_vec, w2_scale_vec};
|
||||
return std::make_pair((intptr_t)&inner, (intptr_t)args);
|
||||
}
|
||||
};
|
||||
|
||||
moe_cls.def("write_weight_scale_to_buffer_task", &WriteWeightScaleToBufferBindings::cpuinfer_interface,
|
||||
py::arg("gpu_tp_count"), py::arg("gpu_experts_num"), py::arg("w13_weight_ptrs"),
|
||||
py::arg("w13_scale_ptrs"), py::arg("w2_weight_ptrs"), py::arg("w2_scale_ptrs"));
|
||||
py::arg("gpu_tp_count"), py::arg("expert_id"), py::arg("w13_weight_ptrs"), py::arg("w13_scale_ptrs"),
|
||||
py::arg("w2_weight_ptrs"), py::arg("w2_scale_ptrs"));
|
||||
}
|
||||
|
||||
// FP8 MoE: processes one expert at a time (expert_id instead of gpu_experts_num)
|
||||
if constexpr (std::is_same_v<MoeTP, AMX_FP8_MOE_TP<amx::GemmKernel224FP8>>) {
|
||||
struct WriteWeightScaleToBufferBindings {
|
||||
struct Args {
|
||||
CPUInfer* cpuinfer;
|
||||
MoeClass* moe;
|
||||
int gpu_tp_count;
|
||||
int expert_id;
|
||||
std::vector<uintptr_t> w13_weight_ptrs;
|
||||
std::vector<uintptr_t> w13_scale_ptrs;
|
||||
std::vector<uintptr_t> w2_weight_ptrs;
|
||||
std::vector<uintptr_t> w2_scale_ptrs;
|
||||
};
|
||||
|
||||
static void inner(void* args) {
|
||||
Args* args_ = (Args*)args;
|
||||
args_->cpuinfer->enqueue(&MoeClass::write_weight_scale_to_buffer, args_->moe, args_->gpu_tp_count,
|
||||
args_->expert_id, args_->w13_weight_ptrs, args_->w13_scale_ptrs, args_->w2_weight_ptrs,
|
||||
args_->w2_scale_ptrs);
|
||||
}
|
||||
|
||||
static std::pair<intptr_t, intptr_t> cpuinfer_interface(std::shared_ptr<MoeClass> moe, int gpu_tp_count,
|
||||
int expert_id, py::list w13_weight_ptrs,
|
||||
py::list w13_scale_ptrs, py::list w2_weight_ptrs,
|
||||
py::list w2_scale_ptrs) {
|
||||
// Convert Python lists to std::vector<uintptr_t>
|
||||
std::vector<uintptr_t> w13_weight_vec, w13_scale_vec, w2_weight_vec, w2_scale_vec;
|
||||
|
||||
for (auto item : w13_weight_ptrs) w13_weight_vec.push_back(py::cast<uintptr_t>(item));
|
||||
for (auto item : w13_scale_ptrs) w13_scale_vec.push_back(py::cast<uintptr_t>(item));
|
||||
for (auto item : w2_weight_ptrs) w2_weight_vec.push_back(py::cast<uintptr_t>(item));
|
||||
for (auto item : w2_scale_ptrs) w2_scale_vec.push_back(py::cast<uintptr_t>(item));
|
||||
|
||||
Args* args = new Args{nullptr, moe.get(), gpu_tp_count, expert_id,
|
||||
w13_weight_vec, w13_scale_vec, w2_weight_vec, w2_scale_vec};
|
||||
return std::make_pair((intptr_t)&inner, (intptr_t)args);
|
||||
}
|
||||
};
|
||||
|
||||
moe_cls.def("write_weight_scale_to_buffer_task", &WriteWeightScaleToBufferBindings::cpuinfer_interface,
|
||||
py::arg("gpu_tp_count"), py::arg("expert_id"), py::arg("w13_weight_ptrs"), py::arg("w13_scale_ptrs"),
|
||||
py::arg("w2_weight_ptrs"), py::arg("w2_scale_ptrs"));
|
||||
}
|
||||
#endif
|
||||
}
|
||||
@@ -562,6 +607,7 @@ PYBIND11_MODULE(kt_kernel_ext, m) {
|
||||
bind_moe_module<AMX_MOE_TP<amx::GemmKernel224Int4_1>>(moe_module, "AMXInt4_1_MOE");
|
||||
bind_moe_module<AMX_AWQ_MOE_TP<amx::GemmKernel224Int4_1_LowKGroup>>(moe_module, "AMXInt4_1KGroup_MOE");
|
||||
bind_moe_module<AMX_K2_MOE_TP<amx::GemmKernel224Int4SmallKGroup>>(moe_module, "AMXInt4_KGroup_MOE");
|
||||
bind_moe_module<AMX_FP8_MOE_TP<amx::GemmKernel224FP8>>(moe_module, "AMXFP8_MOE");
|
||||
#endif
|
||||
#if defined(USE_MOE_KERNEL)
|
||||
bind_moe_module<MOE_KERNEL_TP<moe_kernel::GemmKernelInt8, _is_plain_>>(moe_module, "Int8_KERNEL_MOE");
|
||||
|
||||
@@ -1,73 +1,49 @@
|
||||
/**
|
||||
* @Description :
|
||||
* @Author : chenht2022
|
||||
* @Description : AWQ Int4 AMX MoE operator with KGroup quantization and zero-point support
|
||||
* @Author : chenht2022, oql
|
||||
* @Date : 2024-07-22 02:03:22
|
||||
* @Version : 1.0.0
|
||||
* @LastEditors : chenht2022
|
||||
* @LastEditTime : 2024-07-25 10:35:10
|
||||
* @Version : 2.0.0
|
||||
* @LastEditors : oql
|
||||
* @LastEditTime : 2025-12-10
|
||||
* @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
|
||||
*
|
||||
* This file implements AWQ Int4 MoE using CRTP pattern, inheriting from moe_base.hpp.
|
||||
* AWQ weights are stored with group-wise scales and zero-points (KGroup Int4 with zeros).
|
||||
**/
|
||||
#ifndef CPUINFER_OPERATOR_AMX_AWQ_MOE_H
|
||||
#define CPUINFER_OPERATOR_AMX_AWQ_MOE_H
|
||||
|
||||
// #define CHECK
|
||||
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
#include <cstring>
|
||||
// #define FORWARD_TIME_PROFILE
|
||||
// #define FORWARD_TIME_REPORT
|
||||
|
||||
#include <immintrin.h>
|
||||
|
||||
#include <cmath>
|
||||
#include <cstdio>
|
||||
#include <filesystem>
|
||||
#include <fstream>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "../../cpu_backend/shared_mem_buffer.h"
|
||||
#include "../../cpu_backend/worker_pool.h"
|
||||
#include "../common.hpp"
|
||||
#include "../moe-tp.hpp"
|
||||
#include "la/amx.hpp"
|
||||
#include "llama.cpp/ggml.h"
|
||||
#include "moe_base.hpp"
|
||||
|
||||
/**
|
||||
* @brief AWQ Int4 MoE operator using CRTP pattern
|
||||
* @tparam T Kernel type for AWQ quantization
|
||||
*
|
||||
* This class provides AWQ-specific implementations:
|
||||
* - do_gate_up_gemm: Int4 weight with KGroup scale + zeros + AMX GEMM
|
||||
* - do_down_gemm: Same Int4 KGroup GEMM
|
||||
* - load_weights: Load Int4 weights with group-wise scales and zero-points
|
||||
*/
|
||||
template <class T>
|
||||
class AMX_AWQ_MOE_TP {
|
||||
class AMX_AWQ_MOE_TP : public AMX_MOE_BASE<T, AMX_AWQ_MOE_TP<T>> {
|
||||
private:
|
||||
int tp_part_idx;
|
||||
using Base = AMX_MOE_BASE<T, AMX_AWQ_MOE_TP<T>>;
|
||||
using Base::config_;
|
||||
using Base::tp_part_idx;
|
||||
using Base::gate_bb_;
|
||||
using Base::up_bb_;
|
||||
using Base::down_bb_;
|
||||
using Base::gate_up_ba_;
|
||||
using Base::gate_bc_;
|
||||
using Base::up_bc_;
|
||||
using Base::down_ba_;
|
||||
using Base::down_bc_;
|
||||
using Base::m_local_num_;
|
||||
|
||||
std::filesystem::path prefix;
|
||||
|
||||
void* gate_proj_; // [expert_num * intermediate_size * hidden_size ( /32 if
|
||||
// quantized)]
|
||||
void* up_proj_; // [expert_num * intermediate_size * hidden_size ( /32 if
|
||||
// quantized)]
|
||||
void* down_proj_; // [expert_num * hidden_size * intermediate_size ( /32 if
|
||||
// quantized)]
|
||||
|
||||
ggml_bf16_t* m_local_input_; // [num_experts_per_tok * max_len * hidden_size]
|
||||
ggml_bf16_t* m_local_gate_output_; // [num_experts_per_tok * max_len * intermediate_size]
|
||||
ggml_bf16_t* m_local_up_output_; // [num_experts_per_tok * max_len * intermediate_size]
|
||||
ggml_bf16_t* m_local_down_output_; // [num_experts_per_tok * max_len * hidden_size]
|
||||
|
||||
std::vector<std::vector<int>> m_local_pos_; // [max_len, num_experts_per_tok]
|
||||
std::vector<int> m_local_num_; // [expert_num]
|
||||
std::vector<int> m_expert_id_map_; // [expert_num]
|
||||
std::vector<ggml_bf16_t*> m_local_input_ptr_; // [expert_num]
|
||||
std::vector<ggml_bf16_t*> m_local_gate_output_ptr_; // [expert_num]
|
||||
std::vector<ggml_bf16_t*> m_local_up_output_ptr_; // [expert_num]
|
||||
std::vector<ggml_bf16_t*> m_local_down_output_ptr_; // [expert_num]
|
||||
|
||||
std::vector<std::shared_ptr<typename T::BufferA>> gate_up_ba_;
|
||||
std::vector<std::shared_ptr<typename T::BufferB>> gate_bb_;
|
||||
std::vector<std::shared_ptr<typename T::BufferC>> gate_bc_;
|
||||
std::vector<std::shared_ptr<typename T::BufferB>> up_bb_;
|
||||
std::vector<std::shared_ptr<typename T::BufferC>> up_bc_;
|
||||
std::vector<std::shared_ptr<typename T::BufferA>> down_ba_;
|
||||
std::vector<std::shared_ptr<typename T::BufferB>> down_bb_;
|
||||
std::vector<std::shared_ptr<typename T::BufferC>> down_bc_;
|
||||
#ifdef CHECK
|
||||
char verify_bb[100000000];
|
||||
char check_bb[100000000];
|
||||
@@ -274,32 +250,35 @@ class AMX_AWQ_MOE_TP {
|
||||
zeros_size / mat_split);
|
||||
zeros_file.close();
|
||||
}
|
||||
|
||||
#ifdef CHECK
|
||||
inline void load_check() {
|
||||
memcpy(check_bb, (char*)down_bb_[compare_expers]->b,
|
||||
T::BufferB::required_size(config_.hidden_size, config_.intermediate_size));
|
||||
T::BufferB::required_size(config_.hidden_size, config_.intermediate_size, config_.quant_config.group_size));
|
||||
}
|
||||
|
||||
void verify_load_right() {
|
||||
// printf("varify down bb_0 %d\n", tp_part_idx);
|
||||
memcpy(verify_bb, (char*)down_bb_[compare_expers]->b,
|
||||
T::BufferB::required_size(config_.hidden_size, config_.intermediate_size));
|
||||
// check if verify_bb_0 equal to check_bb_0
|
||||
if (memcmp(verify_bb, check_bb, T::BufferB::required_size(config_.hidden_size, config_.intermediate_size)) != 0) {
|
||||
T::BufferB::required_size(config_.hidden_size, config_.intermediate_size, config_.quant_config.group_size));
|
||||
if (memcmp(verify_bb, check_bb,
|
||||
T::BufferB::required_size(config_.hidden_size, config_.intermediate_size,
|
||||
config_.quant_config.group_size)) != 0) {
|
||||
printf("verify error\n");
|
||||
for (size_t i = 0; i < T::BufferB::required_size(config_.hidden_size, config_.intermediate_size); ++i) {
|
||||
for (size_t i = 0; i < T::BufferB::required_size(config_.hidden_size, config_.intermediate_size,
|
||||
config_.quant_config.group_size);
|
||||
++i) {
|
||||
if (verify_bb[i] != check_bb[i]) {
|
||||
printf("Difference at byte %zu: verify_bb_%d[%zu] = %02x, check_bb[%zu] = %02x\n", i, compare_expers, i,
|
||||
(unsigned char)verify_bb[i], i, (unsigned char)check_bb[i]);
|
||||
break; // find the first difference and exit
|
||||
break;
|
||||
}
|
||||
}
|
||||
assert(0);
|
||||
} else {
|
||||
printf("pass verify\n");
|
||||
// pick out the 100th~150th byte of scale to see
|
||||
printf("numa %d, verify_bb_%d:\n", tp_part_idx, compare_expers);
|
||||
size_t size = T::BufferB::required_size(config_.hidden_size, config_.intermediate_size);
|
||||
size_t size =
|
||||
T::BufferB::required_size(config_.hidden_size, config_.intermediate_size, config_.quant_config.group_size);
|
||||
size_t scale_size = config_.hidden_size * sizeof(float);
|
||||
for (size_t i = size - scale_size; i < size - scale_size + 50; ++i) {
|
||||
printf("%02x ", (unsigned char)verify_bb[i]);
|
||||
@@ -392,7 +371,7 @@ class AMX_AWQ_MOE_TP {
|
||||
}
|
||||
|
||||
// AVX-optimized function to convert INT4 zeros to float mins
|
||||
// mins = zeros * scales (element-wise), where scales is float format
|
||||
// mins = -(zeros * scales) (element-wise), where scales is float format
|
||||
inline void convert_zeros_to_mins_avx(const uint32_t* zeros_int4_packed, const float* scales, float* mins,
|
||||
size_t num_elements) {
|
||||
constexpr size_t simd_width = 8; // 每次解 8 个 int4
|
||||
@@ -408,30 +387,25 @@ class AMX_AWQ_MOE_TP {
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef FORWARD_TIME_REPORT
|
||||
std::chrono::time_point<std::chrono::high_resolution_clock> last_now;
|
||||
#endif
|
||||
|
||||
public:
|
||||
using input_t = ggml_bf16_t;
|
||||
using output_t = float;
|
||||
GeneralMOEConfig config_;
|
||||
static constexpr double ELEMENT_SIZE = T::ELEMENT_SIZE;
|
||||
using typename Base::input_t;
|
||||
using typename Base::output_t;
|
||||
|
||||
AMX_AWQ_MOE_TP(GeneralMOEConfig config, int tp_part_idx) {
|
||||
auto& quant_config = config.quant_config;
|
||||
int& group_size = quant_config.group_size;
|
||||
AMX_AWQ_MOE_TP() = default;
|
||||
|
||||
AMX_AWQ_MOE_TP(GeneralMOEConfig config, int tp_part_idx_ = 0) : Base(config, tp_part_idx_) {
|
||||
auto& quant_config = config_.quant_config;
|
||||
if (quant_config.group_size == 0 || !quant_config.zero_point) {
|
||||
throw std::runtime_error("AWQ-Quantization AMX MoE only support KGroup Int4_1");
|
||||
}
|
||||
auto& load = config.load;
|
||||
auto& save = config.save;
|
||||
if (load && config.path == "") {
|
||||
load = false;
|
||||
}
|
||||
|
||||
prefix = config.path;
|
||||
prefix = prefix / ("_layer_" + std::to_string(config.layer_idx)) / ("_numa_" + std::to_string(tp_part_idx));
|
||||
printf("Creating AMX_AWQ_MOE_TP %d at numa %d\n", tp_part_idx_, numa_node_of_cpu(sched_getcpu()));
|
||||
|
||||
auto& load = config_.load;
|
||||
auto& save = config_.save;
|
||||
|
||||
prefix = config_.path;
|
||||
prefix = prefix / ("_layer_" + std::to_string(config_.layer_idx)) / ("_numa_" + std::to_string(tp_part_idx_));
|
||||
if (save) {
|
||||
std::cout << "Creating " << prefix << std::endl;
|
||||
std::filesystem::create_directories(prefix);
|
||||
@@ -443,77 +417,74 @@ class AMX_AWQ_MOE_TP {
|
||||
throw std::runtime_error("Path not found: " + prefix.string());
|
||||
}
|
||||
}
|
||||
|
||||
this->tp_part_idx = tp_part_idx;
|
||||
config_ = config;
|
||||
gate_proj_ = config_.gate_proj;
|
||||
up_proj_ = config_.up_proj;
|
||||
down_proj_ = config_.down_proj;
|
||||
|
||||
MemoryRequest mem_requests;
|
||||
mem_requests.append_pointer(
|
||||
&m_local_input_, sizeof(ggml_bf16_t) * config_.num_experts_per_tok * config_.max_len * config_.hidden_size);
|
||||
mem_requests.append_pointer(&m_local_gate_output_, sizeof(ggml_bf16_t) * config_.num_experts_per_tok *
|
||||
config_.max_len * config_.intermediate_size);
|
||||
mem_requests.append_pointer(&m_local_up_output_, sizeof(ggml_bf16_t) * config_.num_experts_per_tok *
|
||||
config_.max_len * config_.intermediate_size);
|
||||
mem_requests.append_pointer(&m_local_down_output_, sizeof(ggml_bf16_t) * config_.num_experts_per_tok *
|
||||
config_.max_len * config_.hidden_size);
|
||||
|
||||
m_local_pos_.resize(config_.max_len);
|
||||
for (int i = 0; i < config_.max_len; i++) {
|
||||
m_local_pos_[i].resize(config_.num_experts_per_tok);
|
||||
}
|
||||
m_expert_id_map_.resize(config_.expert_num);
|
||||
m_local_num_.resize(config_.expert_num);
|
||||
m_local_input_ptr_.resize(config_.expert_num);
|
||||
m_local_gate_output_ptr_.resize(config_.expert_num);
|
||||
m_local_up_output_ptr_.resize(config_.expert_num);
|
||||
m_local_down_output_ptr_.resize(config_.expert_num);
|
||||
|
||||
for (size_t i = 0; i < config_.expert_num; i++) {
|
||||
gate_up_ba_.push_back(
|
||||
std::make_shared<typename T::BufferA>(config_.max_len, config_.hidden_size, group_size, nullptr));
|
||||
gate_bc_.push_back(std::make_shared<typename T::BufferC>(config_.max_len, config_.intermediate_size, nullptr));
|
||||
up_bc_.push_back(std::make_shared<typename T::BufferC>(config_.max_len, config_.intermediate_size, nullptr));
|
||||
down_ba_.push_back(
|
||||
std::make_shared<typename T::BufferA>(config_.max_len, config_.intermediate_size, group_size, nullptr));
|
||||
down_bc_.push_back(std::make_shared<typename T::BufferC>(config_.max_len, config_.hidden_size, nullptr));
|
||||
|
||||
void* gate_bb_ptr =
|
||||
std::aligned_alloc(64, T::BufferB::required_size(config_.intermediate_size, config_.hidden_size, group_size));
|
||||
gate_bb_.push_back(std::make_shared<typename T::BufferB>(config_.intermediate_size, config_.hidden_size,
|
||||
group_size, gate_bb_ptr));
|
||||
|
||||
void* up_bb_ptr =
|
||||
std::aligned_alloc(64, T::BufferB::required_size(config_.intermediate_size, config_.hidden_size, group_size));
|
||||
up_bb_.push_back(
|
||||
std::make_shared<typename T::BufferB>(config_.intermediate_size, config_.hidden_size, group_size, up_bb_ptr));
|
||||
|
||||
void* down_bb_ptr =
|
||||
std::aligned_alloc(64, T::BufferB::required_size(config_.hidden_size, config_.intermediate_size, group_size));
|
||||
down_bb_.push_back(std::make_shared<typename T::BufferB>(config_.hidden_size, config_.intermediate_size,
|
||||
group_size, down_bb_ptr));
|
||||
}
|
||||
for (int i = 0; i < config_.expert_num; i++) {
|
||||
mem_requests.append_function([this, i](void* new_ptr) { gate_up_ba_[i]->set_data(new_ptr); },
|
||||
T::BufferA::required_size(config_.max_len, config_.hidden_size, group_size));
|
||||
mem_requests.append_function([this, i](void* new_ptr) { gate_bc_[i]->set_data(new_ptr); },
|
||||
T::BufferC::required_size(config_.max_len, config_.intermediate_size));
|
||||
mem_requests.append_function([this, i](void* new_ptr) { up_bc_[i]->set_data(new_ptr); },
|
||||
T::BufferC::required_size(config_.max_len, config_.intermediate_size));
|
||||
mem_requests.append_function([this, i](void* new_ptr) { down_ba_[i]->set_data(new_ptr); },
|
||||
T::BufferA::required_size(config_.max_len, config_.intermediate_size, group_size));
|
||||
mem_requests.append_function([this, i](void* new_ptr) { down_bc_[i]->set_data(new_ptr); },
|
||||
T::BufferC::required_size(config_.max_len, config_.hidden_size));
|
||||
}
|
||||
shared_mem_buffer_numa.alloc(tp_part_idx, this, mem_requests);
|
||||
}
|
||||
|
||||
~AMX_AWQ_MOE_TP() {
|
||||
// shared_mem_buffer_numa.dealloc(this);
|
||||
~AMX_AWQ_MOE_TP() = default;
|
||||
|
||||
// ============================================================================
|
||||
// CRTP buffer creation - with group_size (AWQ uses zero-point)
|
||||
// ============================================================================
|
||||
|
||||
size_t buffer_a_required_size_impl(size_t m, size_t k) const {
|
||||
return T::BufferA::required_size(m, k, config_.quant_config.group_size);
|
||||
}
|
||||
size_t buffer_b_required_size_impl(size_t n, size_t k) const {
|
||||
return T::BufferB::required_size(n, k, config_.quant_config.group_size);
|
||||
}
|
||||
size_t buffer_c_required_size_impl(size_t m, size_t n) const {
|
||||
return T::BufferC::required_size(m, n);
|
||||
}
|
||||
|
||||
std::shared_ptr<typename T::BufferA> make_buffer_a_impl(size_t m, size_t k, void* data) const {
|
||||
return std::make_shared<typename T::BufferA>(m, k, config_.quant_config.group_size, data);
|
||||
}
|
||||
std::shared_ptr<typename T::BufferB> make_buffer_b_impl(size_t n, size_t k, void* data) const {
|
||||
return std::make_shared<typename T::BufferB>(n, k, config_.quant_config.group_size, data);
|
||||
}
|
||||
std::shared_ptr<typename T::BufferC> make_buffer_c_impl(size_t m, size_t n, void* data) const {
|
||||
return std::make_shared<typename T::BufferC>(m, n, data);
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// CRTP virtual points - GEMM dispatch (uses kgroup with zeros)
|
||||
// ============================================================================
|
||||
|
||||
void do_gate_up_gemm(bool do_up, int expert_idx, int ith, int nth, int qlen) {
|
||||
auto& group_size = config_.quant_config.group_size;
|
||||
int m = m_local_num_[expert_idx];
|
||||
auto& ba = gate_up_ba_[expert_idx];
|
||||
auto& bb = do_up ? up_bb_[expert_idx] : gate_bb_[expert_idx];
|
||||
auto& bc = do_up ? up_bc_[expert_idx] : gate_bc_[expert_idx];
|
||||
|
||||
// Dispatch based on qlen threshold
|
||||
if (qlen > 4 * config_.expert_num / config_.num_experts_per_tok) {
|
||||
amx::mat_mul_kgroup(m, config_.intermediate_size, config_.hidden_size, group_size, ba, bb, bc, ith, nth);
|
||||
} else {
|
||||
amx::vec_mul_kgroup(m, config_.intermediate_size, config_.hidden_size, group_size, ba, bb, bc, ith, nth);
|
||||
}
|
||||
}
|
||||
|
||||
void do_down_gemm(int expert_idx, int ith, int nth, int qlen) {
|
||||
auto& group_size = config_.quant_config.group_size;
|
||||
int m = m_local_num_[expert_idx];
|
||||
|
||||
if (qlen > 4 * config_.expert_num / config_.num_experts_per_tok) {
|
||||
amx::mat_mul_kgroup(m, config_.hidden_size, config_.intermediate_size, group_size, down_ba_[expert_idx],
|
||||
down_bb_[expert_idx], down_bc_[expert_idx], ith, nth);
|
||||
} else {
|
||||
amx::vec_mul_kgroup(m, config_.hidden_size, config_.intermediate_size, group_size, down_ba_[expert_idx],
|
||||
down_bb_[expert_idx], down_bc_[expert_idx], ith, nth);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Load Int4 weights with scales and zero-points
|
||||
*
|
||||
* AWQ weights include:
|
||||
* - Quantized INT4 weights
|
||||
* - FP16 scales (converted to FP32)
|
||||
* - INT4 zeros (converted to FP32 mins = -scale * zero)
|
||||
*/
|
||||
void load_weights() {
|
||||
auto& quant_config = config_.quant_config;
|
||||
int& group_size = quant_config.group_size;
|
||||
@@ -524,15 +495,12 @@ class AMX_AWQ_MOE_TP {
|
||||
|
||||
auto pool = config_.pool->get_subpool(tp_part_idx);
|
||||
if (config_.gate_projs.size()) {
|
||||
throw std::runtime_error("AMX load weights is not support");
|
||||
throw std::runtime_error("AMX load weights from gate_projs is not supported");
|
||||
} else {
|
||||
// AWQ Load from file implementation
|
||||
int nth = T::recommended_nth(config_.intermediate_size);
|
||||
static uint8_t mat_type_all = 3, mat_split = 1;
|
||||
if (config_.load) {
|
||||
throw std::runtime_error("AMX load weights from file is not support");
|
||||
throw std::runtime_error("AMX load weights from file is not supported");
|
||||
}
|
||||
// check process, store down matrix to check
|
||||
#ifdef CHECK
|
||||
load_check();
|
||||
#endif
|
||||
@@ -540,7 +508,7 @@ class AMX_AWQ_MOE_TP {
|
||||
else if (config_.gate_scale != nullptr)
|
||||
#endif
|
||||
{
|
||||
// Loading quantized weights
|
||||
// Loading quantized weights with scales and zeros
|
||||
pool->do_work_stealing_job(
|
||||
nth * config_.expert_num, nullptr,
|
||||
[this, nth, physical_to_logical_map](int task_id) {
|
||||
@@ -594,7 +562,7 @@ class AMX_AWQ_MOE_TP {
|
||||
(ggml_fp16_t*)config_.down_scale + (logical_expert_id * scale_elem_count),
|
||||
scale_elem_count);
|
||||
|
||||
// Convert INT4 zeros to FP32 mins
|
||||
// Convert INT4 zeros to FP32 mins: mins = -(scale * zero)
|
||||
convert_zeros_to_mins_avx(
|
||||
(const uint32_t*)((uint8_t*)config_.gate_zero + ((logical_expert_id * scale_elem_count) >> 1)),
|
||||
gate_bb_[expert_idx]->d, gate_bb_[expert_idx]->mins, scale_elem_count);
|
||||
@@ -617,7 +585,7 @@ class AMX_AWQ_MOE_TP {
|
||||
}
|
||||
}
|
||||
else {
|
||||
// Online Quantization
|
||||
// Online Quantization from BF16
|
||||
assert(config_.gate_proj != nullptr);
|
||||
|
||||
pool->do_work_stealing_job(
|
||||
@@ -668,450 +636,21 @@ class AMX_AWQ_MOE_TP {
|
||||
}
|
||||
}
|
||||
|
||||
void warm_up() {
|
||||
int qlen = config_.max_len;
|
||||
std::vector<uint8_t> input(sizeof(ggml_bf16_t) * qlen * config_.hidden_size);
|
||||
std::vector<uint8_t> output(sizeof(ggml_bf16_t) * qlen * config_.hidden_size);
|
||||
std::vector<int64_t> expert_ids(qlen * config_.num_experts_per_tok);
|
||||
std::vector<float> weights(qlen * config_.num_experts_per_tok);
|
||||
for (int i = 0; i < qlen * config_.num_experts_per_tok; i++) {
|
||||
expert_ids[i] = i % config_.expert_num;
|
||||
weights[i] = 0.01;
|
||||
}
|
||||
forward(qlen, config_.num_experts_per_tok, expert_ids.data(), weights.data(), input.data(), output.data());
|
||||
}
|
||||
|
||||
void forward(int qlen, int k, const int64_t* expert_ids, const float* weights, const void* input, void* output) {
|
||||
if (qlen > 1) {
|
||||
forward_prefill(qlen, k, expert_ids, weights, input, output);
|
||||
} else {
|
||||
forward_decode(k, expert_ids, weights, input, output);
|
||||
}
|
||||
}
|
||||
|
||||
#define DIRECT_OR_POOL_BY_QLEN(var, fn) \
|
||||
do { \
|
||||
if (qlen < 10) { \
|
||||
for (int i = 0; i < (var); i++) { \
|
||||
(fn)(i); \
|
||||
} \
|
||||
} else { \
|
||||
pool->do_work_stealing_job((var), nullptr, (fn), nullptr); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
#define MATMUL_OR_VECMUL_KGROUP_BY_QLEN(...) \
|
||||
do { \
|
||||
if (qlen > 4 * config_.expert_num / config_.num_experts_per_tok) { \
|
||||
amx::mat_mul_kgroup(__VA_ARGS__); \
|
||||
} else { \
|
||||
amx::vec_mul_kgroup(__VA_ARGS__); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
void forward_prefill(int qlen, int k, const int64_t* expert_ids, const float* weights, const void* input,
|
||||
void* output) {
|
||||
auto pool = config_.pool->get_subpool(tp_part_idx);
|
||||
auto& quant_config = config_.quant_config;
|
||||
int& group_size = quant_config.group_size;
|
||||
#ifdef FORWARD_TIME_PROFILE
|
||||
auto start_time = std::chrono::high_resolution_clock::now();
|
||||
auto last = start_time;
|
||||
// 用于保存各阶段耗时(单位:微秒)
|
||||
long prepare_time = 0, cpy_input_time = 0, q_input_time = 0, up_gate_time = 0;
|
||||
long act_time = 0, q_down_time = 0, down_time = 0, weight_time = 0;
|
||||
int max_local_num = 0; // 记录最大的 local num
|
||||
#endif
|
||||
|
||||
int activated_expert = 0;
|
||||
for (int i = 0; i < config_.expert_num; i++) {
|
||||
m_local_num_[i] = 0;
|
||||
}
|
||||
for (int i = 0; i < qlen; i++) {
|
||||
for (int j = 0; j < k; j++) {
|
||||
if (expert_ids[i * k + j] < config_.num_gpu_experts || expert_ids[i * k + j] >= config_.expert_num) {
|
||||
continue;
|
||||
}
|
||||
m_local_pos_[i][j] = m_local_num_[expert_ids[i * k + j]]++;
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = 0; i < config_.expert_num; i++) {
|
||||
if (m_local_num_[i] > 0) {
|
||||
#ifdef FORWARD_TIME_PROFILE
|
||||
max_local_num = std::max(max_local_num, m_local_num_[i]);
|
||||
#endif
|
||||
m_expert_id_map_[activated_expert] = i;
|
||||
activated_expert++;
|
||||
}
|
||||
}
|
||||
|
||||
// activated_expert 已经统计完成
|
||||
|
||||
size_t offset = 0;
|
||||
for (int i = 0; i < config_.expert_num; i++) {
|
||||
m_local_input_ptr_[i] = m_local_input_ + offset * config_.hidden_size;
|
||||
m_local_gate_output_ptr_[i] = m_local_gate_output_ + offset * config_.intermediate_size;
|
||||
m_local_up_output_ptr_[i] = m_local_up_output_ + offset * config_.intermediate_size;
|
||||
m_local_down_output_ptr_[i] = m_local_down_output_ + offset * config_.hidden_size;
|
||||
offset += m_local_num_[i];
|
||||
}
|
||||
#ifdef FORWARD_TIME_PROFILE
|
||||
{
|
||||
auto now_time = std::chrono::high_resolution_clock::now();
|
||||
prepare_time = std::chrono::duration_cast<std::chrono::microseconds>(now_time - last).count();
|
||||
last = now_time;
|
||||
}
|
||||
#endif
|
||||
|
||||
DIRECT_OR_POOL_BY_QLEN(qlen, [&](int i) {
|
||||
for (int j = 0; j < k; j++) {
|
||||
if (expert_ids[i * k + j] < config_.num_gpu_experts || expert_ids[i * k + j] >= config_.expert_num) {
|
||||
continue;
|
||||
}
|
||||
memcpy(m_local_input_ptr_[expert_ids[i * k + j]] + m_local_pos_[i][j] * config_.hidden_size,
|
||||
(ggml_bf16_t*)input + i * config_.hidden_size, sizeof(ggml_bf16_t) * config_.hidden_size);
|
||||
}
|
||||
});
|
||||
|
||||
#ifdef FORWARD_TIME_PROFILE
|
||||
{
|
||||
auto now_time = std::chrono::high_resolution_clock::now();
|
||||
cpy_input_time = std::chrono::duration_cast<std::chrono::microseconds>(now_time - last).count();
|
||||
last = now_time;
|
||||
}
|
||||
#endif
|
||||
|
||||
DIRECT_OR_POOL_BY_QLEN(activated_expert, [this](int task_id) {
|
||||
int expert_idx = m_expert_id_map_[task_id];
|
||||
gate_up_ba_[expert_idx]->from_mat(m_local_num_[expert_idx], m_local_input_ptr_[expert_idx], 0, 1);
|
||||
});
|
||||
|
||||
#ifdef FORWARD_TIME_PROFILE
|
||||
{
|
||||
auto now_time = std::chrono::high_resolution_clock::now();
|
||||
q_input_time = std::chrono::duration_cast<std::chrono::microseconds>(now_time - last).count();
|
||||
last = now_time;
|
||||
}
|
||||
#endif
|
||||
|
||||
int nth = T::recommended_nth(config_.intermediate_size);
|
||||
pool->do_work_stealing_job(
|
||||
nth * activated_expert * 2, [](int _) { T::config(); },
|
||||
[this, nth, qlen](int task_id2) {
|
||||
int& group_size = config_.quant_config.group_size;
|
||||
int task_id = task_id2 / 2;
|
||||
bool do_up = task_id2 % 2;
|
||||
int expert_idx = m_expert_id_map_[task_id / nth];
|
||||
|
||||
int ith = task_id % nth;
|
||||
if (do_up) {
|
||||
MATMUL_OR_VECMUL_KGROUP_BY_QLEN(m_local_num_[expert_idx], config_.intermediate_size, config_.hidden_size,
|
||||
group_size, gate_up_ba_[expert_idx], up_bb_[expert_idx], up_bc_[expert_idx],
|
||||
ith, nth);
|
||||
up_bc_[expert_idx]->to_mat(m_local_num_[expert_idx], m_local_up_output_ptr_[expert_idx], ith, nth);
|
||||
} else {
|
||||
MATMUL_OR_VECMUL_KGROUP_BY_QLEN(m_local_num_[expert_idx], config_.intermediate_size, config_.hidden_size,
|
||||
group_size, gate_up_ba_[expert_idx], gate_bb_[expert_idx],
|
||||
gate_bc_[expert_idx], ith, nth);
|
||||
gate_bc_[expert_idx]->to_mat(m_local_num_[expert_idx], m_local_gate_output_ptr_[expert_idx], ith, nth);
|
||||
}
|
||||
},
|
||||
nullptr);
|
||||
|
||||
#ifdef FORWARD_TIME_PROFILE
|
||||
{
|
||||
auto now_time = std::chrono::high_resolution_clock::now();
|
||||
up_gate_time = std::chrono::duration_cast<std::chrono::microseconds>(now_time - last).count();
|
||||
last = now_time;
|
||||
}
|
||||
#endif
|
||||
|
||||
auto up_gate_fn = [this, nth](int task_id) {
|
||||
int expert_idx = m_expert_id_map_[task_id / nth];
|
||||
int ith = task_id % nth;
|
||||
auto [n_start, n_end] = T::split_range_n(config_.intermediate_size, ith, nth);
|
||||
for (int i = 0; i < m_local_num_[expert_idx]; i++) {
|
||||
ggml_bf16_t* gate_output_ptr = &m_local_gate_output_ptr_[expert_idx][i * config_.intermediate_size];
|
||||
ggml_bf16_t* up_output_ptr = &m_local_up_output_ptr_[expert_idx][i * config_.intermediate_size];
|
||||
for (int j = n_start; j < n_end; j += 32) {
|
||||
__m512 gate_val0, gate_val1, up_val0, up_val1;
|
||||
avx512_32xbf16_to_32xfp32((__m512i*)(gate_output_ptr + j), &gate_val0, &gate_val1);
|
||||
avx512_32xbf16_to_32xfp32((__m512i*)(up_output_ptr + j), &up_val0, &up_val1);
|
||||
__m512 result0 = amx::act_fn(gate_val0, up_val0);
|
||||
__m512 result1 = amx::act_fn(gate_val1, up_val1);
|
||||
avx512_32xfp32_to_32xbf16(&result0, &result1, (__m512i*)(gate_output_ptr + j));
|
||||
}
|
||||
}
|
||||
};
|
||||
DIRECT_OR_POOL_BY_QLEN(nth * activated_expert, up_gate_fn);
|
||||
|
||||
#ifdef FORWARD_TIME_PROFILE
|
||||
{
|
||||
auto now_time = std::chrono::high_resolution_clock::now();
|
||||
act_time = std::chrono::duration_cast<std::chrono::microseconds>(now_time - last).count();
|
||||
last = now_time;
|
||||
}
|
||||
#endif
|
||||
|
||||
pool->do_work_stealing_job(
|
||||
activated_expert, nullptr,
|
||||
[this](int task_id) {
|
||||
int expert_idx = m_expert_id_map_[task_id];
|
||||
down_ba_[expert_idx]->from_mat(m_local_num_[expert_idx], m_local_gate_output_ptr_[expert_idx], 0, 1);
|
||||
},
|
||||
nullptr);
|
||||
#ifdef FORWARD_TIME_PROFILE
|
||||
{
|
||||
auto now_time = std::chrono::high_resolution_clock::now();
|
||||
q_down_time = std::chrono::duration_cast<std::chrono::microseconds>(now_time - last).count();
|
||||
last = now_time;
|
||||
}
|
||||
#endif
|
||||
|
||||
nth = T::recommended_nth(config_.hidden_size);
|
||||
pool->do_work_stealing_job(
|
||||
nth * activated_expert, [](int _) { T::config(); },
|
||||
[this, nth, qlen](int task_id) {
|
||||
int& group_size = config_.quant_config.group_size;
|
||||
int expert_idx = m_expert_id_map_[task_id / nth];
|
||||
int ith = task_id % nth;
|
||||
MATMUL_OR_VECMUL_KGROUP_BY_QLEN(m_local_num_[expert_idx], config_.hidden_size, config_.intermediate_size,
|
||||
group_size, down_ba_[expert_idx], down_bb_[expert_idx], down_bc_[expert_idx],
|
||||
ith, nth);
|
||||
down_bc_[expert_idx]->to_mat(m_local_num_[expert_idx], m_local_down_output_ptr_[expert_idx], ith, nth);
|
||||
},
|
||||
nullptr);
|
||||
|
||||
#ifdef FORWARD_TIME_PROFILE
|
||||
{
|
||||
auto now_time = std::chrono::high_resolution_clock::now();
|
||||
down_time = std::chrono::duration_cast<std::chrono::microseconds>(now_time - last).count();
|
||||
last = now_time;
|
||||
}
|
||||
#endif
|
||||
|
||||
pool->do_work_stealing_job(
|
||||
qlen, nullptr,
|
||||
[this, nth, output, k, expert_ids, weights](int i) {
|
||||
for (int e = 0; e < config_.hidden_size; e += 32) {
|
||||
__m512 x0 = _mm512_setzero_ps();
|
||||
__m512 x1 = _mm512_setzero_ps();
|
||||
for (int j = 0; j < k; j++) {
|
||||
if (expert_ids[i * k + j] < config_.num_gpu_experts || expert_ids[i * k + j] >= config_.expert_num) {
|
||||
continue;
|
||||
}
|
||||
__m512 weight = _mm512_set1_ps(weights[i * k + j]);
|
||||
__m512 down_output0, down_output1;
|
||||
avx512_32xbf16_to_32xfp32((__m512i*)(m_local_down_output_ptr_[expert_ids[i * k + j]] +
|
||||
m_local_pos_[i][j] * config_.hidden_size + e),
|
||||
&down_output0, &down_output1);
|
||||
x0 = _mm512_fmadd_ps(down_output0, weight, x0);
|
||||
x1 = _mm512_fmadd_ps(down_output1, weight, x1);
|
||||
}
|
||||
auto f32out = (__m512*)((float*)output + i * config_.hidden_size + e);
|
||||
f32out[0] = x0;
|
||||
f32out[1] = x1;
|
||||
}
|
||||
},
|
||||
nullptr);
|
||||
|
||||
#ifdef FORWARD_TIME_PROFILE
|
||||
{
|
||||
auto now_time = std::chrono::high_resolution_clock::now();
|
||||
weight_time = std::chrono::duration_cast<std::chrono::microseconds>(now_time - last).count();
|
||||
last = now_time;
|
||||
}
|
||||
auto end_time = std::chrono::high_resolution_clock::now();
|
||||
auto forward_total_time = std::chrono::duration_cast<std::chrono::microseconds>(end_time - start_time).count();
|
||||
// 在函数末尾一次性打印所有阶段的耗时,并附带 max_local_num 和 qlen
|
||||
printf(
|
||||
"Profiling Results (numa[%d]): activated_expert: %d, prepare: %ld us, cpy_input: %ld us, q_input: %ld us, "
|
||||
"up_gate: %ld us, act: %ld us, q_down: %ld us, down: %ld us, weight: %ld us, total: %ld us, max_local_num: "
|
||||
"%d, qlen: %d\n",
|
||||
tp_part_idx, activated_expert, prepare_time, cpy_input_time, q_input_time, up_gate_time, act_time, q_down_time,
|
||||
down_time, weight_time, forward_total_time, max_local_num, qlen);
|
||||
#endif
|
||||
}
|
||||
|
||||
void forward_decode(int k, const int64_t* expert_ids, const float* weights, const void* input, void* output) {
|
||||
int qlen = 1;
|
||||
auto pool = config_.pool->get_subpool(tp_part_idx);
|
||||
auto& quant_config = config_.quant_config;
|
||||
int& group_size = quant_config.group_size;
|
||||
#ifdef FORWARD_TIME_PROFILE
|
||||
auto start_time = std::chrono::high_resolution_clock::now();
|
||||
auto last = start_time;
|
||||
// 用于保存各阶段耗时(单位:微秒)
|
||||
long prepare_time = 0, cpy_input_time = 0, q_input_time = 0, up_gate_time = 0;
|
||||
long act_time = 0, q_down_time = 0, down_time = 0, weight_time = 0;
|
||||
int max_local_num = 0; // 记录最大的 local num
|
||||
#endif
|
||||
|
||||
int activated_expert = 0;
|
||||
for (int i = 0; i < k; i++) {
|
||||
if (expert_ids[i] < config_.num_gpu_experts || expert_ids[i] >= config_.expert_num) {
|
||||
continue;
|
||||
}
|
||||
m_expert_id_map_[activated_expert] = expert_ids[i];
|
||||
activated_expert++;
|
||||
}
|
||||
|
||||
size_t offset = 0;
|
||||
for (int i = 0; i < activated_expert; i++) {
|
||||
auto expert_idx = m_expert_id_map_[i];
|
||||
m_local_gate_output_ptr_[expert_idx] = m_local_gate_output_ + offset * config_.intermediate_size;
|
||||
m_local_up_output_ptr_[expert_idx] = m_local_up_output_ + offset * config_.intermediate_size;
|
||||
m_local_down_output_ptr_[expert_idx] = m_local_down_output_ + offset * config_.hidden_size;
|
||||
offset += qlen;
|
||||
}
|
||||
|
||||
gate_up_ba_[0]->from_mat(qlen, (ggml_bf16_t*)input, 0, 1);
|
||||
|
||||
#ifdef FORWARD_TIME_PROFILE
|
||||
{
|
||||
auto now_time = std::chrono::high_resolution_clock::now();
|
||||
q_input_time = std::chrono::duration_cast<std::chrono::microseconds>(now_time - last).count();
|
||||
last = now_time;
|
||||
}
|
||||
#endif
|
||||
|
||||
int nth = T::recommended_nth(config_.intermediate_size);
|
||||
pool->do_work_stealing_job(
|
||||
nth * activated_expert * 2, [](int _) { T::config(); },
|
||||
[this, nth, qlen](int task_id2) {
|
||||
int& group_size = config_.quant_config.group_size;
|
||||
int task_id = task_id2 / 2;
|
||||
bool do_up = task_id2 % 2;
|
||||
int expert_idx = m_expert_id_map_[task_id / nth];
|
||||
|
||||
int ith = task_id % nth;
|
||||
if (do_up) {
|
||||
amx::vec_mul_kgroup(qlen, config_.intermediate_size, config_.hidden_size, group_size, gate_up_ba_[0],
|
||||
up_bb_[expert_idx], up_bc_[expert_idx], ith, nth);
|
||||
up_bc_[expert_idx]->to_mat(qlen, m_local_up_output_ptr_[expert_idx], ith, nth);
|
||||
} else {
|
||||
amx::vec_mul_kgroup(qlen, config_.intermediate_size, config_.hidden_size, group_size, gate_up_ba_[0],
|
||||
gate_bb_[expert_idx], gate_bc_[expert_idx], ith, nth);
|
||||
gate_bc_[expert_idx]->to_mat(qlen, m_local_gate_output_ptr_[expert_idx], ith, nth);
|
||||
}
|
||||
},
|
||||
nullptr);
|
||||
|
||||
#ifdef FORWARD_TIME_PROFILE
|
||||
{
|
||||
auto now_time = std::chrono::high_resolution_clock::now();
|
||||
up_gate_time = std::chrono::duration_cast<std::chrono::microseconds>(now_time - last).count();
|
||||
last = now_time;
|
||||
}
|
||||
#endif
|
||||
|
||||
for (int task_id = 0; task_id < nth * activated_expert; task_id++) {
|
||||
int expert_idx = m_expert_id_map_[task_id / nth];
|
||||
int ith = task_id % nth;
|
||||
auto [n_start, n_end] = T::split_range_n(config_.intermediate_size, ith, nth);
|
||||
for (int i = 0; i < qlen; i++) {
|
||||
ggml_bf16_t* gate_output_ptr = &m_local_gate_output_ptr_[expert_idx][i * config_.intermediate_size];
|
||||
ggml_bf16_t* up_output_ptr = &m_local_up_output_ptr_[expert_idx][i * config_.intermediate_size];
|
||||
for (int j = n_start; j < n_end; j += 32) {
|
||||
__m512 gate_val0, gate_val1, up_val0, up_val1;
|
||||
avx512_32xbf16_to_32xfp32((__m512i*)(gate_output_ptr + j), &gate_val0, &gate_val1);
|
||||
avx512_32xbf16_to_32xfp32((__m512i*)(up_output_ptr + j), &up_val0, &up_val1);
|
||||
__m512 result0 = amx::act_fn(gate_val0, up_val0);
|
||||
__m512 result1 = amx::act_fn(gate_val1, up_val1);
|
||||
avx512_32xfp32_to_32xbf16(&result0, &result1, (__m512i*)(gate_output_ptr + j));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef FORWARD_TIME_PROFILE
|
||||
{
|
||||
auto now_time = std::chrono::high_resolution_clock::now();
|
||||
act_time = std::chrono::duration_cast<std::chrono::microseconds>(now_time - last).count();
|
||||
last = now_time;
|
||||
}
|
||||
#endif
|
||||
|
||||
pool->do_work_stealing_job(
|
||||
activated_expert, nullptr,
|
||||
[this, qlen](int task_id) {
|
||||
int expert_idx = m_expert_id_map_[task_id];
|
||||
down_ba_[expert_idx]->from_mat(qlen, m_local_gate_output_ptr_[expert_idx], 0, 1);
|
||||
},
|
||||
nullptr);
|
||||
#ifdef FORWARD_TIME_PROFILE
|
||||
{
|
||||
auto now_time = std::chrono::high_resolution_clock::now();
|
||||
q_down_time = std::chrono::duration_cast<std::chrono::microseconds>(now_time - last).count();
|
||||
last = now_time;
|
||||
}
|
||||
#endif
|
||||
|
||||
nth = T::recommended_nth(config_.hidden_size);
|
||||
pool->do_work_stealing_job(
|
||||
nth * activated_expert, [](int _) { T::config(); },
|
||||
[this, nth, qlen](int task_id) {
|
||||
int& group_size = config_.quant_config.group_size;
|
||||
int expert_idx = m_expert_id_map_[task_id / nth];
|
||||
int ith = task_id % nth;
|
||||
amx::vec_mul_kgroup(qlen, config_.hidden_size, config_.intermediate_size, group_size, down_ba_[expert_idx],
|
||||
down_bb_[expert_idx], down_bc_[expert_idx], ith, nth);
|
||||
down_bc_[expert_idx]->to_mat(qlen, m_local_down_output_ptr_[expert_idx], ith, nth);
|
||||
},
|
||||
nullptr);
|
||||
|
||||
#ifdef FORWARD_TIME_PROFILE
|
||||
{
|
||||
auto now_time = std::chrono::high_resolution_clock::now();
|
||||
down_time = std::chrono::duration_cast<std::chrono::microseconds>(now_time - last).count();
|
||||
last = now_time;
|
||||
}
|
||||
#endif
|
||||
for (int i = 0; i < qlen; i++) {
|
||||
for (int e = 0; e < config_.hidden_size; e += 32) {
|
||||
__m512 x0 = _mm512_setzero_ps();
|
||||
__m512 x1 = _mm512_setzero_ps();
|
||||
for (int j = 0; j < k; j++) {
|
||||
if (expert_ids[i * k + j] < config_.num_gpu_experts || expert_ids[i * k + j] >= config_.expert_num) {
|
||||
continue;
|
||||
}
|
||||
__m512 weight = _mm512_set1_ps(weights[i * k + j]);
|
||||
__m512 down_output0, down_output1;
|
||||
avx512_32xbf16_to_32xfp32((__m512i*)(m_local_down_output_ptr_[expert_ids[i * k + j]] +
|
||||
m_local_pos_[i][j] * config_.hidden_size + e),
|
||||
&down_output0, &down_output1);
|
||||
x0 = _mm512_fmadd_ps(down_output0, weight, x0);
|
||||
x1 = _mm512_fmadd_ps(down_output1, weight, x1);
|
||||
}
|
||||
auto f32out = (__m512*)((float*)output + i * config_.hidden_size + e);
|
||||
f32out[0] = x0;
|
||||
f32out[1] = x1;
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef FORWARD_TIME_PROFILE
|
||||
{
|
||||
auto now_time = std::chrono::high_resolution_clock::now();
|
||||
weight_time = std::chrono::duration_cast<std::chrono::microseconds>(now_time - last).count();
|
||||
last = now_time;
|
||||
}
|
||||
auto end_time = std::chrono::high_resolution_clock::now();
|
||||
auto forward_total_time = std::chrono::duration_cast<std::chrono::microseconds>(end_time - start_time).count();
|
||||
// 在函数末尾一次性打印所有阶段的耗时,并附带 max_local_num 和 qlen
|
||||
printf(
|
||||
"Profiling Results (numa[%d]): activated_expert: %d, q_input: %ld us, "
|
||||
"up_gate: %ld us, act: %ld us, q_down: %ld us, down: %ld us, weight: %ld us, total: %ld us\n",
|
||||
tp_part_idx, activated_expert, q_input_time, up_gate_time, act_time, q_down_time, down_time, weight_time,
|
||||
forward_total_time);
|
||||
#endif
|
||||
}
|
||||
// forward, forward_prefill, forward_decode, warm_up are inherited from Base
|
||||
};
|
||||
|
||||
// ============================================================================
|
||||
// TP_MOE specialization for AMX_AWQ_MOE_TP
|
||||
// Inherits from TP_MOE<AMX_MOE_BASE<...>> to reuse merge_results implementation
|
||||
// ============================================================================
|
||||
|
||||
template <typename K>
|
||||
class TP_MOE<AMX_AWQ_MOE_TP<K>> : public TP_MOE_Common<AMX_AWQ_MOE_TP<K>> {
|
||||
class TP_MOE<AMX_AWQ_MOE_TP<K>> : public TP_MOE<AMX_MOE_BASE<K, AMX_AWQ_MOE_TP<K>>> {
|
||||
public:
|
||||
using TP_MOE_Common<AMX_AWQ_MOE_TP<K>>::TP_MOE_Common;
|
||||
void load_weights() {
|
||||
using Base = TP_MOE<AMX_MOE_BASE<K, AMX_AWQ_MOE_TP<K>>>;
|
||||
using Base::Base;
|
||||
|
||||
void load_weights() override {
|
||||
auto& config = this->config;
|
||||
auto& tps = this->tps;
|
||||
auto& tp_count = this->tp_count;
|
||||
@@ -1157,7 +696,7 @@ class TP_MOE<AMX_AWQ_MOE_TP<K>> : public TP_MOE_Common<AMX_AWQ_MOE_TP<K>> {
|
||||
((expert_id * config.intermediate_size * config.hidden_size + i * weight_elem_count) >> 1),
|
||||
((sizeof(uint8_t) * weight_elem_count) >> 1));
|
||||
|
||||
// zeros TP-slicing
|
||||
// down scales and zeros TP-slicing
|
||||
memcpy((ggml_fp16_t*)tpc.down_scale + (expert_id * scales_elem_count),
|
||||
(ggml_fp16_t*)config.down_scale +
|
||||
(expert_id * (config.intermediate_size / group_size) * config.hidden_size +
|
||||
@@ -1172,7 +711,7 @@ class TP_MOE<AMX_AWQ_MOE_TP<K>> : public TP_MOE_Common<AMX_AWQ_MOE_TP<K>> {
|
||||
(sizeof(uint8_t) * scales_elem_count) >> 1);
|
||||
|
||||
for (size_t kg = 0; kg < config.hidden_size / group_size; kg++) {
|
||||
// copy scale
|
||||
// copy gate/up scales
|
||||
memcpy((ggml_fp16_t*)tpc.gate_scale + (expert_id * scales_elem_count) + kg * tpc.intermediate_size,
|
||||
(ggml_fp16_t*)config.gate_scale +
|
||||
(expert_id * ((config.hidden_size / group_size) * config.intermediate_size) +
|
||||
@@ -1185,7 +724,7 @@ class TP_MOE<AMX_AWQ_MOE_TP<K>> : public TP_MOE_Common<AMX_AWQ_MOE_TP<K>> {
|
||||
kg * config.intermediate_size + i * tpc.intermediate_size),
|
||||
(sizeof(ggml_fp16_t) * tpc.intermediate_size));
|
||||
|
||||
// zeros TP-slicing
|
||||
// copy gate/up zeros TP-slicing
|
||||
memcpy(
|
||||
(uint8_t*)tpc.gate_zero + (((expert_id * scales_elem_count) + kg * tpc.intermediate_size) >> 1),
|
||||
(uint8_t*)config.gate_zero +
|
||||
@@ -1202,6 +741,7 @@ class TP_MOE<AMX_AWQ_MOE_TP<K>> : public TP_MOE_Common<AMX_AWQ_MOE_TP<K>> {
|
||||
((sizeof(uint8_t) * tpc.intermediate_size) >> 1));
|
||||
}
|
||||
|
||||
// down weights TP-slicing (column-wise)
|
||||
for (size_t col = 0; col < config.hidden_size; col++) {
|
||||
memcpy((uint8_t*)tpc.down_proj + ((expert_id * weight_elem_count + col * tpc.intermediate_size) >> 1),
|
||||
(uint8_t*)config.down_proj + ((expert_id * config.intermediate_size * config.hidden_size +
|
||||
@@ -1285,37 +825,7 @@ class TP_MOE<AMX_AWQ_MOE_TP<K>> : public TP_MOE_Common<AMX_AWQ_MOE_TP<K>> {
|
||||
}
|
||||
}
|
||||
|
||||
void merge_results(int qlen, void* output, bool incremental) {
|
||||
auto pool = this->config.pool;
|
||||
auto merge_fn = [this, output, incremental](int token_nth) {
|
||||
auto& local_output_numa = this->local_output_numa;
|
||||
auto& tp_configs = this->tp_configs;
|
||||
auto& tp_count = this->tp_count;
|
||||
auto& config = this->config;
|
||||
float* merge_to = local_output_numa[0] + token_nth * tp_configs[0].hidden_size;
|
||||
if (incremental) {
|
||||
for (int e = 0; e < config.hidden_size; e += 32) {
|
||||
__m512 x0, x1;
|
||||
avx512_32xbf16_to_32xfp32((__m512i*)((ggml_bf16_t*)output + token_nth * config.hidden_size + e), &x0, &x1);
|
||||
*((__m512*)(merge_to + e)) = _mm512_add_ps(*((__m512*)(merge_to + e)), x0);
|
||||
*((__m512*)(merge_to + e + 16)) = _mm512_add_ps(*((__m512*)(merge_to + e + 16)), x1);
|
||||
}
|
||||
}
|
||||
for (int i = 1; i < tp_count; i++) {
|
||||
float* merge_from = local_output_numa[i] + token_nth * tp_configs[i].hidden_size;
|
||||
for (int e = 0; e < tp_configs[i].hidden_size; e += 16) {
|
||||
*((__m512*)(merge_to + e)) = _mm512_add_ps(*((__m512*)(merge_to + e)), *((__m512*)(merge_from + e)));
|
||||
}
|
||||
}
|
||||
for (int e = 0; e < config.hidden_size; e += 32) {
|
||||
__m512 x0 = *(__m512*)(merge_to + e);
|
||||
__m512 x1 = *(__m512*)(merge_to + e + 16);
|
||||
avx512_32xfp32_to_32xbf16(&x0, &x1, (__m512i*)((ggml_bf16_t*)output + token_nth * config.hidden_size + e));
|
||||
}
|
||||
};
|
||||
DIRECT_OR_POOL_BY_QLEN(qlen, merge_fn);
|
||||
}
|
||||
void merge_results(int qlen, void* output) { merge_results(qlen, output, false); }
|
||||
// merge_results is inherited from TP_MOE<AMX_MOE_BASE<K, AMX_AWQ_MOE_TP<K>>>
|
||||
};
|
||||
|
||||
#endif
|
||||
|
||||
782
kt-kernel/operators/amx/fp8-moe.hpp
Normal file
782
kt-kernel/operators/amx/fp8-moe.hpp
Normal file
@@ -0,0 +1,782 @@
|
||||
/**
|
||||
* @Description : FP8 AMX MoE operator for DeepSeek V3.2 native inference
|
||||
* @Author : oql, Codex and Claude
|
||||
* @Date : 2025-12-09
|
||||
* @Version : 1.0.0
|
||||
* @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
|
||||
*
|
||||
* This file implements FP8 MoE using CRTP pattern, inheriting from moe_base.hpp.
|
||||
* FP8 weights are stored with 128x128 block-wise scales.
|
||||
**/
|
||||
#ifndef CPUINFER_OPERATOR_AMX_FP8_MOE_H
|
||||
#define CPUINFER_OPERATOR_AMX_FP8_MOE_H
|
||||
|
||||
// #define DEBUG_FP8_MOE
|
||||
|
||||
#include <immintrin.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
#include <cstring>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "la/amx_raw_buffers.hpp"
|
||||
#include "la/amx_raw_kernels.hpp"
|
||||
#include "moe_base.hpp"
|
||||
|
||||
/**
|
||||
* @brief FP8 MoE operator using CRTP pattern
|
||||
* @tparam T Kernel type, defaults to GemmKernel224FP8
|
||||
*
|
||||
* This class provides FP8-specific implementations:
|
||||
* - do_gate_up_gemm, do_down_gemm : FP8 weight -> BF16 conversion mat mul
|
||||
* - load_weights: Load FP8 weights with 128x128 block scales
|
||||
*/
|
||||
template <class T = amx::GemmKernel224FP8>
|
||||
class AMX_FP8_MOE_TP : public AMX_MOE_BASE<T, AMX_FP8_MOE_TP<T>> {
|
||||
using Base = AMX_MOE_BASE<T, AMX_FP8_MOE_TP<T>>;
|
||||
using Base::config_;
|
||||
using Base::down_ba_;
|
||||
using Base::down_bb_;
|
||||
using Base::down_bc_;
|
||||
using Base::gate_bb_;
|
||||
using Base::gate_bc_;
|
||||
using Base::gate_up_ba_;
|
||||
using Base::m_local_num_;
|
||||
using Base::tp_part_idx;
|
||||
using Base::up_bb_;
|
||||
using Base::up_bc_;
|
||||
|
||||
public:
|
||||
using typename Base::input_t;
|
||||
using typename Base::output_t;
|
||||
|
||||
AMX_FP8_MOE_TP() = default;
|
||||
|
||||
AMX_FP8_MOE_TP(GeneralMOEConfig config, int tp_part_idx_ = 0) : Base(config, tp_part_idx_) {
|
||||
auto& quant_config = config_.quant_config;
|
||||
if (quant_config.group_size == 0 || quant_config.zero_point) {
|
||||
throw std::runtime_error("KT-Kernel fp8 MoE only support block-wise FP8. group_size = %d, zero_point = %d",
|
||||
quant_config.group_size, quant_config.zero_point);
|
||||
}
|
||||
printf("Created AMX_FP8_MOE_TP %d at numa %d\n", tp_part_idx_, numa_node_of_cpu(sched_getcpu()));
|
||||
}
|
||||
|
||||
~AMX_FP8_MOE_TP() = default;
|
||||
// ============================================================================
|
||||
// CRTP buffer creation - with group_size
|
||||
// ============================================================================
|
||||
|
||||
size_t buffer_a_required_size_impl(size_t m, size_t k) const { return T::BufferA::required_size(m, k); }
|
||||
size_t buffer_b_required_size_impl(size_t n, size_t k) const {
|
||||
return T::BufferB::required_size(n, k, config_.quant_config.group_size);
|
||||
}
|
||||
size_t buffer_c_required_size_impl(size_t m, size_t n) const { return T::BufferC::required_size(m, n); }
|
||||
|
||||
std::shared_ptr<typename T::BufferA> make_buffer_a_impl(size_t m, size_t k, void* data) const {
|
||||
return std::make_shared<typename T::BufferA>(m, k, data);
|
||||
}
|
||||
std::shared_ptr<typename T::BufferB> make_buffer_b_impl(size_t n, size_t k, void* data) const {
|
||||
return std::make_shared<typename T::BufferB>(n, k, config_.quant_config.group_size, data);
|
||||
}
|
||||
std::shared_ptr<typename T::BufferC> make_buffer_c_impl(size_t m, size_t n, void* data) const {
|
||||
return std::make_shared<typename T::BufferC>(m, n, data);
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// CRTP virtual points - GEMM dispatch
|
||||
// ============================================================================
|
||||
|
||||
void do_gate_up_gemm(bool do_up, int expert_idx, int ith, int nth, int qlen) {
|
||||
auto& group_size = config_.quant_config.group_size;
|
||||
int m = m_local_num_[expert_idx];
|
||||
auto& ba = gate_up_ba_[expert_idx];
|
||||
auto& bb = do_up ? up_bb_[expert_idx] : gate_bb_[expert_idx];
|
||||
auto& bc = do_up ? up_bc_[expert_idx] : gate_bc_[expert_idx];
|
||||
|
||||
amx::vec_mul_kgroup(m, config_.intermediate_size, config_.hidden_size, group_size, ba, bb, bc, ith, nth);
|
||||
}
|
||||
void do_down_gemm(int expert_idx, int ith, int nth, int qlen) {
|
||||
auto& group_size = config_.quant_config.group_size;
|
||||
int m = m_local_num_[expert_idx];
|
||||
|
||||
amx::vec_mul_kgroup(m, config_.hidden_size, config_.intermediate_size, group_size, down_ba_[expert_idx],
|
||||
down_bb_[expert_idx], down_bc_[expert_idx], ith, nth);
|
||||
}
|
||||
|
||||
#ifdef DEBUG_FP8_MOE
|
||||
// Function to dump Buffer B data for debugging FP8 quantization results
|
||||
inline void dump_buffer_b(const std::string& quantization_type, int expert_idx, const std::string& matrix_type,
|
||||
typename T::BufferB* buffer) {
|
||||
auto& quant_config = config_.quant_config;
|
||||
int& group_size = quant_config.group_size;
|
||||
|
||||
printf("[DUMP_BUFFER_B] TP%d %s Expert%d %s:\n", tp_part_idx, quantization_type.c_str(), expert_idx,
|
||||
matrix_type.c_str());
|
||||
|
||||
// Calculate dimensions based on matrix type
|
||||
int rows, cols;
|
||||
size_t scale_elem_count;
|
||||
if (matrix_type == "gate" || matrix_type == "up") {
|
||||
rows = config_.intermediate_size;
|
||||
cols = config_.hidden_size;
|
||||
} else { // down
|
||||
rows = config_.hidden_size;
|
||||
cols = config_.intermediate_size;
|
||||
}
|
||||
int n_blocks_n = (rows + group_size - 1) / group_size;
|
||||
int n_blocks_k = (cols + group_size - 1) / group_size;
|
||||
scale_elem_count = n_blocks_n * n_blocks_k;
|
||||
|
||||
// Dump scales (as BF16 converted to float)
|
||||
printf(" Scales[first 16]: ");
|
||||
for (int i = 0; i < std::min(16, (int)scale_elem_count); i++) {
|
||||
printf("%.6f ", buffer->d[i]);
|
||||
}
|
||||
printf("\n");
|
||||
|
||||
if (scale_elem_count > 16) {
|
||||
printf(" Scales[last 16]: ");
|
||||
int start_idx = std::max(0, (int)scale_elem_count - 16);
|
||||
for (int i = start_idx; i < (int)scale_elem_count; i++) {
|
||||
printf("%.6f ", buffer->d[i]);
|
||||
}
|
||||
printf("\n");
|
||||
}
|
||||
|
||||
// Dump FP8 weights (as hex uint8)
|
||||
size_t weight_size = (size_t)rows * cols; // FP8 is 1 byte per element
|
||||
uint8_t* weight_ptr = (uint8_t*)buffer->b;
|
||||
|
||||
printf(" FP8 Weights[first 32 bytes]: ");
|
||||
for (int i = 0; i < std::min(32, (int)weight_size); i++) {
|
||||
printf("%02x ", weight_ptr[i]);
|
||||
}
|
||||
printf("\n");
|
||||
|
||||
if (weight_size > 32) {
|
||||
printf(" FP8 Weights[last 32 bytes]: ");
|
||||
int start_idx = std::max(32, (int)weight_size - 32);
|
||||
for (int i = start_idx; i < (int)weight_size; i++) {
|
||||
printf("%02x ", weight_ptr[i]);
|
||||
}
|
||||
printf("\n");
|
||||
}
|
||||
|
||||
printf(" Matrix dimensions: %dx%d (n x k), Scale blocks: %dx%d, Group size: %d, Scale elements: %zu\n", rows, cols,
|
||||
n_blocks_n, n_blocks_k, group_size, scale_elem_count);
|
||||
}
|
||||
#endif
|
||||
|
||||
/**
|
||||
* @brief Load FP8 weights from contiguous memory layout
|
||||
*
|
||||
* Loads weights from config_.gate_proj, up_proj, down_proj with scales
|
||||
* from config_.gate_scale, up_scale, down_scale.
|
||||
*/
|
||||
void load_weights() {
|
||||
auto& quant_config = config_.quant_config;
|
||||
int& group_size = quant_config.group_size;
|
||||
const uint64_t* physical_to_logical_map = (const uint64_t*)config_.physical_to_logical_map;
|
||||
auto pool = config_.pool->get_subpool(tp_part_idx);
|
||||
|
||||
if (config_.gate_scale == nullptr) {
|
||||
throw std::runtime_error("FP8 AVX MOE only support native weight.");
|
||||
}
|
||||
|
||||
// load weight
|
||||
int nth = T::recommended_nth(config_.intermediate_size);
|
||||
pool->do_work_stealing_job(
|
||||
nth * config_.expert_num, nullptr,
|
||||
[this, nth, physical_to_logical_map, group_size](int task_id) {
|
||||
uint64_t expert_idx = task_id / nth;
|
||||
uint64_t logical_expert_id = expert_map(physical_to_logical_map, expert_idx);
|
||||
int ith = task_id % nth;
|
||||
// gate part
|
||||
gate_bb_[expert_idx]->from_mat(
|
||||
(uint8_t*)config_.gate_proj + (logical_expert_id * config_.intermediate_size * config_.hidden_size),
|
||||
(float*)config_.gate_scale +
|
||||
(logical_expert_id * (config_.hidden_size / group_size) * (config_.intermediate_size / group_size)),
|
||||
ith, nth);
|
||||
// up part
|
||||
up_bb_[expert_idx]->from_mat(
|
||||
(uint8_t*)config_.up_proj + (logical_expert_id * config_.intermediate_size * config_.hidden_size),
|
||||
(float*)config_.up_scale +
|
||||
(logical_expert_id * (config_.hidden_size / group_size) * (config_.intermediate_size / group_size)),
|
||||
ith, nth);
|
||||
},
|
||||
nullptr);
|
||||
|
||||
nth = T::recommended_nth(config_.hidden_size);
|
||||
pool->do_work_stealing_job(
|
||||
nth * config_.expert_num, nullptr,
|
||||
[this, nth, physical_to_logical_map, group_size](int task_id) {
|
||||
uint64_t expert_idx = task_id / nth;
|
||||
uint64_t logical_expert_id = expert_map(physical_to_logical_map, expert_idx);
|
||||
int ith = task_id % nth;
|
||||
// down part
|
||||
down_bb_[expert_idx]->from_mat(
|
||||
(uint8_t*)config_.down_proj + (logical_expert_id * config_.intermediate_size * config_.hidden_size),
|
||||
(float*)config_.down_scale +
|
||||
(logical_expert_id * (config_.hidden_size / group_size) * (config_.intermediate_size / group_size)),
|
||||
ith, nth);
|
||||
},
|
||||
nullptr);
|
||||
#ifdef DEBUG_FP8_MOE
|
||||
dump_buffer_b("Native FP8", 0, "gate", gate_bb_[0].get());
|
||||
dump_buffer_b("Native FP8", 0, "down", down_bb_[0].get());
|
||||
#endif
|
||||
}
|
||||
|
||||
// Fast 64-byte (512-bit) memcpy using AVX512
|
||||
static inline void fast_memcpy_64(void* __restrict dst, const void* __restrict src) {
|
||||
__m512i data = _mm512_loadu_si512(src);
|
||||
_mm512_storeu_si512(dst, data);
|
||||
}
|
||||
|
||||
// Fast memcpy for arbitrary sizes using AVX512
|
||||
static inline void fast_memcpy(void* __restrict dst, const void* __restrict src, size_t bytes) {
|
||||
uint8_t* d = (uint8_t*)dst;
|
||||
const uint8_t* s = (const uint8_t*)src;
|
||||
size_t chunks = bytes / 64;
|
||||
for (size_t i = 0; i < chunks; i++) {
|
||||
fast_memcpy_64(d, s);
|
||||
d += 64;
|
||||
s += 64;
|
||||
}
|
||||
bytes -= chunks * 64;
|
||||
if (bytes > 0) {
|
||||
std::memcpy(d, s, bytes);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Unpack a single N_STEP x K_STEP block from packed BufferB format to n-major format
|
||||
*
|
||||
* This is the inverse of the packing done in BufferBFP8Impl::from_mat.
|
||||
* Optimized with AVX512 gather for efficient non-contiguous reads.
|
||||
*
|
||||
* @param src Pointer to packed data (N_STEP * K_STEP bytes in packed layout)
|
||||
* @param dst Pointer to destination in n-major layout
|
||||
* @param dst_row_stride Row stride in destination buffer (number of columns in full matrix)
|
||||
*/
|
||||
static inline void unpack_nk_block(const uint8_t* src, uint8_t* dst, size_t dst_row_stride) {
|
||||
// row_map[packed_i] gives the base row for packed index packed_i
|
||||
static constexpr int row_map[8] = {0, 16, 4, 20, 8, 24, 12, 28};
|
||||
const uint64_t* src64 = reinterpret_cast<const uint64_t*>(src);
|
||||
|
||||
// Gather indices: src64[8*j + packed_i] for j = 0..7
|
||||
// Offsets in uint64 units: 0, 8, 16, 24, 32, 40, 48, 56 (+ packed_i for each group)
|
||||
const __m512i gather_offsets = _mm512_set_epi64(56, 48, 40, 32, 24, 16, 8, 0);
|
||||
|
||||
// Process each packed group (8 groups of 4 rows each = 32 rows total)
|
||||
for (int packed_i = 0; packed_i < 8; packed_i++) {
|
||||
const int base_row = row_map[packed_i];
|
||||
const uint64_t* base_src = src64 + packed_i;
|
||||
|
||||
// Gather 8 values for j=0..7 and j=8..15
|
||||
__m512i vals_0_7 = _mm512_i64gather_epi64(gather_offsets, base_src, 8);
|
||||
__m512i vals_8_15 = _mm512_i64gather_epi64(gather_offsets, base_src + 64, 8);
|
||||
|
||||
// Extract 4 rows from each set of 8 values
|
||||
// Row 0: bits 0-15
|
||||
__m128i row0_lo = _mm512_cvtepi64_epi16(_mm512_and_si512(vals_0_7, _mm512_set1_epi64(0xFFFF)));
|
||||
__m128i row0_hi = _mm512_cvtepi64_epi16(_mm512_and_si512(vals_8_15, _mm512_set1_epi64(0xFFFF)));
|
||||
// Row 1: bits 16-31
|
||||
__m128i row1_lo =
|
||||
_mm512_cvtepi64_epi16(_mm512_and_si512(_mm512_srli_epi64(vals_0_7, 16), _mm512_set1_epi64(0xFFFF)));
|
||||
__m128i row1_hi =
|
||||
_mm512_cvtepi64_epi16(_mm512_and_si512(_mm512_srli_epi64(vals_8_15, 16), _mm512_set1_epi64(0xFFFF)));
|
||||
// Row 2: bits 32-47
|
||||
__m128i row2_lo =
|
||||
_mm512_cvtepi64_epi16(_mm512_and_si512(_mm512_srli_epi64(vals_0_7, 32), _mm512_set1_epi64(0xFFFF)));
|
||||
__m128i row2_hi =
|
||||
_mm512_cvtepi64_epi16(_mm512_and_si512(_mm512_srli_epi64(vals_8_15, 32), _mm512_set1_epi64(0xFFFF)));
|
||||
// Row 3: bits 48-63
|
||||
__m128i row3_lo = _mm512_cvtepi64_epi16(_mm512_srli_epi64(vals_0_7, 48));
|
||||
__m128i row3_hi = _mm512_cvtepi64_epi16(_mm512_srli_epi64(vals_8_15, 48));
|
||||
|
||||
// Store 32 bytes (16 x uint16) to each row
|
||||
// Combine two 128-bit values into 256-bit for more efficient stores
|
||||
uint8_t* row0_dst = dst + (size_t)base_row * dst_row_stride;
|
||||
uint8_t* row1_dst = dst + (size_t)(base_row + 1) * dst_row_stride;
|
||||
uint8_t* row2_dst = dst + (size_t)(base_row + 2) * dst_row_stride;
|
||||
uint8_t* row3_dst = dst + (size_t)(base_row + 3) * dst_row_stride;
|
||||
|
||||
// Combine lo and hi into 256-bit and store
|
||||
__m256i row0_256 = _mm256_set_m128i(row0_hi, row0_lo);
|
||||
__m256i row1_256 = _mm256_set_m128i(row1_hi, row1_lo);
|
||||
__m256i row2_256 = _mm256_set_m128i(row2_hi, row2_lo);
|
||||
__m256i row3_256 = _mm256_set_m128i(row3_hi, row3_lo);
|
||||
|
||||
_mm256_storeu_si256((__m256i*)row0_dst, row0_256);
|
||||
_mm256_storeu_si256((__m256i*)row1_dst, row1_256);
|
||||
_mm256_storeu_si256((__m256i*)row2_dst, row2_256);
|
||||
_mm256_storeu_si256((__m256i*)row3_dst, row3_256);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Unpack 4 consecutive N_STEP x K_STEP blocks to maximize cache line utilization
|
||||
*
|
||||
* Processing 4 blocks together means each row write is 128 bytes = 2 cache lines,
|
||||
* which greatly improves write efficiency compared to 32 bytes per row.
|
||||
*
|
||||
* @param src Array of 4 source pointers (each pointing to a 32x32 packed block)
|
||||
* @param dst Destination pointer in n-major layout
|
||||
* @param dst_row_stride Row stride in destination buffer
|
||||
*/
|
||||
static inline void unpack_4nk_blocks(const uint8_t* src[4], uint8_t* dst, size_t dst_row_stride) {
|
||||
static constexpr int row_map[8] = {0, 16, 4, 20, 8, 24, 12, 28};
|
||||
constexpr int K_STEP = T::K_STEP; // 32
|
||||
|
||||
// Reinterpret as uint64 arrays for efficient access
|
||||
const uint64_t* src0 = reinterpret_cast<const uint64_t*>(src[0]);
|
||||
const uint64_t* src1 = reinterpret_cast<const uint64_t*>(src[1]);
|
||||
const uint64_t* src2 = reinterpret_cast<const uint64_t*>(src[2]);
|
||||
const uint64_t* src3 = reinterpret_cast<const uint64_t*>(src[3]);
|
||||
|
||||
// Process all 32 rows, writing 128 bytes (4 x 32) per row
|
||||
for (int packed_i = 0; packed_i < 8; packed_i++) {
|
||||
const int base_row = row_map[packed_i];
|
||||
|
||||
// Process 4 rows at a time
|
||||
for (int r = 0; r < 4; r++) {
|
||||
uint16_t* row_dst = reinterpret_cast<uint16_t*>(dst + (size_t)(base_row + r) * dst_row_stride);
|
||||
const int shift = r * 16;
|
||||
|
||||
// Unroll: process all 4 blocks x 16 columns = 64 uint16 values
|
||||
// Block 0: columns 0-15
|
||||
for (int j = 0; j < 16; j++) {
|
||||
row_dst[j] = static_cast<uint16_t>(src0[8 * j + packed_i] >> shift);
|
||||
}
|
||||
// Block 1: columns 16-31
|
||||
for (int j = 0; j < 16; j++) {
|
||||
row_dst[16 + j] = static_cast<uint16_t>(src1[8 * j + packed_i] >> shift);
|
||||
}
|
||||
// Block 2: columns 32-47
|
||||
for (int j = 0; j < 16; j++) {
|
||||
row_dst[32 + j] = static_cast<uint16_t>(src2[8 * j + packed_i] >> shift);
|
||||
}
|
||||
// Block 3: columns 48-63
|
||||
for (int j = 0; j < 16; j++) {
|
||||
row_dst[48 + j] = static_cast<uint16_t>(src3[8 * j + packed_i] >> shift);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Reconstruct weights for a single expert to the output buffers (no temp buffer version)
|
||||
*
|
||||
* Directly unpacks from packed BufferB format to n-major GPU buffers without intermediate storage.
|
||||
* Optimized version with coarse-grained task splitting for better cache utilization.
|
||||
*
|
||||
* Key optimizations:
|
||||
* - Reduced task count (~40 vs ~350) to minimize scheduling overhead
|
||||
* - Larger chunks per task for better cache line utilization
|
||||
* - Process multiple N_STEPs per task for better write locality
|
||||
*
|
||||
* @param gpu_tp_count Number of GPU TP parts (1, 2, 4, or 8)
|
||||
* @param cpu_tp_count Number of CPU TP parts
|
||||
* @param expert_id Expert index to process
|
||||
* @param full_config Full configuration (before CPU TP split)
|
||||
* @param w13_weight_ptrs Pointers to gate+up weight buffers (one per GPU TP)
|
||||
* @param w13_scale_ptrs Pointers to gate+up scale buffers (one per GPU TP)
|
||||
* @param w2_weight_ptrs Pointers to down weight buffers (one per GPU TP)
|
||||
* @param w2_scale_ptrs Pointers to down scale buffers (one per GPU TP)
|
||||
*/
|
||||
void write_weights_to_buffer(int gpu_tp_count, [[maybe_unused]] int cpu_tp_count, int expert_id,
|
||||
const GeneralMOEConfig& full_config, const std::vector<uintptr_t>& w13_weight_ptrs,
|
||||
const std::vector<uintptr_t>& w13_scale_ptrs,
|
||||
const std::vector<uintptr_t>& w2_weight_ptrs,
|
||||
const std::vector<uintptr_t>& w2_scale_ptrs) const {
|
||||
auto& config = config_;
|
||||
const int group_size = config.quant_config.group_size;
|
||||
auto pool = config.pool->get_subpool(tp_part_idx);
|
||||
|
||||
constexpr int N_STEP = T::N_STEP;
|
||||
constexpr int K_STEP = T::K_STEP;
|
||||
constexpr int N_BLOCK = T::N_BLOCK;
|
||||
constexpr int K_BLOCK = T::K_BLOCK;
|
||||
|
||||
// ========= W13 (gate+up): Shape [intermediate, hidden], split by N only =========
|
||||
const int cpu_n_w13 = config.intermediate_size;
|
||||
const int cpu_k_w13 = config.hidden_size;
|
||||
const int gpu_n_w13 = full_config.intermediate_size / gpu_tp_count;
|
||||
const int gpu_k_w13 = full_config.hidden_size;
|
||||
const int global_n_offset_w13 = tp_part_idx * cpu_n_w13;
|
||||
|
||||
const size_t gpu_w13_weight_per_mat = (size_t)gpu_n_w13 * gpu_k_w13;
|
||||
const size_t gpu_w13_scale_per_mat = (size_t)div_up(gpu_n_w13, group_size) * div_up(gpu_k_w13, group_size);
|
||||
const int cpu_scale_k_blocks_w13 = div_up(cpu_k_w13, group_size);
|
||||
const int gpu_scale_k_blocks_w13 = div_up(gpu_k_w13, group_size);
|
||||
|
||||
// ========= W2 (down): Shape [hidden, intermediate], split by K =========
|
||||
const int cpu_n_w2 = config.hidden_size;
|
||||
const int cpu_k_w2 = config.intermediate_size;
|
||||
const int gpu_n_w2 = full_config.hidden_size;
|
||||
const int gpu_k_w2 = full_config.intermediate_size / gpu_tp_count;
|
||||
const int global_k_offset_w2 = tp_part_idx * cpu_k_w2;
|
||||
|
||||
const size_t gpu_w2_weight_per_mat = (size_t)gpu_n_w2 * gpu_k_w2;
|
||||
const size_t gpu_w2_scale_per_mat = (size_t)div_up(gpu_n_w2, group_size) * div_up(gpu_k_w2, group_size);
|
||||
const int cpu_scale_k_blocks_w2 = div_up(cpu_k_w2, group_size);
|
||||
const int gpu_scale_k_blocks_w2 = div_up(gpu_k_w2, group_size);
|
||||
|
||||
// ========= Scale dimensions =========
|
||||
const int cpu_scale_n_blocks_w13 = div_up(cpu_n_w13, group_size);
|
||||
const int gpu_scale_n_blocks_w13 = div_up(gpu_n_w13, group_size);
|
||||
const int cpu_scale_n_blocks_w2 = div_up(cpu_n_w2, group_size);
|
||||
|
||||
// ========= Optimized job layout =========
|
||||
// Use task count slightly above CPU core count for good work stealing
|
||||
// For 80-core system, ~100 tasks provides good balance
|
||||
constexpr int NUM_W13_TASKS = 32; // Per matrix (gate or up), total 64 for w13
|
||||
constexpr int NUM_W2_TASKS = 32; // For down matrix
|
||||
constexpr int SCALE_TASKS = 3; // gate_scale, up_scale, down_scale
|
||||
|
||||
const int total_tasks = NUM_W13_TASKS * 2 + NUM_W2_TASKS + SCALE_TASKS;
|
||||
|
||||
// Calculate N_STEP blocks per task (must be N_STEP aligned for correct BufferB addressing)
|
||||
const int w13_n_steps = div_up(cpu_n_w13, N_STEP);
|
||||
const int w13_steps_per_task = div_up(w13_n_steps, NUM_W13_TASKS);
|
||||
const int w2_n_steps = div_up(cpu_n_w2, N_STEP);
|
||||
const int w2_steps_per_task = div_up(w2_n_steps, NUM_W2_TASKS);
|
||||
|
||||
pool->do_work_stealing_job(
|
||||
total_tasks, nullptr,
|
||||
[=, &w13_weight_ptrs, &w13_scale_ptrs, &w2_weight_ptrs, &w2_scale_ptrs, this](int task_id) {
|
||||
if (task_id < NUM_W13_TASKS * 2) {
|
||||
// ========= W13 weight task: process chunk of rows x full K =========
|
||||
const bool is_up = task_id >= NUM_W13_TASKS;
|
||||
const int chunk_idx = task_id % NUM_W13_TASKS;
|
||||
const auto& bb = is_up ? up_bb_[expert_id] : gate_bb_[expert_id];
|
||||
|
||||
// Calculate row range for this task (N_STEP aligned)
|
||||
const int step_start = chunk_idx * w13_steps_per_task;
|
||||
const int step_end = std::min(step_start + w13_steps_per_task, w13_n_steps);
|
||||
if (step_start >= w13_n_steps) return;
|
||||
const int chunk_n_start = step_start * N_STEP;
|
||||
const int chunk_n_end = std::min(step_end * N_STEP, cpu_n_w13);
|
||||
|
||||
// Process each N_STEP within this chunk
|
||||
for (int local_n_start = chunk_n_start; local_n_start < chunk_n_end; local_n_start += N_STEP) {
|
||||
// Calculate GPU target and offset for each N_STEP (may cross GPU TP boundaries)
|
||||
const int global_n = global_n_offset_w13 + local_n_start;
|
||||
const int target_gpu = global_n / gpu_n_w13;
|
||||
const int n_in_gpu = global_n % gpu_n_w13;
|
||||
|
||||
uint8_t* weight_base = (uint8_t*)w13_weight_ptrs[target_gpu];
|
||||
// Pointer already points to current expert's location, only add offset for up matrix
|
||||
const size_t expert_weight_off = is_up ? gpu_w13_weight_per_mat : 0;
|
||||
|
||||
// Calculate N_BLOCK info for source addressing
|
||||
const int n_block_idx = local_n_start / N_BLOCK;
|
||||
const int n_block_begin = n_block_idx * N_BLOCK;
|
||||
const int n_block_size = std::min(N_BLOCK, cpu_n_w13 - n_block_begin);
|
||||
const int n_in_block = local_n_start - n_block_begin;
|
||||
|
||||
// Process all K in groups of 4 K_STEPs when possible for cache efficiency
|
||||
for (int k_block_begin = 0; k_block_begin < cpu_k_w13; k_block_begin += K_BLOCK) {
|
||||
const int k_block_size = std::min(K_BLOCK, cpu_k_w13 - k_block_begin);
|
||||
|
||||
// Try to process 4 K_STEPs at once (128 columns = 2 cache lines per row)
|
||||
int k_begin = 0;
|
||||
for (; k_begin + 4 * K_STEP <= k_block_size; k_begin += 4 * K_STEP) {
|
||||
const uint8_t* src_ptrs[4];
|
||||
for (int i = 0; i < 4; i++) {
|
||||
src_ptrs[i] = bb->b + (size_t)n_block_begin * cpu_k_w13 + (size_t)k_block_begin * n_block_size +
|
||||
(size_t)n_in_block * k_block_size + (size_t)(k_begin + i * K_STEP) * N_STEP;
|
||||
}
|
||||
uint8_t* dst =
|
||||
weight_base + expert_weight_off + (size_t)n_in_gpu * gpu_k_w13 + k_block_begin + k_begin;
|
||||
unpack_4nk_blocks(src_ptrs, dst, gpu_k_w13);
|
||||
}
|
||||
|
||||
// Handle remaining K_STEPs one by one
|
||||
for (; k_begin < k_block_size; k_begin += K_STEP) {
|
||||
const uint8_t* src = bb->b + (size_t)n_block_begin * cpu_k_w13 +
|
||||
(size_t)k_block_begin * n_block_size + (size_t)n_in_block * k_block_size +
|
||||
(size_t)k_begin * N_STEP;
|
||||
uint8_t* dst =
|
||||
weight_base + expert_weight_off + (size_t)n_in_gpu * gpu_k_w13 + k_block_begin + k_begin;
|
||||
unpack_nk_block(src, dst, gpu_k_w13);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} else if (task_id < NUM_W13_TASKS * 2 + NUM_W2_TASKS) {
|
||||
// ========= W2 weight task: process chunk of rows x all K slices =========
|
||||
const int chunk_idx = task_id - NUM_W13_TASKS * 2;
|
||||
const auto& bb = down_bb_[expert_id];
|
||||
|
||||
// Calculate row range for this task (N_STEP aligned)
|
||||
const int step_start = chunk_idx * w2_steps_per_task;
|
||||
const int step_end = std::min(step_start + w2_steps_per_task, w2_n_steps);
|
||||
if (step_start >= w2_n_steps) return;
|
||||
const int chunk_n_start = step_start * N_STEP;
|
||||
const int chunk_n_end = std::min(step_end * N_STEP, cpu_n_w2);
|
||||
|
||||
// Process each N_STEP within this chunk
|
||||
for (int local_n_start = chunk_n_start; local_n_start < chunk_n_end; local_n_start += N_STEP) {
|
||||
// Calculate N_BLOCK info for source addressing
|
||||
const int n_block_idx = local_n_start / N_BLOCK;
|
||||
const int n_block_begin = n_block_idx * N_BLOCK;
|
||||
const int n_block_size = std::min(N_BLOCK, cpu_n_w2 - n_block_begin);
|
||||
const int n_in_block = local_n_start - n_block_begin;
|
||||
|
||||
// Process all K slices (each slice goes to a different GPU TP)
|
||||
for (int k_slice_start = 0; k_slice_start < cpu_k_w2; k_slice_start += gpu_k_w2) {
|
||||
const int k_slice_end = std::min(k_slice_start + gpu_k_w2, cpu_k_w2);
|
||||
|
||||
const int global_k_start = global_k_offset_w2 + k_slice_start;
|
||||
const int target_gpu = global_k_start / gpu_k_w2;
|
||||
const int k_in_gpu_base = global_k_start % gpu_k_w2;
|
||||
|
||||
uint8_t* weight_base = (uint8_t*)w2_weight_ptrs[target_gpu];
|
||||
// Pointer already points to current expert's location
|
||||
const size_t expert_weight_off = 0;
|
||||
|
||||
// Process K within this slice, trying 4 K_STEPs at once when aligned
|
||||
for (int k_abs = k_slice_start; k_abs < k_slice_end;) {
|
||||
const int k_block_idx = k_abs / K_BLOCK;
|
||||
const int k_block_begin = k_block_idx * K_BLOCK;
|
||||
const int k_block_size = std::min(K_BLOCK, cpu_k_w2 - k_block_begin);
|
||||
const int k_in_block = k_abs - k_block_begin;
|
||||
const int k_in_gpu = k_in_gpu_base + (k_abs - k_slice_start);
|
||||
|
||||
// Check if we can process 4 K_STEPs at once
|
||||
const int remaining_in_block = k_block_size - k_in_block;
|
||||
const int remaining_in_slice = k_slice_end - k_abs;
|
||||
|
||||
if (remaining_in_block >= 4 * K_STEP && remaining_in_slice >= 4 * K_STEP) {
|
||||
const uint8_t* src_ptrs[4];
|
||||
for (int i = 0; i < 4; i++) {
|
||||
src_ptrs[i] = bb->b + (size_t)n_block_begin * cpu_k_w2 + (size_t)k_block_begin * n_block_size +
|
||||
(size_t)n_in_block * k_block_size + (size_t)(k_in_block + i * K_STEP) * N_STEP;
|
||||
}
|
||||
uint8_t* dst = weight_base + expert_weight_off + (size_t)local_n_start * gpu_k_w2 + k_in_gpu;
|
||||
unpack_4nk_blocks(src_ptrs, dst, gpu_k_w2);
|
||||
k_abs += 4 * K_STEP;
|
||||
} else {
|
||||
const uint8_t* src = bb->b + (size_t)n_block_begin * cpu_k_w2 +
|
||||
(size_t)k_block_begin * n_block_size + (size_t)n_in_block * k_block_size +
|
||||
(size_t)k_in_block * N_STEP;
|
||||
uint8_t* dst = weight_base + expert_weight_off + (size_t)local_n_start * gpu_k_w2 + k_in_gpu;
|
||||
unpack_nk_block(src, dst, gpu_k_w2);
|
||||
k_abs += K_STEP;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} else {
|
||||
// ========= Scale copy task: simple linear copy with fast_memcpy =========
|
||||
const int scale_task_id = task_id - NUM_W13_TASKS * 2 - NUM_W2_TASKS;
|
||||
|
||||
if (scale_task_id < 2) {
|
||||
// Gate (0) or Up (1) scale copy
|
||||
const bool is_up = scale_task_id == 1;
|
||||
const auto& bb = is_up ? up_bb_[expert_id] : gate_bb_[expert_id];
|
||||
|
||||
// W13 scales: copy N blocks corresponding to this CPU TP
|
||||
// Note: when gpu_tp > cpu_tp, scale blocks may span multiple GPU TPs
|
||||
const int bn_start_global = global_n_offset_w13 / group_size;
|
||||
|
||||
for (int bn = 0; bn < cpu_scale_n_blocks_w13; bn++) {
|
||||
const int global_bn = bn_start_global + bn;
|
||||
const int target_gpu = global_bn / gpu_scale_n_blocks_w13;
|
||||
const int gpu_bn = global_bn % gpu_scale_n_blocks_w13;
|
||||
|
||||
float* scale_dst = (float*)w13_scale_ptrs[target_gpu];
|
||||
// Pointer already points to current expert's location, only add offset for up matrix
|
||||
const size_t expert_scale_off = is_up ? gpu_w13_scale_per_mat : 0;
|
||||
|
||||
fast_memcpy(scale_dst + expert_scale_off + (size_t)gpu_bn * gpu_scale_k_blocks_w13,
|
||||
bb->d + (size_t)bn * cpu_scale_k_blocks_w13, cpu_scale_k_blocks_w13 * sizeof(float));
|
||||
}
|
||||
} else {
|
||||
// Down scale copy (scale_task_id == 2)
|
||||
const auto& bb = down_bb_[expert_id];
|
||||
|
||||
// W2 scales: K dimension is split, copy to each GPU TP
|
||||
for (int k_slice_idx = 0; k_slice_idx < div_up(cpu_k_w2, gpu_k_w2); k_slice_idx++) {
|
||||
const int k_slice_start = k_slice_idx * gpu_k_w2;
|
||||
const int k_slice_end = std::min(k_slice_start + gpu_k_w2, cpu_k_w2);
|
||||
|
||||
const int global_k_start = global_k_offset_w2 + k_slice_start;
|
||||
const int target_gpu = global_k_start / gpu_k_w2;
|
||||
const int bk_gpu_base = (global_k_start % gpu_k_w2) / group_size;
|
||||
|
||||
float* scale_dst = (float*)w2_scale_ptrs[target_gpu];
|
||||
// Pointer already points to current expert's location
|
||||
const size_t expert_scale_off = 0;
|
||||
|
||||
const int bk_start = k_slice_start / group_size;
|
||||
const int bk_end = div_up(k_slice_end, group_size);
|
||||
const int bk_count = bk_end - bk_start;
|
||||
|
||||
for (int bn = 0; bn < cpu_scale_n_blocks_w2; bn++) {
|
||||
fast_memcpy(scale_dst + expert_scale_off + (size_t)bn * gpu_scale_k_blocks_w2 + bk_gpu_base,
|
||||
bb->d + (size_t)bn * cpu_scale_k_blocks_w2 + bk_start, bk_count * sizeof(float));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
nullptr);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename K>
|
||||
class TP_MOE<AMX_FP8_MOE_TP<K>> : public TP_MOE<AMX_MOE_BASE<K, AMX_FP8_MOE_TP<K>>> {
|
||||
public:
|
||||
using Base = TP_MOE<AMX_MOE_BASE<K, AMX_FP8_MOE_TP<K>>>;
|
||||
using Base::Base;
|
||||
|
||||
void load_weights() override {
|
||||
auto& config = this->config;
|
||||
auto& tps = this->tps;
|
||||
auto& tp_count = this->tp_count;
|
||||
auto pool = config.pool;
|
||||
const uint64_t* physical_to_logical_map = (const uint64_t*)config.physical_to_logical_map;
|
||||
|
||||
const int group_size = config.quant_config.group_size;
|
||||
if (group_size == 0 || config.quant_config.zero_point) {
|
||||
throw std::runtime_error("FP8 MoE only supports have group_size, zero_point=false");
|
||||
}
|
||||
|
||||
if (config.gate_projs.empty() && config.gate_proj == nullptr) {
|
||||
throw std::runtime_error("no weight source");
|
||||
}
|
||||
const bool use_per_expert_ptrs = !config.gate_projs.empty();
|
||||
|
||||
const size_t full_weight_elems = (size_t)config.intermediate_size * config.hidden_size;
|
||||
const size_t full_scale_elems =
|
||||
(size_t)div_up(config.hidden_size, group_size) * div_up(config.intermediate_size, group_size);
|
||||
|
||||
pool->dispense_backend()->do_numa_job([&, this](int i) {
|
||||
auto& tpc = tps[i]->config_;
|
||||
const size_t tp_weight_elems = (size_t)tpc.intermediate_size * tpc.hidden_size;
|
||||
const size_t tp_scale_elems =
|
||||
(size_t)div_up(tpc.intermediate_size, group_size) * div_up(tpc.hidden_size, group_size);
|
||||
|
||||
tpc.gate_proj = new uint8_t[tpc.expert_num * tp_weight_elems];
|
||||
tpc.up_proj = new uint8_t[tpc.expert_num * tp_weight_elems];
|
||||
tpc.down_proj = new uint8_t[tpc.expert_num * tp_weight_elems];
|
||||
|
||||
tpc.gate_scale = new float[tpc.expert_num * tp_scale_elems];
|
||||
tpc.up_scale = new float[tpc.expert_num * tp_scale_elems];
|
||||
tpc.down_scale = new float[tpc.expert_num * tp_scale_elems];
|
||||
|
||||
const size_t tp_idx = (size_t)i;
|
||||
const size_t gate_up_weight_src_offset = i * tp_weight_elems;
|
||||
const size_t gate_up_scale_src_offset = i * tp_scale_elems;
|
||||
|
||||
const size_t down_weight_src_col_offset = i * (size_t)tpc.intermediate_size;
|
||||
const size_t down_scale_src_block_k_offset = down_weight_src_col_offset / (size_t)group_size;
|
||||
|
||||
pool->get_subpool(i)->do_work_stealing_job(
|
||||
tpc.expert_num, nullptr,
|
||||
[&, &tpc](int expert_id_) {
|
||||
const size_t expert_id = expert_map(physical_to_logical_map, expert_id_);
|
||||
|
||||
uint8_t* gate_dst = (uint8_t*)tpc.gate_proj + expert_id * tp_weight_elems;
|
||||
uint8_t* up_dst = (uint8_t*)tpc.up_proj + expert_id * tp_weight_elems;
|
||||
uint8_t* down_dst = (uint8_t*)tpc.down_proj + expert_id * tp_weight_elems;
|
||||
|
||||
float* gate_scale_dst = (float*)tpc.gate_scale + expert_id * tp_scale_elems;
|
||||
float* up_scale_dst = (float*)tpc.up_scale + expert_id * tp_scale_elems;
|
||||
float* down_scale_dst = (float*)tpc.down_scale + expert_id * tp_scale_elems;
|
||||
|
||||
const uint8_t* gate_src;
|
||||
const uint8_t* up_src;
|
||||
const uint8_t* down_src;
|
||||
const float* gate_scale_src;
|
||||
const float* up_scale_src;
|
||||
const float* down_scale_src;
|
||||
|
||||
if (use_per_expert_ptrs) {
|
||||
gate_src = (const uint8_t*)config.gate_projs[0][expert_id] + gate_up_weight_src_offset;
|
||||
up_src = (const uint8_t*)config.up_projs[0][expert_id] + gate_up_weight_src_offset;
|
||||
down_src = (const uint8_t*)config.down_projs[0][expert_id];
|
||||
|
||||
gate_scale_src = (const float*)config.gate_scales[0][expert_id] + gate_up_scale_src_offset;
|
||||
up_scale_src = (const float*)config.up_scales[0][expert_id] + gate_up_scale_src_offset;
|
||||
down_scale_src = (const float*)config.down_scales[0][expert_id];
|
||||
} else {
|
||||
gate_src = (const uint8_t*)config.gate_proj + expert_id * full_weight_elems + gate_up_weight_src_offset;
|
||||
up_src = (const uint8_t*)config.up_proj + expert_id * full_weight_elems + gate_up_weight_src_offset;
|
||||
down_src = (const uint8_t*)config.down_proj + expert_id * full_weight_elems;
|
||||
|
||||
gate_scale_src =
|
||||
(const float*)config.gate_scale + expert_id * full_scale_elems + gate_up_scale_src_offset;
|
||||
up_scale_src = (const float*)config.up_scale + expert_id * full_scale_elems + gate_up_scale_src_offset;
|
||||
down_scale_src = (const float*)config.down_scale + expert_id * full_scale_elems;
|
||||
}
|
||||
|
||||
std::memcpy(gate_dst, gate_src, tp_weight_elems);
|
||||
std::memcpy(up_dst, up_src, tp_weight_elems);
|
||||
std::memcpy(gate_scale_dst, gate_scale_src, sizeof(float) * tp_scale_elems);
|
||||
std::memcpy(up_scale_dst, up_scale_src, sizeof(float) * tp_scale_elems);
|
||||
|
||||
for (int row = 0; row < config.hidden_size; row++) {
|
||||
const size_t src_row_offset = (size_t)row * (size_t)config.intermediate_size + down_weight_src_col_offset;
|
||||
const size_t dst_row_offset = (size_t)row * (size_t)tpc.intermediate_size;
|
||||
std::memcpy(down_dst + dst_row_offset, down_src + src_row_offset, (size_t)tpc.intermediate_size);
|
||||
}
|
||||
|
||||
const int n_blocks_n = div_up(config.hidden_size, group_size);
|
||||
const int full_n_blocks_k = div_up(config.intermediate_size, group_size);
|
||||
const int tp_n_blocks_k = div_up(tpc.intermediate_size, group_size);
|
||||
for (int bn = 0; bn < n_blocks_n; bn++) {
|
||||
const float* src = down_scale_src + (size_t)bn * (size_t)full_n_blocks_k + down_scale_src_block_k_offset;
|
||||
float* dst = down_scale_dst + (size_t)bn * (size_t)tp_n_blocks_k;
|
||||
std::memcpy(dst, src, sizeof(float) * (size_t)tp_n_blocks_k);
|
||||
}
|
||||
},
|
||||
nullptr);
|
||||
});
|
||||
|
||||
DO_TPS_LOAD_WEIGHTS(pool);
|
||||
|
||||
pool->dispense_backend()->do_numa_job([&, this](int i) {
|
||||
auto& tpc = tps[i]->config_;
|
||||
delete[] (uint8_t*)tpc.gate_proj;
|
||||
delete[] (uint8_t*)tpc.up_proj;
|
||||
delete[] (uint8_t*)tpc.down_proj;
|
||||
delete[] (float*)tpc.gate_scale;
|
||||
delete[] (float*)tpc.up_scale;
|
||||
delete[] (float*)tpc.down_scale;
|
||||
});
|
||||
|
||||
this->weights_loaded = true;
|
||||
}
|
||||
|
||||
void write_weight_scale_to_buffer(int gpu_tp_count, int expert_id, const std::vector<uintptr_t>& w13_weight_ptrs,
|
||||
const std::vector<uintptr_t>& w13_scale_ptrs,
|
||||
const std::vector<uintptr_t>& w2_weight_ptrs,
|
||||
const std::vector<uintptr_t>& w2_scale_ptrs) {
|
||||
if (this->weights_loaded == false) {
|
||||
throw std::runtime_error("Not Loaded");
|
||||
}
|
||||
if (this->tps.empty()) {
|
||||
throw std::runtime_error("No TP parts initialized");
|
||||
}
|
||||
if ((int)w13_weight_ptrs.size() != gpu_tp_count || (int)w13_scale_ptrs.size() != gpu_tp_count ||
|
||||
(int)w2_weight_ptrs.size() != gpu_tp_count || (int)w2_scale_ptrs.size() != gpu_tp_count) {
|
||||
throw std::runtime_error("Pointer arrays size must match gpu_tp_count");
|
||||
}
|
||||
|
||||
this->config.pool->dispense_backend()->do_numa_job([&, this](int i) {
|
||||
this->tps[i]->write_weights_to_buffer(gpu_tp_count, this->tp_count, expert_id, this->config, w13_weight_ptrs,
|
||||
w13_scale_ptrs, w2_weight_ptrs, w2_scale_ptrs);
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
#endif // CPUINFER_OPERATOR_AMX_FP8_MOE_H
|
||||
File diff suppressed because it is too large
Load Diff
@@ -46,6 +46,9 @@ static inline __m512 exp_avx512(__m512 x) {
|
||||
|
||||
static inline __m512 act_fn(__m512 gate_val, __m512 up_val) {
|
||||
__m512 neg_gate_val = _mm512_sub_ps(_mm512_setzero_ps(), gate_val);
|
||||
// Clamp neg_gate_val to avoid exp overflow (exp(88) overflows for float32)
|
||||
const __m512 max_exp_input = _mm512_set1_ps(88.0f);
|
||||
neg_gate_val = _mm512_min_ps(neg_gate_val, max_exp_input);
|
||||
__m512 exp_neg_gate = exp_avx512(neg_gate_val);
|
||||
__m512 denom = _mm512_add_ps(_mm512_set1_ps(1.0f), exp_neg_gate);
|
||||
__m512 act_val = _mm512_div_ps(gate_val, denom);
|
||||
|
||||
@@ -762,6 +762,16 @@ struct GemmKernel224BF {
|
||||
struct BufferC {
|
||||
float* c;
|
||||
int max_m, n;
|
||||
// 物理布局(按 float 元素数):
|
||||
// 逻辑矩阵 C 为 (max_m, n) 行主序,max_m 为 M_STEP 的倍数,
|
||||
// n 按 N_BLOCK 分块。
|
||||
// 存储顺序:
|
||||
// n_block(N_BLOCK 列) → m_block(M_STEP 行) → n_step(N_STEP 列) → (M_STEP×N_STEP) 行主序 tile。
|
||||
// 因此可视为 5D:
|
||||
// c[n_blocks][m_blocks][n_steps][M_STEP][N_STEP],
|
||||
// n_blocks = ceil(n / N_BLOCK),m_blocks = max_m / M_STEP,
|
||||
// n_steps = N_BLOCK / N_STEP(尾块可能更小)。
|
||||
// get_submat(m_begin, n_begin) 返回连续的 (M_STEP×N_STEP) tile 起始地址。
|
||||
|
||||
static size_t required_size(int max_m, int n) { return sizeof(float) * max_m * n; }
|
||||
|
||||
|
||||
488
kt-kernel/operators/amx/la/amx_raw_buffers.hpp
Normal file
488
kt-kernel/operators/amx/la/amx_raw_buffers.hpp
Normal file
@@ -0,0 +1,488 @@
|
||||
#ifndef AMX_RAW_BUFFERS_HPP
|
||||
#define AMX_RAW_BUFFERS_HPP
|
||||
|
||||
/**
|
||||
* @file amx_raw_buffers.hpp
|
||||
* @brief Raw data format buffer management (FP8, BF16, etc.)
|
||||
*
|
||||
* 本文件实现原精度格式的缓冲区管理,用于 DeepSeek V3.2 等原精度推理。
|
||||
*
|
||||
* 缓冲区类型:
|
||||
* - BufferAFP8Impl: 输入激活缓冲区,支持动态 FP8 量化
|
||||
* - BufferBFP8Impl: 权重缓冲区,FP8 格式 + 128x128 块缩放
|
||||
* - BufferBFP8BlockImpl: 优化的块量化权重缓冲区
|
||||
*
|
||||
* 内存布局:
|
||||
* - FP8 数据:1 字节/元素
|
||||
* - Scale:4 字节/块(BufferB 每 128x128 块一个,BufferA 每 128 行一个)
|
||||
*/
|
||||
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
#include <cstdint>
|
||||
#include <cstdio>
|
||||
#include <cstring>
|
||||
#include <limits>
|
||||
#include <vector>
|
||||
|
||||
#include "amx_config.hpp"
|
||||
#include "amx_utils.hpp"
|
||||
#include "llama.cpp/ggml-impl.h"
|
||||
#include "pack.hpp"
|
||||
#include "utils.hpp"
|
||||
|
||||
namespace amx {
|
||||
|
||||
// ============================================================================
|
||||
// BufferAFP8Impl: FP8 激活缓冲区(支持动态量化)
|
||||
// ============================================================================
|
||||
/* 物理布局(按 bf16 元素数):
|
||||
* 逻辑矩阵 A 为 (m, k) 行主序,m pad 到 max_m(=m_block_size,M_STEP 的倍数)。
|
||||
* 存储顺序:
|
||||
* k_block(K_BLOCK 列) → m_block(M_STEP 行) → k_step(K_STEP 列) → (M_STEP×K_STEP) 行主序 tile。
|
||||
* 因此可视为 5D:
|
||||
* a[k_blocks][m_blocks][k_steps][M_STEP][K_STEP],
|
||||
* k_blocks = ceil(k / K_BLOCK),m_blocks = max_m / M_STEP,
|
||||
* k_steps = K_BLOCK / K_STEP(最后一个 k_block 可能更小)。
|
||||
* get_submat(m_begin, k_begin) 返回连续的 (M_STEP×K_STEP) tile。
|
||||
*/
|
||||
template <typename K>
|
||||
struct BufferABF16Impl {
|
||||
ggml_bf16_t* a;
|
||||
int max_m, k;
|
||||
static constexpr int M_STEP = K::M_STEP;
|
||||
static constexpr int K_STEP = K::K_STEP;
|
||||
static constexpr int K_BLOCK = K::K_BLOCK;
|
||||
|
||||
static size_t required_size(int max_m, int k) { return sizeof(ggml_bf16_t) * max_m * k; }
|
||||
|
||||
BufferABF16Impl(int max_m, int k, void* ptr) : max_m(max_m), k(k) {
|
||||
assert(reinterpret_cast<intptr_t>(ptr) % 64 == 0);
|
||||
assert(max_m % M_STEP == 0);
|
||||
assert(k % K_STEP == 0);
|
||||
a = reinterpret_cast<ggml_bf16_t*>(ptr);
|
||||
}
|
||||
|
||||
void set_data(void* new_ptr) { a = reinterpret_cast<ggml_bf16_t*>(new_ptr); }
|
||||
|
||||
void from_mat(int m, ggml_bf16_t* src, int ith, int nth) {
|
||||
assert(m <= max_m);
|
||||
assert(ith == 0 && nth == 1);
|
||||
int m_block_size = (m + M_STEP - 1) / M_STEP * M_STEP;
|
||||
for (int m_begin = 0; m_begin < m; m_begin += M_STEP) {
|
||||
for (int k_block_begin = 0; k_block_begin < k; k_block_begin += K_BLOCK) {
|
||||
int k_block_size = std::min(K_BLOCK, k - k_block_begin);
|
||||
for (int k_begin = 0; k_begin < k_block_size; k_begin += K_STEP) {
|
||||
for (int i = 0; i < M_STEP && m_begin + i < m; i++) {
|
||||
__m512i* s = (__m512i*)(src + (m_begin + i) * k + k_block_begin + k_begin);
|
||||
__m512i* d =
|
||||
(__m512i*)(a + k_block_begin * m_block_size + m_begin * k_block_size + k_begin * M_STEP + i * K_STEP);
|
||||
avx512_copy_32xbf16(s, d);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ggml_bf16_t* get_submat(int m, int k, int m_begin, int k_begin) {
|
||||
int m_block_size = (m + M_STEP - 1) / M_STEP * M_STEP;
|
||||
int k_block_begin = k_begin / K_BLOCK * K_BLOCK;
|
||||
k_begin -= k_block_begin;
|
||||
int k_block_size = std::min(K_BLOCK, k - k_block_begin);
|
||||
return a + k_block_begin * m_block_size + m_begin * k_block_size + k_begin * M_STEP;
|
||||
}
|
||||
};
|
||||
|
||||
// ============================================================================
|
||||
// BufferB
|
||||
// ============================================================================
|
||||
|
||||
/**
|
||||
* @brief BF16 BufferB
|
||||
* 物理布局(按 bf16 元素数):
|
||||
* 逻辑矩阵 B 为 (n, k) 行主序(用于 NT GEMM),n 按 N_BLOCK 分块。
|
||||
* 存储顺序:
|
||||
* n_block(N_BLOCK 行) → k_block(K_BLOCK 列) → n_step(N_STEP 行) → k_step(K_STEP 列)
|
||||
* → (N_STEP×K_STEP) tile;每个 tile 内部再对两个 16×16 子块做 transpose,
|
||||
* 以匹配 AMX BTile 的 VNNI 布局(TILE_K/VNNI_BLK × TILE_N*VNNI_BLK)。
|
||||
* 因此可视为 6D:
|
||||
* b[n_blocks][k_blocks][n_steps][k_steps][N_STEP][K_STEP],
|
||||
* n_blocks = ceil(n / N_BLOCK),k_blocks = ceil(k / K_BLOCK),
|
||||
* n_steps = N_BLOCK / N_STEP,k_steps = K_BLOCK / K_STEP(尾块可能更小)。
|
||||
* get_submat(n_begin, k_begin) 返回连续的 (N_STEP×K_STEP) tile 起始地址。
|
||||
* @tparam K Kernel 类型
|
||||
*/
|
||||
|
||||
template <typename K>
|
||||
struct BufferBBF16Impl {
|
||||
ggml_bf16_t* b;
|
||||
int n, k;
|
||||
static constexpr bool SCALE = false;
|
||||
static constexpr int N_STEP = K::N_STEP;
|
||||
static constexpr int K_STEP = K::K_STEP;
|
||||
static constexpr int N_BLOCK = K::N_BLOCK;
|
||||
static constexpr int K_BLOCK = K::K_BLOCK;
|
||||
static constexpr int TILE_N = K::TILE_N;
|
||||
static size_t required_size(int n, int k) { return sizeof(ggml_bf16_t) * n * k; }
|
||||
|
||||
BufferBBF16Impl(int n, int k, void* ptr) : n(n), k(k) {
|
||||
assert(reinterpret_cast<intptr_t>(ptr) % 64 == 0);
|
||||
assert(n % N_STEP == 0);
|
||||
assert(k % K_STEP == 0);
|
||||
b = reinterpret_cast<ggml_bf16_t*>(ptr);
|
||||
}
|
||||
void set_data(void* new_ptr) { b = reinterpret_cast<ggml_bf16_t*>(new_ptr); }
|
||||
|
||||
void from_mat(ggml_bf16_t* src, int ith, int nth) {
|
||||
auto [n_start, n_end] = K::split_range_n(n, ith, nth);
|
||||
int n_block_begin = n_start;
|
||||
int n_block_size = n_end - n_block_begin;
|
||||
for (int n_begin = 0; n_begin < n_block_size; n_begin += N_STEP) {
|
||||
for (int k_block_begin = 0; k_block_begin < k; k_block_begin += K_BLOCK) {
|
||||
int k_block_size = std::min(K_BLOCK, k - k_block_begin);
|
||||
for (int k_begin = 0; k_begin < k_block_size; k_begin += K_STEP) {
|
||||
for (int i = 0; i < N_STEP; i++) {
|
||||
__m512i* s = (__m512i*)(src + (n_block_begin + n_begin + i) * k + k_block_begin + k_begin);
|
||||
__m512i* d = (__m512i*)(b + n_block_begin * k + k_block_begin * n_block_size + n_begin * k_block_size +
|
||||
k_begin * N_STEP + i * K_STEP);
|
||||
avx512_copy_32xbf16(s, d);
|
||||
}
|
||||
transpose_16x16_32bit((__m512i*)(b + n_block_begin * k + k_block_begin * n_block_size +
|
||||
n_begin * k_block_size + k_begin * N_STEP));
|
||||
transpose_16x16_32bit((__m512i*)(b + n_block_begin * k + k_block_begin * n_block_size +
|
||||
n_begin * k_block_size + k_begin * N_STEP + TILE_N * K_STEP));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
ggml_bf16_t* get_submat(int n, int k, int n_begin, int k_begin) {
|
||||
int n_block_begin = n_begin / N_BLOCK * N_BLOCK;
|
||||
n_begin -= n_block_begin;
|
||||
int n_block_size = std::min(N_BLOCK, n - n_block_begin);
|
||||
int k_block_begin = k_begin / K_BLOCK * K_BLOCK;
|
||||
k_begin -= k_block_begin;
|
||||
int k_block_size = std::min(K_BLOCK, k - k_block_begin);
|
||||
return b + n_block_begin * k + k_block_begin * n_block_size + n_begin * k_block_size + k_begin * N_STEP;
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief FP8 权重缓冲区
|
||||
*
|
||||
* 存储 FP8 格式的权重矩阵,每个 128x128 块有一个缩放因子。
|
||||
* 这与 DeepSeek V3.2 的原精度格式匹配。
|
||||
*
|
||||
* @tparam K Kernel 类型
|
||||
*/
|
||||
template <typename K>
|
||||
struct BufferBFP8Impl {
|
||||
uint8_t* b; // FP8 weight
|
||||
float* d; // scale_inv [n / k_group_size, k / k_group_size]
|
||||
int n, k, k_group_size; // k_group_size = 128 in DeepSeek
|
||||
|
||||
static constexpr int N_STEP = K::N_STEP;
|
||||
static constexpr int K_STEP = K::K_STEP;
|
||||
static constexpr int N_BLOCK = K::N_BLOCK;
|
||||
static constexpr int K_BLOCK = K::K_BLOCK;
|
||||
static constexpr bool SCALE = true;
|
||||
|
||||
/**
|
||||
* @brief 计算所需内存大小
|
||||
*/
|
||||
static size_t required_size(int n, int k, int k_group_size) {
|
||||
int n_blocks_n = (n + k_group_size - 1) / k_group_size;
|
||||
int n_blocks_k = (k + k_group_size - 1) / k_group_size;
|
||||
return sizeof(uint8_t) * n * k + sizeof(float) * n_blocks_n * n_blocks_k;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief 构造函数
|
||||
*/
|
||||
BufferBFP8Impl(int n, int k, int k_group_size, void* ptr) : n(n), k(k), k_group_size(k_group_size) { set_data(ptr); }
|
||||
|
||||
void set_data(void* ptr) {
|
||||
assert(reinterpret_cast<intptr_t>(ptr) % 64 == 0);
|
||||
b = reinterpret_cast<uint8_t*>(ptr);
|
||||
d = reinterpret_cast<float*>(b + (size_t)n * k);
|
||||
}
|
||||
|
||||
static constexpr int mat_offset[8] = {0, 2, 4, 6, 1, 3, 5, 7}; // fp8 matrix offset for reordering
|
||||
/**
|
||||
* @brief 从原始 FP8 权重加载(已经是量化格式)
|
||||
*
|
||||
* @param b_src FP8 权重源数据 (n-major, n×k)
|
||||
* @param d_src FP32 scale_inv 源数据 (n-major, ceil(n/128)×ceil(k/128))
|
||||
*/
|
||||
void from_mat(const uint8_t* b_src, const float* d_src, int ith, int nth) {
|
||||
assert(b != nullptr && d != nullptr);
|
||||
assert(N_STEP == 32 && K_STEP == 32); // from mat block copy assumes this
|
||||
|
||||
// Copy scales (per 128x128 block). Each thread copies its own n-block range.
|
||||
const int n_blocks_k = (k + k_group_size - 1) / k_group_size;
|
||||
if (d_src != nullptr) {
|
||||
auto [n_start, n_end] = K::split_range_n(n, ith, nth);
|
||||
int bn_start = n_start / k_group_size;
|
||||
int bn_end = (n_end + k_group_size - 1) / k_group_size;
|
||||
memcpy(d + bn_start * n_blocks_k, d_src + bn_start * n_blocks_k,
|
||||
sizeof(float) * (bn_end - bn_start) * n_blocks_k);
|
||||
}
|
||||
|
||||
// Reorder FP8 weights into KT block-major layout (same panel->tile order as BF16 BufferB).
|
||||
auto [n_start, n_end] = K::split_range_n(n, ith, nth);
|
||||
int n_block_begin = n_start;
|
||||
int n_block_size = n_end - n_block_begin;
|
||||
for (int n_begin = 0; n_begin < n_block_size; n_begin += N_STEP) {
|
||||
int n_step_size = std::min(N_STEP, n_block_size - n_begin);
|
||||
for (int k_block_begin = 0; k_block_begin < k; k_block_begin += K_BLOCK) {
|
||||
int k_block_size = std::min(K_BLOCK, k - k_block_begin);
|
||||
for (int k_begin = 0; k_begin < k_block_size; k_begin += K_STEP) {
|
||||
int k_step_size = std::min(K_STEP, k_block_size - k_begin);
|
||||
// [k_step_size, n_step_size] block copy
|
||||
const uint8_t* block_b_src = b_src + (size_t)(n_block_begin + n_begin) * k + k_block_begin + k_begin;
|
||||
uint64_t* block_b_dst =
|
||||
reinterpret_cast<uint64_t*>(b + (size_t)n_block_begin * k + (size_t)k_block_begin * n_block_size +
|
||||
(size_t)n_begin * k_block_size + (size_t)k_begin * N_STEP);
|
||||
for (int i = 0; i < 8; i++) {
|
||||
const uint16_t* s = reinterpret_cast<const uint16_t*>(block_b_src + (size_t)i * k * 4);
|
||||
for (int j = 0; j < 16; j++) {
|
||||
uint64_t val = (((uint64_t)s[j])) | (((uint64_t)s[j + (k / 2) * 1]) << 16) |
|
||||
(((uint64_t)s[j + (k / 2) * 2]) << 32) | (((uint64_t)s[j + (k / 2) * 3]) << 48);
|
||||
block_b_dst[8 * j + mat_offset[i]] = val;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief get scale_inv
|
||||
*/
|
||||
float* get_scale(int n, int n_begin, int k, int k_begin) {
|
||||
int n_blocks_k = (k + k_group_size - 1) / k_group_size;
|
||||
int bn = n_begin / k_group_size;
|
||||
int bk = k_begin / k_group_size;
|
||||
return d + bn * n_blocks_k + bk;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief 获取子矩阵指针
|
||||
*/
|
||||
uint8_t* get_submat(int n, int k, int n_begin, int k_begin) {
|
||||
int n_block_begin = n_begin / N_BLOCK * N_BLOCK;
|
||||
n_begin -= n_block_begin;
|
||||
int n_block_size = std::min(N_BLOCK, n - n_block_begin);
|
||||
int k_block_begin = k_begin / K_BLOCK * K_BLOCK;
|
||||
k_begin -= k_block_begin;
|
||||
int k_block_size = std::min(K_BLOCK, k - k_block_begin);
|
||||
return b + (size_t)n_block_begin * k + (size_t)k_block_begin * n_block_size + (size_t)n_begin * k_block_size +
|
||||
(size_t)k_begin * N_STEP;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Inverse mapping for mat_offset used in to_mat
|
||||
* mat_offset = {0, 2, 4, 6, 1, 3, 5, 7}
|
||||
* inv_mat_offset[mat_offset[i]] = i
|
||||
*/
|
||||
static constexpr int inv_mat_offset[8] = {0, 4, 1, 5, 2, 6, 3, 7};
|
||||
|
||||
/**
|
||||
* @brief Unpack FP8 weights from KT block-major layout back to n-major layout
|
||||
*
|
||||
* This is the inverse operation of from_mat.
|
||||
*
|
||||
* @param b_dst FP8 输出缓冲区 (n-major, n×k)
|
||||
* @param d_dst FP32 scale_inv 输出缓冲区 (n-major, ceil(n/128)×ceil(k/128))
|
||||
* @param ith Thread index
|
||||
* @param nth Total number of threads
|
||||
*/
|
||||
void to_mat(uint8_t* b_dst, float* d_dst, int ith, int nth) const {
|
||||
assert(b != nullptr && d != nullptr);
|
||||
assert(N_STEP == 32 && K_STEP == 32);
|
||||
|
||||
// Calculate N_BLOCK range for this thread
|
||||
// Unlike split_range_n which gives one N_BLOCK per thread, we need to handle
|
||||
// the case where nth < n/N_BLOCK (fewer threads than blocks)
|
||||
int total_n_blocks = (n + N_BLOCK - 1) / N_BLOCK;
|
||||
int blocks_per_thread = (total_n_blocks + nth - 1) / nth;
|
||||
int start_n_block_idx = ith * blocks_per_thread;
|
||||
int end_n_block_idx = std::min((ith + 1) * blocks_per_thread, total_n_blocks);
|
||||
|
||||
// Copy scales (per 128x128 block). Each thread copies its own n-block range.
|
||||
const int n_blocks_k = (k + k_group_size - 1) / k_group_size;
|
||||
if (d_dst != nullptr) {
|
||||
int bn_start = start_n_block_idx;
|
||||
int bn_end = end_n_block_idx;
|
||||
memcpy(d_dst + bn_start * n_blocks_k, d + bn_start * n_blocks_k,
|
||||
sizeof(float) * (bn_end - bn_start) * n_blocks_k);
|
||||
}
|
||||
|
||||
// Reorder FP8 weights back to n-major layout (inverse of from_mat)
|
||||
// Process each N_BLOCK assigned to this thread
|
||||
for (int n_block_idx = start_n_block_idx; n_block_idx < end_n_block_idx; n_block_idx++) {
|
||||
int n_block_begin = n_block_idx * N_BLOCK;
|
||||
int n_block_size = std::min(N_BLOCK, n - n_block_begin);
|
||||
|
||||
for (int n_begin = 0; n_begin < n_block_size; n_begin += N_STEP) {
|
||||
for (int k_block_begin = 0; k_block_begin < k; k_block_begin += K_BLOCK) {
|
||||
int k_block_size = std::min(K_BLOCK, k - k_block_begin);
|
||||
for (int k_begin = 0; k_begin < k_block_size; k_begin += K_STEP) {
|
||||
// Source: packed layout (KT block-major)
|
||||
const uint64_t* block_b_src =
|
||||
reinterpret_cast<const uint64_t*>(b + (size_t)n_block_begin * k + (size_t)k_block_begin * n_block_size +
|
||||
(size_t)n_begin * k_block_size + (size_t)k_begin * N_STEP);
|
||||
|
||||
// Destination: n-major layout
|
||||
uint8_t* block_b_dst = b_dst + (size_t)(n_block_begin + n_begin) * k + k_block_begin + k_begin;
|
||||
|
||||
// Inverse of from_mat transformation
|
||||
for (int packed_i = 0; packed_i < 8; packed_i++) {
|
||||
int i = inv_mat_offset[packed_i];
|
||||
uint16_t* d_row = reinterpret_cast<uint16_t*>(block_b_dst + (size_t)i * k * 4);
|
||||
for (int j = 0; j < 16; j++) {
|
||||
uint64_t val = block_b_src[8 * j + packed_i];
|
||||
d_row[j] = (uint16_t)(val & 0xFFFF);
|
||||
d_row[j + (k / 2) * 1] = (uint16_t)((val >> 16) & 0xFFFF);
|
||||
d_row[j + (k / 2) * 2] = (uint16_t)((val >> 32) & 0xFFFF);
|
||||
d_row[j + (k / 2) * 3] = (uint16_t)((val >> 48) & 0xFFFF);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// ============================================================================
|
||||
// BufferCFP8Impl: FP32 输出缓冲区
|
||||
// ============================================================================
|
||||
|
||||
/**
|
||||
* @brief FP32 输出缓冲区
|
||||
*
|
||||
* 存储 FP32 格式的累加器,支持转换为 BF16 输出
|
||||
*
|
||||
* @tparam K Kernel 类型
|
||||
*/
|
||||
template <typename K>
|
||||
struct BufferCFP32Impl {
|
||||
float* c;
|
||||
int max_m, n;
|
||||
static constexpr int M_STEP = K::M_STEP;
|
||||
static constexpr int N_STEP = K::N_STEP;
|
||||
static constexpr int N_BLOCK = K::N_BLOCK;
|
||||
// 物理布局(按 float 元素数):
|
||||
// 逻辑矩阵 C 为 (max_m, n) 行主序,max_m 为 M_STEP 的倍数,
|
||||
// n 按 N_BLOCK 分块。
|
||||
// 存储顺序:
|
||||
// n_block(N_BLOCK 列) → m_block(M_STEP 行) → n_step(N_STEP 列) → (M_STEP×N_STEP) 行主序 tile。
|
||||
// 因此可视为 5D:
|
||||
// c[n_blocks][m_blocks][n_steps][M_STEP][N_STEP],
|
||||
// n_blocks = ceil(n / N_BLOCK),m_blocks = max_m / M_STEP,
|
||||
// n_steps = N_BLOCK / N_STEP(尾块可能更小)。
|
||||
// get_submat(m_begin, n_begin) 返回连续的 (M_STEP×N_STEP) tile 起始地址。
|
||||
|
||||
static size_t required_size(int max_m, int n) { return sizeof(float) * max_m * n; }
|
||||
|
||||
BufferCFP32Impl(int max_m, int n, void* ptr) : max_m(max_m), n(n) {
|
||||
assert(reinterpret_cast<intptr_t>(ptr) % 64 == 0);
|
||||
assert(max_m % M_STEP == 0);
|
||||
assert(n % N_STEP == 0);
|
||||
c = reinterpret_cast<float*>(ptr);
|
||||
}
|
||||
|
||||
void set_data(void* new_ptr) { c = reinterpret_cast<float*>(new_ptr); }
|
||||
|
||||
void to_mat(int m, ggml_bf16_t* dst, int ith, int nth) {
|
||||
assert(m <= max_m);
|
||||
auto [n_start, n_end] = K::split_range_n(n, ith, nth);
|
||||
int m_block_size = (m + M_STEP - 1) / M_STEP * M_STEP;
|
||||
int n_block_begin = n_start;
|
||||
int n_block_size = n_end - n_block_begin;
|
||||
for (int m_begin = 0; m_begin < m; m_begin += M_STEP) {
|
||||
for (int n_begin = 0; n_begin < n_block_size; n_begin += N_STEP) {
|
||||
for (int i = 0; i < M_STEP && m_begin + i < m; i++) {
|
||||
__m512* x0 =
|
||||
(__m512*)(c + m_block_size * n_block_begin + m_begin * n_block_size + n_begin * M_STEP + i * N_STEP);
|
||||
__m512* x1 =
|
||||
(__m512*)(c + m_block_size * n_block_begin + m_begin * n_block_size + n_begin * M_STEP + i * N_STEP + 16);
|
||||
avx512_32xfp32_to_32xbf16(x0, x1, (__m512i*)(dst + (m_begin + i) * n + n_block_begin + n_begin));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
float* get_submat(int m, int n, int m_begin, int n_begin) {
|
||||
int m_block_size = (m + M_STEP - 1) / M_STEP * M_STEP;
|
||||
int n_block_begin = n_begin / N_BLOCK * N_BLOCK;
|
||||
int n_block_size = std::min(N_BLOCK, n - n_block_begin);
|
||||
n_begin -= n_block_begin;
|
||||
return c + m_block_size * n_block_begin + m_begin * n_block_size + n_begin * M_STEP;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename K>
|
||||
struct BufferCFP32ReduceImpl {
|
||||
float* c;
|
||||
float* reduce_buf;
|
||||
int max_m, n;
|
||||
|
||||
static constexpr int M_STEP = K::M_STEP;
|
||||
static constexpr int N_STEP = K::N_STEP;
|
||||
static constexpr int N_BLOCK = K::N_BLOCK;
|
||||
|
||||
static size_t required_size(int max_m, int n) { return sizeof(float) * (size_t)max_m * n * 2; }
|
||||
|
||||
BufferCFP32ReduceImpl(int max_m, int n, void* ptr) : max_m(max_m), n(n) {
|
||||
assert(max_m % M_STEP == 0);
|
||||
assert(n % N_STEP == 0);
|
||||
set_data(ptr);
|
||||
}
|
||||
|
||||
void set_data(void* ptr) {
|
||||
assert(reinterpret_cast<intptr_t>(ptr) % 64 == 0);
|
||||
c = reinterpret_cast<float*>(ptr);
|
||||
reduce_buf = c + (size_t)max_m * n;
|
||||
}
|
||||
|
||||
void to_mat(int m, ggml_bf16_t* dst, int ith, int nth) {
|
||||
assert(m <= max_m);
|
||||
auto [n_start, n_end] = K::split_range_n(n, ith, nth);
|
||||
int m_block_size = (m + M_STEP - 1) / M_STEP * M_STEP;
|
||||
int n_block_begin = n_start;
|
||||
int n_block_size = n_end - n_block_begin;
|
||||
for (int m_begin = 0; m_begin < m; m_begin += M_STEP) {
|
||||
for (int n_begin = 0; n_begin < n_block_size; n_begin += N_STEP) {
|
||||
for (int i = 0; i < M_STEP && m_begin + i < m; i++) {
|
||||
__m512* x0 =
|
||||
(__m512*)(c + m_block_size * n_block_begin + m_begin * n_block_size + n_begin * M_STEP + i * N_STEP);
|
||||
__m512* x1 =
|
||||
(__m512*)(c + m_block_size * n_block_begin + m_begin * n_block_size + n_begin * M_STEP + i * N_STEP + 16);
|
||||
avx512_32xfp32_to_32xbf16(x0, x1, (__m512i*)(dst + (m_begin + i) * n + n_block_begin + n_begin));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
float* get_submat(int m, int n, int m_begin, int n_begin) {
|
||||
int m_block_size = (m + M_STEP - 1) / M_STEP * M_STEP;
|
||||
int n_block_begin = n_begin / N_BLOCK * N_BLOCK;
|
||||
int n_block_size = std::min(N_BLOCK, n - n_block_begin);
|
||||
n_begin -= n_block_begin;
|
||||
return c + (size_t)m_block_size * n_block_begin + (size_t)m_begin * n_block_size + (size_t)n_begin * M_STEP;
|
||||
}
|
||||
|
||||
float* get_reduce_submat(int m, int n, int m_begin, int n_begin) {
|
||||
int m_block_size = (m + M_STEP - 1) / M_STEP * M_STEP;
|
||||
int n_block_begin = n_begin / N_BLOCK * N_BLOCK;
|
||||
int n_block_size = std::min(N_BLOCK, n - n_block_begin);
|
||||
n_begin -= n_block_begin;
|
||||
return reduce_buf + (size_t)m_block_size * n_block_begin + (size_t)m_begin * n_block_size +
|
||||
(size_t)n_begin * M_STEP;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace amx
|
||||
|
||||
#endif // AMX_RAW_BUFFERS_HPP
|
||||
464
kt-kernel/operators/amx/la/amx_raw_kernels.hpp
Normal file
464
kt-kernel/operators/amx/la/amx_raw_kernels.hpp
Normal file
@@ -0,0 +1,464 @@
|
||||
#ifndef AMX_RAW_KERNELS_HPP
|
||||
#define AMX_RAW_KERNELS_HPP
|
||||
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
#include <cmath>
|
||||
#include <cstdint>
|
||||
#include <string>
|
||||
|
||||
#include "amx_config.hpp"
|
||||
#include "amx_raw_buffers.hpp"
|
||||
#include "amx_utils.hpp"
|
||||
#include "llama.cpp/ggml-impl.h"
|
||||
|
||||
namespace amx {
|
||||
|
||||
struct GemmKernel224BF16 {
|
||||
using dt = ggml_bf16_t;
|
||||
using output_t = float;
|
||||
static constexpr double ELEMENT_SIZE = 2;
|
||||
static const int TILE_M = 16;
|
||||
static const int TILE_K = 32;
|
||||
static const int TILE_N = 16;
|
||||
static const int VNNI_BLK = 2;
|
||||
|
||||
static const int M_STEP = TILE_M * 2;
|
||||
static const int N_STEP = TILE_N * 2;
|
||||
static const int K_STEP = TILE_K;
|
||||
|
||||
static inline const int N_BLOCK = 256;
|
||||
static inline const int K_BLOCK = 1792;
|
||||
static std::string name() { return "BF16"; }
|
||||
|
||||
static int recommended_nth(int n) { return (n + N_BLOCK - 1) / N_BLOCK; }
|
||||
|
||||
static std::pair<int, int> split_range_n(int n, int ith, int nth) {
|
||||
int n_start = N_BLOCK * ith;
|
||||
int n_end = std::min(n, N_BLOCK * (ith + 1));
|
||||
return {n_start, n_end};
|
||||
}
|
||||
|
||||
static void config() {
|
||||
#ifdef HAVE_AMX
|
||||
enable_amx();
|
||||
TileConfig tile_config;
|
||||
|
||||
// size is 16 x 32
|
||||
for (int i = 0; i < 2; i++) tile_config.set_row_col(i, TILE_M, TILE_K * sizeof(dt));
|
||||
|
||||
// size is 16 x 32
|
||||
for (int i = 2; i < 4; i++) tile_config.set_row_col(i, TILE_K / VNNI_BLK, TILE_N * VNNI_BLK * sizeof(dt));
|
||||
|
||||
// size is 16 x 16
|
||||
for (int i = 4; i < 8; i++) tile_config.set_row_col(i, TILE_M, TILE_N * sizeof(output_t));
|
||||
|
||||
tile_config.set_config();
|
||||
#endif
|
||||
}
|
||||
|
||||
static void load_a(dt* a, size_t lda) {
|
||||
#ifdef HAVE_AMX
|
||||
_tile_loadd(0, a, lda);
|
||||
_tile_loadd(1, offset_pointer(a, lda * TILE_M), lda);
|
||||
#else
|
||||
(void)a;
|
||||
(void)lda;
|
||||
#endif
|
||||
}
|
||||
|
||||
static void load_b(dt* b, size_t ldb) {
|
||||
#ifdef HAVE_AMX
|
||||
_tile_loadd(2, b, ldb);
|
||||
_tile_loadd(3, offset_pointer(b, ldb * TILE_N), ldb);
|
||||
#else
|
||||
(void)b;
|
||||
(void)ldb;
|
||||
#endif
|
||||
}
|
||||
|
||||
static void clean_c() {
|
||||
#ifdef HAVE_AMX
|
||||
_tile_zero(4);
|
||||
_tile_zero(5);
|
||||
_tile_zero(6);
|
||||
_tile_zero(7);
|
||||
#endif
|
||||
}
|
||||
|
||||
static void load_c(output_t* c, size_t ldc) {
|
||||
#ifdef HAVE_AMX
|
||||
_tile_loadd(4, c, ldc);
|
||||
_tile_loadd(5, offset_pointer(c, TILE_N * sizeof(output_t)), ldc);
|
||||
_tile_loadd(6, offset_pointer(c, ldc * TILE_M), ldc);
|
||||
_tile_loadd(7, offset_pointer(c, ldc * TILE_M + TILE_N * sizeof(output_t)), ldc);
|
||||
#else
|
||||
(void)c;
|
||||
(void)ldc;
|
||||
#endif
|
||||
}
|
||||
|
||||
static void store_c(output_t* c, size_t ldc) {
|
||||
#ifdef HAVE_AMX
|
||||
_tile_stored(4, c, ldc);
|
||||
_tile_stored(5, offset_pointer(c, TILE_N * sizeof(output_t)), ldc);
|
||||
_tile_stored(6, offset_pointer(c, ldc * TILE_M), ldc);
|
||||
_tile_stored(7, offset_pointer(c, ldc * TILE_M + TILE_N * sizeof(output_t)), ldc);
|
||||
#else
|
||||
(void)c;
|
||||
(void)ldc;
|
||||
#endif
|
||||
}
|
||||
|
||||
static void run_tile() {
|
||||
#ifdef HAVE_AMX
|
||||
_tile_dpbf16ps(4, 0, 2);
|
||||
_tile_dpbf16ps(5, 0, 3);
|
||||
_tile_dpbf16ps(6, 1, 2);
|
||||
_tile_dpbf16ps(7, 1, 3);
|
||||
#endif
|
||||
}
|
||||
using BufferA = BufferABF16Impl<GemmKernel224BF16>;
|
||||
using BufferB = BufferBBF16Impl<GemmKernel224BF16>;
|
||||
using BufferC = BufferCFP32Impl<GemmKernel224BF16>;
|
||||
};
|
||||
|
||||
// FP8 (e4m3) AMX kernel that mirrors the GemmKernel224BF16 interface.
|
||||
struct GemmKernel224FP8 {
|
||||
using fp8_t = uint8_t;
|
||||
using output_t = float;
|
||||
|
||||
static constexpr double ELEMENT_SIZE = 1.0;
|
||||
static const int TILE_M = 16;
|
||||
static const int TILE_K = 32;
|
||||
static const int TILE_N = 16;
|
||||
static const int VNNI_BLK = 2;
|
||||
|
||||
static const int M_STEP = TILE_M * 2;
|
||||
static const int N_STEP = TILE_N * 2;
|
||||
static const int K_STEP = TILE_K;
|
||||
|
||||
static inline const int BLOCK_SIZE = 128; // 128 x 128 block quantization
|
||||
static inline const int N_BLOCK = 128;
|
||||
static inline const int K_BLOCK = 7168;
|
||||
|
||||
static std::string name() { return "FP8"; }
|
||||
|
||||
static int recommended_nth(int n) { return (n + N_BLOCK - 1) / N_BLOCK; }
|
||||
|
||||
static std::pair<int, int> split_range_n(int n, int ith, int nth) {
|
||||
int n_start = N_BLOCK * ith;
|
||||
int n_end = std::min(n, N_BLOCK * (ith + 1));
|
||||
return {n_start, n_end};
|
||||
}
|
||||
|
||||
static void config() {}
|
||||
|
||||
private:
|
||||
alignas(64) static constexpr uint8_t bf16_hi_0_val[64] = {
|
||||
0x00, 0x3b, 0x3b, 0x3b, 0x3c, 0x3c, 0x3c, 0x3c, 0x3c, 0x3c, 0x3c, 0x3c, 0x3c, 0x3c, 0x3c, 0x3c,
|
||||
0x3d, 0x3d, 0x3d, 0x3d, 0x3d, 0x3d, 0x3d, 0x3d, 0x3d, 0x3d, 0x3d, 0x3d, 0x3d, 0x3d, 0x3d, 0x3d,
|
||||
0x3e, 0x3e, 0x3e, 0x3e, 0x3e, 0x3e, 0x3e, 0x3e, 0x3e, 0x3e, 0x3e, 0x3e, 0x3e, 0x3e, 0x3e, 0x3e,
|
||||
0x3f, 0x3f, 0x3f, 0x3f, 0x3f, 0x3f, 0x3f, 0x3f, 0x3f, 0x3f, 0x3f, 0x3f, 0x3f, 0x3f, 0x3f, 0x3f,
|
||||
};
|
||||
alignas(64) static constexpr uint8_t bf16_hi_1_val[64] = {
|
||||
0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40,
|
||||
0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41,
|
||||
0x42, 0x42, 0x42, 0x42, 0x42, 0x42, 0x42, 0x42, 0x42, 0x42, 0x42, 0x42, 0x42, 0x42, 0x42, 0x42,
|
||||
0x43, 0x43, 0x43, 0x43, 0x43, 0x43, 0x43, 0x43, 0x43, 0x43, 0x43, 0x43, 0x43, 0x43, 0x43, 0x43,
|
||||
};
|
||||
alignas(64) static constexpr uint8_t bf16_lo_0_val[64] = {
|
||||
0x00, 0x00, 0x80, 0xc0, 0x00, 0x20, 0x40, 0x60, 0x80, 0x90, 0xa0, 0xb0, 0xc0, 0xd0, 0xe0, 0xf0,
|
||||
0x00, 0x10, 0x20, 0x30, 0x40, 0x50, 0x60, 0x70, 0x80, 0x90, 0xa0, 0xb0, 0xc0, 0xd0, 0xe0, 0xf0,
|
||||
0x00, 0x10, 0x20, 0x30, 0x40, 0x50, 0x60, 0x70, 0x80, 0x90, 0xa0, 0xb0, 0xc0, 0xd0, 0xe0, 0xf0,
|
||||
0x00, 0x10, 0x20, 0x30, 0x40, 0x50, 0x60, 0x70, 0x80, 0x90, 0xa0, 0xb0, 0xc0, 0xd0, 0xe0, 0xf0,
|
||||
};
|
||||
alignas(64) static constexpr uint8_t bf16_lo_1_val[64] = {
|
||||
0x00, 0x10, 0x20, 0x30, 0x40, 0x50, 0x60, 0x70, 0x80, 0x90, 0xa0, 0xb0, 0xc0, 0xd0, 0xe0, 0xf0,
|
||||
0x00, 0x10, 0x20, 0x30, 0x40, 0x50, 0x60, 0x70, 0x80, 0x90, 0xa0, 0xb0, 0xc0, 0xd0, 0xe0, 0xf0,
|
||||
0x00, 0x10, 0x20, 0x30, 0x40, 0x50, 0x60, 0x70, 0x80, 0x90, 0xa0, 0xb0, 0xc0, 0xd0, 0xe0, 0xf0,
|
||||
0x00, 0x10, 0x20, 0x30, 0x40, 0x50, 0x60, 0x70, 0x80, 0x90, 0xa0, 0xb0, 0xc0, 0xd0, 0xe0, 0xf0,
|
||||
};
|
||||
// _mm512_set1_epi8 is not constexpr; keep it as a static cached value
|
||||
alignas(64) static const __m512i sign_mask_val;
|
||||
static inline __m512i bf16_hi_0_mask() { return _mm512_load_si512((__m512i const*)bf16_hi_0_val); }
|
||||
static inline __m512i bf16_hi_1_mask() { return _mm512_load_si512((__m512i const*)bf16_hi_1_val); }
|
||||
static inline __m512i bf16_lo_0_mask() { return _mm512_load_si512((__m512i const*)bf16_lo_0_val); }
|
||||
static inline __m512i bf16_lo_1_mask() { return _mm512_load_si512((__m512i const*)bf16_lo_1_val); }
|
||||
static inline __m512i sign_mask() { return _mm512_set1_epi8(0x80); }
|
||||
|
||||
public:
|
||||
using BufferA = BufferABF16Impl<GemmKernel224FP8>;
|
||||
using BufferB = BufferBFP8Impl<GemmKernel224FP8>;
|
||||
using BufferC = BufferCFP32ReduceImpl<GemmKernel224FP8>;
|
||||
|
||||
static inline std::pair<__m512i, __m512i> fp8x64_to_bf16x64(__m512i bfp8_512) {
|
||||
// fp8->bf16
|
||||
__m512i b_hi = _mm512_permutex2var_epi8(bf16_hi_0_mask(), bfp8_512, bf16_hi_1_mask());
|
||||
__m512i b_lo = _mm512_permutex2var_epi8(bf16_lo_0_mask(), bfp8_512, bf16_lo_1_mask());
|
||||
b_hi = _mm512_or_si512(_mm512_and_si512(sign_mask(), bfp8_512), b_hi);
|
||||
__m512i bbf16_0 = _mm512_unpacklo_epi8(b_lo, b_hi);
|
||||
__m512i bbf16_1 = _mm512_unpackhi_epi8(b_lo, b_hi);
|
||||
return {bbf16_0, bbf16_1};
|
||||
}
|
||||
// Optimized AVX kernel: process entire k_group_size
|
||||
// Load all data first, then convert all, then compute all
|
||||
// This gives compiler more freedom to schedule instructions
|
||||
static void avx_kernel(int m, int n, int k, int m_begin, int n_begin, int k_group_begin, float* c, BufferA* ba,
|
||||
BufferB* bb, int k_group_size) {
|
||||
const __m512i bf16_hi_0_val = bf16_hi_0_mask();
|
||||
const __m512i bf16_hi_1_val = bf16_hi_1_mask();
|
||||
const __m512i bf16_lo_0_val = bf16_lo_0_mask();
|
||||
const __m512i bf16_lo_1_val = bf16_lo_1_mask();
|
||||
const __m512i sign_mask_val = sign_mask();
|
||||
|
||||
__m512* c512 = (__m512*)c;
|
||||
int m_block_end = std::min(m - m_begin, M_STEP);
|
||||
|
||||
// Zero out accumulator at the start
|
||||
for (int m_i = 0; m_i < m_block_end; m_i++) {
|
||||
c512[m_i * 2] = _mm512_setzero_ps();
|
||||
c512[m_i * 2 + 1] = _mm512_setzero_ps();
|
||||
}
|
||||
|
||||
// Process entire k_group_size
|
||||
for (int k_begin = 0; k_begin < k_group_size && k_group_begin + k_begin < k; k_begin += K_STEP) {
|
||||
ggml_bf16_t* abf16 = (ggml_bf16_t*)ba->get_submat(m, k, m_begin, k_group_begin + k_begin);
|
||||
__m512i* bfp8_512 = (__m512i*)bb->get_submat(n, k, n_begin, k_group_begin + k_begin);
|
||||
|
||||
for (int m_i = 0; m_i < m_block_end; m_i++) {
|
||||
// Process 2 k_i per iteration
|
||||
for (int k_i = 0; k_i < 16; k_i += 2) {
|
||||
// Load A vectors
|
||||
__m512bh ma0 = (__m512bh)_mm512_set1_epi32(*(int32_t*)&abf16[m_i * K_STEP + k_i * 2]);
|
||||
__m512bh ma1 = (__m512bh)_mm512_set1_epi32(*(int32_t*)&abf16[m_i * K_STEP + (k_i + 1) * 2]);
|
||||
|
||||
// Load B matrices
|
||||
__m512i bfp8_0 = bfp8_512[k_i];
|
||||
__m512i bfp8_1 = bfp8_512[k_i + 1];
|
||||
|
||||
// Convert FP8 -> BF16 for all
|
||||
__m512i b_hi_0 = _mm512_permutex2var_epi8(bf16_hi_0_val, bfp8_0, bf16_hi_1_val);
|
||||
__m512i b_lo_0 = _mm512_permutex2var_epi8(bf16_lo_0_val, bfp8_0, bf16_lo_1_val);
|
||||
b_hi_0 = _mm512_or_si512(_mm512_and_si512(sign_mask_val, bfp8_0), b_hi_0);
|
||||
|
||||
__m512i b_hi_1 = _mm512_permutex2var_epi8(bf16_hi_0_val, bfp8_1, bf16_hi_1_val);
|
||||
__m512i b_lo_1 = _mm512_permutex2var_epi8(bf16_lo_0_val, bfp8_1, bf16_lo_1_val);
|
||||
b_hi_1 = _mm512_or_si512(_mm512_and_si512(sign_mask_val, bfp8_1), b_hi_1);
|
||||
|
||||
// Compute dpbf16 for all
|
||||
__m512bh bbf16_0_0 = (__m512bh)_mm512_unpacklo_epi8(b_lo_0, b_hi_0);
|
||||
__m512bh bbf16_1_0 = (__m512bh)_mm512_unpackhi_epi8(b_lo_0, b_hi_0);
|
||||
c512[m_i * 2] = _mm512_dpbf16_ps(c512[m_i * 2], ma0, bbf16_0_0);
|
||||
c512[m_i * 2 + 1] = _mm512_dpbf16_ps(c512[m_i * 2 + 1], ma0, bbf16_1_0);
|
||||
|
||||
__m512bh bbf16_0_1 = (__m512bh)_mm512_unpacklo_epi8(b_lo_1, b_hi_1);
|
||||
__m512bh bbf16_1_1 = (__m512bh)_mm512_unpackhi_epi8(b_lo_1, b_hi_1);
|
||||
c512[m_i * 2] = _mm512_dpbf16_ps(c512[m_i * 2], ma1, bbf16_0_1);
|
||||
c512[m_i * 2 + 1] = _mm512_dpbf16_ps(c512[m_i * 2 + 1], ma1, bbf16_1_1);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Optimized AVX kernel: process 4 k_i at once, convert B once and reuse for all m rows
|
||||
// This version achieved ~493 GB/s - restoring as baseline for further optimization
|
||||
static void avx_kernel_4(int m, int n, int k, int m_begin, int n_begin, int k_group_begin, float* c, BufferA* ba,
|
||||
BufferB* bb, int k_group_size) {
|
||||
const __m512i bf16_hi_0 = bf16_hi_0_mask();
|
||||
const __m512i bf16_hi_1 = bf16_hi_1_mask();
|
||||
const __m512i bf16_lo_0 = bf16_lo_0_mask();
|
||||
const __m512i bf16_lo_1 = bf16_lo_1_mask();
|
||||
const __m512i sign_mask_v = sign_mask();
|
||||
|
||||
__m512* c512 = (__m512*)c;
|
||||
int m_block_end = std::min(m - m_begin, M_STEP);
|
||||
|
||||
// Zero out accumulator
|
||||
for (int m_i = 0; m_i < m_block_end; m_i++) {
|
||||
c512[m_i * 2] = _mm512_setzero_ps();
|
||||
c512[m_i * 2 + 1] = _mm512_setzero_ps();
|
||||
}
|
||||
|
||||
// Process entire k_group_size
|
||||
for (int k_begin = 0; k_begin < k_group_size && k_group_begin + k_begin < k; k_begin += K_STEP) {
|
||||
ggml_bf16_t* abf16 = (ggml_bf16_t*)ba->get_submat(m, k, m_begin, k_group_begin + k_begin);
|
||||
__m512i* bfp8_512 = (__m512i*)bb->get_submat(n, k, n_begin, k_group_begin + k_begin);
|
||||
|
||||
// Process 4 k_i at once - convert B and reuse across all m rows
|
||||
for (int k_i = 0; k_i < 16; k_i += 4) {
|
||||
// Load 4 B vectors
|
||||
__m512i bfp8_0 = bfp8_512[k_i];
|
||||
__m512i bfp8_1 = bfp8_512[k_i + 1];
|
||||
__m512i bfp8_2 = bfp8_512[k_i + 2];
|
||||
__m512i bfp8_3 = bfp8_512[k_i + 3];
|
||||
|
||||
// Convert all 4 FP8 -> BF16
|
||||
__m512i b_hi, b_lo;
|
||||
|
||||
b_hi = _mm512_or_si512(_mm512_and_si512(sign_mask_v, bfp8_0),
|
||||
_mm512_permutex2var_epi8(bf16_hi_0, bfp8_0, bf16_hi_1));
|
||||
b_lo = _mm512_permutex2var_epi8(bf16_lo_0, bfp8_0, bf16_lo_1);
|
||||
__m512bh bbf16_0_lo = (__m512bh)_mm512_unpacklo_epi8(b_lo, b_hi);
|
||||
__m512bh bbf16_0_hi = (__m512bh)_mm512_unpackhi_epi8(b_lo, b_hi);
|
||||
|
||||
b_hi = _mm512_or_si512(_mm512_and_si512(sign_mask_v, bfp8_1),
|
||||
_mm512_permutex2var_epi8(bf16_hi_0, bfp8_1, bf16_hi_1));
|
||||
b_lo = _mm512_permutex2var_epi8(bf16_lo_0, bfp8_1, bf16_lo_1);
|
||||
__m512bh bbf16_1_lo = (__m512bh)_mm512_unpacklo_epi8(b_lo, b_hi);
|
||||
__m512bh bbf16_1_hi = (__m512bh)_mm512_unpackhi_epi8(b_lo, b_hi);
|
||||
|
||||
b_hi = _mm512_or_si512(_mm512_and_si512(sign_mask_v, bfp8_2),
|
||||
_mm512_permutex2var_epi8(bf16_hi_0, bfp8_2, bf16_hi_1));
|
||||
b_lo = _mm512_permutex2var_epi8(bf16_lo_0, bfp8_2, bf16_lo_1);
|
||||
__m512bh bbf16_2_lo = (__m512bh)_mm512_unpacklo_epi8(b_lo, b_hi);
|
||||
__m512bh bbf16_2_hi = (__m512bh)_mm512_unpackhi_epi8(b_lo, b_hi);
|
||||
|
||||
b_hi = _mm512_or_si512(_mm512_and_si512(sign_mask_v, bfp8_3),
|
||||
_mm512_permutex2var_epi8(bf16_hi_0, bfp8_3, bf16_hi_1));
|
||||
b_lo = _mm512_permutex2var_epi8(bf16_lo_0, bfp8_3, bf16_lo_1);
|
||||
__m512bh bbf16_3_lo = (__m512bh)_mm512_unpacklo_epi8(b_lo, b_hi);
|
||||
__m512bh bbf16_3_hi = (__m512bh)_mm512_unpackhi_epi8(b_lo, b_hi);
|
||||
|
||||
// Process m rows - unroll by 2 for better ILP
|
||||
int m_i = 0;
|
||||
for (; m_i + 1 < m_block_end; m_i += 2) {
|
||||
// Load A values for 2 rows
|
||||
__m512bh ma0_0 = (__m512bh)_mm512_set1_epi32(*(int32_t*)&abf16[m_i * K_STEP + k_i * 2]);
|
||||
__m512bh ma1_0 = (__m512bh)_mm512_set1_epi32(*(int32_t*)&abf16[m_i * K_STEP + (k_i + 1) * 2]);
|
||||
__m512bh ma2_0 = (__m512bh)_mm512_set1_epi32(*(int32_t*)&abf16[m_i * K_STEP + (k_i + 2) * 2]);
|
||||
__m512bh ma3_0 = (__m512bh)_mm512_set1_epi32(*(int32_t*)&abf16[m_i * K_STEP + (k_i + 3) * 2]);
|
||||
__m512bh ma0_1 = (__m512bh)_mm512_set1_epi32(*(int32_t*)&abf16[(m_i + 1) * K_STEP + k_i * 2]);
|
||||
__m512bh ma1_1 = (__m512bh)_mm512_set1_epi32(*(int32_t*)&abf16[(m_i + 1) * K_STEP + (k_i + 1) * 2]);
|
||||
__m512bh ma2_1 = (__m512bh)_mm512_set1_epi32(*(int32_t*)&abf16[(m_i + 1) * K_STEP + (k_i + 2) * 2]);
|
||||
__m512bh ma3_1 = (__m512bh)_mm512_set1_epi32(*(int32_t*)&abf16[(m_i + 1) * K_STEP + (k_i + 3) * 2]);
|
||||
|
||||
// Process row 0, then row 1 - sequential to avoid dependencies
|
||||
c512[m_i * 2] = _mm512_dpbf16_ps(c512[m_i * 2], ma0_0, bbf16_0_lo);
|
||||
c512[m_i * 2 + 1] = _mm512_dpbf16_ps(c512[m_i * 2 + 1], ma0_0, bbf16_0_hi);
|
||||
c512[m_i * 2] = _mm512_dpbf16_ps(c512[m_i * 2], ma1_0, bbf16_1_lo);
|
||||
c512[m_i * 2 + 1] = _mm512_dpbf16_ps(c512[m_i * 2 + 1], ma1_0, bbf16_1_hi);
|
||||
c512[m_i * 2] = _mm512_dpbf16_ps(c512[m_i * 2], ma2_0, bbf16_2_lo);
|
||||
c512[m_i * 2 + 1] = _mm512_dpbf16_ps(c512[m_i * 2 + 1], ma2_0, bbf16_2_hi);
|
||||
c512[m_i * 2] = _mm512_dpbf16_ps(c512[m_i * 2], ma3_0, bbf16_3_lo);
|
||||
c512[m_i * 2 + 1] = _mm512_dpbf16_ps(c512[m_i * 2 + 1], ma3_0, bbf16_3_hi);
|
||||
|
||||
c512[(m_i + 1) * 2] = _mm512_dpbf16_ps(c512[(m_i + 1) * 2], ma0_1, bbf16_0_lo);
|
||||
c512[(m_i + 1) * 2 + 1] = _mm512_dpbf16_ps(c512[(m_i + 1) * 2 + 1], ma0_1, bbf16_0_hi);
|
||||
c512[(m_i + 1) * 2] = _mm512_dpbf16_ps(c512[(m_i + 1) * 2], ma1_1, bbf16_1_lo);
|
||||
c512[(m_i + 1) * 2 + 1] = _mm512_dpbf16_ps(c512[(m_i + 1) * 2 + 1], ma1_1, bbf16_1_hi);
|
||||
c512[(m_i + 1) * 2] = _mm512_dpbf16_ps(c512[(m_i + 1) * 2], ma2_1, bbf16_2_lo);
|
||||
c512[(m_i + 1) * 2 + 1] = _mm512_dpbf16_ps(c512[(m_i + 1) * 2 + 1], ma2_1, bbf16_2_hi);
|
||||
c512[(m_i + 1) * 2] = _mm512_dpbf16_ps(c512[(m_i + 1) * 2], ma3_1, bbf16_3_lo);
|
||||
c512[(m_i + 1) * 2 + 1] = _mm512_dpbf16_ps(c512[(m_i + 1) * 2 + 1], ma3_1, bbf16_3_hi);
|
||||
}
|
||||
// Handle remaining row
|
||||
for (; m_i < m_block_end; m_i++) {
|
||||
__m512bh ma0 = (__m512bh)_mm512_set1_epi32(*(int32_t*)&abf16[m_i * K_STEP + k_i * 2]);
|
||||
__m512bh ma1 = (__m512bh)_mm512_set1_epi32(*(int32_t*)&abf16[m_i * K_STEP + (k_i + 1) * 2]);
|
||||
__m512bh ma2 = (__m512bh)_mm512_set1_epi32(*(int32_t*)&abf16[m_i * K_STEP + (k_i + 2) * 2]);
|
||||
__m512bh ma3 = (__m512bh)_mm512_set1_epi32(*(int32_t*)&abf16[m_i * K_STEP + (k_i + 3) * 2]);
|
||||
|
||||
c512[m_i * 2] = _mm512_dpbf16_ps(c512[m_i * 2], ma0, bbf16_0_lo);
|
||||
c512[m_i * 2 + 1] = _mm512_dpbf16_ps(c512[m_i * 2 + 1], ma0, bbf16_0_hi);
|
||||
c512[m_i * 2] = _mm512_dpbf16_ps(c512[m_i * 2], ma1, bbf16_1_lo);
|
||||
c512[m_i * 2 + 1] = _mm512_dpbf16_ps(c512[m_i * 2 + 1], ma1, bbf16_1_hi);
|
||||
c512[m_i * 2] = _mm512_dpbf16_ps(c512[m_i * 2], ma2, bbf16_2_lo);
|
||||
c512[m_i * 2 + 1] = _mm512_dpbf16_ps(c512[m_i * 2 + 1], ma2, bbf16_2_hi);
|
||||
c512[m_i * 2] = _mm512_dpbf16_ps(c512[m_i * 2], ma3, bbf16_3_lo);
|
||||
c512[m_i * 2 + 1] = _mm512_dpbf16_ps(c512[m_i * 2 + 1], ma3, bbf16_3_hi);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void apply_scale_kgroup(int m, int n, int m_begin, int n_begin, int k_block_begin, float* c, float* reduce_c,
|
||||
BufferA* ba, BufferB* bb, int k, int k_group_size) {
|
||||
using K = GemmKernel224FP8;
|
||||
int to = std::min(m - m_begin, K::M_STEP);
|
||||
|
||||
for (int i = 0; i < to; i++) {
|
||||
// Get scale for this k_group
|
||||
__m512 bs = _mm512_set1_ps(*bb->get_scale(n, n_begin, k, k_block_begin));
|
||||
__m512 now = _mm512_load_ps(reduce_c + i * K::N_STEP);
|
||||
__m512 result = _mm512_mul_ps(now, bs);
|
||||
__m512 existing = _mm512_load_ps(c + i * K::N_STEP);
|
||||
result = _mm512_add_ps(result, existing);
|
||||
_mm512_store_ps(c + i * K::N_STEP, result);
|
||||
|
||||
now = _mm512_load_ps(reduce_c + i * K::N_STEP + K::TILE_N);
|
||||
result = _mm512_mul_ps(now, bs);
|
||||
existing = _mm512_load_ps(c + i * K::N_STEP + K::TILE_N);
|
||||
result = _mm512_add_ps(result, existing);
|
||||
_mm512_store_ps(c + i * K::N_STEP + K::TILE_N, result);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// all step = 32
|
||||
template <typename K, bool amx_or_avx = false>
|
||||
void float_mat_vec_kgroup(int m, int n, int k, int k_group_size, typename K::BufferA* ba, typename K::BufferB* bb,
|
||||
typename K::BufferC* bc, int ith, int nth) {
|
||||
assert(n % K::N_STEP == 0);
|
||||
assert(k % k_group_size == 0);
|
||||
assert(k_group_size % K::K_STEP == 0);
|
||||
|
||||
auto [n_start, n_end] = K::split_range_n(n, ith, nth);
|
||||
|
||||
// Process by k_groups
|
||||
for (int k_group_begin = 0; k_group_begin < k; k_group_begin += k_group_size) {
|
||||
for (int m_begin = 0; m_begin < m; m_begin += K::M_STEP) {
|
||||
for (int n_begin = n_start; n_begin < n_end; n_begin += K::N_STEP) {
|
||||
float* c = bc->get_submat(m, n, m_begin, n_begin);
|
||||
float* reduce_c = bc->get_reduce_submat(m, n, m_begin, n_begin);
|
||||
|
||||
if (k_group_begin == 0) {
|
||||
for (int i = 0; i < K::M_STEP && m_begin + i < m; i++) {
|
||||
for (int j = 0; j < K::N_STEP; j++) {
|
||||
c[i * K::N_STEP + j] = 0.0f;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// avx_kernel_4 now processes entire k_group_size internally (like INT8's avx_kernel)
|
||||
if constexpr (amx_or_avx && AMX_AVAILABLE) {
|
||||
for (int k_begin = k_group_begin; k_begin < std::min(k, k_group_begin + k_group_size); k_begin += K::K_STEP) {
|
||||
K::amx_kernel(m, n, k, m_begin, n_begin, k_begin, reduce_c, ba, bb, k_group_size);
|
||||
}
|
||||
} else {
|
||||
// Single call processes entire k_group
|
||||
K::avx_kernel_4(m, n, k, m_begin, n_begin, k_group_begin, reduce_c, ba, bb, k_group_size);
|
||||
}
|
||||
K::apply_scale_kgroup(m, n, m_begin, n_begin, k_group_begin, c, reduce_c, ba, bb, k, k_group_size);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// inline void vec_mul_kgroup(int m, int n, int k, int k_group_size, std::shared_ptr<GemmKernel224BF16::BufferA> ba,
|
||||
// std::shared_ptr<GemmKernel224BF16::BufferB> bb,
|
||||
// std::shared_ptr<GemmKernel224BF16::BufferC> bc, int ith, int nth) {
|
||||
// float_mat_mul_kgroup<GemmKernel224BF16, false>(m, n, k, k_group_size, ba.get(), bb.get(), bc.get(), ith, nth);
|
||||
// }
|
||||
|
||||
// inline void mat_mul_kgroup(int m, int n, int k, int k_group_size, std::shared_ptr<GemmKernel224BF16::BufferA> ba,
|
||||
// std::shared_ptr<GemmKernel224BF16::BufferB> bb,
|
||||
// std::shared_ptr<GemmKernel224BF16::BufferC> bc, int ith, int nth) {
|
||||
// float_mat_mul_kgroup<GemmKernel224BF16, true>(m, n, k, k_group_size, ba.get(), bb.get(), bc.get(), ith, nth);
|
||||
// }
|
||||
|
||||
inline void vec_mul_kgroup(int m, int n, int k, int k_group_size, std::shared_ptr<GemmKernel224FP8::BufferA> ba,
|
||||
std::shared_ptr<GemmKernel224FP8::BufferB> bb, std::shared_ptr<GemmKernel224FP8::BufferC> bc,
|
||||
int ith, int nth) {
|
||||
float_mat_vec_kgroup<GemmKernel224FP8, false>(m, n, k, k_group_size, ba.get(), bb.get(), bc.get(), ith, nth);
|
||||
}
|
||||
|
||||
inline void mat_mul_kgroup(int m, int n, int k, int k_group_size, std::shared_ptr<GemmKernel224FP8::BufferA> ba,
|
||||
std::shared_ptr<GemmKernel224FP8::BufferB> bb, std::shared_ptr<GemmKernel224FP8::BufferC> bc,
|
||||
int ith, int nth) {
|
||||
float_mat_vec_kgroup<GemmKernel224FP8, false>(m, n, k, k_group_size, ba.get(), bb.get(), bc.get(), ith, nth);
|
||||
}
|
||||
|
||||
} // namespace amx
|
||||
|
||||
#endif // AMX_RAW_KERNELS_HPP
|
||||
@@ -11,30 +11,27 @@
|
||||
#define CPUINFER_OPERATOR_AMX_MOE_H
|
||||
|
||||
// #define CHECK
|
||||
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
#include <cstring>
|
||||
// #define FORWARD_TIME_PROFILE
|
||||
// #define FORWARD_TIME_REPORT
|
||||
|
||||
#include <cmath>
|
||||
#include <cstdio>
|
||||
#include <filesystem>
|
||||
#include <fstream>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "../../cpu_backend/shared_mem_buffer.h"
|
||||
#include "../../cpu_backend/worker_pool.h"
|
||||
#include "../moe-tp.hpp"
|
||||
#include "la/amx.hpp"
|
||||
#include "llama.cpp/ggml.h"
|
||||
#include "moe_base.hpp"
|
||||
|
||||
template <class T>
|
||||
class AMX_MOE_TP {
|
||||
class AMX_MOE_TP : public AMX_MOE_BASE<T, AMX_MOE_TP<T>> {
|
||||
private:
|
||||
int tp_part_idx;
|
||||
using Base = AMX_MOE_BASE<T, AMX_MOE_TP<T>>;
|
||||
using Base::config_;
|
||||
using Base::tp_part_idx;
|
||||
using Base::gate_bb_;
|
||||
using Base::up_bb_;
|
||||
using Base::down_bb_;
|
||||
using Base::gate_up_ba_;
|
||||
using Base::gate_bc_;
|
||||
using Base::up_bc_;
|
||||
using Base::down_ba_;
|
||||
using Base::down_bc_;
|
||||
using Base::m_local_num_;
|
||||
|
||||
std::filesystem::path prefix;
|
||||
|
||||
void* gate_proj_; // [expert_num * intermediate_size * hidden_size ( /32 if
|
||||
@@ -44,27 +41,6 @@ class AMX_MOE_TP {
|
||||
void* down_proj_; // [expert_num * hidden_size * intermediate_size ( /32 if
|
||||
// quantized)]
|
||||
|
||||
ggml_bf16_t* m_local_input_; // [num_experts_per_tok * max_len * hidden_size]
|
||||
ggml_bf16_t* m_local_gate_output_; // [num_experts_per_tok * max_len * intermediate_size]
|
||||
ggml_bf16_t* m_local_up_output_; // [num_experts_per_tok * max_len * intermediate_size]
|
||||
ggml_bf16_t* m_local_down_output_; // [num_experts_per_tok * max_len * hidden_size]
|
||||
|
||||
std::vector<std::vector<int>> m_local_pos_; // [max_len, num_experts_per_tok]
|
||||
std::vector<int> m_local_num_; // [expert_num]
|
||||
std::vector<int> m_expert_id_map_; // [expert_num]
|
||||
std::vector<ggml_bf16_t*> m_local_input_ptr_; // [expert_num]
|
||||
std::vector<ggml_bf16_t*> m_local_gate_output_ptr_; // [expert_num]
|
||||
std::vector<ggml_bf16_t*> m_local_up_output_ptr_; // [expert_num]
|
||||
std::vector<ggml_bf16_t*> m_local_down_output_ptr_; // [expert_num]
|
||||
|
||||
std::vector<std::shared_ptr<typename T::BufferA>> gate_up_ba_;
|
||||
std::vector<std::shared_ptr<typename T::BufferB>> gate_bb_;
|
||||
std::vector<std::shared_ptr<typename T::BufferC>> gate_bc_;
|
||||
std::vector<std::shared_ptr<typename T::BufferB>> up_bb_;
|
||||
std::vector<std::shared_ptr<typename T::BufferC>> up_bc_;
|
||||
std::vector<std::shared_ptr<typename T::BufferA>> down_ba_;
|
||||
std::vector<std::shared_ptr<typename T::BufferB>> down_bb_;
|
||||
std::vector<std::shared_ptr<typename T::BufferC>> down_bc_;
|
||||
#ifdef CHECK
|
||||
char verify_bb[100000000];
|
||||
char check_bb[100000000];
|
||||
@@ -161,21 +137,15 @@ class AMX_MOE_TP {
|
||||
#endif
|
||||
|
||||
public:
|
||||
using input_t = ggml_bf16_t;
|
||||
using output_t = float;
|
||||
GeneralMOEConfig config_;
|
||||
static constexpr double ELEMENT_SIZE = T::ELEMENT_SIZE;
|
||||
AMX_MOE_TP() = default;
|
||||
|
||||
AMX_MOE_TP(GeneralMOEConfig config, int tp_part_idx) {
|
||||
AMX_MOE_TP(GeneralMOEConfig config, int tp_part_idx = 0) : Base(config, tp_part_idx) {
|
||||
printf("Creating AMX_MOE_TP %d at numa %d\n", tp_part_idx, numa_node_of_cpu(sched_getcpu()));
|
||||
auto& load = config.load;
|
||||
auto& save = config.save;
|
||||
if (load && config.path == "") {
|
||||
load = false;
|
||||
}
|
||||
auto& load = config_.load;
|
||||
auto& save = config_.save;
|
||||
|
||||
prefix = config.path;
|
||||
prefix = prefix / ("_layer_" + std::to_string(config.layer_idx)) / ("_numa_" + std::to_string(tp_part_idx));
|
||||
prefix = config_.path;
|
||||
prefix = prefix / ("_layer_" + std::to_string(config_.layer_idx)) / ("_numa_" + std::to_string(tp_part_idx));
|
||||
if (save) {
|
||||
std::cout << "Creating " << prefix << std::endl;
|
||||
std::filesystem::create_directories(prefix);
|
||||
@@ -188,78 +158,65 @@ class AMX_MOE_TP {
|
||||
}
|
||||
}
|
||||
|
||||
this->tp_part_idx = tp_part_idx;
|
||||
config_ = config;
|
||||
gate_proj_ = config_.gate_proj;
|
||||
up_proj_ = config_.up_proj;
|
||||
down_proj_ = config_.down_proj;
|
||||
|
||||
MemoryRequest mem_requests;
|
||||
mem_requests.append_pointer(
|
||||
&m_local_input_, sizeof(ggml_bf16_t) * config_.num_experts_per_tok * config_.max_len * config_.hidden_size);
|
||||
mem_requests.append_pointer(&m_local_gate_output_, sizeof(ggml_bf16_t) * config_.num_experts_per_tok *
|
||||
config_.max_len * config_.intermediate_size);
|
||||
mem_requests.append_pointer(&m_local_up_output_, sizeof(ggml_bf16_t) * config_.num_experts_per_tok *
|
||||
config_.max_len * config_.intermediate_size);
|
||||
mem_requests.append_pointer(&m_local_down_output_, sizeof(ggml_bf16_t) * config_.num_experts_per_tok *
|
||||
config_.max_len * config_.hidden_size);
|
||||
|
||||
m_local_pos_.resize(config_.max_len);
|
||||
for (int i = 0; i < config_.max_len; i++) {
|
||||
m_local_pos_[i].resize(config_.num_experts_per_tok);
|
||||
}
|
||||
m_expert_id_map_.resize(config_.expert_num);
|
||||
m_local_num_.resize(config_.expert_num);
|
||||
m_local_input_ptr_.resize(config_.expert_num);
|
||||
m_local_gate_output_ptr_.resize(config_.expert_num);
|
||||
m_local_up_output_ptr_.resize(config_.expert_num);
|
||||
m_local_down_output_ptr_.resize(config_.expert_num);
|
||||
|
||||
// printf("tp part %d alloc layer %d, %f GB, on numa %d\n", tp_part_idx, config_.layer_idx,
|
||||
// 1e-9 * config_.expert_num *
|
||||
// (T::BufferB::required_size(config_.intermediate_size, config_.hidden_size) * 2 +
|
||||
// T::BufferB::required_size(config_.hidden_size, config_.intermediate_size)),
|
||||
// numa_node_of_cpu(sched_getcpu()));
|
||||
|
||||
for (size_t i = 0; i < config_.expert_num; i++) {
|
||||
gate_up_ba_.push_back(std::make_shared<typename T::BufferA>(config_.max_len, config_.hidden_size, nullptr));
|
||||
gate_bc_.push_back(std::make_shared<typename T::BufferC>(config_.max_len, config_.intermediate_size, nullptr));
|
||||
up_bc_.push_back(std::make_shared<typename T::BufferC>(config_.max_len, config_.intermediate_size, nullptr));
|
||||
down_ba_.push_back(std::make_shared<typename T::BufferA>(config_.max_len, config_.intermediate_size, nullptr));
|
||||
down_bc_.push_back(std::make_shared<typename T::BufferC>(config_.max_len, config_.hidden_size, nullptr));
|
||||
|
||||
void* gate_bb_ptr =
|
||||
std::aligned_alloc(64, T::BufferB::required_size(config_.intermediate_size, config_.hidden_size));
|
||||
gate_bb_.push_back(
|
||||
std::make_shared<typename T::BufferB>(config_.intermediate_size, config_.hidden_size, gate_bb_ptr));
|
||||
|
||||
void* up_bb_ptr =
|
||||
std::aligned_alloc(64, T::BufferB::required_size(config_.intermediate_size, config_.hidden_size));
|
||||
up_bb_.push_back(
|
||||
std::make_shared<typename T::BufferB>(config_.intermediate_size, config_.hidden_size, up_bb_ptr));
|
||||
|
||||
void* down_bb_ptr =
|
||||
std::aligned_alloc(64, T::BufferB::required_size(config_.hidden_size, config_.intermediate_size));
|
||||
down_bb_.push_back(
|
||||
std::make_shared<typename T::BufferB>(config_.hidden_size, config_.intermediate_size, down_bb_ptr));
|
||||
}
|
||||
for (int i = 0; i < config_.expert_num; i++) {
|
||||
mem_requests.append_function([this, i](void* new_ptr) { gate_up_ba_[i]->set_data(new_ptr); },
|
||||
T::BufferA::required_size(config_.max_len, config_.hidden_size));
|
||||
mem_requests.append_function([this, i](void* new_ptr) { gate_bc_[i]->set_data(new_ptr); },
|
||||
T::BufferC::required_size(config_.max_len, config_.intermediate_size));
|
||||
mem_requests.append_function([this, i](void* new_ptr) { up_bc_[i]->set_data(new_ptr); },
|
||||
T::BufferC::required_size(config_.max_len, config_.intermediate_size));
|
||||
mem_requests.append_function([this, i](void* new_ptr) { down_ba_[i]->set_data(new_ptr); },
|
||||
T::BufferA::required_size(config_.max_len, config_.intermediate_size));
|
||||
mem_requests.append_function([this, i](void* new_ptr) { down_bc_[i]->set_data(new_ptr); },
|
||||
T::BufferC::required_size(config_.max_len, config_.hidden_size));
|
||||
}
|
||||
shared_mem_buffer_numa.alloc(tp_part_idx, this, mem_requests);
|
||||
}
|
||||
|
||||
~AMX_MOE_TP() {
|
||||
// shared_mem_buffer_numa.dealloc(this);
|
||||
~AMX_MOE_TP() = default;
|
||||
|
||||
// ============================================================================
|
||||
// CRTP buffer creation - no group_size
|
||||
// ============================================================================
|
||||
|
||||
size_t buffer_a_required_size_impl(size_t m, size_t k) const {
|
||||
return T::BufferA::required_size(m, k);
|
||||
}
|
||||
size_t buffer_b_required_size_impl(size_t n, size_t k) const {
|
||||
return T::BufferB::required_size(n, k);
|
||||
}
|
||||
size_t buffer_c_required_size_impl(size_t m, size_t n) const {
|
||||
return T::BufferC::required_size(m, n);
|
||||
}
|
||||
|
||||
std::shared_ptr<typename T::BufferA> make_buffer_a_impl(size_t m, size_t k, void* data) const {
|
||||
return std::make_shared<typename T::BufferA>(m, k, data);
|
||||
}
|
||||
std::shared_ptr<typename T::BufferB> make_buffer_b_impl(size_t n, size_t k, void* data) const {
|
||||
return std::make_shared<typename T::BufferB>(n, k, data);
|
||||
}
|
||||
std::shared_ptr<typename T::BufferC> make_buffer_c_impl(size_t m, size_t n, void* data) const {
|
||||
return std::make_shared<typename T::BufferC>(m, n, data);
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// CRTP virtual points - GEMM dispatch
|
||||
// ============================================================================
|
||||
|
||||
void do_gate_up_gemm(bool do_up, int expert_idx, int ith, int nth, int qlen) {
|
||||
int m = m_local_num_[expert_idx];
|
||||
auto& ba = gate_up_ba_[expert_idx];
|
||||
auto& bb = do_up ? up_bb_[expert_idx] : gate_bb_[expert_idx];
|
||||
auto& bc = do_up ? up_bc_[expert_idx] : gate_bc_[expert_idx];
|
||||
|
||||
if (qlen > 4 * config_.expert_num / config_.num_experts_per_tok) {
|
||||
amx::mat_mul(m, config_.intermediate_size, config_.hidden_size, ba, bb, bc, ith, nth);
|
||||
} else {
|
||||
amx::vec_mul(m, config_.intermediate_size, config_.hidden_size, ba, bb, bc, ith, nth);
|
||||
}
|
||||
}
|
||||
|
||||
void do_down_gemm(int expert_idx, int ith, int nth, int qlen) {
|
||||
int m = m_local_num_[expert_idx];
|
||||
auto& ba = down_ba_[expert_idx];
|
||||
auto& bb = down_bb_[expert_idx];
|
||||
auto& bc = down_bc_[expert_idx];
|
||||
|
||||
if (qlen > 4 * config_.expert_num / config_.num_experts_per_tok) {
|
||||
amx::mat_mul(m, config_.hidden_size, config_.intermediate_size, ba, bb, bc, ith, nth);
|
||||
} else {
|
||||
amx::vec_mul(m, config_.hidden_size, config_.intermediate_size, ba, bb, bc, ith, nth);
|
||||
}
|
||||
}
|
||||
void load_weights() {
|
||||
auto pool = config_.pool->get_subpool(tp_part_idx);
|
||||
@@ -401,434 +358,21 @@ class AMX_MOE_TP {
|
||||
}
|
||||
}
|
||||
|
||||
void warm_up() {
|
||||
int qlen = config_.max_len;
|
||||
std::vector<uint8_t> input(sizeof(ggml_bf16_t) * qlen * config_.hidden_size);
|
||||
std::vector<uint8_t> output(sizeof(ggml_bf16_t) * qlen * config_.hidden_size);
|
||||
std::vector<int64_t> expert_ids(qlen * config_.num_experts_per_tok);
|
||||
std::vector<float> weights(qlen * config_.num_experts_per_tok);
|
||||
for (int i = 0; i < qlen * config_.num_experts_per_tok; i++) {
|
||||
expert_ids[i] = i % config_.expert_num;
|
||||
weights[i] = 0.01;
|
||||
}
|
||||
forward(qlen, config_.num_experts_per_tok, expert_ids.data(), weights.data(), input.data(), output.data());
|
||||
}
|
||||
|
||||
void forward(int qlen, int k, const int64_t* expert_ids, const float* weights, const void* input, void* output) {
|
||||
if (qlen > 1) {
|
||||
forward_prefill(qlen, k, expert_ids, weights, input, output);
|
||||
} else {
|
||||
forward_decode(k, expert_ids, weights, input, output);
|
||||
}
|
||||
}
|
||||
|
||||
#define DIRECT_OR_POOL_BY_QLEN(var, fn) \
|
||||
do { \
|
||||
if (qlen < 10) { \
|
||||
for (int i = 0; i < (var); i++) { \
|
||||
(fn)(i); \
|
||||
} \
|
||||
} else { \
|
||||
pool->do_work_stealing_job((var), nullptr, (fn), nullptr); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
#define MATMUL_OR_VECMUL_BY_QLEN(...) \
|
||||
do { \
|
||||
if (qlen > 4 * config_.expert_num / config_.num_experts_per_tok) { \
|
||||
amx::mat_mul(__VA_ARGS__); \
|
||||
} else { \
|
||||
amx::vec_mul(__VA_ARGS__); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
void forward_prefill(int qlen, int k, const int64_t* expert_ids, const float* weights, const void* input,
|
||||
void* output) {
|
||||
auto pool = config_.pool->get_subpool(tp_part_idx);
|
||||
#ifdef FORWARD_TIME_PROFILE
|
||||
auto start_time = std::chrono::high_resolution_clock::now();
|
||||
auto last = start_time;
|
||||
// 用于保存各阶段耗时(单位:微秒)
|
||||
long prepare_time = 0, cpy_input_time = 0, q_input_time = 0, up_gate_time = 0;
|
||||
long act_time = 0, q_down_time = 0, down_time = 0, weight_time = 0;
|
||||
int max_local_num = 0; // 记录最大的 local num
|
||||
#endif
|
||||
|
||||
int activated_expert = 0;
|
||||
for (int i = 0; i < config_.expert_num; i++) {
|
||||
m_local_num_[i] = 0;
|
||||
}
|
||||
for (int i = 0; i < qlen; i++) {
|
||||
for (int j = 0; j < k; j++) {
|
||||
if (expert_ids[i * k + j] < config_.num_gpu_experts || expert_ids[i * k + j] >= config_.expert_num) {
|
||||
continue;
|
||||
}
|
||||
m_local_pos_[i][j] = m_local_num_[expert_ids[i * k + j]]++;
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = 0; i < config_.expert_num; i++) {
|
||||
if (m_local_num_[i] > 0) {
|
||||
#ifdef FORWARD_TIME_PROFILE
|
||||
max_local_num = std::max(max_local_num, m_local_num_[i]);
|
||||
#endif
|
||||
m_expert_id_map_[activated_expert] = i;
|
||||
activated_expert++;
|
||||
}
|
||||
}
|
||||
|
||||
// activated_expert 已经统计完成
|
||||
|
||||
size_t offset = 0;
|
||||
for (int i = 0; i < config_.expert_num; i++) {
|
||||
m_local_input_ptr_[i] = m_local_input_ + offset * config_.hidden_size;
|
||||
m_local_gate_output_ptr_[i] = m_local_gate_output_ + offset * config_.intermediate_size;
|
||||
m_local_up_output_ptr_[i] = m_local_up_output_ + offset * config_.intermediate_size;
|
||||
m_local_down_output_ptr_[i] = m_local_down_output_ + offset * config_.hidden_size;
|
||||
offset += m_local_num_[i];
|
||||
}
|
||||
#ifdef FORWARD_TIME_PROFILE
|
||||
{
|
||||
auto now_time = std::chrono::high_resolution_clock::now();
|
||||
prepare_time = std::chrono::duration_cast<std::chrono::microseconds>(now_time - last).count();
|
||||
last = now_time;
|
||||
}
|
||||
#endif
|
||||
|
||||
DIRECT_OR_POOL_BY_QLEN(qlen, [&](int i) {
|
||||
for (int j = 0; j < k; j++) {
|
||||
if (expert_ids[i * k + j] < config_.num_gpu_experts || expert_ids[i * k + j] >= config_.expert_num) {
|
||||
continue;
|
||||
}
|
||||
memcpy(m_local_input_ptr_[expert_ids[i * k + j]] + m_local_pos_[i][j] * config_.hidden_size,
|
||||
(ggml_bf16_t*)input + i * config_.hidden_size, sizeof(ggml_bf16_t) * config_.hidden_size);
|
||||
}
|
||||
});
|
||||
|
||||
#ifdef FORWARD_TIME_PROFILE
|
||||
{
|
||||
auto now_time = std::chrono::high_resolution_clock::now();
|
||||
cpy_input_time = std::chrono::duration_cast<std::chrono::microseconds>(now_time - last).count();
|
||||
last = now_time;
|
||||
}
|
||||
#endif
|
||||
|
||||
DIRECT_OR_POOL_BY_QLEN(activated_expert, [this](int task_id) {
|
||||
int expert_idx = m_expert_id_map_[task_id];
|
||||
gate_up_ba_[expert_idx]->from_mat(m_local_num_[expert_idx], m_local_input_ptr_[expert_idx], 0, 1);
|
||||
});
|
||||
|
||||
#ifdef FORWARD_TIME_PROFILE
|
||||
{
|
||||
auto now_time = std::chrono::high_resolution_clock::now();
|
||||
q_input_time = std::chrono::duration_cast<std::chrono::microseconds>(now_time - last).count();
|
||||
last = now_time;
|
||||
}
|
||||
#endif
|
||||
|
||||
int nth = T::recommended_nth(config_.intermediate_size);
|
||||
pool->do_work_stealing_job(
|
||||
nth * activated_expert * 2, [](int _) { T::config(); },
|
||||
[this, nth, qlen](int task_id2) {
|
||||
int task_id = task_id2 / 2;
|
||||
bool do_up = task_id2 % 2;
|
||||
int expert_idx = m_expert_id_map_[task_id / nth];
|
||||
|
||||
int ith = task_id % nth;
|
||||
if (do_up) {
|
||||
MATMUL_OR_VECMUL_BY_QLEN(m_local_num_[expert_idx], config_.intermediate_size, config_.hidden_size,
|
||||
gate_up_ba_[expert_idx], up_bb_[expert_idx], up_bc_[expert_idx], ith, nth);
|
||||
up_bc_[expert_idx]->to_mat(m_local_num_[expert_idx], m_local_up_output_ptr_[expert_idx], ith, nth);
|
||||
} else {
|
||||
MATMUL_OR_VECMUL_BY_QLEN(m_local_num_[expert_idx], config_.intermediate_size, config_.hidden_size,
|
||||
gate_up_ba_[expert_idx], gate_bb_[expert_idx], gate_bc_[expert_idx], ith, nth);
|
||||
gate_bc_[expert_idx]->to_mat(m_local_num_[expert_idx], m_local_gate_output_ptr_[expert_idx], ith, nth);
|
||||
}
|
||||
},
|
||||
nullptr);
|
||||
#ifdef FORWARD_TIME_PROFILE
|
||||
{
|
||||
auto now_time = std::chrono::high_resolution_clock::now();
|
||||
up_gate_time = std::chrono::duration_cast<std::chrono::microseconds>(now_time - last).count();
|
||||
last = now_time;
|
||||
}
|
||||
#endif
|
||||
|
||||
auto up_gate_fn = [this, nth](int task_id) {
|
||||
int expert_idx = m_expert_id_map_[task_id / nth];
|
||||
int ith = task_id % nth;
|
||||
auto [n_start, n_end] = T::split_range_n(config_.intermediate_size, ith, nth);
|
||||
for (int i = 0; i < m_local_num_[expert_idx]; i++) {
|
||||
ggml_bf16_t* gate_output_ptr = &m_local_gate_output_ptr_[expert_idx][i * config_.intermediate_size];
|
||||
ggml_bf16_t* up_output_ptr = &m_local_up_output_ptr_[expert_idx][i * config_.intermediate_size];
|
||||
for (int j = n_start; j < n_end; j += 32) {
|
||||
__m512 gate_val0, gate_val1, up_val0, up_val1;
|
||||
avx512_32xbf16_to_32xfp32((__m512i*)(gate_output_ptr + j), &gate_val0, &gate_val1);
|
||||
avx512_32xbf16_to_32xfp32((__m512i*)(up_output_ptr + j), &up_val0, &up_val1);
|
||||
__m512 result0 = amx::act_fn(gate_val0, up_val0);
|
||||
__m512 result1 = amx::act_fn(gate_val1, up_val1);
|
||||
avx512_32xfp32_to_32xbf16(&result0, &result1, (__m512i*)(gate_output_ptr + j));
|
||||
}
|
||||
}
|
||||
};
|
||||
DIRECT_OR_POOL_BY_QLEN(nth * activated_expert, up_gate_fn);
|
||||
|
||||
#ifdef FORWARD_TIME_PROFILE
|
||||
{
|
||||
auto now_time = std::chrono::high_resolution_clock::now();
|
||||
act_time = std::chrono::duration_cast<std::chrono::microseconds>(now_time - last).count();
|
||||
last = now_time;
|
||||
}
|
||||
#endif
|
||||
|
||||
pool->do_work_stealing_job(
|
||||
activated_expert, nullptr,
|
||||
[this](int task_id) {
|
||||
int expert_idx = m_expert_id_map_[task_id];
|
||||
down_ba_[expert_idx]->from_mat(m_local_num_[expert_idx], m_local_gate_output_ptr_[expert_idx], 0, 1);
|
||||
},
|
||||
nullptr);
|
||||
#ifdef FORWARD_TIME_PROFILE
|
||||
{
|
||||
auto now_time = std::chrono::high_resolution_clock::now();
|
||||
q_down_time = std::chrono::duration_cast<std::chrono::microseconds>(now_time - last).count();
|
||||
last = now_time;
|
||||
}
|
||||
#endif
|
||||
|
||||
nth = T::recommended_nth(config_.hidden_size);
|
||||
pool->do_work_stealing_job(
|
||||
nth * activated_expert, [](int _) { T::config(); },
|
||||
[this, nth, qlen](int task_id) {
|
||||
int expert_idx = m_expert_id_map_[task_id / nth];
|
||||
int ith = task_id % nth;
|
||||
MATMUL_OR_VECMUL_BY_QLEN(m_local_num_[expert_idx], config_.hidden_size, config_.intermediate_size,
|
||||
down_ba_[expert_idx], down_bb_[expert_idx], down_bc_[expert_idx], ith, nth);
|
||||
down_bc_[expert_idx]->to_mat(m_local_num_[expert_idx], m_local_down_output_ptr_[expert_idx], ith, nth);
|
||||
},
|
||||
nullptr);
|
||||
#ifdef FORWARD_TIME_PROFILE
|
||||
{
|
||||
auto now_time = std::chrono::high_resolution_clock::now();
|
||||
down_time = std::chrono::duration_cast<std::chrono::microseconds>(now_time - last).count();
|
||||
last = now_time;
|
||||
}
|
||||
#endif
|
||||
|
||||
pool->do_work_stealing_job(
|
||||
qlen, nullptr,
|
||||
[this, nth, output, k, expert_ids, weights](int i) {
|
||||
for (int e = 0; e < config_.hidden_size; e += 32) {
|
||||
__m512 x0 = _mm512_setzero_ps();
|
||||
__m512 x1 = _mm512_setzero_ps();
|
||||
for (int j = 0; j < k; j++) {
|
||||
if (expert_ids[i * k + j] < config_.num_gpu_experts || expert_ids[i * k + j] >= config_.expert_num) {
|
||||
continue;
|
||||
}
|
||||
__m512 weight = _mm512_set1_ps(weights[i * k + j]);
|
||||
__m512 down_output0, down_output1;
|
||||
avx512_32xbf16_to_32xfp32((__m512i*)(m_local_down_output_ptr_[expert_ids[i * k + j]] +
|
||||
m_local_pos_[i][j] * config_.hidden_size + e),
|
||||
&down_output0, &down_output1);
|
||||
x0 = _mm512_fmadd_ps(down_output0, weight, x0);
|
||||
x1 = _mm512_fmadd_ps(down_output1, weight, x1);
|
||||
}
|
||||
auto f32out = (__m512*)((float*)output + i * config_.hidden_size + e);
|
||||
f32out[0] = x0;
|
||||
f32out[1] = x1;
|
||||
}
|
||||
},
|
||||
nullptr);
|
||||
#ifdef FORWARD_TIME_PROFILE
|
||||
{
|
||||
auto now_time = std::chrono::high_resolution_clock::now();
|
||||
weight_time = std::chrono::duration_cast<std::chrono::microseconds>(now_time - last).count();
|
||||
last = now_time;
|
||||
}
|
||||
auto end_time = std::chrono::high_resolution_clock::now();
|
||||
auto forward_total_time = std::chrono::duration_cast<std::chrono::microseconds>(end_time - start_time).count();
|
||||
// 在函数末尾一次性打印所有阶段的耗时,并附带 max_local_num 和 qlen
|
||||
printf(
|
||||
"Profiling Results (numa[%d]): activated_expert: %d, prepare: %ld us, cpy_input: %ld us, q_input: %ld us, "
|
||||
"up_gate: %ld us, act: %ld us, q_down: %ld us, down: %ld us, weight: %ld us, total: %ld us, max_local_num: "
|
||||
"%d, qlen: %d\n",
|
||||
tp_part_idx, activated_expert, prepare_time, cpy_input_time, q_input_time, up_gate_time, act_time, q_down_time,
|
||||
down_time, weight_time, forward_total_time, max_local_num, qlen);
|
||||
#endif
|
||||
}
|
||||
|
||||
void forward_decode(int k, const int64_t* expert_ids, const float* weights, const void* input, void* output) {
|
||||
int qlen = 1;
|
||||
auto pool = config_.pool->get_subpool(tp_part_idx);
|
||||
#ifdef FORWARD_TIME_PROFILE
|
||||
auto start_time = std::chrono::high_resolution_clock::now();
|
||||
auto last = start_time;
|
||||
// 用于保存各阶段耗时(单位:微秒)
|
||||
long prepare_time = 0, cpy_input_time = 0, q_input_time = 0, up_gate_time = 0;
|
||||
long act_time = 0, q_down_time = 0, down_time = 0, weight_time = 0;
|
||||
int max_local_num = 0; // 记录最大的 local num
|
||||
#endif
|
||||
|
||||
int activated_expert = 0;
|
||||
for (int i = 0; i < k; i++) {
|
||||
if (expert_ids[i] < config_.num_gpu_experts || expert_ids[i] >= config_.expert_num) {
|
||||
continue;
|
||||
}
|
||||
m_expert_id_map_[activated_expert] = expert_ids[i];
|
||||
activated_expert++;
|
||||
}
|
||||
|
||||
size_t offset = 0;
|
||||
for (int i = 0; i < activated_expert; i++) {
|
||||
auto expert_idx = m_expert_id_map_[i];
|
||||
m_local_gate_output_ptr_[expert_idx] = m_local_gate_output_ + offset * config_.intermediate_size;
|
||||
m_local_up_output_ptr_[expert_idx] = m_local_up_output_ + offset * config_.intermediate_size;
|
||||
m_local_down_output_ptr_[expert_idx] = m_local_down_output_ + offset * config_.hidden_size;
|
||||
offset += qlen;
|
||||
}
|
||||
|
||||
gate_up_ba_[0]->from_mat(qlen, (ggml_bf16_t*)input, 0, 1);
|
||||
|
||||
#ifdef FORWARD_TIME_PROFILE
|
||||
{
|
||||
auto now_time = std::chrono::high_resolution_clock::now();
|
||||
q_input_time = std::chrono::duration_cast<std::chrono::microseconds>(now_time - last).count();
|
||||
last = now_time;
|
||||
}
|
||||
#endif
|
||||
|
||||
int nth = T::recommended_nth(config_.intermediate_size);
|
||||
pool->do_work_stealing_job(
|
||||
nth * activated_expert * 2, [](int _) { T::config(); },
|
||||
[this, nth, qlen](int task_id2) {
|
||||
int task_id = task_id2 / 2;
|
||||
bool do_up = task_id2 % 2;
|
||||
int expert_idx = m_expert_id_map_[task_id / nth];
|
||||
|
||||
int ith = task_id % nth;
|
||||
if (do_up) {
|
||||
amx::vec_mul(qlen, config_.intermediate_size, config_.hidden_size, gate_up_ba_[0], up_bb_[expert_idx],
|
||||
up_bc_[expert_idx], ith, nth);
|
||||
up_bc_[expert_idx]->to_mat(qlen, m_local_up_output_ptr_[expert_idx], ith, nth);
|
||||
} else {
|
||||
amx::vec_mul(qlen, config_.intermediate_size, config_.hidden_size, gate_up_ba_[0], gate_bb_[expert_idx],
|
||||
gate_bc_[expert_idx], ith, nth);
|
||||
gate_bc_[expert_idx]->to_mat(qlen, m_local_gate_output_ptr_[expert_idx], ith, nth);
|
||||
}
|
||||
},
|
||||
nullptr);
|
||||
#ifdef FORWARD_TIME_PROFILE
|
||||
{
|
||||
auto now_time = std::chrono::high_resolution_clock::now();
|
||||
up_gate_time = std::chrono::duration_cast<std::chrono::microseconds>(now_time - last).count();
|
||||
last = now_time;
|
||||
}
|
||||
#endif
|
||||
|
||||
for (int task_id = 0; task_id < nth * activated_expert; task_id++) {
|
||||
int expert_idx = m_expert_id_map_[task_id / nth];
|
||||
int ith = task_id % nth;
|
||||
auto [n_start, n_end] = T::split_range_n(config_.intermediate_size, ith, nth);
|
||||
for (int i = 0; i < qlen; i++) {
|
||||
ggml_bf16_t* gate_output_ptr = &m_local_gate_output_ptr_[expert_idx][i * config_.intermediate_size];
|
||||
ggml_bf16_t* up_output_ptr = &m_local_up_output_ptr_[expert_idx][i * config_.intermediate_size];
|
||||
for (int j = n_start; j < n_end; j += 32) {
|
||||
__m512 gate_val0, gate_val1, up_val0, up_val1;
|
||||
avx512_32xbf16_to_32xfp32((__m512i*)(gate_output_ptr + j), &gate_val0, &gate_val1);
|
||||
avx512_32xbf16_to_32xfp32((__m512i*)(up_output_ptr + j), &up_val0, &up_val1);
|
||||
__m512 result0 = amx::act_fn(gate_val0, up_val0);
|
||||
__m512 result1 = amx::act_fn(gate_val1, up_val1);
|
||||
avx512_32xfp32_to_32xbf16(&result0, &result1, (__m512i*)(gate_output_ptr + j));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef FORWARD_TIME_PROFILE
|
||||
{
|
||||
auto now_time = std::chrono::high_resolution_clock::now();
|
||||
act_time = std::chrono::duration_cast<std::chrono::microseconds>(now_time - last).count();
|
||||
last = now_time;
|
||||
}
|
||||
#endif
|
||||
|
||||
pool->do_work_stealing_job(
|
||||
activated_expert, nullptr,
|
||||
[this, qlen](int task_id) {
|
||||
int expert_idx = m_expert_id_map_[task_id];
|
||||
down_ba_[expert_idx]->from_mat(qlen, m_local_gate_output_ptr_[expert_idx], 0, 1);
|
||||
},
|
||||
nullptr);
|
||||
#ifdef FORWARD_TIME_PROFILE
|
||||
{
|
||||
auto now_time = std::chrono::high_resolution_clock::now();
|
||||
q_down_time = std::chrono::duration_cast<std::chrono::microseconds>(now_time - last).count();
|
||||
last = now_time;
|
||||
}
|
||||
#endif
|
||||
|
||||
nth = T::recommended_nth(config_.hidden_size);
|
||||
pool->do_work_stealing_job(
|
||||
nth * activated_expert, [](int _) { T::config(); },
|
||||
[this, nth, qlen](int task_id) {
|
||||
int expert_idx = m_expert_id_map_[task_id / nth];
|
||||
int ith = task_id % nth;
|
||||
amx::vec_mul(qlen, config_.hidden_size, config_.intermediate_size, down_ba_[expert_idx], down_bb_[expert_idx],
|
||||
down_bc_[expert_idx], ith, nth);
|
||||
down_bc_[expert_idx]->to_mat(qlen, m_local_down_output_ptr_[expert_idx], ith, nth);
|
||||
},
|
||||
nullptr);
|
||||
#ifdef FORWARD_TIME_PROFILE
|
||||
{
|
||||
auto now_time = std::chrono::high_resolution_clock::now();
|
||||
down_time = std::chrono::duration_cast<std::chrono::microseconds>(now_time - last).count();
|
||||
last = now_time;
|
||||
}
|
||||
#endif
|
||||
for (int i = 0; i < qlen; i++) {
|
||||
for (int e = 0; e < config_.hidden_size; e += 32) {
|
||||
__m512 x0 = _mm512_setzero_ps();
|
||||
__m512 x1 = _mm512_setzero_ps();
|
||||
for (int j = 0; j < k; j++) {
|
||||
if (expert_ids[i * k + j] < config_.num_gpu_experts || expert_ids[i * k + j] >= config_.expert_num) {
|
||||
continue;
|
||||
}
|
||||
__m512 weight = _mm512_set1_ps(weights[i * k + j]);
|
||||
__m512 down_output0, down_output1;
|
||||
avx512_32xbf16_to_32xfp32((__m512i*)(m_local_down_output_ptr_[expert_ids[i * k + j]] +
|
||||
m_local_pos_[i][j] * config_.hidden_size + e),
|
||||
&down_output0, &down_output1);
|
||||
x0 = _mm512_fmadd_ps(down_output0, weight, x0);
|
||||
x1 = _mm512_fmadd_ps(down_output1, weight, x1);
|
||||
}
|
||||
auto f32out = (__m512*)((float*)output + i * config_.hidden_size + e);
|
||||
f32out[0] = x0;
|
||||
f32out[1] = x1;
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef FORWARD_TIME_PROFILE
|
||||
{
|
||||
auto now_time = std::chrono::high_resolution_clock::now();
|
||||
weight_time = std::chrono::duration_cast<std::chrono::microseconds>(now_time - last).count();
|
||||
last = now_time;
|
||||
}
|
||||
auto end_time = std::chrono::high_resolution_clock::now();
|
||||
auto forward_total_time = std::chrono::duration_cast<std::chrono::microseconds>(end_time - start_time).count();
|
||||
// 在函数末尾一次性打印所有阶段的耗时,并附带 max_local_num 和 qlen
|
||||
printf(
|
||||
"Profiling Results (numa[%d]) decode: activated_expert: %d, q_input: %ld us, "
|
||||
"up_gate: %ld us, act: %ld us, q_down: %ld us, down: %ld us, weight: %ld us, total: %ld us\n",
|
||||
tp_part_idx, activated_expert, q_input_time, up_gate_time, act_time, q_down_time, down_time, weight_time,
|
||||
forward_total_time);
|
||||
#endif
|
||||
}
|
||||
// forward, forward_prefill, forward_decode, warm_up are inherited from Base
|
||||
};
|
||||
|
||||
// ============================================================================
|
||||
// TP_MOE specialization for AMX_MOE_TP
|
||||
// Inherits from TP_MOE<AMX_MOE_BASE<...>> to reuse merge_results implementation
|
||||
// ============================================================================
|
||||
|
||||
template <typename K>
|
||||
class TP_MOE<AMX_MOE_TP<K>> : public TP_MOE_Common<AMX_MOE_TP<K>> {
|
||||
class TP_MOE<AMX_MOE_TP<K>> : public TP_MOE<AMX_MOE_BASE<K, AMX_MOE_TP<K>>> {
|
||||
public:
|
||||
using TP_MOE_Common<AMX_MOE_TP<K>>::TP_MOE_Common;
|
||||
void load_weights() {
|
||||
using Base = TP_MOE<AMX_MOE_BASE<K, AMX_MOE_TP<K>>>;
|
||||
using Base::Base;
|
||||
|
||||
void load_weights() override {
|
||||
auto& config = this->config;
|
||||
auto& tps = this->tps;
|
||||
auto& tp_count = this->tp_count;
|
||||
@@ -836,7 +380,6 @@ class TP_MOE<AMX_MOE_TP<K>> : public TP_MOE_Common<AMX_MOE_TP<K>> {
|
||||
const uint64_t* physical_to_logical_map = (const uint64_t*)config.physical_to_logical_map;
|
||||
if (config.gate_projs.empty() == false) {
|
||||
printf("TP Load from loader\n");
|
||||
// pool->dispense_backend()->do_numa_job([this, pool](int numa_id) { this->tps[numa_id]->load_weights(); });
|
||||
DO_TPS_LOAD_WEIGHTS(pool);
|
||||
this->weights_loaded = true;
|
||||
} else if (config.gate_proj != nullptr) {
|
||||
@@ -872,7 +415,6 @@ class TP_MOE<AMX_MOE_TP<K>> : public TP_MOE_Common<AMX_MOE_TP<K>> {
|
||||
}
|
||||
}
|
||||
|
||||
// pool->dispense_backend()->do_numa_job([this, pool](int numa_id) { this->tps[numa_id]->load_weights(); });
|
||||
DO_TPS_LOAD_WEIGHTS(pool);
|
||||
|
||||
for (auto i = 0; i < tp_count; i++) {
|
||||
@@ -885,7 +427,6 @@ class TP_MOE<AMX_MOE_TP<K>> : public TP_MOE_Common<AMX_MOE_TP<K>> {
|
||||
this->weights_loaded = true;
|
||||
} else if (config.path != "") {
|
||||
printf("TP Load from file\n");
|
||||
// pool->dispense_backend()->do_numa_job([this, pool](int numa_id) { this->tps[numa_id]->load_weights(); });
|
||||
DO_TPS_LOAD_WEIGHTS(pool);
|
||||
this->weights_loaded = true;
|
||||
} else {
|
||||
@@ -893,37 +434,7 @@ class TP_MOE<AMX_MOE_TP<K>> : public TP_MOE_Common<AMX_MOE_TP<K>> {
|
||||
}
|
||||
}
|
||||
|
||||
void merge_results(int qlen, void* output, bool incremental) {
|
||||
auto pool = this->config.pool;
|
||||
auto merge_fn = [this, output, incremental](int token_nth) {
|
||||
auto& local_output_numa = this->local_output_numa;
|
||||
auto& tp_configs = this->tp_configs;
|
||||
auto& tp_count = this->tp_count;
|
||||
auto& config = this->config;
|
||||
float* merge_to = local_output_numa[0] + token_nth * tp_configs[0].hidden_size;
|
||||
if (incremental) {
|
||||
for (int e = 0; e < config.hidden_size; e += 32) {
|
||||
__m512 x0, x1;
|
||||
avx512_32xbf16_to_32xfp32((__m512i*)((ggml_bf16_t*)output + token_nth * config.hidden_size + e), &x0, &x1);
|
||||
*((__m512*)(merge_to + e)) = _mm512_add_ps(*((__m512*)(merge_to + e)), x0);
|
||||
*((__m512*)(merge_to + e + 16)) = _mm512_add_ps(*((__m512*)(merge_to + e + 16)), x1);
|
||||
}
|
||||
}
|
||||
for (int i = 1; i < tp_count; i++) {
|
||||
float* merge_from = local_output_numa[i] + token_nth * tp_configs[i].hidden_size;
|
||||
for (int e = 0; e < tp_configs[i].hidden_size; e += 16) {
|
||||
*((__m512*)(merge_to + e)) = _mm512_add_ps(*((__m512*)(merge_to + e)), *((__m512*)(merge_from + e)));
|
||||
}
|
||||
}
|
||||
for (int e = 0; e < config.hidden_size; e += 32) {
|
||||
__m512 x0 = *(__m512*)(merge_to + e);
|
||||
__m512 x1 = *(__m512*)(merge_to + e + 16);
|
||||
avx512_32xfp32_to_32xbf16(&x0, &x1, (__m512i*)((ggml_bf16_t*)output + token_nth * config.hidden_size + e));
|
||||
}
|
||||
};
|
||||
DIRECT_OR_POOL_BY_QLEN(qlen, merge_fn);
|
||||
}
|
||||
void merge_results(int qlen, void* output) { merge_results(qlen, output, false); }
|
||||
// merge_results is inherited from TP_MOE<AMX_MOE_BASE<K, AMX_MOE_TP<K>>>
|
||||
};
|
||||
|
||||
#endif
|
||||
|
||||
763
kt-kernel/operators/amx/moe_base.hpp
Normal file
763
kt-kernel/operators/amx/moe_base.hpp
Normal file
@@ -0,0 +1,763 @@
|
||||
/**
|
||||
* @Description : Common AMX MoE base class extracted from K2 implementation.
|
||||
* @Author : oql, Codex and Claude
|
||||
* @Date : 2025-12-09
|
||||
* @Version : 0.1.0
|
||||
* @LastEditors : oql, Codex and Claude
|
||||
* @LastEditTime : 2025-12-09
|
||||
* @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
|
||||
**/
|
||||
#ifndef CPUINFER_OPERATOR_AMX_MOE_BASE_H
|
||||
#define CPUINFER_OPERATOR_AMX_MOE_BASE_H
|
||||
|
||||
// #define FORWARD_TIME_PROFILE
|
||||
|
||||
#include <immintrin.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <chrono>
|
||||
#include <cmath>
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
#include <cstdio>
|
||||
#include <cstring>
|
||||
#include <filesystem>
|
||||
#include <fstream>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "../../cpu_backend/shared_mem_buffer.h"
|
||||
#include "../../cpu_backend/worker_pool.h"
|
||||
#include "../common.hpp"
|
||||
#include "../moe-tp.hpp"
|
||||
#include "la/amx.hpp"
|
||||
#include "llama.cpp/ggml.h"
|
||||
|
||||
template <class T, class Derived>
|
||||
class AMX_MOE_BASE {
|
||||
public:
|
||||
int tp_part_idx = 0;
|
||||
|
||||
ggml_bf16_t* m_local_input_ = nullptr;
|
||||
ggml_bf16_t* m_local_gate_output_ = nullptr;
|
||||
ggml_bf16_t* m_local_up_output_ = nullptr;
|
||||
ggml_bf16_t* m_local_down_output_ = nullptr;
|
||||
|
||||
std::vector<std::vector<int>> m_local_pos_;
|
||||
std::vector<int> m_local_num_;
|
||||
std::vector<int> m_expert_id_map_;
|
||||
std::vector<ggml_bf16_t*> m_local_input_ptr_;
|
||||
std::vector<ggml_bf16_t*> m_local_gate_output_ptr_;
|
||||
std::vector<ggml_bf16_t*> m_local_up_output_ptr_;
|
||||
std::vector<ggml_bf16_t*> m_local_down_output_ptr_;
|
||||
|
||||
std::vector<std::shared_ptr<typename T::BufferA>> gate_up_ba_;
|
||||
std::vector<std::shared_ptr<typename T::BufferB>> gate_bb_;
|
||||
std::vector<std::shared_ptr<typename T::BufferC>> gate_bc_;
|
||||
std::vector<std::shared_ptr<typename T::BufferB>> up_bb_;
|
||||
std::vector<std::shared_ptr<typename T::BufferC>> up_bc_;
|
||||
std::vector<std::shared_ptr<typename T::BufferA>> down_ba_;
|
||||
std::vector<std::shared_ptr<typename T::BufferB>> down_bb_;
|
||||
std::vector<std::shared_ptr<typename T::BufferC>> down_bc_;
|
||||
|
||||
size_t pool_count_ = 0;
|
||||
size_t gate_up_ba_pool_bytes_ = 0;
|
||||
size_t gate_bc_pool_bytes_ = 0;
|
||||
size_t up_bc_pool_bytes_ = 0;
|
||||
size_t down_ba_pool_bytes_ = 0;
|
||||
size_t down_bc_pool_bytes_ = 0;
|
||||
void* gate_up_ba_pool_ = nullptr;
|
||||
void* gate_bc_pool_ = nullptr;
|
||||
void* up_bc_pool_ = nullptr;
|
||||
void* down_ba_pool_ = nullptr;
|
||||
void* down_bc_pool_ = nullptr;
|
||||
|
||||
GeneralMOEConfig config_;
|
||||
using input_t = ggml_bf16_t;
|
||||
using output_t = float;
|
||||
static constexpr double ELEMENT_SIZE = T::ELEMENT_SIZE;
|
||||
|
||||
AMX_MOE_BASE(GeneralMOEConfig config, int tp_part_idx_) : tp_part_idx(tp_part_idx_), config_(config) { init(); }
|
||||
|
||||
void init() {
|
||||
if (config_.load && config_.path == "") {
|
||||
config_.load = false;
|
||||
}
|
||||
|
||||
MemoryRequest mem_requests;
|
||||
mem_requests.append_pointer(
|
||||
&m_local_input_, sizeof(ggml_bf16_t) * config_.num_experts_per_tok * config_.max_len * config_.hidden_size);
|
||||
mem_requests.append_pointer(&m_local_gate_output_, sizeof(ggml_bf16_t) * config_.num_experts_per_tok *
|
||||
config_.max_len * config_.intermediate_size);
|
||||
mem_requests.append_pointer(&m_local_up_output_, sizeof(ggml_bf16_t) * config_.num_experts_per_tok *
|
||||
config_.max_len * config_.intermediate_size);
|
||||
mem_requests.append_pointer(&m_local_down_output_, sizeof(ggml_bf16_t) * config_.num_experts_per_tok *
|
||||
config_.max_len * config_.hidden_size);
|
||||
|
||||
m_local_pos_.resize(config_.max_len);
|
||||
for (int i = 0; i < config_.max_len; i++) {
|
||||
m_local_pos_[i].resize(config_.num_experts_per_tok);
|
||||
}
|
||||
m_expert_id_map_.resize(config_.expert_num);
|
||||
m_local_num_.resize(config_.expert_num);
|
||||
m_local_input_ptr_.resize(config_.expert_num);
|
||||
m_local_gate_output_ptr_.resize(config_.expert_num);
|
||||
m_local_up_output_ptr_.resize(config_.expert_num);
|
||||
m_local_down_output_ptr_.resize(config_.expert_num);
|
||||
|
||||
for (size_t i = 0; i < config_.expert_num; i++) {
|
||||
gate_up_ba_.push_back(make_buffer_a(config_.max_len, config_.hidden_size, nullptr));
|
||||
gate_bc_.push_back(make_buffer_c(config_.max_len, config_.intermediate_size, nullptr));
|
||||
up_bc_.push_back(make_buffer_c(config_.max_len, config_.intermediate_size, nullptr));
|
||||
down_ba_.push_back(make_buffer_a(config_.max_len, config_.intermediate_size, nullptr));
|
||||
down_bc_.push_back(make_buffer_c(config_.max_len, config_.hidden_size, nullptr));
|
||||
|
||||
void* gate_bb_ptr =
|
||||
std::aligned_alloc(64, buffer_b_required_size(config_.intermediate_size, config_.hidden_size));
|
||||
gate_bb_.push_back(make_buffer_b(config_.intermediate_size, config_.hidden_size, gate_bb_ptr));
|
||||
|
||||
void* up_bb_ptr = std::aligned_alloc(64, buffer_b_required_size(config_.intermediate_size, config_.hidden_size));
|
||||
up_bb_.push_back(make_buffer_b(config_.intermediate_size, config_.hidden_size, up_bb_ptr));
|
||||
|
||||
void* down_bb_ptr =
|
||||
std::aligned_alloc(64, buffer_b_required_size(config_.hidden_size, config_.intermediate_size));
|
||||
down_bb_.push_back(make_buffer_b(config_.hidden_size, config_.intermediate_size, down_bb_ptr));
|
||||
}
|
||||
// TODO: need update to all *.hpp
|
||||
// (config_.expert_num * T::M_STEP) in pool_count_ is to ensure padding for each experts.
|
||||
pool_count_ = config_.max_len * config_.num_experts_per_tok + config_.expert_num * T::M_STEP;
|
||||
|
||||
gate_up_ba_pool_bytes_ = buffer_a_required_size(pool_count_, config_.hidden_size) + pool_count_ * 64;
|
||||
gate_bc_pool_bytes_ = buffer_c_required_size(pool_count_, config_.intermediate_size) + pool_count_ * 64;
|
||||
up_bc_pool_bytes_ = buffer_c_required_size(pool_count_, config_.intermediate_size) + pool_count_ * 64;
|
||||
down_ba_pool_bytes_ = buffer_a_required_size(pool_count_, config_.intermediate_size) + pool_count_ * 64;
|
||||
down_bc_pool_bytes_ = buffer_c_required_size(pool_count_, config_.hidden_size) + pool_count_ * 64;
|
||||
|
||||
mem_requests.append_pointer(&gate_up_ba_pool_, gate_up_ba_pool_bytes_);
|
||||
mem_requests.append_pointer(&gate_bc_pool_, gate_bc_pool_bytes_);
|
||||
mem_requests.append_pointer(&up_bc_pool_, up_bc_pool_bytes_);
|
||||
mem_requests.append_pointer(&down_ba_pool_, down_ba_pool_bytes_);
|
||||
mem_requests.append_pointer(&down_bc_pool_, down_bc_pool_bytes_);
|
||||
|
||||
shared_mem_buffer_numa.alloc(tp_part_idx, this, mem_requests);
|
||||
}
|
||||
|
||||
~AMX_MOE_BASE() = default;
|
||||
|
||||
void warm_up() {
|
||||
int qlen = config_.max_len;
|
||||
std::vector<uint8_t> input(sizeof(ggml_bf16_t) * qlen * config_.hidden_size);
|
||||
std::vector<uint8_t> output(sizeof(ggml_bf16_t) * qlen * config_.hidden_size);
|
||||
std::vector<int64_t> expert_ids(qlen * config_.num_experts_per_tok);
|
||||
std::vector<float> weights(qlen * config_.num_experts_per_tok);
|
||||
for (int i = 0; i < qlen * config_.num_experts_per_tok; i++) {
|
||||
expert_ids[i] = i % config_.expert_num;
|
||||
weights[i] = 0.01;
|
||||
}
|
||||
forward(qlen, config_.num_experts_per_tok, expert_ids.data(), weights.data(), input.data(), output.data());
|
||||
}
|
||||
|
||||
void forward(int qlen, int k, const int64_t* expert_ids, const float* weights, const void* input, void* output) {
|
||||
if (qlen > 1) {
|
||||
forward_prefill(qlen, k, expert_ids, weights, input, output);
|
||||
} else {
|
||||
forward_decode(k, expert_ids, weights, input, output);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename... Args>
|
||||
void load_weights(Args&&... args) {
|
||||
derived()->load_weights(std::forward<Args>(args)...);
|
||||
}
|
||||
|
||||
template <typename... Args>
|
||||
void write_weights_to_buffer(Args&&... args) const {
|
||||
derived_const()->write_weights_to_buffer(std::forward<Args>(args)...);
|
||||
}
|
||||
|
||||
void forward_prefill(int qlen, int k, const int64_t* expert_ids, const float* weights, const void* input,
|
||||
void* output) {
|
||||
auto pool = config_.pool->get_subpool(tp_part_idx);
|
||||
#ifdef FORWARD_TIME_PROFILE
|
||||
auto start_time = std::chrono::high_resolution_clock::now();
|
||||
auto last = start_time;
|
||||
long prepare_time = 0, cpy_input_time = 0, q_input_time = 0, up_gate_time = 0;
|
||||
long act_time = 0, q_down_time = 0, down_time = 0, weight_time = 0;
|
||||
int max_local_num = 0;
|
||||
#endif
|
||||
|
||||
int activated_expert = 0;
|
||||
std::fill(m_local_num_.begin(), m_local_num_.end(), 0);
|
||||
for (int i = 0; i < qlen; i++) {
|
||||
for (int j = 0; j < k; j++) {
|
||||
if (expert_ids[i * k + j] < config_.num_gpu_experts || expert_ids[i * k + j] >= config_.expert_num) {
|
||||
continue;
|
||||
}
|
||||
m_local_pos_[i][j] = m_local_num_[expert_ids[i * k + j]]++;
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = 0; i < config_.expert_num; i++) {
|
||||
if (m_local_num_[i] > 0) {
|
||||
#ifdef FORWARD_TIME_PROFILE
|
||||
max_local_num = std::max(max_local_num, m_local_num_[i]);
|
||||
#endif
|
||||
m_expert_id_map_[activated_expert] = i;
|
||||
activated_expert++;
|
||||
}
|
||||
}
|
||||
|
||||
size_t offset = 0;
|
||||
void* gate_up_ba_pool_ptr = gate_up_ba_pool_;
|
||||
void* gate_bc_pool_ptr = gate_bc_pool_;
|
||||
void* up_bc_pool_ptr = up_bc_pool_;
|
||||
void* down_ba_pool_ptr = down_ba_pool_;
|
||||
void* down_bc_pool_ptr = down_bc_pool_;
|
||||
constexpr size_t M_STEP = T::M_STEP;
|
||||
auto align64 = [](size_t v) { return (v + 63) & (~(size_t)63); };
|
||||
size_t used_pool_m = 0;
|
||||
size_t used_pool_bytes_a = 0, used_pool_bytes_bc_gate = 0, used_pool_bytes_bc_up = 0, used_pool_bytes_ba_down = 0,
|
||||
used_pool_bytes_bc_down = 0;
|
||||
|
||||
for (int i = 0; i < config_.expert_num; i++) {
|
||||
m_local_input_ptr_[i] = m_local_input_ + offset * config_.hidden_size;
|
||||
m_local_gate_output_ptr_[i] = m_local_gate_output_ + offset * config_.intermediate_size;
|
||||
m_local_up_output_ptr_[i] = m_local_up_output_ + offset * config_.intermediate_size;
|
||||
m_local_down_output_ptr_[i] = m_local_down_output_ + offset * config_.hidden_size;
|
||||
offset += m_local_num_[i];
|
||||
|
||||
if (m_local_num_[i] == 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
size_t max_m = (m_local_num_[i] + M_STEP - 1) / M_STEP * M_STEP;
|
||||
gate_up_ba_[i]->max_m = max_m;
|
||||
gate_up_ba_[i]->set_data(gate_up_ba_pool_ptr);
|
||||
size_t ba_size = align64(buffer_a_required_size(max_m, config_.hidden_size));
|
||||
gate_up_ba_pool_ptr = (void*)((uintptr_t)gate_up_ba_pool_ptr + ba_size);
|
||||
|
||||
gate_bc_[i]->max_m = max_m;
|
||||
gate_bc_[i]->set_data(gate_bc_pool_ptr);
|
||||
size_t bc_gate_size = align64(buffer_c_required_size(max_m, config_.intermediate_size));
|
||||
gate_bc_pool_ptr = (void*)((uintptr_t)gate_bc_pool_ptr + bc_gate_size);
|
||||
|
||||
up_bc_[i]->max_m = max_m;
|
||||
up_bc_[i]->set_data(up_bc_pool_ptr);
|
||||
size_t bc_up_size = align64(buffer_c_required_size(max_m, config_.intermediate_size));
|
||||
up_bc_pool_ptr = (void*)((uintptr_t)up_bc_pool_ptr + bc_up_size);
|
||||
|
||||
down_ba_[i]->max_m = max_m;
|
||||
down_ba_[i]->set_data(down_ba_pool_ptr);
|
||||
size_t ba_down_size = align64(buffer_a_required_size(max_m, config_.intermediate_size));
|
||||
down_ba_pool_ptr = (void*)((uintptr_t)down_ba_pool_ptr + ba_down_size);
|
||||
|
||||
down_bc_[i]->max_m = max_m;
|
||||
down_bc_[i]->set_data(down_bc_pool_ptr);
|
||||
size_t bc_down_size = align64(buffer_c_required_size(max_m, config_.hidden_size));
|
||||
down_bc_pool_ptr = (void*)((uintptr_t)down_bc_pool_ptr + bc_down_size);
|
||||
|
||||
used_pool_m += max_m;
|
||||
used_pool_bytes_a += ba_size;
|
||||
used_pool_bytes_bc_gate += bc_gate_size;
|
||||
used_pool_bytes_bc_up += bc_up_size;
|
||||
used_pool_bytes_ba_down += ba_down_size;
|
||||
used_pool_bytes_bc_down += bc_down_size;
|
||||
}
|
||||
|
||||
assert(used_pool_m <= pool_count_);
|
||||
assert(used_pool_bytes_a <= gate_up_ba_pool_bytes_);
|
||||
assert(used_pool_bytes_bc_gate <= gate_bc_pool_bytes_);
|
||||
assert(used_pool_bytes_bc_up <= up_bc_pool_bytes_);
|
||||
assert(used_pool_bytes_ba_down <= down_ba_pool_bytes_);
|
||||
assert(used_pool_bytes_bc_down <= down_bc_pool_bytes_);
|
||||
|
||||
#ifdef FORWARD_TIME_PROFILE
|
||||
{
|
||||
auto now_time = std::chrono::high_resolution_clock::now();
|
||||
prepare_time = std::chrono::duration_cast<std::chrono::microseconds>(now_time - last).count();
|
||||
last = now_time;
|
||||
}
|
||||
#endif
|
||||
|
||||
auto direct_or_pool = [&](int count, auto&& fn) {
|
||||
if (qlen < 10) {
|
||||
for (int i = 0; i < count; i++) {
|
||||
fn(i);
|
||||
}
|
||||
} else {
|
||||
pool->do_work_stealing_job(count, nullptr, fn, nullptr);
|
||||
}
|
||||
};
|
||||
|
||||
direct_or_pool(qlen, [&](int i) {
|
||||
for (int j = 0; j < k; j++) {
|
||||
if (expert_ids[i * k + j] < config_.num_gpu_experts || expert_ids[i * k + j] >= config_.expert_num) {
|
||||
continue;
|
||||
}
|
||||
memcpy(m_local_input_ptr_[expert_ids[i * k + j]] + m_local_pos_[i][j] * config_.hidden_size,
|
||||
(ggml_bf16_t*)input + i * config_.hidden_size, sizeof(ggml_bf16_t) * config_.hidden_size);
|
||||
}
|
||||
});
|
||||
|
||||
#ifdef FORWARD_TIME_PROFILE
|
||||
{
|
||||
auto now_time = std::chrono::high_resolution_clock::now();
|
||||
cpy_input_time = std::chrono::duration_cast<std::chrono::microseconds>(now_time - last).count();
|
||||
last = now_time;
|
||||
}
|
||||
#endif
|
||||
|
||||
direct_or_pool(activated_expert, [this](int task_id) {
|
||||
int expert_idx = m_expert_id_map_[task_id];
|
||||
gate_up_ba_[expert_idx]->from_mat(m_local_num_[expert_idx], m_local_input_ptr_[expert_idx], 0, 1);
|
||||
});
|
||||
|
||||
#ifdef FORWARD_TIME_PROFILE
|
||||
{
|
||||
auto now_time = std::chrono::high_resolution_clock::now();
|
||||
q_input_time = std::chrono::duration_cast<std::chrono::microseconds>(now_time - last).count();
|
||||
last = now_time;
|
||||
}
|
||||
#endif
|
||||
|
||||
int nth = T::recommended_nth(config_.intermediate_size);
|
||||
pool->do_work_stealing_job(
|
||||
nth * activated_expert * 2, [](int _) { T::config(); },
|
||||
[this, nth, qlen](int task_id2) {
|
||||
int task_id = task_id2 / 2;
|
||||
bool do_up = task_id2 % 2;
|
||||
int expert_idx = m_expert_id_map_[task_id / nth];
|
||||
|
||||
int ith = task_id % nth;
|
||||
derived()->do_gate_up_gemm(do_up, expert_idx, ith, nth, qlen);
|
||||
if (do_up) {
|
||||
up_bc_[expert_idx]->to_mat(m_local_num_[expert_idx], m_local_up_output_ptr_[expert_idx], ith, nth);
|
||||
} else {
|
||||
gate_bc_[expert_idx]->to_mat(m_local_num_[expert_idx], m_local_gate_output_ptr_[expert_idx], ith, nth);
|
||||
}
|
||||
},
|
||||
nullptr);
|
||||
|
||||
#ifdef FORWARD_TIME_PROFILE
|
||||
{
|
||||
auto now_time = std::chrono::high_resolution_clock::now();
|
||||
up_gate_time = std::chrono::duration_cast<std::chrono::microseconds>(now_time - last).count();
|
||||
last = now_time;
|
||||
}
|
||||
#endif
|
||||
|
||||
apply_activation(activated_expert, nth, qlen);
|
||||
|
||||
#ifdef FORWARD_TIME_PROFILE
|
||||
{
|
||||
auto now_time = std::chrono::high_resolution_clock::now();
|
||||
act_time = std::chrono::duration_cast<std::chrono::microseconds>(now_time - last).count();
|
||||
last = now_time;
|
||||
}
|
||||
#endif
|
||||
|
||||
pool->do_work_stealing_job(
|
||||
activated_expert, nullptr,
|
||||
[this](int task_id) {
|
||||
int expert_idx = m_expert_id_map_[task_id];
|
||||
down_ba_[expert_idx]->from_mat(m_local_num_[expert_idx], m_local_gate_output_ptr_[expert_idx], 0, 1);
|
||||
},
|
||||
nullptr);
|
||||
|
||||
#ifdef FORWARD_TIME_PROFILE
|
||||
{
|
||||
auto now_time = std::chrono::high_resolution_clock::now();
|
||||
q_down_time = std::chrono::duration_cast<std::chrono::microseconds>(now_time - last).count();
|
||||
last = now_time;
|
||||
}
|
||||
#endif
|
||||
|
||||
nth = T::recommended_nth(config_.hidden_size);
|
||||
pool->do_work_stealing_job(
|
||||
nth * activated_expert, [](int _) { T::config(); },
|
||||
[this, nth, qlen](int task_id) {
|
||||
int expert_idx = m_expert_id_map_[task_id / nth];
|
||||
int ith = task_id % nth;
|
||||
derived()->do_down_gemm(expert_idx, ith, nth, qlen);
|
||||
down_bc_[expert_idx]->to_mat(m_local_num_[expert_idx], m_local_down_output_ptr_[expert_idx], ith, nth);
|
||||
},
|
||||
nullptr);
|
||||
|
||||
#ifdef FORWARD_TIME_PROFILE
|
||||
{
|
||||
auto now_time = std::chrono::high_resolution_clock::now();
|
||||
down_time = std::chrono::duration_cast<std::chrono::microseconds>(now_time - last).count();
|
||||
last = now_time;
|
||||
}
|
||||
#endif
|
||||
|
||||
pool->do_work_stealing_job(
|
||||
qlen, nullptr,
|
||||
[this, output, k, expert_ids, weights](int i) {
|
||||
for (int e = 0; e < config_.hidden_size; e += 32) {
|
||||
__m512 x0 = _mm512_setzero_ps();
|
||||
__m512 x1 = _mm512_setzero_ps();
|
||||
for (int j = 0; j < k; j++) {
|
||||
if (expert_ids[i * k + j] < config_.num_gpu_experts || expert_ids[i * k + j] >= config_.expert_num) {
|
||||
continue;
|
||||
}
|
||||
__m512 weight = _mm512_set1_ps(weights[i * k + j]);
|
||||
__m512 down_output0, down_output1;
|
||||
avx512_32xbf16_to_32xfp32((__m512i*)(m_local_down_output_ptr_[expert_ids[i * k + j]] +
|
||||
m_local_pos_[i][j] * config_.hidden_size + e),
|
||||
&down_output0, &down_output1);
|
||||
x0 = _mm512_fmadd_ps(down_output0, weight, x0);
|
||||
x1 = _mm512_fmadd_ps(down_output1, weight, x1);
|
||||
}
|
||||
auto f32out = (__m512*)((float*)output + i * config_.hidden_size + e);
|
||||
f32out[0] = x0;
|
||||
f32out[1] = x1;
|
||||
}
|
||||
},
|
||||
nullptr);
|
||||
|
||||
#ifdef FORWARD_TIME_PROFILE
|
||||
{
|
||||
auto now_time = std::chrono::high_resolution_clock::now();
|
||||
weight_time = std::chrono::duration_cast<std::chrono::microseconds>(now_time - last).count();
|
||||
last = now_time;
|
||||
}
|
||||
auto end_time = std::chrono::high_resolution_clock::now();
|
||||
auto forward_total_time = std::chrono::duration_cast<std::chrono::microseconds>(end_time - start_time).count();
|
||||
printf(
|
||||
"Profiling Results (numa[%d]): activated_expert: %d, prepare: %ld us, cpy_input: %ld us, q_input: %ld us, "
|
||||
"up_gate: %ld us, act: %ld us, q_down: %ld us, down: %ld us, weight: %ld us, total: %ld us, max_local_num: "
|
||||
"%d, qlen: %d\n",
|
||||
tp_part_idx, activated_expert, prepare_time, cpy_input_time, q_input_time, up_gate_time, act_time, q_down_time,
|
||||
down_time, weight_time, forward_total_time, max_local_num, qlen);
|
||||
#endif
|
||||
}
|
||||
|
||||
void forward_decode(int k, const int64_t* expert_ids, const float* weights, const void* input, void* output) {
|
||||
int qlen = 1;
|
||||
auto pool = config_.pool->get_subpool(tp_part_idx);
|
||||
#ifdef FORWARD_TIME_PROFILE
|
||||
auto start_time = std::chrono::high_resolution_clock::now();
|
||||
auto last = start_time;
|
||||
long q_input_time = 0, up_gate_time = 0, act_time = 0, q_down_time = 0, down_time = 0, weight_time = 0;
|
||||
#endif
|
||||
|
||||
int activated_expert = 0;
|
||||
std::fill(m_local_num_.begin(), m_local_num_.end(), 0);
|
||||
for (int i = 0; i < k; i++) {
|
||||
if (expert_ids[i] < config_.num_gpu_experts || expert_ids[i] >= config_.expert_num) {
|
||||
continue;
|
||||
}
|
||||
m_expert_id_map_[activated_expert] = expert_ids[i];
|
||||
m_local_pos_[0][i] = 0;
|
||||
m_local_num_[expert_ids[i]] = qlen;
|
||||
activated_expert++;
|
||||
}
|
||||
|
||||
size_t offset = 0;
|
||||
for (int i = 0; i < activated_expert; i++) {
|
||||
auto expert_idx = m_expert_id_map_[i];
|
||||
m_local_gate_output_ptr_[expert_idx] = m_local_gate_output_ + offset * config_.intermediate_size;
|
||||
m_local_up_output_ptr_[expert_idx] = m_local_up_output_ + offset * config_.intermediate_size;
|
||||
m_local_down_output_ptr_[expert_idx] = m_local_down_output_ + offset * config_.hidden_size;
|
||||
offset += qlen;
|
||||
}
|
||||
|
||||
void* gate_bc_pool_ptr = gate_bc_pool_;
|
||||
void* up_bc_pool_ptr = up_bc_pool_;
|
||||
void* down_ba_pool_ptr = down_ba_pool_;
|
||||
void* down_bc_pool_ptr = down_bc_pool_;
|
||||
constexpr size_t M_STEP = T::M_STEP;
|
||||
auto align64 = [](size_t v) { return (v + 63) & (~(size_t)63); };
|
||||
size_t used_pool_m = 0;
|
||||
size_t used_pool_bytes_bc_gate = 0, used_pool_bytes_bc_up = 0, used_pool_bytes_ba_down = 0,
|
||||
used_pool_bytes_bc_down = 0;
|
||||
for (int i = 0; i < activated_expert; i++) {
|
||||
auto expert_idx = m_expert_id_map_[i];
|
||||
size_t max_m = (qlen + M_STEP - 1) / M_STEP * M_STEP;
|
||||
|
||||
gate_bc_[expert_idx]->max_m = max_m;
|
||||
gate_bc_[expert_idx]->set_data(gate_bc_pool_ptr);
|
||||
size_t bc_gate_size = align64(buffer_c_required_size(max_m, config_.intermediate_size));
|
||||
gate_bc_pool_ptr = (void*)((uintptr_t)gate_bc_pool_ptr + bc_gate_size);
|
||||
|
||||
up_bc_[expert_idx]->max_m = max_m;
|
||||
up_bc_[expert_idx]->set_data(up_bc_pool_ptr);
|
||||
size_t bc_up_size = align64(buffer_c_required_size(max_m, config_.intermediate_size));
|
||||
up_bc_pool_ptr = (void*)((uintptr_t)up_bc_pool_ptr + bc_up_size);
|
||||
|
||||
down_ba_[expert_idx]->max_m = max_m;
|
||||
down_ba_[expert_idx]->set_data(down_ba_pool_ptr);
|
||||
size_t ba_down_size = align64(buffer_a_required_size(max_m, config_.intermediate_size));
|
||||
down_ba_pool_ptr = (void*)((uintptr_t)down_ba_pool_ptr + ba_down_size);
|
||||
|
||||
down_bc_[expert_idx]->max_m = max_m;
|
||||
down_bc_[expert_idx]->set_data(down_bc_pool_ptr);
|
||||
size_t bc_down_size = align64(buffer_c_required_size(max_m, config_.hidden_size));
|
||||
down_bc_pool_ptr = (void*)((uintptr_t)down_bc_pool_ptr + bc_down_size);
|
||||
|
||||
used_pool_m += max_m;
|
||||
used_pool_bytes_bc_gate += bc_gate_size;
|
||||
used_pool_bytes_bc_up += bc_up_size;
|
||||
used_pool_bytes_ba_down += ba_down_size;
|
||||
used_pool_bytes_bc_down += bc_down_size;
|
||||
}
|
||||
assert(used_pool_m <= pool_count_);
|
||||
assert(used_pool_bytes_bc_gate <= gate_bc_pool_bytes_);
|
||||
assert(used_pool_bytes_bc_up <= up_bc_pool_bytes_);
|
||||
assert(used_pool_bytes_ba_down <= down_ba_pool_bytes_);
|
||||
assert(used_pool_bytes_bc_down <= down_bc_pool_bytes_);
|
||||
|
||||
void* gate_up_ba_pool_ptr = gate_up_ba_pool_;
|
||||
for (int i = 0; i < activated_expert; i++) {
|
||||
auto expert_idx = m_expert_id_map_[i];
|
||||
size_t max_m = (qlen + M_STEP - 1) / M_STEP * M_STEP;
|
||||
gate_up_ba_[expert_idx]->max_m = max_m;
|
||||
gate_up_ba_[expert_idx]->set_data(gate_up_ba_pool_ptr);
|
||||
size_t ba_size = align64(buffer_a_required_size(max_m, config_.hidden_size));
|
||||
gate_up_ba_pool_ptr = (void*)((uintptr_t)gate_up_ba_pool_ptr + ba_size);
|
||||
gate_up_ba_[expert_idx]->from_mat(qlen, (ggml_bf16_t*)input, 0, 1);
|
||||
}
|
||||
|
||||
#ifdef FORWARD_TIME_PROFILE
|
||||
{
|
||||
auto now_time = std::chrono::high_resolution_clock::now();
|
||||
q_input_time = std::chrono::duration_cast<std::chrono::microseconds>(now_time - last).count();
|
||||
last = now_time;
|
||||
}
|
||||
#endif
|
||||
|
||||
int nth = T::recommended_nth(config_.intermediate_size);
|
||||
pool->do_work_stealing_job(
|
||||
nth * activated_expert * 2, [](int _) { T::config(); },
|
||||
[this, nth, qlen](int task_id2) {
|
||||
int task_id = task_id2 / 2;
|
||||
bool do_up = task_id2 % 2;
|
||||
int expert_idx = m_expert_id_map_[task_id / nth];
|
||||
|
||||
int ith = task_id % nth;
|
||||
derived()->do_gate_up_gemm(do_up, expert_idx, ith, nth, qlen);
|
||||
if (do_up) {
|
||||
up_bc_[expert_idx]->to_mat(qlen, m_local_up_output_ptr_[expert_idx], ith, nth);
|
||||
} else {
|
||||
gate_bc_[expert_idx]->to_mat(qlen, m_local_gate_output_ptr_[expert_idx], ith, nth);
|
||||
}
|
||||
},
|
||||
nullptr);
|
||||
|
||||
#ifdef FORWARD_TIME_PROFILE
|
||||
{
|
||||
auto now_time = std::chrono::high_resolution_clock::now();
|
||||
up_gate_time = std::chrono::duration_cast<std::chrono::microseconds>(now_time - last).count();
|
||||
last = now_time;
|
||||
}
|
||||
#endif
|
||||
|
||||
apply_activation(activated_expert, nth, qlen);
|
||||
|
||||
#ifdef FORWARD_TIME_PROFILE
|
||||
{
|
||||
auto now_time = std::chrono::high_resolution_clock::now();
|
||||
act_time = std::chrono::duration_cast<std::chrono::microseconds>(now_time - last).count();
|
||||
last = now_time;
|
||||
}
|
||||
#endif
|
||||
|
||||
pool->do_work_stealing_job(
|
||||
activated_expert, nullptr,
|
||||
[this, qlen](int task_id) {
|
||||
int expert_idx = m_expert_id_map_[task_id];
|
||||
down_ba_[expert_idx]->from_mat(qlen, m_local_gate_output_ptr_[expert_idx], 0, 1);
|
||||
},
|
||||
nullptr);
|
||||
|
||||
#ifdef FORWARD_TIME_PROFILE
|
||||
{
|
||||
auto now_time = std::chrono::high_resolution_clock::now();
|
||||
q_down_time = std::chrono::duration_cast<std::chrono::microseconds>(now_time - last).count();
|
||||
last = now_time;
|
||||
}
|
||||
#endif
|
||||
|
||||
nth = T::recommended_nth(config_.hidden_size);
|
||||
pool->do_work_stealing_job(
|
||||
nth * activated_expert, [](int _) { T::config(); },
|
||||
[this, nth, qlen](int task_id) {
|
||||
int expert_idx = m_expert_id_map_[task_id / nth];
|
||||
int ith = task_id % nth;
|
||||
derived()->do_down_gemm(expert_idx, ith, nth, qlen);
|
||||
down_bc_[expert_idx]->to_mat(qlen, m_local_down_output_ptr_[expert_idx], ith, nth);
|
||||
},
|
||||
nullptr);
|
||||
|
||||
#ifdef FORWARD_TIME_PROFILE
|
||||
{
|
||||
auto now_time = std::chrono::high_resolution_clock::now();
|
||||
down_time = std::chrono::duration_cast<std::chrono::microseconds>(now_time - last).count();
|
||||
last = now_time;
|
||||
}
|
||||
#endif
|
||||
|
||||
for (int e = 0; e < config_.hidden_size; e += 32) {
|
||||
__m512 x0 = _mm512_setzero_ps();
|
||||
__m512 x1 = _mm512_setzero_ps();
|
||||
for (int j = 0; j < k; j++) {
|
||||
if (expert_ids[j] < config_.num_gpu_experts || expert_ids[j] >= config_.expert_num) {
|
||||
continue;
|
||||
}
|
||||
__m512 weight = _mm512_set1_ps(weights[j]);
|
||||
__m512 down_output0, down_output1;
|
||||
avx512_32xbf16_to_32xfp32(
|
||||
(__m512i*)(m_local_down_output_ptr_[expert_ids[j]] + m_local_pos_[0][j] * config_.hidden_size + e),
|
||||
&down_output0, &down_output1);
|
||||
x0 = _mm512_fmadd_ps(down_output0, weight, x0);
|
||||
x1 = _mm512_fmadd_ps(down_output1, weight, x1);
|
||||
}
|
||||
auto f32out = (__m512*)((float*)output + e);
|
||||
f32out[0] = x0;
|
||||
f32out[1] = x1;
|
||||
}
|
||||
|
||||
#ifdef FORWARD_TIME_PROFILE
|
||||
{
|
||||
auto now_time = std::chrono::high_resolution_clock::now();
|
||||
weight_time = std::chrono::duration_cast<std::chrono::microseconds>(now_time - last).count();
|
||||
last = now_time;
|
||||
}
|
||||
auto end_time = std::chrono::high_resolution_clock::now();
|
||||
auto forward_total_time = std::chrono::duration_cast<std::chrono::microseconds>(end_time - start_time).count();
|
||||
printf(
|
||||
"Profiling Results (numa[%d]): activated_expert: %d, q_input: %ld us, "
|
||||
"up_gate: %ld us, act: %ld us, q_down: %ld us, down: %ld us, weight: %ld us, total: %ld us\n",
|
||||
tp_part_idx, activated_expert, q_input_time, up_gate_time, act_time, q_down_time, down_time, weight_time,
|
||||
forward_total_time);
|
||||
#endif
|
||||
}
|
||||
|
||||
protected:
|
||||
Derived* derived() { return static_cast<Derived*>(this); }
|
||||
const Derived* derived_const() const { return static_cast<const Derived*>(this); }
|
||||
|
||||
// ============================================================================
|
||||
// Virtual points for buffer creation and size calculation
|
||||
// Default implementations use group_size (for KGroup quantization like K2)
|
||||
// Derived classes (like moe.hpp) can override to not use group_size
|
||||
// ============================================================================
|
||||
|
||||
size_t buffer_a_required_size(size_t m, size_t k) const { return derived_const()->buffer_a_required_size_impl(m, k); }
|
||||
size_t buffer_b_required_size(size_t n, size_t k) const { return derived_const()->buffer_b_required_size_impl(n, k); }
|
||||
size_t buffer_c_required_size(size_t m, size_t n) const { return derived_const()->buffer_c_required_size_impl(m, n); }
|
||||
|
||||
std::shared_ptr<typename T::BufferA> make_buffer_a(size_t m, size_t k, void* data) const {
|
||||
return derived_const()->make_buffer_a_impl(m, k, data);
|
||||
}
|
||||
std::shared_ptr<typename T::BufferB> make_buffer_b(size_t n, size_t k, void* data) const {
|
||||
return derived_const()->make_buffer_b_impl(n, k, data);
|
||||
}
|
||||
std::shared_ptr<typename T::BufferC> make_buffer_c(size_t m, size_t n, void* data) const {
|
||||
return derived_const()->make_buffer_c_impl(m, n, data);
|
||||
}
|
||||
|
||||
void apply_activation(int activated_expert, int nth, int qlen) {
|
||||
auto pool = config_.pool->get_subpool(tp_part_idx);
|
||||
auto fn = [this, nth](int task_id) {
|
||||
int expert_idx = m_expert_id_map_[task_id / nth];
|
||||
int ith = task_id % nth;
|
||||
auto [n_start, n_end] = T::split_range_n(config_.intermediate_size, ith, nth);
|
||||
for (int i = 0; i < m_local_num_[expert_idx]; i++) {
|
||||
ggml_bf16_t* gate_output_ptr = &m_local_gate_output_ptr_[expert_idx][i * config_.intermediate_size];
|
||||
ggml_bf16_t* up_output_ptr = &m_local_up_output_ptr_[expert_idx][i * config_.intermediate_size];
|
||||
for (int j = n_start; j < n_end; j += 32) {
|
||||
__m512 gate_val0, gate_val1, up_val0, up_val1;
|
||||
avx512_32xbf16_to_32xfp32((__m512i*)(gate_output_ptr + j), &gate_val0, &gate_val1);
|
||||
avx512_32xbf16_to_32xfp32((__m512i*)(up_output_ptr + j), &up_val0, &up_val1);
|
||||
__m512 result0 = amx::act_fn(gate_val0, up_val0);
|
||||
__m512 result1 = amx::act_fn(gate_val1, up_val1);
|
||||
avx512_32xfp32_to_32xbf16(&result0, &result1, (__m512i*)(gate_output_ptr + j));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
if (activated_expert == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (qlen < 10) {
|
||||
for (int task_id = 0; task_id < nth * activated_expert; task_id++) {
|
||||
fn(task_id);
|
||||
}
|
||||
} else {
|
||||
pool->do_work_stealing_job(nth * activated_expert, nullptr, fn, nullptr);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// ============================================================================
|
||||
// TP_MOE specialization for AMX_MOE_BASE derived classes
|
||||
// ============================================================================
|
||||
|
||||
template <class T, class Derived>
|
||||
class TP_MOE<AMX_MOE_BASE<T, Derived>> : public TP_MOE_Common<AMX_MOE_BASE<T, Derived>> {
|
||||
public:
|
||||
using TP_MOE_Common<AMX_MOE_BASE<T, Derived>>::TP_MOE_Common;
|
||||
|
||||
// Default load_weights implementation - can be overridden by derived TP_MOE classes
|
||||
void load_weights() override { throw std::runtime_error("Not Implemented"); }
|
||||
|
||||
void write_weight_scale_to_buffer(int gpu_tp_count, int gpu_experts_num,
|
||||
const std::vector<uintptr_t>& w13_weight_ptrs,
|
||||
const std::vector<uintptr_t>& w13_scale_ptrs,
|
||||
const std::vector<uintptr_t>& w2_weight_ptrs,
|
||||
const std::vector<uintptr_t>& w2_scale_ptrs) {
|
||||
throw std::runtime_error("Not Implemented");
|
||||
}
|
||||
|
||||
void merge_results(int qlen, void* output, bool incremental) override {
|
||||
auto& config = this->config;
|
||||
auto& tp_count = this->tp_count;
|
||||
auto& local_output_numa = this->local_output_numa;
|
||||
auto& tp_configs = this->tp_configs;
|
||||
|
||||
auto merge_fn = [this, output, incremental, &config, &tp_count, &local_output_numa, &tp_configs](int token_nth) {
|
||||
float* merge_to = local_output_numa[0] + token_nth * tp_configs[0].hidden_size;
|
||||
if (incremental) {
|
||||
for (int e = 0; e < config.hidden_size; e += 32) {
|
||||
__m512 x0, x1;
|
||||
avx512_32xbf16_to_32xfp32((__m512i*)((ggml_bf16_t*)output + token_nth * config.hidden_size + e), &x0, &x1);
|
||||
*((__m512*)(merge_to + e)) = _mm512_add_ps(*((__m512*)(merge_to + e)), x0);
|
||||
*((__m512*)(merge_to + e + 16)) = _mm512_add_ps(*((__m512*)(merge_to + e + 16)), x1);
|
||||
}
|
||||
}
|
||||
for (int i = 1; i < tp_count; i++) {
|
||||
float* merge_from = local_output_numa[i] + token_nth * tp_configs[i].hidden_size;
|
||||
for (int e = 0; e < tp_configs[i].hidden_size; e += 16) {
|
||||
*((__m512*)(merge_to + e)) = _mm512_add_ps(*((__m512*)(merge_to + e)), *((__m512*)(merge_from + e)));
|
||||
}
|
||||
}
|
||||
for (int e = 0; e < config.hidden_size; e += 32) {
|
||||
__m512 x0 = *(__m512*)(merge_to + e);
|
||||
__m512 x1 = *(__m512*)(merge_to + e + 16);
|
||||
avx512_32xfp32_to_32xbf16(&x0, &x1, (__m512i*)((ggml_bf16_t*)output + token_nth * config.hidden_size + e));
|
||||
}
|
||||
};
|
||||
|
||||
auto pool = config.pool;
|
||||
|
||||
auto direct_or_pool = [&](int count, auto&& fn) {
|
||||
if (qlen < 10) {
|
||||
for (int i = 0; i < count; i++) {
|
||||
fn(i);
|
||||
}
|
||||
} else {
|
||||
pool->do_work_stealing_job(count, nullptr, fn, nullptr);
|
||||
}
|
||||
};
|
||||
|
||||
direct_or_pool(qlen, merge_fn);
|
||||
}
|
||||
|
||||
void merge_results(int qlen, void* output) override { merge_results(qlen, output, false); }
|
||||
};
|
||||
|
||||
#endif // CPUINFER_OPERATOR_AMX_MOE_BASE_H
|
||||
@@ -27,6 +27,12 @@ dependencies = [
|
||||
"numpy>=1.24.0",
|
||||
"triton>=2.0.0",
|
||||
"gguf>=0.17.0",
|
||||
# CLI dependencies
|
||||
"typer[all]>=0.9.0",
|
||||
"rich>=13.0.0",
|
||||
"pyyaml>=6.0",
|
||||
"httpx>=0.25.0",
|
||||
"packaging>=23.0",
|
||||
# Development dependencies
|
||||
"black>=25.9.0",
|
||||
]
|
||||
@@ -37,19 +43,35 @@ test = [
|
||||
"psutil>=5.9.0",
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
kt = "kt_kernel.cli.main:main"
|
||||
|
||||
[project.urls]
|
||||
Homepage = "https://github.com/kvcache-ai"
|
||||
|
||||
[tool.setuptools]
|
||||
packages = ["kt_kernel", "kt_kernel.utils"]
|
||||
packages = [
|
||||
"kt_kernel",
|
||||
"kt_kernel.utils",
|
||||
"kt_kernel.cli",
|
||||
"kt_kernel.cli.commands",
|
||||
"kt_kernel.cli.config",
|
||||
"kt_kernel.cli.utils",
|
||||
"kt_kernel.cli.completions",
|
||||
]
|
||||
include-package-data = true
|
||||
|
||||
[tool.setuptools.package-dir]
|
||||
kt_kernel = "python"
|
||||
"kt_kernel.utils" = "python/utils"
|
||||
"kt_kernel.cli" = "python/cli"
|
||||
"kt_kernel.cli.commands" = "python/cli/commands"
|
||||
"kt_kernel.cli.config" = "python/cli/config"
|
||||
"kt_kernel.cli.utils" = "python/cli/utils"
|
||||
"kt_kernel.cli.completions" = "python/cli/completions"
|
||||
|
||||
[tool.setuptools.package-data]
|
||||
# (empty) placeholder if you later add resources
|
||||
"kt_kernel.cli.completions" = ["*.bash", "*.fish", "_kt"]
|
||||
|
||||
[tool.setuptools.exclude-package-data]
|
||||
# (empty)
|
||||
|
||||
@@ -37,11 +37,13 @@ from __future__ import annotations
|
||||
|
||||
# Detect CPU and load optimal extension variant
|
||||
from ._cpu_detect import initialize as _initialize_cpu
|
||||
|
||||
_kt_kernel_ext, __cpu_variant__ = _initialize_cpu()
|
||||
|
||||
# Make the extension module available to other modules in this package
|
||||
import sys
|
||||
sys.modules['kt_kernel_ext'] = _kt_kernel_ext
|
||||
|
||||
sys.modules["kt_kernel_ext"] = _kt_kernel_ext
|
||||
|
||||
# Also expose kt_kernel_ext as an attribute for backward compatibility
|
||||
kt_kernel_ext = _kt_kernel_ext
|
||||
@@ -53,25 +55,28 @@ from .experts import KTMoEWrapper
|
||||
try:
|
||||
# Try to get version from installed package metadata (works in installed environment)
|
||||
from importlib.metadata import version, PackageNotFoundError
|
||||
|
||||
try:
|
||||
__version__ = version('kt-kernel')
|
||||
__version__ = version("kt-kernel")
|
||||
except PackageNotFoundError:
|
||||
# Package not installed, try to read from source tree version.py
|
||||
import os
|
||||
_root_version_file = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(__file__))), 'version.py')
|
||||
|
||||
_root_version_file = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(__file__))), "version.py")
|
||||
if os.path.exists(_root_version_file):
|
||||
_version_ns = {}
|
||||
with open(_root_version_file, 'r', encoding='utf-8') as f:
|
||||
with open(_root_version_file, "r", encoding="utf-8") as f:
|
||||
exec(f.read(), _version_ns)
|
||||
__version__ = _version_ns.get('__version__', '0.4.3')
|
||||
__version__ = _version_ns.get("__version__", "0.4.3")
|
||||
else:
|
||||
__version__ = "0.4.3"
|
||||
except ImportError:
|
||||
# Python < 3.8, fallback to pkg_resources or hardcoded version
|
||||
try:
|
||||
from pkg_resources import get_distribution, DistributionNotFound
|
||||
|
||||
try:
|
||||
__version__ = get_distribution('kt-kernel').version
|
||||
__version__ = get_distribution("kt-kernel").version
|
||||
except DistributionNotFound:
|
||||
__version__ = "0.4.3"
|
||||
except ImportError:
|
||||
|
||||
@@ -17,6 +17,7 @@ Example:
|
||||
>>> os.environ['KT_KERNEL_CPU_VARIANT'] = 'avx2'
|
||||
>>> import kt_kernel # Will use AVX2 variant
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
@@ -35,82 +36,82 @@ def detect_cpu_features():
|
||||
str: 'amx', 'avx512', or 'avx2'
|
||||
"""
|
||||
# Check environment override
|
||||
variant = os.environ.get('KT_KERNEL_CPU_VARIANT', '').lower()
|
||||
if variant in ['amx', 'avx512', 'avx2']:
|
||||
if os.environ.get('KT_KERNEL_DEBUG') == '1':
|
||||
variant = os.environ.get("KT_KERNEL_CPU_VARIANT", "").lower()
|
||||
if variant in ["amx", "avx512", "avx2"]:
|
||||
if os.environ.get("KT_KERNEL_DEBUG") == "1":
|
||||
print(f"[kt-kernel] Using environment override: {variant}")
|
||||
return variant
|
||||
|
||||
# Try to read /proc/cpuinfo on Linux
|
||||
try:
|
||||
with open('/proc/cpuinfo', 'r') as f:
|
||||
with open("/proc/cpuinfo", "r") as f:
|
||||
cpuinfo = f.read().lower()
|
||||
|
||||
# Check for AMX support (Intel Sapphire Rapids+)
|
||||
# AMX requires amx_tile, amx_int8, and amx_bf16
|
||||
amx_flags = ['amx_tile', 'amx_int8', 'amx_bf16']
|
||||
amx_flags = ["amx_tile", "amx_int8", "amx_bf16"]
|
||||
has_amx = all(flag in cpuinfo for flag in amx_flags)
|
||||
|
||||
if has_amx:
|
||||
if os.environ.get('KT_KERNEL_DEBUG') == '1':
|
||||
if os.environ.get("KT_KERNEL_DEBUG") == "1":
|
||||
print("[kt-kernel] Detected AMX support via /proc/cpuinfo")
|
||||
return 'amx'
|
||||
return "amx"
|
||||
|
||||
# Check for AVX512 support
|
||||
# AVX512F is the foundation for all AVX512 variants
|
||||
if 'avx512f' in cpuinfo:
|
||||
if os.environ.get('KT_KERNEL_DEBUG') == '1':
|
||||
if "avx512f" in cpuinfo:
|
||||
if os.environ.get("KT_KERNEL_DEBUG") == "1":
|
||||
print("[kt-kernel] Detected AVX512 support via /proc/cpuinfo")
|
||||
return 'avx512'
|
||||
return "avx512"
|
||||
|
||||
# Check for AVX2 support
|
||||
if 'avx2' in cpuinfo:
|
||||
if os.environ.get('KT_KERNEL_DEBUG') == '1':
|
||||
if "avx2" in cpuinfo:
|
||||
if os.environ.get("KT_KERNEL_DEBUG") == "1":
|
||||
print("[kt-kernel] Detected AVX2 support via /proc/cpuinfo")
|
||||
return 'avx2'
|
||||
return "avx2"
|
||||
|
||||
# Fallback to AVX2 (should be rare on modern CPUs)
|
||||
if os.environ.get('KT_KERNEL_DEBUG') == '1':
|
||||
if os.environ.get("KT_KERNEL_DEBUG") == "1":
|
||||
print("[kt-kernel] No AVX2/AVX512/AMX detected, using AVX2 fallback")
|
||||
return 'avx2'
|
||||
return "avx2"
|
||||
|
||||
except FileNotFoundError:
|
||||
# /proc/cpuinfo doesn't exist (not Linux or in container)
|
||||
# Try cpufeature package as fallback
|
||||
if os.environ.get('KT_KERNEL_DEBUG') == '1':
|
||||
if os.environ.get("KT_KERNEL_DEBUG") == "1":
|
||||
print("[kt-kernel] /proc/cpuinfo not found, trying cpufeature package")
|
||||
|
||||
try:
|
||||
import cpufeature
|
||||
|
||||
# Check for AMX
|
||||
if cpufeature.CPUFeature.get('AMX_TILE', False):
|
||||
if os.environ.get('KT_KERNEL_DEBUG') == '1':
|
||||
if cpufeature.CPUFeature.get("AMX_TILE", False):
|
||||
if os.environ.get("KT_KERNEL_DEBUG") == "1":
|
||||
print("[kt-kernel] Detected AMX support via cpufeature")
|
||||
return 'amx'
|
||||
return "amx"
|
||||
|
||||
# Check for AVX512
|
||||
if cpufeature.CPUFeature.get('AVX512F', False):
|
||||
if os.environ.get('KT_KERNEL_DEBUG') == '1':
|
||||
if cpufeature.CPUFeature.get("AVX512F", False):
|
||||
if os.environ.get("KT_KERNEL_DEBUG") == "1":
|
||||
print("[kt-kernel] Detected AVX512 support via cpufeature")
|
||||
return 'avx512'
|
||||
return "avx512"
|
||||
|
||||
# Fallback to AVX2
|
||||
if os.environ.get('KT_KERNEL_DEBUG') == '1':
|
||||
if os.environ.get("KT_KERNEL_DEBUG") == "1":
|
||||
print("[kt-kernel] Using AVX2 fallback via cpufeature")
|
||||
return 'avx2'
|
||||
return "avx2"
|
||||
|
||||
except ImportError:
|
||||
# cpufeature not available - ultimate fallback
|
||||
if os.environ.get('KT_KERNEL_DEBUG') == '1':
|
||||
if os.environ.get("KT_KERNEL_DEBUG") == "1":
|
||||
print("[kt-kernel] cpufeature not available, using AVX2 fallback")
|
||||
return 'avx2'
|
||||
return "avx2"
|
||||
|
||||
except Exception as e:
|
||||
# Any other error - safe fallback
|
||||
if os.environ.get('KT_KERNEL_DEBUG') == '1':
|
||||
if os.environ.get("KT_KERNEL_DEBUG") == "1":
|
||||
print(f"[kt-kernel] Error during CPU detection: {e}, using AVX2 fallback")
|
||||
return 'avx2'
|
||||
return "avx2"
|
||||
|
||||
|
||||
def load_extension(variant):
|
||||
@@ -148,51 +149,53 @@ def load_extension(variant):
|
||||
kt_kernel_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
# Try multi-variant naming first
|
||||
pattern = os.path.join(kt_kernel_dir, f'_kt_kernel_ext_{variant}.*.so')
|
||||
pattern = os.path.join(kt_kernel_dir, f"_kt_kernel_ext_{variant}.*.so")
|
||||
so_files = glob.glob(pattern)
|
||||
|
||||
if not so_files:
|
||||
# Try single-variant naming (fallback for builds without CPUINFER_BUILD_ALL_VARIANTS)
|
||||
pattern = os.path.join(kt_kernel_dir, 'kt_kernel_ext.*.so')
|
||||
pattern = os.path.join(kt_kernel_dir, "kt_kernel_ext.*.so")
|
||||
so_files = glob.glob(pattern)
|
||||
|
||||
if so_files:
|
||||
if os.environ.get('KT_KERNEL_DEBUG') == '1':
|
||||
if os.environ.get("KT_KERNEL_DEBUG") == "1":
|
||||
print(f"[kt-kernel] Multi-variant {variant} not found, using single-variant build")
|
||||
else:
|
||||
raise ImportError(f"No .so file found for variant {variant} (tried patterns: {kt_kernel_dir}/_kt_kernel_ext_{variant}.*.so and {kt_kernel_dir}/kt_kernel_ext.*.so)")
|
||||
raise ImportError(
|
||||
f"No .so file found for variant {variant} (tried patterns: {kt_kernel_dir}/_kt_kernel_ext_{variant}.*.so and {kt_kernel_dir}/kt_kernel_ext.*.so)"
|
||||
)
|
||||
|
||||
so_file = so_files[0]
|
||||
|
||||
if os.environ.get('KT_KERNEL_DEBUG') == '1':
|
||||
if os.environ.get("KT_KERNEL_DEBUG") == "1":
|
||||
print(f"[kt-kernel] Loading {variant} from: {so_file}")
|
||||
|
||||
# Load the module manually
|
||||
# The module exports PyInit_kt_kernel_ext, so we use that as the module name
|
||||
spec = importlib.util.spec_from_file_location('kt_kernel_ext', so_file)
|
||||
spec = importlib.util.spec_from_file_location("kt_kernel_ext", so_file)
|
||||
if spec is None or spec.loader is None:
|
||||
raise ImportError(f"Failed to create spec for {so_file}")
|
||||
|
||||
ext = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(ext)
|
||||
|
||||
if os.environ.get('KT_KERNEL_DEBUG') == '1':
|
||||
if os.environ.get("KT_KERNEL_DEBUG") == "1":
|
||||
print(f"[kt-kernel] Successfully loaded {variant.upper()} variant")
|
||||
return ext
|
||||
|
||||
except (ImportError, ModuleNotFoundError, FileNotFoundError) as e:
|
||||
if os.environ.get('KT_KERNEL_DEBUG') == '1':
|
||||
if os.environ.get("KT_KERNEL_DEBUG") == "1":
|
||||
print(f"[kt-kernel] Failed to load {variant} variant: {e}")
|
||||
|
||||
# Automatic fallback to next best variant
|
||||
if variant == 'amx':
|
||||
if os.environ.get('KT_KERNEL_DEBUG') == '1':
|
||||
if variant == "amx":
|
||||
if os.environ.get("KT_KERNEL_DEBUG") == "1":
|
||||
print("[kt-kernel] Falling back from AMX to AVX512")
|
||||
return load_extension('avx512')
|
||||
elif variant == 'avx512':
|
||||
if os.environ.get('KT_KERNEL_DEBUG') == '1':
|
||||
return load_extension("avx512")
|
||||
elif variant == "avx512":
|
||||
if os.environ.get("KT_KERNEL_DEBUG") == "1":
|
||||
print("[kt-kernel] Falling back from AVX512 to AVX2")
|
||||
return load_extension('avx2')
|
||||
return load_extension("avx2")
|
||||
else:
|
||||
# AVX2 is the last fallback - if this fails, we can't continue
|
||||
raise ImportError(
|
||||
@@ -221,13 +224,13 @@ def initialize():
|
||||
# Detect CPU features
|
||||
variant = detect_cpu_features()
|
||||
|
||||
if os.environ.get('KT_KERNEL_DEBUG') == '1':
|
||||
if os.environ.get("KT_KERNEL_DEBUG") == "1":
|
||||
print(f"[kt-kernel] Selected CPU variant: {variant}")
|
||||
|
||||
# Load the appropriate extension
|
||||
ext = load_extension(variant)
|
||||
|
||||
if os.environ.get('KT_KERNEL_DEBUG') == '1':
|
||||
if os.environ.get("KT_KERNEL_DEBUG") == "1":
|
||||
print(f"[kt-kernel] Extension module loaded: {ext.__name__}")
|
||||
|
||||
return ext, variant
|
||||
|
||||
8
kt-kernel/python/cli/__init__.py
Normal file
8
kt-kernel/python/cli/__init__.py
Normal file
@@ -0,0 +1,8 @@
|
||||
"""
|
||||
KTransformers CLI - A unified command-line interface for KTransformers.
|
||||
|
||||
This CLI provides a user-friendly interface to all KTransformers functionality,
|
||||
including model inference, fine-tuning, benchmarking, and more.
|
||||
"""
|
||||
|
||||
__version__ = "0.1.0"
|
||||
3
kt-kernel/python/cli/commands/__init__.py
Normal file
3
kt-kernel/python/cli/commands/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
"""
|
||||
Command modules for kt-cli.
|
||||
"""
|
||||
274
kt-kernel/python/cli/commands/bench.py
Normal file
274
kt-kernel/python/cli/commands/bench.py
Normal file
@@ -0,0 +1,274 @@
|
||||
"""
|
||||
Bench commands for kt-cli.
|
||||
|
||||
Runs benchmarks for performance testing.
|
||||
"""
|
||||
|
||||
import subprocess
|
||||
import sys
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import typer
|
||||
|
||||
from kt_kernel.cli.i18n import t
|
||||
from kt_kernel.cli.utils.console import (
|
||||
console,
|
||||
print_error,
|
||||
print_info,
|
||||
print_step,
|
||||
print_success,
|
||||
)
|
||||
|
||||
|
||||
class BenchType(str, Enum):
|
||||
"""Benchmark type."""
|
||||
|
||||
INFERENCE = "inference"
|
||||
MLA = "mla"
|
||||
MOE = "moe"
|
||||
LINEAR = "linear"
|
||||
ATTENTION = "attention"
|
||||
ALL = "all"
|
||||
|
||||
|
||||
def bench(
|
||||
type: BenchType = typer.Option(
|
||||
BenchType.ALL,
|
||||
"--type",
|
||||
"-t",
|
||||
help="Benchmark type",
|
||||
),
|
||||
model: Optional[str] = typer.Option(
|
||||
None,
|
||||
"--model",
|
||||
"-m",
|
||||
help="Model to benchmark",
|
||||
),
|
||||
output: Optional[Path] = typer.Option(
|
||||
None,
|
||||
"--output",
|
||||
"-o",
|
||||
help="Output file for results (JSON)",
|
||||
),
|
||||
iterations: int = typer.Option(
|
||||
10,
|
||||
"--iterations",
|
||||
"-n",
|
||||
help="Number of iterations",
|
||||
),
|
||||
) -> None:
|
||||
"""Run full benchmark suite."""
|
||||
console.print()
|
||||
print_step(t("bench_starting"))
|
||||
print_info(t("bench_type", type=type.value))
|
||||
console.print()
|
||||
|
||||
if type == BenchType.ALL:
|
||||
_run_all_benchmarks(model, output, iterations)
|
||||
elif type == BenchType.INFERENCE:
|
||||
_run_inference_benchmark(model, output, iterations)
|
||||
elif type == BenchType.MLA:
|
||||
_run_component_benchmark("mla", output, iterations)
|
||||
elif type == BenchType.MOE:
|
||||
_run_component_benchmark("moe", output, iterations)
|
||||
elif type == BenchType.LINEAR:
|
||||
_run_component_benchmark("linear", output, iterations)
|
||||
elif type == BenchType.ATTENTION:
|
||||
_run_component_benchmark("attention", output, iterations)
|
||||
|
||||
console.print()
|
||||
print_success(t("bench_complete"))
|
||||
if output:
|
||||
console.print(f" Results saved to: {output}")
|
||||
console.print()
|
||||
|
||||
|
||||
def microbench(
|
||||
component: str = typer.Argument(
|
||||
"moe",
|
||||
help="Component to benchmark (moe, mla, linear, attention)",
|
||||
),
|
||||
batch_size: int = typer.Option(
|
||||
1,
|
||||
"--batch-size",
|
||||
"-b",
|
||||
help="Batch size",
|
||||
),
|
||||
seq_len: int = typer.Option(
|
||||
1,
|
||||
"--seq-len",
|
||||
"-s",
|
||||
help="Sequence length",
|
||||
),
|
||||
iterations: int = typer.Option(
|
||||
100,
|
||||
"--iterations",
|
||||
"-n",
|
||||
help="Number of iterations",
|
||||
),
|
||||
warmup: int = typer.Option(
|
||||
10,
|
||||
"--warmup",
|
||||
"-w",
|
||||
help="Warmup iterations",
|
||||
),
|
||||
output: Optional[Path] = typer.Option(
|
||||
None,
|
||||
"--output",
|
||||
"-o",
|
||||
help="Output file for results (JSON)",
|
||||
),
|
||||
) -> None:
|
||||
"""Run micro-benchmark for specific components."""
|
||||
console.print()
|
||||
console.print(f"[yellow]{t('feature_coming_soon')}[/yellow]")
|
||||
console.print()
|
||||
raise typer.Exit(0)
|
||||
|
||||
# Try to find the benchmark script
|
||||
kt_kernel_path = _find_kt_kernel_path()
|
||||
|
||||
if kt_kernel_path is None:
|
||||
print_error("kt-kernel not found. Install with: kt install inference")
|
||||
raise typer.Exit(1)
|
||||
|
||||
bench_dir = kt_kernel_path / "bench"
|
||||
|
||||
# Map component to script
|
||||
component_scripts = {
|
||||
"moe": "bench_moe.py",
|
||||
"mla": "bench_mla.py",
|
||||
"linear": "bench_linear.py",
|
||||
"attention": "bench_attention.py",
|
||||
"mlp": "bench_mlp.py",
|
||||
}
|
||||
|
||||
script_name = component_scripts.get(component.lower())
|
||||
if script_name is None:
|
||||
print_error(f"Unknown component: {component}")
|
||||
console.print(f"Available: {', '.join(component_scripts.keys())}")
|
||||
raise typer.Exit(1)
|
||||
|
||||
script_path = bench_dir / script_name
|
||||
if not script_path.exists():
|
||||
print_error(f"Benchmark script not found: {script_path}")
|
||||
raise typer.Exit(1)
|
||||
|
||||
# Run benchmark
|
||||
cmd = [
|
||||
sys.executable,
|
||||
str(script_path),
|
||||
"--batch-size",
|
||||
str(batch_size),
|
||||
"--seq-len",
|
||||
str(seq_len),
|
||||
"--iterations",
|
||||
str(iterations),
|
||||
"--warmup",
|
||||
str(warmup),
|
||||
]
|
||||
|
||||
if output:
|
||||
cmd.extend(["--output", str(output)])
|
||||
|
||||
console.print(f"[dim]$ {' '.join(cmd)}[/dim]")
|
||||
console.print()
|
||||
|
||||
try:
|
||||
process = subprocess.run(cmd)
|
||||
|
||||
if process.returncode == 0:
|
||||
console.print()
|
||||
print_success(t("bench_complete"))
|
||||
if output:
|
||||
console.print(f" Results saved to: {output}")
|
||||
else:
|
||||
print_error(f"Benchmark failed with exit code {process.returncode}")
|
||||
raise typer.Exit(process.returncode)
|
||||
|
||||
except FileNotFoundError as e:
|
||||
print_error(f"Failed to run benchmark: {e}")
|
||||
raise typer.Exit(1)
|
||||
|
||||
|
||||
def _find_kt_kernel_path() -> Optional[Path]:
|
||||
"""Find the kt-kernel installation path."""
|
||||
try:
|
||||
import kt_kernel
|
||||
|
||||
return Path(kt_kernel.__file__).parent.parent
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
# Check common locations
|
||||
possible_paths = [
|
||||
Path.home() / "Projects" / "ktransformers" / "kt-kernel",
|
||||
Path("/opt/ktransformers/kt-kernel"),
|
||||
Path.cwd() / "kt-kernel",
|
||||
]
|
||||
|
||||
for path in possible_paths:
|
||||
if path.exists() and (path / "bench").exists():
|
||||
return path
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _run_all_benchmarks(model: Optional[str], output: Optional[Path], iterations: int) -> None:
|
||||
"""Run all benchmarks."""
|
||||
components = ["moe", "mla", "linear", "attention"]
|
||||
|
||||
for component in components:
|
||||
console.print(f"\n[bold]Running {component} benchmark...[/bold]")
|
||||
_run_component_benchmark(component, None, iterations)
|
||||
|
||||
|
||||
def _run_inference_benchmark(model: Optional[str], output: Optional[Path], iterations: int) -> None:
|
||||
"""Run inference benchmark."""
|
||||
if model is None:
|
||||
print_error("Model required for inference benchmark. Use --model flag.")
|
||||
raise typer.Exit(1)
|
||||
|
||||
print_info(f"Running inference benchmark on {model}...")
|
||||
console.print()
|
||||
console.print("[dim]This will start the server and run test requests.[/dim]")
|
||||
console.print()
|
||||
|
||||
# TODO: Implement actual inference benchmarking
|
||||
print_error("Inference benchmarking not yet implemented.")
|
||||
|
||||
|
||||
def _run_component_benchmark(component: str, output: Optional[Path], iterations: int) -> None:
|
||||
"""Run a component benchmark."""
|
||||
kt_kernel_path = _find_kt_kernel_path()
|
||||
|
||||
if kt_kernel_path is None:
|
||||
print_error("kt-kernel not found.")
|
||||
return
|
||||
|
||||
bench_dir = kt_kernel_path / "bench"
|
||||
script_map = {
|
||||
"moe": "bench_moe.py",
|
||||
"mla": "bench_mla.py",
|
||||
"linear": "bench_linear.py",
|
||||
"attention": "bench_attention.py",
|
||||
}
|
||||
|
||||
script_name = script_map.get(component)
|
||||
if script_name is None:
|
||||
print_error(f"Unknown component: {component}")
|
||||
return
|
||||
|
||||
script_path = bench_dir / script_name
|
||||
if not script_path.exists():
|
||||
print_error(f"Script not found: {script_path}")
|
||||
return
|
||||
|
||||
cmd = [sys.executable, str(script_path), "--iterations", str(iterations)]
|
||||
|
||||
try:
|
||||
subprocess.run(cmd)
|
||||
except Exception as e:
|
||||
print_error(f"Benchmark failed: {e}")
|
||||
437
kt-kernel/python/cli/commands/chat.py
Normal file
437
kt-kernel/python/cli/commands/chat.py
Normal file
@@ -0,0 +1,437 @@
|
||||
"""
|
||||
Chat command for kt-cli.
|
||||
|
||||
Provides interactive chat interface with running model server.
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import typer
|
||||
from rich.console import Console
|
||||
from rich.markdown import Markdown
|
||||
from rich.panel import Panel
|
||||
from rich.prompt import Prompt, Confirm
|
||||
|
||||
from kt_kernel.cli.config.settings import get_settings
|
||||
from kt_kernel.cli.i18n import t
|
||||
from kt_kernel.cli.utils.console import (
|
||||
console,
|
||||
print_error,
|
||||
print_info,
|
||||
print_success,
|
||||
print_warning,
|
||||
)
|
||||
|
||||
# Try to import OpenAI SDK
|
||||
try:
|
||||
from openai import OpenAI
|
||||
|
||||
HAS_OPENAI = True
|
||||
except ImportError:
|
||||
HAS_OPENAI = False
|
||||
|
||||
|
||||
def chat(
|
||||
host: Optional[str] = typer.Option(
|
||||
None,
|
||||
"--host",
|
||||
"-H",
|
||||
help="Server host address",
|
||||
),
|
||||
port: Optional[int] = typer.Option(
|
||||
None,
|
||||
"--port",
|
||||
"-p",
|
||||
help="Server port",
|
||||
),
|
||||
model: Optional[str] = typer.Option(
|
||||
None,
|
||||
"--model",
|
||||
"-m",
|
||||
help="Model name (if server hosts multiple models)",
|
||||
),
|
||||
temperature: float = typer.Option(
|
||||
0.7,
|
||||
"--temperature",
|
||||
"-t",
|
||||
help="Sampling temperature (0.0 to 2.0)",
|
||||
),
|
||||
max_tokens: int = typer.Option(
|
||||
2048,
|
||||
"--max-tokens",
|
||||
help="Maximum tokens to generate",
|
||||
),
|
||||
system_prompt: Optional[str] = typer.Option(
|
||||
None,
|
||||
"--system",
|
||||
"-s",
|
||||
help="System prompt",
|
||||
),
|
||||
save_history: bool = typer.Option(
|
||||
True,
|
||||
"--save-history/--no-save-history",
|
||||
help="Save conversation history",
|
||||
),
|
||||
history_file: Optional[Path] = typer.Option(
|
||||
None,
|
||||
"--history-file",
|
||||
help="Path to save conversation history",
|
||||
),
|
||||
stream: bool = typer.Option(
|
||||
True,
|
||||
"--stream/--no-stream",
|
||||
help="Enable streaming output",
|
||||
),
|
||||
) -> None:
|
||||
"""Start interactive chat with a running model server.
|
||||
|
||||
Examples:
|
||||
kt chat # Connect to default server
|
||||
kt chat --host 127.0.0.1 -p 8080 # Connect to specific server
|
||||
kt chat -t 0.9 --max-tokens 4096 # Adjust generation parameters
|
||||
"""
|
||||
if not HAS_OPENAI:
|
||||
print_error("OpenAI Python SDK is required for chat functionality.")
|
||||
console.print()
|
||||
console.print("Install it with:")
|
||||
console.print(" pip install openai")
|
||||
raise typer.Exit(1)
|
||||
|
||||
settings = get_settings()
|
||||
|
||||
# Resolve server connection
|
||||
final_host = host or settings.get("server.host", "127.0.0.1")
|
||||
final_port = port or settings.get("server.port", 30000)
|
||||
|
||||
# Construct base URL for OpenAI-compatible API
|
||||
base_url = f"http://{final_host}:{final_port}/v1"
|
||||
|
||||
console.print()
|
||||
console.print(
|
||||
Panel.fit(
|
||||
f"[bold cyan]KTransformers Chat[/bold cyan]\n\n"
|
||||
f"Server: [yellow]{final_host}:{final_port}[/yellow]\n"
|
||||
f"Temperature: [cyan]{temperature}[/cyan] | Max tokens: [cyan]{max_tokens}[/cyan]\n\n"
|
||||
f"[dim]Type '/help' for commands, '/quit' to exit[/dim]",
|
||||
border_style="cyan",
|
||||
)
|
||||
)
|
||||
console.print()
|
||||
|
||||
# Check for proxy environment variables
|
||||
proxy_vars = ["HTTP_PROXY", "HTTPS_PROXY", "http_proxy", "https_proxy", "ALL_PROXY", "all_proxy"]
|
||||
detected_proxies = {var: os.environ.get(var) for var in proxy_vars if os.environ.get(var)}
|
||||
|
||||
if detected_proxies:
|
||||
proxy_info = ", ".join(f"{k}={v}" for k, v in detected_proxies.items())
|
||||
console.print()
|
||||
print_warning(t("chat_proxy_detected"))
|
||||
console.print(f" [dim]{proxy_info}[/dim]")
|
||||
console.print()
|
||||
|
||||
use_proxy = Confirm.ask(t("chat_proxy_confirm"), default=False)
|
||||
|
||||
if not use_proxy:
|
||||
# Temporarily disable proxy for this connection
|
||||
for var in proxy_vars:
|
||||
if var in os.environ:
|
||||
del os.environ[var]
|
||||
print_info(t("chat_proxy_disabled"))
|
||||
console.print()
|
||||
|
||||
# Initialize OpenAI client
|
||||
try:
|
||||
client = OpenAI(
|
||||
base_url=base_url,
|
||||
api_key="EMPTY", # SGLang doesn't require API key
|
||||
)
|
||||
|
||||
# Test connection
|
||||
print_info("Connecting to server...")
|
||||
models = client.models.list()
|
||||
available_models = [m.id for m in models.data]
|
||||
|
||||
if not available_models:
|
||||
print_error("No models available on server")
|
||||
raise typer.Exit(1)
|
||||
|
||||
# Select model
|
||||
if model:
|
||||
if model not in available_models:
|
||||
print_warning(f"Model '{model}' not found. Available models: {', '.join(available_models)}")
|
||||
selected_model = available_models[0]
|
||||
else:
|
||||
selected_model = model
|
||||
else:
|
||||
selected_model = available_models[0]
|
||||
|
||||
print_success(f"Connected to model: {selected_model}")
|
||||
console.print()
|
||||
|
||||
except Exception as e:
|
||||
print_error(f"Failed to connect to server: {e}")
|
||||
console.print()
|
||||
console.print("Make sure the model server is running:")
|
||||
console.print(" kt run <model>")
|
||||
raise typer.Exit(1)
|
||||
|
||||
# Initialize conversation history
|
||||
messages = []
|
||||
|
||||
# Add system prompt if provided
|
||||
if system_prompt:
|
||||
messages.append({"role": "system", "content": system_prompt})
|
||||
|
||||
# Setup history file
|
||||
if save_history:
|
||||
if history_file is None:
|
||||
history_dir = settings.config_dir / "chat_history"
|
||||
history_dir.mkdir(parents=True, exist_ok=True)
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
history_file = history_dir / f"chat_{timestamp}.json"
|
||||
else:
|
||||
history_file = Path(history_file)
|
||||
history_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Main chat loop
|
||||
try:
|
||||
while True:
|
||||
# Get user input
|
||||
try:
|
||||
user_input = Prompt.ask("[bold green]You[/bold green]")
|
||||
except (EOFError, KeyboardInterrupt):
|
||||
console.print()
|
||||
print_info("Goodbye!")
|
||||
break
|
||||
|
||||
if not user_input.strip():
|
||||
continue
|
||||
|
||||
# Handle special commands
|
||||
if user_input.startswith("/"):
|
||||
if _handle_command(user_input, messages, temperature, max_tokens):
|
||||
continue
|
||||
else:
|
||||
break # Exit command
|
||||
|
||||
# Add user message to history
|
||||
messages.append({"role": "user", "content": user_input})
|
||||
|
||||
# Generate response
|
||||
console.print()
|
||||
console.print("[bold cyan]Assistant[/bold cyan]")
|
||||
|
||||
try:
|
||||
if stream:
|
||||
# Streaming response
|
||||
response_content = _stream_response(client, selected_model, messages, temperature, max_tokens)
|
||||
else:
|
||||
# Non-streaming response
|
||||
response_content = _generate_response(client, selected_model, messages, temperature, max_tokens)
|
||||
|
||||
# Add assistant response to history
|
||||
messages.append({"role": "assistant", "content": response_content})
|
||||
|
||||
console.print()
|
||||
|
||||
except Exception as e:
|
||||
print_error(f"Error generating response: {e}")
|
||||
# Remove the user message that caused the error
|
||||
messages.pop()
|
||||
continue
|
||||
|
||||
# Save history if enabled
|
||||
if save_history:
|
||||
_save_history(history_file, messages, selected_model)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
console.print()
|
||||
console.print()
|
||||
print_info("Chat interrupted. Goodbye!")
|
||||
|
||||
# Final history save
|
||||
if save_history and messages:
|
||||
_save_history(history_file, messages, selected_model)
|
||||
console.print(f"[dim]History saved to: {history_file}[/dim]")
|
||||
console.print()
|
||||
|
||||
|
||||
def _stream_response(
|
||||
client: "OpenAI",
|
||||
model: str,
|
||||
messages: list,
|
||||
temperature: float,
|
||||
max_tokens: int,
|
||||
) -> str:
|
||||
"""Generate streaming response and display in real-time."""
|
||||
response_content = ""
|
||||
|
||||
try:
|
||||
stream = client.chat.completions.create(
|
||||
model=model,
|
||||
messages=messages,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
for chunk in stream:
|
||||
if chunk.choices[0].delta.content:
|
||||
content = chunk.choices[0].delta.content
|
||||
response_content += content
|
||||
console.print(content, end="")
|
||||
|
||||
console.print() # Newline after streaming
|
||||
|
||||
except Exception as e:
|
||||
raise Exception(f"Streaming error: {e}")
|
||||
|
||||
return response_content
|
||||
|
||||
|
||||
def _generate_response(
|
||||
client: "OpenAI",
|
||||
model: str,
|
||||
messages: list,
|
||||
temperature: float,
|
||||
max_tokens: int,
|
||||
) -> str:
|
||||
"""Generate non-streaming response."""
|
||||
try:
|
||||
response = client.chat.completions.create(
|
||||
model=model,
|
||||
messages=messages,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
stream=False,
|
||||
)
|
||||
|
||||
content = response.choices[0].message.content
|
||||
|
||||
# Display as markdown
|
||||
md = Markdown(content)
|
||||
console.print(md)
|
||||
|
||||
return content
|
||||
|
||||
except Exception as e:
|
||||
raise Exception(f"Generation error: {e}")
|
||||
|
||||
|
||||
def _handle_command(command: str, messages: list, temperature: float, max_tokens: int) -> bool:
|
||||
"""Handle special commands. Returns True to continue chat, False to exit."""
|
||||
cmd = command.lower().strip()
|
||||
|
||||
if cmd in ["/quit", "/exit", "/q"]:
|
||||
console.print()
|
||||
print_info("Goodbye!")
|
||||
return False
|
||||
|
||||
elif cmd in ["/help", "/h"]:
|
||||
console.print()
|
||||
console.print(
|
||||
Panel(
|
||||
"[bold]Available Commands:[/bold]\n\n"
|
||||
"/help, /h - Show this help message\n"
|
||||
"/quit, /exit, /q - Exit chat\n"
|
||||
"/clear, /c - Clear conversation history\n"
|
||||
"/history, /hist - Show conversation history\n"
|
||||
"/info, /i - Show current settings\n"
|
||||
"/retry, /r - Regenerate last response",
|
||||
title="Help",
|
||||
border_style="cyan",
|
||||
)
|
||||
)
|
||||
console.print()
|
||||
return True
|
||||
|
||||
elif cmd in ["/clear", "/c"]:
|
||||
messages.clear()
|
||||
console.print()
|
||||
print_success("Conversation history cleared")
|
||||
console.print()
|
||||
return True
|
||||
|
||||
elif cmd in ["/history", "/hist"]:
|
||||
console.print()
|
||||
if not messages:
|
||||
print_info("No conversation history")
|
||||
else:
|
||||
console.print(
|
||||
Panel(
|
||||
_format_history(messages),
|
||||
title=f"History ({len(messages)} messages)",
|
||||
border_style="cyan",
|
||||
)
|
||||
)
|
||||
console.print()
|
||||
return True
|
||||
|
||||
elif cmd in ["/info", "/i"]:
|
||||
console.print()
|
||||
console.print(
|
||||
Panel(
|
||||
f"[bold]Current Settings:[/bold]\n\n"
|
||||
f"Temperature: [cyan]{temperature}[/cyan]\n"
|
||||
f"Max tokens: [cyan]{max_tokens}[/cyan]\n"
|
||||
f"Messages: [cyan]{len(messages)}[/cyan]",
|
||||
title="Info",
|
||||
border_style="cyan",
|
||||
)
|
||||
)
|
||||
console.print()
|
||||
return True
|
||||
|
||||
elif cmd in ["/retry", "/r"]:
|
||||
if len(messages) >= 2 and messages[-1]["role"] == "assistant":
|
||||
# Remove last assistant response
|
||||
messages.pop()
|
||||
print_info("Retrying last response...")
|
||||
console.print()
|
||||
else:
|
||||
print_warning("No previous response to retry")
|
||||
console.print()
|
||||
return True
|
||||
|
||||
else:
|
||||
print_warning(f"Unknown command: {command}")
|
||||
console.print("[dim]Type /help for available commands[/dim]")
|
||||
console.print()
|
||||
return True
|
||||
|
||||
|
||||
def _format_history(messages: list) -> str:
|
||||
"""Format conversation history for display."""
|
||||
lines = []
|
||||
for i, msg in enumerate(messages, 1):
|
||||
role = msg["role"].capitalize()
|
||||
content = msg["content"]
|
||||
|
||||
# Truncate long messages
|
||||
if len(content) > 200:
|
||||
content = content[:200] + "..."
|
||||
|
||||
lines.append(f"[bold]{i}. {role}:[/bold] {content}")
|
||||
|
||||
return "\n\n".join(lines)
|
||||
|
||||
|
||||
def _save_history(file_path: Path, messages: list, model: str) -> None:
|
||||
"""Save conversation history to file."""
|
||||
try:
|
||||
history_data = {
|
||||
"model": model,
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"messages": messages,
|
||||
}
|
||||
|
||||
with open(file_path, "w", encoding="utf-8") as f:
|
||||
json.dump(history_data, f, indent=2, ensure_ascii=False)
|
||||
|
||||
except Exception as e:
|
||||
print_warning(f"Failed to save history: {e}")
|
||||
167
kt-kernel/python/cli/commands/config.py
Normal file
167
kt-kernel/python/cli/commands/config.py
Normal file
@@ -0,0 +1,167 @@
|
||||
"""
|
||||
Config command for kt-cli.
|
||||
|
||||
Manages kt-cli configuration.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import typer
|
||||
import yaml
|
||||
from rich.syntax import Syntax
|
||||
|
||||
from kt_kernel.cli.config.settings import get_settings
|
||||
from kt_kernel.cli.i18n import t
|
||||
from kt_kernel.cli.utils.console import confirm, console, print_error, print_success
|
||||
|
||||
app = typer.Typer(help="Manage kt-cli configuration")
|
||||
|
||||
|
||||
@app.command(name="init")
|
||||
def init() -> None:
|
||||
"""Initialize or re-run the first-time setup wizard."""
|
||||
from kt_kernel.cli.main import _show_first_run_setup
|
||||
from kt_kernel.cli.config.settings import get_settings
|
||||
|
||||
settings = get_settings()
|
||||
_show_first_run_setup(settings)
|
||||
|
||||
|
||||
@app.command(name="show")
|
||||
def show(
|
||||
key: Optional[str] = typer.Argument(None, help="Configuration key to show (e.g., server.port)"),
|
||||
) -> None:
|
||||
"""Show current configuration."""
|
||||
settings = get_settings()
|
||||
|
||||
if key:
|
||||
value = settings.get(key)
|
||||
if value is not None:
|
||||
if isinstance(value, (dict, list)):
|
||||
console.print(yaml.dump({key: value}, default_flow_style=False, allow_unicode=True))
|
||||
else:
|
||||
console.print(t("config_get_value", key=key, value=value))
|
||||
else:
|
||||
print_error(t("config_get_not_found", key=key))
|
||||
raise typer.Exit(1)
|
||||
else:
|
||||
console.print(f"\n[bold]{t('config_show_title')}[/bold]\n")
|
||||
console.print(f"[dim]{t('config_file_location', path=str(settings.config_path))}[/dim]\n")
|
||||
|
||||
config_yaml = yaml.dump(settings.get_all(), default_flow_style=False, allow_unicode=True)
|
||||
syntax = Syntax(config_yaml, "yaml", theme="monokai", line_numbers=False)
|
||||
console.print(syntax)
|
||||
|
||||
|
||||
@app.command(name="set")
|
||||
def set_config(
|
||||
key: str = typer.Argument(..., help="Configuration key (e.g., server.port)"),
|
||||
value: str = typer.Argument(..., help="Value to set"),
|
||||
) -> None:
|
||||
"""Set a configuration value."""
|
||||
settings = get_settings()
|
||||
|
||||
# Try to parse value as JSON/YAML for complex types
|
||||
parsed_value = _parse_value(value)
|
||||
|
||||
settings.set(key, parsed_value)
|
||||
print_success(t("config_set_success", key=key, value=parsed_value))
|
||||
|
||||
|
||||
@app.command(name="get")
|
||||
def get_config(
|
||||
key: str = typer.Argument(..., help="Configuration key (e.g., server.port)"),
|
||||
) -> None:
|
||||
"""Get a configuration value."""
|
||||
settings = get_settings()
|
||||
value = settings.get(key)
|
||||
|
||||
if value is not None:
|
||||
if isinstance(value, (dict, list)):
|
||||
console.print(yaml.dump(value, default_flow_style=False, allow_unicode=True))
|
||||
else:
|
||||
console.print(str(value))
|
||||
else:
|
||||
print_error(t("config_get_not_found", key=key))
|
||||
raise typer.Exit(1)
|
||||
|
||||
|
||||
@app.command(name="reset")
|
||||
def reset(
|
||||
yes: bool = typer.Option(False, "--yes", "-y", help="Skip confirmation"),
|
||||
) -> None:
|
||||
"""Reset configuration to defaults."""
|
||||
if not yes:
|
||||
if not confirm(t("config_reset_confirm"), default=False):
|
||||
raise typer.Abort()
|
||||
|
||||
settings = get_settings()
|
||||
settings.reset()
|
||||
print_success(t("config_reset_success"))
|
||||
|
||||
|
||||
@app.command(name="path")
|
||||
def path() -> None:
|
||||
"""Show configuration file path."""
|
||||
settings = get_settings()
|
||||
console.print(str(settings.config_path))
|
||||
|
||||
|
||||
@app.command(name="model-path-list", deprecated=True, hidden=True)
|
||||
def model_path_list() -> None:
|
||||
"""[Deprecated] Use 'kt model path-list' instead."""
|
||||
console.print("[yellow]⚠ This command is deprecated. Use 'kt model path-list' instead.[/yellow]\n")
|
||||
import subprocess
|
||||
subprocess.run(["kt", "model", "path-list"])
|
||||
|
||||
|
||||
@app.command(name="model-path-add", deprecated=True, hidden=True)
|
||||
def model_path_add(
|
||||
path: str = typer.Argument(..., help="Path to add"),
|
||||
) -> None:
|
||||
"""[Deprecated] Use 'kt model path-add' instead."""
|
||||
console.print("[yellow]⚠ This command is deprecated. Use 'kt model path-add' instead.[/yellow]\n")
|
||||
import subprocess
|
||||
subprocess.run(["kt", "model", "path-add", path])
|
||||
|
||||
|
||||
@app.command(name="model-path-remove", deprecated=True, hidden=True)
|
||||
def model_path_remove(
|
||||
path: str = typer.Argument(..., help="Path to remove"),
|
||||
) -> None:
|
||||
"""[Deprecated] Use 'kt model path-remove' instead."""
|
||||
console.print("[yellow]⚠ This command is deprecated. Use 'kt model path-remove' instead.[/yellow]\n")
|
||||
import subprocess
|
||||
subprocess.run(["kt", "model", "path-remove", path])
|
||||
|
||||
|
||||
def _parse_value(value: str):
|
||||
"""Parse a string value into appropriate Python type."""
|
||||
# Try boolean
|
||||
if value.lower() in ("true", "yes", "on", "1"):
|
||||
return True
|
||||
if value.lower() in ("false", "no", "off", "0"):
|
||||
return False
|
||||
|
||||
# Try integer
|
||||
try:
|
||||
return int(value)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
# Try float
|
||||
try:
|
||||
return float(value)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
# Try YAML/JSON parsing for lists/dicts
|
||||
try:
|
||||
parsed = yaml.safe_load(value)
|
||||
if isinstance(parsed, (dict, list)):
|
||||
return parsed
|
||||
except yaml.YAMLError:
|
||||
pass
|
||||
|
||||
# Return as string
|
||||
return value
|
||||
394
kt-kernel/python/cli/commands/doctor.py
Normal file
394
kt-kernel/python/cli/commands/doctor.py
Normal file
@@ -0,0 +1,394 @@
|
||||
"""
|
||||
Doctor command for kt-cli.
|
||||
|
||||
Diagnoses environment issues and provides recommendations.
|
||||
"""
|
||||
|
||||
import platform
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import typer
|
||||
from rich.table import Table
|
||||
|
||||
from kt_kernel.cli.config.settings import get_settings
|
||||
from kt_kernel.cli.i18n import t
|
||||
from kt_kernel.cli.utils.console import console, print_error, print_info, print_success, print_warning
|
||||
from kt_kernel.cli.utils.environment import (
|
||||
check_docker,
|
||||
detect_available_ram_gb,
|
||||
detect_cpu_info,
|
||||
detect_cuda_version,
|
||||
detect_disk_space_gb,
|
||||
detect_env_managers,
|
||||
detect_gpus,
|
||||
detect_memory_info,
|
||||
detect_ram_gb,
|
||||
get_installed_package_version,
|
||||
)
|
||||
|
||||
|
||||
def doctor(
|
||||
verbose: bool = typer.Option(False, "--verbose", "-v", help="Show detailed diagnostics"),
|
||||
) -> None:
|
||||
"""Diagnose environment issues."""
|
||||
console.print(f"\n[bold]{t('doctor_title')}[/bold]\n")
|
||||
|
||||
issues_found = False
|
||||
checks = []
|
||||
|
||||
# 1. Python version
|
||||
python_version = platform.python_version()
|
||||
python_ok = _check_python_version(python_version)
|
||||
checks.append(
|
||||
{
|
||||
"name": t("doctor_check_python"),
|
||||
"status": "ok" if python_ok else "error",
|
||||
"value": python_version,
|
||||
"hint": "Python 3.10+ required" if not python_ok else None,
|
||||
}
|
||||
)
|
||||
if not python_ok:
|
||||
issues_found = True
|
||||
|
||||
# 2. CUDA availability
|
||||
cuda_version = detect_cuda_version()
|
||||
checks.append(
|
||||
{
|
||||
"name": t("doctor_check_cuda"),
|
||||
"status": "ok" if cuda_version else "warning",
|
||||
"value": cuda_version or t("version_cuda_not_found"),
|
||||
"hint": "CUDA is optional but recommended for GPU acceleration" if not cuda_version else None,
|
||||
}
|
||||
)
|
||||
|
||||
# 3. GPU detection
|
||||
gpus = detect_gpus()
|
||||
if gpus:
|
||||
gpu_names = ", ".join(g.name for g in gpus)
|
||||
total_vram = sum(g.vram_gb for g in gpus)
|
||||
checks.append(
|
||||
{
|
||||
"name": t("doctor_check_gpu"),
|
||||
"status": "ok",
|
||||
"value": t("doctor_gpu_found", count=len(gpus), names=gpu_names),
|
||||
"hint": f"Total VRAM: {total_vram}GB",
|
||||
}
|
||||
)
|
||||
else:
|
||||
checks.append(
|
||||
{
|
||||
"name": t("doctor_check_gpu"),
|
||||
"status": "warning",
|
||||
"value": t("doctor_gpu_not_found"),
|
||||
"hint": "GPU recommended for best performance",
|
||||
}
|
||||
)
|
||||
|
||||
# 4. CPU information
|
||||
cpu_info = detect_cpu_info()
|
||||
checks.append(
|
||||
{
|
||||
"name": t("doctor_check_cpu"),
|
||||
"status": "ok",
|
||||
"value": t("doctor_cpu_info", name=cpu_info.name, cores=cpu_info.cores, threads=cpu_info.threads),
|
||||
"hint": None,
|
||||
}
|
||||
)
|
||||
|
||||
# 5. CPU instruction sets (critical for kt-kernel)
|
||||
isa_list = cpu_info.instruction_sets
|
||||
# Check for recommended instruction sets
|
||||
recommended_isa = {"AVX2", "AVX512F", "AMX-INT8"}
|
||||
has_recommended = bool(set(isa_list) & recommended_isa)
|
||||
has_avx2 = "AVX2" in isa_list
|
||||
has_avx512 = any(isa.startswith("AVX512") for isa in isa_list)
|
||||
has_amx = any(isa.startswith("AMX") for isa in isa_list)
|
||||
|
||||
# Determine status and build display string
|
||||
if has_amx:
|
||||
isa_status = "ok"
|
||||
isa_hint = "AMX available - best performance for INT4/INT8"
|
||||
elif has_avx512:
|
||||
isa_status = "ok"
|
||||
isa_hint = "AVX512 available - good performance"
|
||||
elif has_avx2:
|
||||
isa_status = "warning"
|
||||
isa_hint = "AVX2 only - consider upgrading CPU for better performance"
|
||||
else:
|
||||
isa_status = "error"
|
||||
isa_hint = "AVX2 required for kt-kernel"
|
||||
|
||||
# Show top instruction sets (prioritize important ones)
|
||||
display_isa = isa_list[:8] if len(isa_list) > 8 else isa_list
|
||||
isa_display = ", ".join(display_isa)
|
||||
if len(isa_list) > 8:
|
||||
isa_display += f" (+{len(isa_list) - 8} more)"
|
||||
|
||||
checks.append(
|
||||
{
|
||||
"name": t("doctor_check_cpu_isa"),
|
||||
"status": isa_status,
|
||||
"value": isa_display if isa_display else "None detected",
|
||||
"hint": isa_hint,
|
||||
}
|
||||
)
|
||||
|
||||
# 6. NUMA topology
|
||||
numa_detail = []
|
||||
for node, cpus in sorted(cpu_info.numa_info.items()):
|
||||
if len(cpus) > 6:
|
||||
cpu_str = f"{cpus[0]}-{cpus[-1]}"
|
||||
else:
|
||||
cpu_str = ",".join(str(c) for c in cpus)
|
||||
numa_detail.append(f"{node}: {cpu_str}")
|
||||
|
||||
numa_value = t("doctor_numa_info", nodes=cpu_info.numa_nodes)
|
||||
if verbose and numa_detail:
|
||||
numa_value += " (" + "; ".join(numa_detail) + ")"
|
||||
|
||||
checks.append(
|
||||
{
|
||||
"name": t("doctor_check_numa"),
|
||||
"status": "ok",
|
||||
"value": numa_value,
|
||||
"hint": f"{cpu_info.threads // cpu_info.numa_nodes} threads per node" if cpu_info.numa_nodes > 1 else None,
|
||||
}
|
||||
)
|
||||
|
||||
# 7. System memory (with frequency if available)
|
||||
mem_info = detect_memory_info()
|
||||
if mem_info.frequency_mhz and mem_info.type:
|
||||
mem_value = t(
|
||||
"doctor_memory_freq",
|
||||
available=f"{mem_info.available_gb}GB",
|
||||
total=f"{mem_info.total_gb}GB",
|
||||
freq=mem_info.frequency_mhz,
|
||||
type=mem_info.type,
|
||||
)
|
||||
else:
|
||||
mem_value = t("doctor_memory_info", available=f"{mem_info.available_gb}GB", total=f"{mem_info.total_gb}GB")
|
||||
|
||||
ram_ok = mem_info.total_gb >= 32
|
||||
checks.append(
|
||||
{
|
||||
"name": t("doctor_check_memory"),
|
||||
"status": "ok" if ram_ok else "warning",
|
||||
"value": mem_value,
|
||||
"hint": "32GB+ RAM recommended for large models" if not ram_ok else None,
|
||||
}
|
||||
)
|
||||
|
||||
# 8. Disk space - check all model paths
|
||||
settings = get_settings()
|
||||
model_paths = settings.get_model_paths()
|
||||
|
||||
# Check all configured model paths
|
||||
for i, disk_path in enumerate(model_paths):
|
||||
available_disk, total_disk = detect_disk_space_gb(str(disk_path))
|
||||
disk_ok = available_disk >= 100
|
||||
|
||||
# For multiple paths, add index to name
|
||||
path_label = f"Model Path {i+1}" if len(model_paths) > 1 else t("doctor_check_disk")
|
||||
|
||||
checks.append(
|
||||
{
|
||||
"name": path_label,
|
||||
"status": "ok" if disk_ok else "warning",
|
||||
"value": t("doctor_disk_info", available=f"{available_disk}GB", path=str(disk_path)),
|
||||
"hint": "100GB+ free space recommended for model storage" if not disk_ok else None,
|
||||
}
|
||||
)
|
||||
|
||||
# 6. Required packages
|
||||
packages = [
|
||||
("kt-kernel", ">=0.4.0", False), # name, version_req, required
|
||||
("ktransformers", ">=0.4.0", False),
|
||||
("sglang", ">=0.4.0", False),
|
||||
("torch", ">=2.4.0", True),
|
||||
("transformers", ">=4.45.0", True),
|
||||
]
|
||||
|
||||
package_issues = []
|
||||
for pkg_name, version_req, required in packages:
|
||||
version = get_installed_package_version(pkg_name)
|
||||
if version:
|
||||
package_issues.append((pkg_name, version, "ok"))
|
||||
elif required:
|
||||
package_issues.append((pkg_name, t("version_not_installed"), "error"))
|
||||
issues_found = True
|
||||
else:
|
||||
package_issues.append((pkg_name, t("version_not_installed"), "warning"))
|
||||
|
||||
if verbose:
|
||||
checks.append(
|
||||
{
|
||||
"name": t("doctor_check_packages"),
|
||||
"status": "ok" if not any(p[2] == "error" for p in package_issues) else "error",
|
||||
"value": f"{sum(1 for p in package_issues if p[2] == 'ok')}/{len(package_issues)} installed",
|
||||
"packages": package_issues,
|
||||
}
|
||||
)
|
||||
|
||||
# 7. SGLang installation source check
|
||||
from kt_kernel.cli.utils.sglang_checker import check_sglang_installation, check_sglang_kt_kernel_support
|
||||
|
||||
sglang_info = check_sglang_installation()
|
||||
|
||||
if sglang_info["installed"]:
|
||||
if sglang_info["from_source"]:
|
||||
if sglang_info["git_info"]:
|
||||
git_remote = sglang_info["git_info"].get("remote", "unknown")
|
||||
git_branch = sglang_info["git_info"].get("branch", "unknown")
|
||||
sglang_source_value = f"Source (GitHub: {git_remote}, branch: {git_branch})"
|
||||
sglang_source_status = "ok"
|
||||
sglang_source_hint = None
|
||||
else:
|
||||
sglang_source_value = "Source (editable)"
|
||||
sglang_source_status = "ok"
|
||||
sglang_source_hint = None
|
||||
else:
|
||||
sglang_source_value = "PyPI (not recommended)"
|
||||
sglang_source_status = "warning"
|
||||
sglang_source_hint = t("sglang_pypi_hint")
|
||||
else:
|
||||
sglang_source_value = "Not installed"
|
||||
sglang_source_status = "warning"
|
||||
sglang_source_hint = t("sglang_install_hint")
|
||||
|
||||
checks.append(
|
||||
{
|
||||
"name": "SGLang Source",
|
||||
"status": sglang_source_status,
|
||||
"value": sglang_source_value,
|
||||
"hint": sglang_source_hint,
|
||||
}
|
||||
)
|
||||
|
||||
# 7b. SGLang kt-kernel support check (only if SGLang is installed)
|
||||
kt_kernel_support = {"supported": True} # Default to True if not checked
|
||||
if sglang_info["installed"]:
|
||||
# Use cache=False to force re-check in doctor, but silent=True since we show in table
|
||||
kt_kernel_support = check_sglang_kt_kernel_support(use_cache=False, silent=True)
|
||||
|
||||
if kt_kernel_support["supported"]:
|
||||
kt_kernel_value = t("sglang_kt_kernel_supported")
|
||||
kt_kernel_status = "ok"
|
||||
kt_kernel_hint = None
|
||||
else:
|
||||
kt_kernel_value = t("sglang_kt_kernel_not_supported")
|
||||
kt_kernel_status = "error"
|
||||
kt_kernel_hint = 'Reinstall SGLang from: git clone https://github.com/kvcache-ai/sglang && cd sglang && pip install -e "python[all]"'
|
||||
issues_found = True
|
||||
|
||||
checks.append(
|
||||
{
|
||||
"name": "SGLang kt-kernel",
|
||||
"status": kt_kernel_status,
|
||||
"value": kt_kernel_value,
|
||||
"hint": kt_kernel_hint,
|
||||
}
|
||||
)
|
||||
|
||||
# 8. Environment managers
|
||||
env_managers = detect_env_managers()
|
||||
docker = check_docker()
|
||||
env_list = [f"{m.name} {m.version}" for m in env_managers]
|
||||
if docker:
|
||||
env_list.append(f"docker {docker.version}")
|
||||
|
||||
checks.append(
|
||||
{
|
||||
"name": "Environment Managers",
|
||||
"status": "ok" if env_list else "warning",
|
||||
"value": ", ".join(env_list) if env_list else "None found",
|
||||
"hint": "conda or docker recommended for installation" if not env_list else None,
|
||||
}
|
||||
)
|
||||
|
||||
# Display results
|
||||
_display_results(checks, verbose)
|
||||
|
||||
# Show SGLang installation instructions if not installed
|
||||
if not sglang_info["installed"]:
|
||||
from kt_kernel.cli.utils.sglang_checker import print_sglang_install_instructions
|
||||
|
||||
console.print()
|
||||
print_sglang_install_instructions()
|
||||
# Show kt-kernel installation instructions if SGLang is installed but doesn't support kt-kernel
|
||||
elif sglang_info["installed"] and not kt_kernel_support.get("supported", True):
|
||||
from kt_kernel.cli.utils.sglang_checker import print_sglang_kt_kernel_instructions
|
||||
|
||||
console.print()
|
||||
print_sglang_kt_kernel_instructions()
|
||||
|
||||
# Summary
|
||||
console.print()
|
||||
if issues_found:
|
||||
print_warning(t("doctor_has_issues"))
|
||||
else:
|
||||
print_success(t("doctor_all_ok"))
|
||||
console.print()
|
||||
|
||||
|
||||
def _check_python_version(version: str) -> bool:
|
||||
"""Check if Python version meets requirements."""
|
||||
parts = version.split(".")
|
||||
try:
|
||||
major, minor = int(parts[0]), int(parts[1])
|
||||
return major >= 3 and minor >= 10
|
||||
except (IndexError, ValueError):
|
||||
return False
|
||||
|
||||
|
||||
def _display_results(checks: list[dict], verbose: bool) -> None:
|
||||
"""Display diagnostic results."""
|
||||
table = Table(show_header=True, header_style="bold")
|
||||
table.add_column("Check", style="bold")
|
||||
table.add_column("Status", width=8)
|
||||
table.add_column("Value")
|
||||
if verbose:
|
||||
table.add_column("Notes", style="dim")
|
||||
|
||||
for check in checks:
|
||||
status = check["status"]
|
||||
if status == "ok":
|
||||
status_str = f"[green]{t('doctor_status_ok')}[/green]"
|
||||
elif status == "warning":
|
||||
status_str = f"[yellow]{t('doctor_status_warning')}[/yellow]"
|
||||
else:
|
||||
status_str = f"[red]{t('doctor_status_error')}[/red]"
|
||||
|
||||
if verbose:
|
||||
table.add_row(
|
||||
check["name"],
|
||||
status_str,
|
||||
check["value"],
|
||||
check.get("hint", ""),
|
||||
)
|
||||
else:
|
||||
table.add_row(
|
||||
check["name"],
|
||||
status_str,
|
||||
check["value"],
|
||||
)
|
||||
|
||||
# Show package details if verbose
|
||||
if verbose and "packages" in check:
|
||||
for pkg_name, pkg_version, pkg_status in check["packages"]:
|
||||
if pkg_status == "ok":
|
||||
pkg_status_str = "[green]✓[/green]"
|
||||
elif pkg_status == "warning":
|
||||
pkg_status_str = "[yellow]○[/yellow]"
|
||||
else:
|
||||
pkg_status_str = "[red]✗[/red]"
|
||||
|
||||
table.add_row(
|
||||
f" └─ {pkg_name}",
|
||||
pkg_status_str,
|
||||
pkg_version,
|
||||
"",
|
||||
)
|
||||
|
||||
console.print(table)
|
||||
409
kt-kernel/python/cli/commands/model.py
Normal file
409
kt-kernel/python/cli/commands/model.py
Normal file
@@ -0,0 +1,409 @@
|
||||
"""
|
||||
Model command for kt-cli.
|
||||
|
||||
Manages models: download, list, and storage paths.
|
||||
"""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import typer
|
||||
|
||||
from kt_kernel.cli.config.settings import get_settings
|
||||
from kt_kernel.cli.i18n import t
|
||||
from kt_kernel.cli.utils.console import (
|
||||
confirm,
|
||||
console,
|
||||
print_error,
|
||||
print_info,
|
||||
print_success,
|
||||
print_warning,
|
||||
prompt_choice,
|
||||
)
|
||||
|
||||
app = typer.Typer(
|
||||
help="Manage models and storage paths",
|
||||
invoke_without_command=True,
|
||||
no_args_is_help=False,
|
||||
)
|
||||
|
||||
|
||||
@app.callback()
|
||||
def callback(ctx: typer.Context) -> None:
|
||||
"""
|
||||
Model management commands.
|
||||
|
||||
Run without arguments to see available models.
|
||||
"""
|
||||
# If no subcommand is provided, show the model list
|
||||
if ctx.invoked_subcommand is None:
|
||||
show_model_list()
|
||||
|
||||
|
||||
def show_model_list() -> None:
|
||||
"""Display available models with their status and paths."""
|
||||
from rich.table import Table
|
||||
from kt_kernel.cli.utils.model_registry import get_registry
|
||||
from kt_kernel.cli.i18n import get_lang
|
||||
|
||||
registry = get_registry()
|
||||
settings = get_settings()
|
||||
|
||||
console.print()
|
||||
console.print(f"[bold cyan]{t('model_supported_title')}[/bold cyan]\n")
|
||||
|
||||
# Get local models mapping
|
||||
local_models = {m.name: p for m, p in registry.find_local_models()}
|
||||
|
||||
# Create table
|
||||
table = Table(show_header=True, header_style="bold")
|
||||
table.add_column(t("model_column_model"), style="cyan", no_wrap=True)
|
||||
table.add_column(t("model_column_status"), justify="center")
|
||||
|
||||
all_models = registry.list_all()
|
||||
for model in all_models:
|
||||
if model.name in local_models:
|
||||
status = f"[green]✓ {t('model_status_local')}[/green]"
|
||||
else:
|
||||
status = "[dim]-[/dim]"
|
||||
|
||||
table.add_row(model.name, status)
|
||||
|
||||
console.print(table)
|
||||
console.print()
|
||||
|
||||
# Usage instructions
|
||||
console.print(f"[bold]{t('model_usage_title')}:[/bold]")
|
||||
console.print(f" • {t('model_usage_download')} [cyan]kt model download <model-name>[/cyan]")
|
||||
console.print(f" • {t('model_usage_list_local')} [cyan]kt model list --local[/cyan]")
|
||||
console.print(f" • {t('model_usage_search')} [cyan]kt model search <query>[/cyan]")
|
||||
console.print()
|
||||
|
||||
# Show model storage paths
|
||||
model_paths = settings.get_model_paths()
|
||||
console.print(f"[bold]{t('model_storage_paths_title')}:[/bold]")
|
||||
for path in model_paths:
|
||||
marker = "[green]✓[/green]" if path.exists() else "[dim]✗[/dim]"
|
||||
console.print(f" {marker} {path}")
|
||||
console.print()
|
||||
|
||||
|
||||
@app.command(name="download")
|
||||
def download(
|
||||
model: Optional[str] = typer.Argument(
|
||||
None,
|
||||
help="Model name or HuggingFace repo (e.g., deepseek-v3, Qwen/Qwen3-30B)",
|
||||
),
|
||||
path: Optional[Path] = typer.Option(
|
||||
None,
|
||||
"--path",
|
||||
"-p",
|
||||
help="Custom download path",
|
||||
),
|
||||
list_models: bool = typer.Option(
|
||||
False,
|
||||
"--list",
|
||||
"-l",
|
||||
help="List available models",
|
||||
),
|
||||
resume: bool = typer.Option(
|
||||
True,
|
||||
"--resume/--no-resume",
|
||||
help="Resume incomplete downloads",
|
||||
),
|
||||
yes: bool = typer.Option(
|
||||
False,
|
||||
"--yes",
|
||||
"-y",
|
||||
help="Skip confirmation prompts",
|
||||
),
|
||||
) -> None:
|
||||
"""Download model weights from HuggingFace."""
|
||||
import subprocess
|
||||
from kt_kernel.cli.i18n import get_lang
|
||||
from kt_kernel.cli.utils.console import print_model_table, print_step
|
||||
from kt_kernel.cli.utils.model_registry import get_registry
|
||||
|
||||
settings = get_settings()
|
||||
registry = get_registry()
|
||||
|
||||
console.print()
|
||||
|
||||
# List mode
|
||||
if list_models or model is None:
|
||||
print_step(t("download_list_title"))
|
||||
console.print()
|
||||
|
||||
models = registry.list_all()
|
||||
model_dicts = []
|
||||
for m in models:
|
||||
lang = get_lang()
|
||||
desc = m.description_zh if lang == "zh" and m.description_zh else m.description
|
||||
model_dicts.append(
|
||||
{
|
||||
"name": m.name,
|
||||
"hf_repo": m.hf_repo,
|
||||
"type": m.type,
|
||||
"gpu_vram_gb": m.gpu_vram_gb,
|
||||
"cpu_ram_gb": m.cpu_ram_gb,
|
||||
}
|
||||
)
|
||||
|
||||
print_model_table(model_dicts)
|
||||
console.print()
|
||||
|
||||
if model is None:
|
||||
console.print(f"[dim]{t('model_download_usage_hint')}[/dim]")
|
||||
console.print()
|
||||
return
|
||||
|
||||
# Search for model
|
||||
print_step(t("download_searching", name=model))
|
||||
|
||||
# Check if it's a direct HuggingFace repo path
|
||||
if "/" in model:
|
||||
hf_repo = model
|
||||
model_info = None
|
||||
model_name = model.split("/")[-1]
|
||||
else:
|
||||
matches = registry.search(model)
|
||||
|
||||
if not matches:
|
||||
print_error(t("run_model_not_found", name=model))
|
||||
console.print()
|
||||
console.print(t("model_download_list_hint"))
|
||||
console.print(t("model_download_hf_hint"))
|
||||
raise typer.Exit(1)
|
||||
|
||||
if len(matches) == 1:
|
||||
model_info = matches[0]
|
||||
else:
|
||||
console.print()
|
||||
print_info(t("download_multiple_found"))
|
||||
choices = [f"{m.name} ({m.hf_repo})" for m in matches]
|
||||
selected = prompt_choice(t("download_select"), choices)
|
||||
idx = choices.index(selected)
|
||||
model_info = matches[idx]
|
||||
|
||||
hf_repo = model_info.hf_repo
|
||||
model_name = model_info.name
|
||||
|
||||
print_success(t("download_found", name=hf_repo))
|
||||
|
||||
# Determine download path
|
||||
if path is None:
|
||||
download_path = settings.models_dir / model_name.replace(" ", "-")
|
||||
else:
|
||||
download_path = path
|
||||
|
||||
console.print()
|
||||
print_info(t("download_destination", path=str(download_path)))
|
||||
|
||||
# Check if already exists
|
||||
if download_path.exists() and (download_path / "config.json").exists():
|
||||
print_warning(t("download_already_exists", path=str(download_path)))
|
||||
if not yes:
|
||||
if not confirm(t("download_overwrite_prompt"), default=False):
|
||||
raise typer.Abort()
|
||||
|
||||
# Confirm download
|
||||
if not yes:
|
||||
console.print()
|
||||
if not confirm(t("prompt_continue")):
|
||||
raise typer.Abort()
|
||||
|
||||
# Download using huggingface-cli
|
||||
console.print()
|
||||
print_step(t("download_starting"))
|
||||
|
||||
cmd = [
|
||||
"huggingface-cli",
|
||||
"download",
|
||||
hf_repo,
|
||||
"--local-dir",
|
||||
str(download_path),
|
||||
]
|
||||
|
||||
if resume:
|
||||
cmd.append("--resume-download")
|
||||
|
||||
# Add mirror if configured
|
||||
mirror = settings.get("download.mirror", "")
|
||||
if mirror:
|
||||
cmd.extend(["--endpoint", mirror])
|
||||
|
||||
try:
|
||||
process = subprocess.run(cmd, check=True)
|
||||
|
||||
console.print()
|
||||
print_success(t("download_complete"))
|
||||
console.print()
|
||||
console.print(f" {t('model_saved_to', path=download_path)}")
|
||||
console.print()
|
||||
console.print(f" {t('model_start_with', name=model_name)}")
|
||||
console.print()
|
||||
|
||||
except subprocess.CalledProcessError as e:
|
||||
print_error(t("model_download_failed", error=str(e)))
|
||||
raise typer.Exit(1)
|
||||
except FileNotFoundError:
|
||||
print_error(t("model_hf_cli_not_found"))
|
||||
raise typer.Exit(1)
|
||||
|
||||
|
||||
@app.command(name="list")
|
||||
def list_models(
|
||||
local_only: bool = typer.Option(False, "--local", help="Show only locally downloaded models"),
|
||||
verbose: bool = typer.Option(False, "--verbose", "-v", help="Show detailed info including paths"),
|
||||
) -> None:
|
||||
"""List available models."""
|
||||
from rich.table import Table
|
||||
from kt_kernel.cli.utils.model_registry import get_registry
|
||||
|
||||
registry = get_registry()
|
||||
console.print()
|
||||
|
||||
if local_only:
|
||||
# Show only local models
|
||||
local_models = registry.find_local_models()
|
||||
|
||||
if not local_models:
|
||||
print_warning(t("model_no_local_models"))
|
||||
console.print()
|
||||
console.print(f" {t('model_download_hint')} [cyan]kt model download <model-name>[/cyan]")
|
||||
console.print()
|
||||
return
|
||||
|
||||
table = Table(title=t("model_local_models_title"), show_header=True, header_style="bold")
|
||||
table.add_column(t("model_column_model"), style="cyan", no_wrap=True)
|
||||
if verbose:
|
||||
table.add_column(t("model_column_local_path"), style="dim")
|
||||
|
||||
for model_info, model_path in local_models:
|
||||
if verbose:
|
||||
table.add_row(model_info.name, str(model_path))
|
||||
else:
|
||||
table.add_row(model_info.name)
|
||||
|
||||
console.print(table)
|
||||
else:
|
||||
# Show all registered models
|
||||
all_models = registry.list_all()
|
||||
local_models_dict = {m.name: p for m, p in registry.find_local_models()}
|
||||
|
||||
table = Table(title=t("model_available_models_title"), show_header=True, header_style="bold")
|
||||
table.add_column(t("model_column_model"), style="cyan", no_wrap=True)
|
||||
table.add_column(t("model_column_status"), justify="center")
|
||||
if verbose:
|
||||
table.add_column(t("model_column_local_path"), style="dim")
|
||||
|
||||
for model in all_models:
|
||||
if model.name in local_models_dict:
|
||||
status = f"[green]✓ {t('model_status_local')}[/green]"
|
||||
local_path = str(local_models_dict[model.name])
|
||||
else:
|
||||
status = "[dim]-[/dim]"
|
||||
local_path = f"[dim]{t('model_status_not_downloaded')}[/dim]"
|
||||
|
||||
if verbose:
|
||||
table.add_row(model.name, status, local_path)
|
||||
else:
|
||||
table.add_row(model.name, status)
|
||||
|
||||
console.print(table)
|
||||
|
||||
console.print()
|
||||
|
||||
|
||||
@app.command(name="path-list")
|
||||
def path_list() -> None:
|
||||
"""List all configured model storage paths."""
|
||||
settings = get_settings()
|
||||
model_paths = settings.get_model_paths()
|
||||
|
||||
console.print()
|
||||
console.print(f"[bold]{t('model_storage_paths_title')}:[/bold]\n")
|
||||
|
||||
for i, path in enumerate(model_paths, 1):
|
||||
marker = "[green]✓[/green]" if path.exists() else "[red]✗[/red]"
|
||||
console.print(f" {marker} [{i}] {path}")
|
||||
|
||||
console.print()
|
||||
|
||||
|
||||
@app.command(name="path-add")
|
||||
def path_add(
|
||||
path: str = typer.Argument(..., help="Path to add"),
|
||||
) -> None:
|
||||
"""Add a new model storage path."""
|
||||
# Expand user home directory
|
||||
path = os.path.expanduser(path)
|
||||
|
||||
# Check if path exists or can be created
|
||||
path_obj = Path(path)
|
||||
if not path_obj.exists():
|
||||
console.print(f"[yellow]{t('model_path_not_exist', path=path)}[/yellow]")
|
||||
if confirm(t("model_create_directory", path=path), default=True):
|
||||
try:
|
||||
path_obj.mkdir(parents=True, exist_ok=True)
|
||||
console.print(f"[green]✓[/green] {t('model_created_directory', path=path)}")
|
||||
except (OSError, PermissionError) as e:
|
||||
print_error(t("model_create_dir_failed", error=str(e)))
|
||||
raise typer.Exit(1)
|
||||
else:
|
||||
raise typer.Abort()
|
||||
|
||||
# Add to configuration
|
||||
settings = get_settings()
|
||||
settings.add_model_path(path)
|
||||
print_success(t("model_path_added", path=path))
|
||||
|
||||
|
||||
@app.command(name="path-remove")
|
||||
def path_remove(
|
||||
path: str = typer.Argument(..., help="Path to remove"),
|
||||
) -> None:
|
||||
"""Remove a model storage path from configuration."""
|
||||
# Expand user home directory
|
||||
path = os.path.expanduser(path)
|
||||
|
||||
settings = get_settings()
|
||||
if settings.remove_model_path(path):
|
||||
print_success(t("model_path_removed", path=path))
|
||||
else:
|
||||
print_error(t("model_path_not_found", path=path))
|
||||
raise typer.Exit(1)
|
||||
|
||||
|
||||
@app.command(name="search")
|
||||
def search(
|
||||
query: str = typer.Argument(..., help="Search query (model name or keyword)"),
|
||||
) -> None:
|
||||
"""Search for models in the registry."""
|
||||
from rich.table import Table
|
||||
from kt_kernel.cli.utils.model_registry import get_registry
|
||||
|
||||
registry = get_registry()
|
||||
matches = registry.search(query)
|
||||
|
||||
console.print()
|
||||
|
||||
if not matches:
|
||||
print_warning(t("model_search_no_results", query=query))
|
||||
console.print()
|
||||
return
|
||||
|
||||
table = Table(title=t("model_search_results_title", query=query), show_header=True)
|
||||
table.add_column(t("model_column_name"), style="cyan")
|
||||
table.add_column(t("model_column_hf_repo"), style="dim")
|
||||
table.add_column(t("model_column_aliases"), style="yellow")
|
||||
|
||||
for model in matches:
|
||||
aliases = ", ".join(model.aliases[:3])
|
||||
if len(model.aliases) > 3:
|
||||
aliases += f" +{len(model.aliases) - 3} more"
|
||||
table.add_row(model.name, model.hf_repo, aliases)
|
||||
|
||||
console.print(table)
|
||||
console.print()
|
||||
239
kt-kernel/python/cli/commands/quant.py
Normal file
239
kt-kernel/python/cli/commands/quant.py
Normal file
@@ -0,0 +1,239 @@
|
||||
"""
|
||||
Quant command for kt-cli.
|
||||
|
||||
Quantizes model weights for CPU inference.
|
||||
"""
|
||||
|
||||
import subprocess
|
||||
import sys
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import typer
|
||||
|
||||
from kt_kernel.cli.config.settings import get_settings
|
||||
from kt_kernel.cli.i18n import t
|
||||
from kt_kernel.cli.utils.console import (
|
||||
confirm,
|
||||
console,
|
||||
create_progress,
|
||||
print_error,
|
||||
print_info,
|
||||
print_step,
|
||||
print_success,
|
||||
print_warning,
|
||||
)
|
||||
from kt_kernel.cli.utils.environment import detect_cpu_info
|
||||
|
||||
|
||||
class QuantMethod(str, Enum):
|
||||
"""Quantization method."""
|
||||
|
||||
INT4 = "int4"
|
||||
INT8 = "int8"
|
||||
|
||||
|
||||
def quant(
|
||||
model: str = typer.Argument(
|
||||
...,
|
||||
help="Model name or path to quantize",
|
||||
),
|
||||
method: QuantMethod = typer.Option(
|
||||
QuantMethod.INT4,
|
||||
"--method",
|
||||
"-m",
|
||||
help="Quantization method",
|
||||
),
|
||||
output: Optional[Path] = typer.Option(
|
||||
None,
|
||||
"--output",
|
||||
"-o",
|
||||
help="Output path for quantized weights",
|
||||
),
|
||||
input_type: str = typer.Option(
|
||||
"fp8",
|
||||
"--input-type",
|
||||
"-i",
|
||||
help="Input weight type (fp8, fp16, bf16)",
|
||||
),
|
||||
cpu_threads: Optional[int] = typer.Option(
|
||||
None,
|
||||
"--cpu-threads",
|
||||
help="Number of CPU threads for quantization",
|
||||
),
|
||||
numa_nodes: Optional[int] = typer.Option(
|
||||
None,
|
||||
"--numa-nodes",
|
||||
help="Number of NUMA nodes",
|
||||
),
|
||||
no_merge: bool = typer.Option(
|
||||
False,
|
||||
"--no-merge",
|
||||
help="Don't merge safetensor files",
|
||||
),
|
||||
yes: bool = typer.Option(
|
||||
False,
|
||||
"--yes",
|
||||
"-y",
|
||||
help="Skip confirmation prompts",
|
||||
),
|
||||
) -> None:
|
||||
"""Quantize model weights for CPU inference."""
|
||||
settings = get_settings()
|
||||
console.print()
|
||||
|
||||
# Resolve input path
|
||||
input_path = _resolve_input_path(model, settings)
|
||||
if input_path is None:
|
||||
print_error(t("quant_input_not_found", path=model))
|
||||
raise typer.Exit(1)
|
||||
|
||||
print_info(t("quant_input_path", path=str(input_path)))
|
||||
|
||||
# Resolve output path
|
||||
if output is None:
|
||||
output = input_path.parent / f"{input_path.name}-{method.value.upper()}"
|
||||
|
||||
print_info(t("quant_output_path", path=str(output)))
|
||||
print_info(t("quant_method", method=method.value.upper()))
|
||||
|
||||
# Detect CPU configuration
|
||||
cpu = detect_cpu_info()
|
||||
final_cpu_threads = cpu_threads or cpu.cores
|
||||
final_numa_nodes = numa_nodes or cpu.numa_nodes
|
||||
|
||||
print_info(f"CPU threads: {final_cpu_threads}")
|
||||
print_info(f"NUMA nodes: {final_numa_nodes}")
|
||||
|
||||
# Check if output exists
|
||||
if output.exists():
|
||||
print_warning(f"Output path already exists: {output}")
|
||||
if not yes:
|
||||
if not confirm("Overwrite?", default=False):
|
||||
raise typer.Abort()
|
||||
|
||||
# Confirm
|
||||
if not yes:
|
||||
console.print()
|
||||
console.print("[bold]Quantization Settings:[/bold]")
|
||||
console.print(f" Input: {input_path}")
|
||||
console.print(f" Output: {output}")
|
||||
console.print(f" Method: {method.value.upper()}")
|
||||
console.print(f" Input type: {input_type}")
|
||||
console.print()
|
||||
print_warning("Quantization may take 30-60 minutes depending on model size.")
|
||||
console.print()
|
||||
|
||||
if not confirm(t("prompt_continue")):
|
||||
raise typer.Abort()
|
||||
|
||||
# Find conversion script
|
||||
kt_kernel_path = _find_kt_kernel_path()
|
||||
if kt_kernel_path is None:
|
||||
print_error("kt-kernel not found. Install with: kt install inference")
|
||||
raise typer.Exit(1)
|
||||
|
||||
script_path = kt_kernel_path / "scripts" / "convert_cpu_weights.py"
|
||||
if not script_path.exists():
|
||||
print_error(f"Conversion script not found: {script_path}")
|
||||
raise typer.Exit(1)
|
||||
|
||||
# Build command
|
||||
cmd = [
|
||||
sys.executable, str(script_path),
|
||||
"--input-path", str(input_path),
|
||||
"--input-type", input_type,
|
||||
"--output", str(output),
|
||||
"--quant-method", method.value,
|
||||
"--cpuinfer-threads", str(final_cpu_threads),
|
||||
"--threadpool-count", str(final_numa_nodes),
|
||||
]
|
||||
|
||||
if no_merge:
|
||||
cmd.append("--no-merge-safetensor")
|
||||
|
||||
# Run quantization
|
||||
console.print()
|
||||
print_step(t("quant_starting"))
|
||||
console.print()
|
||||
console.print(f"[dim]$ {' '.join(cmd)}[/dim]")
|
||||
console.print()
|
||||
|
||||
try:
|
||||
process = subprocess.run(cmd)
|
||||
|
||||
if process.returncode == 0:
|
||||
console.print()
|
||||
print_success(t("quant_complete"))
|
||||
console.print()
|
||||
console.print(f" Quantized weights saved to: {output}")
|
||||
console.print()
|
||||
console.print(" Use with:")
|
||||
console.print(f" kt run {model} --weights-path {output}")
|
||||
console.print()
|
||||
else:
|
||||
print_error(f"Quantization failed with exit code {process.returncode}")
|
||||
raise typer.Exit(process.returncode)
|
||||
|
||||
except FileNotFoundError as e:
|
||||
print_error(f"Failed to run quantization: {e}")
|
||||
raise typer.Exit(1)
|
||||
except KeyboardInterrupt:
|
||||
console.print()
|
||||
print_warning("Quantization interrupted.")
|
||||
raise typer.Exit(130)
|
||||
|
||||
|
||||
def _resolve_input_path(model: str, settings) -> Optional[Path]:
|
||||
"""Resolve the input model path."""
|
||||
# Check if it's already a path
|
||||
path = Path(model)
|
||||
if path.exists() and (path / "config.json").exists():
|
||||
return path
|
||||
|
||||
# Search in models directory
|
||||
from kt_kernel.cli.utils.model_registry import get_registry
|
||||
|
||||
registry = get_registry()
|
||||
matches = registry.search(model)
|
||||
|
||||
if matches:
|
||||
model_info = matches[0]
|
||||
# Try to find in all configured model directories
|
||||
model_paths = settings.get_model_paths()
|
||||
|
||||
for models_dir in model_paths:
|
||||
possible_paths = [
|
||||
models_dir / model_info.name,
|
||||
models_dir / model_info.name.lower(),
|
||||
models_dir / model_info.hf_repo.split("/")[-1],
|
||||
]
|
||||
|
||||
for p in possible_paths:
|
||||
if p.exists() and (p / "config.json").exists():
|
||||
return p
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _find_kt_kernel_path() -> Optional[Path]:
|
||||
"""Find the kt-kernel installation path."""
|
||||
try:
|
||||
import kt_kernel
|
||||
return Path(kt_kernel.__file__).parent.parent
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
# Check common locations
|
||||
possible_paths = [
|
||||
Path.home() / "Projects" / "ktransformers" / "kt-kernel",
|
||||
Path.cwd().parent / "kt-kernel",
|
||||
Path.cwd() / "kt-kernel",
|
||||
]
|
||||
|
||||
for path in possible_paths:
|
||||
if path.exists() and (path / "scripts").exists():
|
||||
return path
|
||||
|
||||
return None
|
||||
831
kt-kernel/python/cli/commands/run.py
Normal file
831
kt-kernel/python/cli/commands/run.py
Normal file
@@ -0,0 +1,831 @@
|
||||
"""
|
||||
Run command for kt-cli.
|
||||
|
||||
Starts the model inference server using SGLang + kt-kernel.
|
||||
"""
|
||||
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import typer
|
||||
|
||||
from kt_kernel.cli.config.settings import get_settings
|
||||
from kt_kernel.cli.i18n import t
|
||||
from kt_kernel.cli.utils.console import (
|
||||
confirm,
|
||||
console,
|
||||
print_api_info,
|
||||
print_error,
|
||||
print_info,
|
||||
print_server_info,
|
||||
print_step,
|
||||
print_success,
|
||||
print_warning,
|
||||
prompt_choice,
|
||||
)
|
||||
from kt_kernel.cli.utils.environment import detect_cpu_info, detect_gpus, detect_ram_gb
|
||||
from kt_kernel.cli.utils.model_registry import MODEL_COMPUTE_FUNCTIONS, ModelInfo, get_registry
|
||||
|
||||
|
||||
def run(
|
||||
model: Optional[str] = typer.Argument(
|
||||
None,
|
||||
help="Model name or path (e.g., deepseek-v3, qwen3-30b). If not specified, shows interactive selection.",
|
||||
),
|
||||
host: str = typer.Option(
|
||||
None,
|
||||
"--host",
|
||||
"-H",
|
||||
help="Server host address",
|
||||
),
|
||||
port: int = typer.Option(
|
||||
None,
|
||||
"--port",
|
||||
"-p",
|
||||
help="Server port",
|
||||
),
|
||||
# CPU/GPU configuration
|
||||
gpu_experts: Optional[int] = typer.Option(
|
||||
None,
|
||||
"--gpu-experts",
|
||||
help="Number of GPU experts per layer",
|
||||
),
|
||||
cpu_threads: Optional[int] = typer.Option(
|
||||
None,
|
||||
"--cpu-threads",
|
||||
help="Number of CPU inference threads (kt-cpuinfer, defaults to 80% of CPU cores)",
|
||||
),
|
||||
numa_nodes: Optional[int] = typer.Option(
|
||||
None,
|
||||
"--numa-nodes",
|
||||
help="Number of NUMA nodes",
|
||||
),
|
||||
tensor_parallel_size: Optional[int] = typer.Option(
|
||||
None,
|
||||
"--tensor-parallel-size",
|
||||
"--tp",
|
||||
help="Tensor parallel size (number of GPUs)",
|
||||
),
|
||||
# Model paths
|
||||
model_path: Optional[Path] = typer.Option(
|
||||
None,
|
||||
"--model-path",
|
||||
help="Custom model path",
|
||||
),
|
||||
weights_path: Optional[Path] = typer.Option(
|
||||
None,
|
||||
"--weights-path",
|
||||
help="Custom quantized weights path",
|
||||
),
|
||||
# KT-kernel options
|
||||
kt_method: Optional[str] = typer.Option(
|
||||
None,
|
||||
"--kt-method",
|
||||
help="KT quantization method (AMXINT4, RAWFP8, etc.)",
|
||||
),
|
||||
kt_gpu_prefill_token_threshold: Optional[int] = typer.Option(
|
||||
None,
|
||||
"--kt-gpu-prefill-threshold",
|
||||
help="GPU prefill token threshold for kt-kernel",
|
||||
),
|
||||
# SGLang options
|
||||
attention_backend: Optional[str] = typer.Option(
|
||||
None,
|
||||
"--attention-backend",
|
||||
help="Attention backend (triton, flashinfer)",
|
||||
),
|
||||
max_total_tokens: Optional[int] = typer.Option(
|
||||
None,
|
||||
"--max-total-tokens",
|
||||
help="Maximum total tokens",
|
||||
),
|
||||
max_running_requests: Optional[int] = typer.Option(
|
||||
None,
|
||||
"--max-running-requests",
|
||||
help="Maximum running requests",
|
||||
),
|
||||
chunked_prefill_size: Optional[int] = typer.Option(
|
||||
None,
|
||||
"--chunked-prefill-size",
|
||||
help="Chunked prefill size",
|
||||
),
|
||||
mem_fraction_static: Optional[float] = typer.Option(
|
||||
None,
|
||||
"--mem-fraction-static",
|
||||
help="Memory fraction for static allocation",
|
||||
),
|
||||
watchdog_timeout: Optional[int] = typer.Option(
|
||||
None,
|
||||
"--watchdog-timeout",
|
||||
help="Watchdog timeout in seconds",
|
||||
),
|
||||
served_model_name: Optional[str] = typer.Option(
|
||||
None,
|
||||
"--served-model-name",
|
||||
help="Custom model name for API responses",
|
||||
),
|
||||
# Performance flags
|
||||
disable_shared_experts_fusion: Optional[bool] = typer.Option(
|
||||
None,
|
||||
"--disable-shared-experts-fusion/--enable-shared-experts-fusion",
|
||||
help="Disable/enable shared experts fusion",
|
||||
),
|
||||
# Other options
|
||||
quantize: bool = typer.Option(
|
||||
False,
|
||||
"--quantize",
|
||||
"-q",
|
||||
help="Quantize model if weights not found",
|
||||
),
|
||||
advanced: bool = typer.Option(
|
||||
False,
|
||||
"--advanced",
|
||||
help="Show advanced options",
|
||||
),
|
||||
dry_run: bool = typer.Option(
|
||||
False,
|
||||
"--dry-run",
|
||||
help="Show command without executing",
|
||||
),
|
||||
) -> None:
|
||||
"""Start model inference server."""
|
||||
# Check if SGLang is installed before proceeding
|
||||
from kt_kernel.cli.utils.sglang_checker import (
|
||||
check_sglang_installation,
|
||||
check_sglang_kt_kernel_support,
|
||||
print_sglang_install_instructions,
|
||||
print_sglang_kt_kernel_instructions,
|
||||
)
|
||||
|
||||
sglang_info = check_sglang_installation()
|
||||
if not sglang_info["installed"]:
|
||||
console.print()
|
||||
print_error(t("sglang_not_found"))
|
||||
console.print()
|
||||
print_sglang_install_instructions()
|
||||
raise typer.Exit(1)
|
||||
|
||||
# Check if SGLang supports kt-kernel (has --kt-gpu-prefill-token-threshold parameter)
|
||||
kt_kernel_support = check_sglang_kt_kernel_support()
|
||||
if not kt_kernel_support["supported"]:
|
||||
console.print()
|
||||
print_error(t("sglang_kt_kernel_not_supported"))
|
||||
console.print()
|
||||
print_sglang_kt_kernel_instructions()
|
||||
raise typer.Exit(1)
|
||||
|
||||
settings = get_settings()
|
||||
registry = get_registry()
|
||||
|
||||
console.print()
|
||||
|
||||
# If no model specified, show interactive selection
|
||||
if model is None:
|
||||
model = _interactive_model_selection(registry, settings)
|
||||
if model is None:
|
||||
raise typer.Exit(0)
|
||||
|
||||
# Step 1: Detect hardware
|
||||
print_step(t("run_detecting_hardware"))
|
||||
gpus = detect_gpus()
|
||||
cpu = detect_cpu_info()
|
||||
ram = detect_ram_gb()
|
||||
|
||||
if gpus:
|
||||
gpu_info = f"{gpus[0].name} ({gpus[0].vram_gb}GB VRAM)"
|
||||
if len(gpus) > 1:
|
||||
gpu_info += f" + {len(gpus) - 1} more"
|
||||
print_info(t("run_gpu_info", name=gpus[0].name, vram=gpus[0].vram_gb))
|
||||
else:
|
||||
print_warning(t("doctor_gpu_not_found"))
|
||||
gpu_info = "None"
|
||||
|
||||
print_info(t("run_cpu_info", name=cpu.name, cores=cpu.cores, numa=cpu.numa_nodes))
|
||||
print_info(t("run_ram_info", total=int(ram)))
|
||||
|
||||
# Step 2: Resolve model
|
||||
console.print()
|
||||
print_step(t("run_checking_model"))
|
||||
|
||||
model_info = None
|
||||
resolved_model_path = model_path
|
||||
|
||||
# Check if model is a path
|
||||
if Path(model).exists():
|
||||
resolved_model_path = Path(model)
|
||||
print_info(t("run_model_path", path=str(resolved_model_path)))
|
||||
|
||||
# Try to infer model type from path to use default configurations
|
||||
# Check directory name against known models
|
||||
dir_name = resolved_model_path.name.lower()
|
||||
for registered_model in registry.list_all():
|
||||
# Check if directory name matches model name or aliases
|
||||
if dir_name == registered_model.name.lower():
|
||||
model_info = registered_model
|
||||
print_info(f"Detected model type: {registered_model.name}")
|
||||
break
|
||||
for alias in registered_model.aliases:
|
||||
if dir_name == alias.lower() or alias.lower() in dir_name:
|
||||
model_info = registered_model
|
||||
print_info(f"Detected model type: {registered_model.name}")
|
||||
break
|
||||
if model_info:
|
||||
break
|
||||
|
||||
# Also check HuggingFace repo format (org--model)
|
||||
if not model_info:
|
||||
for registered_model in registry.list_all():
|
||||
repo_slug = registered_model.hf_repo.replace("/", "--").lower()
|
||||
if repo_slug in dir_name or dir_name in repo_slug:
|
||||
model_info = registered_model
|
||||
print_info(f"Detected model type: {registered_model.name}")
|
||||
break
|
||||
|
||||
if not model_info:
|
||||
print_warning("Could not detect model type from path. Using default parameters.")
|
||||
console.print(" [dim]Tip: Use model name (e.g., 'kt run m2') to apply optimized configurations[/dim]")
|
||||
else:
|
||||
# Search in registry
|
||||
matches = registry.search(model)
|
||||
|
||||
if not matches:
|
||||
print_error(t("run_model_not_found", name=model))
|
||||
console.print()
|
||||
console.print("Available models:")
|
||||
for m in registry.list_all()[:5]:
|
||||
console.print(f" - {m.name} ({', '.join(m.aliases[:2])})")
|
||||
raise typer.Exit(1)
|
||||
|
||||
if len(matches) == 1:
|
||||
model_info = matches[0]
|
||||
else:
|
||||
# Multiple matches - prompt user
|
||||
console.print()
|
||||
print_info(t("run_multiple_matches"))
|
||||
choices = [f"{m.name} ({m.hf_repo})" for m in matches]
|
||||
selected = prompt_choice(t("run_select_model"), choices)
|
||||
idx = choices.index(selected)
|
||||
model_info = matches[idx]
|
||||
|
||||
# Find model path
|
||||
if model_path is None:
|
||||
resolved_model_path = _find_model_path(model_info, settings)
|
||||
if resolved_model_path is None:
|
||||
print_error(t("run_model_not_found", name=model_info.name))
|
||||
console.print()
|
||||
console.print(
|
||||
f" Download with: kt download {model_info.aliases[0] if model_info.aliases else model_info.name}"
|
||||
)
|
||||
raise typer.Exit(1)
|
||||
|
||||
print_info(t("run_model_path", path=str(resolved_model_path)))
|
||||
|
||||
# Step 3: Check quantized weights (only if explicitly requested)
|
||||
resolved_weights_path = None
|
||||
|
||||
# Only use quantized weights if explicitly specified by user
|
||||
if weights_path is not None:
|
||||
# User explicitly specified weights path
|
||||
resolved_weights_path = weights_path
|
||||
if not resolved_weights_path.exists():
|
||||
print_error(t("run_weights_not_found"))
|
||||
console.print(f" Path: {resolved_weights_path}")
|
||||
raise typer.Exit(1)
|
||||
print_info(f"Using quantized weights: {resolved_weights_path}")
|
||||
elif quantize:
|
||||
# User requested quantization
|
||||
console.print()
|
||||
print_step(t("run_quantizing"))
|
||||
# TODO: Implement quantization
|
||||
print_warning("Quantization not yet implemented. Please run 'kt quant' manually.")
|
||||
raise typer.Exit(1)
|
||||
else:
|
||||
# Default: use original precision model without quantization
|
||||
console.print()
|
||||
print_info("Using original precision model (no quantization)")
|
||||
|
||||
# Step 4: Build command
|
||||
# Resolve all parameters (CLI > model defaults > config > auto-detect)
|
||||
final_host = host or settings.get("server.host", "0.0.0.0")
|
||||
final_port = port or settings.get("server.port", 30000)
|
||||
|
||||
# Get defaults from model info if available
|
||||
model_defaults = model_info.default_params if model_info else {}
|
||||
|
||||
# Determine tensor parallel size first (needed for GPU expert calculation)
|
||||
# Priority: CLI > model defaults > config > auto-detect (with model constraints)
|
||||
|
||||
# Check if explicitly specified by user or configuration
|
||||
explicitly_specified = (
|
||||
tensor_parallel_size # CLI argument (highest priority)
|
||||
or model_defaults.get("tensor-parallel-size") # Model defaults
|
||||
or settings.get("inference.tensor_parallel_size") # Config file
|
||||
)
|
||||
|
||||
if explicitly_specified:
|
||||
# Use explicitly specified value
|
||||
requested_tensor_parallel_size = explicitly_specified
|
||||
else:
|
||||
# Auto-detect from GPUs, considering model's max constraint
|
||||
detected_gpu_count = len(gpus) if gpus else 1
|
||||
if model_info and model_info.max_tensor_parallel_size is not None:
|
||||
# Automatically limit to model's maximum to use as many GPUs as possible
|
||||
requested_tensor_parallel_size = min(detected_gpu_count, model_info.max_tensor_parallel_size)
|
||||
else:
|
||||
requested_tensor_parallel_size = detected_gpu_count
|
||||
|
||||
# Apply model's max_tensor_parallel_size constraint if explicitly specified value exceeds it
|
||||
final_tensor_parallel_size = requested_tensor_parallel_size
|
||||
if model_info and model_info.max_tensor_parallel_size is not None:
|
||||
if requested_tensor_parallel_size > model_info.max_tensor_parallel_size:
|
||||
console.print()
|
||||
print_warning(
|
||||
f"Model {model_info.name} only supports up to {model_info.max_tensor_parallel_size}-way "
|
||||
f"tensor parallelism, but {requested_tensor_parallel_size} was requested. "
|
||||
f"Reducing to {model_info.max_tensor_parallel_size}."
|
||||
)
|
||||
final_tensor_parallel_size = model_info.max_tensor_parallel_size
|
||||
|
||||
# CPU/GPU configuration with smart defaults
|
||||
# kt-cpuinfer: default to 80% of total CPU threads (cores * NUMA nodes)
|
||||
total_threads = cpu.cores * cpu.numa_nodes
|
||||
final_cpu_threads = (
|
||||
cpu_threads
|
||||
or model_defaults.get("kt-cpuinfer")
|
||||
or settings.get("inference.cpu_threads")
|
||||
or int(total_threads * 0.8)
|
||||
)
|
||||
|
||||
# kt-threadpool-count: default to NUMA node count
|
||||
final_numa_nodes = (
|
||||
numa_nodes
|
||||
or model_defaults.get("kt-threadpool-count")
|
||||
or settings.get("inference.numa_nodes")
|
||||
or cpu.numa_nodes
|
||||
)
|
||||
|
||||
# kt-num-gpu-experts: use model-specific computation if available and not explicitly set
|
||||
if gpu_experts is not None:
|
||||
# User explicitly set it
|
||||
final_gpu_experts = gpu_experts
|
||||
elif model_info and model_info.name in MODEL_COMPUTE_FUNCTIONS and gpus:
|
||||
# Use model-specific computation function (only if GPUs detected)
|
||||
vram_per_gpu = gpus[0].vram_gb
|
||||
compute_func = MODEL_COMPUTE_FUNCTIONS[model_info.name]
|
||||
final_gpu_experts = compute_func(final_tensor_parallel_size, vram_per_gpu)
|
||||
console.print()
|
||||
print_info(
|
||||
f"Auto-computed kt-num-gpu-experts: {final_gpu_experts} (TP={final_tensor_parallel_size}, VRAM={vram_per_gpu}GB per GPU)"
|
||||
)
|
||||
else:
|
||||
# Fall back to defaults
|
||||
final_gpu_experts = model_defaults.get("kt-num-gpu-experts") or settings.get("inference.gpu_experts", 1)
|
||||
|
||||
# KT-kernel options
|
||||
final_kt_method = kt_method or model_defaults.get("kt-method") or settings.get("inference.kt_method", "AMXINT4")
|
||||
final_kt_gpu_prefill_threshold = (
|
||||
kt_gpu_prefill_token_threshold
|
||||
or model_defaults.get("kt-gpu-prefill-token-threshold")
|
||||
or settings.get("inference.kt_gpu_prefill_token_threshold", 4096)
|
||||
)
|
||||
|
||||
# SGLang options
|
||||
final_attention_backend = (
|
||||
attention_backend
|
||||
or model_defaults.get("attention-backend")
|
||||
or settings.get("inference.attention_backend", "triton")
|
||||
)
|
||||
final_max_total_tokens = (
|
||||
max_total_tokens or model_defaults.get("max-total-tokens") or settings.get("inference.max_total_tokens", 40000)
|
||||
)
|
||||
final_max_running_requests = (
|
||||
max_running_requests
|
||||
or model_defaults.get("max-running-requests")
|
||||
or settings.get("inference.max_running_requests", 32)
|
||||
)
|
||||
final_chunked_prefill_size = (
|
||||
chunked_prefill_size
|
||||
or model_defaults.get("chunked-prefill-size")
|
||||
or settings.get("inference.chunked_prefill_size", 4096)
|
||||
)
|
||||
final_mem_fraction_static = (
|
||||
mem_fraction_static
|
||||
or model_defaults.get("mem-fraction-static")
|
||||
or settings.get("inference.mem_fraction_static", 0.98)
|
||||
)
|
||||
final_watchdog_timeout = (
|
||||
watchdog_timeout or model_defaults.get("watchdog-timeout") or settings.get("inference.watchdog_timeout", 3000)
|
||||
)
|
||||
final_served_model_name = (
|
||||
served_model_name or model_defaults.get("served-model-name") or settings.get("inference.served_model_name", "")
|
||||
)
|
||||
|
||||
# Performance flags
|
||||
if disable_shared_experts_fusion is not None:
|
||||
final_disable_shared_experts_fusion = disable_shared_experts_fusion
|
||||
elif "disable-shared-experts-fusion" in model_defaults:
|
||||
final_disable_shared_experts_fusion = model_defaults["disable-shared-experts-fusion"]
|
||||
else:
|
||||
final_disable_shared_experts_fusion = settings.get("inference.disable_shared_experts_fusion", False)
|
||||
|
||||
# Pass all model default params to handle any extra parameters
|
||||
extra_params = model_defaults if model_info else {}
|
||||
|
||||
cmd = _build_sglang_command(
|
||||
model_path=resolved_model_path,
|
||||
weights_path=resolved_weights_path,
|
||||
model_info=model_info,
|
||||
host=final_host,
|
||||
port=final_port,
|
||||
gpu_experts=final_gpu_experts,
|
||||
cpu_threads=final_cpu_threads,
|
||||
numa_nodes=final_numa_nodes,
|
||||
tensor_parallel_size=final_tensor_parallel_size,
|
||||
kt_method=final_kt_method,
|
||||
kt_gpu_prefill_threshold=final_kt_gpu_prefill_threshold,
|
||||
attention_backend=final_attention_backend,
|
||||
max_total_tokens=final_max_total_tokens,
|
||||
max_running_requests=final_max_running_requests,
|
||||
chunked_prefill_size=final_chunked_prefill_size,
|
||||
mem_fraction_static=final_mem_fraction_static,
|
||||
watchdog_timeout=final_watchdog_timeout,
|
||||
served_model_name=final_served_model_name,
|
||||
disable_shared_experts_fusion=final_disable_shared_experts_fusion,
|
||||
settings=settings,
|
||||
extra_model_params=extra_params,
|
||||
)
|
||||
|
||||
# Prepare environment variables
|
||||
env = os.environ.copy()
|
||||
# Add environment variables from advanced.env
|
||||
env.update(settings.get_env_vars())
|
||||
# Add environment variables from inference.env
|
||||
inference_env = settings.get("inference.env", {})
|
||||
if isinstance(inference_env, dict):
|
||||
env.update({k: str(v) for k, v in inference_env.items()})
|
||||
|
||||
# Step 5: Show configuration summary
|
||||
console.print()
|
||||
print_step("Configuration")
|
||||
|
||||
# Model info
|
||||
if model_info:
|
||||
console.print(f" Model: [bold]{model_info.name}[/bold]")
|
||||
else:
|
||||
console.print(f" Model: [bold]{resolved_model_path.name}[/bold]")
|
||||
|
||||
console.print(f" Path: [dim]{resolved_model_path}[/dim]")
|
||||
|
||||
# Key parameters
|
||||
console.print()
|
||||
console.print(f" GPU Experts: [cyan]{final_gpu_experts}[/cyan] per layer")
|
||||
console.print(f" CPU Threads (kt-cpuinfer): [cyan]{final_cpu_threads}[/cyan]")
|
||||
console.print(f" NUMA Nodes (kt-threadpool-count): [cyan]{final_numa_nodes}[/cyan]")
|
||||
console.print(f" Tensor Parallel: [cyan]{final_tensor_parallel_size}[/cyan]")
|
||||
console.print(f" Method: [cyan]{final_kt_method}[/cyan]")
|
||||
console.print(f" Attention: [cyan]{final_attention_backend}[/cyan]")
|
||||
|
||||
# Weights info
|
||||
if resolved_weights_path:
|
||||
console.print()
|
||||
console.print(f" Quantized weights: [yellow]{resolved_weights_path}[/yellow]")
|
||||
|
||||
console.print()
|
||||
console.print(f" Server: [green]http://{final_host}:{final_port}[/green]")
|
||||
console.print()
|
||||
|
||||
# Step 6: Show or execute
|
||||
if dry_run:
|
||||
console.print()
|
||||
console.print("[bold]Command:[/bold]")
|
||||
console.print()
|
||||
console.print(f" [dim]{' '.join(cmd)}[/dim]")
|
||||
console.print()
|
||||
return
|
||||
|
||||
# Execute with prepared environment variables
|
||||
# Don't print "Server started" or API info here - let sglang's logs speak for themselves
|
||||
# The actual startup takes time and these messages are misleading
|
||||
|
||||
# Print the command being executed
|
||||
console.print()
|
||||
console.print("[bold]Launching server with command:[/bold]")
|
||||
console.print()
|
||||
console.print(f" [dim]{' '.join(cmd)}[/dim]")
|
||||
console.print()
|
||||
|
||||
try:
|
||||
# Execute directly without intercepting output or signals
|
||||
# This allows direct output to terminal and Ctrl+C to work naturally
|
||||
process = subprocess.run(cmd, env=env)
|
||||
sys.exit(process.returncode)
|
||||
|
||||
except FileNotFoundError:
|
||||
from kt_kernel.cli.utils.sglang_checker import print_sglang_install_instructions
|
||||
|
||||
print_error(t("sglang_not_found"))
|
||||
console.print()
|
||||
print_sglang_install_instructions()
|
||||
raise typer.Exit(1)
|
||||
except Exception as e:
|
||||
print_error(f"Failed to start server: {e}")
|
||||
raise typer.Exit(1)
|
||||
|
||||
|
||||
def _find_model_path(model_info: ModelInfo, settings) -> Optional[Path]:
|
||||
"""Find the model path on disk by searching all configured model paths."""
|
||||
model_paths = settings.get_model_paths()
|
||||
|
||||
# Search in all configured model directories
|
||||
for models_dir in model_paths:
|
||||
# Check common path patterns
|
||||
possible_paths = [
|
||||
models_dir / model_info.name,
|
||||
models_dir / model_info.name.lower(),
|
||||
models_dir / model_info.name.replace(" ", "-"),
|
||||
models_dir / model_info.hf_repo.split("/")[-1],
|
||||
models_dir / model_info.hf_repo.replace("/", "--"),
|
||||
]
|
||||
|
||||
# Add alias-based paths
|
||||
for alias in model_info.aliases:
|
||||
possible_paths.append(models_dir / alias)
|
||||
possible_paths.append(models_dir / alias.lower())
|
||||
|
||||
for path in possible_paths:
|
||||
if path.exists() and (path / "config.json").exists():
|
||||
return path
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _find_weights_path(model_info: ModelInfo, settings) -> Optional[Path]:
|
||||
"""Find the quantized weights path on disk by searching all configured paths."""
|
||||
model_paths = settings.get_model_paths()
|
||||
weights_dir = settings.weights_dir
|
||||
|
||||
# Check common patterns
|
||||
base_names = [
|
||||
model_info.name,
|
||||
model_info.name.lower(),
|
||||
model_info.hf_repo.split("/")[-1],
|
||||
]
|
||||
|
||||
suffixes = ["-INT4", "-int4", "_INT4", "_int4", "-quant", "-quantized"]
|
||||
|
||||
# Prepare search directories
|
||||
search_dirs = [weights_dir] if weights_dir else []
|
||||
search_dirs.extend(model_paths)
|
||||
|
||||
for base in base_names:
|
||||
for suffix in suffixes:
|
||||
for dir_path in search_dirs:
|
||||
if dir_path:
|
||||
path = dir_path / f"{base}{suffix}"
|
||||
if path.exists():
|
||||
return path
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _build_sglang_command(
|
||||
model_path: Path,
|
||||
weights_path: Optional[Path],
|
||||
model_info: Optional[ModelInfo],
|
||||
host: str,
|
||||
port: int,
|
||||
gpu_experts: int,
|
||||
cpu_threads: int,
|
||||
numa_nodes: int,
|
||||
tensor_parallel_size: int,
|
||||
kt_method: str,
|
||||
kt_gpu_prefill_threshold: int,
|
||||
attention_backend: str,
|
||||
max_total_tokens: int,
|
||||
max_running_requests: int,
|
||||
chunked_prefill_size: int,
|
||||
mem_fraction_static: float,
|
||||
watchdog_timeout: int,
|
||||
served_model_name: str,
|
||||
disable_shared_experts_fusion: bool,
|
||||
settings,
|
||||
extra_model_params: Optional[dict] = None, # New parameter for additional params
|
||||
) -> list[str]:
|
||||
"""Build the SGLang launch command."""
|
||||
cmd = [
|
||||
sys.executable,
|
||||
"-m",
|
||||
"sglang.launch_server",
|
||||
"--host",
|
||||
host,
|
||||
"--port",
|
||||
str(port),
|
||||
"--model",
|
||||
str(model_path),
|
||||
]
|
||||
|
||||
# Add kt-kernel options
|
||||
# kt-kernel is needed for:
|
||||
# 1. Quantized models (when weights_path is provided)
|
||||
# 2. MoE models with CPU offloading (when kt-cpuinfer > 0 or kt-num-gpu-experts is configured)
|
||||
use_kt_kernel = False
|
||||
|
||||
# Check if we should use kt-kernel
|
||||
if weights_path:
|
||||
# Quantized model - always use kt-kernel
|
||||
use_kt_kernel = True
|
||||
elif cpu_threads > 0 or gpu_experts > 1:
|
||||
# CPU offloading configured - use kt-kernel
|
||||
use_kt_kernel = True
|
||||
elif model_info and model_info.type == "moe":
|
||||
# MoE model - likely needs kt-kernel for expert offloading
|
||||
use_kt_kernel = True
|
||||
|
||||
if use_kt_kernel:
|
||||
# Add kt-weight-path: use quantized weights if available, otherwise use model path
|
||||
weight_path_to_use = weights_path if weights_path else model_path
|
||||
|
||||
# Add kt-kernel configuration
|
||||
cmd.extend(
|
||||
[
|
||||
"--kt-weight-path",
|
||||
str(weight_path_to_use),
|
||||
"--kt-cpuinfer",
|
||||
str(cpu_threads),
|
||||
"--kt-threadpool-count",
|
||||
str(numa_nodes),
|
||||
"--kt-num-gpu-experts",
|
||||
str(gpu_experts),
|
||||
"--kt-method",
|
||||
kt_method,
|
||||
"--kt-gpu-prefill-token-threshold",
|
||||
str(kt_gpu_prefill_threshold),
|
||||
]
|
||||
)
|
||||
|
||||
# Add SGLang options
|
||||
cmd.extend(
|
||||
[
|
||||
"--attention-backend",
|
||||
attention_backend,
|
||||
"--trust-remote-code",
|
||||
"--mem-fraction-static",
|
||||
str(mem_fraction_static),
|
||||
"--chunked-prefill-size",
|
||||
str(chunked_prefill_size),
|
||||
"--max-running-requests",
|
||||
str(max_running_requests),
|
||||
"--max-total-tokens",
|
||||
str(max_total_tokens),
|
||||
"--watchdog-timeout",
|
||||
str(watchdog_timeout),
|
||||
"--enable-mixed-chunk",
|
||||
"--tensor-parallel-size",
|
||||
str(tensor_parallel_size),
|
||||
"--enable-p2p-check",
|
||||
]
|
||||
)
|
||||
|
||||
# Add served model name if specified
|
||||
if served_model_name:
|
||||
cmd.extend(["--served-model-name", served_model_name])
|
||||
|
||||
# Add performance flags
|
||||
if disable_shared_experts_fusion:
|
||||
cmd.append("--disable-shared-experts-fusion")
|
||||
|
||||
# Add any extra parameters from model defaults that weren't explicitly handled
|
||||
if extra_model_params:
|
||||
# List of parameters already handled above
|
||||
handled_params = {
|
||||
"kt-num-gpu-experts",
|
||||
"kt-cpuinfer",
|
||||
"kt-threadpool-count",
|
||||
"kt-method",
|
||||
"kt-gpu-prefill-token-threshold",
|
||||
"attention-backend",
|
||||
"tensor-parallel-size",
|
||||
"max-total-tokens",
|
||||
"max-running-requests",
|
||||
"chunked-prefill-size",
|
||||
"mem-fraction-static",
|
||||
"watchdog-timeout",
|
||||
"served-model-name",
|
||||
"disable-shared-experts-fusion",
|
||||
}
|
||||
|
||||
for key, value in extra_model_params.items():
|
||||
if key not in handled_params:
|
||||
# Add unhandled parameters dynamically
|
||||
cmd.append(f"--{key}")
|
||||
if isinstance(value, bool):
|
||||
# Boolean flags don't need a value
|
||||
if not value:
|
||||
# For False boolean, skip the flag entirely
|
||||
cmd.pop() # Remove the flag we just added
|
||||
else:
|
||||
cmd.append(str(value))
|
||||
|
||||
# Add extra args from settings
|
||||
extra_args = settings.get("advanced.sglang_args", [])
|
||||
if extra_args:
|
||||
cmd.extend(extra_args)
|
||||
|
||||
return cmd
|
||||
|
||||
|
||||
def _interactive_model_selection(registry, settings) -> Optional[str]:
|
||||
"""Show interactive model selection interface.
|
||||
|
||||
Returns:
|
||||
Selected model name or None if cancelled.
|
||||
"""
|
||||
from rich.panel import Panel
|
||||
from rich.table import Table
|
||||
from rich.prompt import Prompt
|
||||
|
||||
from kt_kernel.cli.i18n import get_lang
|
||||
|
||||
lang = get_lang()
|
||||
|
||||
# Find local models first
|
||||
local_models = registry.find_local_models()
|
||||
|
||||
# Get all registered models
|
||||
all_models = registry.list_all()
|
||||
|
||||
console.print()
|
||||
console.print(
|
||||
Panel.fit(
|
||||
t("run_select_model_title"),
|
||||
border_style="cyan",
|
||||
)
|
||||
)
|
||||
console.print()
|
||||
|
||||
# Build choices list
|
||||
choices = []
|
||||
choice_map = {} # index -> model name
|
||||
|
||||
# Section 1: Local models (downloaded)
|
||||
if local_models:
|
||||
console.print(f"[bold green]{t('run_local_models')}[/bold green]")
|
||||
console.print()
|
||||
|
||||
for i, (model_info, path) in enumerate(local_models, 1):
|
||||
desc = model_info.description_zh if lang == "zh" else model_info.description
|
||||
short_desc = desc[:50] + "..." if len(desc) > 50 else desc
|
||||
console.print(f" [cyan][{i}][/cyan] [bold]{model_info.name}[/bold]")
|
||||
console.print(f" [dim]{short_desc}[/dim]")
|
||||
console.print(f" [dim]{path}[/dim]")
|
||||
choices.append(str(i))
|
||||
choice_map[str(i)] = model_info.name
|
||||
|
||||
console.print()
|
||||
|
||||
# Section 2: All registered models (for reference)
|
||||
start_idx = len(local_models) + 1
|
||||
console.print(f"[bold yellow]{t('run_registered_models')}[/bold yellow]")
|
||||
console.print()
|
||||
|
||||
# Filter out already shown local models
|
||||
local_model_names = {m.name for m, _ in local_models}
|
||||
|
||||
for i, model_info in enumerate(all_models, start_idx):
|
||||
if model_info.name in local_model_names:
|
||||
continue
|
||||
|
||||
desc = model_info.description_zh if lang == "zh" else model_info.description
|
||||
short_desc = desc[:50] + "..." if len(desc) > 50 else desc
|
||||
console.print(f" [cyan][{i}][/cyan] [bold]{model_info.name}[/bold]")
|
||||
console.print(f" [dim]{short_desc}[/dim]")
|
||||
console.print(f" [dim]{model_info.hf_repo}[/dim]")
|
||||
choices.append(str(i))
|
||||
choice_map[str(i)] = model_info.name
|
||||
|
||||
console.print()
|
||||
|
||||
# Add cancel option
|
||||
cancel_idx = str(len(choices) + 1)
|
||||
console.print(f" [cyan][{cancel_idx}][/cyan] [dim]{t('cancel')}[/dim]")
|
||||
choices.append(cancel_idx)
|
||||
console.print()
|
||||
|
||||
# Prompt for selection
|
||||
try:
|
||||
selection = Prompt.ask(
|
||||
t("run_select_model_prompt"),
|
||||
choices=choices,
|
||||
default="1" if choices else cancel_idx,
|
||||
)
|
||||
except KeyboardInterrupt:
|
||||
console.print()
|
||||
return None
|
||||
|
||||
if selection == cancel_idx:
|
||||
return None
|
||||
|
||||
return choice_map.get(selection)
|
||||
52
kt-kernel/python/cli/commands/sft.py
Normal file
52
kt-kernel/python/cli/commands/sft.py
Normal file
@@ -0,0 +1,52 @@
|
||||
"""
|
||||
SFT command for kt-cli.
|
||||
|
||||
Fine-tuning with LlamaFactory integration.
|
||||
"""
|
||||
|
||||
import typer
|
||||
|
||||
from kt_kernel.cli.i18n import t
|
||||
from kt_kernel.cli.utils.console import console
|
||||
|
||||
app = typer.Typer(help="Fine-tuning with LlamaFactory (coming soon)")
|
||||
|
||||
|
||||
@app.callback(invoke_without_command=True)
|
||||
def callback(ctx: typer.Context) -> None:
|
||||
"""Fine-tuning commands (coming soon)."""
|
||||
if ctx.invoked_subcommand is None:
|
||||
console.print()
|
||||
console.print(f"[yellow]{t('feature_coming_soon')}[/yellow]")
|
||||
console.print()
|
||||
console.print("[dim]kt sft train - Train a model[/dim]")
|
||||
console.print("[dim]kt sft chat - Chat with a trained model[/dim]")
|
||||
console.print("[dim]kt sft export - Export a trained model[/dim]")
|
||||
console.print()
|
||||
|
||||
|
||||
@app.command(name="train")
|
||||
def train() -> None:
|
||||
"""Train a model using LlamaFactory (coming soon)."""
|
||||
console.print()
|
||||
console.print(f"[yellow]{t('feature_coming_soon')}[/yellow]")
|
||||
console.print()
|
||||
raise typer.Exit(0)
|
||||
|
||||
|
||||
@app.command(name="chat")
|
||||
def chat() -> None:
|
||||
"""Chat with a trained model using LlamaFactory (coming soon)."""
|
||||
console.print()
|
||||
console.print(f"[yellow]{t('feature_coming_soon')}[/yellow]")
|
||||
console.print()
|
||||
raise typer.Exit(0)
|
||||
|
||||
|
||||
@app.command(name="export")
|
||||
def export() -> None:
|
||||
"""Export a trained model using LlamaFactory (coming soon)."""
|
||||
console.print()
|
||||
console.print(f"[yellow]{t('feature_coming_soon')}[/yellow]")
|
||||
console.print()
|
||||
raise typer.Exit(0)
|
||||
118
kt-kernel/python/cli/commands/version.py
Normal file
118
kt-kernel/python/cli/commands/version.py
Normal file
@@ -0,0 +1,118 @@
|
||||
"""
|
||||
Version command for kt-cli.
|
||||
|
||||
Displays version information for kt-cli and related packages.
|
||||
"""
|
||||
|
||||
import platform
|
||||
from typing import Optional
|
||||
|
||||
import typer
|
||||
|
||||
from kt_kernel.cli import __version__
|
||||
from kt_kernel.cli.i18n import t
|
||||
from kt_kernel.cli.utils.console import console, print_version_table
|
||||
from kt_kernel.cli.utils.environment import detect_cuda_version, get_installed_package_version
|
||||
|
||||
|
||||
def _get_sglang_info() -> str:
|
||||
"""Get sglang version and installation source information."""
|
||||
try:
|
||||
import sglang
|
||||
|
||||
version = getattr(sglang, "__version__", None)
|
||||
|
||||
if not version:
|
||||
version = get_installed_package_version("sglang")
|
||||
|
||||
if not version:
|
||||
return t("version_not_installed")
|
||||
|
||||
# Try to detect installation source
|
||||
from pathlib import Path
|
||||
import subprocess
|
||||
|
||||
if hasattr(sglang, "__file__") and sglang.__file__:
|
||||
location = Path(sglang.__file__).parent.parent
|
||||
git_dir = location / ".git"
|
||||
|
||||
if git_dir.exists():
|
||||
# Installed from git (editable install)
|
||||
try:
|
||||
# Get remote URL
|
||||
result = subprocess.run(
|
||||
["git", "remote", "get-url", "origin"],
|
||||
cwd=location,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=2,
|
||||
)
|
||||
if result.returncode == 0:
|
||||
remote_url = result.stdout.strip()
|
||||
# Simplify GitHub URLs
|
||||
if "github.com" in remote_url:
|
||||
repo_name = remote_url.split("/")[-1].replace(".git", "")
|
||||
owner = remote_url.split("/")[-2]
|
||||
return f"{version} [dim](GitHub: {owner}/{repo_name})[/dim]"
|
||||
return f"{version} [dim](Git: {remote_url})[/dim]"
|
||||
except (subprocess.TimeoutExpired, FileNotFoundError, OSError):
|
||||
pass
|
||||
|
||||
# Default: installed from PyPI
|
||||
return f"{version} [dim](PyPI)[/dim]"
|
||||
|
||||
except ImportError:
|
||||
return t("version_not_installed")
|
||||
|
||||
|
||||
def version(
|
||||
verbose: bool = typer.Option(False, "--verbose", "-v", help="Show detailed version info"),
|
||||
) -> None:
|
||||
"""Show version information."""
|
||||
console.print(f"\n[bold]{t('version_info')}[/bold] v{__version__}\n")
|
||||
|
||||
# Basic info
|
||||
versions = {
|
||||
t("version_python"): platform.python_version(),
|
||||
t("version_platform"): f"{platform.system()} {platform.release()}",
|
||||
}
|
||||
|
||||
# CUDA version
|
||||
cuda_version = detect_cuda_version()
|
||||
versions[t("version_cuda")] = cuda_version or t("version_cuda_not_found")
|
||||
|
||||
print_version_table(versions)
|
||||
|
||||
# Always show key packages with installation source
|
||||
console.print("\n[bold]Packages:[/bold]\n")
|
||||
|
||||
sglang_info = _get_sglang_info()
|
||||
key_packages = {
|
||||
t("version_kt_kernel"): get_installed_package_version("kt-kernel") or t("version_not_installed"),
|
||||
t("version_sglang"): sglang_info,
|
||||
}
|
||||
|
||||
print_version_table(key_packages)
|
||||
|
||||
# Show SGLang installation hint if not installed
|
||||
if sglang_info == t("version_not_installed"):
|
||||
from kt_kernel.cli.utils.sglang_checker import print_sglang_install_instructions
|
||||
|
||||
console.print()
|
||||
print_sglang_install_instructions()
|
||||
|
||||
if verbose:
|
||||
console.print("\n[bold]Additional Packages:[/bold]\n")
|
||||
|
||||
package_versions = {
|
||||
t("version_ktransformers"): get_installed_package_version("ktransformers") or t("version_not_installed"),
|
||||
t("version_llamafactory"): get_installed_package_version("llamafactory") or t("version_not_installed"),
|
||||
"typer": get_installed_package_version("typer") or t("version_not_installed"),
|
||||
"rich": get_installed_package_version("rich") or t("version_not_installed"),
|
||||
"torch": get_installed_package_version("torch") or t("version_not_installed"),
|
||||
"transformers": get_installed_package_version("transformers") or t("version_not_installed"),
|
||||
}
|
||||
|
||||
print_version_table(package_versions)
|
||||
|
||||
console.print()
|
||||
1
kt-kernel/python/cli/completions/__init__.py
Normal file
1
kt-kernel/python/cli/completions/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Shell completion scripts for kt-cli."""
|
||||
153
kt-kernel/python/cli/completions/_kt
Normal file
153
kt-kernel/python/cli/completions/_kt
Normal file
@@ -0,0 +1,153 @@
|
||||
#compdef kt
|
||||
# Zsh completion for kt command
|
||||
# This is a static completion script that doesn't require Python startup
|
||||
|
||||
_kt() {
|
||||
local -a commands
|
||||
commands=(
|
||||
'version:Show version information'
|
||||
'run:Start model inference server'
|
||||
'chat:Interactive chat with running model'
|
||||
'quant:Quantize model weights'
|
||||
'bench:Run full benchmark'
|
||||
'microbench:Run micro-benchmark'
|
||||
'doctor:Diagnose environment issues'
|
||||
'model:Manage models and storage paths'
|
||||
'config:Manage configuration'
|
||||
'sft:Fine-tuning with LlamaFactory'
|
||||
)
|
||||
|
||||
local -a run_opts
|
||||
run_opts=(
|
||||
'--host[Server host]:host:'
|
||||
'--port[Server port]:port:'
|
||||
'--gpu-experts[Number of GPU experts]:count:'
|
||||
'--cpu-threads[Number of CPU threads]:count:'
|
||||
'--tensor-parallel-size[Tensor parallel size]:size:'
|
||||
'--kt-method[KT method]:method:(AMXINT4 FP8 RAWINT4)'
|
||||
'--attention-backend[Attention backend]:backend:(triton flashinfer)'
|
||||
'--max-total-tokens[Maximum total tokens]:tokens:'
|
||||
'--dry-run[Show command without executing]'
|
||||
'--help[Show help message]'
|
||||
)
|
||||
|
||||
local -a chat_opts
|
||||
chat_opts=(
|
||||
'--host[Server host]:host:'
|
||||
'--port[Server port]:port:'
|
||||
'--model[Model name]:model:'
|
||||
'--temperature[Sampling temperature]:temp:'
|
||||
'--max-tokens[Maximum tokens]:tokens:'
|
||||
'--system[System prompt]:prompt:'
|
||||
'--save-history[Save conversation history]'
|
||||
'--no-save-history[Do not save history]'
|
||||
'--history-file[History file path]:path:_files'
|
||||
'--stream[Enable streaming output]'
|
||||
'--no-stream[Disable streaming output]'
|
||||
'--help[Show help message]'
|
||||
)
|
||||
|
||||
local -a model_cmds
|
||||
model_cmds=(
|
||||
'download:Download a model from HuggingFace'
|
||||
'list:List available models'
|
||||
'path-list:List all model storage paths'
|
||||
'path-add:Add a new model storage path'
|
||||
'path-remove:Remove a model storage path'
|
||||
'search:Search for models in the registry'
|
||||
)
|
||||
|
||||
local -a config_cmds
|
||||
config_cmds=(
|
||||
'show:Show all configuration'
|
||||
'get:Get configuration value'
|
||||
'set:Set configuration value'
|
||||
'reset:Reset to defaults'
|
||||
'path:Show configuration file path'
|
||||
'init:Re-run first-time setup wizard'
|
||||
)
|
||||
|
||||
local -a sft_cmds
|
||||
sft_cmds=(
|
||||
'train:Train model'
|
||||
'chat:Chat with model'
|
||||
'export:Export model'
|
||||
)
|
||||
|
||||
_arguments -C \
|
||||
'1: :->command' \
|
||||
'*::arg:->args'
|
||||
|
||||
case $state in
|
||||
command)
|
||||
_describe 'kt commands' commands
|
||||
_arguments \
|
||||
'--help[Show help message]' \
|
||||
'--version[Show version]'
|
||||
;;
|
||||
args)
|
||||
case $words[1] in
|
||||
run)
|
||||
_arguments $run_opts \
|
||||
'1:model:'
|
||||
;;
|
||||
chat)
|
||||
_arguments $chat_opts
|
||||
;;
|
||||
quant)
|
||||
_arguments \
|
||||
'--method[Quantization method]:method:' \
|
||||
'--output[Output directory]:path:_files -/' \
|
||||
'--help[Show help message]' \
|
||||
'1:model:_files -/'
|
||||
;;
|
||||
bench|microbench)
|
||||
_arguments \
|
||||
'--model[Model name or path]:model:' \
|
||||
'--config[Config file path]:path:_files' \
|
||||
'--help[Show help message]'
|
||||
;;
|
||||
doctor)
|
||||
_arguments \
|
||||
'--verbose[Verbose output]' \
|
||||
'--help[Show help message]'
|
||||
;;
|
||||
model)
|
||||
_arguments \
|
||||
'1: :->model_cmd' \
|
||||
'*::arg:->model_args'
|
||||
|
||||
case $state in
|
||||
model_cmd)
|
||||
_describe 'model commands' model_cmds
|
||||
;;
|
||||
esac
|
||||
;;
|
||||
config)
|
||||
_arguments \
|
||||
'1: :->config_cmd' \
|
||||
'*::arg:->config_args'
|
||||
|
||||
case $state in
|
||||
config_cmd)
|
||||
_describe 'config commands' config_cmds
|
||||
;;
|
||||
esac
|
||||
;;
|
||||
sft)
|
||||
_arguments \
|
||||
'1: :->sft_cmd' \
|
||||
'*::arg:->sft_args'
|
||||
|
||||
case $state in
|
||||
sft_cmd)
|
||||
_describe 'sft commands' sft_cmds
|
||||
;;
|
||||
esac
|
||||
;;
|
||||
esac
|
||||
;;
|
||||
esac
|
||||
}
|
||||
|
||||
_kt "$@"
|
||||
73
kt-kernel/python/cli/completions/kt-completion.bash
Normal file
73
kt-kernel/python/cli/completions/kt-completion.bash
Normal file
@@ -0,0 +1,73 @@
|
||||
#!/bin/bash
|
||||
# Bash completion for kt command
|
||||
# This is a static completion script that doesn't require Python startup
|
||||
|
||||
_kt_completion() {
|
||||
local cur prev opts
|
||||
COMPREPLY=()
|
||||
cur="${COMP_WORDS[COMP_CWORD]}"
|
||||
prev="${COMP_WORDS[COMP_CWORD-1]}"
|
||||
|
||||
# Main commands
|
||||
local commands="version run chat quant bench microbench doctor model config sft"
|
||||
|
||||
# Global options
|
||||
local global_opts="--help --version"
|
||||
|
||||
# Handle subcommands
|
||||
case "${COMP_CWORD}" in
|
||||
1)
|
||||
# First argument: suggest commands and global options
|
||||
COMPREPLY=( $(compgen -W "${commands} ${global_opts}" -- ${cur}) )
|
||||
return 0
|
||||
;;
|
||||
*)
|
||||
# Handle specific command options
|
||||
case "${COMP_WORDS[1]}" in
|
||||
run)
|
||||
local run_opts="--host --port --gpu-experts --cpu-threads --tensor-parallel-size --kt-method --attention-backend --max-total-tokens --dry-run --help"
|
||||
COMPREPLY=( $(compgen -W "${run_opts}" -- ${cur}) )
|
||||
;;
|
||||
chat)
|
||||
local chat_opts="--host --port --model --temperature --max-tokens --system --save-history --no-save-history --history-file --stream --no-stream --help"
|
||||
COMPREPLY=( $(compgen -W "${chat_opts}" -- ${cur}) )
|
||||
;;
|
||||
quant)
|
||||
local quant_opts="--method --output --help"
|
||||
COMPREPLY=( $(compgen -W "${quant_opts}" -- ${cur}) )
|
||||
;;
|
||||
bench|microbench)
|
||||
local bench_opts="--model --config --help"
|
||||
COMPREPLY=( $(compgen -W "${bench_opts}" -- ${cur}) )
|
||||
;;
|
||||
doctor)
|
||||
local doctor_opts="--verbose --help"
|
||||
COMPREPLY=( $(compgen -W "${doctor_opts}" -- ${cur}) )
|
||||
;;
|
||||
model)
|
||||
local model_cmds="download list path-list path-add path-remove search"
|
||||
local model_opts="--help"
|
||||
COMPREPLY=( $(compgen -W "${model_cmds} ${model_opts}" -- ${cur}) )
|
||||
;;
|
||||
config)
|
||||
local config_cmds="show get set reset path init model-path-list model-path-add model-path-remove"
|
||||
local config_opts="--help"
|
||||
COMPREPLY=( $(compgen -W "${config_cmds} ${config_opts}" -- ${cur}) )
|
||||
;;
|
||||
sft)
|
||||
local sft_cmds="train chat export"
|
||||
local sft_opts="--help"
|
||||
COMPREPLY=( $(compgen -W "${sft_cmds} ${sft_opts}" -- ${cur}) )
|
||||
;;
|
||||
version)
|
||||
COMPREPLY=( $(compgen -W "--help" -- ${cur}) )
|
||||
;;
|
||||
*)
|
||||
COMPREPLY=()
|
||||
;;
|
||||
esac
|
||||
;;
|
||||
esac
|
||||
}
|
||||
|
||||
complete -F _kt_completion kt
|
||||
74
kt-kernel/python/cli/completions/kt.fish
Normal file
74
kt-kernel/python/cli/completions/kt.fish
Normal file
@@ -0,0 +1,74 @@
|
||||
# Fish completion for kt command
|
||||
# This is a static completion script that doesn't require Python startup
|
||||
|
||||
# Main commands
|
||||
complete -c kt -f -n "__fish_use_subcommand" -a "version" -d "Show version information"
|
||||
complete -c kt -f -n "__fish_use_subcommand" -a "run" -d "Start model inference server"
|
||||
complete -c kt -f -n "__fish_use_subcommand" -a "chat" -d "Interactive chat with running model"
|
||||
complete -c kt -f -n "__fish_use_subcommand" -a "quant" -d "Quantize model weights"
|
||||
complete -c kt -f -n "__fish_use_subcommand" -a "bench" -d "Run full benchmark"
|
||||
complete -c kt -f -n "__fish_use_subcommand" -a "microbench" -d "Run micro-benchmark"
|
||||
complete -c kt -f -n "__fish_use_subcommand" -a "doctor" -d "Diagnose environment issues"
|
||||
complete -c kt -f -n "__fish_use_subcommand" -a "model" -d "Manage models and storage paths"
|
||||
complete -c kt -f -n "__fish_use_subcommand" -a "config" -d "Manage configuration"
|
||||
complete -c kt -f -n "__fish_use_subcommand" -a "sft" -d "Fine-tuning with LlamaFactory"
|
||||
|
||||
# Global options
|
||||
complete -c kt -l help -d "Show help message"
|
||||
complete -c kt -l version -d "Show version"
|
||||
|
||||
# Run command options
|
||||
complete -c kt -f -n "__fish_seen_subcommand_from run" -l host -d "Server host"
|
||||
complete -c kt -f -n "__fish_seen_subcommand_from run" -l port -d "Server port"
|
||||
complete -c kt -f -n "__fish_seen_subcommand_from run" -l gpu-experts -d "Number of GPU experts"
|
||||
complete -c kt -f -n "__fish_seen_subcommand_from run" -l cpu-threads -d "Number of CPU threads"
|
||||
complete -c kt -f -n "__fish_seen_subcommand_from run" -l tensor-parallel-size -d "Tensor parallel size"
|
||||
complete -c kt -f -n "__fish_seen_subcommand_from run" -l kt-method -d "KT method"
|
||||
complete -c kt -f -n "__fish_seen_subcommand_from run" -l attention-backend -d "Attention backend"
|
||||
complete -c kt -f -n "__fish_seen_subcommand_from run" -l max-total-tokens -d "Maximum total tokens"
|
||||
complete -c kt -f -n "__fish_seen_subcommand_from run" -l dry-run -d "Show command without executing"
|
||||
|
||||
# Chat command options
|
||||
complete -c kt -f -n "__fish_seen_subcommand_from chat" -l host -d "Server host"
|
||||
complete -c kt -f -n "__fish_seen_subcommand_from chat" -l port -d "Server port"
|
||||
complete -c kt -f -n "__fish_seen_subcommand_from chat" -l model -d "Model name"
|
||||
complete -c kt -f -n "__fish_seen_subcommand_from chat" -l temperature -d "Sampling temperature"
|
||||
complete -c kt -f -n "__fish_seen_subcommand_from chat" -l max-tokens -d "Maximum tokens"
|
||||
complete -c kt -f -n "__fish_seen_subcommand_from chat" -l system -d "System prompt"
|
||||
complete -c kt -f -n "__fish_seen_subcommand_from chat" -l save-history -d "Save conversation history"
|
||||
complete -c kt -f -n "__fish_seen_subcommand_from chat" -l no-save-history -d "Do not save history"
|
||||
complete -c kt -f -n "__fish_seen_subcommand_from chat" -l history-file -d "History file path"
|
||||
complete -c kt -f -n "__fish_seen_subcommand_from chat" -l stream -d "Enable streaming output"
|
||||
complete -c kt -f -n "__fish_seen_subcommand_from chat" -l no-stream -d "Disable streaming output"
|
||||
|
||||
# Quant command options
|
||||
complete -c kt -f -n "__fish_seen_subcommand_from quant" -l method -d "Quantization method"
|
||||
complete -c kt -f -n "__fish_seen_subcommand_from quant" -l output -d "Output directory"
|
||||
|
||||
# Bench command options
|
||||
complete -c kt -f -n "__fish_seen_subcommand_from bench microbench" -l model -d "Model name or path"
|
||||
complete -c kt -f -n "__fish_seen_subcommand_from bench microbench" -l config -d "Config file path"
|
||||
|
||||
# Doctor command options
|
||||
complete -c kt -f -n "__fish_seen_subcommand_from doctor" -l verbose -d "Verbose output"
|
||||
|
||||
# Model subcommands
|
||||
complete -c kt -f -n "__fish_seen_subcommand_from model; and not __fish_seen_subcommand_from download list path-list path-add path-remove search" -a "download" -d "Download a model from HuggingFace"
|
||||
complete -c kt -f -n "__fish_seen_subcommand_from model; and not __fish_seen_subcommand_from download list path-list path-add path-remove search" -a "list" -d "List available models"
|
||||
complete -c kt -f -n "__fish_seen_subcommand_from model; and not __fish_seen_subcommand_from download list path-list path-add path-remove search" -a "path-list" -d "List all model storage paths"
|
||||
complete -c kt -f -n "__fish_seen_subcommand_from model; and not __fish_seen_subcommand_from download list path-list path-add path-remove search" -a "path-add" -d "Add a new model storage path"
|
||||
complete -c kt -f -n "__fish_seen_subcommand_from model; and not __fish_seen_subcommand_from download list path-list path-add path-remove search" -a "path-remove" -d "Remove a model storage path"
|
||||
complete -c kt -f -n "__fish_seen_subcommand_from model; and not __fish_seen_subcommand_from download list path-list path-add path-remove search" -a "search" -d "Search for models in the registry"
|
||||
|
||||
# Config subcommands
|
||||
complete -c kt -f -n "__fish_seen_subcommand_from config; and not __fish_seen_subcommand_from show get set reset path init" -a "show" -d "Show all configuration"
|
||||
complete -c kt -f -n "__fish_seen_subcommand_from config; and not __fish_seen_subcommand_from show get set reset path init" -a "get" -d "Get configuration value"
|
||||
complete -c kt -f -n "__fish_seen_subcommand_from config; and not __fish_seen_subcommand_from show get set reset path init" -a "set" -d "Set configuration value"
|
||||
complete -c kt -f -n "__fish_seen_subcommand_from config; and not __fish_seen_subcommand_from show get set reset path init" -a "reset" -d "Reset to defaults"
|
||||
complete -c kt -f -n "__fish_seen_subcommand_from config; and not __fish_seen_subcommand_from show get set reset path init" -a "path" -d "Show configuration file path"
|
||||
complete -c kt -f -n "__fish_seen_subcommand_from config; and not __fish_seen_subcommand_from show get set reset path init" -a "init" -d "Re-run first-time setup wizard"
|
||||
|
||||
# SFT subcommands
|
||||
complete -c kt -f -n "__fish_seen_subcommand_from sft; and not __fish_seen_subcommand_from train chat export" -a "train" -d "Train model"
|
||||
complete -c kt -f -n "__fish_seen_subcommand_from sft; and not __fish_seen_subcommand_from train chat export" -a "chat" -d "Chat with model"
|
||||
complete -c kt -f -n "__fish_seen_subcommand_from sft; and not __fish_seen_subcommand_from train chat export" -a "export" -d "Export model"
|
||||
7
kt-kernel/python/cli/config/__init__.py
Normal file
7
kt-kernel/python/cli/config/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
"""
|
||||
Configuration management for kt-cli.
|
||||
"""
|
||||
|
||||
from kt_kernel.cli.config.settings import Settings, get_settings
|
||||
|
||||
__all__ = ["Settings", "get_settings"]
|
||||
311
kt-kernel/python/cli/config/settings.py
Normal file
311
kt-kernel/python/cli/config/settings.py
Normal file
@@ -0,0 +1,311 @@
|
||||
"""
|
||||
Configuration management for kt-cli.
|
||||
|
||||
Handles reading and writing configuration from ~/.ktransformers/config.yaml
|
||||
"""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional
|
||||
|
||||
import yaml
|
||||
|
||||
# Default configuration directory
|
||||
DEFAULT_CONFIG_DIR = Path.home() / ".ktransformers"
|
||||
DEFAULT_CONFIG_FILE = DEFAULT_CONFIG_DIR / "config.yaml"
|
||||
DEFAULT_MODELS_DIR = DEFAULT_CONFIG_DIR / "models"
|
||||
DEFAULT_CACHE_DIR = DEFAULT_CONFIG_DIR / "cache"
|
||||
|
||||
# Default configuration values
|
||||
DEFAULT_CONFIG = {
|
||||
"general": {
|
||||
"language": "auto", # auto, en, zh
|
||||
"color": True,
|
||||
"verbose": False,
|
||||
},
|
||||
"paths": {
|
||||
"models": str(DEFAULT_MODELS_DIR),
|
||||
"cache": str(DEFAULT_CACHE_DIR),
|
||||
"weights": "", # Custom quantized weights path
|
||||
},
|
||||
"server": {
|
||||
"host": "0.0.0.0",
|
||||
"port": 30000,
|
||||
},
|
||||
"inference": {
|
||||
# Inference parameters are model-specific and should not have defaults
|
||||
# They will be auto-detected or use model-specific optimizations
|
||||
# Environment variables (general optimizations)
|
||||
"env": {
|
||||
"PYTORCH_ALLOC_CONF": "expandable_segments:True",
|
||||
"SGLANG_ENABLE_JIT_DEEPGEMM": "0",
|
||||
},
|
||||
},
|
||||
"download": {
|
||||
"mirror": "", # HuggingFace mirror URL
|
||||
"resume": True,
|
||||
"verify": True,
|
||||
},
|
||||
"advanced": {
|
||||
# Environment variables to set when running
|
||||
"env": {},
|
||||
# Extra arguments to pass to sglang
|
||||
"sglang_args": [],
|
||||
# Extra arguments to pass to llamafactory
|
||||
"llamafactory_args": [],
|
||||
},
|
||||
"dependencies": {
|
||||
# SGLang installation source configuration
|
||||
"sglang": {
|
||||
"source": "github", # "pypi" or "github"
|
||||
"repo": "https://github.com/kvcache-ai/sglang",
|
||||
"branch": "main",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class Settings:
|
||||
"""Configuration manager for kt-cli."""
|
||||
|
||||
def __init__(self, config_path: Optional[Path] = None):
|
||||
"""Initialize settings manager.
|
||||
|
||||
Args:
|
||||
config_path: Path to config file. Defaults to ~/.ktransformers/config.yaml
|
||||
"""
|
||||
self.config_path = config_path or DEFAULT_CONFIG_FILE
|
||||
self.config_dir = self.config_path.parent
|
||||
self._config: dict[str, Any] = {}
|
||||
self._load()
|
||||
|
||||
def _ensure_dirs(self) -> None:
|
||||
"""Ensure configuration directories exist."""
|
||||
self.config_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Ensure all model paths exist
|
||||
model_paths = self.get_model_paths()
|
||||
for path in model_paths:
|
||||
path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
Path(self.get("paths.cache", DEFAULT_CACHE_DIR)).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def _load(self) -> None:
|
||||
"""Load configuration from file."""
|
||||
self._config = self._deep_copy(DEFAULT_CONFIG)
|
||||
|
||||
if self.config_path.exists():
|
||||
try:
|
||||
with open(self.config_path, "r", encoding="utf-8") as f:
|
||||
user_config = yaml.safe_load(f) or {}
|
||||
self._deep_merge(self._config, user_config)
|
||||
except (yaml.YAMLError, OSError) as e:
|
||||
# Log warning but continue with defaults
|
||||
print(f"Warning: Failed to load config: {e}")
|
||||
|
||||
self._ensure_dirs()
|
||||
|
||||
def _save(self) -> None:
|
||||
"""Save configuration to file."""
|
||||
self._ensure_dirs()
|
||||
try:
|
||||
with open(self.config_path, "w", encoding="utf-8") as f:
|
||||
yaml.dump(self._config, f, default_flow_style=False, allow_unicode=True)
|
||||
except OSError as e:
|
||||
raise RuntimeError(f"Failed to save config: {e}")
|
||||
|
||||
def _deep_copy(self, obj: Any) -> Any:
|
||||
"""Create a deep copy of a nested dict."""
|
||||
if isinstance(obj, dict):
|
||||
return {k: self._deep_copy(v) for k, v in obj.items()}
|
||||
if isinstance(obj, list):
|
||||
return [self._deep_copy(item) for item in obj]
|
||||
return obj
|
||||
|
||||
def _deep_merge(self, base: dict, override: dict) -> None:
|
||||
"""Deep merge override into base."""
|
||||
for key, value in override.items():
|
||||
if key in base and isinstance(base[key], dict) and isinstance(value, dict):
|
||||
self._deep_merge(base[key], value)
|
||||
else:
|
||||
base[key] = value
|
||||
|
||||
def get(self, key: str, default: Any = None) -> Any:
|
||||
"""Get a configuration value by dot-separated key.
|
||||
|
||||
Args:
|
||||
key: Dot-separated key path (e.g., "server.port")
|
||||
default: Default value if key not found
|
||||
|
||||
Returns:
|
||||
Configuration value or default
|
||||
"""
|
||||
parts = key.split(".")
|
||||
value = self._config
|
||||
|
||||
for part in parts:
|
||||
if isinstance(value, dict) and part in value:
|
||||
value = value[part]
|
||||
else:
|
||||
return default
|
||||
|
||||
return value
|
||||
|
||||
def set(self, key: str, value: Any) -> None:
|
||||
"""Set a configuration value by dot-separated key.
|
||||
|
||||
Args:
|
||||
key: Dot-separated key path (e.g., "server.port")
|
||||
value: Value to set
|
||||
"""
|
||||
parts = key.split(".")
|
||||
config = self._config
|
||||
|
||||
# Navigate to parent
|
||||
for part in parts[:-1]:
|
||||
if part not in config:
|
||||
config[part] = {}
|
||||
config = config[part]
|
||||
|
||||
# Set value
|
||||
config[parts[-1]] = value
|
||||
self._save()
|
||||
|
||||
def delete(self, key: str) -> bool:
|
||||
"""Delete a configuration value.
|
||||
|
||||
Args:
|
||||
key: Dot-separated key path
|
||||
|
||||
Returns:
|
||||
True if key was deleted, False if not found
|
||||
"""
|
||||
parts = key.split(".")
|
||||
config = self._config
|
||||
|
||||
# Navigate to parent
|
||||
for part in parts[:-1]:
|
||||
if part not in config:
|
||||
return False
|
||||
config = config[part]
|
||||
|
||||
# Delete key
|
||||
if parts[-1] in config:
|
||||
del config[parts[-1]]
|
||||
self._save()
|
||||
return True
|
||||
return False
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset configuration to defaults."""
|
||||
self._config = self._deep_copy(DEFAULT_CONFIG)
|
||||
self._save()
|
||||
|
||||
def get_all(self) -> dict[str, Any]:
|
||||
"""Get all configuration values."""
|
||||
return self._deep_copy(self._config)
|
||||
|
||||
def get_env_vars(self) -> dict[str, str]:
|
||||
"""Get environment variables to set."""
|
||||
env_vars = {}
|
||||
|
||||
# Get from advanced.env
|
||||
advanced_env = self.get("advanced.env", {})
|
||||
if isinstance(advanced_env, dict):
|
||||
env_vars.update({k: str(v) for k, v in advanced_env.items()})
|
||||
|
||||
return env_vars
|
||||
|
||||
@property
|
||||
def models_dir(self) -> Path:
|
||||
"""Get the primary models directory path (for backward compatibility)."""
|
||||
paths = self.get_model_paths()
|
||||
return paths[0] if paths else Path(DEFAULT_MODELS_DIR)
|
||||
|
||||
def get_model_paths(self) -> list[Path]:
|
||||
"""Get all model directory paths.
|
||||
|
||||
Returns a list of Path objects. Supports both:
|
||||
- Single path: paths.models = "/path/to/models"
|
||||
- Multiple paths: paths.models = ["/path/1", "/path/2"]
|
||||
"""
|
||||
models_config = self.get("paths.models", DEFAULT_MODELS_DIR)
|
||||
|
||||
# Handle both string and list
|
||||
if isinstance(models_config, str):
|
||||
return [Path(models_config)]
|
||||
elif isinstance(models_config, list):
|
||||
return [Path(p) for p in models_config]
|
||||
else:
|
||||
return [Path(DEFAULT_MODELS_DIR)]
|
||||
|
||||
def add_model_path(self, path: str) -> None:
|
||||
"""Add a new model path to the configuration."""
|
||||
models_config = self.get("paths.models", DEFAULT_MODELS_DIR)
|
||||
|
||||
# Convert to list if it's a string
|
||||
if isinstance(models_config, str):
|
||||
paths = [models_config]
|
||||
elif isinstance(models_config, list):
|
||||
paths = list(models_config)
|
||||
else:
|
||||
paths = []
|
||||
|
||||
# Add new path if not already present
|
||||
if path not in paths:
|
||||
paths.append(path)
|
||||
self.set("paths.models", paths)
|
||||
|
||||
def remove_model_path(self, path: str) -> bool:
|
||||
"""Remove a model path from the configuration.
|
||||
|
||||
Returns True if path was removed, False if not found.
|
||||
"""
|
||||
models_config = self.get("paths.models", DEFAULT_MODELS_DIR)
|
||||
|
||||
if isinstance(models_config, str):
|
||||
# Can't remove if it's a single string
|
||||
if models_config == path:
|
||||
# Don't remove the last path
|
||||
return False
|
||||
return False
|
||||
elif isinstance(models_config, list):
|
||||
if path in models_config:
|
||||
paths = list(models_config)
|
||||
paths.remove(path)
|
||||
# Don't allow removing all paths
|
||||
if not paths:
|
||||
return False
|
||||
self.set("paths.models", paths if len(paths) > 1 else paths[0])
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
@property
|
||||
def cache_dir(self) -> Path:
|
||||
"""Get the cache directory path."""
|
||||
return Path(self.get("paths.cache", DEFAULT_CACHE_DIR))
|
||||
|
||||
@property
|
||||
def weights_dir(self) -> Optional[Path]:
|
||||
"""Get the custom weights directory path."""
|
||||
weights = self.get("paths.weights", "")
|
||||
return Path(weights) if weights else None
|
||||
|
||||
|
||||
# Global settings instance
|
||||
_settings: Optional[Settings] = None
|
||||
|
||||
|
||||
def get_settings() -> Settings:
|
||||
"""Get the global settings instance."""
|
||||
global _settings
|
||||
if _settings is None:
|
||||
_settings = Settings()
|
||||
return _settings
|
||||
|
||||
|
||||
def reset_settings() -> None:
|
||||
"""Reset the global settings instance."""
|
||||
global _settings
|
||||
_settings = None
|
||||
655
kt-kernel/python/cli/i18n.py
Normal file
655
kt-kernel/python/cli/i18n.py
Normal file
@@ -0,0 +1,655 @@
|
||||
"""
|
||||
Internationalization (i18n) module for kt-cli.
|
||||
|
||||
Supports English and Chinese languages, with automatic detection based on
|
||||
system locale or KT_LANG environment variable.
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
# Message definitions for all supported languages
|
||||
MESSAGES: dict[str, dict[str, str]] = {
|
||||
"en": {
|
||||
# General
|
||||
"welcome": "Welcome to KTransformers!",
|
||||
"goodbye": "Goodbye!",
|
||||
"error": "Error",
|
||||
"warning": "Warning",
|
||||
"success": "Success",
|
||||
"info": "Info",
|
||||
"yes": "Yes",
|
||||
"no": "No",
|
||||
"cancel": "Cancel",
|
||||
"confirm": "Confirm",
|
||||
"done": "Done",
|
||||
"failed": "Failed",
|
||||
"skip": "Skip",
|
||||
"back": "Back",
|
||||
"next": "Next",
|
||||
"retry": "Retry",
|
||||
"abort": "Abort",
|
||||
# Version command
|
||||
"version_info": "KTransformers CLI",
|
||||
"version_python": "Python",
|
||||
"version_platform": "Platform",
|
||||
"version_cuda": "CUDA",
|
||||
"version_cuda_not_found": "Not found",
|
||||
"version_kt_kernel": "kt-kernel",
|
||||
"version_ktransformers": "ktransformers",
|
||||
"version_sglang": "sglang",
|
||||
"version_llamafactory": "llamafactory",
|
||||
"version_not_installed": "Not installed",
|
||||
# Install command
|
||||
"install_detecting_env": "Detecting environment managers...",
|
||||
"install_found": "Found {name} (version {version})",
|
||||
"install_not_found": "Not found: {name}",
|
||||
"install_checking_env": "Checking existing environments...",
|
||||
"install_env_exists": "Found existing 'kt' environment",
|
||||
"install_env_not_exists": "No 'kt' environment found",
|
||||
"install_no_env_manager": "No virtual environment manager detected",
|
||||
"install_select_method": "Please select installation method:",
|
||||
"install_method_conda": "Create new conda environment 'kt' (Recommended)",
|
||||
"install_method_venv": "Create new venv environment",
|
||||
"install_method_uv": "Create new uv environment (Fast)",
|
||||
"install_method_docker": "Use Docker container",
|
||||
"install_method_system": "Install to system Python (Not recommended)",
|
||||
"install_select_mode": "Please select installation mode:",
|
||||
"install_mode_inference": "Inference - Install kt-kernel + SGLang",
|
||||
"install_mode_sft": "Training - Install kt-sft + LlamaFactory",
|
||||
"install_mode_full": "Full - Install all components",
|
||||
"install_creating_env": "Creating {type} environment '{name}'...",
|
||||
"install_env_created": "Environment created successfully",
|
||||
"install_installing_deps": "Installing dependencies...",
|
||||
"install_checking_deps": "Checking dependency versions...",
|
||||
"install_dep_ok": "OK",
|
||||
"install_dep_outdated": "Needs update",
|
||||
"install_dep_missing": "Missing",
|
||||
"install_installing_pytorch": "Installing PyTorch...",
|
||||
"install_installing_from_requirements": "Installing from requirements file...",
|
||||
"install_deps_outdated": "Found {count} package(s) that need updating. Continue?",
|
||||
"install_updating": "Updating packages...",
|
||||
"install_complete": "Installation complete!",
|
||||
"install_activate_hint": "Activate environment: {command}",
|
||||
"install_start_hint": "Get started: kt run --help",
|
||||
"install_docker_pulling": "Pulling Docker image...",
|
||||
"install_docker_complete": "Docker image ready!",
|
||||
"install_docker_run_hint": "Run with: docker run --gpus all -p 30000:30000 {image} kt run {model}",
|
||||
"install_in_venv": "Running in virtual environment: {name}",
|
||||
"install_continue_without_venv": "Continue installing to system Python?",
|
||||
"install_already_installed": "All dependencies are already installed!",
|
||||
"install_confirm": "Install {count} package(s)?",
|
||||
# Install - System dependencies
|
||||
"install_checking_system_deps": "Checking system dependencies...",
|
||||
"install_dep_name": "Dependency",
|
||||
"install_dep_status": "Status",
|
||||
"install_deps_all_installed": "All system dependencies are installed",
|
||||
"install_deps_install_prompt": "Install missing dependencies?",
|
||||
"install_installing_system_deps": "Installing system dependencies...",
|
||||
"install_installing_dep": "Installing {name}",
|
||||
"install_dep_no_install_cmd": "No install command available for {name} on {os}",
|
||||
"install_dep_install_failed": "Failed to install {name}",
|
||||
"install_deps_skipped": "Skipping dependency installation",
|
||||
"install_deps_failed": "Failed to install system dependencies",
|
||||
# Install - CPU detection
|
||||
"install_auto_detect_cpu": "Auto-detecting CPU capabilities...",
|
||||
"install_cpu_features": "Detected CPU features: {features}",
|
||||
"install_cpu_no_features": "No advanced CPU features detected",
|
||||
# Install - Build configuration
|
||||
"install_build_config": "Build Configuration:",
|
||||
"install_native_warning": "Note: Binary optimized for THIS CPU only (not portable)",
|
||||
"install_building_from_source": "Building kt-kernel from source...",
|
||||
"install_build_failed": "Build failed",
|
||||
"install_build_success": "Build completed successfully",
|
||||
# Install - Verification
|
||||
"install_verifying": "Verifying installation...",
|
||||
"install_verify_success": "kt-kernel {version} ({variant} variant) installed successfully",
|
||||
"install_verify_failed": "Verification failed: {error}",
|
||||
# Install - Docker
|
||||
"install_docker_guide_title": "Docker Installation",
|
||||
"install_docker_guide_desc": "For Docker installation, please refer to the official guide:",
|
||||
# Config command
|
||||
"config_show_title": "Current Configuration",
|
||||
"config_set_success": "Configuration updated: {key} = {value}",
|
||||
"config_get_value": "{key} = {value}",
|
||||
"config_get_not_found": "Configuration key '{key}' not found",
|
||||
"config_reset_confirm": "This will reset all configurations to default. Continue?",
|
||||
"config_reset_success": "Configuration reset to default",
|
||||
"config_file_location": "Configuration file: {path}",
|
||||
# Doctor command
|
||||
"doctor_title": "KTransformers Environment Diagnostics",
|
||||
"doctor_checking": "Running diagnostics...",
|
||||
"doctor_check_python": "Python version",
|
||||
"doctor_check_cuda": "CUDA availability",
|
||||
"doctor_check_gpu": "GPU detection",
|
||||
"doctor_check_cpu": "CPU",
|
||||
"doctor_check_cpu_isa": "CPU Instructions",
|
||||
"doctor_check_numa": "NUMA Topology",
|
||||
"doctor_check_memory": "System memory",
|
||||
"doctor_check_disk": "Disk space",
|
||||
"doctor_check_packages": "Required packages",
|
||||
"doctor_check_env": "Environment variables",
|
||||
"doctor_status_ok": "OK",
|
||||
"doctor_status_warning": "Warning",
|
||||
"doctor_status_error": "Error",
|
||||
"doctor_gpu_found": "Found {count} GPU(s): {names}",
|
||||
"doctor_gpu_not_found": "No GPU detected",
|
||||
"doctor_cpu_info": "{name} ({cores} cores / {threads} threads)",
|
||||
"doctor_cpu_isa_info": "{isa_list}",
|
||||
"doctor_cpu_isa_missing": "Missing recommended: {missing}",
|
||||
"doctor_numa_info": "{nodes} node(s)",
|
||||
"doctor_numa_detail": "{node}: CPUs {cpus}",
|
||||
"doctor_memory_info": "{available} available / {total} total",
|
||||
"doctor_memory_freq": "{available} available / {total} total ({freq}MHz {type})",
|
||||
"doctor_disk_info": "{available} available at {path}",
|
||||
"doctor_all_ok": "All checks passed! Your environment is ready.",
|
||||
"doctor_has_issues": "Some issues were found. Please review the warnings/errors above.",
|
||||
# Run command
|
||||
"run_detecting_hardware": "Detecting hardware configuration...",
|
||||
"run_gpu_info": "GPU: {name} ({vram}GB VRAM)",
|
||||
"run_cpu_info": "CPU: {name} ({cores} cores, {numa} NUMA nodes)",
|
||||
"run_ram_info": "RAM: {total}GB",
|
||||
"run_checking_model": "Checking model status...",
|
||||
"run_model_path": "Model path: {path}",
|
||||
"run_weights_not_found": "Quantized weights not found",
|
||||
"run_quant_prompt": "Quantize model now? (This may take a while)",
|
||||
"run_quantizing": "Quantizing model...",
|
||||
"run_starting_server": "Starting server...",
|
||||
"run_server_mode": "Mode: SGLang + kt-kernel",
|
||||
"run_server_port": "Port: {port}",
|
||||
"run_gpu_experts": "GPU experts: {count}/layer",
|
||||
"run_cpu_threads": "CPU threads: {count}",
|
||||
"run_server_started": "Server started!",
|
||||
"run_api_url": "API URL: http://{host}:{port}",
|
||||
"run_docs_url": "Docs URL: http://{host}:{port}/docs",
|
||||
"run_stop_hint": "Press Ctrl+C to stop the server",
|
||||
"run_model_not_found": "Model '{name}' not found. Run 'kt download' first.",
|
||||
"run_multiple_matches": "Multiple models found. Please select:",
|
||||
"run_select_model": "Select model",
|
||||
"run_select_model_title": "Select a model to run",
|
||||
"run_select_model_prompt": "Enter number",
|
||||
"run_local_models": "Local Models (Downloaded)",
|
||||
"run_registered_models": "Registered Models",
|
||||
# Download command
|
||||
"download_list_title": "Available Models",
|
||||
"download_searching": "Searching for model '{name}'...",
|
||||
"download_found": "Found: {name}",
|
||||
"download_multiple_found": "Multiple matches found:",
|
||||
"download_select": "Select model to download:",
|
||||
"download_destination": "Destination: {path}",
|
||||
"download_starting": "Starting download...",
|
||||
"download_progress": "Downloading {name}...",
|
||||
"download_complete": "Download complete!",
|
||||
"download_already_exists": "Model already exists at {path}",
|
||||
"download_overwrite_prompt": "Overwrite existing files?",
|
||||
# Quant command
|
||||
"quant_input_path": "Input path: {path}",
|
||||
"quant_output_path": "Output path: {path}",
|
||||
"quant_method": "Quantization method: {method}",
|
||||
"quant_starting": "Starting quantization...",
|
||||
"quant_progress": "Quantizing...",
|
||||
"quant_complete": "Quantization complete!",
|
||||
"quant_input_not_found": "Input model not found at {path}",
|
||||
# SFT command
|
||||
"sft_mode_train": "Training mode",
|
||||
"sft_mode_chat": "Chat mode",
|
||||
"sft_mode_export": "Export mode",
|
||||
"sft_config_path": "Config file: {path}",
|
||||
"sft_starting": "Starting {mode}...",
|
||||
"sft_complete": "{mode} complete!",
|
||||
"sft_config_not_found": "Config file not found: {path}",
|
||||
# Bench command
|
||||
"bench_starting": "Starting benchmark...",
|
||||
"bench_type": "Benchmark type: {type}",
|
||||
"bench_complete": "Benchmark complete!",
|
||||
"bench_results_title": "Benchmark Results",
|
||||
# Common prompts
|
||||
"prompt_continue": "Continue?",
|
||||
"prompt_select": "Please select:",
|
||||
"prompt_enter_value": "Enter value:",
|
||||
"prompt_confirm_action": "Confirm this action?",
|
||||
# First-run setup - Model path selection
|
||||
"setup_model_path_title": "Model Storage Location",
|
||||
"setup_model_path_desc": "LLM models are large (50-200GB+). Please select a storage location with sufficient space:",
|
||||
"setup_scanning_disks": "Scanning available storage locations...",
|
||||
"setup_disk_option": "{path} ({available} available / {total} total)",
|
||||
"setup_disk_option_recommended": "{path} ({available} available / {total} total) [Recommended]",
|
||||
"setup_custom_path": "Enter custom path",
|
||||
"setup_enter_custom_path": "Enter the path for model storage",
|
||||
"setup_path_not_exist": "Path does not exist. Create it?",
|
||||
"setup_path_no_write": "No write permission for this path. Please choose another.",
|
||||
"setup_path_low_space": "Warning: Less than 100GB available. Large models may not fit.",
|
||||
"setup_model_path_set": "Model storage path set to: {path}",
|
||||
"setup_no_large_disk": "No large storage locations found. Using default path.",
|
||||
"setup_scanning_models": "Scanning for existing models...",
|
||||
"setup_found_models": "Found {count} model(s):",
|
||||
"setup_model_info": "{name} ({size}, {type})",
|
||||
"setup_no_models_found": "No existing models found in this location.",
|
||||
"setup_location_has_models": "{count} model(s) found",
|
||||
"setup_installing_completion": "Installing shell completion for {shell}...",
|
||||
"setup_completion_installed": "Shell completion installed! Restart terminal to enable.",
|
||||
"setup_completion_failed": "Failed to install shell completion. Run 'kt --install-completion' manually.",
|
||||
# Auto completion
|
||||
"completion_installed_title": "Tab Completion",
|
||||
"completion_installed_for": "Shell completion installed for {shell}",
|
||||
"completion_activate_now": "To enable completion in this terminal session, run:",
|
||||
"completion_next_session": "Completion will be automatically enabled in new terminal sessions.",
|
||||
# SGLang
|
||||
"sglang_not_found": "SGLang not found",
|
||||
"sglang_pypi_warning": "SGLang from PyPI may not be compatible with kt-kernel",
|
||||
"sglang_pypi_hint": 'SGLang from PyPI may not be compatible. Install from source: git clone https://github.com/kvcache-ai/sglang && cd sglang && pip install -e "python[all]"',
|
||||
"sglang_install_hint": 'Install SGLang: git clone https://github.com/kvcache-ai/sglang && cd sglang && pip install -e "python[all]"',
|
||||
"sglang_recommend_source": 'Recommend reinstalling from source: git clone https://github.com/kvcache-ai/sglang && cd sglang && pip install -e "python[all]"',
|
||||
"sglang_kt_kernel_not_supported": "SGLang does not support kt-kernel (missing --kt-gpu-prefill-token-threshold parameter)",
|
||||
"sglang_checking_kt_kernel_support": "Checking SGLang kt-kernel support...",
|
||||
"sglang_kt_kernel_supported": "SGLang kt-kernel support verified",
|
||||
# Chat
|
||||
"chat_proxy_detected": "Proxy detected in environment",
|
||||
"chat_proxy_confirm": "Use proxy for connection?",
|
||||
"chat_proxy_disabled": "Proxy disabled for this session",
|
||||
# Model command
|
||||
"model_supported_title": "KTransformers Supported Models",
|
||||
"model_column_model": "Model",
|
||||
"model_column_status": "Status",
|
||||
"model_column_local_path": "Local Path",
|
||||
"model_status_local": "Local",
|
||||
"model_status_not_downloaded": "Not downloaded",
|
||||
"model_usage_title": "Usage",
|
||||
"model_usage_download": "Download a model:",
|
||||
"model_usage_list_local": "List local models:",
|
||||
"model_usage_search": "Search models:",
|
||||
"model_storage_paths_title": "Model Storage Paths",
|
||||
"model_local_models_title": "Locally Downloaded Models",
|
||||
"model_available_models_title": "Available Models",
|
||||
"model_no_local_models": "No locally downloaded models found",
|
||||
"model_download_hint": "Download a model with:",
|
||||
"model_download_usage_hint": "Usage: kt model download <model-name>",
|
||||
"model_download_list_hint": "Use 'kt model download --list' to see available models.",
|
||||
"model_download_hf_hint": "Or specify a HuggingFace repo directly: kt model download org/model-name",
|
||||
"model_saved_to": "Model saved to: {path}",
|
||||
"model_start_with": "Start with: kt run {name}",
|
||||
"model_download_failed": "Download failed: {error}",
|
||||
"model_hf_cli_not_found": "huggingface-cli not found. Install with: pip install huggingface-hub",
|
||||
"model_path_not_exist": "Path does not exist: {path}",
|
||||
"model_create_directory": "Create directory {path}?",
|
||||
"model_created_directory": "Created directory: {path}",
|
||||
"model_create_dir_failed": "Failed to create directory: {error}",
|
||||
"model_path_added": "Added model path: {path}",
|
||||
"model_path_removed": "Removed model path: {path}",
|
||||
"model_path_not_found": "Path not found in configuration or cannot remove last path: {path}",
|
||||
"model_search_no_results": "No models found matching '{query}'",
|
||||
"model_search_results_title": "Search Results for '{query}'",
|
||||
"model_column_name": "Name",
|
||||
"model_column_hf_repo": "HuggingFace Repo",
|
||||
"model_column_aliases": "Aliases",
|
||||
# Coming soon
|
||||
"feature_coming_soon": "This feature is coming soon...",
|
||||
},
|
||||
"zh": {
|
||||
# General
|
||||
"welcome": "欢迎使用 KTransformers!",
|
||||
"goodbye": "再见!",
|
||||
"error": "错误",
|
||||
"warning": "警告",
|
||||
"success": "成功",
|
||||
"info": "信息",
|
||||
"yes": "是",
|
||||
"no": "否",
|
||||
"cancel": "取消",
|
||||
"confirm": "确认",
|
||||
"done": "完成",
|
||||
"failed": "失败",
|
||||
"skip": "跳过",
|
||||
"back": "返回",
|
||||
"next": "下一步",
|
||||
"retry": "重试",
|
||||
"abort": "中止",
|
||||
# Version command
|
||||
"version_info": "KTransformers CLI",
|
||||
"version_python": "Python",
|
||||
"version_platform": "平台",
|
||||
"version_cuda": "CUDA",
|
||||
"version_cuda_not_found": "未找到",
|
||||
"version_kt_kernel": "kt-kernel",
|
||||
"version_ktransformers": "ktransformers",
|
||||
"version_sglang": "sglang",
|
||||
"version_llamafactory": "llamafactory",
|
||||
"version_not_installed": "未安装",
|
||||
# Install command
|
||||
"install_detecting_env": "检测环境管理工具...",
|
||||
"install_found": "发现 {name} (版本 {version})",
|
||||
"install_not_found": "未找到: {name}",
|
||||
"install_checking_env": "检查现有环境...",
|
||||
"install_env_exists": "发现现有 'kt' 环境",
|
||||
"install_env_not_exists": "未发现 'kt' 环境",
|
||||
"install_no_env_manager": "未检测到虚拟环境管理工具",
|
||||
"install_select_method": "请选择安装方式:",
|
||||
"install_method_conda": "创建新的 conda 环境 'kt' (推荐)",
|
||||
"install_method_venv": "创建新的 venv 环境",
|
||||
"install_method_uv": "创建新的 uv 环境 (快速)",
|
||||
"install_method_docker": "使用 Docker 容器",
|
||||
"install_method_system": "安装到系统 Python (不推荐)",
|
||||
"install_select_mode": "请选择安装模式:",
|
||||
"install_mode_inference": "推理模式 - 安装 kt-kernel + SGLang",
|
||||
"install_mode_sft": "训练模式 - 安装 kt-sft + LlamaFactory",
|
||||
"install_mode_full": "完整安装 - 安装所有组件",
|
||||
"install_creating_env": "正在创建 {type} 环境 '{name}'...",
|
||||
"install_env_created": "环境创建成功",
|
||||
"install_installing_deps": "正在安装依赖...",
|
||||
"install_checking_deps": "检查依赖版本...",
|
||||
"install_dep_ok": "正常",
|
||||
"install_dep_outdated": "需更新",
|
||||
"install_dep_missing": "缺失",
|
||||
"install_installing_pytorch": "正在安装 PyTorch...",
|
||||
"install_installing_from_requirements": "从依赖文件安装...",
|
||||
"install_deps_outdated": "发现 {count} 个包需要更新,是否继续?",
|
||||
"install_updating": "正在更新包...",
|
||||
"install_complete": "安装完成!",
|
||||
"install_activate_hint": "激活环境: {command}",
|
||||
"install_start_hint": "开始使用: kt run --help",
|
||||
"install_docker_pulling": "正在拉取 Docker 镜像...",
|
||||
"install_docker_complete": "Docker 镜像已就绪!",
|
||||
"install_docker_run_hint": "运行: docker run --gpus all -p 30000:30000 {image} kt run {model}",
|
||||
"install_in_venv": "当前在虚拟环境中: {name}",
|
||||
"install_continue_without_venv": "继续安装到系统 Python?",
|
||||
"install_already_installed": "所有依赖已安装!",
|
||||
"install_confirm": "安装 {count} 个包?",
|
||||
# Install - System dependencies
|
||||
"install_checking_system_deps": "检查系统依赖...",
|
||||
"install_dep_name": "依赖项",
|
||||
"install_dep_status": "状态",
|
||||
"install_deps_all_installed": "所有系统依赖已安装",
|
||||
"install_deps_install_prompt": "是否安装缺失的依赖?",
|
||||
"install_installing_system_deps": "正在安装系统依赖...",
|
||||
"install_installing_dep": "正在安装 {name}",
|
||||
"install_dep_no_install_cmd": "{os} 系统上没有 {name} 的安装命令",
|
||||
"install_dep_install_failed": "安装 {name} 失败",
|
||||
"install_deps_skipped": "跳过依赖安装",
|
||||
"install_deps_failed": "系统依赖安装失败",
|
||||
# Install - CPU detection
|
||||
"install_auto_detect_cpu": "正在自动检测 CPU 能力...",
|
||||
"install_cpu_features": "检测到的 CPU 特性: {features}",
|
||||
"install_cpu_no_features": "未检测到高级 CPU 特性",
|
||||
# Install - Build configuration
|
||||
"install_build_config": "构建配置:",
|
||||
"install_native_warning": "注意: 二进制文件仅针对当前 CPU 优化(不可移植)",
|
||||
"install_building_from_source": "正在从源码构建 kt-kernel...",
|
||||
"install_build_failed": "构建失败",
|
||||
"install_build_success": "构建成功",
|
||||
# Install - Verification
|
||||
"install_verifying": "正在验证安装...",
|
||||
"install_verify_success": "kt-kernel {version} ({variant} 变体) 安装成功",
|
||||
"install_verify_failed": "验证失败: {error}",
|
||||
# Install - Docker
|
||||
"install_docker_guide_title": "Docker 安装",
|
||||
"install_docker_guide_desc": "有关 Docker 安装,请参阅官方指南:",
|
||||
# Config command
|
||||
"config_show_title": "当前配置",
|
||||
"config_set_success": "配置已更新: {key} = {value}",
|
||||
"config_get_value": "{key} = {value}",
|
||||
"config_get_not_found": "未找到配置项 '{key}'",
|
||||
"config_reset_confirm": "这将重置所有配置为默认值。是否继续?",
|
||||
"config_reset_success": "配置已重置为默认值",
|
||||
"config_file_location": "配置文件: {path}",
|
||||
# Doctor command
|
||||
"doctor_title": "KTransformers 环境诊断",
|
||||
"doctor_checking": "正在运行诊断...",
|
||||
"doctor_check_python": "Python 版本",
|
||||
"doctor_check_cuda": "CUDA 可用性",
|
||||
"doctor_check_gpu": "GPU 检测",
|
||||
"doctor_check_cpu": "CPU",
|
||||
"doctor_check_cpu_isa": "CPU 指令集",
|
||||
"doctor_check_numa": "NUMA 拓扑",
|
||||
"doctor_check_memory": "系统内存",
|
||||
"doctor_check_disk": "磁盘空间",
|
||||
"doctor_check_packages": "必需的包",
|
||||
"doctor_check_env": "环境变量",
|
||||
"doctor_status_ok": "正常",
|
||||
"doctor_status_warning": "警告",
|
||||
"doctor_status_error": "错误",
|
||||
"doctor_gpu_found": "发现 {count} 个 GPU: {names}",
|
||||
"doctor_gpu_not_found": "未检测到 GPU",
|
||||
"doctor_cpu_info": "{name} ({cores} 核心 / {threads} 线程)",
|
||||
"doctor_cpu_isa_info": "{isa_list}",
|
||||
"doctor_cpu_isa_missing": "缺少推荐指令集: {missing}",
|
||||
"doctor_numa_info": "{nodes} 个节点",
|
||||
"doctor_numa_detail": "{node}: CPU {cpus}",
|
||||
"doctor_memory_info": "{available} 可用 / {total} 总计",
|
||||
"doctor_memory_freq": "{available} 可用 / {total} 总计 ({freq}MHz {type})",
|
||||
"doctor_disk_info": "{path} 有 {available} 可用空间",
|
||||
"doctor_all_ok": "所有检查通过!您的环境已就绪。",
|
||||
"doctor_has_issues": "发现一些问题,请查看上方的警告/错误信息。",
|
||||
# Run command
|
||||
"run_detecting_hardware": "检测硬件配置...",
|
||||
"run_gpu_info": "GPU: {name} ({vram}GB 显存)",
|
||||
"run_cpu_info": "CPU: {name} ({cores} 核心, {numa} NUMA 节点)",
|
||||
"run_ram_info": "内存: {total}GB",
|
||||
"run_checking_model": "检查模型状态...",
|
||||
"run_model_path": "模型路径: {path}",
|
||||
"run_weights_not_found": "未找到量化权重",
|
||||
"run_quant_prompt": "是否现在量化模型?(这可能需要一些时间)",
|
||||
"run_quantizing": "正在量化模型...",
|
||||
"run_starting_server": "正在启动服务器...",
|
||||
"run_server_mode": "模式: SGLang + kt-kernel",
|
||||
"run_server_port": "端口: {port}",
|
||||
"run_gpu_experts": "GPU 专家: {count}/层",
|
||||
"run_cpu_threads": "CPU 线程: {count}",
|
||||
"run_server_started": "服务器已启动!",
|
||||
"run_api_url": "API 地址: http://{host}:{port}",
|
||||
"run_docs_url": "文档地址: http://{host}:{port}/docs",
|
||||
"run_stop_hint": "按 Ctrl+C 停止服务器",
|
||||
"run_model_not_found": "未找到模型 '{name}'。请先运行 'kt download'。",
|
||||
"run_multiple_matches": "找到多个匹配的模型,请选择:",
|
||||
"run_select_model": "选择模型",
|
||||
"run_select_model_title": "选择要运行的模型",
|
||||
"run_select_model_prompt": "输入编号",
|
||||
"run_local_models": "本地模型 (已下载)",
|
||||
"run_registered_models": "注册模型",
|
||||
# Download command
|
||||
"download_list_title": "可用模型",
|
||||
"download_searching": "正在搜索模型 '{name}'...",
|
||||
"download_found": "找到: {name}",
|
||||
"download_multiple_found": "找到多个匹配:",
|
||||
"download_select": "选择要下载的模型:",
|
||||
"download_destination": "目标路径: {path}",
|
||||
"download_starting": "开始下载...",
|
||||
"download_progress": "正在下载 {name}...",
|
||||
"download_complete": "下载完成!",
|
||||
"download_already_exists": "模型已存在于 {path}",
|
||||
"download_overwrite_prompt": "是否覆盖现有文件?",
|
||||
# Quant command
|
||||
"quant_input_path": "输入路径: {path}",
|
||||
"quant_output_path": "输出路径: {path}",
|
||||
"quant_method": "量化方法: {method}",
|
||||
"quant_starting": "开始量化...",
|
||||
"quant_progress": "正在量化...",
|
||||
"quant_complete": "量化完成!",
|
||||
"quant_input_not_found": "未找到输入模型: {path}",
|
||||
# SFT command
|
||||
"sft_mode_train": "训练模式",
|
||||
"sft_mode_chat": "聊天模式",
|
||||
"sft_mode_export": "导出模式",
|
||||
"sft_config_path": "配置文件: {path}",
|
||||
"sft_starting": "正在启动 {mode}...",
|
||||
"sft_complete": "{mode} 完成!",
|
||||
"sft_config_not_found": "未找到配置文件: {path}",
|
||||
# Bench command
|
||||
"bench_starting": "开始基准测试...",
|
||||
"bench_type": "测试类型: {type}",
|
||||
"bench_complete": "基准测试完成!",
|
||||
"bench_results_title": "基准测试结果",
|
||||
# Common prompts
|
||||
"prompt_continue": "是否继续?",
|
||||
"prompt_select": "请选择:",
|
||||
"prompt_enter_value": "请输入:",
|
||||
"prompt_confirm_action": "确认此操作?",
|
||||
# First-run setup - Model path selection
|
||||
"setup_model_path_title": "模型存储位置",
|
||||
"setup_model_path_desc": "大语言模型体积较大(50-200GB+)。请选择一个有足够空间的存储位置:",
|
||||
"setup_scanning_disks": "正在扫描可用存储位置...",
|
||||
"setup_disk_option": "{path} (可用 {available} / 总共 {total})",
|
||||
"setup_disk_option_recommended": "{path} (可用 {available} / 总共 {total}) [推荐]",
|
||||
"setup_custom_path": "输入自定义路径",
|
||||
"setup_enter_custom_path": "请输入模型存储路径",
|
||||
"setup_path_not_exist": "路径不存在,是否创建?",
|
||||
"setup_path_no_write": "没有该路径的写入权限,请选择其他路径。",
|
||||
"setup_path_low_space": "警告:可用空间不足 100GB,可能无法存储大型模型。",
|
||||
"setup_model_path_set": "模型存储路径已设置为: {path}",
|
||||
"setup_no_large_disk": "未发现大容量存储位置,使用默认路径。",
|
||||
"setup_scanning_models": "正在扫描已有模型...",
|
||||
"setup_found_models": "发现 {count} 个模型:",
|
||||
"setup_model_info": "{name} ({size}, {type})",
|
||||
"setup_no_models_found": "该位置未发现已有模型。",
|
||||
"setup_location_has_models": "发现 {count} 个模型",
|
||||
"setup_installing_completion": "正在为 {shell} 安装命令补全...",
|
||||
"setup_completion_installed": "命令补全已安装!重启终端后生效。",
|
||||
"setup_completion_failed": "命令补全安装失败。请手动运行 'kt --install-completion'。",
|
||||
# Auto completion
|
||||
"completion_installed_title": "命令补全",
|
||||
"completion_installed_for": "已为 {shell} 安装命令补全",
|
||||
"completion_activate_now": "在当前终端会话中启用补全,请运行:",
|
||||
"completion_next_session": "新的终端会话将自动启用补全。",
|
||||
# SGLang
|
||||
"sglang_not_found": "未找到 SGLang",
|
||||
"sglang_pypi_warning": "PyPI 版本的 SGLang 可能与 kt-kernel 不兼容",
|
||||
"sglang_pypi_hint": 'PyPI 版本可能不兼容。从源码安装: git clone https://github.com/kvcache-ai/sglang && cd sglang && pip install -e "python[all]"',
|
||||
"sglang_install_hint": '安装 SGLang: git clone https://github.com/kvcache-ai/sglang && cd sglang && pip install -e "python[all]"',
|
||||
"sglang_recommend_source": '建议从源码重新安装: git clone https://github.com/kvcache-ai/sglang && cd sglang && pip install -e "python[all]"',
|
||||
"sglang_kt_kernel_not_supported": "SGLang 不支持 kt-kernel (缺少 --kt-gpu-prefill-token-threshold 参数)",
|
||||
"sglang_checking_kt_kernel_support": "正在检查 SGLang kt-kernel 支持...",
|
||||
"sglang_kt_kernel_supported": "SGLang kt-kernel 支持已验证",
|
||||
# Chat
|
||||
"chat_proxy_detected": "检测到环境中存在代理设置",
|
||||
"chat_proxy_confirm": "是否使用代理连接?",
|
||||
"chat_proxy_disabled": "已在本次会话中禁用代理",
|
||||
# Model command
|
||||
"model_supported_title": "KTransformers 支持的模型",
|
||||
"model_column_model": "模型",
|
||||
"model_column_status": "状态",
|
||||
"model_column_local_path": "本地路径",
|
||||
"model_status_local": "本地",
|
||||
"model_status_not_downloaded": "未下载",
|
||||
"model_usage_title": "使用方法",
|
||||
"model_usage_download": "下载模型:",
|
||||
"model_usage_list_local": "列出本地模型:",
|
||||
"model_usage_search": "搜索模型:",
|
||||
"model_storage_paths_title": "模型存储路径",
|
||||
"model_local_models_title": "本地已下载的模型",
|
||||
"model_available_models_title": "可用模型",
|
||||
"model_no_local_models": "未找到本地已下载的模型",
|
||||
"model_download_hint": "下载模型:",
|
||||
"model_download_usage_hint": "用法: kt model download <模型名称>",
|
||||
"model_download_list_hint": "使用 'kt model download --list' 查看可用模型。",
|
||||
"model_download_hf_hint": "或直接指定 HuggingFace 仓库: kt model download org/model-name",
|
||||
"model_saved_to": "模型已保存到: {path}",
|
||||
"model_start_with": "启动命令: kt run {name}",
|
||||
"model_download_failed": "下载失败: {error}",
|
||||
"model_hf_cli_not_found": "未找到 huggingface-cli。请安装: pip install huggingface-hub",
|
||||
"model_path_not_exist": "路径不存在: {path}",
|
||||
"model_create_directory": "创建目录 {path}?",
|
||||
"model_created_directory": "已创建目录: {path}",
|
||||
"model_create_dir_failed": "创建目录失败: {error}",
|
||||
"model_path_added": "已添加模型路径: {path}",
|
||||
"model_path_removed": "已移除模型路径: {path}",
|
||||
"model_path_not_found": "路径未找到或无法移除最后一个路径: {path}",
|
||||
"model_search_no_results": "未找到匹配 '{query}' 的模型",
|
||||
"model_search_results_title": "'{query}' 的搜索结果",
|
||||
"model_column_name": "名称",
|
||||
"model_column_hf_repo": "HuggingFace 仓库",
|
||||
"model_column_aliases": "别名",
|
||||
# Coming soon
|
||||
"feature_coming_soon": "此功能即将推出...",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
# Cache for language detection to avoid repeated I/O
|
||||
_lang_cache: str | None = None
|
||||
|
||||
|
||||
def get_lang() -> str:
|
||||
"""
|
||||
Detect the current language setting.
|
||||
|
||||
Priority:
|
||||
1. KT_LANG environment variable
|
||||
2. Config file general.language setting
|
||||
3. LANG environment variable (if config is "auto")
|
||||
4. Default to English
|
||||
|
||||
Returns:
|
||||
Language code: "zh" for Chinese, "en" for English
|
||||
"""
|
||||
global _lang_cache
|
||||
|
||||
# 1. Check KT_LANG environment variable (highest priority)
|
||||
kt_lang = os.environ.get("KT_LANG", "").lower()
|
||||
if kt_lang:
|
||||
return "zh" if kt_lang.startswith("zh") else "en"
|
||||
|
||||
# 2. Return cached value if available (avoids I/O on every call)
|
||||
if _lang_cache is not None:
|
||||
return _lang_cache
|
||||
|
||||
# 3. Check config file setting (with caching)
|
||||
# Import here to avoid circular imports
|
||||
from kt_kernel.cli.config.settings import get_settings
|
||||
|
||||
try:
|
||||
settings = get_settings()
|
||||
config_lang = settings.get("general.language", "auto")
|
||||
if config_lang and config_lang != "auto":
|
||||
lang = "zh" if config_lang.lower().startswith("zh") else "en"
|
||||
_lang_cache = lang
|
||||
return lang
|
||||
except Exception:
|
||||
# If settings fail to load, continue with system detection
|
||||
pass
|
||||
|
||||
# 4. Check system LANG environment variable
|
||||
system_lang = os.environ.get("LANG", "").lower()
|
||||
lang = "zh" if system_lang.startswith("zh") else "en"
|
||||
_lang_cache = lang
|
||||
return lang
|
||||
|
||||
|
||||
def t(msg_key: str, **kwargs: Any) -> str:
|
||||
"""
|
||||
Translate a message key to the current language.
|
||||
|
||||
Args:
|
||||
msg_key: Message key to translate
|
||||
**kwargs: Format arguments for the message
|
||||
|
||||
Returns:
|
||||
Translated and formatted message string
|
||||
|
||||
Example:
|
||||
>>> t("welcome")
|
||||
"Welcome to KTransformers!" # or "欢迎使用 KTransformers!" in Chinese
|
||||
|
||||
>>> t("install_found", name="conda", version="24.1.0")
|
||||
"Found conda (version 24.1.0)"
|
||||
"""
|
||||
lang = get_lang()
|
||||
messages = MESSAGES.get(lang, MESSAGES["en"])
|
||||
message = messages.get(msg_key, MESSAGES["en"].get(msg_key, msg_key))
|
||||
|
||||
if kwargs:
|
||||
try:
|
||||
return message.format(**kwargs)
|
||||
except KeyError:
|
||||
return message
|
||||
return message
|
||||
|
||||
|
||||
def set_lang(lang: str) -> None:
|
||||
"""
|
||||
Set the language for the current session.
|
||||
|
||||
Args:
|
||||
lang: Language code ("en" or "zh")
|
||||
"""
|
||||
global _lang_cache
|
||||
os.environ["KT_LANG"] = lang
|
||||
_lang_cache = lang # Update cache when language is explicitly set
|
||||
436
kt-kernel/python/cli/main.py
Normal file
436
kt-kernel/python/cli/main.py
Normal file
@@ -0,0 +1,436 @@
|
||||
"""
|
||||
Main entry point for kt-cli.
|
||||
|
||||
KTransformers CLI - A unified command-line interface for KTransformers.
|
||||
"""
|
||||
|
||||
import sys
|
||||
|
||||
import typer
|
||||
|
||||
from kt_kernel.cli import __version__
|
||||
from kt_kernel.cli.commands import bench, chat, config, doctor, model, quant, run, sft, version
|
||||
from kt_kernel.cli.i18n import t, set_lang, get_lang
|
||||
|
||||
|
||||
def _get_app_help() -> str:
|
||||
"""Get app help text based on current language."""
|
||||
lang = get_lang()
|
||||
if lang == "zh":
|
||||
return "KTransformers CLI - KTransformers 统一命令行界面"
|
||||
return "KTransformers CLI - A unified command-line interface for KTransformers."
|
||||
|
||||
|
||||
def _get_help(key: str) -> str:
|
||||
"""Get help text based on current language."""
|
||||
help_texts = {
|
||||
"version": {"en": "Show version information", "zh": "显示版本信息"},
|
||||
"run": {"en": "Start model inference server", "zh": "启动模型推理服务器"},
|
||||
"chat": {"en": "Interactive chat with running model", "zh": "与运行中的模型进行交互式聊天"},
|
||||
"quant": {"en": "Quantize model weights", "zh": "量化模型权重"},
|
||||
"bench": {"en": "Run full benchmark", "zh": "运行完整基准测试"},
|
||||
"microbench": {"en": "Run micro-benchmark", "zh": "运行微基准测试"},
|
||||
"doctor": {"en": "Diagnose environment issues", "zh": "诊断环境问题"},
|
||||
"model": {"en": "Manage models and storage paths", "zh": "管理模型和存储路径"},
|
||||
"config": {"en": "Manage configuration", "zh": "管理配置"},
|
||||
"sft": {"en": "Fine-tuning with LlamaFactory", "zh": "使用 LlamaFactory 进行微调"},
|
||||
}
|
||||
lang = get_lang()
|
||||
return help_texts.get(key, {}).get(lang, help_texts.get(key, {}).get("en", key))
|
||||
|
||||
|
||||
# Create main app with dynamic help
|
||||
app = typer.Typer(
|
||||
name="kt",
|
||||
help="KTransformers CLI - A unified command-line interface for KTransformers.",
|
||||
no_args_is_help=True,
|
||||
add_completion=False, # Use static completion scripts instead of dynamic completion
|
||||
rich_markup_mode="rich",
|
||||
)
|
||||
|
||||
|
||||
def _update_help_texts() -> None:
|
||||
"""Update all help texts based on current language setting."""
|
||||
# Update main app help
|
||||
app.info.help = _get_app_help()
|
||||
|
||||
# Update command help texts
|
||||
for cmd_info in app.registered_commands:
|
||||
# cmd_info is a CommandInfo object
|
||||
if hasattr(cmd_info, "name") and cmd_info.name:
|
||||
cmd_info.help = _get_help(cmd_info.name)
|
||||
|
||||
# Update sub-app help texts
|
||||
for group_info in app.registered_groups:
|
||||
if hasattr(group_info, "name") and group_info.name:
|
||||
group_info.help = _get_help(group_info.name)
|
||||
|
||||
|
||||
# Register commands
|
||||
app.command(name="version", help="Show version information")(version.version)
|
||||
app.command(name="run", help="Start model inference server")(run.run)
|
||||
app.command(name="chat", help="Interactive chat with running model")(chat.chat)
|
||||
app.command(name="quant", help="Quantize model weights")(quant.quant)
|
||||
app.command(name="bench", help="Run full benchmark")(bench.bench)
|
||||
app.command(name="microbench", help="Run micro-benchmark")(bench.microbench)
|
||||
app.command(name="doctor", help="Diagnose environment issues")(doctor.doctor)
|
||||
|
||||
# Register sub-apps
|
||||
app.add_typer(model.app, name="model", help="Manage models and storage paths")
|
||||
app.add_typer(config.app, name="config", help="Manage configuration")
|
||||
app.add_typer(sft.app, name="sft", help="Fine-tuning with LlamaFactory")
|
||||
|
||||
|
||||
def check_first_run() -> None:
|
||||
"""Check if this is the first run and prompt for language setup."""
|
||||
import os
|
||||
|
||||
# Skip if not running in interactive terminal
|
||||
if not sys.stdin.isatty():
|
||||
return
|
||||
|
||||
from kt_kernel.cli.config.settings import DEFAULT_CONFIG_FILE
|
||||
|
||||
# Only check if config file exists - don't create it yet
|
||||
if not DEFAULT_CONFIG_FILE.exists():
|
||||
# First run - show welcome and language selection
|
||||
from kt_kernel.cli.config.settings import get_settings
|
||||
|
||||
settings = get_settings()
|
||||
_show_first_run_setup(settings)
|
||||
else:
|
||||
# Config exists - check if initialized
|
||||
from kt_kernel.cli.config.settings import get_settings
|
||||
|
||||
settings = get_settings()
|
||||
if not settings.get("general._initialized"):
|
||||
_show_first_run_setup(settings)
|
||||
|
||||
|
||||
def _show_first_run_setup(settings) -> None:
|
||||
"""Show first-run setup wizard."""
|
||||
from rich.console import Console
|
||||
from rich.panel import Panel
|
||||
from rich.prompt import Prompt, Confirm
|
||||
from rich.spinner import Spinner
|
||||
from rich.live import Live
|
||||
|
||||
from kt_kernel.cli.utils.environment import scan_storage_locations, format_size_gb, scan_models_in_location
|
||||
|
||||
console = Console()
|
||||
|
||||
# Welcome message
|
||||
console.print()
|
||||
console.print(
|
||||
Panel.fit(
|
||||
"[bold cyan]Welcome to KTransformers CLI! / 欢迎使用 KTransformers CLI![/bold cyan]\n\n"
|
||||
"Let's set up your preferences.\n"
|
||||
"让我们设置您的偏好。",
|
||||
title="kt-cli",
|
||||
border_style="cyan",
|
||||
)
|
||||
)
|
||||
console.print()
|
||||
|
||||
# Language selection
|
||||
console.print("[bold]Select your preferred language / 选择您的首选语言:[/bold]")
|
||||
console.print()
|
||||
console.print(" [cyan][1][/cyan] English")
|
||||
console.print(" [cyan][2][/cyan] 中文 (Chinese)")
|
||||
console.print()
|
||||
|
||||
while True:
|
||||
choice = Prompt.ask("Enter choice / 输入选择", choices=["1", "2"], default="1")
|
||||
|
||||
if choice == "1":
|
||||
lang = "en"
|
||||
break
|
||||
elif choice == "2":
|
||||
lang = "zh"
|
||||
break
|
||||
|
||||
# Save language setting
|
||||
settings.set("general.language", lang)
|
||||
set_lang(lang)
|
||||
|
||||
# Confirmation message
|
||||
console.print()
|
||||
if lang == "zh":
|
||||
console.print("[green]✓[/green] 语言已设置为中文")
|
||||
else:
|
||||
console.print("[green]✓[/green] Language set to English")
|
||||
|
||||
# Model storage path selection
|
||||
console.print()
|
||||
console.print(f"[bold]{t('setup_model_path_title')}[/bold]")
|
||||
console.print()
|
||||
console.print(f"[dim]{t('setup_model_path_desc')}[/dim]")
|
||||
console.print()
|
||||
|
||||
# Scan for storage locations
|
||||
console.print(f"[dim]{t('setup_scanning_disks')}[/dim]")
|
||||
locations = scan_storage_locations(min_size_gb=50.0)
|
||||
console.print()
|
||||
|
||||
if locations:
|
||||
# Scan for models in each location
|
||||
console.print(f"[dim]{t('setup_scanning_models')}[/dim]")
|
||||
location_models: dict[str, list] = {}
|
||||
for loc in locations[:5]:
|
||||
models = scan_models_in_location(loc, max_depth=2)
|
||||
if models:
|
||||
location_models[loc.path] = models
|
||||
console.print()
|
||||
|
||||
# Show options
|
||||
for i, loc in enumerate(locations[:5], 1): # Show top 5 options
|
||||
available = format_size_gb(loc.available_gb)
|
||||
total = format_size_gb(loc.total_gb)
|
||||
|
||||
# Build the option string
|
||||
if i == 1:
|
||||
option_str = t("setup_disk_option_recommended", path=loc.path, available=available, total=total)
|
||||
else:
|
||||
option_str = t("setup_disk_option", path=loc.path, available=available, total=total)
|
||||
|
||||
# Add model count if any
|
||||
if loc.path in location_models:
|
||||
model_count = len(location_models[loc.path])
|
||||
option_str += f" [green]✓ {t('setup_location_has_models', count=model_count)}[/green]"
|
||||
|
||||
console.print(f" [cyan][{i}][/cyan] {option_str}")
|
||||
|
||||
# Show first few models found in this location
|
||||
if loc.path in location_models:
|
||||
for model in location_models[loc.path][:3]: # Show up to 3 models
|
||||
size_str = format_size_gb(model.size_gb)
|
||||
console.print(f" [dim]• {model.name} ({size_str})[/dim]")
|
||||
if len(location_models[loc.path]) > 3:
|
||||
remaining = len(location_models[loc.path]) - 3
|
||||
console.print(f" [dim] ... +{remaining} more[/dim]")
|
||||
|
||||
# Custom path option
|
||||
custom_idx = min(len(locations), 5) + 1
|
||||
console.print(f" [cyan][{custom_idx}][/cyan] {t('setup_custom_path')}")
|
||||
console.print()
|
||||
|
||||
valid_choices = [str(i) for i in range(1, custom_idx + 1)]
|
||||
path_choice = Prompt.ask(t("prompt_select"), choices=valid_choices, default="1")
|
||||
|
||||
if path_choice == str(custom_idx):
|
||||
# Custom path
|
||||
selected_path = _prompt_custom_path(console, settings)
|
||||
else:
|
||||
selected_path = locations[int(path_choice) - 1].path
|
||||
else:
|
||||
# No large storage found, ask for custom path
|
||||
console.print(f"[yellow]{t('setup_no_large_disk')}[/yellow]")
|
||||
console.print()
|
||||
selected_path = _prompt_custom_path(console, settings)
|
||||
|
||||
# Ensure the path exists
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
if not os.path.exists(selected_path):
|
||||
if Confirm.ask(t("setup_path_not_exist"), default=True):
|
||||
try:
|
||||
Path(selected_path).mkdir(parents=True, exist_ok=True)
|
||||
except (OSError, PermissionError) as e:
|
||||
console.print(f"[red]{t('error')}: {e}[/red]")
|
||||
# Fall back to default
|
||||
selected_path = str(Path.home() / ".ktransformers" / "models")
|
||||
Path(selected_path).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Check available space and warn if low
|
||||
from kt_kernel.cli.utils.environment import detect_disk_space_gb
|
||||
|
||||
available_gb, _ = detect_disk_space_gb(
|
||||
selected_path if os.path.exists(selected_path) else str(Path(selected_path).parent)
|
||||
)
|
||||
if available_gb < 100:
|
||||
console.print(f"[yellow]{t('setup_path_low_space')}[/yellow]")
|
||||
|
||||
# Save the path
|
||||
settings.set("paths.models", selected_path)
|
||||
settings.set("general._initialized", True)
|
||||
|
||||
console.print()
|
||||
console.print(f"[green]✓[/green] {t('setup_model_path_set', path=selected_path)}")
|
||||
console.print()
|
||||
|
||||
# Tips
|
||||
if lang == "zh":
|
||||
console.print("[dim]提示: 运行 'kt config show' 查看所有配置[/dim]")
|
||||
else:
|
||||
console.print("[dim]Tip: Run 'kt config show' to view all settings[/dim]")
|
||||
|
||||
console.print()
|
||||
|
||||
|
||||
def _prompt_custom_path(console, settings) -> str:
|
||||
"""Prompt user to enter a custom path."""
|
||||
from rich.prompt import Prompt
|
||||
from pathlib import Path
|
||||
import os
|
||||
|
||||
default_path = str(Path.home() / ".ktransformers" / "models")
|
||||
|
||||
while True:
|
||||
custom_path = Prompt.ask(t("setup_enter_custom_path"), default=default_path)
|
||||
|
||||
# Expand user home
|
||||
custom_path = os.path.expanduser(custom_path)
|
||||
|
||||
# Check if path exists or parent is writable
|
||||
if os.path.exists(custom_path):
|
||||
if os.access(custom_path, os.W_OK):
|
||||
return custom_path
|
||||
else:
|
||||
console.print(f"[red]{t('setup_path_no_write')}[/red]")
|
||||
else:
|
||||
# Check if we can create it (parent writable)
|
||||
parent = str(Path(custom_path).parent)
|
||||
while not os.path.exists(parent) and parent != "/":
|
||||
parent = str(Path(parent).parent)
|
||||
|
||||
if os.access(parent, os.W_OK):
|
||||
return custom_path
|
||||
else:
|
||||
console.print(f"[red]{t('setup_path_no_write')}[/red]")
|
||||
|
||||
|
||||
def _install_shell_completion() -> None:
|
||||
"""Install shell completion scripts to user directories.
|
||||
|
||||
Uses standard locations that are auto-loaded by shell completion systems:
|
||||
- Bash: ~/.local/share/bash-completion/completions/kt (auto-loaded by bash-completion 2.0+)
|
||||
- Zsh: ~/.zfunc/_kt (requires fpath setup, but commonly used)
|
||||
- Fish: ~/.config/fish/completions/kt.fish (auto-loaded)
|
||||
"""
|
||||
import os
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
from kt_kernel.cli.config.settings import get_settings
|
||||
|
||||
settings = get_settings()
|
||||
|
||||
# Check if already installed
|
||||
if settings.get("general._completion_installed", False):
|
||||
return
|
||||
|
||||
# Detect current shell
|
||||
shell = os.environ.get("SHELL", "")
|
||||
if "zsh" in shell:
|
||||
shell_name = "zsh"
|
||||
elif "fish" in shell:
|
||||
shell_name = "fish"
|
||||
else:
|
||||
shell_name = "bash"
|
||||
|
||||
try:
|
||||
cli_dir = Path(__file__).parent
|
||||
completions_dir = cli_dir / "completions"
|
||||
home = Path.home()
|
||||
|
||||
installed = False
|
||||
|
||||
if shell_name == "bash":
|
||||
# Use XDG standard location for bash-completion (auto-loaded)
|
||||
src_file = completions_dir / "kt-completion.bash"
|
||||
dest_dir = home / ".local" / "share" / "bash-completion" / "completions"
|
||||
dest_file = dest_dir / "kt"
|
||||
|
||||
if src_file.exists():
|
||||
dest_dir.mkdir(parents=True, exist_ok=True)
|
||||
shutil.copy2(src_file, dest_file)
|
||||
installed = True
|
||||
|
||||
elif shell_name == "zsh":
|
||||
src_file = completions_dir / "_kt"
|
||||
dest_dir = home / ".zfunc"
|
||||
dest_file = dest_dir / "_kt"
|
||||
|
||||
if src_file.exists():
|
||||
dest_dir.mkdir(parents=True, exist_ok=True)
|
||||
shutil.copy2(src_file, dest_file)
|
||||
installed = True
|
||||
|
||||
elif shell_name == "fish":
|
||||
# Fish auto-loads from this directory
|
||||
src_file = completions_dir / "kt.fish"
|
||||
dest_dir = home / ".config" / "fish" / "completions"
|
||||
dest_file = dest_dir / "kt.fish"
|
||||
|
||||
if src_file.exists():
|
||||
dest_dir.mkdir(parents=True, exist_ok=True)
|
||||
shutil.copy2(src_file, dest_file)
|
||||
installed = True
|
||||
|
||||
# Mark as installed
|
||||
settings.set("general._completion_installed", True)
|
||||
|
||||
# For bash/zsh, completion will work in new terminals automatically
|
||||
# (bash-completion 2.0+ auto-loads from ~/.local/share/bash-completion/completions/)
|
||||
|
||||
except (OSError, IOError):
|
||||
# Silently ignore errors - completion is not critical
|
||||
pass
|
||||
|
||||
|
||||
def _apply_saved_language() -> None:
|
||||
"""Apply the saved language setting.
|
||||
|
||||
Priority:
|
||||
1. KT_LANG environment variable (if already set, don't override)
|
||||
2. Config file setting
|
||||
3. System locale (auto)
|
||||
"""
|
||||
import os
|
||||
|
||||
# Don't override if KT_LANG is already set by user
|
||||
if os.environ.get("KT_LANG"):
|
||||
return
|
||||
|
||||
from kt_kernel.cli.config.settings import get_settings
|
||||
|
||||
settings = get_settings()
|
||||
lang = settings.get("general.language", "auto")
|
||||
|
||||
if lang != "auto":
|
||||
set_lang(lang)
|
||||
|
||||
|
||||
def main():
|
||||
"""Main entry point."""
|
||||
# Apply saved language setting first (before anything else for correct help display)
|
||||
_apply_saved_language()
|
||||
|
||||
# Update help texts based on language
|
||||
_update_help_texts()
|
||||
|
||||
# Check for first run (but not for certain commands)
|
||||
# Skip first-run check for: --help, config commands, version
|
||||
args = sys.argv[1:] if len(sys.argv) > 1 else []
|
||||
skip_commands = ["--help", "-h", "config", "version", "--version"]
|
||||
|
||||
should_check_first_run = True
|
||||
for arg in args:
|
||||
if arg in skip_commands:
|
||||
should_check_first_run = False
|
||||
break
|
||||
|
||||
# Auto-install shell completion on first run
|
||||
if should_check_first_run:
|
||||
_install_shell_completion()
|
||||
|
||||
# Check first run before running commands
|
||||
if should_check_first_run and args:
|
||||
check_first_run()
|
||||
|
||||
app()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
6
kt-kernel/python/cli/requirements/inference.txt
Normal file
6
kt-kernel/python/cli/requirements/inference.txt
Normal file
@@ -0,0 +1,6 @@
|
||||
# Inference dependencies for KTransformers
|
||||
# NOTE: sglang is installed separately from source (see install.py)
|
||||
|
||||
transformers>=4.45.0
|
||||
safetensors>=0.4.0
|
||||
huggingface-hub>=0.20.0
|
||||
7
kt-kernel/python/cli/requirements/sft.txt
Normal file
7
kt-kernel/python/cli/requirements/sft.txt
Normal file
@@ -0,0 +1,7 @@
|
||||
# SFT (Supervised Fine-Tuning) dependencies for KTransformers
|
||||
|
||||
llamafactory>=0.9.0
|
||||
peft>=0.12.0
|
||||
transformers>=4.45.0
|
||||
datasets>=2.14.0
|
||||
accelerate>=0.30.0
|
||||
3
kt-kernel/python/cli/utils/__init__.py
Normal file
3
kt-kernel/python/cli/utils/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
"""
|
||||
Utility modules for kt-cli.
|
||||
"""
|
||||
249
kt-kernel/python/cli/utils/console.py
Normal file
249
kt-kernel/python/cli/utils/console.py
Normal file
@@ -0,0 +1,249 @@
|
||||
"""
|
||||
Console utilities for kt-cli.
|
||||
|
||||
Provides Rich-based console output helpers for consistent formatting.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from rich.console import Console
|
||||
from rich.panel import Panel
|
||||
from rich.progress import (
|
||||
BarColumn,
|
||||
DownloadColumn,
|
||||
Progress,
|
||||
SpinnerColumn,
|
||||
TaskProgressColumn,
|
||||
TextColumn,
|
||||
TimeElapsedColumn,
|
||||
TimeRemainingColumn,
|
||||
TransferSpeedColumn,
|
||||
)
|
||||
from rich.prompt import Confirm, Prompt
|
||||
from rich.table import Table
|
||||
from rich.theme import Theme
|
||||
|
||||
from kt_kernel.cli.i18n import t
|
||||
|
||||
# Custom theme for kt-cli
|
||||
KT_THEME = Theme(
|
||||
{
|
||||
"info": "cyan",
|
||||
"warning": "yellow",
|
||||
"error": "bold red",
|
||||
"success": "bold green",
|
||||
"highlight": "bold magenta",
|
||||
"muted": "dim",
|
||||
}
|
||||
)
|
||||
|
||||
# Global console instance
|
||||
console = Console(theme=KT_THEME)
|
||||
|
||||
|
||||
def print_info(message: str, **kwargs) -> None:
|
||||
"""Print an info message."""
|
||||
console.print(f"[info]ℹ[/info] {message}", **kwargs)
|
||||
|
||||
|
||||
def print_success(message: str, **kwargs) -> None:
|
||||
"""Print a success message."""
|
||||
console.print(f"[success]✓[/success] {message}", **kwargs)
|
||||
|
||||
|
||||
def print_warning(message: str, **kwargs) -> None:
|
||||
"""Print a warning message."""
|
||||
console.print(f"[warning]⚠[/warning] {message}", **kwargs)
|
||||
|
||||
|
||||
def print_error(message: str, **kwargs) -> None:
|
||||
"""Print an error message."""
|
||||
console.print(f"[error]✗[/error] {message}", **kwargs)
|
||||
|
||||
|
||||
def print_step(message: str, **kwargs) -> None:
|
||||
"""Print a step indicator."""
|
||||
console.print(f"[highlight]→[/highlight] {message}", **kwargs)
|
||||
|
||||
|
||||
def print_header(title: str, subtitle: Optional[str] = None) -> None:
|
||||
"""Print a header panel."""
|
||||
content = f"[bold]{title}[/bold]"
|
||||
if subtitle:
|
||||
content += f"\n[muted]{subtitle}[/muted]"
|
||||
console.print(Panel(content, expand=False))
|
||||
|
||||
|
||||
def print_version_table(versions: dict[str, Optional[str]]) -> None:
|
||||
"""Print a version information table."""
|
||||
table = Table(show_header=False, box=None, padding=(0, 2))
|
||||
table.add_column("Component", style="bold")
|
||||
table.add_column("Version")
|
||||
|
||||
for name, version in versions.items():
|
||||
if version:
|
||||
table.add_row(name, f"[success]{version}[/success]")
|
||||
else:
|
||||
table.add_row(name, f"[muted]{t('version_not_installed')}[/muted]")
|
||||
|
||||
console.print(table)
|
||||
|
||||
|
||||
def print_dependency_table(deps: list[dict]) -> None:
|
||||
"""Print a dependency status table."""
|
||||
table = Table(title=t("install_checking_deps"))
|
||||
table.add_column(t("version_info"), style="bold")
|
||||
table.add_column("Current")
|
||||
table.add_column("Required")
|
||||
table.add_column("Status")
|
||||
|
||||
for dep in deps:
|
||||
status = dep.get("status", "ok")
|
||||
if status == "ok":
|
||||
status_str = f"[success]{t('install_dep_ok')}[/success]"
|
||||
elif status == "outdated":
|
||||
status_str = f"[warning]{t('install_dep_outdated')}[/warning]"
|
||||
else:
|
||||
status_str = f"[error]{t('install_dep_missing')}[/error]"
|
||||
|
||||
table.add_row(
|
||||
dep["name"],
|
||||
dep.get("installed", "-"),
|
||||
dep.get("required", "-"),
|
||||
status_str,
|
||||
)
|
||||
|
||||
console.print(table)
|
||||
|
||||
|
||||
def confirm(message: str, default: bool = True) -> bool:
|
||||
"""Ask for confirmation."""
|
||||
return Confirm.ask(message, default=default, console=console)
|
||||
|
||||
|
||||
def prompt_choice(message: str, choices: list[str], default: Optional[str] = None) -> str:
|
||||
"""Prompt for a choice from a list."""
|
||||
# Display numbered choices
|
||||
console.print(f"\n[bold]{message}[/bold]")
|
||||
for i, choice in enumerate(choices, 1):
|
||||
console.print(f" [highlight][{i}][/highlight] {choice}")
|
||||
|
||||
while True:
|
||||
response = Prompt.ask(
|
||||
"\n" + t("prompt_select"),
|
||||
console=console,
|
||||
default=str(choices.index(default) + 1) if default else None,
|
||||
)
|
||||
try:
|
||||
idx = int(response) - 1
|
||||
if 0 <= idx < len(choices):
|
||||
return choices[idx]
|
||||
except ValueError:
|
||||
# Check if response matches a choice directly
|
||||
if response in choices:
|
||||
return response
|
||||
|
||||
print_error(f"Please enter a number between 1 and {len(choices)}")
|
||||
|
||||
|
||||
def prompt_text(message: str, default: Optional[str] = None) -> str:
|
||||
"""Prompt for text input."""
|
||||
return Prompt.ask(message, console=console, default=default)
|
||||
|
||||
|
||||
def create_progress() -> Progress:
|
||||
"""Create a progress bar for general tasks."""
|
||||
return Progress(
|
||||
SpinnerColumn(),
|
||||
TextColumn("[progress.description]{task.description}"),
|
||||
BarColumn(),
|
||||
TaskProgressColumn(),
|
||||
TimeElapsedColumn(),
|
||||
console=console,
|
||||
)
|
||||
|
||||
|
||||
def create_download_progress() -> Progress:
|
||||
"""Create a progress bar for downloads."""
|
||||
return Progress(
|
||||
SpinnerColumn(),
|
||||
TextColumn("[progress.description]{task.description}"),
|
||||
BarColumn(),
|
||||
DownloadColumn(),
|
||||
TransferSpeedColumn(),
|
||||
TimeRemainingColumn(),
|
||||
console=console,
|
||||
)
|
||||
|
||||
|
||||
def print_model_table(models: list[dict]) -> None:
|
||||
"""Print a table of models."""
|
||||
table = Table(title=t("download_list_title"))
|
||||
table.add_column("Name", style="bold")
|
||||
table.add_column("Repository")
|
||||
table.add_column("Type")
|
||||
table.add_column("Requirements")
|
||||
|
||||
for model in models:
|
||||
reqs = []
|
||||
if model.get("gpu_vram_gb"):
|
||||
reqs.append(f"GPU: {model['gpu_vram_gb']}GB")
|
||||
if model.get("cpu_ram_gb"):
|
||||
reqs.append(f"RAM: {model['cpu_ram_gb']}GB")
|
||||
|
||||
table.add_row(
|
||||
model.get("name", ""),
|
||||
model.get("hf_repo", ""),
|
||||
model.get("type", ""),
|
||||
", ".join(reqs) if reqs else "-",
|
||||
)
|
||||
|
||||
console.print(table)
|
||||
|
||||
|
||||
def print_hardware_info(gpu_info: str, cpu_info: str, ram_info: str) -> None:
|
||||
"""Print hardware information."""
|
||||
table = Table(show_header=False, box=None)
|
||||
table.add_column("Icon", width=3)
|
||||
table.add_column("Info")
|
||||
|
||||
table.add_row("🖥️", gpu_info)
|
||||
table.add_row("💻", cpu_info)
|
||||
table.add_row("🧠", ram_info)
|
||||
|
||||
console.print(Panel(table, title="Hardware", expand=False))
|
||||
|
||||
|
||||
def print_server_info(
|
||||
mode: str, host: str, port: int, gpu_experts: int, cpu_threads: int
|
||||
) -> None:
|
||||
"""Print server startup information."""
|
||||
table = Table(show_header=False, box=None)
|
||||
table.add_column("Key", style="bold")
|
||||
table.add_column("Value")
|
||||
|
||||
table.add_row(t("run_server_mode").split(":")[0], mode)
|
||||
table.add_row("Host", host)
|
||||
table.add_row("Port", str(port))
|
||||
table.add_row(t("run_gpu_experts").split(":")[0], f"{gpu_experts}/layer")
|
||||
table.add_row(t("run_cpu_threads").split(":")[0], str(cpu_threads))
|
||||
|
||||
console.print(Panel(table, title=t("run_server_started"), expand=False, border_style="green"))
|
||||
|
||||
|
||||
def print_api_info(host: str, port: int) -> None:
|
||||
"""Print API endpoint information."""
|
||||
api_url = f"http://{host}:{port}"
|
||||
docs_url = f"http://{host}:{port}/docs"
|
||||
|
||||
console.print()
|
||||
console.print(f" {t('run_api_url', host=host, port=port)}")
|
||||
console.print(f" {t('run_docs_url', host=host, port=port)}")
|
||||
console.print()
|
||||
console.print(f" [muted]Test command:[/muted]")
|
||||
console.print(
|
||||
f" [dim]curl {api_url}/v1/chat/completions -H 'Content-Type: application/json' "
|
||||
f"-d '{{\"model\": \"default\", \"messages\": [{{\"role\": \"user\", \"content\": \"Hello\"}}]}}'[/dim]"
|
||||
)
|
||||
console.print()
|
||||
console.print(f" [muted]{t('run_stop_hint')}[/muted]")
|
||||
1108
kt-kernel/python/cli/utils/environment.py
Normal file
1108
kt-kernel/python/cli/utils/environment.py
Normal file
File diff suppressed because it is too large
Load Diff
374
kt-kernel/python/cli/utils/model_registry.py
Normal file
374
kt-kernel/python/cli/utils/model_registry.py
Normal file
@@ -0,0 +1,374 @@
|
||||
"""
|
||||
Model registry for kt-cli.
|
||||
|
||||
Provides a registry of supported models with fuzzy matching capabilities.
|
||||
"""
|
||||
|
||||
import re
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Callable, Optional
|
||||
|
||||
import yaml
|
||||
|
||||
from kt_kernel.cli.config.settings import get_settings
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelInfo:
|
||||
"""Information about a supported model."""
|
||||
|
||||
name: str
|
||||
hf_repo: str
|
||||
aliases: list[str] = field(default_factory=list)
|
||||
type: str = "moe" # moe, dense
|
||||
gpu_vram_gb: float = 0
|
||||
cpu_ram_gb: float = 0
|
||||
default_params: dict = field(default_factory=dict)
|
||||
description: str = ""
|
||||
description_zh: str = ""
|
||||
max_tensor_parallel_size: Optional[int] = None # Maximum tensor parallel size for this model
|
||||
|
||||
|
||||
# Built-in model registry
|
||||
BUILTIN_MODELS: list[ModelInfo] = [
|
||||
ModelInfo(
|
||||
name="DeepSeek-V3-0324",
|
||||
hf_repo="deepseek-ai/DeepSeek-V3-0324",
|
||||
aliases=["deepseek-v3-0324", "deepseek-v3", "dsv3", "deepseek3", "v3-0324"],
|
||||
type="moe",
|
||||
default_params={
|
||||
"kt-num-gpu-experts": 1,
|
||||
"attention-backend": "triton",
|
||||
"disable-shared-experts-fusion": True,
|
||||
"kt-method": "AMXINT4",
|
||||
},
|
||||
description="DeepSeek V3-0324 685B MoE model (March 2025, improved benchmarks)",
|
||||
description_zh="DeepSeek V3-0324 685B MoE 模型(2025年3月,改进的基准测试)",
|
||||
),
|
||||
ModelInfo(
|
||||
name="DeepSeek-V3.2",
|
||||
hf_repo="deepseek-ai/DeepSeek-V3.2",
|
||||
aliases=["deepseek-v3.2", "dsv3.2", "deepseek3.2", "v3.2"],
|
||||
type="moe",
|
||||
default_params={
|
||||
"kt-method": "FP8",
|
||||
"kt-gpu-prefill-token-threshold": 4096,
|
||||
"attention-backend": "flashinfer",
|
||||
"fp8-gemm-backend": "triton",
|
||||
"max-total-tokens": 100000,
|
||||
"max-running-requests": 16,
|
||||
"chunked-prefill-size": 32768,
|
||||
"mem-fraction-static": 0.80,
|
||||
"watchdog-timeout": 3000,
|
||||
"served-model-name": "DeepSeek-V3.2",
|
||||
"disable-shared-experts-fusion": True,
|
||||
},
|
||||
description="DeepSeek V3.2 671B MoE model (latest)",
|
||||
description_zh="DeepSeek V3.2 671B MoE 模型(最新)",
|
||||
),
|
||||
ModelInfo(
|
||||
name="DeepSeek-R1-0528",
|
||||
hf_repo="deepseek-ai/DeepSeek-R1-0528",
|
||||
aliases=["deepseek-r1-0528", "deepseek-r1", "dsr1", "r1", "r1-0528"],
|
||||
type="moe",
|
||||
default_params={
|
||||
"kt-num-gpu-experts": 1,
|
||||
"attention-backend": "triton",
|
||||
"disable-shared-experts-fusion": True,
|
||||
"kt-method": "AMXINT4",
|
||||
},
|
||||
description="DeepSeek R1-0528 reasoning model (May 2025, improved reasoning depth)",
|
||||
description_zh="DeepSeek R1-0528 推理模型(2025年5月,改进的推理深度)",
|
||||
),
|
||||
ModelInfo(
|
||||
name="Kimi-K2-Thinking",
|
||||
hf_repo="moonshotai/Kimi-K2-Thinking",
|
||||
aliases=["kimi-k2-thinking", "kimi-thinking", "k2-thinking", "kimi", "k2"],
|
||||
type="moe",
|
||||
default_params={
|
||||
"kt-method": "RAWINT4",
|
||||
"kt-gpu-prefill-token-threshold": 400,
|
||||
"attention-backend": "flashinfer",
|
||||
"max-total-tokens": 100000,
|
||||
"max-running-requests": 16,
|
||||
"chunked-prefill-size": 32768,
|
||||
"mem-fraction-static": 0.80,
|
||||
"watchdog-timeout": 3000,
|
||||
"served-model-name": "Kimi-K2-Thinking",
|
||||
"disable-shared-experts-fusion": True,
|
||||
},
|
||||
description="Moonshot Kimi K2 Thinking MoE model",
|
||||
description_zh="月之暗面 Kimi K2 Thinking MoE 模型",
|
||||
),
|
||||
ModelInfo(
|
||||
name="MiniMax-M2",
|
||||
hf_repo="MiniMaxAI/MiniMax-M2",
|
||||
aliases=["minimax-m2", "m2"],
|
||||
type="moe",
|
||||
default_params={
|
||||
"kt-method": "FP8",
|
||||
"kt-gpu-prefill-token-threshold": 4096,
|
||||
"attention-backend": "flashinfer",
|
||||
"fp8-gemm-backend": "triton",
|
||||
"max-total-tokens": 100000,
|
||||
"max-running-requests": 16,
|
||||
"chunked-prefill-size": 32768,
|
||||
"mem-fraction-static": 0.80,
|
||||
"watchdog-timeout": 3000,
|
||||
"served-model-name": "MiniMax-M2",
|
||||
"disable-shared-experts-fusion": True,
|
||||
"tool-call-parser": "minimax-m2",
|
||||
"reasoning-parser": "minimax-append-think",
|
||||
},
|
||||
description="MiniMax M2 MoE model",
|
||||
description_zh="MiniMax M2 MoE 模型",
|
||||
max_tensor_parallel_size=4, # M2 only supports up to 4-way tensor parallelism
|
||||
),
|
||||
ModelInfo(
|
||||
name="MiniMax-M2.1",
|
||||
hf_repo="MiniMaxAI/MiniMax-M2.1",
|
||||
aliases=["minimax-m2.1", "m2.1"],
|
||||
type="moe",
|
||||
default_params={
|
||||
"kt-method": "FP8",
|
||||
"kt-gpu-prefill-token-threshold": 4096,
|
||||
"attention-backend": "flashinfer",
|
||||
"fp8-gemm-backend": "triton",
|
||||
"max-total-tokens": 100000,
|
||||
"max-running-requests": 16,
|
||||
"chunked-prefill-size": 32768,
|
||||
"mem-fraction-static": 0.80,
|
||||
"watchdog-timeout": 3000,
|
||||
"served-model-name": "MiniMax-M2.1",
|
||||
"disable-shared-experts-fusion": True,
|
||||
"tool-call-parser": "minimax-m2",
|
||||
"reasoning-parser": "minimax-append-think",
|
||||
},
|
||||
description="MiniMax M2.1 MoE model (enhanced multi-language programming)",
|
||||
description_zh="MiniMax M2.1 MoE 模型(增强多语言编程能力)",
|
||||
max_tensor_parallel_size=4, # M2.1 only supports up to 4-way tensor parallelism
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
class ModelRegistry:
|
||||
"""Registry of supported models with fuzzy matching."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the model registry."""
|
||||
self._models: dict[str, ModelInfo] = {}
|
||||
self._aliases: dict[str, str] = {}
|
||||
self._load_builtin_models()
|
||||
self._load_user_models()
|
||||
|
||||
def _load_builtin_models(self) -> None:
|
||||
"""Load built-in models."""
|
||||
for model in BUILTIN_MODELS:
|
||||
self._register(model)
|
||||
|
||||
def _load_user_models(self) -> None:
|
||||
"""Load user-defined models from config."""
|
||||
settings = get_settings()
|
||||
registry_file = settings.config_dir / "registry.yaml"
|
||||
|
||||
if registry_file.exists():
|
||||
try:
|
||||
with open(registry_file, "r", encoding="utf-8") as f:
|
||||
data = yaml.safe_load(f) or {}
|
||||
|
||||
for name, info in data.get("models", {}).items():
|
||||
model = ModelInfo(
|
||||
name=name,
|
||||
hf_repo=info.get("hf_repo", ""),
|
||||
aliases=info.get("aliases", []),
|
||||
type=info.get("type", "moe"),
|
||||
gpu_vram_gb=info.get("gpu_vram_gb", 0),
|
||||
cpu_ram_gb=info.get("cpu_ram_gb", 0),
|
||||
default_params=info.get("default_params", {}),
|
||||
description=info.get("description", ""),
|
||||
description_zh=info.get("description_zh", ""),
|
||||
max_tensor_parallel_size=info.get("max_tensor_parallel_size"),
|
||||
)
|
||||
self._register(model)
|
||||
except (yaml.YAMLError, OSError):
|
||||
pass
|
||||
|
||||
def _register(self, model: ModelInfo) -> None:
|
||||
"""Register a model."""
|
||||
self._models[model.name.lower()] = model
|
||||
|
||||
# Register aliases
|
||||
for alias in model.aliases:
|
||||
self._aliases[alias.lower()] = model.name.lower()
|
||||
|
||||
def get(self, name: str) -> Optional[ModelInfo]:
|
||||
"""Get a model by exact name or alias."""
|
||||
name_lower = name.lower()
|
||||
|
||||
# Check direct match
|
||||
if name_lower in self._models:
|
||||
return self._models[name_lower]
|
||||
|
||||
# Check aliases
|
||||
if name_lower in self._aliases:
|
||||
return self._models[self._aliases[name_lower]]
|
||||
|
||||
return None
|
||||
|
||||
def search(self, query: str, limit: int = 10) -> list[ModelInfo]:
|
||||
"""Search for models using fuzzy matching.
|
||||
|
||||
Args:
|
||||
query: Search query
|
||||
limit: Maximum number of results
|
||||
|
||||
Returns:
|
||||
List of matching models, sorted by relevance
|
||||
"""
|
||||
query_lower = query.lower()
|
||||
results: list[tuple[float, ModelInfo]] = []
|
||||
|
||||
for model in self._models.values():
|
||||
score = self._match_score(query_lower, model)
|
||||
if score > 0:
|
||||
results.append((score, model))
|
||||
|
||||
# Sort by score descending
|
||||
results.sort(key=lambda x: x[0], reverse=True)
|
||||
|
||||
return [model for _, model in results[:limit]]
|
||||
|
||||
def _match_score(self, query: str, model: ModelInfo) -> float:
|
||||
"""Calculate match score for a model.
|
||||
|
||||
Returns a score between 0 and 1, where 1 is an exact match.
|
||||
"""
|
||||
# Check exact match
|
||||
if query == model.name.lower():
|
||||
return 1.0
|
||||
|
||||
# Check alias exact match
|
||||
for alias in model.aliases:
|
||||
if query == alias.lower():
|
||||
return 0.95
|
||||
|
||||
# Check if query is contained in name
|
||||
if query in model.name.lower():
|
||||
return 0.8
|
||||
|
||||
# Check if query is contained in aliases
|
||||
for alias in model.aliases:
|
||||
if query in alias.lower():
|
||||
return 0.7
|
||||
|
||||
# Check if query is contained in hf_repo
|
||||
if query in model.hf_repo.lower():
|
||||
return 0.6
|
||||
|
||||
# Fuzzy matching - check if all query parts are present
|
||||
query_parts = re.split(r"[-_.\s]", query)
|
||||
name_lower = model.name.lower()
|
||||
|
||||
matches = sum(1 for part in query_parts if part and part in name_lower)
|
||||
if matches > 0:
|
||||
return 0.5 * (matches / len(query_parts))
|
||||
|
||||
return 0.0
|
||||
|
||||
def list_all(self) -> list[ModelInfo]:
|
||||
"""List all registered models."""
|
||||
return list(self._models.values())
|
||||
|
||||
def find_local_models(self) -> list[tuple[ModelInfo, Path]]:
|
||||
"""Find models that are downloaded locally in any configured model path.
|
||||
|
||||
Returns:
|
||||
List of (ModelInfo, path) tuples for local models
|
||||
"""
|
||||
settings = get_settings()
|
||||
model_paths = settings.get_model_paths()
|
||||
results = []
|
||||
|
||||
for model in self._models.values():
|
||||
found = False
|
||||
# Search in all configured model directories
|
||||
for models_dir in model_paths:
|
||||
if not models_dir.exists():
|
||||
continue
|
||||
|
||||
# Check common path patterns
|
||||
possible_paths = [
|
||||
models_dir / model.name,
|
||||
models_dir / model.name.lower(),
|
||||
models_dir / model.hf_repo.split("/")[-1],
|
||||
models_dir / model.hf_repo.replace("/", "--"),
|
||||
]
|
||||
|
||||
for path in possible_paths:
|
||||
if path.exists() and (path / "config.json").exists():
|
||||
results.append((model, path))
|
||||
found = True
|
||||
break
|
||||
|
||||
if found:
|
||||
break
|
||||
|
||||
return results
|
||||
|
||||
|
||||
# Global registry instance
|
||||
_registry: Optional[ModelRegistry] = None
|
||||
|
||||
|
||||
def get_registry() -> ModelRegistry:
|
||||
"""Get the global model registry instance."""
|
||||
global _registry
|
||||
if _registry is None:
|
||||
_registry = ModelRegistry()
|
||||
return _registry
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Model-specific parameter computation functions
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def compute_deepseek_v3_gpu_experts(tensor_parallel_size: int, vram_per_gpu_gb: float) -> int:
|
||||
per_gpu_gb = 16
|
||||
if vram_per_gpu_gb < per_gpu_gb:
|
||||
return int(0)
|
||||
total_vram = int(tensor_parallel_size * (vram_per_gpu_gb - per_gpu_gb))
|
||||
|
||||
return total_vram // 3
|
||||
|
||||
|
||||
def compute_kimi_k2_thinking_gpu_experts(tensor_parallel_size: int, vram_per_gpu_gb: float) -> int:
|
||||
"""Compute kt-num-gpu-experts for Kimi K2 Thinking."""
|
||||
per_gpu_gb = 16
|
||||
if vram_per_gpu_gb < per_gpu_gb:
|
||||
return int(0)
|
||||
total_vram = int(tensor_parallel_size * (vram_per_gpu_gb - per_gpu_gb))
|
||||
|
||||
return total_vram * 2 // 3
|
||||
|
||||
|
||||
def compute_minimax_m2_gpu_experts(tensor_parallel_size: int, vram_per_gpu_gb: float) -> int:
|
||||
"""Compute kt-num-gpu-experts for MiniMax M2/M2.1."""
|
||||
per_gpu_gb = 16
|
||||
if vram_per_gpu_gb < per_gpu_gb:
|
||||
return int(0)
|
||||
total_vram = int(tensor_parallel_size * (vram_per_gpu_gb - per_gpu_gb))
|
||||
|
||||
return total_vram // 1
|
||||
|
||||
|
||||
# Model name to computation function mapping
|
||||
MODEL_COMPUTE_FUNCTIONS: dict[str, Callable[[int, float], int]] = {
|
||||
"DeepSeek-V3-0324": compute_deepseek_v3_gpu_experts,
|
||||
"DeepSeek-V3.2": compute_deepseek_v3_gpu_experts, # Same as V3-0324
|
||||
"DeepSeek-R1-0528": compute_deepseek_v3_gpu_experts, # Same as V3-0324
|
||||
"Kimi-K2-Thinking": compute_kimi_k2_thinking_gpu_experts,
|
||||
"MiniMax-M2": compute_minimax_m2_gpu_experts,
|
||||
"MiniMax-M2.1": compute_minimax_m2_gpu_experts, # Same as M2
|
||||
}
|
||||
407
kt-kernel/python/cli/utils/sglang_checker.py
Normal file
407
kt-kernel/python/cli/utils/sglang_checker.py
Normal file
@@ -0,0 +1,407 @@
|
||||
"""
|
||||
SGLang installation checker and installation instructions provider.
|
||||
|
||||
This module provides utilities to:
|
||||
- Check if SGLang is installed and get its metadata
|
||||
- Provide installation instructions when SGLang is not found
|
||||
"""
|
||||
|
||||
import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from kt_kernel.cli.i18n import t
|
||||
from kt_kernel.cli.utils.console import console
|
||||
|
||||
|
||||
def check_sglang_installation() -> dict:
|
||||
"""Check if SGLang is installed and get its metadata.
|
||||
|
||||
Returns:
|
||||
dict with keys:
|
||||
- installed: bool
|
||||
- version: str or None
|
||||
- location: str or None (installation path)
|
||||
- editable: bool (whether installed in editable mode)
|
||||
- git_info: dict or None (git remote and branch if available)
|
||||
- from_source: bool (whether installed from source repository)
|
||||
"""
|
||||
try:
|
||||
# Try to import sglang
|
||||
import sglang
|
||||
|
||||
version = getattr(sglang, "__version__", None)
|
||||
|
||||
# Use pip show to get detailed package information
|
||||
location = None
|
||||
editable = False
|
||||
git_info = None
|
||||
from_source = False
|
||||
|
||||
try:
|
||||
# Get pip show output
|
||||
result = subprocess.run(
|
||||
[sys.executable, "-m", "pip", "show", "sglang"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=10,
|
||||
)
|
||||
|
||||
if result.returncode == 0:
|
||||
pip_info = {}
|
||||
for line in result.stdout.split("\n"):
|
||||
if ":" in line:
|
||||
key, value = line.split(":", 1)
|
||||
pip_info[key.strip()] = value.strip()
|
||||
|
||||
location = pip_info.get("Location")
|
||||
editable_location = pip_info.get("Editable project location")
|
||||
|
||||
if editable_location:
|
||||
editable = True
|
||||
location = editable_location
|
||||
except (subprocess.TimeoutExpired, FileNotFoundError, OSError):
|
||||
# Fallback to module location
|
||||
if hasattr(sglang, "__file__") and sglang.__file__:
|
||||
location = str(Path(sglang.__file__).parent.parent)
|
||||
|
||||
# Check if it's installed from source (has .git directory)
|
||||
if location:
|
||||
git_root = None
|
||||
check_path = Path(location)
|
||||
|
||||
# Check current directory and up to 2 parent directories
|
||||
for _ in range(3):
|
||||
git_dir = check_path / ".git"
|
||||
if git_dir.exists():
|
||||
git_root = check_path
|
||||
from_source = True
|
||||
break
|
||||
if check_path.parent == check_path: # Reached root
|
||||
break
|
||||
check_path = check_path.parent
|
||||
|
||||
if from_source and git_root:
|
||||
# Try to get git remote and branch info
|
||||
try:
|
||||
# Get remote URL
|
||||
result = subprocess.run(
|
||||
["git", "remote", "get-url", "origin"],
|
||||
cwd=git_root,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=5,
|
||||
)
|
||||
remote_url = result.stdout.strip() if result.returncode == 0 else None
|
||||
|
||||
# Extract org/repo from URL
|
||||
remote_short = None
|
||||
if remote_url:
|
||||
# Handle both https and git@ URLs
|
||||
if "github.com" in remote_url:
|
||||
parts = remote_url.rstrip("/").replace(".git", "").split("github.com")[-1]
|
||||
remote_short = parts.lstrip("/").lstrip(":")
|
||||
|
||||
# Get current branch
|
||||
result = subprocess.run(
|
||||
["git", "branch", "--show-current"],
|
||||
cwd=git_root,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=5,
|
||||
)
|
||||
branch = result.stdout.strip() if result.returncode == 0 else None
|
||||
|
||||
if remote_url or branch:
|
||||
git_info = {
|
||||
"remote": remote_short or remote_url,
|
||||
"branch": branch,
|
||||
}
|
||||
except (subprocess.TimeoutExpired, FileNotFoundError, OSError):
|
||||
pass
|
||||
|
||||
return {
|
||||
"installed": True,
|
||||
"version": version,
|
||||
"location": location,
|
||||
"editable": editable,
|
||||
"git_info": git_info,
|
||||
"from_source": from_source,
|
||||
}
|
||||
except ImportError:
|
||||
return {
|
||||
"installed": False,
|
||||
"version": None,
|
||||
"location": None,
|
||||
"editable": False,
|
||||
"git_info": None,
|
||||
"from_source": False,
|
||||
}
|
||||
|
||||
|
||||
def get_sglang_install_instructions(lang: Optional[str] = None) -> str:
|
||||
"""Get SGLang installation instructions.
|
||||
|
||||
Args:
|
||||
lang: Language code ('en' or 'zh'). If None, uses current language setting.
|
||||
|
||||
Returns:
|
||||
Formatted installation instructions string.
|
||||
"""
|
||||
from kt_kernel.cli.i18n import get_lang
|
||||
|
||||
if lang is None:
|
||||
lang = get_lang()
|
||||
|
||||
if lang == "zh":
|
||||
return """
|
||||
[bold yellow]SGLang \u672a\u5b89\u88c5[/bold yellow]
|
||||
|
||||
\u8bf7\u6309\u7167\u4ee5\u4e0b\u6b65\u9aa4\u5b89\u88c5 SGLang:
|
||||
|
||||
[bold]1. \u514b\u9686\u4ed3\u5e93:[/bold]
|
||||
git clone https://github.com/kvcache-ai/sglang.git
|
||||
cd sglang
|
||||
|
||||
[bold]2. \u5b89\u88c5 (\u4e8c\u9009\u4e00):[/bold]
|
||||
|
||||
[cyan]\u65b9\u5f0f A - pip \u5b89\u88c5 (\u63a8\u8350):[/cyan]
|
||||
pip install -e "python[all]"
|
||||
|
||||
[cyan]\u65b9\u5f0f B - uv \u5b89\u88c5 (\u66f4\u5feb):[/cyan]
|
||||
pip install uv
|
||||
uv pip install -e "python[all]"
|
||||
|
||||
[dim]\u6ce8\u610f: \u8bf7\u786e\u4fdd\u5728\u6b63\u786e\u7684 Python \u73af\u5883\u4e2d\u6267\u884c\u4ee5\u4e0a\u547d\u4ee4[/dim]
|
||||
"""
|
||||
else:
|
||||
return """
|
||||
[bold yellow]SGLang is not installed[/bold yellow]
|
||||
|
||||
Please follow these steps to install SGLang:
|
||||
|
||||
[bold]1. Clone the repository:[/bold]
|
||||
git clone https://github.com/kvcache-ai/sglang.git
|
||||
cd sglang
|
||||
|
||||
[bold]2. Install (choose one):[/bold]
|
||||
|
||||
[cyan]Option A - pip install (recommended):[/cyan]
|
||||
pip install -e "python[all]"
|
||||
|
||||
[cyan]Option B - uv install (faster):[/cyan]
|
||||
pip install uv
|
||||
uv pip install -e "python[all]"
|
||||
|
||||
[dim]Note: Make sure to run these commands in the correct Python environment[/dim]
|
||||
"""
|
||||
|
||||
|
||||
def print_sglang_install_instructions() -> None:
|
||||
"""Print SGLang installation instructions to console."""
|
||||
instructions = get_sglang_install_instructions()
|
||||
console.print(instructions)
|
||||
|
||||
|
||||
def check_sglang_and_warn() -> bool:
|
||||
"""Check if SGLang is installed, print warning if not.
|
||||
|
||||
Returns:
|
||||
True if SGLang is installed, False otherwise.
|
||||
"""
|
||||
info = check_sglang_installation()
|
||||
|
||||
if not info["installed"]:
|
||||
print_sglang_install_instructions()
|
||||
return False
|
||||
|
||||
# Check if installed from PyPI (not recommended)
|
||||
if info["installed"] and not info["from_source"]:
|
||||
from kt_kernel.cli.utils.console import print_warning
|
||||
|
||||
print_warning(t("sglang_pypi_warning"))
|
||||
console.print()
|
||||
console.print("[dim]" + t("sglang_recommend_source") + "[/dim]")
|
||||
console.print()
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def _get_sglang_kt_kernel_cache_path() -> Path:
|
||||
"""Get the path to the sglang kt-kernel support cache file."""
|
||||
cache_dir = Path.home() / ".ktransformers" / "cache"
|
||||
cache_dir.mkdir(parents=True, exist_ok=True)
|
||||
return cache_dir / "sglang_kt_kernel_supported"
|
||||
|
||||
|
||||
def _is_sglang_kt_kernel_cache_valid() -> bool:
|
||||
"""Check if the sglang kt-kernel support cache is valid.
|
||||
|
||||
The cache is considered valid if:
|
||||
1. The cache file exists
|
||||
2. The cache file contains 'true' (indicating previous check passed)
|
||||
|
||||
Returns:
|
||||
True if cache is valid and indicates support, False otherwise.
|
||||
"""
|
||||
cache_path = _get_sglang_kt_kernel_cache_path()
|
||||
if cache_path.exists():
|
||||
try:
|
||||
content = cache_path.read_text().strip().lower()
|
||||
return content == "true"
|
||||
except (OSError, IOError):
|
||||
pass
|
||||
return False
|
||||
|
||||
|
||||
def _save_sglang_kt_kernel_cache(supported: bool) -> None:
|
||||
"""Save the sglang kt-kernel support check result to cache."""
|
||||
cache_path = _get_sglang_kt_kernel_cache_path()
|
||||
try:
|
||||
cache_path.write_text("true" if supported else "false")
|
||||
except (OSError, IOError):
|
||||
pass # Ignore cache write errors
|
||||
|
||||
|
||||
def clear_sglang_kt_kernel_cache() -> None:
|
||||
"""Clear the sglang kt-kernel support cache, forcing a re-check on next run."""
|
||||
cache_path = _get_sglang_kt_kernel_cache_path()
|
||||
try:
|
||||
if cache_path.exists():
|
||||
cache_path.unlink()
|
||||
except (OSError, IOError):
|
||||
pass
|
||||
|
||||
|
||||
def check_sglang_kt_kernel_support(use_cache: bool = True, silent: bool = False) -> dict:
|
||||
"""Check if SGLang supports kt-kernel parameters (--kt-gpu-prefill-token-threshold).
|
||||
|
||||
This function runs `python -m sglang.launch_server --help` and checks if the
|
||||
output contains the `--kt-gpu-prefill-token-threshold` parameter. This parameter
|
||||
is only available in the kvcache-ai/sglang fork, not in the official sglang.
|
||||
|
||||
The result is cached after the first successful check to avoid repeated checks.
|
||||
|
||||
Args:
|
||||
use_cache: If True, use cached result if available. Default is True.
|
||||
silent: If True, don't print checking message. Default is False.
|
||||
|
||||
Returns:
|
||||
dict with keys:
|
||||
- supported: bool - True if kt-kernel parameters are supported
|
||||
- help_output: str or None - The help output from sglang.launch_server
|
||||
- error: str or None - Error message if check failed
|
||||
- from_cache: bool - True if result was from cache
|
||||
"""
|
||||
from kt_kernel.cli.utils.console import print_step
|
||||
|
||||
# Check cache first
|
||||
if use_cache and _is_sglang_kt_kernel_cache_valid():
|
||||
return {
|
||||
"supported": True,
|
||||
"help_output": None,
|
||||
"error": None,
|
||||
"from_cache": True,
|
||||
}
|
||||
|
||||
# Print checking message
|
||||
if not silent:
|
||||
print_step(t("sglang_checking_kt_kernel_support"))
|
||||
|
||||
try:
|
||||
result = subprocess.run(
|
||||
[sys.executable, "-m", "sglang.launch_server", "--help"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=30,
|
||||
)
|
||||
|
||||
help_output = result.stdout + result.stderr
|
||||
|
||||
# Check if --kt-gpu-prefill-token-threshold is in the help output
|
||||
supported = "--kt-gpu-prefill-token-threshold" in help_output
|
||||
|
||||
# Save to cache if supported
|
||||
if supported:
|
||||
_save_sglang_kt_kernel_cache(True)
|
||||
|
||||
return {
|
||||
"supported": supported,
|
||||
"help_output": help_output,
|
||||
"error": None,
|
||||
"from_cache": False,
|
||||
}
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
return {
|
||||
"supported": False,
|
||||
"help_output": None,
|
||||
"error": "Timeout while checking sglang.launch_server --help",
|
||||
"from_cache": False,
|
||||
}
|
||||
except FileNotFoundError:
|
||||
return {
|
||||
"supported": False,
|
||||
"help_output": None,
|
||||
"error": "Python interpreter not found",
|
||||
"from_cache": False,
|
||||
}
|
||||
except Exception as e:
|
||||
return {
|
||||
"supported": False,
|
||||
"help_output": None,
|
||||
"error": str(e),
|
||||
"from_cache": False,
|
||||
}
|
||||
|
||||
|
||||
def print_sglang_kt_kernel_instructions() -> None:
|
||||
"""Print instructions for installing the kvcache-ai fork of SGLang with kt-kernel support."""
|
||||
from kt_kernel.cli.i18n import get_lang
|
||||
|
||||
lang = get_lang()
|
||||
|
||||
if lang == "zh":
|
||||
instructions = """
|
||||
[bold red]SGLang 不支持 kt-kernel[/bold red]
|
||||
|
||||
您当前安装的 SGLang 不包含 kt-kernel 支持。
|
||||
kt-kernel 需要使用 kvcache-ai 维护的 SGLang 分支。
|
||||
|
||||
[bold]请按以下步骤重新安装 SGLang:[/bold]
|
||||
|
||||
[cyan]1. 卸载当前的 SGLang:[/cyan]
|
||||
pip uninstall sglang -y
|
||||
|
||||
[cyan]2. 克隆 kvcache-ai 的 SGLang 仓库:[/cyan]
|
||||
git clone https://github.com/kvcache-ai/sglang.git
|
||||
cd sglang
|
||||
|
||||
[cyan]3. 安装 SGLang:[/cyan]
|
||||
pip install -e "python[all]"
|
||||
|
||||
[dim]注意: 请确保在正确的 Python 环境中执行以上命令[/dim]
|
||||
"""
|
||||
else:
|
||||
instructions = """
|
||||
[bold red]SGLang does not support kt-kernel[/bold red]
|
||||
|
||||
Your current SGLang installation does not include kt-kernel support.
|
||||
kt-kernel requires the kvcache-ai maintained fork of SGLang.
|
||||
|
||||
[bold]Please reinstall SGLang with the following steps:[/bold]
|
||||
|
||||
[cyan]1. Uninstall current SGLang:[/cyan]
|
||||
pip uninstall sglang -y
|
||||
|
||||
[cyan]2. Clone the kvcache-ai SGLang repository:[/cyan]
|
||||
git clone https://github.com/kvcache-ai/sglang.git
|
||||
cd sglang
|
||||
|
||||
[cyan]3. Install SGLang:[/cyan]
|
||||
pip install -e "python[all]"
|
||||
|
||||
[dim]Note: Make sure to run these commands in the correct Python environment[/dim]
|
||||
"""
|
||||
console.print(instructions)
|
||||
@@ -17,7 +17,7 @@ from typing import List, Optional
|
||||
from .experts_base import BaseMoEWrapper, KExpertsCPUBuffer
|
||||
|
||||
# Import backend implementations
|
||||
from .utils.amx import AMXMoEWrapper, RAWAMXMoEWrapper
|
||||
from .utils.amx import AMXMoEWrapper, NativeMoEWrapper
|
||||
from .utils.llamafile import LlamafileMoEWrapper
|
||||
from .utils.moe_kernel import GeneralMoEWrapper
|
||||
|
||||
@@ -77,7 +77,7 @@ class KTMoEWrapper:
|
||||
chunked_prefill_size: Maximum prefill chunk size
|
||||
cpu_save: Whether to save weights to CPU memory
|
||||
max_deferred_experts_per_token: Number of experts per token to defer. Defaults to 0.
|
||||
method: Backend method ("AMXINT4", "AMXINT8", "RAWINT4", "LLAMAFILE", "MOE_INT4", "MOE_INT8")
|
||||
method: Backend method ("AMXINT4", "AMXINT8", "RAWINT4", "FP8", "LLAMAFILE", "MOE_INT4", "MOE_INT8")
|
||||
|
||||
Returns:
|
||||
An instance of the appropriate backend implementation (e.g., AMXMoEWrapper)
|
||||
@@ -85,8 +85,8 @@ class KTMoEWrapper:
|
||||
# Select backend based on method
|
||||
if method in ["AMXINT4", "AMXINT8"]:
|
||||
backend_cls = AMXMoEWrapper
|
||||
elif method == "RAWINT4":
|
||||
backend_cls = RAWAMXMoEWrapper
|
||||
elif method in ["RAWINT4", "FP8"]:
|
||||
backend_cls = NativeMoEWrapper
|
||||
elif method == "LLAMAFILE":
|
||||
backend_cls = LlamafileMoEWrapper
|
||||
elif method in ["MOE_INT4", "MOE_INT8"]:
|
||||
|
||||
@@ -4,13 +4,13 @@
|
||||
Utilities for kt_kernel package.
|
||||
"""
|
||||
|
||||
from .amx import AMXMoEWrapper, RAWAMXMoEWrapper
|
||||
from .amx import AMXMoEWrapper, NativeMoEWrapper
|
||||
from .llamafile import LlamafileMoEWrapper
|
||||
from .loader import SafeTensorLoader, GGUFLoader, CompressedSafeTensorLoader
|
||||
|
||||
__all__ = [
|
||||
"AMXMoEWrapper",
|
||||
"RAWAMXMoEWrapper",
|
||||
"NativeMoEWrapper",
|
||||
"LlamafileMoEWrapper",
|
||||
"SafeTensorLoader",
|
||||
"CompressedSafeTensorLoader",
|
||||
|
||||
@@ -4,16 +4,16 @@ import ctypes
|
||||
|
||||
# Use relative imports for package structure
|
||||
from ..experts_base import BaseMoEWrapper
|
||||
from .loader import SafeTensorLoader, CompressedSafeTensorLoader
|
||||
from .loader import SafeTensorLoader, CompressedSafeTensorLoader, FP8SafeTensorLoader
|
||||
from kt_kernel_ext.moe import MOEConfig
|
||||
|
||||
try:
|
||||
from kt_kernel_ext.moe import AMXInt4_MOE, AMXInt8_MOE, AMXInt4_KGroup_MOE
|
||||
from kt_kernel_ext.moe import AMXInt4_MOE, AMXInt8_MOE, AMXInt4_KGroup_MOE, AMXFP8_MOE
|
||||
|
||||
_HAS_AMX_SUPPORT = True
|
||||
except (ImportError, AttributeError):
|
||||
_HAS_AMX_SUPPORT = False
|
||||
AMXInt4_MOE, AMXInt8_MOE, AMXInt4_KGroup_MOE = None, None, None
|
||||
AMXInt4_MOE, AMXInt8_MOE, AMXInt4_KGroup_MOE, AMXFP8_MOE = None, None, None, None
|
||||
|
||||
from typing import Optional
|
||||
|
||||
@@ -303,10 +303,10 @@ class AMXMoEWrapper(BaseMoEWrapper):
|
||||
del self.down_scales
|
||||
|
||||
|
||||
class RAWAMXMoEWrapper(BaseMoEWrapper):
|
||||
"""Wrapper for RAWINT4 experts stored in compressed SafeTensor format."""
|
||||
class NativeMoEWrapper(BaseMoEWrapper):
|
||||
"""Wrapper for RAWINT4/FP8 experts stored in compressed SafeTensor format."""
|
||||
|
||||
_compressed_loader_instance = None
|
||||
_native_loader_instance = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -324,8 +324,12 @@ class RAWAMXMoEWrapper(BaseMoEWrapper):
|
||||
max_deferred_experts_per_token: Optional[int] = None,
|
||||
method: str = "RAWINT4",
|
||||
):
|
||||
if not _HAS_AMX_SUPPORT or AMXInt4_KGroup_MOE is None:
|
||||
if not _HAS_AMX_SUPPORT:
|
||||
raise RuntimeError("AMX backend is not available.")
|
||||
if method == "RAWINT4" and AMXInt4_KGroup_MOE is None:
|
||||
raise RuntimeError("AMX backend with RAWINT4 support is not available.")
|
||||
if method == "FP8" and AMXFP8_MOE is None:
|
||||
raise RuntimeError("AMX backend with FP8 support is not available.")
|
||||
|
||||
super().__init__(
|
||||
layer_idx=layer_idx,
|
||||
@@ -343,9 +347,14 @@ class RAWAMXMoEWrapper(BaseMoEWrapper):
|
||||
method=method,
|
||||
)
|
||||
|
||||
if RAWAMXMoEWrapper._compressed_loader_instance is None:
|
||||
RAWAMXMoEWrapper._compressed_loader_instance = CompressedSafeTensorLoader(weight_path)
|
||||
self.loader = RAWAMXMoEWrapper._compressed_loader_instance
|
||||
if NativeMoEWrapper._native_loader_instance is None:
|
||||
if method == "RAWINT4":
|
||||
NativeMoEWrapper._native_loader_instance = CompressedSafeTensorLoader(weight_path)
|
||||
elif method == "FP8":
|
||||
NativeMoEWrapper._native_loader_instance = FP8SafeTensorLoader(weight_path)
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported method for NativeMoEWrapper: {method}")
|
||||
self.loader = NativeMoEWrapper._native_loader_instance
|
||||
|
||||
self.gate_weights = None
|
||||
self.up_weights = None
|
||||
@@ -378,9 +387,17 @@ class RAWAMXMoEWrapper(BaseMoEWrapper):
|
||||
self.down_weights = weights["down"]
|
||||
|
||||
# Convert scales to bf16 individually
|
||||
self.gate_scales = [t.to(torch.bfloat16).contiguous() for t in weights["gate_scale"]]
|
||||
self.up_scales = [t.to(torch.bfloat16).contiguous() for t in weights["up_scale"]]
|
||||
self.down_scales = [t.to(torch.bfloat16).contiguous() for t in weights["down_scale"]]
|
||||
# self.gate_scales = [t.to(torch.bfloat16).contiguous() for t in weights["gate_scale"]]
|
||||
# self.up_scales = [t.to(torch.bfloat16).contiguous() for t in weights["up_scale"]]
|
||||
# self.down_scales = [t.to(torch.bfloat16).contiguous() for t in weights["down_scale"]]
|
||||
self.gate_scales = weights["gate_scale"]
|
||||
self.up_scales = weights["up_scale"]
|
||||
self.down_scales = weights["down_scale"]
|
||||
if self.method == "RAWINT4":
|
||||
assert self.gate_scales[0].dtype == torch.bfloat16, "Expected bf16 scales for RAWINT4"
|
||||
elif self.method == "FP8":
|
||||
assert self.gate_scales[0].dtype == torch.float32, "Expected float32 scales for FP8"
|
||||
|
||||
t2 = time.time()
|
||||
|
||||
# Build pointer lists: [numa_id][expert_id] -> pointer
|
||||
@@ -404,18 +421,6 @@ class RAWAMXMoEWrapper(BaseMoEWrapper):
|
||||
moe_config.pool = self.cpu_infer.backend_
|
||||
moe_config.max_len = self.chunked_prefill_size
|
||||
|
||||
# Infer group_size from scale shape (column-major layout)
|
||||
# For gate/up projection: in_features = hidden_size
|
||||
# So: group_size = hidden_size / scale.shape[1]
|
||||
scale_shape = self.gate_scales[0].shape
|
||||
group_size = self.hidden_size // scale_shape[1]
|
||||
print(f"[RAWAMXMoEWrapper Layer {self.layer_idx}] Inferred group_size: {group_size}")
|
||||
|
||||
moe_config.quant_config.bits = 4
|
||||
moe_config.quant_config.group_size = group_size
|
||||
|
||||
moe_config.quant_config.zero_point = False
|
||||
|
||||
# Use gate_projs instead of gate_proj for per-expert pointers
|
||||
moe_config.gate_projs = gate_ptrs
|
||||
moe_config.up_projs = up_ptrs
|
||||
@@ -424,7 +429,21 @@ class RAWAMXMoEWrapper(BaseMoEWrapper):
|
||||
moe_config.up_scales = up_scale_ptrs
|
||||
moe_config.down_scales = down_scale_ptrs
|
||||
|
||||
self.moe = AMXInt4_KGroup_MOE(moe_config)
|
||||
# Infer group_size from scale shape (column-major layout)
|
||||
# For gate/up projection: in_features = hidden_size
|
||||
# So: group_size = hidden_size / scale.shape[1]
|
||||
|
||||
if self.method == "RAWINT4":
|
||||
group_size = self.hidden_size // self.gate_scales[0].shape[1]
|
||||
moe_config.quant_config.bits = 4
|
||||
moe_config.quant_config.group_size = group_size
|
||||
moe_config.quant_config.zero_point = False
|
||||
self.moe = AMXInt4_KGroup_MOE(moe_config)
|
||||
elif self.method == "FP8":
|
||||
moe_config.quant_config.bits = 8
|
||||
moe_config.quant_config.group_size = 128
|
||||
moe_config.quant_config.zero_point = False
|
||||
self.moe = AMXFP8_MOE(moe_config)
|
||||
t4 = time.time()
|
||||
|
||||
self.cpu_infer.submit(self.moe.load_weights_task(physical_to_logical_map_cpu.data_ptr()))
|
||||
@@ -440,7 +459,7 @@ class RAWAMXMoEWrapper(BaseMoEWrapper):
|
||||
t6 = time.time()
|
||||
|
||||
print(
|
||||
f"[RAWAMXMoEWrapper Layer {self.layer_idx}] "
|
||||
f"[NativeMoEWrapper Layer {self.layer_idx}] "
|
||||
f"load_experts: {(t1-t0)*1000:.1f}ms, "
|
||||
f"prepare_tensors: {(t2-t1)*1000:.1f}ms, "
|
||||
f"build_ptrs: {(t3-t2)*1000:.1f}ms, "
|
||||
@@ -453,7 +472,7 @@ class RAWAMXMoEWrapper(BaseMoEWrapper):
|
||||
def submit_write_weight_scale_to_buffer(
|
||||
self,
|
||||
gpu_tp_count: int,
|
||||
gpu_experts_num: int,
|
||||
expert_id: int,
|
||||
w13_weight_ptrs,
|
||||
w13_scale_ptrs,
|
||||
w2_weight_ptrs,
|
||||
@@ -477,7 +496,7 @@ class RAWAMXMoEWrapper(BaseMoEWrapper):
|
||||
self.cpu_infer.submit(
|
||||
self.moe.write_weight_scale_to_buffer_task(
|
||||
gpu_tp_count,
|
||||
gpu_experts_num,
|
||||
expert_id,
|
||||
w13_weight_ptrs,
|
||||
w13_scale_ptrs,
|
||||
w2_weight_ptrs,
|
||||
|
||||
@@ -219,4 +219,4 @@ class LlamafileMoEWrapper(BaseMoEWrapper):
|
||||
self.cpu_infer.sync()
|
||||
|
||||
# Drop original weights after loading
|
||||
self.weights_to_keep = None
|
||||
self.weights_to_keep = None
|
||||
|
||||
@@ -237,6 +237,117 @@ class SafeTensorLoader:
|
||||
return name in self.tensor_file_map
|
||||
|
||||
|
||||
class FP8SafeTensorLoader(SafeTensorLoader):
|
||||
"""Loader for FP8 expert weights with auto-detection of naming formats.
|
||||
|
||||
Supported formats:
|
||||
- DeepSeek style: {base}.mlp.experts.{id}.{gate,up,down}_proj.weight
|
||||
- Mixtral/MiniMax style: {base}.block_sparse_moe.experts.{id}.{w1,w3,w2}.weight
|
||||
|
||||
The format is auto-detected during initialization.
|
||||
"""
|
||||
|
||||
# Known MoE naming formats: (experts_path_template, gate_name, up_name, down_name)
|
||||
MOE_FORMATS = {
|
||||
"deepseek": ("{base}.mlp.experts", "gate_proj", "up_proj", "down_proj"),
|
||||
"mixtral": ("{base}.block_sparse_moe.experts", "w1", "w3", "w2"),
|
||||
}
|
||||
|
||||
def __init__(self, file_path: str):
|
||||
super().__init__(file_path)
|
||||
self._detected_format = None
|
||||
self._detect_format()
|
||||
|
||||
def _detect_format(self):
|
||||
"""Auto-detect the MoE naming format by checking tensor keys."""
|
||||
# Sample some tensor names to detect format
|
||||
sample_keys = list(self.tensor_file_map.keys())[:1000]
|
||||
|
||||
for fmt_name, (path_tpl, gate, up, down) in self.MOE_FORMATS.items():
|
||||
# Check if any key matches this format pattern
|
||||
# Look for pattern like: model.layers.0.{experts_path}.0.{gate_name}.weight
|
||||
for key in sample_keys:
|
||||
if ".experts." in key and f".{gate}.weight" in key:
|
||||
# Verify the path template matches
|
||||
if "block_sparse_moe.experts" in key and fmt_name == "mixtral":
|
||||
self._detected_format = fmt_name
|
||||
print(f"[FP8SafeTensorLoader] Detected format: {fmt_name}")
|
||||
return
|
||||
elif "mlp.experts" in key and "block_sparse_moe" not in key and fmt_name == "deepseek":
|
||||
self._detected_format = fmt_name
|
||||
print(f"[FP8SafeTensorLoader] Detected format: {fmt_name}")
|
||||
return
|
||||
|
||||
# Default to deepseek if no format detected
|
||||
self._detected_format = "deepseek"
|
||||
print("[FP8SafeTensorLoader] No MoE format detected, defaulting to: deepseek")
|
||||
|
||||
def _get_experts_prefix(self, base_key: str) -> str:
|
||||
"""Get the experts prefix based on detected format."""
|
||||
path_tpl, _, _, _ = self.MOE_FORMATS[self._detected_format]
|
||||
return path_tpl.format(base=base_key)
|
||||
|
||||
def _get_proj_names(self):
|
||||
"""Get projection names (gate, up, down) based on detected format."""
|
||||
_, gate, up, down = self.MOE_FORMATS[self._detected_format]
|
||||
return gate, up, down
|
||||
|
||||
def load_tensor(self, key: str, device: str = "cpu"):
|
||||
if key not in self.tensor_file_map:
|
||||
raise KeyError(f"Key {key} not found in Safetensor files")
|
||||
file = self.tensor_file_map[key]
|
||||
f = self.file_handle_map.get(file)
|
||||
if f is None:
|
||||
raise FileNotFoundError(f"File {file} not found in Safetensor files")
|
||||
tensor = f.get_tensor(key)
|
||||
if device == "cpu":
|
||||
return tensor
|
||||
return tensor.to(device)
|
||||
|
||||
def load_experts(self, base_key: str, device: str = "cpu"):
|
||||
"""Load FP8 expert weights and their block-wise scale_inv tensors."""
|
||||
experts_prefix = self._get_experts_prefix(base_key)
|
||||
gate_name, up_name, down_name = self._get_proj_names()
|
||||
|
||||
expert_count = 0
|
||||
while self.has_tensor(f"{experts_prefix}.{expert_count}.{gate_name}.weight"):
|
||||
expert_count += 1
|
||||
|
||||
if expert_count == 0:
|
||||
raise ValueError(f"No experts found for key {experts_prefix}")
|
||||
|
||||
gate_weights = [None] * expert_count
|
||||
up_weights = [None] * expert_count
|
||||
down_weights = [None] * expert_count
|
||||
gate_scales = [None] * expert_count
|
||||
up_scales = [None] * expert_count
|
||||
down_scales = [None] * expert_count
|
||||
|
||||
for exp_id in range(expert_count):
|
||||
gate_w_key = f"{experts_prefix}.{exp_id}.{gate_name}.weight"
|
||||
up_w_key = f"{experts_prefix}.{exp_id}.{up_name}.weight"
|
||||
down_w_key = f"{experts_prefix}.{exp_id}.{down_name}.weight"
|
||||
gate_s_key = f"{experts_prefix}.{exp_id}.{gate_name}.weight_scale_inv"
|
||||
up_s_key = f"{experts_prefix}.{exp_id}.{up_name}.weight_scale_inv"
|
||||
down_s_key = f"{experts_prefix}.{exp_id}.{down_name}.weight_scale_inv"
|
||||
|
||||
gate_weights[exp_id] = self.load_tensor(gate_w_key, device).contiguous()
|
||||
up_weights[exp_id] = self.load_tensor(up_w_key, device).contiguous()
|
||||
down_weights[exp_id] = self.load_tensor(down_w_key, device).contiguous()
|
||||
gate_scales[exp_id] = self.load_tensor(gate_s_key, device).contiguous()
|
||||
up_scales[exp_id] = self.load_tensor(up_s_key, device).contiguous()
|
||||
down_scales[exp_id] = self.load_tensor(down_s_key, device).contiguous()
|
||||
|
||||
return {
|
||||
"gate": gate_weights,
|
||||
"up": up_weights,
|
||||
"down": down_weights,
|
||||
"gate_scale": gate_scales,
|
||||
"up_scale": up_scales,
|
||||
"down_scale": down_scales,
|
||||
}
|
||||
|
||||
|
||||
class CompressedSafeTensorLoader(SafeTensorLoader):
|
||||
"""Loader for compressed SafeTensor layouts (RAWINT4 weights)."""
|
||||
|
||||
|
||||
@@ -285,9 +285,9 @@ class CMakeBuild(build_ext):
|
||||
|
||||
# Variant configurations: (name, CPUINFER_CPU_INSTRUCT, CPUINFER_ENABLE_AMX)
|
||||
variants = [
|
||||
("amx", "AVX512", "ON"), # AVX512 + AMX
|
||||
("amx", "AVX512", "ON"), # AVX512 + AMX
|
||||
("avx512", "AVX512", "OFF"), # AVX512 only
|
||||
("avx2", "AVX2", "OFF"), # AVX2 only
|
||||
("avx2", "AVX2", "OFF"), # AVX2 only
|
||||
]
|
||||
|
||||
for variant_name, cpu_instruct, enable_amx in variants:
|
||||
@@ -384,6 +384,7 @@ class CMakeBuild(build_ext):
|
||||
build_temp: Temporary build directory for CMake
|
||||
cfg: Build type (Release/Debug/etc.)
|
||||
"""
|
||||
|
||||
# Auto-detect CUDA toolkit if user did not explicitly set CPUINFER_USE_CUDA
|
||||
def detect_cuda_toolkit() -> bool:
|
||||
# Respect CUDA_HOME
|
||||
@@ -614,10 +615,26 @@ setup(
|
||||
author="kvcache-ai",
|
||||
license="Apache-2.0",
|
||||
python_requires=">=3.8",
|
||||
packages=["kt_kernel", "kt_kernel.utils"],
|
||||
packages=[
|
||||
"kt_kernel",
|
||||
"kt_kernel.utils",
|
||||
"kt_kernel.cli",
|
||||
"kt_kernel.cli.commands",
|
||||
"kt_kernel.cli.config",
|
||||
"kt_kernel.cli.utils",
|
||||
],
|
||||
package_dir={
|
||||
"kt_kernel": "python",
|
||||
"kt_kernel.utils": "python/utils",
|
||||
"kt_kernel.cli": "python/cli",
|
||||
"kt_kernel.cli.commands": "python/cli/commands",
|
||||
"kt_kernel.cli.config": "python/cli/config",
|
||||
"kt_kernel.cli.utils": "python/cli/utils",
|
||||
},
|
||||
entry_points={
|
||||
"console_scripts": [
|
||||
"kt=kt_kernel.cli.main:main",
|
||||
],
|
||||
},
|
||||
ext_modules=[CMakeExtension("kt_kernel.kt_kernel_ext", str(REPO_ROOT))],
|
||||
cmdclass={"build_ext": CMakeBuild},
|
||||
|
||||
@@ -17,6 +17,7 @@ register_cpu_ci(est_time=30, suite="default")
|
||||
# Check if kt_kernel_ext is available
|
||||
try:
|
||||
import kt_kernel # Import kt_kernel first to register kt_kernel_ext
|
||||
|
||||
kt_kernel_ext = kt_kernel.kt_kernel_ext # Access the extension module
|
||||
HAS_KT_KERNEL = True
|
||||
except ImportError:
|
||||
@@ -51,7 +52,7 @@ def test_basic_module_attributes():
|
||||
pytest.skip("kt_kernel_ext not built or available")
|
||||
|
||||
# Check for key attributes/functions
|
||||
assert hasattr(kt_kernel_ext, 'CPUInfer'), "kt_kernel_ext should have CPUInfer class"
|
||||
assert hasattr(kt_kernel_ext, "CPUInfer"), "kt_kernel_ext should have CPUInfer class"
|
||||
|
||||
|
||||
def run_all_tests():
|
||||
|
||||
@@ -20,6 +20,7 @@ register_cpu_ci(est_time=120, suite="default")
|
||||
try:
|
||||
import torch
|
||||
import kt_kernel # Import kt_kernel first to register kt_kernel_ext
|
||||
|
||||
kt_kernel_ext = kt_kernel.kt_kernel_ext # Access the extension module
|
||||
HAS_DEPS = True
|
||||
except ImportError as e:
|
||||
@@ -68,9 +69,7 @@ def moe_torch(input, expert_ids, weights, gate_proj, up_proj, down_proj):
|
||||
if num_tokens == 0:
|
||||
continue
|
||||
tokens_for_this_expert = sorted_tokens[start_idx:end_idx]
|
||||
expert_out = mlp_torch(
|
||||
tokens_for_this_expert, gate_proj[i], up_proj[i], down_proj[i]
|
||||
)
|
||||
expert_out = mlp_torch(tokens_for_this_expert, gate_proj[i], up_proj[i], down_proj[i])
|
||||
outputs.append(expert_out)
|
||||
start_idx = end_idx
|
||||
|
||||
@@ -96,9 +95,7 @@ def test_moe_amx_int4_accuracy():
|
||||
pytest.skip(f"Dependencies not available: {import_error}")
|
||||
|
||||
global physical_to_logical_map
|
||||
physical_to_logical_map = torch.tensor(
|
||||
data=range(expert_num), device="cpu", dtype=torch.int64
|
||||
).contiguous()
|
||||
physical_to_logical_map = torch.tensor(data=range(expert_num), device="cpu", dtype=torch.int64).contiguous()
|
||||
|
||||
CPUInfer = kt_kernel_ext.CPUInfer(60)
|
||||
|
||||
@@ -133,9 +130,7 @@ def test_moe_amx_int4_accuracy():
|
||||
)
|
||||
|
||||
# Create MOE config
|
||||
config = kt_kernel_ext.moe.MOEConfig(
|
||||
expert_num, num_experts_per_tok, hidden_size, intermediate_size, 0
|
||||
)
|
||||
config = kt_kernel_ext.moe.MOEConfig(expert_num, num_experts_per_tok, hidden_size, intermediate_size, 0)
|
||||
config.max_len = max_len
|
||||
config.gate_proj = gate_proj.data_ptr()
|
||||
config.up_proj = up_proj.data_ptr()
|
||||
@@ -176,14 +171,10 @@ def test_moe_amx_int4_accuracy():
|
||||
CPUInfer.sync()
|
||||
|
||||
# Run torch reference
|
||||
t_output = moe_torch(
|
||||
input_data, expert_ids, weights, gate_proj, up_proj, down_proj
|
||||
)
|
||||
t_output = moe_torch(input_data, expert_ids, weights, gate_proj, up_proj, down_proj)
|
||||
|
||||
# Calculate relative difference
|
||||
diff = torch.mean(torch.abs(output - t_output)) / torch.mean(
|
||||
torch.abs(t_output)
|
||||
)
|
||||
diff = torch.mean(torch.abs(output - t_output)) / torch.mean(torch.abs(t_output))
|
||||
print(f"Iteration {i}, diff = {diff:.6f}")
|
||||
|
||||
# INT4 should have diff < 0.35
|
||||
@@ -205,6 +196,7 @@ def run_all_tests():
|
||||
except Exception as e:
|
||||
print(f"\n✗ Test failed: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
@@ -20,6 +20,7 @@ register_cpu_ci(est_time=120, suite="default")
|
||||
try:
|
||||
import torch
|
||||
import kt_kernel # Import kt_kernel first to register kt_kernel_ext
|
||||
|
||||
kt_kernel_ext = kt_kernel.kt_kernel_ext # Access the extension module
|
||||
HAS_DEPS = True
|
||||
except ImportError as e:
|
||||
@@ -68,9 +69,7 @@ def moe_torch(input, expert_ids, weights, gate_proj, up_proj, down_proj):
|
||||
if num_tokens == 0:
|
||||
continue
|
||||
tokens_for_this_expert = sorted_tokens[start_idx:end_idx]
|
||||
expert_out = mlp_torch(
|
||||
tokens_for_this_expert, gate_proj[i], up_proj[i], down_proj[i]
|
||||
)
|
||||
expert_out = mlp_torch(tokens_for_this_expert, gate_proj[i], up_proj[i], down_proj[i])
|
||||
outputs.append(expert_out)
|
||||
start_idx = end_idx
|
||||
|
||||
@@ -96,9 +95,7 @@ def test_moe_amx_int4_1_accuracy():
|
||||
pytest.skip(f"Dependencies not available: {import_error}")
|
||||
|
||||
global physical_to_logical_map
|
||||
physical_to_logical_map = torch.tensor(
|
||||
data=range(expert_num), device="cpu", dtype=torch.int64
|
||||
).contiguous()
|
||||
physical_to_logical_map = torch.tensor(data=range(expert_num), device="cpu", dtype=torch.int64).contiguous()
|
||||
|
||||
CPUInfer = kt_kernel_ext.CPUInfer(60)
|
||||
|
||||
@@ -133,9 +130,7 @@ def test_moe_amx_int4_1_accuracy():
|
||||
)
|
||||
|
||||
# Create MOE config
|
||||
config = kt_kernel_ext.moe.MOEConfig(
|
||||
expert_num, num_experts_per_tok, hidden_size, intermediate_size, 0
|
||||
)
|
||||
config = kt_kernel_ext.moe.MOEConfig(expert_num, num_experts_per_tok, hidden_size, intermediate_size, 0)
|
||||
config.max_len = max_len
|
||||
config.gate_proj = gate_proj.data_ptr()
|
||||
config.up_proj = up_proj.data_ptr()
|
||||
@@ -176,14 +171,10 @@ def test_moe_amx_int4_1_accuracy():
|
||||
CPUInfer.sync()
|
||||
|
||||
# Run torch reference
|
||||
t_output = moe_torch(
|
||||
input_data, expert_ids, weights, gate_proj, up_proj, down_proj
|
||||
)
|
||||
t_output = moe_torch(input_data, expert_ids, weights, gate_proj, up_proj, down_proj)
|
||||
|
||||
# Calculate relative difference
|
||||
diff = torch.mean(torch.abs(output - t_output)) / torch.mean(
|
||||
torch.abs(t_output)
|
||||
)
|
||||
diff = torch.mean(torch.abs(output - t_output)) / torch.mean(torch.abs(t_output))
|
||||
print(f"Iteration {i}, diff = {diff:.6f}")
|
||||
|
||||
# INT4_1 should have diff < 0.35
|
||||
@@ -205,6 +196,7 @@ def run_all_tests():
|
||||
except Exception as e:
|
||||
print(f"\n✗ Test failed: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
@@ -20,6 +20,7 @@ register_cpu_ci(est_time=120, suite="default")
|
||||
try:
|
||||
import torch
|
||||
import kt_kernel # Import kt_kernel first to register kt_kernel_ext
|
||||
|
||||
kt_kernel_ext = kt_kernel.kt_kernel_ext # Access the extension module
|
||||
HAS_DEPS = True
|
||||
except ImportError as e:
|
||||
@@ -69,9 +70,7 @@ def moe_torch(input, expert_ids, weights, gate_proj, up_proj, down_proj):
|
||||
if num_tokens == 0:
|
||||
continue
|
||||
tokens_for_this_expert = sorted_tokens[start_idx:end_idx]
|
||||
expert_out = mlp_torch(
|
||||
tokens_for_this_expert, gate_proj[i], up_proj[i], down_proj[i]
|
||||
)
|
||||
expert_out = mlp_torch(tokens_for_this_expert, gate_proj[i], up_proj[i], down_proj[i])
|
||||
outputs.append(expert_out)
|
||||
start_idx = end_idx
|
||||
|
||||
@@ -97,9 +96,7 @@ def test_moe_amx_int4_1k_accuracy():
|
||||
pytest.skip(f"Dependencies not available: {import_error}")
|
||||
|
||||
global physical_to_logical_map
|
||||
physical_to_logical_map = torch.tensor(
|
||||
data=range(expert_num), device="cpu", dtype=torch.int64
|
||||
).contiguous()
|
||||
physical_to_logical_map = torch.tensor(data=range(expert_num), device="cpu", dtype=torch.int64).contiguous()
|
||||
|
||||
CPUInfer = kt_kernel_ext.CPUInfer(60)
|
||||
|
||||
@@ -134,9 +131,7 @@ def test_moe_amx_int4_1k_accuracy():
|
||||
)
|
||||
|
||||
# Create MOE config
|
||||
config = kt_kernel_ext.moe.MOEConfig(
|
||||
expert_num, num_experts_per_tok, hidden_size, intermediate_size, 0
|
||||
)
|
||||
config = kt_kernel_ext.moe.MOEConfig(expert_num, num_experts_per_tok, hidden_size, intermediate_size, 0)
|
||||
config.max_len = max_len
|
||||
config.gate_proj = gate_proj.data_ptr()
|
||||
config.up_proj = up_proj.data_ptr()
|
||||
@@ -180,14 +175,10 @@ def test_moe_amx_int4_1k_accuracy():
|
||||
CPUInfer.sync()
|
||||
|
||||
# Run torch reference
|
||||
t_output = moe_torch(
|
||||
input_data, expert_ids, weights, gate_proj, up_proj, down_proj
|
||||
)
|
||||
t_output = moe_torch(input_data, expert_ids, weights, gate_proj, up_proj, down_proj)
|
||||
|
||||
# Calculate relative difference
|
||||
diff = torch.mean(torch.abs(output - t_output)) / torch.mean(
|
||||
torch.abs(t_output)
|
||||
)
|
||||
diff = torch.mean(torch.abs(output - t_output)) / torch.mean(torch.abs(t_output))
|
||||
print(f"Iteration {i}, diff = {diff:.6f}")
|
||||
|
||||
# INT4_1K should have diff < 0.35
|
||||
@@ -209,6 +200,7 @@ def run_all_tests():
|
||||
except Exception as e:
|
||||
print(f"\n✗ Test failed: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
@@ -20,6 +20,7 @@ register_cpu_ci(est_time=120, suite="default")
|
||||
try:
|
||||
import torch
|
||||
import kt_kernel # Import kt_kernel first to register kt_kernel_ext
|
||||
|
||||
kt_kernel_ext = kt_kernel.kt_kernel_ext # Access the extension module
|
||||
HAS_DEPS = True
|
||||
except ImportError as e:
|
||||
@@ -68,9 +69,7 @@ def moe_torch(input, expert_ids, weights, gate_proj, up_proj, down_proj):
|
||||
if num_tokens == 0:
|
||||
continue
|
||||
tokens_for_this_expert = sorted_tokens[start_idx:end_idx]
|
||||
expert_out = mlp_torch(
|
||||
tokens_for_this_expert, gate_proj[i], up_proj[i], down_proj[i]
|
||||
)
|
||||
expert_out = mlp_torch(tokens_for_this_expert, gate_proj[i], up_proj[i], down_proj[i])
|
||||
outputs.append(expert_out)
|
||||
start_idx = end_idx
|
||||
|
||||
@@ -96,9 +95,7 @@ def test_moe_amx_int8_accuracy():
|
||||
pytest.skip(f"Dependencies not available: {import_error}")
|
||||
|
||||
global physical_to_logical_map
|
||||
physical_to_logical_map = torch.tensor(
|
||||
data=range(expert_num), device="cpu", dtype=torch.int64
|
||||
).contiguous()
|
||||
physical_to_logical_map = torch.tensor(data=range(expert_num), device="cpu", dtype=torch.int64).contiguous()
|
||||
|
||||
CPUInfer = kt_kernel_ext.CPUInfer(60)
|
||||
|
||||
@@ -133,9 +130,7 @@ def test_moe_amx_int8_accuracy():
|
||||
)
|
||||
|
||||
# Create MOE config
|
||||
config = kt_kernel_ext.moe.MOEConfig(
|
||||
expert_num, num_experts_per_tok, hidden_size, intermediate_size, 0
|
||||
)
|
||||
config = kt_kernel_ext.moe.MOEConfig(expert_num, num_experts_per_tok, hidden_size, intermediate_size, 0)
|
||||
config.max_len = max_len
|
||||
config.gate_proj = gate_proj.data_ptr()
|
||||
config.up_proj = up_proj.data_ptr()
|
||||
@@ -174,14 +169,10 @@ def test_moe_amx_int8_accuracy():
|
||||
CPUInfer.sync()
|
||||
|
||||
# Run torch reference
|
||||
t_output = moe_torch(
|
||||
input_data, expert_ids, weights, gate_proj, up_proj, down_proj
|
||||
)
|
||||
t_output = moe_torch(input_data, expert_ids, weights, gate_proj, up_proj, down_proj)
|
||||
|
||||
# Calculate relative difference
|
||||
diff = torch.mean(torch.abs(output - t_output)) / torch.mean(
|
||||
torch.abs(t_output)
|
||||
)
|
||||
diff = torch.mean(torch.abs(output - t_output)) / torch.mean(torch.abs(t_output))
|
||||
print(f"Iteration {i}, diff = {diff:.6f}")
|
||||
|
||||
# INT8 should have diff < 0.05
|
||||
@@ -203,6 +194,7 @@ def run_all_tests():
|
||||
except Exception as e:
|
||||
print(f"\n✗ Test failed: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
@@ -24,8 +24,10 @@ register_cpu_ci(est_time=300, suite="default")
|
||||
try:
|
||||
import torch
|
||||
import kt_kernel # Import kt_kernel first to register kt_kernel_ext
|
||||
|
||||
kt_kernel_ext = kt_kernel.kt_kernel_ext # Access the extension module
|
||||
from tqdm import tqdm
|
||||
|
||||
HAS_DEPS = True
|
||||
except ImportError as e:
|
||||
HAS_DEPS = False
|
||||
@@ -306,6 +308,7 @@ def run_all_tests():
|
||||
except Exception as e:
|
||||
print(f"\n✗ Test failed: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
@@ -24,6 +24,7 @@ register_cpu_ci(est_time=300, suite="default")
|
||||
try:
|
||||
import torch
|
||||
import kt_kernel # Import kt_kernel first to register kt_kernel_ext
|
||||
|
||||
kt_kernel_ext = kt_kernel.kt_kernel_ext # Access the extension module
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
@@ -25,8 +25,10 @@ register_cpu_ci(est_time=300, suite="default")
|
||||
try:
|
||||
import torch
|
||||
import kt_kernel # Import kt_kernel first to register kt_kernel_ext
|
||||
|
||||
kt_kernel_ext = kt_kernel.kt_kernel_ext # Access the extension module
|
||||
from tqdm import tqdm
|
||||
|
||||
HAS_DEPS = True
|
||||
except ImportError as e:
|
||||
HAS_DEPS = False
|
||||
@@ -156,11 +158,7 @@ def test_moe_amx_int4_1k_benchmark():
|
||||
CPUInfer = kt_kernel_ext.CPUInfer(worker_config)
|
||||
|
||||
# Physical to logical map for weight loading
|
||||
physical_to_logical_map = torch.tensor(
|
||||
data=range(expert_num),
|
||||
device="cpu",
|
||||
dtype=torch.int64
|
||||
).contiguous()
|
||||
physical_to_logical_map = torch.tensor(data=range(expert_num), device="cpu", dtype=torch.int64).contiguous()
|
||||
|
||||
# Initialize MOE layers
|
||||
moes = []
|
||||
@@ -322,6 +320,7 @@ def run_all_tests():
|
||||
except Exception as e:
|
||||
print(f"\nTest failed: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
@@ -24,8 +24,10 @@ register_cpu_ci(est_time=300, suite="default")
|
||||
try:
|
||||
import torch
|
||||
import kt_kernel # Import kt_kernel first to register kt_kernel_ext
|
||||
|
||||
kt_kernel_ext = kt_kernel.kt_kernel_ext # Access the extension module
|
||||
from tqdm import tqdm
|
||||
|
||||
HAS_DEPS = True
|
||||
except ImportError as e:
|
||||
HAS_DEPS = False
|
||||
@@ -51,7 +53,6 @@ worker_config_dict = {
|
||||
CPUINFER_PARAM = 60
|
||||
|
||||
|
||||
|
||||
def get_git_commit():
|
||||
"""Get current git commit information."""
|
||||
result = {}
|
||||
@@ -307,6 +308,7 @@ def run_all_tests():
|
||||
except Exception as e:
|
||||
print(f"\n✗ Test failed: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user