Kt minimax (#1742)

[feat]: fp8 kernel and kt-cli support
This commit is contained in:
ErvinXie
2025-12-24 15:39:44 +08:00
committed by GitHub
parent e7d277d163
commit d8046e1bb4
65 changed files with 12111 additions and 2502 deletions

View File

@@ -16,6 +16,8 @@
KTransformers is a research project focused on efficient inference and fine-tuning of large language models through CPU-GPU heterogeneous computing. The project has evolved into **two core modules**: [kt-kernel](./kt-kernel/) and [kt-sft](./kt-sft/).
## 🔥 Updates
* **Dec 24, 2025**: Support Native MiniMax-M2.1 inference. ([Tutorial](./doc/en/MiniMax-M2.1-Tutorial.md))
* **Dec 22, 2025**: Support RL-DPO fine-tuning with LLaMA-Factory. ([Tutorial](./doc/en/SFT/DPO_tutorial.md))
* **Dec 5, 2025**: Support Native Kimi-K2-Thinking inference ([Tutorial](./doc/en/Kimi-K2-Thinking-Native.md))
* **Nov 6, 2025**: Support Kimi-K2-Thinking inference ([Tutorial](./doc/en/Kimi-K2-Thinking.md)) and fine-tune ([Tutorial](./doc/en/SFT_Installation_Guide_KimiK2.md))

Binary file not shown.

After

Width:  |  Height:  |  Size: 73 KiB

View File

@@ -0,0 +1,198 @@
# Running MiniMax-M2.1 with Native Precision using SGLang and KT-Kernel
This tutorial demonstrates how to run MiniMax-M2.1 model inference using SGLang integrated with KT-Kernel. MiniMax-M2.1 provides native FP8 weights, enabling efficient GPU inference with reduced memory footprint while maintaining high accuracy.
## Table of Contents
- [Hardware Requirements](#hardware-requirements)
- [Prerequisites](#prerequisites)
- [Step 1: Download Model Weights](#step-1-download-model-weights)
- [Step 2: Launch Server with KT CLI](#step-2-launch-server-with-kt-cli)
- [Step 3: Send Inference Requests](#step-3-send-inference-requests)
- [Performance](#performance)
- [Troubleshooting](#troubleshooting)
## Hardware Requirements
**Minimum Configuration:**
- **GPU**: NVIDIA RTX 5090 32 GB (or equivalent with at least 32GB VRAM available)
- **CPU**: x86 CPU with AVX512 support (e.g., Intel Sapphire Rapids, AMD EPYC)
- **RAM**: At least 256GB system memory
- **Storage**: >220 GB for model weights (same weight dir for GPU and CPU)
**Tested Configuration:**
- **GPU**: 1/2 x NVIDIA GeForce RTX 5090 (32 GB)
- **CPU**: 2 x AMD EPYC 9355 32-Core Processor (128 threads)
- **RAM**: 1TB DDR5 5600MT/s ECC
- **OS**: Linux (Ubuntu 20.04+ recommended)
## Prerequisites
Before starting, ensure you have:
1. **SGLang installed**
Note: Currently, please clone our custom SGLang repository:
```bash
git clone https://github.com/kvcache-ai/sglang.git
cd sglang
pip install -e "python[all]"
```
You can follow [SGLang integration steps](https://docs.sglang.io/get_started/install.html)
2. **KT-Kernel installed**
Please follow [kt-kernel](https://github.com/kvcache-ai/ktransformers/blob/main/kt-kernel/README.md)
After installation, verify the CLI is working:
```bash
kt version
```
3. **CUDA toolkit** - CUDA 12.0+ recommended for FP8 support
4. **Hugging Face CLI** - For downloading models:
```bash
pip install -U huggingface-hub
```
## Step 1: Download Model Weights
Download the official MiniMax-M2.1 weights.
* huggingface: https://huggingface.co/MiniMaxAI/MiniMax-M2.1
```bash
hf download MiniMaxAI/MiniMax-M2.1 --local-dir /path/to/minimax-m2.1
```
## Step 2: Launch Server with KT CLI
The simplest way to start the MiniMax-M2.1 server is using the `kt` CLI:
```bash
kt run m2.1
```
The CLI will automatically detect your hardware configuration and apply optimal parameters for your system.
### Advanced Options
For custom configurations, you can specify additional parameters:
```bash
# Use specific number of GPUs (tensor parallel)
kt run m2.1 --tensor-parallel-size 2
# Custom CPU threads and NUMA configuration
kt run m2.1 --cpu-threads 64 --numa-nodes 2
```
### Dry Run
To preview the command without executing:
```bash
kt run m2.1 --dry-run
```
See [KT-Kernel Parameters](https://github.com/kvcache-ai/ktransformers/tree/main/kt-kernel#kt-kernel-parameters) for detailed parameter tuning guidelines.
### Key Parameters
| Parameter | Description |
|-----------|-------------|
| `--kt-method FP8` | Enable FP8 inference mode for MiniMax-M2.1 native FP8 weights. |
| `--kt-cpuinfer` | Number of CPU inference threads. Set to physical CPU cores (not hyperthreads). |
| `--kt-threadpool-count` | Number of thread pools. Set to NUMA node count. |
| `--kt-num-gpu-experts` | Number of experts kept on GPU for decoding. |
| `--chunked-prefill-size` | Maximum tokens per prefill batch. |
| `--max-total-tokens` | Maximum total tokens in KV cache. |
| `--kt-gpu-prefill-token-threshold` | Token threshold for layerwise prefill strategy. |
## Step 3: Send Inference Requests
Once the server is running (default: `http://localhost:30000`), you can interact with the model in several ways:
### Option A: Interactive Chat with KT CLI
The easiest way to chat with the model:
```bash
kt chat
```
This opens an interactive terminal chat session. Type your messages and press Enter to send. Use `Ctrl+C` to exit.
### Option B: OpenAI-Compatible API
The server exposes an OpenAI-compatible API at `http://localhost:30000/v1`.
**curl example (streaming):**
```bash
curl http://localhost:30000/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{
"model": "MiniMax-M2.1",
"messages": [{"role": "user", "content": "Hello!"}],
"stream": true
}'
```
## Performance
### Throughput (tokens/s)
The following benchmarks were measured with single concurrency (Prefill tps / Decode tps):
| GPU | CPU | PCIe | 2048 tokens | 8192 tokens | 32768 tokens |
|------------|-------------|-------------|-------------|-------------|--------------|
| 1 x RTX 4090 (48 GB) | 2 x Intel Xeon Platinum 8488C| PCIe 4.0 | 129 / 21.8 | 669 / 20.9 | 1385 / 18.5 |
| 2 x RTX 4090 (48 GB) | 2 x Intel Xeon Platinum 8488C| PCIe 4.0 | 139 / 23.6 | 1013 / 23.3 | 2269 / 21.6 |
| 1 x RTX 5090 (32 GB) | 2 x AMD EPYC 9355 | PCIe 5.0 | 408 / 32.1 | 1196 / 31.4 | 2540 / 27.6 |
| 2 x RTX 5090 (32 GB) | 2 x AMD EPYC 9355 | PCIe 5.0 | 414 / 35.9 | 1847 / 35.5 | 4007 / 33.1 |
### Comparison with llama.cpp
We benchmarked KT-Kernel + Sglang against llama.cpp to demonstrate the performance advantages of our CPU-GPU heterogeneous inference approach.
- **Weight formats**: KT-Kernel uses native unquantized FP8 weights from MiniMax-M2, while llama.cpp only supports quantized weights, so we used Q8_0 quantization for the llama.cpp benchmarks.
- **Test environment**: 2 x RTX 5090 (32 GB) with AMD EPYC 9355 CPUs, input tokens=32768, output tokens=512. We made our best effort to optimize llama.cpp performance, but we could not achieve optimal prefill and decode with a single command, so we used separate configurations for prefill and decode measurements.
![Performance Comparison with llama.cpp](../assets/MiniMax-M2_comparison.png)
As shown in the chart, KT-Kernel achieves up to **>4.5x prefill** and **30% faster decode** compared to llama.cpp on the same hardware.
## Troubleshooting
### OOM (Out of Memory) Issues
Layerwise prefill requires extra VRAM (~3.6GB + incremental cost with prefill length). If you encounter OOM, adjust these parameters when launching the server:
| Parameter | VRAM Impact |
|-----------|-------------|
| `--kt-num-gpu-experts` | Reduces expert weight VRAM usage |
| `--chunked-prefill-size` | Reduces prefill extra VRAM allocation |
| `--max-total-tokens` | Reduces KV cache VRAM usage |
**Tip:** Test with an input of length `chunked-prefill-size` to verify your configuration won't OOM during prefill.
## Advanced Use Case: Running Claude Code with MiniMax-M2.1 Local Backend
```bash
kt run m2.1 --tool-call-parser minimax-m2 --reasoning-parser minimax-append-think
```
With the above command, you can use [claude-code-router](https://github.com/musistudio/claude-code-router) to connect MiniMax-M2.1 as a local backend for [Claude Code](https://github.com/anthropics/claude-code).
## Additional Resources
- [KT-Kernel Documentation](../../kt-kernel/README.md)
- [SGLang GitHub](https://github.com/sgl-project/sglang)
- [KT-Kernel Parameters Reference](../../kt-kernel/README.md#kt-kernel-parameters)

View File

@@ -38,6 +38,7 @@ High-performance kernel operations for KTransformers, featuring CPU-optimized Mo
-**Universal CPU (llamafile backend)**: Supported (using GGUF-format weights)
-**AMD CPUs with BLIS**: Supported (for int8 prefill & decode)
-**Kimi-K2 Native INT4 (RAWINT4)**: Supported on AVX512 CPUs (CPU-GPU shared INT4 weights) - [Guide](../doc/en/Kimi-K2-Thinking-Native.md)
-**FP8 weights (e.g., MiniMax-M2.1)**: Supported on AVX512 CPUs (CPU-GPU shared FP8 weights) - [Guide](../doc/en/MiniMax-M2.1-Tutorial.md)
## Features
@@ -167,10 +168,57 @@ Simply run the install script - it will auto-detect your CPU and optimize for be
## Verification
After installation, verify that the CLI is working:
```bash
kt version
```
Expected output:
```
KTransformers CLI v0.x.x
Python: 3.11.x
Platform: Linux 5.15.0-xxx-generic
CUDA: 12.x
kt-kernel: 0.x.x (amx)
sglang: 0.x.x
```
You can also verify the Python module directly:
```bash
python -c "from kt_kernel import KTMoEWrapper; print('✓ kt-kernel installed successfully')"
```
## KT CLI Overview
The `kt` command-line tool provides a unified interface for running and managing KTransformers models:
| Command | Description |
|---------|-------------|
| `kt run <model>` | Start model inference server with auto-optimized parameters |
| `kt chat` | Interactive chat with a running model server |
| `kt model` | Manage models and storage paths |
| `kt doctor` | Diagnose environment issues and check system compatibility |
| `kt config` | Manage CLI configuration |
| `kt version` | Show version information |
**Quick Start Example:**
```bash
# Start a model server (auto-detects hardware and applies optimal settings)
kt run m2
# In another terminal, chat with the model
kt chat
# Check system compatibility
kt doctor
```
Run `kt --help` for more options, or `kt <command> --help` for command-specific help.
## Integration with SGLang
KT-Kernel can be used standalone via [Direct Python API](#direct-python-api-usage) or integrated with SGLang for production deployment. This section describes SGLang integration to enable CPU-GPU heterogeneous inference, where "hot" experts run on GPU and "cold" experts run on CPU for optimal resource utilization.
@@ -361,13 +409,13 @@ python -m sglang.launch_server \
| Parameter | Description | Example Value |
|-----------|-------------|---------------|
| `--kt-method` | CPU inference backend method | `AMXINT4`, `AMXINT8`, `RAWINT4`, or `LLAMAFILE` |
| `--kt-method` | CPU inference backend method | `AMXINT4`, `AMXINT8`, `RAWINT4`, `FP8` or `LLAMAFILE` |
| `--kt-weight-path` | Path to quantized CPU weights | `/path/to/cpu-weights` |
| `--kt-cpuinfer` | Number of CPU inference threads | `64` (adjust based on CPU cores) |
| `--kt-threadpool-count` | Number of thread pools for parallel execution | `2` (typically 1-4) |
| `--kt-num-gpu-experts` | Number of experts to keep on GPU | `32` (remaining experts go to CPU) |
| `--kt-max-deferred-experts-per-token` | Number of experts per token to defer for pipelined execution | `2` (0 to disable, 1-4 recommended) |
| `--kt-gpu-prefill-token-threshold` | Token count threshold for prefill strategy (RAWINT4 only) | ~`400` |
| `--kt-gpu-prefill-token-threshold` | Token count threshold for prefill strategy (FP8 and RAWINT4 only) | ~`1024` |
**Parameter Guidelines:**
@@ -375,6 +423,7 @@ python -m sglang.launch_server \
- `AMXINT4`: Best performance on AMX CPUs with INT4 quantized weights (May cause huge accuracy drop for some models, e.g., Qwen3-30B-A3B)
- `AMXINT8`: Higher accuracy with INT8 quantized weights on AMX CPUs
- `RAWINT4`: Native INT4 weights shared by CPU and GPU (AMX backend only, currently supports Kimi-K2-Thinking model). See [Kimi-K2-Thinking Native Tutorial](../doc/en/Kimi-K2-Thinking-Native.md) for details.
- `FP8`: FP8 weights shared by CPU and GPU
- `LLAMAFILE`: GGUF-based backend
- **`kt-cpuinfer`**: Set to the number of **physical CPU cores** (not hyperthreads).
@@ -400,10 +449,10 @@ python -m sglang.launch_server \
- `1-4`: Deferred execution (recommended range; good latency/quality balance, requires tuning)
- `5-7`: Highest latency reduction but may introduce noticeable accuracy loss; use with care
- **`kt-gpu-prefill-token-threshold`** (RAWINT4 only): Controls prefill strategy for native INT4 inference:
- **`kt-gpu-prefill-token-threshold`** (FP8 and RAWINT4 only): Controls prefill strategy for native FP8 and INT4 inference:
- **≤ threshold**: Uses hybrid CPU+GPU prefill. No extra VRAM needed, but performance degrades slowly as token count increases.
- **> threshold**: Uses layerwise GPU prefill. Performance scales better with longer sequences, but requires ~9GB+ extra VRAM.
- Only applicable when `--kt-method RAWINT4` is used. Currently supports Kimi-K2-Thinking model only.
- **> threshold**: Uses layerwise GPU prefill. Performance scales better with longer sequences, but requires one MoE layer extra VRAM (e.g., ~9GB+ for Kimi-K2-Thinking and ~3.6GB for MiniMax-M2.1).
- Only applicable when `--kt-method RAWINT4` or `--kt-method FP8` is used.
## Direct Python API Usage

View File

@@ -0,0 +1,286 @@
"""
Performance benchmark for FP8 MoE kernel (AVX implementation).
This benchmark measures the performance of the FP8 MoE operator with:
- FP8 (E4M3) weights with 128x128 block-wise scaling
- BF16 activations
- AVX-512 DPBF16 compute path
"""
import os
import sys
import time
import json
import subprocess
import platform
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "build"))
import torch
import kt_kernel_ext
from tqdm import tqdm
# Test parameters
expert_num = 256
hidden_size = 7168
intermediate_size = 2048
num_experts_per_tok = 8
fp8_group_size = 128
max_len = 25600
layer_num = 2
qlen = 1024
warm_up_iter = 10
test_iter = 30
CPUINFER_PARAM = 80
CPUInfer = kt_kernel_ext.CPUInfer(CPUINFER_PARAM)
# Result file path
script_path = os.path.abspath(__file__)
script_dir = os.path.dirname(script_path)
json_path = os.path.join(script_dir, "bench_results.jsonl")
def get_git_commit():
"""Get current git commit info"""
result = {}
try:
commit = subprocess.check_output(["git", "rev-parse", "HEAD"]).decode("utf-8").strip()
commit_msg = subprocess.check_output(["git", "log", "-1", "--pretty=%B"]).decode("utf-8").strip()
result["commit"] = commit
result["commit_message"] = commit_msg
dirty_output = subprocess.check_output(["git", "status", "--porcelain"]).decode("utf-8").strip()
result["dirty"] = bool(dirty_output)
if dirty_output:
result["dirty_files"] = dirty_output.splitlines()
except Exception as e:
result["commit"] = None
result["error"] = str(e)
return result
def get_system_info():
"""Get system information"""
info = {}
uname = platform.uname()
info["system_name"] = uname.system
info["node_name"] = uname.node
cpu_model = None
if os.path.exists("/proc/cpuinfo"):
try:
with open("/proc/cpuinfo", "r") as f:
for line in f:
if "model name" in line:
cpu_model = line.split(":", 1)[1].strip()
break
except Exception:
pass
info["cpu_model"] = cpu_model
info["cpu_core_count"] = os.cpu_count()
return info
def record_results(result, filename=json_path):
"""Append result to JSON file"""
with open(filename, "a") as f:
f.write(json.dumps(result) + "\n")
def generate_fp8_weights_direct(shape: tuple, group_size: int = 128):
"""
Directly generate random FP8 weights and e8m0 format scale_inv.
Args:
shape: (expert_num, n, k) - weight tensor shape
group_size: block size for scaling (128x128 blocks)
Returns:
fp8_weights: uint8 tensor with random FP8 E4M3 values
scale_inv: fp32 tensor with e8m0 format (powers of 2)
"""
e, n, k = shape
n_blocks = n // group_size
k_blocks = k // group_size
# Directly generate random FP8 weights as uint8
# FP8 E4M3 format: 1 sign + 4 exp + 3 mantissa
# Valid range for normal numbers: exp 1-14 (0 is subnormal, 15 is special)
fp8_weights = torch.randint(0, 256, (e, n, k), dtype=torch.uint8, device="cuda").to("cpu").contiguous()
# Generate e8m0 format scale_inv (powers of 2)
# e8m0: 8-bit exponent only, no mantissa, bias = 127
# Generate random exponents in a reasonable range (e.g., -8 to 8)
exponents = torch.randint(-8, 9, (e, n_blocks, k_blocks), dtype=torch.int32, device="cuda").to("cpu").contiguous()
scale_inv = (2.0 ** exponents.float()).to(torch.float32).contiguous()
return fp8_weights, scale_inv
def bench_fp8_moe():
"""Benchmark FP8 MoE performance"""
with torch.inference_mode():
print("=" * 70)
print("FP8 MoE Kernel Performance Benchmark")
print("=" * 70)
# Generate FP8 weights directly (no quantization from fp32)
print("\nGenerating FP8 weights directly...")
torch.manual_seed(42)
gate_fp8, gate_scales = generate_fp8_weights_direct(
(expert_num, intermediate_size, hidden_size), fp8_group_size
)
up_fp8, up_scales = generate_fp8_weights_direct((expert_num, intermediate_size, hidden_size), fp8_group_size)
down_fp8, down_scales = generate_fp8_weights_direct(
(expert_num, hidden_size, intermediate_size), fp8_group_size
)
physical_to_logical_map = torch.tensor(range(expert_num), device="cpu", dtype=torch.int64).contiguous()
# Build MoE layers
print("Building FP8 MoE layers...")
moes = []
for _ in tqdm(range(layer_num), desc="Initializing MOEs"):
config = kt_kernel_ext.moe.MOEConfig(expert_num, num_experts_per_tok, hidden_size, intermediate_size, 0)
config.max_len = max_len
config.quant_config.bits = 8
config.quant_config.group_size = fp8_group_size
config.quant_config.zero_point = False
config.gate_proj = gate_fp8.data_ptr()
config.up_proj = up_fp8.data_ptr()
config.down_proj = down_fp8.data_ptr()
config.gate_scale = gate_scales.data_ptr()
config.up_scale = up_scales.data_ptr()
config.down_scale = down_scales.data_ptr()
config.pool = CPUInfer.backend_
moe = kt_kernel_ext.moe.AMXFP8_MOE(config)
CPUInfer.submit(moe.load_weights_task(physical_to_logical_map.data_ptr()))
CPUInfer.sync()
moes.append(moe)
# Generate input data
print("Generating input data...")
gen_iter = 1000
expert_ids = (
torch.rand(gen_iter * qlen, expert_num, device="cpu")
.argsort(dim=-1)[:, :num_experts_per_tok]
.reshape(gen_iter, qlen * num_experts_per_tok)
.contiguous()
)
weights = torch.rand((gen_iter, qlen, num_experts_per_tok), dtype=torch.float32, device="cpu").contiguous()
input_tensor = torch.randn((layer_num, qlen, hidden_size), dtype=torch.bfloat16, device="cpu").contiguous()
output_tensor = torch.empty((layer_num, qlen, hidden_size), dtype=torch.bfloat16, device="cpu").contiguous()
qlen_tensor = torch.tensor([qlen], dtype=torch.int32)
# Warmup
print(f"Warming up ({warm_up_iter} iterations)...")
for i in tqdm(range(warm_up_iter), desc="Warm-up"):
CPUInfer.submit(
moes[i % layer_num].forward_task(
qlen_tensor.data_ptr(),
num_experts_per_tok,
expert_ids[i % gen_iter].data_ptr(),
weights[i % gen_iter].data_ptr(),
input_tensor[i % layer_num].data_ptr(),
output_tensor[i % layer_num].data_ptr(),
False,
)
)
CPUInfer.sync()
# Benchmark
print(f"Running benchmark ({test_iter} iterations)...")
start = time.perf_counter()
for i in tqdm(range(test_iter), desc="Testing"):
CPUInfer.submit(
moes[i % layer_num].forward_task(
qlen_tensor.data_ptr(),
num_experts_per_tok,
expert_ids[i % gen_iter].data_ptr(),
weights[i % gen_iter].data_ptr(),
input_tensor[i % layer_num].data_ptr(),
output_tensor[i % layer_num].data_ptr(),
False,
)
)
CPUInfer.sync()
end = time.perf_counter()
total_time = end - start
# Calculate metrics
time_per_iter_us = total_time / test_iter * 1e6
# FLOPS calculation:
# Each expert performs: gate(intermediate x hidden) + up(intermediate x hidden) + down(hidden x intermediate)
# GEMM/GEMV: 2 * m * n * k flops (multiply + accumulate = 2 ops per element)
# For vector-matrix multiply (qlen=1): 2 * n * k per matrix
flops_per_expert = (
2 * intermediate_size * hidden_size # gate
+ 2 * intermediate_size * hidden_size # up
+ 2 * hidden_size * intermediate_size # down
)
total_flops = qlen * num_experts_per_tok * flops_per_expert * test_iter
tflops = total_flops / total_time / 1e12
# Bandwidth calculation (FP8 = 1 byte per element)
bytes_per_elem = 1.0
# Weight memory: gate + up + down per expert
bandwidth = (
hidden_size
* intermediate_size
* 3
* num_experts_per_tok
* (1 / num_experts_per_tok * expert_num * (1 - (1 - num_experts_per_tok / expert_num) ** qlen))
* bytes_per_elem
* test_iter
/ total_time
/ 1e9
) # 单位GB/s
# Print results
print("\n" + "=" * 70)
print("Benchmark Results")
print("=" * 70)
print(f"Quant mode: FP8 (E4M3) with {fp8_group_size}x{fp8_group_size} block scaling")
print(f"Total time: {total_time:.4f} s")
print(f"Iterations: {test_iter}")
print(f"Time per iteration: {time_per_iter_us:.2f} us")
print(f"Bandwidth: {bandwidth:.2f} GB/s")
print(f"TFLOPS: {tflops:.4f}")
print("")
# Record results
result = {
"test_name": os.path.basename(__file__),
"quant_mode": "fp8_e4m3",
"total_time_seconds": total_time,
"iterations": test_iter,
"time_per_iteration_us": time_per_iter_us,
"bandwidth_GBs": bandwidth,
"flops_TFLOPS": tflops,
"timestamp": time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()),
"test_parameters": {
"expert_num": expert_num,
"hidden_size": hidden_size,
"intermediate_size": intermediate_size,
"num_experts_per_tok": num_experts_per_tok,
"fp8_group_size": fp8_group_size,
"layer_num": layer_num,
"qlen": qlen,
"warm_up_iter": warm_up_iter,
"test_iter": test_iter,
"CPUInfer_parameter": CPUINFER_PARAM,
},
}
result.update(get_git_commit())
result.update(get_system_info())
record_results(result)
return tflops, bandwidth
if __name__ == "__main__":
bench_fp8_moe()

View File

@@ -0,0 +1,294 @@
#!/usr/bin/env python
# coding=utf-8
"""
Benchmark write_weight_scale_to_buffer for AMX_FP8_MOE_TP (FP8 weights + float32 scales).
Uses two MOE instances that alternate writing to simulate realistic multi-layer scenarios.
"""
import json
import os
import platform
import subprocess
import sys
import time
from tqdm import tqdm
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "build"))
from kt_kernel import kt_kernel_ext
from kt_kernel_ext.moe import AMXFP8_MOE
import torch
# Benchmark parameters
expert_num = 256
num_experts_per_tok = 8
gpu_tp_count = 2
warm_up_iter = 3
test_iter = 7
gpu_experts_num = expert_num
hidden_size = 7168
intermediate_size = 2048
group_size = 128 # FP8 uses 128x128 block-wise scales
max_len = 1
physical_to_logical_map = torch.arange(expert_num, dtype=torch.int64, device="cpu").contiguous()
CPUInfer = kt_kernel_ext.CPUInfer(80)
def get_git_commit():
result = {}
try:
commit = subprocess.check_output(["git", "rev-parse", "HEAD"]).decode("utf-8").strip()
commit_msg = subprocess.check_output(["git", "log", "-1", "--pretty=%B"]).decode("utf-8").strip()
result["commit"] = commit
result["commit_message"] = commit_msg
dirty_output = subprocess.check_output(["git", "status", "--porcelain"]).decode("utf-8").strip()
result["dirty"] = bool(dirty_output)
if dirty_output:
result["dirty_files"] = dirty_output.splitlines()
except Exception as e:
result["error"] = str(e)
return result
def get_system_info():
info = {}
info["system_name"] = platform.uname().system
info["node_name"] = platform.uname().node
info["cpu_core_count"] = os.cpu_count()
if os.path.exists("/proc/cpuinfo"):
with open("/proc/cpuinfo", "r") as f:
for line in f:
if "model name" in line:
info["cpu_model"] = line.split(":", 1)[1].strip()
break
if os.path.exists("/proc/meminfo"):
with open("/proc/meminfo", "r") as f:
for line in f:
if "MemTotal" in line:
mem_kb = float(line.split(":", 1)[1].split()[0])
info["memory_size_GB"] = round(mem_kb / (1024 * 1024), 2)
break
return info
script_path = os.path.abspath(__file__)
script_dir = os.path.dirname(script_path)
script_name = os.path.splitext(os.path.basename(script_path))[0]
json_path = os.path.join(script_dir, script_name + ".jsonl")
def record_results(result, filename=json_path):
with open(filename, "a") as f:
f.write(json.dumps(result) + "\n")
def allocate_weights():
per_mat_weight_bytes = hidden_size * intermediate_size
n_blocks_n_gate_up = (intermediate_size + group_size - 1) // group_size
n_blocks_k = (hidden_size + group_size - 1) // group_size
per_mat_scale_elems_gate_up = n_blocks_n_gate_up * n_blocks_k
per_mat_scale_elems_down = n_blocks_k * n_blocks_n_gate_up
gate_q = (
torch.randint(0, 256, (expert_num * per_mat_weight_bytes,), dtype=torch.uint8, device="cuda")
.to("cpu")
.contiguous()
)
up_q = (
torch.randint(0, 256, (expert_num * per_mat_weight_bytes,), dtype=torch.uint8, device="cuda")
.to("cpu")
.contiguous()
)
down_q = (
torch.randint(0, 256, (expert_num * per_mat_weight_bytes,), dtype=torch.uint8, device="cuda")
.to("cpu")
.contiguous()
)
gate_scale = (
torch.randn(expert_num * per_mat_scale_elems_gate_up, dtype=torch.float32, device="cuda").to("cpu").contiguous()
)
up_scale = (
torch.randn(expert_num * per_mat_scale_elems_gate_up, dtype=torch.float32, device="cuda").to("cpu").contiguous()
)
down_scale = (
torch.randn(expert_num * per_mat_scale_elems_down, dtype=torch.float32, device="cuda").to("cpu").contiguous()
)
return (
gate_q,
up_q,
down_q,
gate_scale,
up_scale,
down_scale,
per_mat_weight_bytes,
per_mat_scale_elems_gate_up,
per_mat_scale_elems_down,
)
def build_moe(layer_idx=0):
"""Build a single MOE instance with the given layer_idx."""
(
gate_q,
up_q,
down_q,
gate_scale,
up_scale,
down_scale,
per_mat_weight_bytes,
per_mat_scale_elems_gate_up,
per_mat_scale_elems_down,
) = allocate_weights()
config = kt_kernel_ext.moe.MOEConfig(expert_num, num_experts_per_tok, hidden_size, intermediate_size)
config.max_len = max_len
config.layer_idx = layer_idx
config.quant_config.bits = 8
config.quant_config.group_size = group_size
config.quant_config.zero_point = False
config.pool = CPUInfer.backend_
config.gate_proj = gate_q.data_ptr()
config.up_proj = up_q.data_ptr()
config.down_proj = down_q.data_ptr()
config.gate_scale = gate_scale.data_ptr()
config.up_scale = up_scale.data_ptr()
config.down_scale = down_scale.data_ptr()
moe = AMXFP8_MOE(config)
CPUInfer.submit(moe.load_weights_task(physical_to_logical_map.data_ptr()))
CPUInfer.sync()
keep_tensors = {
"gate_q": gate_q,
"up_q": up_q,
"down_q": down_q,
"gate_scale": gate_scale,
"up_scale": up_scale,
"down_scale": down_scale,
}
buffer_shapes = {
"per_mat_weight_bytes": per_mat_weight_bytes,
"per_mat_scale_elems_gate_up": per_mat_scale_elems_gate_up,
"per_mat_scale_elems_down": per_mat_scale_elems_down,
}
return moe, buffer_shapes, keep_tensors
def allocate_buffers(buffer_shapes):
"""Allocate shared output buffers for single expert."""
per_mat_weight_bytes = buffer_shapes["per_mat_weight_bytes"]
per_mat_scale_elems_gate_up = buffer_shapes["per_mat_scale_elems_gate_up"]
per_mat_scale_elems_down = buffer_shapes["per_mat_scale_elems_down"]
weight_bytes_per_expert_per_tp = per_mat_weight_bytes // gpu_tp_count
scale_elems_per_expert_per_tp_gate_up = per_mat_scale_elems_gate_up // gpu_tp_count
scale_elems_per_expert_per_tp_down = per_mat_scale_elems_down // gpu_tp_count
# Each buffer stores data for a single expert
w13_weight_bufs = [torch.empty(2 * weight_bytes_per_expert_per_tp, dtype=torch.uint8) for _ in range(gpu_tp_count)]
w13_scale_bufs = [
torch.empty(2 * scale_elems_per_expert_per_tp_gate_up, dtype=torch.float32) for _ in range(gpu_tp_count)
]
w2_weight_bufs = [torch.empty(weight_bytes_per_expert_per_tp, dtype=torch.uint8) for _ in range(gpu_tp_count)]
w2_scale_bufs = [torch.empty(scale_elems_per_expert_per_tp_down, dtype=torch.float32) for _ in range(gpu_tp_count)]
buffer_ptrs = {
"w13_weight_ptrs": [buf.data_ptr() for buf in w13_weight_bufs],
"w13_scale_ptrs": [buf.data_ptr() for buf in w13_scale_bufs],
"w2_weight_ptrs": [buf.data_ptr() for buf in w2_weight_bufs],
"w2_scale_ptrs": [buf.data_ptr() for buf in w2_scale_bufs],
}
keep_tensors = {
"w13_weight_bufs": w13_weight_bufs,
"w13_scale_bufs": w13_scale_bufs,
"w2_weight_bufs": w2_weight_bufs,
"w2_scale_bufs": w2_scale_bufs,
}
return buffer_ptrs, keep_tensors
def bench_write_buffer():
# Build two MOE instances with different layer_idx
moe_0, buffer_shapes, keep_tensors_0 = build_moe(layer_idx=0)
moe_1, _, keep_tensors_1 = build_moe(layer_idx=1)
moes = [moe_0, moe_1]
# Allocate shared buffers
buffer_ptrs, buffer_keep_tensors = allocate_buffers(buffer_shapes)
total_weights = hidden_size * intermediate_size * expert_num * 3
total_scale_bytes = (
(buffer_shapes["per_mat_scale_elems_gate_up"] * 2 + buffer_shapes["per_mat_scale_elems_down"]) * expert_num * 4
)
bytes_per_call = total_weights + total_scale_bytes
# Warm-up: alternate between two MOEs
for _ in tqdm(range(warm_up_iter), desc="Warm-up"):
for moe_idx, moe in enumerate(moes):
for expert_id in range(gpu_experts_num):
CPUInfer.submit(
moe.write_weight_scale_to_buffer_task(gpu_tp_count=gpu_tp_count, expert_id=expert_id, **buffer_ptrs)
)
CPUInfer.sync()
total_time = 0
for iter_idx in tqdm(range(test_iter), desc="Testing"):
start = time.perf_counter()
# Alternate between two MOEs
for moe_idx, moe in enumerate(moes):
for expert_id in range(gpu_experts_num):
CPUInfer.submit(
moe.write_weight_scale_to_buffer_task(gpu_tp_count=gpu_tp_count, expert_id=expert_id, **buffer_ptrs)
)
CPUInfer.sync()
end = time.perf_counter()
iter_time = end - start
total_time += iter_time
print(f"Iter {iter_idx}: {iter_time*1000:.2f} ms")
time.sleep(0.3)
# bytes_per_call is for one MOE, we have 2 MOEs
bytes_per_iter = bytes_per_call * 2
time_per_iter_ms = total_time / test_iter * 1000
bandwidth_gbs = bytes_per_iter * test_iter / total_time / 1e9
print(f"\n{'='*60}")
print("FP8 write_weight_scale_to_buffer benchmark (2 MOEs alternating)")
print(f"{'='*60}")
print(f"Time per iteration: {time_per_iter_ms:.2f} ms")
print(f"Bandwidth: {bandwidth_gbs:.2f} GB/s")
print(f"Experts per MOE: {gpu_experts_num}, MOEs: 2")
print(f"Time per expert: {time_per_iter_ms/(gpu_experts_num*2)*1000:.2f} us")
result = {
"op": "write_weight_scale_to_buffer_fp8",
"time_per_iteration_ms": time_per_iter_ms,
"bandwidth_GBs": bandwidth_gbs,
"timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
"test_parameters": {
"expert_num": expert_num,
"hidden_size": hidden_size,
"intermediate_size": intermediate_size,
"group_size": group_size,
"gpu_tp_count": gpu_tp_count,
"bytes_per_iter": bytes_per_iter,
"num_moes": 2,
},
}
result.update(get_git_commit())
result.update(get_system_info())
record_results(result)
if __name__ == "__main__":
bench_write_buffer()

View File

@@ -2,6 +2,8 @@
# coding=utf-8
"""
Benchmark write_weight_scale_to_buffer for AMX_K2_MOE_TP (int4 packed weights + bf16 scales).
Uses two MOE instances that alternate writing to simulate realistic multi-layer scenarios.
"""
import json
import os
@@ -17,7 +19,7 @@ sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "build"))
from kt_kernel import kt_kernel_ext
import torch
# Benchmark parameters (single MoE, mirror examples/test_k2_write_buffer.py)
# Benchmark parameters
expert_num = 384
num_experts_per_tok = expert_num
gpu_tp_count = 4
@@ -33,7 +35,7 @@ group_size = 32
max_len = 1
physical_to_logical_map = torch.arange(expert_num, dtype=torch.int64, device="cpu").contiguous()
CPUInfer = kt_kernel_ext.CPUInfer(96)
CPUInfer = kt_kernel_ext.CPUInfer(80)
def get_git_commit():
@@ -140,7 +142,8 @@ def allocate_weights():
)
def build_moe():
def build_moe(layer_idx=0):
"""Build a single MOE instance with the given layer_idx."""
(
gate_q,
up_q,
@@ -154,6 +157,7 @@ def build_moe():
config = kt_kernel_ext.moe.MOEConfig(expert_num, num_experts_per_tok, hidden_size, intermediate_size)
config.max_len = max_len
config.layer_idx = layer_idx
config.quant_config.bits = 4
config.quant_config.group_size = group_size
config.quant_config.zero_point = False
@@ -170,16 +174,36 @@ def build_moe():
CPUInfer.submit(moe.load_weights_task(physical_to_logical_map.data_ptr()))
CPUInfer.sync()
# Buffer sizing per TP
keep_tensors = {
"gate_q": gate_q,
"up_q": up_q,
"down_q": down_q,
"gate_scale": gate_scale,
"up_scale": up_scale,
"down_scale": down_scale,
}
buffer_shapes = {
"per_mat_weight_bytes": per_mat_weight_bytes,
"per_mat_scale_elems": per_mat_scale_elems,
}
return moe, buffer_shapes, keep_tensors
def allocate_buffers(buffer_shapes):
"""Allocate shared output buffers for single expert."""
per_mat_weight_bytes = buffer_shapes["per_mat_weight_bytes"]
per_mat_scale_elems = buffer_shapes["per_mat_scale_elems"]
weight_bytes_per_expert_per_tp = per_mat_weight_bytes // gpu_tp_count
scale_elems_per_expert_per_tp = per_mat_scale_elems // gpu_tp_count
total_weight_bytes_per_tp = gpu_experts_num * weight_bytes_per_expert_per_tp
total_scale_elems_per_tp = gpu_experts_num * scale_elems_per_expert_per_tp
w13_weight_bufs = [torch.empty(2 * total_weight_bytes_per_tp, dtype=torch.uint8) for _ in range(gpu_tp_count)]
w13_scale_bufs = [torch.empty(2 * total_scale_elems_per_tp, dtype=torch.bfloat16) for _ in range(gpu_tp_count)]
w2_weight_bufs = [torch.empty(total_weight_bytes_per_tp, dtype=torch.uint8) for _ in range(gpu_tp_count)]
w2_scale_bufs = [torch.empty(total_scale_elems_per_tp, dtype=torch.bfloat16) for _ in range(gpu_tp_count)]
# Each buffer stores data for a single expert
w13_weight_bufs = [torch.empty(2 * weight_bytes_per_expert_per_tp, dtype=torch.uint8) for _ in range(gpu_tp_count)]
w13_scale_bufs = [torch.empty(2 * scale_elems_per_expert_per_tp, dtype=torch.bfloat16) for _ in range(gpu_tp_count)]
w2_weight_bufs = [torch.empty(weight_bytes_per_expert_per_tp, dtype=torch.uint8) for _ in range(gpu_tp_count)]
w2_scale_bufs = [torch.empty(scale_elems_per_expert_per_tp, dtype=torch.bfloat16) for _ in range(gpu_tp_count)]
buffer_ptrs = {
"w13_weight_ptrs": [buf.data_ptr() for buf in w13_weight_bufs],
@@ -188,97 +212,89 @@ def build_moe():
"w2_scale_ptrs": [buf.data_ptr() for buf in w2_scale_bufs],
}
buffer_shapes = {
"per_mat_weight_bytes": per_mat_weight_bytes,
"per_mat_scale_elems": per_mat_scale_elems,
"weight_bytes_per_expert_per_tp": weight_bytes_per_expert_per_tp,
"scale_elems_per_expert_per_tp": scale_elems_per_expert_per_tp,
"total_weight_bytes_per_tp": total_weight_bytes_per_tp,
"total_scale_elems_per_tp": total_scale_elems_per_tp,
}
keep_tensors = {
"gate_q": gate_q,
"up_q": up_q,
"down_q": down_q,
"gate_scale": gate_scale,
"up_scale": up_scale,
"down_scale": down_scale,
"w13_weight_bufs": w13_weight_bufs,
"w13_scale_bufs": w13_scale_bufs,
"w2_weight_bufs": w2_weight_bufs,
"w2_scale_bufs": w2_scale_bufs,
}
return moe, buffer_ptrs, buffer_shapes, keep_tensors
return buffer_ptrs, keep_tensors
def bench_write_buffer():
moe, buffer_ptrs, buffer_shapes, keep_tensors = build_moe()
# Build two MOE instances with different layer_idx
moe_0, buffer_shapes, keep_tensors_0 = build_moe(layer_idx=0)
moe_1, _, keep_tensors_1 = build_moe(layer_idx=1)
moes = [moe_0, moe_1]
# Allocate shared buffers
buffer_ptrs, buffer_keep_tensors = allocate_buffers(buffer_shapes)
total_weights = hidden_size * intermediate_size * expert_num * 3
# Throughput accounting consistent with examples/test_k2_write_buffer.py
bytes_per_call = total_weights // group_size + total_weights // 2
# Throughput accounting: scale bytes (bf16) + weight bytes (int4 packed)
bytes_per_call = total_weights // group_size * 2 + total_weights // 2
# Warm-up
# Warm-up: alternate between two MOEs
for _ in tqdm(range(warm_up_iter), desc="Warm-up"):
CPUInfer.submit(
moe.write_weight_scale_to_buffer_task(
gpu_tp_count=gpu_tp_count,
gpu_experts_num=gpu_experts_num,
**buffer_ptrs,
)
)
CPUInfer.sync()
for moe_idx, moe in enumerate(moes):
for expert_id in range(gpu_experts_num):
CPUInfer.submit(
moe.write_weight_scale_to_buffer_task(
gpu_tp_count=gpu_tp_count,
expert_id=expert_id,
**buffer_ptrs,
)
)
CPUInfer.sync()
total_time = 0
for _ in tqdm(range(test_iter), desc="Testing"):
for iter_idx in tqdm(range(test_iter), desc="Testing"):
start = time.perf_counter()
CPUInfer.submit(
moe.write_weight_scale_to_buffer_task(
gpu_tp_count=gpu_tp_count,
gpu_experts_num=gpu_experts_num,
**buffer_ptrs,
)
)
CPUInfer.sync()
# Alternate between two MOEs
for moe_idx, moe in enumerate(moes):
for expert_id in range(gpu_experts_num):
CPUInfer.submit(
moe.write_weight_scale_to_buffer_task(
gpu_tp_count=gpu_tp_count,
expert_id=expert_id,
**buffer_ptrs,
)
)
CPUInfer.sync()
end = time.perf_counter()
total_time += end - start
time.sleep(0.6)
print(end - start)
iter_time = end - start
total_time += iter_time
print(f"Iter {iter_idx}: {iter_time*1000:.2f} ms")
time.sleep(0.3)
time_per_iter_us = total_time / test_iter * 1e6
bandwidth_gbs = bytes_per_call * test_iter / total_time / 1e9
# bytes_per_call is for one MOE, we have 2 MOEs
bytes_per_iter = bytes_per_call * 2
time_per_iter_ms = total_time / test_iter * 1000
bandwidth_gbs = bytes_per_iter * test_iter / total_time / 1e9
print("write_weight_scale_to_buffer benchmark")
print("Time(s): ", total_time)
print("Iteration: ", test_iter)
print("Time(us) per iteration: ", time_per_iter_us)
print("Bandwidth: ", bandwidth_gbs, "GB/s")
print("")
print(f"\n{'='*60}")
print("K2 write_weight_scale_to_buffer benchmark (2 MOEs alternating)")
print(f"{'='*60}")
print(f"Time per iteration: {time_per_iter_ms:.2f} ms")
print(f"Bandwidth: {bandwidth_gbs:.2f} GB/s")
print(f"Experts per MOE: {gpu_experts_num}, MOEs: 2")
print(f"Time per expert: {time_per_iter_ms/(gpu_experts_num*2)*1000:.2f} us")
result = {
"op": "write_weight_scale_to_buffer",
"total_time_seconds": total_time,
"iterations": test_iter,
"time_per_iteration_us": time_per_iter_us,
"op": "write_weight_scale_to_buffer_k2",
"time_per_iteration_ms": time_per_iter_ms,
"bandwidth_GBs": bandwidth_gbs,
"timestamp": time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()),
"timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
"test_parameters": {
"expert_num": expert_num,
"hidden_size": hidden_size,
"intermediate_size": intermediate_size,
"group_size": group_size,
"max_len": max_len,
"num_experts_per_tok": num_experts_per_tok,
"gpu_tp_count": gpu_tp_count,
"gpu_experts_num": gpu_experts_num,
"warm_up_iter": warm_up_iter,
"test_iter": test_iter,
"bytes_per_call": bytes_per_call,
"bytes_per_iter": bytes_per_iter,
"num_moes": 2,
},
"buffer_shapes": buffer_shapes,
"keep_tensors_alive": list(keep_tensors.keys()),
}
result.update(get_git_commit())
result.update(get_system_info())

View File

@@ -0,0 +1,457 @@
"""
Test script for GemmKernel224FP8 (FP8 MoE) kernel validation.
This script:
1. Generates random BF16 weights
2. Quantizes them to FP8 format with 128x128 block-wise scales
3. Runs the FP8 MoE kernel
4. Compares results with PyTorch reference using dequantized BF16 weights
FP8 format notes:
- Weight: FP8 (E4M3) stored as uint8, shape [expert_num, n, k]
- Scale: FP32, shape [expert_num, n // group_size, k // group_size], group_size=128
"""
import os
import sys
sys.path.insert(0, os.path.dirname(__file__) + "/../build")
import torch
import kt_kernel
torch.manual_seed(42)
# Model config
hidden_size = 3072
intermediate_size = 1536
max_len = 25600
expert_num = 16
num_experts_per_tok = 8
qlen = 100
layer_num = 1
CPUInfer = kt_kernel_ext.CPUInfer(40)
validation_iter = 1
fp8_group_size = 128 # FP8 uses 128x128 block quantization
debug_print_count = 16
physical_to_logical_map = torch.tensor(data=range(expert_num), device="cpu", dtype=torch.int64).contiguous()
def act_fn(x):
"""SiLU activation function"""
return x / (1.0 + torch.exp(-x))
def mlp_torch(input, gate_proj, up_proj, down_proj):
"""Reference MLP computation in PyTorch"""
gate_buf = torch.mm(input, gate_proj.t())
up_buf = torch.mm(input, up_proj.t())
intermediate = act_fn(gate_buf) * up_buf
ret = torch.mm(intermediate, down_proj.t())
return ret
def moe_torch(input, expert_ids, weights, gate_proj, up_proj, down_proj):
"""Reference MoE computation in PyTorch"""
cnts = expert_ids.new_zeros((expert_ids.shape[0], expert_num))
cnts.scatter_(1, expert_ids, 1)
tokens_per_expert = cnts.sum(dim=0)
idxs = expert_ids.view(-1).argsort()
sorted_tokens = input[idxs // expert_ids.shape[1]]
outputs = []
start_idx = 0
for i, num_tokens in enumerate(tokens_per_expert):
end_idx = start_idx + num_tokens
if num_tokens == 0:
continue
tokens_for_this_expert = sorted_tokens[start_idx:end_idx]
expert_out = mlp_torch(tokens_for_this_expert, gate_proj[i], up_proj[i], down_proj[i])
outputs.append(expert_out)
start_idx = end_idx
outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0)
new_x = torch.empty_like(outs)
new_x[idxs] = outs
t_output = (
new_x.view(*expert_ids.shape, -1)
.type(weights.dtype)
.mul_(weights.unsqueeze(dim=-1))
.sum(dim=1)
.type(new_x.dtype)
)
return t_output
# FP8 E4M3 constants
FP8_E4M3_MAX = 448.0 # Maximum representable value in FP8 E4M3
def fp8_e4m3_to_float(fp8_val: int) -> float:
"""
Convert FP8 E4M3 value to float.
FP8 E4M3 format: 1 sign bit, 4 exponent bits, 3 mantissa bits
"""
sign = (fp8_val >> 7) & 1
exp = (fp8_val >> 3) & 0xF
mant = fp8_val & 0x7
if exp == 0:
# Subnormal or zero
if mant == 0:
return -0.0 if sign else 0.0
# Subnormal: value = (-1)^sign * 2^(-6) * (0.mant)
return ((-1) ** sign) * (2**-6) * (mant / 8.0)
elif exp == 15:
# NaN (FP8 E4M3 doesn't have Inf, all exp=15 are NaN)
return float("nan")
else:
# Normal: value = (-1)^sign * 2^(exp-7) * (1.mant)
return ((-1) ** sign) * (2 ** (exp - 7)) * (1.0 + mant / 8.0)
def float_to_fp8_e4m3(val: float) -> int:
"""
Convert float to FP8 E4M3 value.
"""
if val != val: # NaN
return 0x7F # NaN representation
sign = 1 if val < 0 else 0
val = abs(val)
if val == 0:
return sign << 7
# Clamp to max representable value
val = min(val, FP8_E4M3_MAX)
# Find exponent
import math
if val < 2**-9: # Subnormal threshold
# Subnormal
mant = int(round(val / (2**-9)))
mant = min(mant, 7)
return (sign << 7) | mant
exp = int(math.floor(math.log2(val))) + 7
exp = max(1, min(exp, 14)) # Clamp exponent to valid range
# Calculate mantissa
mant = int(round((val / (2 ** (exp - 7)) - 1.0) * 8))
mant = max(0, min(mant, 7))
# Handle overflow to next exponent
if mant > 7:
mant = 0
exp += 1
if exp > 14:
exp = 14
mant = 7
return (sign << 7) | (exp << 3) | mant
def quantize_to_fp8_blockwise(weights: torch.Tensor, group_size: int = 128):
"""
Quantize BF16/FP32 weights to FP8 with block-wise scaling.
Args:
weights: [expert_num, n, k] tensor in BF16/FP32
group_size: Block size for quantization (default 128 for DeepSeek)
Returns:
fp8_weights: [expert_num, n, k] uint8 tensor
scales: [expert_num, n // group_size, k // group_size] BF16 tensor (scale_inv)
"""
weights_f32 = weights.to(torch.float32)
e, n, k = weights_f32.shape
assert n % group_size == 0, f"n ({n}) must be divisible by group_size ({group_size})"
assert k % group_size == 0, f"k ({k}) must be divisible by group_size ({group_size})"
n_blocks = n // group_size
k_blocks = k // group_size
# Reshape to [e, n_blocks, group_size, k_blocks, group_size]
reshaped = weights_f32.view(e, n_blocks, group_size, k_blocks, group_size)
# Move to [e, n_blocks, k_blocks, group_size, group_size] for block processing
reshaped = reshaped.permute(0, 1, 3, 2, 4)
# Calculate max abs per block
max_abs = reshaped.abs().amax(dim=(-2, -1), keepdim=True)
max_abs = torch.clamp(max_abs, min=1e-12)
# Scale to FP8 range: scale = max_abs / FP8_MAX
# We store scale_inv = scale (for dequantization: fp8 * scale)
scales = (max_abs / FP8_E4M3_MAX).squeeze(-1).squeeze(-1) # [e, n_blocks, k_blocks]
# Quantize: q = round(val / scale)
scaled = reshaped / (scales.unsqueeze(-1).unsqueeze(-1) + 1e-12)
# Convert to FP8 E4M3 using vectorized approach
# Clamp to FP8 representable range
scaled = scaled.clamp(-FP8_E4M3_MAX, FP8_E4M3_MAX)
# Simple quantization: round to nearest representable FP8 value
# For simplicity, we use a lookup table approach
fp8_q = torch.zeros_like(scaled, dtype=torch.uint8)
# Vectorized FP8 quantization
sign_mask = (scaled < 0).to(torch.uint8) << 7
abs_scaled = scaled.abs()
# Handle different ranges
# Subnormal: 0 < |x| < 2^-6
subnormal_mask = (abs_scaled > 0) & (abs_scaled < 2**-6)
subnormal_mant = (abs_scaled / (2**-9)).round().clamp(0, 7).to(torch.uint8)
# Normal values
normal_mask = abs_scaled >= 2**-6
log2_val = torch.log2(abs_scaled.clamp(min=2**-9))
exp = (log2_val.floor() + 7).clamp(1, 14).to(torch.int32)
mant = ((abs_scaled / (2.0 ** (exp.float() - 7)) - 1.0) * 8).round().clamp(0, 7).to(torch.uint8)
# Combine
fp8_q = torch.where(subnormal_mask, sign_mask | subnormal_mant, fp8_q)
fp8_q = torch.where(normal_mask, sign_mask | (exp.to(torch.uint8) << 3) | mant, fp8_q)
# Reshape back to [e, n, k]
fp8_q = fp8_q.permute(0, 1, 3, 2, 4).reshape(e, n, k)
# Scales shape: [e, n_blocks, k_blocks] -> store as [e, n_blocks, k_blocks]
scales_fp32 = scales.to(torch.float32).contiguous()
return fp8_q.contiguous(), scales_fp32
def dequantize_fp8_blockwise(fp8_weights: torch.Tensor, scales: torch.Tensor, group_size: int = 128):
"""
Dequantize FP8 weights back to BF16 for reference computation.
Args:
fp8_weights: [expert_num, n, k] uint8 tensor
scales: [expert_num, n // group_size, k // group_size] BF16 tensor
group_size: Block size
Returns:
dequantized: [expert_num, n, k] BF16 tensor
"""
e, n, k = fp8_weights.shape
n_blocks = n // group_size
k_blocks = k // group_size
# Convert FP8 to float
# Build lookup table for FP8 E4M3 -> float
fp8_lut = torch.tensor([fp8_e4m3_to_float(i) for i in range(256)], dtype=torch.float32)
# Use lookup table
fp8_float = fp8_lut[fp8_weights.to(torch.int64)]
# Reshape for block-wise scaling
fp8_reshaped = fp8_float.view(e, n_blocks, group_size, k_blocks, group_size)
fp8_reshaped = fp8_reshaped.permute(0, 1, 3, 2, 4) # [e, n_blocks, k_blocks, group_size, group_size]
# Apply scales
scales_f32 = scales.to(torch.float32).unsqueeze(-1).unsqueeze(-1) # [e, n_blocks, k_blocks, 1, 1]
dequantized = fp8_reshaped * scales_f32
# Reshape back
dequantized = dequantized.permute(0, 1, 3, 2, 4).reshape(e, n, k)
return dequantized.to(torch.bfloat16).contiguous()
def build_random_fp8_weights():
"""
Generate random BF16 weights and quantize to FP8.
Returns:
dict with fp8 weights, scales, and original bf16 for reference
"""
torch.manual_seed(42)
# Generate random BF16 weights with small values
gate_proj = (torch.randn((expert_num, intermediate_size, hidden_size), dtype=torch.float32) / 100.0).to(
torch.bfloat16
)
up_proj = (torch.randn((expert_num, intermediate_size, hidden_size), dtype=torch.float32) / 100.0).to(
torch.bfloat16
)
down_proj = (torch.randn((expert_num, hidden_size, intermediate_size), dtype=torch.float32) / 100.0).to(
torch.bfloat16
)
# Quantize to FP8
gate_fp8, gate_scales = quantize_to_fp8_blockwise(gate_proj, fp8_group_size)
up_fp8, up_scales = quantize_to_fp8_blockwise(up_proj, fp8_group_size)
down_fp8, down_scales = quantize_to_fp8_blockwise(down_proj, fp8_group_size)
# Dequantize for reference computation
gate_deq = dequantize_fp8_blockwise(gate_fp8, gate_scales, fp8_group_size)
up_deq = dequantize_fp8_blockwise(up_fp8, up_scales, fp8_group_size)
down_deq = dequantize_fp8_blockwise(down_fp8, down_scales, fp8_group_size)
print(f"FP8 weights shape: gate={gate_fp8.shape}, up={up_fp8.shape}, down={down_fp8.shape}")
print(f"Scales shape: gate={gate_scales.shape}, up={up_scales.shape}, down={down_scales.shape}")
# Debug: Print FP8 weight and scale info for expert 0
print("\n=== DEBUG: FP8 Weight and Scale Info (Expert 0) ===")
print(f"gate_fp8[0] first 8x8 block:")
for i in range(8):
print(f" row {i}: {gate_fp8[0, i, :8].numpy().tobytes().hex(' ')}")
print(f"gate_fp8[0] stats: min={gate_fp8[0].min()}, max={gate_fp8[0].max()}")
print(f"gate_scales[0] first 4x4 block:\n{gate_scales[0, :4, :4]}")
print(f"gate_scales[0] stats: min={gate_scales[0].min()}, max={gate_scales[0].max()}")
print(f"\nup_fp8[0] first 8x8 block:")
for i in range(8):
print(f" row {i}: {up_fp8[0, i, :8].numpy().tobytes().hex(' ')}")
print(f"up_fp8[0] stats: min={up_fp8[0].min()}, max={up_fp8[0].max()}")
print(f"up_scales[0] first 4x4 block:\n{up_scales[0, :4, :4]}")
print(f"up_scales[0] stats: min={up_scales[0].min()}, max={up_scales[0].max()}")
print(f"\ndown_fp8[0] first 8x8 block:")
for i in range(8):
print(f" row {i}: {down_fp8[0, i, :8].numpy().tobytes().hex(' ')}")
print(f"down_fp8[0] stats: min={down_fp8[0].min()}, max={down_fp8[0].max()}")
print(f"down_scales[0] first 4x4 block:\n{down_scales[0, :4, :4]}")
print(f"down_scales[0] stats: min={down_scales[0].min()}, max={down_scales[0].max()}")
return {
"gate_fp8": gate_fp8.contiguous(),
"up_fp8": up_fp8.contiguous(),
"down_fp8": down_fp8.contiguous(),
"gate_scales": gate_scales.contiguous(),
"up_scales": up_scales.contiguous(),
"down_scales": down_scales.contiguous(),
"gate_deq": gate_deq.contiguous(),
"up_deq": up_deq.contiguous(),
"down_deq": down_deq.contiguous(),
}
def build_moes_from_fp8_data(fp8_data: dict):
"""
Build FP8 MoE modules from quantized data.
"""
moes = []
with torch.inference_mode(mode=True):
for _ in range(layer_num):
config = kt_kernel_ext.moe.MOEConfig(expert_num, num_experts_per_tok, hidden_size, intermediate_size, 0)
config.max_len = max_len
config.quant_config.bits = 8
config.quant_config.group_size = fp8_group_size
config.quant_config.zero_point = False
# Set FP8 weight pointers
config.gate_proj = fp8_data["gate_fp8"].data_ptr()
config.up_proj = fp8_data["up_fp8"].data_ptr()
config.down_proj = fp8_data["down_fp8"].data_ptr()
# Set scale pointers
config.gate_scale = fp8_data["gate_scales"].data_ptr()
config.up_scale = fp8_data["up_scales"].data_ptr()
config.down_scale = fp8_data["down_scales"].data_ptr()
config.pool = CPUInfer.backend_
moe = kt_kernel_ext.moe.AMXFP8_MOE(config)
CPUInfer.submit(moe.load_weights_task(physical_to_logical_map.data_ptr()))
CPUInfer.sync()
moes.append(moe)
return moes
def run_fp8_moe_test():
"""
Run FP8 MoE validation test.
"""
print("\n" + "=" * 70)
print("FP8 MoE Kernel Validation Test")
print("=" * 70)
# Build FP8 weights
print("\nGenerating and quantizing weights...")
fp8_data = build_random_fp8_weights()
# Build MoE modules
print("\nBuilding FP8 MoE modules...")
moes = build_moes_from_fp8_data(fp8_data)
# Get dequantized weights for reference
gate_deq = fp8_data["gate_deq"]
up_deq = fp8_data["up_deq"]
down_deq = fp8_data["down_deq"]
diffs = []
with torch.inference_mode(mode=True):
for i in range(validation_iter):
torch.manual_seed(100 + i)
bsz_tensor = torch.tensor([qlen], device="cpu")
expert_ids = torch.stack(
[torch.randperm(expert_num)[:num_experts_per_tok] for _ in range(qlen)]
).contiguous()
weights = torch.randn((qlen, num_experts_per_tok), dtype=torch.float32).contiguous() / 100
input_tensor = torch.randn((qlen, hidden_size), dtype=torch.bfloat16).contiguous() * 1.5
output = torch.empty((qlen, hidden_size), dtype=torch.bfloat16).contiguous()
moe = moes[i % layer_num]
CPUInfer.submit(
moe.forward_task(
bsz_tensor.data_ptr(),
num_experts_per_tok,
expert_ids.data_ptr(),
weights.data_ptr(),
input_tensor.data_ptr(),
output.data_ptr(),
False,
)
)
CPUInfer.sync()
assert not torch.isnan(output).any(), "NaN values detected in CPU expert output."
assert not torch.isinf(output).any(), "Inf values detected in CPU expert output."
# Reference computation using dequantized weights
t_output = moe_torch(input_tensor, expert_ids, weights, gate_deq, up_deq, down_deq)
t_output_flat = t_output.flatten()
output_flat = output.flatten()
diff = torch.mean(torch.abs(output_flat - t_output_flat)) / (torch.mean(torch.abs(t_output_flat)) + 1e-12)
diffs.append(diff.item())
print(f"Iteration {i}: relative L1 diff = {diff:.6f}")
if i < 3: # Print detailed output for first few iterations
print(f" kernel output: {output_flat[:debug_print_count]}")
print(f" torch output: {t_output_flat[:debug_print_count]}")
mean_diff = float(sum(diffs) / len(diffs))
max_diff = float(max(diffs))
min_diff = float(min(diffs))
print("\n" + "=" * 70)
print("FP8 MoE Test Results")
print("=" * 70)
print(f"Mean relative L1 diff: {mean_diff*100:.4f}%")
print(f"Max relative L1 diff: {max_diff*100:.4f}%")
print(f"Min relative L1 diff: {min_diff*100:.4f}%")
# Pass/Fail criteria
threshold = 15.0 # 15% relative error threshold for FP8
if mean_diff * 100 < threshold:
print(f"\nPASS: Mean error {mean_diff*100:.4f}% < {threshold}% threshold")
else:
print(f"\nFAIL: Mean error {mean_diff*100:.4f}% >= {threshold}% threshold")
return {"mean": mean_diff, "max": max_diff, "min": min_diff}
if __name__ == "__main__":
run_fp8_moe_test()

View File

@@ -0,0 +1,389 @@
import os
import sys
import time
import torch
import numpy as np
from kt_kernel import kt_kernel_ext
from kt_kernel_ext import CPUInfer
from kt_kernel_ext.moe import AMXFP8_MOE
def make_cpu_infer(thread_num=80):
return CPUInfer(thread_num)
def build_config(cpuinfer, expert_num, num_experts_per_tok, hidden_size, intermediate_size, group_size):
cfg = kt_kernel_ext.moe.MOEConfig(expert_num, num_experts_per_tok, hidden_size, intermediate_size)
cfg.max_len = 1
cfg.quant_config.bits = 8 # FP8
cfg.quant_config.group_size = group_size
cfg.quant_config.zero_point = False
cfg.pool = cpuinfer.backend_
return cfg
def allocate_weights(expert_num, hidden_size, intermediate_size, group_size):
"""Allocate FP8 weights and scales for testing"""
# FP8 weights: 1 byte per element
per_mat_weight_bytes = hidden_size * intermediate_size
# FP8 scales: block-wise (group_size x group_size blocks), stored as float32
n_blocks_n_gate_up = (intermediate_size + group_size - 1) // group_size
n_blocks_k = (hidden_size + group_size - 1) // group_size
per_mat_scale_elems_gate_up = n_blocks_n_gate_up * n_blocks_k
# For down: n=hidden_size, k=intermediate_size
n_blocks_n_down = n_blocks_k
n_blocks_k_down = n_blocks_n_gate_up
per_mat_scale_elems_down = n_blocks_n_down * n_blocks_k_down
gate_q = torch.randint(0, 256, (expert_num * per_mat_weight_bytes,), dtype=torch.uint8)
up_q = torch.randint(0, 256, (expert_num * per_mat_weight_bytes,), dtype=torch.uint8)
down_q = torch.randint(0, 256, (expert_num * per_mat_weight_bytes,), dtype=torch.uint8)
# FP8 scales are float32
gate_scale = torch.randn(expert_num * per_mat_scale_elems_gate_up, dtype=torch.float32)
up_scale = torch.randn(expert_num * per_mat_scale_elems_gate_up, dtype=torch.float32)
down_scale = torch.randn(expert_num * per_mat_scale_elems_down, dtype=torch.float32)
return (
gate_q,
up_q,
down_q,
gate_scale,
up_scale,
down_scale,
per_mat_weight_bytes,
per_mat_scale_elems_gate_up,
per_mat_scale_elems_down,
)
def test_with_tp(gpu_tp_count):
"""Test write_weight_scale_to_buffer with a specific gpu_tp_count"""
torch.manual_seed(123)
expert_num = 256 # Reduced for debugging
gpu_experts = expert_num # Number of experts on GPU
num_experts_per_tok = 8
hidden_size = 3072
intermediate_size = 1536 # Changed from 2048 to test non-aligned case
group_size = 128 # FP8 uses 128x128 block-wise scales
cpuinfer = make_cpu_infer()
cfg = build_config(cpuinfer, expert_num, num_experts_per_tok, hidden_size, intermediate_size, group_size)
(
gate_q,
up_q,
down_q,
gate_scale,
up_scale,
down_scale,
per_mat_weight_bytes,
per_mat_scale_elems_gate_up,
per_mat_scale_elems_down,
) = allocate_weights(expert_num, hidden_size, intermediate_size, group_size)
cfg.gate_proj = gate_q.data_ptr()
cfg.up_proj = up_q.data_ptr()
cfg.down_proj = down_q.data_ptr()
cfg.gate_scale = gate_scale.data_ptr()
cfg.up_scale = up_scale.data_ptr()
cfg.down_scale = down_scale.data_ptr()
moe = AMXFP8_MOE(cfg)
physical_to_logical_map = torch.arange(expert_num, dtype=torch.int64, device="cpu").contiguous()
cpuinfer.submit(moe.load_weights_task(physical_to_logical_map.data_ptr()))
cpuinfer.sync()
# TP configuration
# Calculate sizes per TP part (per expert) - must match C++ code which uses div_up
def div_up(a, b):
return (a + b - 1) // b
weight_bytes_per_expert_per_tp = per_mat_weight_bytes // gpu_tp_count
# For W13 (gate/up): n=intermediate_size/gpu_tp, k=hidden_size
gpu_n_w13 = intermediate_size // gpu_tp_count
gpu_k_w13 = hidden_size
scale_elems_per_expert_per_tp_gate_up = div_up(gpu_n_w13, group_size) * div_up(gpu_k_w13, group_size)
# For W2 (down): n=hidden_size, k=intermediate_size/gpu_tp
gpu_n_w2 = hidden_size
gpu_k_w2 = intermediate_size // gpu_tp_count
scale_elems_per_expert_per_tp_down = div_up(gpu_n_w2, group_size) * div_up(gpu_k_w2, group_size)
# Total sizes for all gpu_experts
total_weight_bytes_per_tp = gpu_experts * weight_bytes_per_expert_per_tp
total_scale_elems_per_tp_gate_up = gpu_experts * scale_elems_per_expert_per_tp_gate_up
total_scale_elems_per_tp_down = gpu_experts * scale_elems_per_expert_per_tp_down
# Create buffer lists for w13 (gate+up) and w2 (down)
# These hold all experts' data for each GPU TP
w13_weight_bufs = []
w13_scale_bufs = []
w2_weight_bufs = []
w2_scale_bufs = []
for tp_idx in range(gpu_tp_count):
# w13 combines gate and up, so needs 2x the size
w13_weight_bufs.append(torch.empty(2 * total_weight_bytes_per_tp, dtype=torch.uint8))
w13_scale_bufs.append(torch.empty(2 * total_scale_elems_per_tp_gate_up, dtype=torch.float32))
w2_weight_bufs.append(torch.empty(total_weight_bytes_per_tp, dtype=torch.uint8))
w2_scale_bufs.append(torch.empty(total_scale_elems_per_tp_down, dtype=torch.float32))
print(f"Total experts: {expert_num}, GPU experts: {gpu_experts}")
print(f"GPU TP count: {gpu_tp_count}")
print(f"Original per matrix weight bytes: {per_mat_weight_bytes}")
print(f"Original per matrix scale elements (gate/up): {per_mat_scale_elems_gate_up}")
print(f"Original per matrix scale elements (down): {per_mat_scale_elems_down}")
print(f"Weight bytes per expert per TP: {weight_bytes_per_expert_per_tp}")
print(f"Scale elements per expert per TP (gate/up): {scale_elems_per_expert_per_tp_gate_up}")
print(f"Scale elements per expert per TP (down): {scale_elems_per_expert_per_tp_down}")
print(f"Total weight bytes per TP (w13): {2 * total_weight_bytes_per_tp}")
print(f"Total weight bytes per TP (w2): {total_weight_bytes_per_tp}")
# Helper function to get pointers with expert offset
# write_weights_to_buffer writes one expert at a time, so we need to pass
# pointers that already point to the correct location for each expert
def get_expert_ptrs(expert_id):
w13_weight_ptrs = []
w13_scale_ptrs = []
w2_weight_ptrs = []
w2_scale_ptrs = []
for tp_idx in range(gpu_tp_count):
# Calculate byte offsets for this expert
# w13: gate_weight + up_weight interleaved by expert
# Layout: [expert0_gate, expert0_up, expert1_gate, expert1_up, ...]
w13_weight_expert_offset = expert_id * 2 * weight_bytes_per_expert_per_tp
w13_scale_expert_offset = expert_id * 2 * scale_elems_per_expert_per_tp_gate_up
w2_weight_expert_offset = expert_id * weight_bytes_per_expert_per_tp
w2_scale_expert_offset = expert_id * scale_elems_per_expert_per_tp_down
w13_weight_ptrs.append(w13_weight_bufs[tp_idx].data_ptr() + w13_weight_expert_offset)
w13_scale_ptrs.append(w13_scale_bufs[tp_idx].data_ptr() + w13_scale_expert_offset * 4) # float32 = 4 bytes
w2_weight_ptrs.append(w2_weight_bufs[tp_idx].data_ptr() + w2_weight_expert_offset)
w2_scale_ptrs.append(w2_scale_bufs[tp_idx].data_ptr() + w2_scale_expert_offset * 4) # float32 = 4 bytes
return w13_weight_ptrs, w13_scale_ptrs, w2_weight_ptrs, w2_scale_ptrs
# Warm up
for i in range(2):
for expert_id in range(gpu_experts):
w13_weight_ptrs, w13_scale_ptrs, w2_weight_ptrs, w2_scale_ptrs = get_expert_ptrs(expert_id)
cpuinfer.submit(
moe.write_weight_scale_to_buffer_task(
gpu_tp_count=gpu_tp_count,
expert_id=expert_id,
w13_weight_ptrs=w13_weight_ptrs,
w13_scale_ptrs=w13_scale_ptrs,
w2_weight_ptrs=w2_weight_ptrs,
w2_scale_ptrs=w2_scale_ptrs,
)
)
cpuinfer.sync()
# Timing
begin_time = time.perf_counter_ns()
for expert_id in range(gpu_experts):
w13_weight_ptrs, w13_scale_ptrs, w2_weight_ptrs, w2_scale_ptrs = get_expert_ptrs(expert_id)
cpuinfer.submit(
moe.write_weight_scale_to_buffer_task(
gpu_tp_count=gpu_tp_count,
expert_id=expert_id,
w13_weight_ptrs=w13_weight_ptrs,
w13_scale_ptrs=w13_scale_ptrs,
w2_weight_ptrs=w2_weight_ptrs,
w2_scale_ptrs=w2_scale_ptrs,
)
)
cpuinfer.sync()
end_time = time.perf_counter_ns()
elapsed_ms = (end_time - begin_time) / 1000000
# Calculate throughput
total_weights = hidden_size * intermediate_size * gpu_experts * 3
total_scale_bytes = (per_mat_scale_elems_gate_up * 2 + per_mat_scale_elems_down) * gpu_experts * 4 # float32
total_bytes = total_weights + total_scale_bytes
print(f"write_weight_scale_to_buffer time: {elapsed_ms:.2f} ms")
print(f"Throughput: {total_bytes / (elapsed_ms * 1e6):.2f} GB/s")
def split_expert_tensor(tensor, chunk):
"""Split tensor by experts"""
return [tensor[i * chunk : (i + 1) * chunk] for i in range(expert_num)]
# Split by experts first
gate_q_experts = split_expert_tensor(gate_q, per_mat_weight_bytes)
up_q_experts = split_expert_tensor(up_q, per_mat_weight_bytes)
down_q_experts = split_expert_tensor(down_q, per_mat_weight_bytes)
gate_scale_experts = split_expert_tensor(gate_scale, per_mat_scale_elems_gate_up)
up_scale_experts = split_expert_tensor(up_scale, per_mat_scale_elems_gate_up)
down_scale_experts = split_expert_tensor(down_scale, per_mat_scale_elems_down)
# For down matrix
n_blocks_n = (hidden_size + group_size - 1) // group_size
n_blocks_k = (intermediate_size + group_size - 1) // group_size
n_blocks_k_per_tp = n_blocks_k // gpu_tp_count
# Verify buffers for each TP part
for tp_idx in range(gpu_tp_count):
expected_w13_weights = []
expected_w13_scales = []
expected_w2_weights = []
expected_w2_scales = []
weight13_per_tp = per_mat_weight_bytes // gpu_tp_count
scale13_per_tp = per_mat_scale_elems_gate_up // gpu_tp_count
# Process each GPU expert
for expert_id in range(gpu_experts):
# For w13 (gate and up), the slicing is along intermediate_size (n direction)
start_weight = tp_idx * weight13_per_tp
end_weight = (tp_idx + 1) * weight13_per_tp
start_scale = tp_idx * scale13_per_tp
end_scale = (tp_idx + 1) * scale13_per_tp
# Gate
gate_weight_tp = gate_q_experts[expert_id][start_weight:end_weight]
gate_scale_tp = gate_scale_experts[expert_id][start_scale:end_scale]
# Up
up_weight_tp = up_q_experts[expert_id][start_weight:end_weight]
up_scale_tp = up_scale_experts[expert_id][start_scale:end_scale]
# Down matrix needs special handling because it's sliced column-wise
# down is (hidden_size, intermediate_size) in n-major format
down_weight_tp_parts = []
down_scale_tp_parts = []
# Iterate through each row to extract the corresponding parts
for row_idx in range(hidden_size):
row_weight_start = row_idx * intermediate_size
# Direct mapping: each CPU TP corresponds to a GPU TP
tp_slice_weight_size = intermediate_size // gpu_tp_count
tp_weight_offset = row_weight_start + tp_idx * tp_slice_weight_size
down_weight_tp_parts.append(
down_q_experts[expert_id][tp_weight_offset : tp_weight_offset + tp_slice_weight_size]
)
# For scale: only process at block boundaries
for bn in range(n_blocks_n):
row_scale_start = bn * n_blocks_k
tp_scale_offset = row_scale_start + tp_idx * n_blocks_k_per_tp
down_scale_tp_parts.append(
down_scale_experts[expert_id][tp_scale_offset : tp_scale_offset + n_blocks_k_per_tp]
)
# Concatenate all slices for this TP
down_weight_tp = torch.cat(down_weight_tp_parts)
down_scale_tp = torch.cat(down_scale_tp_parts)
# Append to expected lists - interleaved by expert: [gate0, up0, gate1, up1, ...]
expected_w13_weights.append(gate_weight_tp)
expected_w13_weights.append(up_weight_tp)
expected_w13_scales.append(gate_scale_tp)
expected_w13_scales.append(up_scale_tp)
expected_w2_weights.append(down_weight_tp)
expected_w2_scales.append(down_scale_tp)
# Concatenate all experts for this TP part
expected_w13_weight = torch.cat(expected_w13_weights)
expected_w13_scale = torch.cat(expected_w13_scales)
expected_w2_weight = torch.cat(expected_w2_weights)
expected_w2_scale = torch.cat(expected_w2_scales)
print(f"=== Checking TP part {tp_idx} ===")
print(f" w13 weight shape: actual={w13_weight_bufs[tp_idx].shape}, expected={expected_w13_weight.shape}")
print(f" w13 scale shape: actual={w13_scale_bufs[tp_idx].shape}, expected={expected_w13_scale.shape}")
print(f" w2 weight shape: actual={w2_weight_bufs[tp_idx].shape}, expected={expected_w2_weight.shape}")
print(f" w2 scale shape: actual={w2_scale_bufs[tp_idx].shape}, expected={expected_w2_scale.shape}")
# Assert all checks pass
if not torch.equal(w13_weight_bufs[tp_idx], expected_w13_weight):
# Find first mismatch
diff_mask = w13_weight_bufs[tp_idx] != expected_w13_weight
first_diff_idx = diff_mask.nonzero()[0].item() if diff_mask.any() else -1
print(f" w13 weight mismatch at index {first_diff_idx}")
print(f" actual: {w13_weight_bufs[tp_idx][first_diff_idx:first_diff_idx+10]}")
print(f" expected: {expected_w13_weight[first_diff_idx:first_diff_idx+10]}")
raise AssertionError(f"w13 weight bytes mismatch for TP {tp_idx}")
if not torch.allclose(w13_scale_bufs[tp_idx], expected_w13_scale):
diff = torch.abs(w13_scale_bufs[tp_idx] - expected_w13_scale)
max_diff_idx = diff.argmax().item()
print(f" w13 scale mismatch, max diff at index {max_diff_idx}")
print(f" actual: {w13_scale_bufs[tp_idx][max_diff_idx]}")
print(f" expected: {expected_w13_scale[max_diff_idx]}")
raise AssertionError(f"w13 scale values mismatch for TP {tp_idx}")
if not torch.equal(w2_weight_bufs[tp_idx], expected_w2_weight):
diff_mask = w2_weight_bufs[tp_idx] != expected_w2_weight
first_diff_idx = diff_mask.nonzero()[0].item() if diff_mask.any() else -1
print(f" w2 weight mismatch at index {first_diff_idx}")
print(f" actual: {w2_weight_bufs[tp_idx][first_diff_idx:first_diff_idx+10]}")
print(f" expected: {expected_w2_weight[first_diff_idx:first_diff_idx+10]}")
raise AssertionError(f"w2 weight bytes mismatch for TP {tp_idx}")
if not torch.allclose(w2_scale_bufs[tp_idx], expected_w2_scale):
diff = torch.abs(w2_scale_bufs[tp_idx] - expected_w2_scale)
max_diff_idx = diff.argmax().item()
print(f" w2 scale mismatch, max diff at index {max_diff_idx}")
print(f" actual: {w2_scale_bufs[tp_idx][max_diff_idx]}")
print(f" expected: {expected_w2_scale[max_diff_idx]}")
raise AssertionError(f"w2 scale values mismatch for TP {tp_idx}")
print(
f"\n✓ write_weight_scale_to_buffer passed: extracted {gpu_experts} GPU experts across {gpu_tp_count} TP parts"
)
return True
def main():
"""Run tests for all gpu_tp_count values: 1, 2, 4, 8"""
tp_values = [1, 2, 4] # Test TP=8
all_passed = True
results = {}
print("=" * 60)
print("Testing FP8 write_weight_scale_to_buffer for TP = ", tp_values)
print("=" * 60)
for tp in tp_values:
print(f"\n{'='*60}")
print(f"Testing with gpu_tp_count = {tp}")
print(f"{'='*60}")
try:
test_with_tp(tp)
results[tp] = "PASSED"
print(f"✓ TP={tp} PASSED")
except Exception as e:
results[tp] = f"FAILED: {e}"
all_passed = False
print(f"✗ TP={tp} FAILED: {e}")
print("\n" + "=" * 60)
print("SUMMARY")
print("=" * 60)
for tp, result in results.items():
status = "" if "PASSED" in result else ""
print(f" {status} TP={tp}: {result}")
if all_passed:
print("\n✓ ALL TESTS PASSED")
else:
print("\n✗ SOME TESTS FAILED")
sys.exit(1)
if __name__ == "__main__":
main()

View File

@@ -6,11 +6,6 @@ import torch
import numpy as np
# Ensure we can import the local extension
# REPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "../"))
# if REPO_ROOT not in sys.path:
# sys.path.insert(0, REPO_ROOT)
from kt_kernel import kt_kernel_ext
from kt_kernel_ext import CPUInfer
@@ -54,12 +49,12 @@ def allocate_weights(expert_num, hidden_size, intermediate_size, group_size):
)
def main():
def test_with_tp(gpu_tp_count):
"""Test write_weight_scale_to_buffer with a specific gpu_tp_count"""
torch.manual_seed(123)
expert_num = 256 # Total experts
expert_num = 8 # Reduced for faster testing
gpu_experts = expert_num # Number of experts on GPU
gpu_tp_count = 2 # Number of TP parts
num_experts_per_tok = 8
hidden_size = 7168
@@ -94,11 +89,7 @@ def main():
cpuinfer.sync()
# TP configuration
# Since weights are col-major, we can directly divide the total size by tp_count
# Each matrix is divided into gpu_tp_count parts in memory order
# Calculate sizes per TP part (direct division since col-major)
# Calculate sizes per TP part (per expert)
weight_bytes_per_expert_per_tp = per_mat_weight_bytes // gpu_tp_count
scale_elems_per_expert_per_tp = per_mat_scale_elems // gpu_tp_count
@@ -107,24 +98,19 @@ def main():
total_scale_elems_per_tp = gpu_experts * scale_elems_per_expert_per_tp
# Create buffer lists for w13 (gate+up) and w2 (down)
# These hold all experts' data for each GPU TP
w13_weight_bufs = []
w13_scale_bufs = []
w2_weight_bufs = []
w2_scale_bufs = []
for tp_idx in range(gpu_tp_count):
# w13 combines gate and up, so needs 2x the size
# w13 combines gate and up, so needs 2x the size per expert
w13_weight_bufs.append(torch.empty(2 * total_weight_bytes_per_tp, dtype=torch.uint8))
w13_scale_bufs.append(torch.empty(2 * total_scale_elems_per_tp, dtype=torch.bfloat16))
w2_weight_bufs.append(torch.empty(total_weight_bytes_per_tp, dtype=torch.uint8))
w2_scale_bufs.append(torch.empty(total_scale_elems_per_tp, dtype=torch.bfloat16))
# Get data pointers for all buffers
w13_weight_ptrs = [buf.data_ptr() for buf in w13_weight_bufs]
w13_scale_ptrs = [buf.data_ptr() for buf in w13_scale_bufs]
w2_weight_ptrs = [buf.data_ptr() for buf in w2_weight_bufs]
w2_scale_ptrs = [buf.data_ptr() for buf in w2_scale_bufs]
print(f"Total experts: {expert_num}, GPU experts: {gpu_experts}")
print(f"GPU TP count: {gpu_tp_count}")
print(f"Original per matrix weight bytes: {per_mat_weight_bytes}")
@@ -133,14 +119,56 @@ def main():
print(f"Scale elements per expert per TP: {scale_elems_per_expert_per_tp}")
print(f"Total weight bytes per TP (w13): {2 * total_weight_bytes_per_tp}")
print(f"Total weight bytes per TP (w2): {total_weight_bytes_per_tp}")
print(f"Total scale elements per TP (w13): {2 * total_scale_elems_per_tp}")
print(f"Total scale elements per TP (w2): {total_scale_elems_per_tp}")
for i in range(5):
# Helper function to get pointers with expert offset
# K2 write_weights_to_buffer writes one expert at a time, so we need to pass
# pointers that already point to the correct location for each expert
def get_expert_ptrs(expert_id):
w13_weight_ptrs = []
w13_scale_ptrs = []
w2_weight_ptrs = []
w2_scale_ptrs = []
for tp_idx in range(gpu_tp_count):
# Calculate byte offsets for this expert
# w13: gate_weight + up_weight interleaved by expert
# Layout: [expert0_gate, expert0_up, expert1_gate, expert1_up, ...]
w13_weight_expert_offset = expert_id * 2 * weight_bytes_per_expert_per_tp
w13_scale_expert_offset = expert_id * 2 * scale_elems_per_expert_per_tp
w2_weight_expert_offset = expert_id * weight_bytes_per_expert_per_tp
w2_scale_expert_offset = expert_id * scale_elems_per_expert_per_tp
w13_weight_ptrs.append(w13_weight_bufs[tp_idx].data_ptr() + w13_weight_expert_offset)
w13_scale_ptrs.append(w13_scale_bufs[tp_idx].data_ptr() + w13_scale_expert_offset * 2) # bf16 = 2 bytes
w2_weight_ptrs.append(w2_weight_bufs[tp_idx].data_ptr() + w2_weight_expert_offset)
w2_scale_ptrs.append(w2_scale_bufs[tp_idx].data_ptr() + w2_scale_expert_offset * 2) # bf16 = 2 bytes
return w13_weight_ptrs, w13_scale_ptrs, w2_weight_ptrs, w2_scale_ptrs
# Warm up
for i in range(2):
for expert_id in range(gpu_experts):
w13_weight_ptrs, w13_scale_ptrs, w2_weight_ptrs, w2_scale_ptrs = get_expert_ptrs(expert_id)
cpuinfer.submit(
moe.write_weight_scale_to_buffer_task(
gpu_tp_count=gpu_tp_count,
expert_id=expert_id,
w13_weight_ptrs=w13_weight_ptrs,
w13_scale_ptrs=w13_scale_ptrs,
w2_weight_ptrs=w2_weight_ptrs,
w2_scale_ptrs=w2_scale_ptrs,
)
)
cpuinfer.sync()
# Timing
begin_time = time.perf_counter_ns()
for expert_id in range(gpu_experts):
w13_weight_ptrs, w13_scale_ptrs, w2_weight_ptrs, w2_scale_ptrs = get_expert_ptrs(expert_id)
cpuinfer.submit(
moe.write_weight_scale_to_buffer_task(
gpu_tp_count=gpu_tp_count,
gpu_experts_num=gpu_experts,
expert_id=expert_id,
w13_weight_ptrs=w13_weight_ptrs,
w13_scale_ptrs=w13_scale_ptrs,
w2_weight_ptrs=w2_weight_ptrs,
@@ -148,23 +176,10 @@ def main():
)
)
cpuinfer.sync()
begin_time = time.perf_counter_ns()
cpuinfer.submit(
moe.write_weight_scale_to_buffer_task(
gpu_tp_count=gpu_tp_count,
gpu_experts_num=gpu_experts,
w13_weight_ptrs=w13_weight_ptrs,
w13_scale_ptrs=w13_scale_ptrs,
w2_weight_ptrs=w2_weight_ptrs,
w2_scale_ptrs=w2_scale_ptrs,
)
)
cpuinfer.sync()
end_time = time.perf_counter_ns()
elapsed_ms = (end_time - begin_time) / 1000000
total_weights = hidden_size * intermediate_size * expert_num * 3
total_bytes = total_weights // group_size + total_weights // 2
total_weights = hidden_size * intermediate_size * gpu_experts * 3
total_bytes = total_weights // group_size * 2 + total_weights // 2 # scale (bf16) + weight (int4)
print(f"write_weight_scale_to_buffer time: {elapsed_ms:.2f} ms")
print(f"Throughput: {total_bytes / (elapsed_ms * 1e6):.2f} GB/s")
@@ -181,9 +196,6 @@ def main():
up_scale_experts = split_expert_tensor(up_scale, per_mat_scale_elems)
down_scale_experts = split_expert_tensor(down_scale, per_mat_scale_elems)
# CPU TP count is always 2 in this test setup (one TP per NUMA node)
cpu_tp_count = 2
# Verify buffers for each TP part
for tp_idx in range(gpu_tp_count):
expected_w13_weights = []
@@ -193,22 +205,22 @@ def main():
weight13_per_tp = per_mat_weight_bytes // gpu_tp_count
scale13_per_tp = per_mat_scale_elems // gpu_tp_count
# Process each GPU expert
for expert_idx in range(gpu_experts):
# For w13 (gate and up), the slicing is straightforward
# Process each GPU expert
for expert_id in range(gpu_experts):
# For w13 (gate and up), the slicing is straightforward
start_weight = tp_idx * weight13_per_tp
end_weight = (tp_idx + 1) * weight13_per_tp
start_scale = tp_idx * scale13_per_tp
end_scale = (tp_idx + 1) * scale13_per_tp
# Gate
gate_weight_tp = gate_q_experts[expert_idx][start_weight:end_weight]
gate_scale_tp = gate_scale_experts[expert_idx][start_scale:end_scale]
gate_weight_tp = gate_q_experts[expert_id][start_weight:end_weight]
gate_scale_tp = gate_scale_experts[expert_id][start_scale:end_scale]
# Up
up_weight_tp = up_q_experts[expert_idx][start_weight:end_weight]
up_scale_tp = up_scale_experts[expert_idx][start_scale:end_scale]
up_weight_tp = up_q_experts[expert_id][start_weight:end_weight]
up_scale_tp = up_scale_experts[expert_id][start_scale:end_scale]
# Down matrix needs special handling because it's sliced column-wise
# We need to reconstruct it from column slices
@@ -228,16 +240,17 @@ def main():
tp_scale_offset = col_scale_start + tp_idx * tp_slice_scale_size
down_weight_tp_parts.append(
down_q_experts[expert_idx][tp_weight_offset : tp_weight_offset + tp_slice_weight_size]
down_q_experts[expert_id][tp_weight_offset : tp_weight_offset + tp_slice_weight_size]
)
down_scale_tp_parts.append(
down_scale_experts[expert_idx][tp_scale_offset : tp_scale_offset + tp_slice_scale_size]
down_scale_experts[expert_id][tp_scale_offset : tp_scale_offset + tp_slice_scale_size]
)
# Concatenate all column slices for this TP
down_weight_tp = torch.cat(down_weight_tp_parts)
down_scale_tp = torch.cat(down_scale_tp_parts)
# Append to expected lists - interleaved by expert: [gate0, up0, gate1, up1, ...]
expected_w13_weights.append(gate_weight_tp)
expected_w13_weights.append(up_weight_tp)
expected_w13_scales.append(gate_scale_tp)
@@ -252,16 +265,85 @@ def main():
expected_w2_scale = torch.cat(expected_w2_scales)
print(f"=== Checking TP part {tp_idx} ===")
print(f" w13 weight shape: actual={w13_weight_bufs[tp_idx].shape}, expected={expected_w13_weight.shape}")
print(f" w13 scale shape: actual={w13_scale_bufs[tp_idx].shape}, expected={expected_w13_scale.shape}")
print(f" w2 weight shape: actual={w2_weight_bufs[tp_idx].shape}, expected={expected_w2_weight.shape}")
print(f" w2 scale shape: actual={w2_scale_bufs[tp_idx].shape}, expected={expected_w2_scale.shape}")
# Assert all checks pass
assert torch.equal(w13_weight_bufs[tp_idx], expected_w13_weight), f"w13 weight bytes mismatch for TP {tp_idx}"
assert torch.allclose(w13_scale_bufs[tp_idx], expected_w13_scale), f"w13 scale values mismatch for TP {tp_idx}"
assert torch.equal(w2_weight_bufs[tp_idx], expected_w2_weight), f"w2 weight bytes mismatch for TP {tp_idx}"
assert torch.allclose(w2_scale_bufs[tp_idx], expected_w2_scale), f"w2 scale values mismatch for TP {tp_idx}"
if not torch.equal(w13_weight_bufs[tp_idx], expected_w13_weight):
diff_mask = w13_weight_bufs[tp_idx] != expected_w13_weight
first_diff_idx = diff_mask.nonzero()[0].item() if diff_mask.any() else -1
print(f" w13 weight mismatch at index {first_diff_idx}")
print(f" actual: {w13_weight_bufs[tp_idx][first_diff_idx:first_diff_idx+10]}")
print(f" expected: {expected_w13_weight[first_diff_idx:first_diff_idx+10]}")
raise AssertionError(f"w13 weight bytes mismatch for TP {tp_idx}")
if not torch.allclose(w13_scale_bufs[tp_idx], expected_w13_scale):
diff = torch.abs(w13_scale_bufs[tp_idx].float() - expected_w13_scale.float())
max_diff_idx = diff.argmax().item()
print(f" w13 scale mismatch, max diff at index {max_diff_idx}")
print(f" actual: {w13_scale_bufs[tp_idx][max_diff_idx]}")
print(f" expected: {expected_w13_scale[max_diff_idx]}")
raise AssertionError(f"w13 scale values mismatch for TP {tp_idx}")
if not torch.equal(w2_weight_bufs[tp_idx], expected_w2_weight):
diff_mask = w2_weight_bufs[tp_idx] != expected_w2_weight
first_diff_idx = diff_mask.nonzero()[0].item() if diff_mask.any() else -1
print(f" w2 weight mismatch at index {first_diff_idx}")
print(f" actual: {w2_weight_bufs[tp_idx][first_diff_idx:first_diff_idx+10]}")
print(f" expected: {expected_w2_weight[first_diff_idx:first_diff_idx+10]}")
raise AssertionError(f"w2 weight bytes mismatch for TP {tp_idx}")
if not torch.allclose(w2_scale_bufs[tp_idx], expected_w2_scale):
diff = torch.abs(w2_scale_bufs[tp_idx].float() - expected_w2_scale.float())
max_diff_idx = diff.argmax().item()
print(f" w2 scale mismatch, max diff at index {max_diff_idx}")
print(f" actual: {w2_scale_bufs[tp_idx][max_diff_idx]}")
print(f" expected: {expected_w2_scale[max_diff_idx]}")
raise AssertionError(f"w2 scale values mismatch for TP {tp_idx}")
print(
f"\n✓ write_weight_scale_to_buffer passed: extracted {gpu_experts} GPU experts across {gpu_tp_count} TP parts from total {expert_num} experts"
f"\n✓ write_weight_scale_to_buffer passed: extracted {gpu_experts} GPU experts across {gpu_tp_count} TP parts"
)
return True
def main():
"""Run tests for all gpu_tp_count values: 1, 2, 4, 8"""
tp_values = [1, 2, 4, 8]
all_passed = True
results = {}
print("=" * 60)
print("Testing K2 write_weight_scale_to_buffer for TP = 1, 2, 4, 8")
print("=" * 60)
for tp in tp_values:
print(f"\n{'='*60}")
print(f"Testing with gpu_tp_count = {tp}")
print(f"{'='*60}")
try:
test_with_tp(tp)
results[tp] = "PASSED"
print(f"✓ TP={tp} PASSED")
except Exception as e:
results[tp] = f"FAILED: {e}"
all_passed = False
print(f"✗ TP={tp} FAILED: {e}")
print("\n" + "=" * 60)
print("SUMMARY")
print("=" * 60)
for tp, result in results.items():
status = "" if "PASSED" in result else ""
print(f" {status} TP={tp}: {result}")
if all_passed:
print("\n✓ ALL TESTS PASSED")
else:
print("\n✗ SOME TESTS FAILED")
sys.exit(1)
if __name__ == "__main__":

View File

@@ -36,6 +36,7 @@ static const bool _is_plain_ = false;
#if defined(__x86_64__) && defined(USE_AMX_AVX_KERNEL)
#include "operators/amx/awq-moe.hpp"
#include "operators/amx/fp8-moe.hpp"
#include "operators/amx/k2-moe.hpp"
#include "operators/amx/la/amx_kernels.hpp"
#include "operators/amx/moe.hpp"
@@ -255,7 +256,7 @@ void bind_moe_module(py::module_& moe_module, const char* name) {
CPUInfer* cpuinfer;
MoeClass* moe;
int gpu_tp_count;
int gpu_experts_num;
int expert_id;
std::vector<uintptr_t> w13_weight_ptrs;
std::vector<uintptr_t> w13_scale_ptrs;
std::vector<uintptr_t> w2_weight_ptrs;
@@ -265,12 +266,12 @@ void bind_moe_module(py::module_& moe_module, const char* name) {
static void inner(void* args) {
Args* args_ = (Args*)args;
args_->cpuinfer->enqueue(&MoeClass::write_weight_scale_to_buffer, args_->moe, args_->gpu_tp_count,
args_->gpu_experts_num, args_->w13_weight_ptrs, args_->w13_scale_ptrs,
args_->w2_weight_ptrs, args_->w2_scale_ptrs);
args_->expert_id, args_->w13_weight_ptrs, args_->w13_scale_ptrs, args_->w2_weight_ptrs,
args_->w2_scale_ptrs);
}
static std::pair<intptr_t, intptr_t> cpuinfer_interface(std::shared_ptr<MoeClass> moe, int gpu_tp_count,
int gpu_experts_num, py::list w13_weight_ptrs,
int expert_id, py::list w13_weight_ptrs,
py::list w13_scale_ptrs, py::list w2_weight_ptrs,
py::list w2_scale_ptrs) {
// Convert Python lists to std::vector<uintptr_t>
@@ -281,15 +282,59 @@ void bind_moe_module(py::module_& moe_module, const char* name) {
for (auto item : w2_weight_ptrs) w2_weight_vec.push_back(py::cast<uintptr_t>(item));
for (auto item : w2_scale_ptrs) w2_scale_vec.push_back(py::cast<uintptr_t>(item));
Args* args = new Args{nullptr, moe.get(), gpu_tp_count, gpu_experts_num,
Args* args = new Args{nullptr, moe.get(), gpu_tp_count, expert_id,
w13_weight_vec, w13_scale_vec, w2_weight_vec, w2_scale_vec};
return std::make_pair((intptr_t)&inner, (intptr_t)args);
}
};
moe_cls.def("write_weight_scale_to_buffer_task", &WriteWeightScaleToBufferBindings::cpuinfer_interface,
py::arg("gpu_tp_count"), py::arg("gpu_experts_num"), py::arg("w13_weight_ptrs"),
py::arg("w13_scale_ptrs"), py::arg("w2_weight_ptrs"), py::arg("w2_scale_ptrs"));
py::arg("gpu_tp_count"), py::arg("expert_id"), py::arg("w13_weight_ptrs"), py::arg("w13_scale_ptrs"),
py::arg("w2_weight_ptrs"), py::arg("w2_scale_ptrs"));
}
// FP8 MoE: processes one expert at a time (expert_id instead of gpu_experts_num)
if constexpr (std::is_same_v<MoeTP, AMX_FP8_MOE_TP<amx::GemmKernel224FP8>>) {
struct WriteWeightScaleToBufferBindings {
struct Args {
CPUInfer* cpuinfer;
MoeClass* moe;
int gpu_tp_count;
int expert_id;
std::vector<uintptr_t> w13_weight_ptrs;
std::vector<uintptr_t> w13_scale_ptrs;
std::vector<uintptr_t> w2_weight_ptrs;
std::vector<uintptr_t> w2_scale_ptrs;
};
static void inner(void* args) {
Args* args_ = (Args*)args;
args_->cpuinfer->enqueue(&MoeClass::write_weight_scale_to_buffer, args_->moe, args_->gpu_tp_count,
args_->expert_id, args_->w13_weight_ptrs, args_->w13_scale_ptrs, args_->w2_weight_ptrs,
args_->w2_scale_ptrs);
}
static std::pair<intptr_t, intptr_t> cpuinfer_interface(std::shared_ptr<MoeClass> moe, int gpu_tp_count,
int expert_id, py::list w13_weight_ptrs,
py::list w13_scale_ptrs, py::list w2_weight_ptrs,
py::list w2_scale_ptrs) {
// Convert Python lists to std::vector<uintptr_t>
std::vector<uintptr_t> w13_weight_vec, w13_scale_vec, w2_weight_vec, w2_scale_vec;
for (auto item : w13_weight_ptrs) w13_weight_vec.push_back(py::cast<uintptr_t>(item));
for (auto item : w13_scale_ptrs) w13_scale_vec.push_back(py::cast<uintptr_t>(item));
for (auto item : w2_weight_ptrs) w2_weight_vec.push_back(py::cast<uintptr_t>(item));
for (auto item : w2_scale_ptrs) w2_scale_vec.push_back(py::cast<uintptr_t>(item));
Args* args = new Args{nullptr, moe.get(), gpu_tp_count, expert_id,
w13_weight_vec, w13_scale_vec, w2_weight_vec, w2_scale_vec};
return std::make_pair((intptr_t)&inner, (intptr_t)args);
}
};
moe_cls.def("write_weight_scale_to_buffer_task", &WriteWeightScaleToBufferBindings::cpuinfer_interface,
py::arg("gpu_tp_count"), py::arg("expert_id"), py::arg("w13_weight_ptrs"), py::arg("w13_scale_ptrs"),
py::arg("w2_weight_ptrs"), py::arg("w2_scale_ptrs"));
}
#endif
}
@@ -562,6 +607,7 @@ PYBIND11_MODULE(kt_kernel_ext, m) {
bind_moe_module<AMX_MOE_TP<amx::GemmKernel224Int4_1>>(moe_module, "AMXInt4_1_MOE");
bind_moe_module<AMX_AWQ_MOE_TP<amx::GemmKernel224Int4_1_LowKGroup>>(moe_module, "AMXInt4_1KGroup_MOE");
bind_moe_module<AMX_K2_MOE_TP<amx::GemmKernel224Int4SmallKGroup>>(moe_module, "AMXInt4_KGroup_MOE");
bind_moe_module<AMX_FP8_MOE_TP<amx::GemmKernel224FP8>>(moe_module, "AMXFP8_MOE");
#endif
#if defined(USE_MOE_KERNEL)
bind_moe_module<MOE_KERNEL_TP<moe_kernel::GemmKernelInt8, _is_plain_>>(moe_module, "Int8_KERNEL_MOE");

View File

@@ -1,73 +1,49 @@
/**
* @Description :
* @Author : chenht2022
* @Description : AWQ Int4 AMX MoE operator with KGroup quantization and zero-point support
* @Author : chenht2022, oql
* @Date : 2024-07-22 02:03:22
* @Version : 1.0.0
* @LastEditors : chenht2022
* @LastEditTime : 2024-07-25 10:35:10
* @Version : 2.0.0
* @LastEditors : oql
* @LastEditTime : 2025-12-10
* @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
*
* This file implements AWQ Int4 MoE using CRTP pattern, inheriting from moe_base.hpp.
* AWQ weights are stored with group-wise scales and zero-points (KGroup Int4 with zeros).
**/
#ifndef CPUINFER_OPERATOR_AMX_AWQ_MOE_H
#define CPUINFER_OPERATOR_AMX_AWQ_MOE_H
// #define CHECK
#include <cstddef>
#include <cstdint>
#include <cstring>
// #define FORWARD_TIME_PROFILE
// #define FORWARD_TIME_REPORT
#include <immintrin.h>
#include <cmath>
#include <cstdio>
#include <filesystem>
#include <fstream>
#include <string>
#include <vector>
#include "../../cpu_backend/shared_mem_buffer.h"
#include "../../cpu_backend/worker_pool.h"
#include "../common.hpp"
#include "../moe-tp.hpp"
#include "la/amx.hpp"
#include "llama.cpp/ggml.h"
#include "moe_base.hpp"
/**
* @brief AWQ Int4 MoE operator using CRTP pattern
* @tparam T Kernel type for AWQ quantization
*
* This class provides AWQ-specific implementations:
* - do_gate_up_gemm: Int4 weight with KGroup scale + zeros + AMX GEMM
* - do_down_gemm: Same Int4 KGroup GEMM
* - load_weights: Load Int4 weights with group-wise scales and zero-points
*/
template <class T>
class AMX_AWQ_MOE_TP {
class AMX_AWQ_MOE_TP : public AMX_MOE_BASE<T, AMX_AWQ_MOE_TP<T>> {
private:
int tp_part_idx;
using Base = AMX_MOE_BASE<T, AMX_AWQ_MOE_TP<T>>;
using Base::config_;
using Base::tp_part_idx;
using Base::gate_bb_;
using Base::up_bb_;
using Base::down_bb_;
using Base::gate_up_ba_;
using Base::gate_bc_;
using Base::up_bc_;
using Base::down_ba_;
using Base::down_bc_;
using Base::m_local_num_;
std::filesystem::path prefix;
void* gate_proj_; // [expert_num * intermediate_size * hidden_size ( /32 if
// quantized)]
void* up_proj_; // [expert_num * intermediate_size * hidden_size ( /32 if
// quantized)]
void* down_proj_; // [expert_num * hidden_size * intermediate_size ( /32 if
// quantized)]
ggml_bf16_t* m_local_input_; // [num_experts_per_tok * max_len * hidden_size]
ggml_bf16_t* m_local_gate_output_; // [num_experts_per_tok * max_len * intermediate_size]
ggml_bf16_t* m_local_up_output_; // [num_experts_per_tok * max_len * intermediate_size]
ggml_bf16_t* m_local_down_output_; // [num_experts_per_tok * max_len * hidden_size]
std::vector<std::vector<int>> m_local_pos_; // [max_len, num_experts_per_tok]
std::vector<int> m_local_num_; // [expert_num]
std::vector<int> m_expert_id_map_; // [expert_num]
std::vector<ggml_bf16_t*> m_local_input_ptr_; // [expert_num]
std::vector<ggml_bf16_t*> m_local_gate_output_ptr_; // [expert_num]
std::vector<ggml_bf16_t*> m_local_up_output_ptr_; // [expert_num]
std::vector<ggml_bf16_t*> m_local_down_output_ptr_; // [expert_num]
std::vector<std::shared_ptr<typename T::BufferA>> gate_up_ba_;
std::vector<std::shared_ptr<typename T::BufferB>> gate_bb_;
std::vector<std::shared_ptr<typename T::BufferC>> gate_bc_;
std::vector<std::shared_ptr<typename T::BufferB>> up_bb_;
std::vector<std::shared_ptr<typename T::BufferC>> up_bc_;
std::vector<std::shared_ptr<typename T::BufferA>> down_ba_;
std::vector<std::shared_ptr<typename T::BufferB>> down_bb_;
std::vector<std::shared_ptr<typename T::BufferC>> down_bc_;
#ifdef CHECK
char verify_bb[100000000];
char check_bb[100000000];
@@ -274,32 +250,35 @@ class AMX_AWQ_MOE_TP {
zeros_size / mat_split);
zeros_file.close();
}
#ifdef CHECK
inline void load_check() {
memcpy(check_bb, (char*)down_bb_[compare_expers]->b,
T::BufferB::required_size(config_.hidden_size, config_.intermediate_size));
T::BufferB::required_size(config_.hidden_size, config_.intermediate_size, config_.quant_config.group_size));
}
void verify_load_right() {
// printf("varify down bb_0 %d\n", tp_part_idx);
memcpy(verify_bb, (char*)down_bb_[compare_expers]->b,
T::BufferB::required_size(config_.hidden_size, config_.intermediate_size));
// check if verify_bb_0 equal to check_bb_0
if (memcmp(verify_bb, check_bb, T::BufferB::required_size(config_.hidden_size, config_.intermediate_size)) != 0) {
T::BufferB::required_size(config_.hidden_size, config_.intermediate_size, config_.quant_config.group_size));
if (memcmp(verify_bb, check_bb,
T::BufferB::required_size(config_.hidden_size, config_.intermediate_size,
config_.quant_config.group_size)) != 0) {
printf("verify error\n");
for (size_t i = 0; i < T::BufferB::required_size(config_.hidden_size, config_.intermediate_size); ++i) {
for (size_t i = 0; i < T::BufferB::required_size(config_.hidden_size, config_.intermediate_size,
config_.quant_config.group_size);
++i) {
if (verify_bb[i] != check_bb[i]) {
printf("Difference at byte %zu: verify_bb_%d[%zu] = %02x, check_bb[%zu] = %02x\n", i, compare_expers, i,
(unsigned char)verify_bb[i], i, (unsigned char)check_bb[i]);
break; // find the first difference and exit
break;
}
}
assert(0);
} else {
printf("pass verify\n");
// pick out the 100th~150th byte of scale to see
printf("numa %d, verify_bb_%d:\n", tp_part_idx, compare_expers);
size_t size = T::BufferB::required_size(config_.hidden_size, config_.intermediate_size);
size_t size =
T::BufferB::required_size(config_.hidden_size, config_.intermediate_size, config_.quant_config.group_size);
size_t scale_size = config_.hidden_size * sizeof(float);
for (size_t i = size - scale_size; i < size - scale_size + 50; ++i) {
printf("%02x ", (unsigned char)verify_bb[i]);
@@ -392,7 +371,7 @@ class AMX_AWQ_MOE_TP {
}
// AVX-optimized function to convert INT4 zeros to float mins
// mins = zeros * scales (element-wise), where scales is float format
// mins = -(zeros * scales) (element-wise), where scales is float format
inline void convert_zeros_to_mins_avx(const uint32_t* zeros_int4_packed, const float* scales, float* mins,
size_t num_elements) {
constexpr size_t simd_width = 8; // 每次解 8 个 int4
@@ -408,30 +387,25 @@ class AMX_AWQ_MOE_TP {
}
}
#ifdef FORWARD_TIME_REPORT
std::chrono::time_point<std::chrono::high_resolution_clock> last_now;
#endif
public:
using input_t = ggml_bf16_t;
using output_t = float;
GeneralMOEConfig config_;
static constexpr double ELEMENT_SIZE = T::ELEMENT_SIZE;
using typename Base::input_t;
using typename Base::output_t;
AMX_AWQ_MOE_TP(GeneralMOEConfig config, int tp_part_idx) {
auto& quant_config = config.quant_config;
int& group_size = quant_config.group_size;
AMX_AWQ_MOE_TP() = default;
AMX_AWQ_MOE_TP(GeneralMOEConfig config, int tp_part_idx_ = 0) : Base(config, tp_part_idx_) {
auto& quant_config = config_.quant_config;
if (quant_config.group_size == 0 || !quant_config.zero_point) {
throw std::runtime_error("AWQ-Quantization AMX MoE only support KGroup Int4_1");
}
auto& load = config.load;
auto& save = config.save;
if (load && config.path == "") {
load = false;
}
prefix = config.path;
prefix = prefix / ("_layer_" + std::to_string(config.layer_idx)) / ("_numa_" + std::to_string(tp_part_idx));
printf("Creating AMX_AWQ_MOE_TP %d at numa %d\n", tp_part_idx_, numa_node_of_cpu(sched_getcpu()));
auto& load = config_.load;
auto& save = config_.save;
prefix = config_.path;
prefix = prefix / ("_layer_" + std::to_string(config_.layer_idx)) / ("_numa_" + std::to_string(tp_part_idx_));
if (save) {
std::cout << "Creating " << prefix << std::endl;
std::filesystem::create_directories(prefix);
@@ -443,77 +417,74 @@ class AMX_AWQ_MOE_TP {
throw std::runtime_error("Path not found: " + prefix.string());
}
}
this->tp_part_idx = tp_part_idx;
config_ = config;
gate_proj_ = config_.gate_proj;
up_proj_ = config_.up_proj;
down_proj_ = config_.down_proj;
MemoryRequest mem_requests;
mem_requests.append_pointer(
&m_local_input_, sizeof(ggml_bf16_t) * config_.num_experts_per_tok * config_.max_len * config_.hidden_size);
mem_requests.append_pointer(&m_local_gate_output_, sizeof(ggml_bf16_t) * config_.num_experts_per_tok *
config_.max_len * config_.intermediate_size);
mem_requests.append_pointer(&m_local_up_output_, sizeof(ggml_bf16_t) * config_.num_experts_per_tok *
config_.max_len * config_.intermediate_size);
mem_requests.append_pointer(&m_local_down_output_, sizeof(ggml_bf16_t) * config_.num_experts_per_tok *
config_.max_len * config_.hidden_size);
m_local_pos_.resize(config_.max_len);
for (int i = 0; i < config_.max_len; i++) {
m_local_pos_[i].resize(config_.num_experts_per_tok);
}
m_expert_id_map_.resize(config_.expert_num);
m_local_num_.resize(config_.expert_num);
m_local_input_ptr_.resize(config_.expert_num);
m_local_gate_output_ptr_.resize(config_.expert_num);
m_local_up_output_ptr_.resize(config_.expert_num);
m_local_down_output_ptr_.resize(config_.expert_num);
for (size_t i = 0; i < config_.expert_num; i++) {
gate_up_ba_.push_back(
std::make_shared<typename T::BufferA>(config_.max_len, config_.hidden_size, group_size, nullptr));
gate_bc_.push_back(std::make_shared<typename T::BufferC>(config_.max_len, config_.intermediate_size, nullptr));
up_bc_.push_back(std::make_shared<typename T::BufferC>(config_.max_len, config_.intermediate_size, nullptr));
down_ba_.push_back(
std::make_shared<typename T::BufferA>(config_.max_len, config_.intermediate_size, group_size, nullptr));
down_bc_.push_back(std::make_shared<typename T::BufferC>(config_.max_len, config_.hidden_size, nullptr));
void* gate_bb_ptr =
std::aligned_alloc(64, T::BufferB::required_size(config_.intermediate_size, config_.hidden_size, group_size));
gate_bb_.push_back(std::make_shared<typename T::BufferB>(config_.intermediate_size, config_.hidden_size,
group_size, gate_bb_ptr));
void* up_bb_ptr =
std::aligned_alloc(64, T::BufferB::required_size(config_.intermediate_size, config_.hidden_size, group_size));
up_bb_.push_back(
std::make_shared<typename T::BufferB>(config_.intermediate_size, config_.hidden_size, group_size, up_bb_ptr));
void* down_bb_ptr =
std::aligned_alloc(64, T::BufferB::required_size(config_.hidden_size, config_.intermediate_size, group_size));
down_bb_.push_back(std::make_shared<typename T::BufferB>(config_.hidden_size, config_.intermediate_size,
group_size, down_bb_ptr));
}
for (int i = 0; i < config_.expert_num; i++) {
mem_requests.append_function([this, i](void* new_ptr) { gate_up_ba_[i]->set_data(new_ptr); },
T::BufferA::required_size(config_.max_len, config_.hidden_size, group_size));
mem_requests.append_function([this, i](void* new_ptr) { gate_bc_[i]->set_data(new_ptr); },
T::BufferC::required_size(config_.max_len, config_.intermediate_size));
mem_requests.append_function([this, i](void* new_ptr) { up_bc_[i]->set_data(new_ptr); },
T::BufferC::required_size(config_.max_len, config_.intermediate_size));
mem_requests.append_function([this, i](void* new_ptr) { down_ba_[i]->set_data(new_ptr); },
T::BufferA::required_size(config_.max_len, config_.intermediate_size, group_size));
mem_requests.append_function([this, i](void* new_ptr) { down_bc_[i]->set_data(new_ptr); },
T::BufferC::required_size(config_.max_len, config_.hidden_size));
}
shared_mem_buffer_numa.alloc(tp_part_idx, this, mem_requests);
}
~AMX_AWQ_MOE_TP() {
// shared_mem_buffer_numa.dealloc(this);
~AMX_AWQ_MOE_TP() = default;
// ============================================================================
// CRTP buffer creation - with group_size (AWQ uses zero-point)
// ============================================================================
size_t buffer_a_required_size_impl(size_t m, size_t k) const {
return T::BufferA::required_size(m, k, config_.quant_config.group_size);
}
size_t buffer_b_required_size_impl(size_t n, size_t k) const {
return T::BufferB::required_size(n, k, config_.quant_config.group_size);
}
size_t buffer_c_required_size_impl(size_t m, size_t n) const {
return T::BufferC::required_size(m, n);
}
std::shared_ptr<typename T::BufferA> make_buffer_a_impl(size_t m, size_t k, void* data) const {
return std::make_shared<typename T::BufferA>(m, k, config_.quant_config.group_size, data);
}
std::shared_ptr<typename T::BufferB> make_buffer_b_impl(size_t n, size_t k, void* data) const {
return std::make_shared<typename T::BufferB>(n, k, config_.quant_config.group_size, data);
}
std::shared_ptr<typename T::BufferC> make_buffer_c_impl(size_t m, size_t n, void* data) const {
return std::make_shared<typename T::BufferC>(m, n, data);
}
// ============================================================================
// CRTP virtual points - GEMM dispatch (uses kgroup with zeros)
// ============================================================================
void do_gate_up_gemm(bool do_up, int expert_idx, int ith, int nth, int qlen) {
auto& group_size = config_.quant_config.group_size;
int m = m_local_num_[expert_idx];
auto& ba = gate_up_ba_[expert_idx];
auto& bb = do_up ? up_bb_[expert_idx] : gate_bb_[expert_idx];
auto& bc = do_up ? up_bc_[expert_idx] : gate_bc_[expert_idx];
// Dispatch based on qlen threshold
if (qlen > 4 * config_.expert_num / config_.num_experts_per_tok) {
amx::mat_mul_kgroup(m, config_.intermediate_size, config_.hidden_size, group_size, ba, bb, bc, ith, nth);
} else {
amx::vec_mul_kgroup(m, config_.intermediate_size, config_.hidden_size, group_size, ba, bb, bc, ith, nth);
}
}
void do_down_gemm(int expert_idx, int ith, int nth, int qlen) {
auto& group_size = config_.quant_config.group_size;
int m = m_local_num_[expert_idx];
if (qlen > 4 * config_.expert_num / config_.num_experts_per_tok) {
amx::mat_mul_kgroup(m, config_.hidden_size, config_.intermediate_size, group_size, down_ba_[expert_idx],
down_bb_[expert_idx], down_bc_[expert_idx], ith, nth);
} else {
amx::vec_mul_kgroup(m, config_.hidden_size, config_.intermediate_size, group_size, down_ba_[expert_idx],
down_bb_[expert_idx], down_bc_[expert_idx], ith, nth);
}
}
/**
* @brief Load Int4 weights with scales and zero-points
*
* AWQ weights include:
* - Quantized INT4 weights
* - FP16 scales (converted to FP32)
* - INT4 zeros (converted to FP32 mins = -scale * zero)
*/
void load_weights() {
auto& quant_config = config_.quant_config;
int& group_size = quant_config.group_size;
@@ -524,15 +495,12 @@ class AMX_AWQ_MOE_TP {
auto pool = config_.pool->get_subpool(tp_part_idx);
if (config_.gate_projs.size()) {
throw std::runtime_error("AMX load weights is not support");
throw std::runtime_error("AMX load weights from gate_projs is not supported");
} else {
// AWQ Load from file implementation
int nth = T::recommended_nth(config_.intermediate_size);
static uint8_t mat_type_all = 3, mat_split = 1;
if (config_.load) {
throw std::runtime_error("AMX load weights from file is not support");
throw std::runtime_error("AMX load weights from file is not supported");
}
// check process, store down matrix to check
#ifdef CHECK
load_check();
#endif
@@ -540,7 +508,7 @@ class AMX_AWQ_MOE_TP {
else if (config_.gate_scale != nullptr)
#endif
{
// Loading quantized weights
// Loading quantized weights with scales and zeros
pool->do_work_stealing_job(
nth * config_.expert_num, nullptr,
[this, nth, physical_to_logical_map](int task_id) {
@@ -594,7 +562,7 @@ class AMX_AWQ_MOE_TP {
(ggml_fp16_t*)config_.down_scale + (logical_expert_id * scale_elem_count),
scale_elem_count);
// Convert INT4 zeros to FP32 mins
// Convert INT4 zeros to FP32 mins: mins = -(scale * zero)
convert_zeros_to_mins_avx(
(const uint32_t*)((uint8_t*)config_.gate_zero + ((logical_expert_id * scale_elem_count) >> 1)),
gate_bb_[expert_idx]->d, gate_bb_[expert_idx]->mins, scale_elem_count);
@@ -617,7 +585,7 @@ class AMX_AWQ_MOE_TP {
}
}
else {
// Online Quantization
// Online Quantization from BF16
assert(config_.gate_proj != nullptr);
pool->do_work_stealing_job(
@@ -668,450 +636,21 @@ class AMX_AWQ_MOE_TP {
}
}
void warm_up() {
int qlen = config_.max_len;
std::vector<uint8_t> input(sizeof(ggml_bf16_t) * qlen * config_.hidden_size);
std::vector<uint8_t> output(sizeof(ggml_bf16_t) * qlen * config_.hidden_size);
std::vector<int64_t> expert_ids(qlen * config_.num_experts_per_tok);
std::vector<float> weights(qlen * config_.num_experts_per_tok);
for (int i = 0; i < qlen * config_.num_experts_per_tok; i++) {
expert_ids[i] = i % config_.expert_num;
weights[i] = 0.01;
}
forward(qlen, config_.num_experts_per_tok, expert_ids.data(), weights.data(), input.data(), output.data());
}
void forward(int qlen, int k, const int64_t* expert_ids, const float* weights, const void* input, void* output) {
if (qlen > 1) {
forward_prefill(qlen, k, expert_ids, weights, input, output);
} else {
forward_decode(k, expert_ids, weights, input, output);
}
}
#define DIRECT_OR_POOL_BY_QLEN(var, fn) \
do { \
if (qlen < 10) { \
for (int i = 0; i < (var); i++) { \
(fn)(i); \
} \
} else { \
pool->do_work_stealing_job((var), nullptr, (fn), nullptr); \
} \
} while (0)
#define MATMUL_OR_VECMUL_KGROUP_BY_QLEN(...) \
do { \
if (qlen > 4 * config_.expert_num / config_.num_experts_per_tok) { \
amx::mat_mul_kgroup(__VA_ARGS__); \
} else { \
amx::vec_mul_kgroup(__VA_ARGS__); \
} \
} while (0)
void forward_prefill(int qlen, int k, const int64_t* expert_ids, const float* weights, const void* input,
void* output) {
auto pool = config_.pool->get_subpool(tp_part_idx);
auto& quant_config = config_.quant_config;
int& group_size = quant_config.group_size;
#ifdef FORWARD_TIME_PROFILE
auto start_time = std::chrono::high_resolution_clock::now();
auto last = start_time;
// 用于保存各阶段耗时(单位:微秒)
long prepare_time = 0, cpy_input_time = 0, q_input_time = 0, up_gate_time = 0;
long act_time = 0, q_down_time = 0, down_time = 0, weight_time = 0;
int max_local_num = 0; // 记录最大的 local num
#endif
int activated_expert = 0;
for (int i = 0; i < config_.expert_num; i++) {
m_local_num_[i] = 0;
}
for (int i = 0; i < qlen; i++) {
for (int j = 0; j < k; j++) {
if (expert_ids[i * k + j] < config_.num_gpu_experts || expert_ids[i * k + j] >= config_.expert_num) {
continue;
}
m_local_pos_[i][j] = m_local_num_[expert_ids[i * k + j]]++;
}
}
for (int i = 0; i < config_.expert_num; i++) {
if (m_local_num_[i] > 0) {
#ifdef FORWARD_TIME_PROFILE
max_local_num = std::max(max_local_num, m_local_num_[i]);
#endif
m_expert_id_map_[activated_expert] = i;
activated_expert++;
}
}
// activated_expert 已经统计完成
size_t offset = 0;
for (int i = 0; i < config_.expert_num; i++) {
m_local_input_ptr_[i] = m_local_input_ + offset * config_.hidden_size;
m_local_gate_output_ptr_[i] = m_local_gate_output_ + offset * config_.intermediate_size;
m_local_up_output_ptr_[i] = m_local_up_output_ + offset * config_.intermediate_size;
m_local_down_output_ptr_[i] = m_local_down_output_ + offset * config_.hidden_size;
offset += m_local_num_[i];
}
#ifdef FORWARD_TIME_PROFILE
{
auto now_time = std::chrono::high_resolution_clock::now();
prepare_time = std::chrono::duration_cast<std::chrono::microseconds>(now_time - last).count();
last = now_time;
}
#endif
DIRECT_OR_POOL_BY_QLEN(qlen, [&](int i) {
for (int j = 0; j < k; j++) {
if (expert_ids[i * k + j] < config_.num_gpu_experts || expert_ids[i * k + j] >= config_.expert_num) {
continue;
}
memcpy(m_local_input_ptr_[expert_ids[i * k + j]] + m_local_pos_[i][j] * config_.hidden_size,
(ggml_bf16_t*)input + i * config_.hidden_size, sizeof(ggml_bf16_t) * config_.hidden_size);
}
});
#ifdef FORWARD_TIME_PROFILE
{
auto now_time = std::chrono::high_resolution_clock::now();
cpy_input_time = std::chrono::duration_cast<std::chrono::microseconds>(now_time - last).count();
last = now_time;
}
#endif
DIRECT_OR_POOL_BY_QLEN(activated_expert, [this](int task_id) {
int expert_idx = m_expert_id_map_[task_id];
gate_up_ba_[expert_idx]->from_mat(m_local_num_[expert_idx], m_local_input_ptr_[expert_idx], 0, 1);
});
#ifdef FORWARD_TIME_PROFILE
{
auto now_time = std::chrono::high_resolution_clock::now();
q_input_time = std::chrono::duration_cast<std::chrono::microseconds>(now_time - last).count();
last = now_time;
}
#endif
int nth = T::recommended_nth(config_.intermediate_size);
pool->do_work_stealing_job(
nth * activated_expert * 2, [](int _) { T::config(); },
[this, nth, qlen](int task_id2) {
int& group_size = config_.quant_config.group_size;
int task_id = task_id2 / 2;
bool do_up = task_id2 % 2;
int expert_idx = m_expert_id_map_[task_id / nth];
int ith = task_id % nth;
if (do_up) {
MATMUL_OR_VECMUL_KGROUP_BY_QLEN(m_local_num_[expert_idx], config_.intermediate_size, config_.hidden_size,
group_size, gate_up_ba_[expert_idx], up_bb_[expert_idx], up_bc_[expert_idx],
ith, nth);
up_bc_[expert_idx]->to_mat(m_local_num_[expert_idx], m_local_up_output_ptr_[expert_idx], ith, nth);
} else {
MATMUL_OR_VECMUL_KGROUP_BY_QLEN(m_local_num_[expert_idx], config_.intermediate_size, config_.hidden_size,
group_size, gate_up_ba_[expert_idx], gate_bb_[expert_idx],
gate_bc_[expert_idx], ith, nth);
gate_bc_[expert_idx]->to_mat(m_local_num_[expert_idx], m_local_gate_output_ptr_[expert_idx], ith, nth);
}
},
nullptr);
#ifdef FORWARD_TIME_PROFILE
{
auto now_time = std::chrono::high_resolution_clock::now();
up_gate_time = std::chrono::duration_cast<std::chrono::microseconds>(now_time - last).count();
last = now_time;
}
#endif
auto up_gate_fn = [this, nth](int task_id) {
int expert_idx = m_expert_id_map_[task_id / nth];
int ith = task_id % nth;
auto [n_start, n_end] = T::split_range_n(config_.intermediate_size, ith, nth);
for (int i = 0; i < m_local_num_[expert_idx]; i++) {
ggml_bf16_t* gate_output_ptr = &m_local_gate_output_ptr_[expert_idx][i * config_.intermediate_size];
ggml_bf16_t* up_output_ptr = &m_local_up_output_ptr_[expert_idx][i * config_.intermediate_size];
for (int j = n_start; j < n_end; j += 32) {
__m512 gate_val0, gate_val1, up_val0, up_val1;
avx512_32xbf16_to_32xfp32((__m512i*)(gate_output_ptr + j), &gate_val0, &gate_val1);
avx512_32xbf16_to_32xfp32((__m512i*)(up_output_ptr + j), &up_val0, &up_val1);
__m512 result0 = amx::act_fn(gate_val0, up_val0);
__m512 result1 = amx::act_fn(gate_val1, up_val1);
avx512_32xfp32_to_32xbf16(&result0, &result1, (__m512i*)(gate_output_ptr + j));
}
}
};
DIRECT_OR_POOL_BY_QLEN(nth * activated_expert, up_gate_fn);
#ifdef FORWARD_TIME_PROFILE
{
auto now_time = std::chrono::high_resolution_clock::now();
act_time = std::chrono::duration_cast<std::chrono::microseconds>(now_time - last).count();
last = now_time;
}
#endif
pool->do_work_stealing_job(
activated_expert, nullptr,
[this](int task_id) {
int expert_idx = m_expert_id_map_[task_id];
down_ba_[expert_idx]->from_mat(m_local_num_[expert_idx], m_local_gate_output_ptr_[expert_idx], 0, 1);
},
nullptr);
#ifdef FORWARD_TIME_PROFILE
{
auto now_time = std::chrono::high_resolution_clock::now();
q_down_time = std::chrono::duration_cast<std::chrono::microseconds>(now_time - last).count();
last = now_time;
}
#endif
nth = T::recommended_nth(config_.hidden_size);
pool->do_work_stealing_job(
nth * activated_expert, [](int _) { T::config(); },
[this, nth, qlen](int task_id) {
int& group_size = config_.quant_config.group_size;
int expert_idx = m_expert_id_map_[task_id / nth];
int ith = task_id % nth;
MATMUL_OR_VECMUL_KGROUP_BY_QLEN(m_local_num_[expert_idx], config_.hidden_size, config_.intermediate_size,
group_size, down_ba_[expert_idx], down_bb_[expert_idx], down_bc_[expert_idx],
ith, nth);
down_bc_[expert_idx]->to_mat(m_local_num_[expert_idx], m_local_down_output_ptr_[expert_idx], ith, nth);
},
nullptr);
#ifdef FORWARD_TIME_PROFILE
{
auto now_time = std::chrono::high_resolution_clock::now();
down_time = std::chrono::duration_cast<std::chrono::microseconds>(now_time - last).count();
last = now_time;
}
#endif
pool->do_work_stealing_job(
qlen, nullptr,
[this, nth, output, k, expert_ids, weights](int i) {
for (int e = 0; e < config_.hidden_size; e += 32) {
__m512 x0 = _mm512_setzero_ps();
__m512 x1 = _mm512_setzero_ps();
for (int j = 0; j < k; j++) {
if (expert_ids[i * k + j] < config_.num_gpu_experts || expert_ids[i * k + j] >= config_.expert_num) {
continue;
}
__m512 weight = _mm512_set1_ps(weights[i * k + j]);
__m512 down_output0, down_output1;
avx512_32xbf16_to_32xfp32((__m512i*)(m_local_down_output_ptr_[expert_ids[i * k + j]] +
m_local_pos_[i][j] * config_.hidden_size + e),
&down_output0, &down_output1);
x0 = _mm512_fmadd_ps(down_output0, weight, x0);
x1 = _mm512_fmadd_ps(down_output1, weight, x1);
}
auto f32out = (__m512*)((float*)output + i * config_.hidden_size + e);
f32out[0] = x0;
f32out[1] = x1;
}
},
nullptr);
#ifdef FORWARD_TIME_PROFILE
{
auto now_time = std::chrono::high_resolution_clock::now();
weight_time = std::chrono::duration_cast<std::chrono::microseconds>(now_time - last).count();
last = now_time;
}
auto end_time = std::chrono::high_resolution_clock::now();
auto forward_total_time = std::chrono::duration_cast<std::chrono::microseconds>(end_time - start_time).count();
// 在函数末尾一次性打印所有阶段的耗时,并附带 max_local_num 和 qlen
printf(
"Profiling Results (numa[%d]): activated_expert: %d, prepare: %ld us, cpy_input: %ld us, q_input: %ld us, "
"up_gate: %ld us, act: %ld us, q_down: %ld us, down: %ld us, weight: %ld us, total: %ld us, max_local_num: "
"%d, qlen: %d\n",
tp_part_idx, activated_expert, prepare_time, cpy_input_time, q_input_time, up_gate_time, act_time, q_down_time,
down_time, weight_time, forward_total_time, max_local_num, qlen);
#endif
}
void forward_decode(int k, const int64_t* expert_ids, const float* weights, const void* input, void* output) {
int qlen = 1;
auto pool = config_.pool->get_subpool(tp_part_idx);
auto& quant_config = config_.quant_config;
int& group_size = quant_config.group_size;
#ifdef FORWARD_TIME_PROFILE
auto start_time = std::chrono::high_resolution_clock::now();
auto last = start_time;
// 用于保存各阶段耗时(单位:微秒)
long prepare_time = 0, cpy_input_time = 0, q_input_time = 0, up_gate_time = 0;
long act_time = 0, q_down_time = 0, down_time = 0, weight_time = 0;
int max_local_num = 0; // 记录最大的 local num
#endif
int activated_expert = 0;
for (int i = 0; i < k; i++) {
if (expert_ids[i] < config_.num_gpu_experts || expert_ids[i] >= config_.expert_num) {
continue;
}
m_expert_id_map_[activated_expert] = expert_ids[i];
activated_expert++;
}
size_t offset = 0;
for (int i = 0; i < activated_expert; i++) {
auto expert_idx = m_expert_id_map_[i];
m_local_gate_output_ptr_[expert_idx] = m_local_gate_output_ + offset * config_.intermediate_size;
m_local_up_output_ptr_[expert_idx] = m_local_up_output_ + offset * config_.intermediate_size;
m_local_down_output_ptr_[expert_idx] = m_local_down_output_ + offset * config_.hidden_size;
offset += qlen;
}
gate_up_ba_[0]->from_mat(qlen, (ggml_bf16_t*)input, 0, 1);
#ifdef FORWARD_TIME_PROFILE
{
auto now_time = std::chrono::high_resolution_clock::now();
q_input_time = std::chrono::duration_cast<std::chrono::microseconds>(now_time - last).count();
last = now_time;
}
#endif
int nth = T::recommended_nth(config_.intermediate_size);
pool->do_work_stealing_job(
nth * activated_expert * 2, [](int _) { T::config(); },
[this, nth, qlen](int task_id2) {
int& group_size = config_.quant_config.group_size;
int task_id = task_id2 / 2;
bool do_up = task_id2 % 2;
int expert_idx = m_expert_id_map_[task_id / nth];
int ith = task_id % nth;
if (do_up) {
amx::vec_mul_kgroup(qlen, config_.intermediate_size, config_.hidden_size, group_size, gate_up_ba_[0],
up_bb_[expert_idx], up_bc_[expert_idx], ith, nth);
up_bc_[expert_idx]->to_mat(qlen, m_local_up_output_ptr_[expert_idx], ith, nth);
} else {
amx::vec_mul_kgroup(qlen, config_.intermediate_size, config_.hidden_size, group_size, gate_up_ba_[0],
gate_bb_[expert_idx], gate_bc_[expert_idx], ith, nth);
gate_bc_[expert_idx]->to_mat(qlen, m_local_gate_output_ptr_[expert_idx], ith, nth);
}
},
nullptr);
#ifdef FORWARD_TIME_PROFILE
{
auto now_time = std::chrono::high_resolution_clock::now();
up_gate_time = std::chrono::duration_cast<std::chrono::microseconds>(now_time - last).count();
last = now_time;
}
#endif
for (int task_id = 0; task_id < nth * activated_expert; task_id++) {
int expert_idx = m_expert_id_map_[task_id / nth];
int ith = task_id % nth;
auto [n_start, n_end] = T::split_range_n(config_.intermediate_size, ith, nth);
for (int i = 0; i < qlen; i++) {
ggml_bf16_t* gate_output_ptr = &m_local_gate_output_ptr_[expert_idx][i * config_.intermediate_size];
ggml_bf16_t* up_output_ptr = &m_local_up_output_ptr_[expert_idx][i * config_.intermediate_size];
for (int j = n_start; j < n_end; j += 32) {
__m512 gate_val0, gate_val1, up_val0, up_val1;
avx512_32xbf16_to_32xfp32((__m512i*)(gate_output_ptr + j), &gate_val0, &gate_val1);
avx512_32xbf16_to_32xfp32((__m512i*)(up_output_ptr + j), &up_val0, &up_val1);
__m512 result0 = amx::act_fn(gate_val0, up_val0);
__m512 result1 = amx::act_fn(gate_val1, up_val1);
avx512_32xfp32_to_32xbf16(&result0, &result1, (__m512i*)(gate_output_ptr + j));
}
}
}
#ifdef FORWARD_TIME_PROFILE
{
auto now_time = std::chrono::high_resolution_clock::now();
act_time = std::chrono::duration_cast<std::chrono::microseconds>(now_time - last).count();
last = now_time;
}
#endif
pool->do_work_stealing_job(
activated_expert, nullptr,
[this, qlen](int task_id) {
int expert_idx = m_expert_id_map_[task_id];
down_ba_[expert_idx]->from_mat(qlen, m_local_gate_output_ptr_[expert_idx], 0, 1);
},
nullptr);
#ifdef FORWARD_TIME_PROFILE
{
auto now_time = std::chrono::high_resolution_clock::now();
q_down_time = std::chrono::duration_cast<std::chrono::microseconds>(now_time - last).count();
last = now_time;
}
#endif
nth = T::recommended_nth(config_.hidden_size);
pool->do_work_stealing_job(
nth * activated_expert, [](int _) { T::config(); },
[this, nth, qlen](int task_id) {
int& group_size = config_.quant_config.group_size;
int expert_idx = m_expert_id_map_[task_id / nth];
int ith = task_id % nth;
amx::vec_mul_kgroup(qlen, config_.hidden_size, config_.intermediate_size, group_size, down_ba_[expert_idx],
down_bb_[expert_idx], down_bc_[expert_idx], ith, nth);
down_bc_[expert_idx]->to_mat(qlen, m_local_down_output_ptr_[expert_idx], ith, nth);
},
nullptr);
#ifdef FORWARD_TIME_PROFILE
{
auto now_time = std::chrono::high_resolution_clock::now();
down_time = std::chrono::duration_cast<std::chrono::microseconds>(now_time - last).count();
last = now_time;
}
#endif
for (int i = 0; i < qlen; i++) {
for (int e = 0; e < config_.hidden_size; e += 32) {
__m512 x0 = _mm512_setzero_ps();
__m512 x1 = _mm512_setzero_ps();
for (int j = 0; j < k; j++) {
if (expert_ids[i * k + j] < config_.num_gpu_experts || expert_ids[i * k + j] >= config_.expert_num) {
continue;
}
__m512 weight = _mm512_set1_ps(weights[i * k + j]);
__m512 down_output0, down_output1;
avx512_32xbf16_to_32xfp32((__m512i*)(m_local_down_output_ptr_[expert_ids[i * k + j]] +
m_local_pos_[i][j] * config_.hidden_size + e),
&down_output0, &down_output1);
x0 = _mm512_fmadd_ps(down_output0, weight, x0);
x1 = _mm512_fmadd_ps(down_output1, weight, x1);
}
auto f32out = (__m512*)((float*)output + i * config_.hidden_size + e);
f32out[0] = x0;
f32out[1] = x1;
}
}
#ifdef FORWARD_TIME_PROFILE
{
auto now_time = std::chrono::high_resolution_clock::now();
weight_time = std::chrono::duration_cast<std::chrono::microseconds>(now_time - last).count();
last = now_time;
}
auto end_time = std::chrono::high_resolution_clock::now();
auto forward_total_time = std::chrono::duration_cast<std::chrono::microseconds>(end_time - start_time).count();
// 在函数末尾一次性打印所有阶段的耗时,并附带 max_local_num 和 qlen
printf(
"Profiling Results (numa[%d]): activated_expert: %d, q_input: %ld us, "
"up_gate: %ld us, act: %ld us, q_down: %ld us, down: %ld us, weight: %ld us, total: %ld us\n",
tp_part_idx, activated_expert, q_input_time, up_gate_time, act_time, q_down_time, down_time, weight_time,
forward_total_time);
#endif
}
// forward, forward_prefill, forward_decode, warm_up are inherited from Base
};
// ============================================================================
// TP_MOE specialization for AMX_AWQ_MOE_TP
// Inherits from TP_MOE<AMX_MOE_BASE<...>> to reuse merge_results implementation
// ============================================================================
template <typename K>
class TP_MOE<AMX_AWQ_MOE_TP<K>> : public TP_MOE_Common<AMX_AWQ_MOE_TP<K>> {
class TP_MOE<AMX_AWQ_MOE_TP<K>> : public TP_MOE<AMX_MOE_BASE<K, AMX_AWQ_MOE_TP<K>>> {
public:
using TP_MOE_Common<AMX_AWQ_MOE_TP<K>>::TP_MOE_Common;
void load_weights() {
using Base = TP_MOE<AMX_MOE_BASE<K, AMX_AWQ_MOE_TP<K>>>;
using Base::Base;
void load_weights() override {
auto& config = this->config;
auto& tps = this->tps;
auto& tp_count = this->tp_count;
@@ -1157,7 +696,7 @@ class TP_MOE<AMX_AWQ_MOE_TP<K>> : public TP_MOE_Common<AMX_AWQ_MOE_TP<K>> {
((expert_id * config.intermediate_size * config.hidden_size + i * weight_elem_count) >> 1),
((sizeof(uint8_t) * weight_elem_count) >> 1));
// zeros TP-slicing
// down scales and zeros TP-slicing
memcpy((ggml_fp16_t*)tpc.down_scale + (expert_id * scales_elem_count),
(ggml_fp16_t*)config.down_scale +
(expert_id * (config.intermediate_size / group_size) * config.hidden_size +
@@ -1172,7 +711,7 @@ class TP_MOE<AMX_AWQ_MOE_TP<K>> : public TP_MOE_Common<AMX_AWQ_MOE_TP<K>> {
(sizeof(uint8_t) * scales_elem_count) >> 1);
for (size_t kg = 0; kg < config.hidden_size / group_size; kg++) {
// copy scale
// copy gate/up scales
memcpy((ggml_fp16_t*)tpc.gate_scale + (expert_id * scales_elem_count) + kg * tpc.intermediate_size,
(ggml_fp16_t*)config.gate_scale +
(expert_id * ((config.hidden_size / group_size) * config.intermediate_size) +
@@ -1185,7 +724,7 @@ class TP_MOE<AMX_AWQ_MOE_TP<K>> : public TP_MOE_Common<AMX_AWQ_MOE_TP<K>> {
kg * config.intermediate_size + i * tpc.intermediate_size),
(sizeof(ggml_fp16_t) * tpc.intermediate_size));
// zeros TP-slicing
// copy gate/up zeros TP-slicing
memcpy(
(uint8_t*)tpc.gate_zero + (((expert_id * scales_elem_count) + kg * tpc.intermediate_size) >> 1),
(uint8_t*)config.gate_zero +
@@ -1202,6 +741,7 @@ class TP_MOE<AMX_AWQ_MOE_TP<K>> : public TP_MOE_Common<AMX_AWQ_MOE_TP<K>> {
((sizeof(uint8_t) * tpc.intermediate_size) >> 1));
}
// down weights TP-slicing (column-wise)
for (size_t col = 0; col < config.hidden_size; col++) {
memcpy((uint8_t*)tpc.down_proj + ((expert_id * weight_elem_count + col * tpc.intermediate_size) >> 1),
(uint8_t*)config.down_proj + ((expert_id * config.intermediate_size * config.hidden_size +
@@ -1285,37 +825,7 @@ class TP_MOE<AMX_AWQ_MOE_TP<K>> : public TP_MOE_Common<AMX_AWQ_MOE_TP<K>> {
}
}
void merge_results(int qlen, void* output, bool incremental) {
auto pool = this->config.pool;
auto merge_fn = [this, output, incremental](int token_nth) {
auto& local_output_numa = this->local_output_numa;
auto& tp_configs = this->tp_configs;
auto& tp_count = this->tp_count;
auto& config = this->config;
float* merge_to = local_output_numa[0] + token_nth * tp_configs[0].hidden_size;
if (incremental) {
for (int e = 0; e < config.hidden_size; e += 32) {
__m512 x0, x1;
avx512_32xbf16_to_32xfp32((__m512i*)((ggml_bf16_t*)output + token_nth * config.hidden_size + e), &x0, &x1);
*((__m512*)(merge_to + e)) = _mm512_add_ps(*((__m512*)(merge_to + e)), x0);
*((__m512*)(merge_to + e + 16)) = _mm512_add_ps(*((__m512*)(merge_to + e + 16)), x1);
}
}
for (int i = 1; i < tp_count; i++) {
float* merge_from = local_output_numa[i] + token_nth * tp_configs[i].hidden_size;
for (int e = 0; e < tp_configs[i].hidden_size; e += 16) {
*((__m512*)(merge_to + e)) = _mm512_add_ps(*((__m512*)(merge_to + e)), *((__m512*)(merge_from + e)));
}
}
for (int e = 0; e < config.hidden_size; e += 32) {
__m512 x0 = *(__m512*)(merge_to + e);
__m512 x1 = *(__m512*)(merge_to + e + 16);
avx512_32xfp32_to_32xbf16(&x0, &x1, (__m512i*)((ggml_bf16_t*)output + token_nth * config.hidden_size + e));
}
};
DIRECT_OR_POOL_BY_QLEN(qlen, merge_fn);
}
void merge_results(int qlen, void* output) { merge_results(qlen, output, false); }
// merge_results is inherited from TP_MOE<AMX_MOE_BASE<K, AMX_AWQ_MOE_TP<K>>>
};
#endif

View File

@@ -0,0 +1,782 @@
/**
* @Description : FP8 AMX MoE operator for DeepSeek V3.2 native inference
* @Author : oql, Codex and Claude
* @Date : 2025-12-09
* @Version : 1.0.0
* @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
*
* This file implements FP8 MoE using CRTP pattern, inheriting from moe_base.hpp.
* FP8 weights are stored with 128x128 block-wise scales.
**/
#ifndef CPUINFER_OPERATOR_AMX_FP8_MOE_H
#define CPUINFER_OPERATOR_AMX_FP8_MOE_H
// #define DEBUG_FP8_MOE
#include <immintrin.h>
#include <algorithm>
#include <cstddef>
#include <cstdint>
#include <cstring>
#include <memory>
#include <string>
#include <vector>
#include "la/amx_raw_buffers.hpp"
#include "la/amx_raw_kernels.hpp"
#include "moe_base.hpp"
/**
* @brief FP8 MoE operator using CRTP pattern
* @tparam T Kernel type, defaults to GemmKernel224FP8
*
* This class provides FP8-specific implementations:
* - do_gate_up_gemm, do_down_gemm : FP8 weight -> BF16 conversion mat mul
* - load_weights: Load FP8 weights with 128x128 block scales
*/
template <class T = amx::GemmKernel224FP8>
class AMX_FP8_MOE_TP : public AMX_MOE_BASE<T, AMX_FP8_MOE_TP<T>> {
using Base = AMX_MOE_BASE<T, AMX_FP8_MOE_TP<T>>;
using Base::config_;
using Base::down_ba_;
using Base::down_bb_;
using Base::down_bc_;
using Base::gate_bb_;
using Base::gate_bc_;
using Base::gate_up_ba_;
using Base::m_local_num_;
using Base::tp_part_idx;
using Base::up_bb_;
using Base::up_bc_;
public:
using typename Base::input_t;
using typename Base::output_t;
AMX_FP8_MOE_TP() = default;
AMX_FP8_MOE_TP(GeneralMOEConfig config, int tp_part_idx_ = 0) : Base(config, tp_part_idx_) {
auto& quant_config = config_.quant_config;
if (quant_config.group_size == 0 || quant_config.zero_point) {
throw std::runtime_error("KT-Kernel fp8 MoE only support block-wise FP8. group_size = %d, zero_point = %d",
quant_config.group_size, quant_config.zero_point);
}
printf("Created AMX_FP8_MOE_TP %d at numa %d\n", tp_part_idx_, numa_node_of_cpu(sched_getcpu()));
}
~AMX_FP8_MOE_TP() = default;
// ============================================================================
// CRTP buffer creation - with group_size
// ============================================================================
size_t buffer_a_required_size_impl(size_t m, size_t k) const { return T::BufferA::required_size(m, k); }
size_t buffer_b_required_size_impl(size_t n, size_t k) const {
return T::BufferB::required_size(n, k, config_.quant_config.group_size);
}
size_t buffer_c_required_size_impl(size_t m, size_t n) const { return T::BufferC::required_size(m, n); }
std::shared_ptr<typename T::BufferA> make_buffer_a_impl(size_t m, size_t k, void* data) const {
return std::make_shared<typename T::BufferA>(m, k, data);
}
std::shared_ptr<typename T::BufferB> make_buffer_b_impl(size_t n, size_t k, void* data) const {
return std::make_shared<typename T::BufferB>(n, k, config_.quant_config.group_size, data);
}
std::shared_ptr<typename T::BufferC> make_buffer_c_impl(size_t m, size_t n, void* data) const {
return std::make_shared<typename T::BufferC>(m, n, data);
}
// ============================================================================
// CRTP virtual points - GEMM dispatch
// ============================================================================
void do_gate_up_gemm(bool do_up, int expert_idx, int ith, int nth, int qlen) {
auto& group_size = config_.quant_config.group_size;
int m = m_local_num_[expert_idx];
auto& ba = gate_up_ba_[expert_idx];
auto& bb = do_up ? up_bb_[expert_idx] : gate_bb_[expert_idx];
auto& bc = do_up ? up_bc_[expert_idx] : gate_bc_[expert_idx];
amx::vec_mul_kgroup(m, config_.intermediate_size, config_.hidden_size, group_size, ba, bb, bc, ith, nth);
}
void do_down_gemm(int expert_idx, int ith, int nth, int qlen) {
auto& group_size = config_.quant_config.group_size;
int m = m_local_num_[expert_idx];
amx::vec_mul_kgroup(m, config_.hidden_size, config_.intermediate_size, group_size, down_ba_[expert_idx],
down_bb_[expert_idx], down_bc_[expert_idx], ith, nth);
}
#ifdef DEBUG_FP8_MOE
// Function to dump Buffer B data for debugging FP8 quantization results
inline void dump_buffer_b(const std::string& quantization_type, int expert_idx, const std::string& matrix_type,
typename T::BufferB* buffer) {
auto& quant_config = config_.quant_config;
int& group_size = quant_config.group_size;
printf("[DUMP_BUFFER_B] TP%d %s Expert%d %s:\n", tp_part_idx, quantization_type.c_str(), expert_idx,
matrix_type.c_str());
// Calculate dimensions based on matrix type
int rows, cols;
size_t scale_elem_count;
if (matrix_type == "gate" || matrix_type == "up") {
rows = config_.intermediate_size;
cols = config_.hidden_size;
} else { // down
rows = config_.hidden_size;
cols = config_.intermediate_size;
}
int n_blocks_n = (rows + group_size - 1) / group_size;
int n_blocks_k = (cols + group_size - 1) / group_size;
scale_elem_count = n_blocks_n * n_blocks_k;
// Dump scales (as BF16 converted to float)
printf(" Scales[first 16]: ");
for (int i = 0; i < std::min(16, (int)scale_elem_count); i++) {
printf("%.6f ", buffer->d[i]);
}
printf("\n");
if (scale_elem_count > 16) {
printf(" Scales[last 16]: ");
int start_idx = std::max(0, (int)scale_elem_count - 16);
for (int i = start_idx; i < (int)scale_elem_count; i++) {
printf("%.6f ", buffer->d[i]);
}
printf("\n");
}
// Dump FP8 weights (as hex uint8)
size_t weight_size = (size_t)rows * cols; // FP8 is 1 byte per element
uint8_t* weight_ptr = (uint8_t*)buffer->b;
printf(" FP8 Weights[first 32 bytes]: ");
for (int i = 0; i < std::min(32, (int)weight_size); i++) {
printf("%02x ", weight_ptr[i]);
}
printf("\n");
if (weight_size > 32) {
printf(" FP8 Weights[last 32 bytes]: ");
int start_idx = std::max(32, (int)weight_size - 32);
for (int i = start_idx; i < (int)weight_size; i++) {
printf("%02x ", weight_ptr[i]);
}
printf("\n");
}
printf(" Matrix dimensions: %dx%d (n x k), Scale blocks: %dx%d, Group size: %d, Scale elements: %zu\n", rows, cols,
n_blocks_n, n_blocks_k, group_size, scale_elem_count);
}
#endif
/**
* @brief Load FP8 weights from contiguous memory layout
*
* Loads weights from config_.gate_proj, up_proj, down_proj with scales
* from config_.gate_scale, up_scale, down_scale.
*/
void load_weights() {
auto& quant_config = config_.quant_config;
int& group_size = quant_config.group_size;
const uint64_t* physical_to_logical_map = (const uint64_t*)config_.physical_to_logical_map;
auto pool = config_.pool->get_subpool(tp_part_idx);
if (config_.gate_scale == nullptr) {
throw std::runtime_error("FP8 AVX MOE only support native weight.");
}
// load weight
int nth = T::recommended_nth(config_.intermediate_size);
pool->do_work_stealing_job(
nth * config_.expert_num, nullptr,
[this, nth, physical_to_logical_map, group_size](int task_id) {
uint64_t expert_idx = task_id / nth;
uint64_t logical_expert_id = expert_map(physical_to_logical_map, expert_idx);
int ith = task_id % nth;
// gate part
gate_bb_[expert_idx]->from_mat(
(uint8_t*)config_.gate_proj + (logical_expert_id * config_.intermediate_size * config_.hidden_size),
(float*)config_.gate_scale +
(logical_expert_id * (config_.hidden_size / group_size) * (config_.intermediate_size / group_size)),
ith, nth);
// up part
up_bb_[expert_idx]->from_mat(
(uint8_t*)config_.up_proj + (logical_expert_id * config_.intermediate_size * config_.hidden_size),
(float*)config_.up_scale +
(logical_expert_id * (config_.hidden_size / group_size) * (config_.intermediate_size / group_size)),
ith, nth);
},
nullptr);
nth = T::recommended_nth(config_.hidden_size);
pool->do_work_stealing_job(
nth * config_.expert_num, nullptr,
[this, nth, physical_to_logical_map, group_size](int task_id) {
uint64_t expert_idx = task_id / nth;
uint64_t logical_expert_id = expert_map(physical_to_logical_map, expert_idx);
int ith = task_id % nth;
// down part
down_bb_[expert_idx]->from_mat(
(uint8_t*)config_.down_proj + (logical_expert_id * config_.intermediate_size * config_.hidden_size),
(float*)config_.down_scale +
(logical_expert_id * (config_.hidden_size / group_size) * (config_.intermediate_size / group_size)),
ith, nth);
},
nullptr);
#ifdef DEBUG_FP8_MOE
dump_buffer_b("Native FP8", 0, "gate", gate_bb_[0].get());
dump_buffer_b("Native FP8", 0, "down", down_bb_[0].get());
#endif
}
// Fast 64-byte (512-bit) memcpy using AVX512
static inline void fast_memcpy_64(void* __restrict dst, const void* __restrict src) {
__m512i data = _mm512_loadu_si512(src);
_mm512_storeu_si512(dst, data);
}
// Fast memcpy for arbitrary sizes using AVX512
static inline void fast_memcpy(void* __restrict dst, const void* __restrict src, size_t bytes) {
uint8_t* d = (uint8_t*)dst;
const uint8_t* s = (const uint8_t*)src;
size_t chunks = bytes / 64;
for (size_t i = 0; i < chunks; i++) {
fast_memcpy_64(d, s);
d += 64;
s += 64;
}
bytes -= chunks * 64;
if (bytes > 0) {
std::memcpy(d, s, bytes);
}
}
/**
* @brief Unpack a single N_STEP x K_STEP block from packed BufferB format to n-major format
*
* This is the inverse of the packing done in BufferBFP8Impl::from_mat.
* Optimized with AVX512 gather for efficient non-contiguous reads.
*
* @param src Pointer to packed data (N_STEP * K_STEP bytes in packed layout)
* @param dst Pointer to destination in n-major layout
* @param dst_row_stride Row stride in destination buffer (number of columns in full matrix)
*/
static inline void unpack_nk_block(const uint8_t* src, uint8_t* dst, size_t dst_row_stride) {
// row_map[packed_i] gives the base row for packed index packed_i
static constexpr int row_map[8] = {0, 16, 4, 20, 8, 24, 12, 28};
const uint64_t* src64 = reinterpret_cast<const uint64_t*>(src);
// Gather indices: src64[8*j + packed_i] for j = 0..7
// Offsets in uint64 units: 0, 8, 16, 24, 32, 40, 48, 56 (+ packed_i for each group)
const __m512i gather_offsets = _mm512_set_epi64(56, 48, 40, 32, 24, 16, 8, 0);
// Process each packed group (8 groups of 4 rows each = 32 rows total)
for (int packed_i = 0; packed_i < 8; packed_i++) {
const int base_row = row_map[packed_i];
const uint64_t* base_src = src64 + packed_i;
// Gather 8 values for j=0..7 and j=8..15
__m512i vals_0_7 = _mm512_i64gather_epi64(gather_offsets, base_src, 8);
__m512i vals_8_15 = _mm512_i64gather_epi64(gather_offsets, base_src + 64, 8);
// Extract 4 rows from each set of 8 values
// Row 0: bits 0-15
__m128i row0_lo = _mm512_cvtepi64_epi16(_mm512_and_si512(vals_0_7, _mm512_set1_epi64(0xFFFF)));
__m128i row0_hi = _mm512_cvtepi64_epi16(_mm512_and_si512(vals_8_15, _mm512_set1_epi64(0xFFFF)));
// Row 1: bits 16-31
__m128i row1_lo =
_mm512_cvtepi64_epi16(_mm512_and_si512(_mm512_srli_epi64(vals_0_7, 16), _mm512_set1_epi64(0xFFFF)));
__m128i row1_hi =
_mm512_cvtepi64_epi16(_mm512_and_si512(_mm512_srli_epi64(vals_8_15, 16), _mm512_set1_epi64(0xFFFF)));
// Row 2: bits 32-47
__m128i row2_lo =
_mm512_cvtepi64_epi16(_mm512_and_si512(_mm512_srli_epi64(vals_0_7, 32), _mm512_set1_epi64(0xFFFF)));
__m128i row2_hi =
_mm512_cvtepi64_epi16(_mm512_and_si512(_mm512_srli_epi64(vals_8_15, 32), _mm512_set1_epi64(0xFFFF)));
// Row 3: bits 48-63
__m128i row3_lo = _mm512_cvtepi64_epi16(_mm512_srli_epi64(vals_0_7, 48));
__m128i row3_hi = _mm512_cvtepi64_epi16(_mm512_srli_epi64(vals_8_15, 48));
// Store 32 bytes (16 x uint16) to each row
// Combine two 128-bit values into 256-bit for more efficient stores
uint8_t* row0_dst = dst + (size_t)base_row * dst_row_stride;
uint8_t* row1_dst = dst + (size_t)(base_row + 1) * dst_row_stride;
uint8_t* row2_dst = dst + (size_t)(base_row + 2) * dst_row_stride;
uint8_t* row3_dst = dst + (size_t)(base_row + 3) * dst_row_stride;
// Combine lo and hi into 256-bit and store
__m256i row0_256 = _mm256_set_m128i(row0_hi, row0_lo);
__m256i row1_256 = _mm256_set_m128i(row1_hi, row1_lo);
__m256i row2_256 = _mm256_set_m128i(row2_hi, row2_lo);
__m256i row3_256 = _mm256_set_m128i(row3_hi, row3_lo);
_mm256_storeu_si256((__m256i*)row0_dst, row0_256);
_mm256_storeu_si256((__m256i*)row1_dst, row1_256);
_mm256_storeu_si256((__m256i*)row2_dst, row2_256);
_mm256_storeu_si256((__m256i*)row3_dst, row3_256);
}
}
/**
* @brief Unpack 4 consecutive N_STEP x K_STEP blocks to maximize cache line utilization
*
* Processing 4 blocks together means each row write is 128 bytes = 2 cache lines,
* which greatly improves write efficiency compared to 32 bytes per row.
*
* @param src Array of 4 source pointers (each pointing to a 32x32 packed block)
* @param dst Destination pointer in n-major layout
* @param dst_row_stride Row stride in destination buffer
*/
static inline void unpack_4nk_blocks(const uint8_t* src[4], uint8_t* dst, size_t dst_row_stride) {
static constexpr int row_map[8] = {0, 16, 4, 20, 8, 24, 12, 28};
constexpr int K_STEP = T::K_STEP; // 32
// Reinterpret as uint64 arrays for efficient access
const uint64_t* src0 = reinterpret_cast<const uint64_t*>(src[0]);
const uint64_t* src1 = reinterpret_cast<const uint64_t*>(src[1]);
const uint64_t* src2 = reinterpret_cast<const uint64_t*>(src[2]);
const uint64_t* src3 = reinterpret_cast<const uint64_t*>(src[3]);
// Process all 32 rows, writing 128 bytes (4 x 32) per row
for (int packed_i = 0; packed_i < 8; packed_i++) {
const int base_row = row_map[packed_i];
// Process 4 rows at a time
for (int r = 0; r < 4; r++) {
uint16_t* row_dst = reinterpret_cast<uint16_t*>(dst + (size_t)(base_row + r) * dst_row_stride);
const int shift = r * 16;
// Unroll: process all 4 blocks x 16 columns = 64 uint16 values
// Block 0: columns 0-15
for (int j = 0; j < 16; j++) {
row_dst[j] = static_cast<uint16_t>(src0[8 * j + packed_i] >> shift);
}
// Block 1: columns 16-31
for (int j = 0; j < 16; j++) {
row_dst[16 + j] = static_cast<uint16_t>(src1[8 * j + packed_i] >> shift);
}
// Block 2: columns 32-47
for (int j = 0; j < 16; j++) {
row_dst[32 + j] = static_cast<uint16_t>(src2[8 * j + packed_i] >> shift);
}
// Block 3: columns 48-63
for (int j = 0; j < 16; j++) {
row_dst[48 + j] = static_cast<uint16_t>(src3[8 * j + packed_i] >> shift);
}
}
}
}
/**
* @brief Reconstruct weights for a single expert to the output buffers (no temp buffer version)
*
* Directly unpacks from packed BufferB format to n-major GPU buffers without intermediate storage.
* Optimized version with coarse-grained task splitting for better cache utilization.
*
* Key optimizations:
* - Reduced task count (~40 vs ~350) to minimize scheduling overhead
* - Larger chunks per task for better cache line utilization
* - Process multiple N_STEPs per task for better write locality
*
* @param gpu_tp_count Number of GPU TP parts (1, 2, 4, or 8)
* @param cpu_tp_count Number of CPU TP parts
* @param expert_id Expert index to process
* @param full_config Full configuration (before CPU TP split)
* @param w13_weight_ptrs Pointers to gate+up weight buffers (one per GPU TP)
* @param w13_scale_ptrs Pointers to gate+up scale buffers (one per GPU TP)
* @param w2_weight_ptrs Pointers to down weight buffers (one per GPU TP)
* @param w2_scale_ptrs Pointers to down scale buffers (one per GPU TP)
*/
void write_weights_to_buffer(int gpu_tp_count, [[maybe_unused]] int cpu_tp_count, int expert_id,
const GeneralMOEConfig& full_config, const std::vector<uintptr_t>& w13_weight_ptrs,
const std::vector<uintptr_t>& w13_scale_ptrs,
const std::vector<uintptr_t>& w2_weight_ptrs,
const std::vector<uintptr_t>& w2_scale_ptrs) const {
auto& config = config_;
const int group_size = config.quant_config.group_size;
auto pool = config.pool->get_subpool(tp_part_idx);
constexpr int N_STEP = T::N_STEP;
constexpr int K_STEP = T::K_STEP;
constexpr int N_BLOCK = T::N_BLOCK;
constexpr int K_BLOCK = T::K_BLOCK;
// ========= W13 (gate+up): Shape [intermediate, hidden], split by N only =========
const int cpu_n_w13 = config.intermediate_size;
const int cpu_k_w13 = config.hidden_size;
const int gpu_n_w13 = full_config.intermediate_size / gpu_tp_count;
const int gpu_k_w13 = full_config.hidden_size;
const int global_n_offset_w13 = tp_part_idx * cpu_n_w13;
const size_t gpu_w13_weight_per_mat = (size_t)gpu_n_w13 * gpu_k_w13;
const size_t gpu_w13_scale_per_mat = (size_t)div_up(gpu_n_w13, group_size) * div_up(gpu_k_w13, group_size);
const int cpu_scale_k_blocks_w13 = div_up(cpu_k_w13, group_size);
const int gpu_scale_k_blocks_w13 = div_up(gpu_k_w13, group_size);
// ========= W2 (down): Shape [hidden, intermediate], split by K =========
const int cpu_n_w2 = config.hidden_size;
const int cpu_k_w2 = config.intermediate_size;
const int gpu_n_w2 = full_config.hidden_size;
const int gpu_k_w2 = full_config.intermediate_size / gpu_tp_count;
const int global_k_offset_w2 = tp_part_idx * cpu_k_w2;
const size_t gpu_w2_weight_per_mat = (size_t)gpu_n_w2 * gpu_k_w2;
const size_t gpu_w2_scale_per_mat = (size_t)div_up(gpu_n_w2, group_size) * div_up(gpu_k_w2, group_size);
const int cpu_scale_k_blocks_w2 = div_up(cpu_k_w2, group_size);
const int gpu_scale_k_blocks_w2 = div_up(gpu_k_w2, group_size);
// ========= Scale dimensions =========
const int cpu_scale_n_blocks_w13 = div_up(cpu_n_w13, group_size);
const int gpu_scale_n_blocks_w13 = div_up(gpu_n_w13, group_size);
const int cpu_scale_n_blocks_w2 = div_up(cpu_n_w2, group_size);
// ========= Optimized job layout =========
// Use task count slightly above CPU core count for good work stealing
// For 80-core system, ~100 tasks provides good balance
constexpr int NUM_W13_TASKS = 32; // Per matrix (gate or up), total 64 for w13
constexpr int NUM_W2_TASKS = 32; // For down matrix
constexpr int SCALE_TASKS = 3; // gate_scale, up_scale, down_scale
const int total_tasks = NUM_W13_TASKS * 2 + NUM_W2_TASKS + SCALE_TASKS;
// Calculate N_STEP blocks per task (must be N_STEP aligned for correct BufferB addressing)
const int w13_n_steps = div_up(cpu_n_w13, N_STEP);
const int w13_steps_per_task = div_up(w13_n_steps, NUM_W13_TASKS);
const int w2_n_steps = div_up(cpu_n_w2, N_STEP);
const int w2_steps_per_task = div_up(w2_n_steps, NUM_W2_TASKS);
pool->do_work_stealing_job(
total_tasks, nullptr,
[=, &w13_weight_ptrs, &w13_scale_ptrs, &w2_weight_ptrs, &w2_scale_ptrs, this](int task_id) {
if (task_id < NUM_W13_TASKS * 2) {
// ========= W13 weight task: process chunk of rows x full K =========
const bool is_up = task_id >= NUM_W13_TASKS;
const int chunk_idx = task_id % NUM_W13_TASKS;
const auto& bb = is_up ? up_bb_[expert_id] : gate_bb_[expert_id];
// Calculate row range for this task (N_STEP aligned)
const int step_start = chunk_idx * w13_steps_per_task;
const int step_end = std::min(step_start + w13_steps_per_task, w13_n_steps);
if (step_start >= w13_n_steps) return;
const int chunk_n_start = step_start * N_STEP;
const int chunk_n_end = std::min(step_end * N_STEP, cpu_n_w13);
// Process each N_STEP within this chunk
for (int local_n_start = chunk_n_start; local_n_start < chunk_n_end; local_n_start += N_STEP) {
// Calculate GPU target and offset for each N_STEP (may cross GPU TP boundaries)
const int global_n = global_n_offset_w13 + local_n_start;
const int target_gpu = global_n / gpu_n_w13;
const int n_in_gpu = global_n % gpu_n_w13;
uint8_t* weight_base = (uint8_t*)w13_weight_ptrs[target_gpu];
// Pointer already points to current expert's location, only add offset for up matrix
const size_t expert_weight_off = is_up ? gpu_w13_weight_per_mat : 0;
// Calculate N_BLOCK info for source addressing
const int n_block_idx = local_n_start / N_BLOCK;
const int n_block_begin = n_block_idx * N_BLOCK;
const int n_block_size = std::min(N_BLOCK, cpu_n_w13 - n_block_begin);
const int n_in_block = local_n_start - n_block_begin;
// Process all K in groups of 4 K_STEPs when possible for cache efficiency
for (int k_block_begin = 0; k_block_begin < cpu_k_w13; k_block_begin += K_BLOCK) {
const int k_block_size = std::min(K_BLOCK, cpu_k_w13 - k_block_begin);
// Try to process 4 K_STEPs at once (128 columns = 2 cache lines per row)
int k_begin = 0;
for (; k_begin + 4 * K_STEP <= k_block_size; k_begin += 4 * K_STEP) {
const uint8_t* src_ptrs[4];
for (int i = 0; i < 4; i++) {
src_ptrs[i] = bb->b + (size_t)n_block_begin * cpu_k_w13 + (size_t)k_block_begin * n_block_size +
(size_t)n_in_block * k_block_size + (size_t)(k_begin + i * K_STEP) * N_STEP;
}
uint8_t* dst =
weight_base + expert_weight_off + (size_t)n_in_gpu * gpu_k_w13 + k_block_begin + k_begin;
unpack_4nk_blocks(src_ptrs, dst, gpu_k_w13);
}
// Handle remaining K_STEPs one by one
for (; k_begin < k_block_size; k_begin += K_STEP) {
const uint8_t* src = bb->b + (size_t)n_block_begin * cpu_k_w13 +
(size_t)k_block_begin * n_block_size + (size_t)n_in_block * k_block_size +
(size_t)k_begin * N_STEP;
uint8_t* dst =
weight_base + expert_weight_off + (size_t)n_in_gpu * gpu_k_w13 + k_block_begin + k_begin;
unpack_nk_block(src, dst, gpu_k_w13);
}
}
}
} else if (task_id < NUM_W13_TASKS * 2 + NUM_W2_TASKS) {
// ========= W2 weight task: process chunk of rows x all K slices =========
const int chunk_idx = task_id - NUM_W13_TASKS * 2;
const auto& bb = down_bb_[expert_id];
// Calculate row range for this task (N_STEP aligned)
const int step_start = chunk_idx * w2_steps_per_task;
const int step_end = std::min(step_start + w2_steps_per_task, w2_n_steps);
if (step_start >= w2_n_steps) return;
const int chunk_n_start = step_start * N_STEP;
const int chunk_n_end = std::min(step_end * N_STEP, cpu_n_w2);
// Process each N_STEP within this chunk
for (int local_n_start = chunk_n_start; local_n_start < chunk_n_end; local_n_start += N_STEP) {
// Calculate N_BLOCK info for source addressing
const int n_block_idx = local_n_start / N_BLOCK;
const int n_block_begin = n_block_idx * N_BLOCK;
const int n_block_size = std::min(N_BLOCK, cpu_n_w2 - n_block_begin);
const int n_in_block = local_n_start - n_block_begin;
// Process all K slices (each slice goes to a different GPU TP)
for (int k_slice_start = 0; k_slice_start < cpu_k_w2; k_slice_start += gpu_k_w2) {
const int k_slice_end = std::min(k_slice_start + gpu_k_w2, cpu_k_w2);
const int global_k_start = global_k_offset_w2 + k_slice_start;
const int target_gpu = global_k_start / gpu_k_w2;
const int k_in_gpu_base = global_k_start % gpu_k_w2;
uint8_t* weight_base = (uint8_t*)w2_weight_ptrs[target_gpu];
// Pointer already points to current expert's location
const size_t expert_weight_off = 0;
// Process K within this slice, trying 4 K_STEPs at once when aligned
for (int k_abs = k_slice_start; k_abs < k_slice_end;) {
const int k_block_idx = k_abs / K_BLOCK;
const int k_block_begin = k_block_idx * K_BLOCK;
const int k_block_size = std::min(K_BLOCK, cpu_k_w2 - k_block_begin);
const int k_in_block = k_abs - k_block_begin;
const int k_in_gpu = k_in_gpu_base + (k_abs - k_slice_start);
// Check if we can process 4 K_STEPs at once
const int remaining_in_block = k_block_size - k_in_block;
const int remaining_in_slice = k_slice_end - k_abs;
if (remaining_in_block >= 4 * K_STEP && remaining_in_slice >= 4 * K_STEP) {
const uint8_t* src_ptrs[4];
for (int i = 0; i < 4; i++) {
src_ptrs[i] = bb->b + (size_t)n_block_begin * cpu_k_w2 + (size_t)k_block_begin * n_block_size +
(size_t)n_in_block * k_block_size + (size_t)(k_in_block + i * K_STEP) * N_STEP;
}
uint8_t* dst = weight_base + expert_weight_off + (size_t)local_n_start * gpu_k_w2 + k_in_gpu;
unpack_4nk_blocks(src_ptrs, dst, gpu_k_w2);
k_abs += 4 * K_STEP;
} else {
const uint8_t* src = bb->b + (size_t)n_block_begin * cpu_k_w2 +
(size_t)k_block_begin * n_block_size + (size_t)n_in_block * k_block_size +
(size_t)k_in_block * N_STEP;
uint8_t* dst = weight_base + expert_weight_off + (size_t)local_n_start * gpu_k_w2 + k_in_gpu;
unpack_nk_block(src, dst, gpu_k_w2);
k_abs += K_STEP;
}
}
}
}
} else {
// ========= Scale copy task: simple linear copy with fast_memcpy =========
const int scale_task_id = task_id - NUM_W13_TASKS * 2 - NUM_W2_TASKS;
if (scale_task_id < 2) {
// Gate (0) or Up (1) scale copy
const bool is_up = scale_task_id == 1;
const auto& bb = is_up ? up_bb_[expert_id] : gate_bb_[expert_id];
// W13 scales: copy N blocks corresponding to this CPU TP
// Note: when gpu_tp > cpu_tp, scale blocks may span multiple GPU TPs
const int bn_start_global = global_n_offset_w13 / group_size;
for (int bn = 0; bn < cpu_scale_n_blocks_w13; bn++) {
const int global_bn = bn_start_global + bn;
const int target_gpu = global_bn / gpu_scale_n_blocks_w13;
const int gpu_bn = global_bn % gpu_scale_n_blocks_w13;
float* scale_dst = (float*)w13_scale_ptrs[target_gpu];
// Pointer already points to current expert's location, only add offset for up matrix
const size_t expert_scale_off = is_up ? gpu_w13_scale_per_mat : 0;
fast_memcpy(scale_dst + expert_scale_off + (size_t)gpu_bn * gpu_scale_k_blocks_w13,
bb->d + (size_t)bn * cpu_scale_k_blocks_w13, cpu_scale_k_blocks_w13 * sizeof(float));
}
} else {
// Down scale copy (scale_task_id == 2)
const auto& bb = down_bb_[expert_id];
// W2 scales: K dimension is split, copy to each GPU TP
for (int k_slice_idx = 0; k_slice_idx < div_up(cpu_k_w2, gpu_k_w2); k_slice_idx++) {
const int k_slice_start = k_slice_idx * gpu_k_w2;
const int k_slice_end = std::min(k_slice_start + gpu_k_w2, cpu_k_w2);
const int global_k_start = global_k_offset_w2 + k_slice_start;
const int target_gpu = global_k_start / gpu_k_w2;
const int bk_gpu_base = (global_k_start % gpu_k_w2) / group_size;
float* scale_dst = (float*)w2_scale_ptrs[target_gpu];
// Pointer already points to current expert's location
const size_t expert_scale_off = 0;
const int bk_start = k_slice_start / group_size;
const int bk_end = div_up(k_slice_end, group_size);
const int bk_count = bk_end - bk_start;
for (int bn = 0; bn < cpu_scale_n_blocks_w2; bn++) {
fast_memcpy(scale_dst + expert_scale_off + (size_t)bn * gpu_scale_k_blocks_w2 + bk_gpu_base,
bb->d + (size_t)bn * cpu_scale_k_blocks_w2 + bk_start, bk_count * sizeof(float));
}
}
}
}
},
nullptr);
}
};
template <typename K>
class TP_MOE<AMX_FP8_MOE_TP<K>> : public TP_MOE<AMX_MOE_BASE<K, AMX_FP8_MOE_TP<K>>> {
public:
using Base = TP_MOE<AMX_MOE_BASE<K, AMX_FP8_MOE_TP<K>>>;
using Base::Base;
void load_weights() override {
auto& config = this->config;
auto& tps = this->tps;
auto& tp_count = this->tp_count;
auto pool = config.pool;
const uint64_t* physical_to_logical_map = (const uint64_t*)config.physical_to_logical_map;
const int group_size = config.quant_config.group_size;
if (group_size == 0 || config.quant_config.zero_point) {
throw std::runtime_error("FP8 MoE only supports have group_size, zero_point=false");
}
if (config.gate_projs.empty() && config.gate_proj == nullptr) {
throw std::runtime_error("no weight source");
}
const bool use_per_expert_ptrs = !config.gate_projs.empty();
const size_t full_weight_elems = (size_t)config.intermediate_size * config.hidden_size;
const size_t full_scale_elems =
(size_t)div_up(config.hidden_size, group_size) * div_up(config.intermediate_size, group_size);
pool->dispense_backend()->do_numa_job([&, this](int i) {
auto& tpc = tps[i]->config_;
const size_t tp_weight_elems = (size_t)tpc.intermediate_size * tpc.hidden_size;
const size_t tp_scale_elems =
(size_t)div_up(tpc.intermediate_size, group_size) * div_up(tpc.hidden_size, group_size);
tpc.gate_proj = new uint8_t[tpc.expert_num * tp_weight_elems];
tpc.up_proj = new uint8_t[tpc.expert_num * tp_weight_elems];
tpc.down_proj = new uint8_t[tpc.expert_num * tp_weight_elems];
tpc.gate_scale = new float[tpc.expert_num * tp_scale_elems];
tpc.up_scale = new float[tpc.expert_num * tp_scale_elems];
tpc.down_scale = new float[tpc.expert_num * tp_scale_elems];
const size_t tp_idx = (size_t)i;
const size_t gate_up_weight_src_offset = i * tp_weight_elems;
const size_t gate_up_scale_src_offset = i * tp_scale_elems;
const size_t down_weight_src_col_offset = i * (size_t)tpc.intermediate_size;
const size_t down_scale_src_block_k_offset = down_weight_src_col_offset / (size_t)group_size;
pool->get_subpool(i)->do_work_stealing_job(
tpc.expert_num, nullptr,
[&, &tpc](int expert_id_) {
const size_t expert_id = expert_map(physical_to_logical_map, expert_id_);
uint8_t* gate_dst = (uint8_t*)tpc.gate_proj + expert_id * tp_weight_elems;
uint8_t* up_dst = (uint8_t*)tpc.up_proj + expert_id * tp_weight_elems;
uint8_t* down_dst = (uint8_t*)tpc.down_proj + expert_id * tp_weight_elems;
float* gate_scale_dst = (float*)tpc.gate_scale + expert_id * tp_scale_elems;
float* up_scale_dst = (float*)tpc.up_scale + expert_id * tp_scale_elems;
float* down_scale_dst = (float*)tpc.down_scale + expert_id * tp_scale_elems;
const uint8_t* gate_src;
const uint8_t* up_src;
const uint8_t* down_src;
const float* gate_scale_src;
const float* up_scale_src;
const float* down_scale_src;
if (use_per_expert_ptrs) {
gate_src = (const uint8_t*)config.gate_projs[0][expert_id] + gate_up_weight_src_offset;
up_src = (const uint8_t*)config.up_projs[0][expert_id] + gate_up_weight_src_offset;
down_src = (const uint8_t*)config.down_projs[0][expert_id];
gate_scale_src = (const float*)config.gate_scales[0][expert_id] + gate_up_scale_src_offset;
up_scale_src = (const float*)config.up_scales[0][expert_id] + gate_up_scale_src_offset;
down_scale_src = (const float*)config.down_scales[0][expert_id];
} else {
gate_src = (const uint8_t*)config.gate_proj + expert_id * full_weight_elems + gate_up_weight_src_offset;
up_src = (const uint8_t*)config.up_proj + expert_id * full_weight_elems + gate_up_weight_src_offset;
down_src = (const uint8_t*)config.down_proj + expert_id * full_weight_elems;
gate_scale_src =
(const float*)config.gate_scale + expert_id * full_scale_elems + gate_up_scale_src_offset;
up_scale_src = (const float*)config.up_scale + expert_id * full_scale_elems + gate_up_scale_src_offset;
down_scale_src = (const float*)config.down_scale + expert_id * full_scale_elems;
}
std::memcpy(gate_dst, gate_src, tp_weight_elems);
std::memcpy(up_dst, up_src, tp_weight_elems);
std::memcpy(gate_scale_dst, gate_scale_src, sizeof(float) * tp_scale_elems);
std::memcpy(up_scale_dst, up_scale_src, sizeof(float) * tp_scale_elems);
for (int row = 0; row < config.hidden_size; row++) {
const size_t src_row_offset = (size_t)row * (size_t)config.intermediate_size + down_weight_src_col_offset;
const size_t dst_row_offset = (size_t)row * (size_t)tpc.intermediate_size;
std::memcpy(down_dst + dst_row_offset, down_src + src_row_offset, (size_t)tpc.intermediate_size);
}
const int n_blocks_n = div_up(config.hidden_size, group_size);
const int full_n_blocks_k = div_up(config.intermediate_size, group_size);
const int tp_n_blocks_k = div_up(tpc.intermediate_size, group_size);
for (int bn = 0; bn < n_blocks_n; bn++) {
const float* src = down_scale_src + (size_t)bn * (size_t)full_n_blocks_k + down_scale_src_block_k_offset;
float* dst = down_scale_dst + (size_t)bn * (size_t)tp_n_blocks_k;
std::memcpy(dst, src, sizeof(float) * (size_t)tp_n_blocks_k);
}
},
nullptr);
});
DO_TPS_LOAD_WEIGHTS(pool);
pool->dispense_backend()->do_numa_job([&, this](int i) {
auto& tpc = tps[i]->config_;
delete[] (uint8_t*)tpc.gate_proj;
delete[] (uint8_t*)tpc.up_proj;
delete[] (uint8_t*)tpc.down_proj;
delete[] (float*)tpc.gate_scale;
delete[] (float*)tpc.up_scale;
delete[] (float*)tpc.down_scale;
});
this->weights_loaded = true;
}
void write_weight_scale_to_buffer(int gpu_tp_count, int expert_id, const std::vector<uintptr_t>& w13_weight_ptrs,
const std::vector<uintptr_t>& w13_scale_ptrs,
const std::vector<uintptr_t>& w2_weight_ptrs,
const std::vector<uintptr_t>& w2_scale_ptrs) {
if (this->weights_loaded == false) {
throw std::runtime_error("Not Loaded");
}
if (this->tps.empty()) {
throw std::runtime_error("No TP parts initialized");
}
if ((int)w13_weight_ptrs.size() != gpu_tp_count || (int)w13_scale_ptrs.size() != gpu_tp_count ||
(int)w2_weight_ptrs.size() != gpu_tp_count || (int)w2_scale_ptrs.size() != gpu_tp_count) {
throw std::runtime_error("Pointer arrays size must match gpu_tp_count");
}
this->config.pool->dispense_backend()->do_numa_job([&, this](int i) {
this->tps[i]->write_weights_to_buffer(gpu_tp_count, this->tp_count, expert_id, this->config, w13_weight_ptrs,
w13_scale_ptrs, w2_weight_ptrs, w2_scale_ptrs);
});
}
};
#endif // CPUINFER_OPERATOR_AMX_FP8_MOE_H

File diff suppressed because it is too large Load Diff

View File

@@ -46,6 +46,9 @@ static inline __m512 exp_avx512(__m512 x) {
static inline __m512 act_fn(__m512 gate_val, __m512 up_val) {
__m512 neg_gate_val = _mm512_sub_ps(_mm512_setzero_ps(), gate_val);
// Clamp neg_gate_val to avoid exp overflow (exp(88) overflows for float32)
const __m512 max_exp_input = _mm512_set1_ps(88.0f);
neg_gate_val = _mm512_min_ps(neg_gate_val, max_exp_input);
__m512 exp_neg_gate = exp_avx512(neg_gate_val);
__m512 denom = _mm512_add_ps(_mm512_set1_ps(1.0f), exp_neg_gate);
__m512 act_val = _mm512_div_ps(gate_val, denom);

View File

@@ -762,6 +762,16 @@ struct GemmKernel224BF {
struct BufferC {
float* c;
int max_m, n;
// 物理布局(按 float 元素数)
// 逻辑矩阵 C 为 (max_m, n) 行主序max_m 为 M_STEP 的倍数,
// n 按 N_BLOCK 分块。
// 存储顺序:
// n_block(N_BLOCK 列) → m_block(M_STEP 行) → n_step(N_STEP 列) → (M_STEP×N_STEP) 行主序 tile。
// 因此可视为 5D
// c[n_blocks][m_blocks][n_steps][M_STEP][N_STEP]
// n_blocks = ceil(n / N_BLOCK)m_blocks = max_m / M_STEP
// n_steps = N_BLOCK / N_STEP尾块可能更小
// get_submat(m_begin, n_begin) 返回连续的 (M_STEP×N_STEP) tile 起始地址。
static size_t required_size(int max_m, int n) { return sizeof(float) * max_m * n; }

View File

@@ -0,0 +1,488 @@
#ifndef AMX_RAW_BUFFERS_HPP
#define AMX_RAW_BUFFERS_HPP
/**
* @file amx_raw_buffers.hpp
* @brief Raw data format buffer management (FP8, BF16, etc.)
*
* 本文件实现原精度格式的缓冲区管理,用于 DeepSeek V3.2 等原精度推理。
*
* 缓冲区类型:
* - BufferAFP8Impl: 输入激活缓冲区,支持动态 FP8 量化
* - BufferBFP8Impl: 权重缓冲区FP8 格式 + 128x128 块缩放
* - BufferBFP8BlockImpl: 优化的块量化权重缓冲区
*
* 内存布局:
* - FP8 数据1 字节/元素
* - Scale4 字节/块BufferB 每 128x128 块一个BufferA 每 128 行一个)
*/
#include <algorithm>
#include <cassert>
#include <cstdint>
#include <cstdio>
#include <cstring>
#include <limits>
#include <vector>
#include "amx_config.hpp"
#include "amx_utils.hpp"
#include "llama.cpp/ggml-impl.h"
#include "pack.hpp"
#include "utils.hpp"
namespace amx {
// ============================================================================
// BufferAFP8Impl: FP8 激活缓冲区(支持动态量化)
// ============================================================================
/* 物理布局(按 bf16 元素数)
* 逻辑矩阵 A 为 (m, k) 行主序m pad 到 max_m(=m_block_sizeM_STEP 的倍数)。
* 存储顺序:
* k_block(K_BLOCK 列) → m_block(M_STEP 行) → k_step(K_STEP 列) → (M_STEP×K_STEP) 行主序 tile。
* 因此可视为 5D
* a[k_blocks][m_blocks][k_steps][M_STEP][K_STEP]
* k_blocks = ceil(k / K_BLOCK)m_blocks = max_m / M_STEP
* k_steps = K_BLOCK / K_STEP最后一个 k_block 可能更小)。
* get_submat(m_begin, k_begin) 返回连续的 (M_STEP×K_STEP) tile。
*/
template <typename K>
struct BufferABF16Impl {
ggml_bf16_t* a;
int max_m, k;
static constexpr int M_STEP = K::M_STEP;
static constexpr int K_STEP = K::K_STEP;
static constexpr int K_BLOCK = K::K_BLOCK;
static size_t required_size(int max_m, int k) { return sizeof(ggml_bf16_t) * max_m * k; }
BufferABF16Impl(int max_m, int k, void* ptr) : max_m(max_m), k(k) {
assert(reinterpret_cast<intptr_t>(ptr) % 64 == 0);
assert(max_m % M_STEP == 0);
assert(k % K_STEP == 0);
a = reinterpret_cast<ggml_bf16_t*>(ptr);
}
void set_data(void* new_ptr) { a = reinterpret_cast<ggml_bf16_t*>(new_ptr); }
void from_mat(int m, ggml_bf16_t* src, int ith, int nth) {
assert(m <= max_m);
assert(ith == 0 && nth == 1);
int m_block_size = (m + M_STEP - 1) / M_STEP * M_STEP;
for (int m_begin = 0; m_begin < m; m_begin += M_STEP) {
for (int k_block_begin = 0; k_block_begin < k; k_block_begin += K_BLOCK) {
int k_block_size = std::min(K_BLOCK, k - k_block_begin);
for (int k_begin = 0; k_begin < k_block_size; k_begin += K_STEP) {
for (int i = 0; i < M_STEP && m_begin + i < m; i++) {
__m512i* s = (__m512i*)(src + (m_begin + i) * k + k_block_begin + k_begin);
__m512i* d =
(__m512i*)(a + k_block_begin * m_block_size + m_begin * k_block_size + k_begin * M_STEP + i * K_STEP);
avx512_copy_32xbf16(s, d);
}
}
}
}
}
ggml_bf16_t* get_submat(int m, int k, int m_begin, int k_begin) {
int m_block_size = (m + M_STEP - 1) / M_STEP * M_STEP;
int k_block_begin = k_begin / K_BLOCK * K_BLOCK;
k_begin -= k_block_begin;
int k_block_size = std::min(K_BLOCK, k - k_block_begin);
return a + k_block_begin * m_block_size + m_begin * k_block_size + k_begin * M_STEP;
}
};
// ============================================================================
// BufferB
// ============================================================================
/**
* @brief BF16 BufferB
* 物理布局(按 bf16 元素数)
* 逻辑矩阵 B 为 (n, k) 行主序(用于 NT GEMMn 按 N_BLOCK 分块。
* 存储顺序:
* n_block(N_BLOCK 行) → k_block(K_BLOCK 列) → n_step(N_STEP 行) → k_step(K_STEP 列)
* → (N_STEP×K_STEP) tile每个 tile 内部再对两个 16×16 子块做 transpose
* 以匹配 AMX BTile 的 VNNI 布局TILE_K/VNNI_BLK × TILE_N*VNNI_BLK
* 因此可视为 6D
* b[n_blocks][k_blocks][n_steps][k_steps][N_STEP][K_STEP]
* n_blocks = ceil(n / N_BLOCK)k_blocks = ceil(k / K_BLOCK)
* n_steps = N_BLOCK / N_STEPk_steps = K_BLOCK / K_STEP尾块可能更小
* get_submat(n_begin, k_begin) 返回连续的 (N_STEP×K_STEP) tile 起始地址。
* @tparam K Kernel 类型
*/
template <typename K>
struct BufferBBF16Impl {
ggml_bf16_t* b;
int n, k;
static constexpr bool SCALE = false;
static constexpr int N_STEP = K::N_STEP;
static constexpr int K_STEP = K::K_STEP;
static constexpr int N_BLOCK = K::N_BLOCK;
static constexpr int K_BLOCK = K::K_BLOCK;
static constexpr int TILE_N = K::TILE_N;
static size_t required_size(int n, int k) { return sizeof(ggml_bf16_t) * n * k; }
BufferBBF16Impl(int n, int k, void* ptr) : n(n), k(k) {
assert(reinterpret_cast<intptr_t>(ptr) % 64 == 0);
assert(n % N_STEP == 0);
assert(k % K_STEP == 0);
b = reinterpret_cast<ggml_bf16_t*>(ptr);
}
void set_data(void* new_ptr) { b = reinterpret_cast<ggml_bf16_t*>(new_ptr); }
void from_mat(ggml_bf16_t* src, int ith, int nth) {
auto [n_start, n_end] = K::split_range_n(n, ith, nth);
int n_block_begin = n_start;
int n_block_size = n_end - n_block_begin;
for (int n_begin = 0; n_begin < n_block_size; n_begin += N_STEP) {
for (int k_block_begin = 0; k_block_begin < k; k_block_begin += K_BLOCK) {
int k_block_size = std::min(K_BLOCK, k - k_block_begin);
for (int k_begin = 0; k_begin < k_block_size; k_begin += K_STEP) {
for (int i = 0; i < N_STEP; i++) {
__m512i* s = (__m512i*)(src + (n_block_begin + n_begin + i) * k + k_block_begin + k_begin);
__m512i* d = (__m512i*)(b + n_block_begin * k + k_block_begin * n_block_size + n_begin * k_block_size +
k_begin * N_STEP + i * K_STEP);
avx512_copy_32xbf16(s, d);
}
transpose_16x16_32bit((__m512i*)(b + n_block_begin * k + k_block_begin * n_block_size +
n_begin * k_block_size + k_begin * N_STEP));
transpose_16x16_32bit((__m512i*)(b + n_block_begin * k + k_block_begin * n_block_size +
n_begin * k_block_size + k_begin * N_STEP + TILE_N * K_STEP));
}
}
}
}
ggml_bf16_t* get_submat(int n, int k, int n_begin, int k_begin) {
int n_block_begin = n_begin / N_BLOCK * N_BLOCK;
n_begin -= n_block_begin;
int n_block_size = std::min(N_BLOCK, n - n_block_begin);
int k_block_begin = k_begin / K_BLOCK * K_BLOCK;
k_begin -= k_block_begin;
int k_block_size = std::min(K_BLOCK, k - k_block_begin);
return b + n_block_begin * k + k_block_begin * n_block_size + n_begin * k_block_size + k_begin * N_STEP;
}
};
/**
* @brief FP8 权重缓冲区
*
* 存储 FP8 格式的权重矩阵,每个 128x128 块有一个缩放因子。
* 这与 DeepSeek V3.2 的原精度格式匹配。
*
* @tparam K Kernel 类型
*/
template <typename K>
struct BufferBFP8Impl {
uint8_t* b; // FP8 weight
float* d; // scale_inv [n / k_group_size, k / k_group_size]
int n, k, k_group_size; // k_group_size = 128 in DeepSeek
static constexpr int N_STEP = K::N_STEP;
static constexpr int K_STEP = K::K_STEP;
static constexpr int N_BLOCK = K::N_BLOCK;
static constexpr int K_BLOCK = K::K_BLOCK;
static constexpr bool SCALE = true;
/**
* @brief 计算所需内存大小
*/
static size_t required_size(int n, int k, int k_group_size) {
int n_blocks_n = (n + k_group_size - 1) / k_group_size;
int n_blocks_k = (k + k_group_size - 1) / k_group_size;
return sizeof(uint8_t) * n * k + sizeof(float) * n_blocks_n * n_blocks_k;
}
/**
* @brief 构造函数
*/
BufferBFP8Impl(int n, int k, int k_group_size, void* ptr) : n(n), k(k), k_group_size(k_group_size) { set_data(ptr); }
void set_data(void* ptr) {
assert(reinterpret_cast<intptr_t>(ptr) % 64 == 0);
b = reinterpret_cast<uint8_t*>(ptr);
d = reinterpret_cast<float*>(b + (size_t)n * k);
}
static constexpr int mat_offset[8] = {0, 2, 4, 6, 1, 3, 5, 7}; // fp8 matrix offset for reordering
/**
* @brief 从原始 FP8 权重加载(已经是量化格式)
*
* @param b_src FP8 权重源数据 (n-major, n×k)
* @param d_src FP32 scale_inv 源数据 (n-major, ceil(n/128)×ceil(k/128))
*/
void from_mat(const uint8_t* b_src, const float* d_src, int ith, int nth) {
assert(b != nullptr && d != nullptr);
assert(N_STEP == 32 && K_STEP == 32); // from mat block copy assumes this
// Copy scales (per 128x128 block). Each thread copies its own n-block range.
const int n_blocks_k = (k + k_group_size - 1) / k_group_size;
if (d_src != nullptr) {
auto [n_start, n_end] = K::split_range_n(n, ith, nth);
int bn_start = n_start / k_group_size;
int bn_end = (n_end + k_group_size - 1) / k_group_size;
memcpy(d + bn_start * n_blocks_k, d_src + bn_start * n_blocks_k,
sizeof(float) * (bn_end - bn_start) * n_blocks_k);
}
// Reorder FP8 weights into KT block-major layout (same panel->tile order as BF16 BufferB).
auto [n_start, n_end] = K::split_range_n(n, ith, nth);
int n_block_begin = n_start;
int n_block_size = n_end - n_block_begin;
for (int n_begin = 0; n_begin < n_block_size; n_begin += N_STEP) {
int n_step_size = std::min(N_STEP, n_block_size - n_begin);
for (int k_block_begin = 0; k_block_begin < k; k_block_begin += K_BLOCK) {
int k_block_size = std::min(K_BLOCK, k - k_block_begin);
for (int k_begin = 0; k_begin < k_block_size; k_begin += K_STEP) {
int k_step_size = std::min(K_STEP, k_block_size - k_begin);
// [k_step_size, n_step_size] block copy
const uint8_t* block_b_src = b_src + (size_t)(n_block_begin + n_begin) * k + k_block_begin + k_begin;
uint64_t* block_b_dst =
reinterpret_cast<uint64_t*>(b + (size_t)n_block_begin * k + (size_t)k_block_begin * n_block_size +
(size_t)n_begin * k_block_size + (size_t)k_begin * N_STEP);
for (int i = 0; i < 8; i++) {
const uint16_t* s = reinterpret_cast<const uint16_t*>(block_b_src + (size_t)i * k * 4);
for (int j = 0; j < 16; j++) {
uint64_t val = (((uint64_t)s[j])) | (((uint64_t)s[j + (k / 2) * 1]) << 16) |
(((uint64_t)s[j + (k / 2) * 2]) << 32) | (((uint64_t)s[j + (k / 2) * 3]) << 48);
block_b_dst[8 * j + mat_offset[i]] = val;
}
}
}
}
}
}
/**
* @brief get scale_inv
*/
float* get_scale(int n, int n_begin, int k, int k_begin) {
int n_blocks_k = (k + k_group_size - 1) / k_group_size;
int bn = n_begin / k_group_size;
int bk = k_begin / k_group_size;
return d + bn * n_blocks_k + bk;
}
/**
* @brief 获取子矩阵指针
*/
uint8_t* get_submat(int n, int k, int n_begin, int k_begin) {
int n_block_begin = n_begin / N_BLOCK * N_BLOCK;
n_begin -= n_block_begin;
int n_block_size = std::min(N_BLOCK, n - n_block_begin);
int k_block_begin = k_begin / K_BLOCK * K_BLOCK;
k_begin -= k_block_begin;
int k_block_size = std::min(K_BLOCK, k - k_block_begin);
return b + (size_t)n_block_begin * k + (size_t)k_block_begin * n_block_size + (size_t)n_begin * k_block_size +
(size_t)k_begin * N_STEP;
}
/**
* @brief Inverse mapping for mat_offset used in to_mat
* mat_offset = {0, 2, 4, 6, 1, 3, 5, 7}
* inv_mat_offset[mat_offset[i]] = i
*/
static constexpr int inv_mat_offset[8] = {0, 4, 1, 5, 2, 6, 3, 7};
/**
* @brief Unpack FP8 weights from KT block-major layout back to n-major layout
*
* This is the inverse operation of from_mat.
*
* @param b_dst FP8 输出缓冲区 (n-major, n×k)
* @param d_dst FP32 scale_inv 输出缓冲区 (n-major, ceil(n/128)×ceil(k/128))
* @param ith Thread index
* @param nth Total number of threads
*/
void to_mat(uint8_t* b_dst, float* d_dst, int ith, int nth) const {
assert(b != nullptr && d != nullptr);
assert(N_STEP == 32 && K_STEP == 32);
// Calculate N_BLOCK range for this thread
// Unlike split_range_n which gives one N_BLOCK per thread, we need to handle
// the case where nth < n/N_BLOCK (fewer threads than blocks)
int total_n_blocks = (n + N_BLOCK - 1) / N_BLOCK;
int blocks_per_thread = (total_n_blocks + nth - 1) / nth;
int start_n_block_idx = ith * blocks_per_thread;
int end_n_block_idx = std::min((ith + 1) * blocks_per_thread, total_n_blocks);
// Copy scales (per 128x128 block). Each thread copies its own n-block range.
const int n_blocks_k = (k + k_group_size - 1) / k_group_size;
if (d_dst != nullptr) {
int bn_start = start_n_block_idx;
int bn_end = end_n_block_idx;
memcpy(d_dst + bn_start * n_blocks_k, d + bn_start * n_blocks_k,
sizeof(float) * (bn_end - bn_start) * n_blocks_k);
}
// Reorder FP8 weights back to n-major layout (inverse of from_mat)
// Process each N_BLOCK assigned to this thread
for (int n_block_idx = start_n_block_idx; n_block_idx < end_n_block_idx; n_block_idx++) {
int n_block_begin = n_block_idx * N_BLOCK;
int n_block_size = std::min(N_BLOCK, n - n_block_begin);
for (int n_begin = 0; n_begin < n_block_size; n_begin += N_STEP) {
for (int k_block_begin = 0; k_block_begin < k; k_block_begin += K_BLOCK) {
int k_block_size = std::min(K_BLOCK, k - k_block_begin);
for (int k_begin = 0; k_begin < k_block_size; k_begin += K_STEP) {
// Source: packed layout (KT block-major)
const uint64_t* block_b_src =
reinterpret_cast<const uint64_t*>(b + (size_t)n_block_begin * k + (size_t)k_block_begin * n_block_size +
(size_t)n_begin * k_block_size + (size_t)k_begin * N_STEP);
// Destination: n-major layout
uint8_t* block_b_dst = b_dst + (size_t)(n_block_begin + n_begin) * k + k_block_begin + k_begin;
// Inverse of from_mat transformation
for (int packed_i = 0; packed_i < 8; packed_i++) {
int i = inv_mat_offset[packed_i];
uint16_t* d_row = reinterpret_cast<uint16_t*>(block_b_dst + (size_t)i * k * 4);
for (int j = 0; j < 16; j++) {
uint64_t val = block_b_src[8 * j + packed_i];
d_row[j] = (uint16_t)(val & 0xFFFF);
d_row[j + (k / 2) * 1] = (uint16_t)((val >> 16) & 0xFFFF);
d_row[j + (k / 2) * 2] = (uint16_t)((val >> 32) & 0xFFFF);
d_row[j + (k / 2) * 3] = (uint16_t)((val >> 48) & 0xFFFF);
}
}
}
}
}
}
}
};
// ============================================================================
// BufferCFP8Impl: FP32 输出缓冲区
// ============================================================================
/**
* @brief FP32 输出缓冲区
*
* 存储 FP32 格式的累加器,支持转换为 BF16 输出
*
* @tparam K Kernel 类型
*/
template <typename K>
struct BufferCFP32Impl {
float* c;
int max_m, n;
static constexpr int M_STEP = K::M_STEP;
static constexpr int N_STEP = K::N_STEP;
static constexpr int N_BLOCK = K::N_BLOCK;
// 物理布局(按 float 元素数)
// 逻辑矩阵 C 为 (max_m, n) 行主序max_m 为 M_STEP 的倍数,
// n 按 N_BLOCK 分块。
// 存储顺序:
// n_block(N_BLOCK 列) → m_block(M_STEP 行) → n_step(N_STEP 列) → (M_STEP×N_STEP) 行主序 tile。
// 因此可视为 5D
// c[n_blocks][m_blocks][n_steps][M_STEP][N_STEP]
// n_blocks = ceil(n / N_BLOCK)m_blocks = max_m / M_STEP
// n_steps = N_BLOCK / N_STEP尾块可能更小
// get_submat(m_begin, n_begin) 返回连续的 (M_STEP×N_STEP) tile 起始地址。
static size_t required_size(int max_m, int n) { return sizeof(float) * max_m * n; }
BufferCFP32Impl(int max_m, int n, void* ptr) : max_m(max_m), n(n) {
assert(reinterpret_cast<intptr_t>(ptr) % 64 == 0);
assert(max_m % M_STEP == 0);
assert(n % N_STEP == 0);
c = reinterpret_cast<float*>(ptr);
}
void set_data(void* new_ptr) { c = reinterpret_cast<float*>(new_ptr); }
void to_mat(int m, ggml_bf16_t* dst, int ith, int nth) {
assert(m <= max_m);
auto [n_start, n_end] = K::split_range_n(n, ith, nth);
int m_block_size = (m + M_STEP - 1) / M_STEP * M_STEP;
int n_block_begin = n_start;
int n_block_size = n_end - n_block_begin;
for (int m_begin = 0; m_begin < m; m_begin += M_STEP) {
for (int n_begin = 0; n_begin < n_block_size; n_begin += N_STEP) {
for (int i = 0; i < M_STEP && m_begin + i < m; i++) {
__m512* x0 =
(__m512*)(c + m_block_size * n_block_begin + m_begin * n_block_size + n_begin * M_STEP + i * N_STEP);
__m512* x1 =
(__m512*)(c + m_block_size * n_block_begin + m_begin * n_block_size + n_begin * M_STEP + i * N_STEP + 16);
avx512_32xfp32_to_32xbf16(x0, x1, (__m512i*)(dst + (m_begin + i) * n + n_block_begin + n_begin));
}
}
}
}
float* get_submat(int m, int n, int m_begin, int n_begin) {
int m_block_size = (m + M_STEP - 1) / M_STEP * M_STEP;
int n_block_begin = n_begin / N_BLOCK * N_BLOCK;
int n_block_size = std::min(N_BLOCK, n - n_block_begin);
n_begin -= n_block_begin;
return c + m_block_size * n_block_begin + m_begin * n_block_size + n_begin * M_STEP;
}
};
template <typename K>
struct BufferCFP32ReduceImpl {
float* c;
float* reduce_buf;
int max_m, n;
static constexpr int M_STEP = K::M_STEP;
static constexpr int N_STEP = K::N_STEP;
static constexpr int N_BLOCK = K::N_BLOCK;
static size_t required_size(int max_m, int n) { return sizeof(float) * (size_t)max_m * n * 2; }
BufferCFP32ReduceImpl(int max_m, int n, void* ptr) : max_m(max_m), n(n) {
assert(max_m % M_STEP == 0);
assert(n % N_STEP == 0);
set_data(ptr);
}
void set_data(void* ptr) {
assert(reinterpret_cast<intptr_t>(ptr) % 64 == 0);
c = reinterpret_cast<float*>(ptr);
reduce_buf = c + (size_t)max_m * n;
}
void to_mat(int m, ggml_bf16_t* dst, int ith, int nth) {
assert(m <= max_m);
auto [n_start, n_end] = K::split_range_n(n, ith, nth);
int m_block_size = (m + M_STEP - 1) / M_STEP * M_STEP;
int n_block_begin = n_start;
int n_block_size = n_end - n_block_begin;
for (int m_begin = 0; m_begin < m; m_begin += M_STEP) {
for (int n_begin = 0; n_begin < n_block_size; n_begin += N_STEP) {
for (int i = 0; i < M_STEP && m_begin + i < m; i++) {
__m512* x0 =
(__m512*)(c + m_block_size * n_block_begin + m_begin * n_block_size + n_begin * M_STEP + i * N_STEP);
__m512* x1 =
(__m512*)(c + m_block_size * n_block_begin + m_begin * n_block_size + n_begin * M_STEP + i * N_STEP + 16);
avx512_32xfp32_to_32xbf16(x0, x1, (__m512i*)(dst + (m_begin + i) * n + n_block_begin + n_begin));
}
}
}
}
float* get_submat(int m, int n, int m_begin, int n_begin) {
int m_block_size = (m + M_STEP - 1) / M_STEP * M_STEP;
int n_block_begin = n_begin / N_BLOCK * N_BLOCK;
int n_block_size = std::min(N_BLOCK, n - n_block_begin);
n_begin -= n_block_begin;
return c + (size_t)m_block_size * n_block_begin + (size_t)m_begin * n_block_size + (size_t)n_begin * M_STEP;
}
float* get_reduce_submat(int m, int n, int m_begin, int n_begin) {
int m_block_size = (m + M_STEP - 1) / M_STEP * M_STEP;
int n_block_begin = n_begin / N_BLOCK * N_BLOCK;
int n_block_size = std::min(N_BLOCK, n - n_block_begin);
n_begin -= n_block_begin;
return reduce_buf + (size_t)m_block_size * n_block_begin + (size_t)m_begin * n_block_size +
(size_t)n_begin * M_STEP;
}
};
} // namespace amx
#endif // AMX_RAW_BUFFERS_HPP

View File

@@ -0,0 +1,464 @@
#ifndef AMX_RAW_KERNELS_HPP
#define AMX_RAW_KERNELS_HPP
#include <algorithm>
#include <cassert>
#include <cmath>
#include <cstdint>
#include <string>
#include "amx_config.hpp"
#include "amx_raw_buffers.hpp"
#include "amx_utils.hpp"
#include "llama.cpp/ggml-impl.h"
namespace amx {
struct GemmKernel224BF16 {
using dt = ggml_bf16_t;
using output_t = float;
static constexpr double ELEMENT_SIZE = 2;
static const int TILE_M = 16;
static const int TILE_K = 32;
static const int TILE_N = 16;
static const int VNNI_BLK = 2;
static const int M_STEP = TILE_M * 2;
static const int N_STEP = TILE_N * 2;
static const int K_STEP = TILE_K;
static inline const int N_BLOCK = 256;
static inline const int K_BLOCK = 1792;
static std::string name() { return "BF16"; }
static int recommended_nth(int n) { return (n + N_BLOCK - 1) / N_BLOCK; }
static std::pair<int, int> split_range_n(int n, int ith, int nth) {
int n_start = N_BLOCK * ith;
int n_end = std::min(n, N_BLOCK * (ith + 1));
return {n_start, n_end};
}
static void config() {
#ifdef HAVE_AMX
enable_amx();
TileConfig tile_config;
// size is 16 x 32
for (int i = 0; i < 2; i++) tile_config.set_row_col(i, TILE_M, TILE_K * sizeof(dt));
// size is 16 x 32
for (int i = 2; i < 4; i++) tile_config.set_row_col(i, TILE_K / VNNI_BLK, TILE_N * VNNI_BLK * sizeof(dt));
// size is 16 x 16
for (int i = 4; i < 8; i++) tile_config.set_row_col(i, TILE_M, TILE_N * sizeof(output_t));
tile_config.set_config();
#endif
}
static void load_a(dt* a, size_t lda) {
#ifdef HAVE_AMX
_tile_loadd(0, a, lda);
_tile_loadd(1, offset_pointer(a, lda * TILE_M), lda);
#else
(void)a;
(void)lda;
#endif
}
static void load_b(dt* b, size_t ldb) {
#ifdef HAVE_AMX
_tile_loadd(2, b, ldb);
_tile_loadd(3, offset_pointer(b, ldb * TILE_N), ldb);
#else
(void)b;
(void)ldb;
#endif
}
static void clean_c() {
#ifdef HAVE_AMX
_tile_zero(4);
_tile_zero(5);
_tile_zero(6);
_tile_zero(7);
#endif
}
static void load_c(output_t* c, size_t ldc) {
#ifdef HAVE_AMX
_tile_loadd(4, c, ldc);
_tile_loadd(5, offset_pointer(c, TILE_N * sizeof(output_t)), ldc);
_tile_loadd(6, offset_pointer(c, ldc * TILE_M), ldc);
_tile_loadd(7, offset_pointer(c, ldc * TILE_M + TILE_N * sizeof(output_t)), ldc);
#else
(void)c;
(void)ldc;
#endif
}
static void store_c(output_t* c, size_t ldc) {
#ifdef HAVE_AMX
_tile_stored(4, c, ldc);
_tile_stored(5, offset_pointer(c, TILE_N * sizeof(output_t)), ldc);
_tile_stored(6, offset_pointer(c, ldc * TILE_M), ldc);
_tile_stored(7, offset_pointer(c, ldc * TILE_M + TILE_N * sizeof(output_t)), ldc);
#else
(void)c;
(void)ldc;
#endif
}
static void run_tile() {
#ifdef HAVE_AMX
_tile_dpbf16ps(4, 0, 2);
_tile_dpbf16ps(5, 0, 3);
_tile_dpbf16ps(6, 1, 2);
_tile_dpbf16ps(7, 1, 3);
#endif
}
using BufferA = BufferABF16Impl<GemmKernel224BF16>;
using BufferB = BufferBBF16Impl<GemmKernel224BF16>;
using BufferC = BufferCFP32Impl<GemmKernel224BF16>;
};
// FP8 (e4m3) AMX kernel that mirrors the GemmKernel224BF16 interface.
struct GemmKernel224FP8 {
using fp8_t = uint8_t;
using output_t = float;
static constexpr double ELEMENT_SIZE = 1.0;
static const int TILE_M = 16;
static const int TILE_K = 32;
static const int TILE_N = 16;
static const int VNNI_BLK = 2;
static const int M_STEP = TILE_M * 2;
static const int N_STEP = TILE_N * 2;
static const int K_STEP = TILE_K;
static inline const int BLOCK_SIZE = 128; // 128 x 128 block quantization
static inline const int N_BLOCK = 128;
static inline const int K_BLOCK = 7168;
static std::string name() { return "FP8"; }
static int recommended_nth(int n) { return (n + N_BLOCK - 1) / N_BLOCK; }
static std::pair<int, int> split_range_n(int n, int ith, int nth) {
int n_start = N_BLOCK * ith;
int n_end = std::min(n, N_BLOCK * (ith + 1));
return {n_start, n_end};
}
static void config() {}
private:
alignas(64) static constexpr uint8_t bf16_hi_0_val[64] = {
0x00, 0x3b, 0x3b, 0x3b, 0x3c, 0x3c, 0x3c, 0x3c, 0x3c, 0x3c, 0x3c, 0x3c, 0x3c, 0x3c, 0x3c, 0x3c,
0x3d, 0x3d, 0x3d, 0x3d, 0x3d, 0x3d, 0x3d, 0x3d, 0x3d, 0x3d, 0x3d, 0x3d, 0x3d, 0x3d, 0x3d, 0x3d,
0x3e, 0x3e, 0x3e, 0x3e, 0x3e, 0x3e, 0x3e, 0x3e, 0x3e, 0x3e, 0x3e, 0x3e, 0x3e, 0x3e, 0x3e, 0x3e,
0x3f, 0x3f, 0x3f, 0x3f, 0x3f, 0x3f, 0x3f, 0x3f, 0x3f, 0x3f, 0x3f, 0x3f, 0x3f, 0x3f, 0x3f, 0x3f,
};
alignas(64) static constexpr uint8_t bf16_hi_1_val[64] = {
0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40,
0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41,
0x42, 0x42, 0x42, 0x42, 0x42, 0x42, 0x42, 0x42, 0x42, 0x42, 0x42, 0x42, 0x42, 0x42, 0x42, 0x42,
0x43, 0x43, 0x43, 0x43, 0x43, 0x43, 0x43, 0x43, 0x43, 0x43, 0x43, 0x43, 0x43, 0x43, 0x43, 0x43,
};
alignas(64) static constexpr uint8_t bf16_lo_0_val[64] = {
0x00, 0x00, 0x80, 0xc0, 0x00, 0x20, 0x40, 0x60, 0x80, 0x90, 0xa0, 0xb0, 0xc0, 0xd0, 0xe0, 0xf0,
0x00, 0x10, 0x20, 0x30, 0x40, 0x50, 0x60, 0x70, 0x80, 0x90, 0xa0, 0xb0, 0xc0, 0xd0, 0xe0, 0xf0,
0x00, 0x10, 0x20, 0x30, 0x40, 0x50, 0x60, 0x70, 0x80, 0x90, 0xa0, 0xb0, 0xc0, 0xd0, 0xe0, 0xf0,
0x00, 0x10, 0x20, 0x30, 0x40, 0x50, 0x60, 0x70, 0x80, 0x90, 0xa0, 0xb0, 0xc0, 0xd0, 0xe0, 0xf0,
};
alignas(64) static constexpr uint8_t bf16_lo_1_val[64] = {
0x00, 0x10, 0x20, 0x30, 0x40, 0x50, 0x60, 0x70, 0x80, 0x90, 0xa0, 0xb0, 0xc0, 0xd0, 0xe0, 0xf0,
0x00, 0x10, 0x20, 0x30, 0x40, 0x50, 0x60, 0x70, 0x80, 0x90, 0xa0, 0xb0, 0xc0, 0xd0, 0xe0, 0xf0,
0x00, 0x10, 0x20, 0x30, 0x40, 0x50, 0x60, 0x70, 0x80, 0x90, 0xa0, 0xb0, 0xc0, 0xd0, 0xe0, 0xf0,
0x00, 0x10, 0x20, 0x30, 0x40, 0x50, 0x60, 0x70, 0x80, 0x90, 0xa0, 0xb0, 0xc0, 0xd0, 0xe0, 0xf0,
};
// _mm512_set1_epi8 is not constexpr; keep it as a static cached value
alignas(64) static const __m512i sign_mask_val;
static inline __m512i bf16_hi_0_mask() { return _mm512_load_si512((__m512i const*)bf16_hi_0_val); }
static inline __m512i bf16_hi_1_mask() { return _mm512_load_si512((__m512i const*)bf16_hi_1_val); }
static inline __m512i bf16_lo_0_mask() { return _mm512_load_si512((__m512i const*)bf16_lo_0_val); }
static inline __m512i bf16_lo_1_mask() { return _mm512_load_si512((__m512i const*)bf16_lo_1_val); }
static inline __m512i sign_mask() { return _mm512_set1_epi8(0x80); }
public:
using BufferA = BufferABF16Impl<GemmKernel224FP8>;
using BufferB = BufferBFP8Impl<GemmKernel224FP8>;
using BufferC = BufferCFP32ReduceImpl<GemmKernel224FP8>;
static inline std::pair<__m512i, __m512i> fp8x64_to_bf16x64(__m512i bfp8_512) {
// fp8->bf16
__m512i b_hi = _mm512_permutex2var_epi8(bf16_hi_0_mask(), bfp8_512, bf16_hi_1_mask());
__m512i b_lo = _mm512_permutex2var_epi8(bf16_lo_0_mask(), bfp8_512, bf16_lo_1_mask());
b_hi = _mm512_or_si512(_mm512_and_si512(sign_mask(), bfp8_512), b_hi);
__m512i bbf16_0 = _mm512_unpacklo_epi8(b_lo, b_hi);
__m512i bbf16_1 = _mm512_unpackhi_epi8(b_lo, b_hi);
return {bbf16_0, bbf16_1};
}
// Optimized AVX kernel: process entire k_group_size
// Load all data first, then convert all, then compute all
// This gives compiler more freedom to schedule instructions
static void avx_kernel(int m, int n, int k, int m_begin, int n_begin, int k_group_begin, float* c, BufferA* ba,
BufferB* bb, int k_group_size) {
const __m512i bf16_hi_0_val = bf16_hi_0_mask();
const __m512i bf16_hi_1_val = bf16_hi_1_mask();
const __m512i bf16_lo_0_val = bf16_lo_0_mask();
const __m512i bf16_lo_1_val = bf16_lo_1_mask();
const __m512i sign_mask_val = sign_mask();
__m512* c512 = (__m512*)c;
int m_block_end = std::min(m - m_begin, M_STEP);
// Zero out accumulator at the start
for (int m_i = 0; m_i < m_block_end; m_i++) {
c512[m_i * 2] = _mm512_setzero_ps();
c512[m_i * 2 + 1] = _mm512_setzero_ps();
}
// Process entire k_group_size
for (int k_begin = 0; k_begin < k_group_size && k_group_begin + k_begin < k; k_begin += K_STEP) {
ggml_bf16_t* abf16 = (ggml_bf16_t*)ba->get_submat(m, k, m_begin, k_group_begin + k_begin);
__m512i* bfp8_512 = (__m512i*)bb->get_submat(n, k, n_begin, k_group_begin + k_begin);
for (int m_i = 0; m_i < m_block_end; m_i++) {
// Process 2 k_i per iteration
for (int k_i = 0; k_i < 16; k_i += 2) {
// Load A vectors
__m512bh ma0 = (__m512bh)_mm512_set1_epi32(*(int32_t*)&abf16[m_i * K_STEP + k_i * 2]);
__m512bh ma1 = (__m512bh)_mm512_set1_epi32(*(int32_t*)&abf16[m_i * K_STEP + (k_i + 1) * 2]);
// Load B matrices
__m512i bfp8_0 = bfp8_512[k_i];
__m512i bfp8_1 = bfp8_512[k_i + 1];
// Convert FP8 -> BF16 for all
__m512i b_hi_0 = _mm512_permutex2var_epi8(bf16_hi_0_val, bfp8_0, bf16_hi_1_val);
__m512i b_lo_0 = _mm512_permutex2var_epi8(bf16_lo_0_val, bfp8_0, bf16_lo_1_val);
b_hi_0 = _mm512_or_si512(_mm512_and_si512(sign_mask_val, bfp8_0), b_hi_0);
__m512i b_hi_1 = _mm512_permutex2var_epi8(bf16_hi_0_val, bfp8_1, bf16_hi_1_val);
__m512i b_lo_1 = _mm512_permutex2var_epi8(bf16_lo_0_val, bfp8_1, bf16_lo_1_val);
b_hi_1 = _mm512_or_si512(_mm512_and_si512(sign_mask_val, bfp8_1), b_hi_1);
// Compute dpbf16 for all
__m512bh bbf16_0_0 = (__m512bh)_mm512_unpacklo_epi8(b_lo_0, b_hi_0);
__m512bh bbf16_1_0 = (__m512bh)_mm512_unpackhi_epi8(b_lo_0, b_hi_0);
c512[m_i * 2] = _mm512_dpbf16_ps(c512[m_i * 2], ma0, bbf16_0_0);
c512[m_i * 2 + 1] = _mm512_dpbf16_ps(c512[m_i * 2 + 1], ma0, bbf16_1_0);
__m512bh bbf16_0_1 = (__m512bh)_mm512_unpacklo_epi8(b_lo_1, b_hi_1);
__m512bh bbf16_1_1 = (__m512bh)_mm512_unpackhi_epi8(b_lo_1, b_hi_1);
c512[m_i * 2] = _mm512_dpbf16_ps(c512[m_i * 2], ma1, bbf16_0_1);
c512[m_i * 2 + 1] = _mm512_dpbf16_ps(c512[m_i * 2 + 1], ma1, bbf16_1_1);
}
}
}
}
// Optimized AVX kernel: process 4 k_i at once, convert B once and reuse for all m rows
// This version achieved ~493 GB/s - restoring as baseline for further optimization
static void avx_kernel_4(int m, int n, int k, int m_begin, int n_begin, int k_group_begin, float* c, BufferA* ba,
BufferB* bb, int k_group_size) {
const __m512i bf16_hi_0 = bf16_hi_0_mask();
const __m512i bf16_hi_1 = bf16_hi_1_mask();
const __m512i bf16_lo_0 = bf16_lo_0_mask();
const __m512i bf16_lo_1 = bf16_lo_1_mask();
const __m512i sign_mask_v = sign_mask();
__m512* c512 = (__m512*)c;
int m_block_end = std::min(m - m_begin, M_STEP);
// Zero out accumulator
for (int m_i = 0; m_i < m_block_end; m_i++) {
c512[m_i * 2] = _mm512_setzero_ps();
c512[m_i * 2 + 1] = _mm512_setzero_ps();
}
// Process entire k_group_size
for (int k_begin = 0; k_begin < k_group_size && k_group_begin + k_begin < k; k_begin += K_STEP) {
ggml_bf16_t* abf16 = (ggml_bf16_t*)ba->get_submat(m, k, m_begin, k_group_begin + k_begin);
__m512i* bfp8_512 = (__m512i*)bb->get_submat(n, k, n_begin, k_group_begin + k_begin);
// Process 4 k_i at once - convert B and reuse across all m rows
for (int k_i = 0; k_i < 16; k_i += 4) {
// Load 4 B vectors
__m512i bfp8_0 = bfp8_512[k_i];
__m512i bfp8_1 = bfp8_512[k_i + 1];
__m512i bfp8_2 = bfp8_512[k_i + 2];
__m512i bfp8_3 = bfp8_512[k_i + 3];
// Convert all 4 FP8 -> BF16
__m512i b_hi, b_lo;
b_hi = _mm512_or_si512(_mm512_and_si512(sign_mask_v, bfp8_0),
_mm512_permutex2var_epi8(bf16_hi_0, bfp8_0, bf16_hi_1));
b_lo = _mm512_permutex2var_epi8(bf16_lo_0, bfp8_0, bf16_lo_1);
__m512bh bbf16_0_lo = (__m512bh)_mm512_unpacklo_epi8(b_lo, b_hi);
__m512bh bbf16_0_hi = (__m512bh)_mm512_unpackhi_epi8(b_lo, b_hi);
b_hi = _mm512_or_si512(_mm512_and_si512(sign_mask_v, bfp8_1),
_mm512_permutex2var_epi8(bf16_hi_0, bfp8_1, bf16_hi_1));
b_lo = _mm512_permutex2var_epi8(bf16_lo_0, bfp8_1, bf16_lo_1);
__m512bh bbf16_1_lo = (__m512bh)_mm512_unpacklo_epi8(b_lo, b_hi);
__m512bh bbf16_1_hi = (__m512bh)_mm512_unpackhi_epi8(b_lo, b_hi);
b_hi = _mm512_or_si512(_mm512_and_si512(sign_mask_v, bfp8_2),
_mm512_permutex2var_epi8(bf16_hi_0, bfp8_2, bf16_hi_1));
b_lo = _mm512_permutex2var_epi8(bf16_lo_0, bfp8_2, bf16_lo_1);
__m512bh bbf16_2_lo = (__m512bh)_mm512_unpacklo_epi8(b_lo, b_hi);
__m512bh bbf16_2_hi = (__m512bh)_mm512_unpackhi_epi8(b_lo, b_hi);
b_hi = _mm512_or_si512(_mm512_and_si512(sign_mask_v, bfp8_3),
_mm512_permutex2var_epi8(bf16_hi_0, bfp8_3, bf16_hi_1));
b_lo = _mm512_permutex2var_epi8(bf16_lo_0, bfp8_3, bf16_lo_1);
__m512bh bbf16_3_lo = (__m512bh)_mm512_unpacklo_epi8(b_lo, b_hi);
__m512bh bbf16_3_hi = (__m512bh)_mm512_unpackhi_epi8(b_lo, b_hi);
// Process m rows - unroll by 2 for better ILP
int m_i = 0;
for (; m_i + 1 < m_block_end; m_i += 2) {
// Load A values for 2 rows
__m512bh ma0_0 = (__m512bh)_mm512_set1_epi32(*(int32_t*)&abf16[m_i * K_STEP + k_i * 2]);
__m512bh ma1_0 = (__m512bh)_mm512_set1_epi32(*(int32_t*)&abf16[m_i * K_STEP + (k_i + 1) * 2]);
__m512bh ma2_0 = (__m512bh)_mm512_set1_epi32(*(int32_t*)&abf16[m_i * K_STEP + (k_i + 2) * 2]);
__m512bh ma3_0 = (__m512bh)_mm512_set1_epi32(*(int32_t*)&abf16[m_i * K_STEP + (k_i + 3) * 2]);
__m512bh ma0_1 = (__m512bh)_mm512_set1_epi32(*(int32_t*)&abf16[(m_i + 1) * K_STEP + k_i * 2]);
__m512bh ma1_1 = (__m512bh)_mm512_set1_epi32(*(int32_t*)&abf16[(m_i + 1) * K_STEP + (k_i + 1) * 2]);
__m512bh ma2_1 = (__m512bh)_mm512_set1_epi32(*(int32_t*)&abf16[(m_i + 1) * K_STEP + (k_i + 2) * 2]);
__m512bh ma3_1 = (__m512bh)_mm512_set1_epi32(*(int32_t*)&abf16[(m_i + 1) * K_STEP + (k_i + 3) * 2]);
// Process row 0, then row 1 - sequential to avoid dependencies
c512[m_i * 2] = _mm512_dpbf16_ps(c512[m_i * 2], ma0_0, bbf16_0_lo);
c512[m_i * 2 + 1] = _mm512_dpbf16_ps(c512[m_i * 2 + 1], ma0_0, bbf16_0_hi);
c512[m_i * 2] = _mm512_dpbf16_ps(c512[m_i * 2], ma1_0, bbf16_1_lo);
c512[m_i * 2 + 1] = _mm512_dpbf16_ps(c512[m_i * 2 + 1], ma1_0, bbf16_1_hi);
c512[m_i * 2] = _mm512_dpbf16_ps(c512[m_i * 2], ma2_0, bbf16_2_lo);
c512[m_i * 2 + 1] = _mm512_dpbf16_ps(c512[m_i * 2 + 1], ma2_0, bbf16_2_hi);
c512[m_i * 2] = _mm512_dpbf16_ps(c512[m_i * 2], ma3_0, bbf16_3_lo);
c512[m_i * 2 + 1] = _mm512_dpbf16_ps(c512[m_i * 2 + 1], ma3_0, bbf16_3_hi);
c512[(m_i + 1) * 2] = _mm512_dpbf16_ps(c512[(m_i + 1) * 2], ma0_1, bbf16_0_lo);
c512[(m_i + 1) * 2 + 1] = _mm512_dpbf16_ps(c512[(m_i + 1) * 2 + 1], ma0_1, bbf16_0_hi);
c512[(m_i + 1) * 2] = _mm512_dpbf16_ps(c512[(m_i + 1) * 2], ma1_1, bbf16_1_lo);
c512[(m_i + 1) * 2 + 1] = _mm512_dpbf16_ps(c512[(m_i + 1) * 2 + 1], ma1_1, bbf16_1_hi);
c512[(m_i + 1) * 2] = _mm512_dpbf16_ps(c512[(m_i + 1) * 2], ma2_1, bbf16_2_lo);
c512[(m_i + 1) * 2 + 1] = _mm512_dpbf16_ps(c512[(m_i + 1) * 2 + 1], ma2_1, bbf16_2_hi);
c512[(m_i + 1) * 2] = _mm512_dpbf16_ps(c512[(m_i + 1) * 2], ma3_1, bbf16_3_lo);
c512[(m_i + 1) * 2 + 1] = _mm512_dpbf16_ps(c512[(m_i + 1) * 2 + 1], ma3_1, bbf16_3_hi);
}
// Handle remaining row
for (; m_i < m_block_end; m_i++) {
__m512bh ma0 = (__m512bh)_mm512_set1_epi32(*(int32_t*)&abf16[m_i * K_STEP + k_i * 2]);
__m512bh ma1 = (__m512bh)_mm512_set1_epi32(*(int32_t*)&abf16[m_i * K_STEP + (k_i + 1) * 2]);
__m512bh ma2 = (__m512bh)_mm512_set1_epi32(*(int32_t*)&abf16[m_i * K_STEP + (k_i + 2) * 2]);
__m512bh ma3 = (__m512bh)_mm512_set1_epi32(*(int32_t*)&abf16[m_i * K_STEP + (k_i + 3) * 2]);
c512[m_i * 2] = _mm512_dpbf16_ps(c512[m_i * 2], ma0, bbf16_0_lo);
c512[m_i * 2 + 1] = _mm512_dpbf16_ps(c512[m_i * 2 + 1], ma0, bbf16_0_hi);
c512[m_i * 2] = _mm512_dpbf16_ps(c512[m_i * 2], ma1, bbf16_1_lo);
c512[m_i * 2 + 1] = _mm512_dpbf16_ps(c512[m_i * 2 + 1], ma1, bbf16_1_hi);
c512[m_i * 2] = _mm512_dpbf16_ps(c512[m_i * 2], ma2, bbf16_2_lo);
c512[m_i * 2 + 1] = _mm512_dpbf16_ps(c512[m_i * 2 + 1], ma2, bbf16_2_hi);
c512[m_i * 2] = _mm512_dpbf16_ps(c512[m_i * 2], ma3, bbf16_3_lo);
c512[m_i * 2 + 1] = _mm512_dpbf16_ps(c512[m_i * 2 + 1], ma3, bbf16_3_hi);
}
}
}
}
static void apply_scale_kgroup(int m, int n, int m_begin, int n_begin, int k_block_begin, float* c, float* reduce_c,
BufferA* ba, BufferB* bb, int k, int k_group_size) {
using K = GemmKernel224FP8;
int to = std::min(m - m_begin, K::M_STEP);
for (int i = 0; i < to; i++) {
// Get scale for this k_group
__m512 bs = _mm512_set1_ps(*bb->get_scale(n, n_begin, k, k_block_begin));
__m512 now = _mm512_load_ps(reduce_c + i * K::N_STEP);
__m512 result = _mm512_mul_ps(now, bs);
__m512 existing = _mm512_load_ps(c + i * K::N_STEP);
result = _mm512_add_ps(result, existing);
_mm512_store_ps(c + i * K::N_STEP, result);
now = _mm512_load_ps(reduce_c + i * K::N_STEP + K::TILE_N);
result = _mm512_mul_ps(now, bs);
existing = _mm512_load_ps(c + i * K::N_STEP + K::TILE_N);
result = _mm512_add_ps(result, existing);
_mm512_store_ps(c + i * K::N_STEP + K::TILE_N, result);
}
}
};
// all step = 32
template <typename K, bool amx_or_avx = false>
void float_mat_vec_kgroup(int m, int n, int k, int k_group_size, typename K::BufferA* ba, typename K::BufferB* bb,
typename K::BufferC* bc, int ith, int nth) {
assert(n % K::N_STEP == 0);
assert(k % k_group_size == 0);
assert(k_group_size % K::K_STEP == 0);
auto [n_start, n_end] = K::split_range_n(n, ith, nth);
// Process by k_groups
for (int k_group_begin = 0; k_group_begin < k; k_group_begin += k_group_size) {
for (int m_begin = 0; m_begin < m; m_begin += K::M_STEP) {
for (int n_begin = n_start; n_begin < n_end; n_begin += K::N_STEP) {
float* c = bc->get_submat(m, n, m_begin, n_begin);
float* reduce_c = bc->get_reduce_submat(m, n, m_begin, n_begin);
if (k_group_begin == 0) {
for (int i = 0; i < K::M_STEP && m_begin + i < m; i++) {
for (int j = 0; j < K::N_STEP; j++) {
c[i * K::N_STEP + j] = 0.0f;
}
}
}
// avx_kernel_4 now processes entire k_group_size internally (like INT8's avx_kernel)
if constexpr (amx_or_avx && AMX_AVAILABLE) {
for (int k_begin = k_group_begin; k_begin < std::min(k, k_group_begin + k_group_size); k_begin += K::K_STEP) {
K::amx_kernel(m, n, k, m_begin, n_begin, k_begin, reduce_c, ba, bb, k_group_size);
}
} else {
// Single call processes entire k_group
K::avx_kernel_4(m, n, k, m_begin, n_begin, k_group_begin, reduce_c, ba, bb, k_group_size);
}
K::apply_scale_kgroup(m, n, m_begin, n_begin, k_group_begin, c, reduce_c, ba, bb, k, k_group_size);
}
}
}
}
// inline void vec_mul_kgroup(int m, int n, int k, int k_group_size, std::shared_ptr<GemmKernel224BF16::BufferA> ba,
// std::shared_ptr<GemmKernel224BF16::BufferB> bb,
// std::shared_ptr<GemmKernel224BF16::BufferC> bc, int ith, int nth) {
// float_mat_mul_kgroup<GemmKernel224BF16, false>(m, n, k, k_group_size, ba.get(), bb.get(), bc.get(), ith, nth);
// }
// inline void mat_mul_kgroup(int m, int n, int k, int k_group_size, std::shared_ptr<GemmKernel224BF16::BufferA> ba,
// std::shared_ptr<GemmKernel224BF16::BufferB> bb,
// std::shared_ptr<GemmKernel224BF16::BufferC> bc, int ith, int nth) {
// float_mat_mul_kgroup<GemmKernel224BF16, true>(m, n, k, k_group_size, ba.get(), bb.get(), bc.get(), ith, nth);
// }
inline void vec_mul_kgroup(int m, int n, int k, int k_group_size, std::shared_ptr<GemmKernel224FP8::BufferA> ba,
std::shared_ptr<GemmKernel224FP8::BufferB> bb, std::shared_ptr<GemmKernel224FP8::BufferC> bc,
int ith, int nth) {
float_mat_vec_kgroup<GemmKernel224FP8, false>(m, n, k, k_group_size, ba.get(), bb.get(), bc.get(), ith, nth);
}
inline void mat_mul_kgroup(int m, int n, int k, int k_group_size, std::shared_ptr<GemmKernel224FP8::BufferA> ba,
std::shared_ptr<GemmKernel224FP8::BufferB> bb, std::shared_ptr<GemmKernel224FP8::BufferC> bc,
int ith, int nth) {
float_mat_vec_kgroup<GemmKernel224FP8, false>(m, n, k, k_group_size, ba.get(), bb.get(), bc.get(), ith, nth);
}
} // namespace amx
#endif // AMX_RAW_KERNELS_HPP

View File

@@ -11,30 +11,27 @@
#define CPUINFER_OPERATOR_AMX_MOE_H
// #define CHECK
#include <cstddef>
#include <cstdint>
#include <cstring>
// #define FORWARD_TIME_PROFILE
// #define FORWARD_TIME_REPORT
#include <cmath>
#include <cstdio>
#include <filesystem>
#include <fstream>
#include <string>
#include <vector>
#include "../../cpu_backend/shared_mem_buffer.h"
#include "../../cpu_backend/worker_pool.h"
#include "../moe-tp.hpp"
#include "la/amx.hpp"
#include "llama.cpp/ggml.h"
#include "moe_base.hpp"
template <class T>
class AMX_MOE_TP {
class AMX_MOE_TP : public AMX_MOE_BASE<T, AMX_MOE_TP<T>> {
private:
int tp_part_idx;
using Base = AMX_MOE_BASE<T, AMX_MOE_TP<T>>;
using Base::config_;
using Base::tp_part_idx;
using Base::gate_bb_;
using Base::up_bb_;
using Base::down_bb_;
using Base::gate_up_ba_;
using Base::gate_bc_;
using Base::up_bc_;
using Base::down_ba_;
using Base::down_bc_;
using Base::m_local_num_;
std::filesystem::path prefix;
void* gate_proj_; // [expert_num * intermediate_size * hidden_size ( /32 if
@@ -44,27 +41,6 @@ class AMX_MOE_TP {
void* down_proj_; // [expert_num * hidden_size * intermediate_size ( /32 if
// quantized)]
ggml_bf16_t* m_local_input_; // [num_experts_per_tok * max_len * hidden_size]
ggml_bf16_t* m_local_gate_output_; // [num_experts_per_tok * max_len * intermediate_size]
ggml_bf16_t* m_local_up_output_; // [num_experts_per_tok * max_len * intermediate_size]
ggml_bf16_t* m_local_down_output_; // [num_experts_per_tok * max_len * hidden_size]
std::vector<std::vector<int>> m_local_pos_; // [max_len, num_experts_per_tok]
std::vector<int> m_local_num_; // [expert_num]
std::vector<int> m_expert_id_map_; // [expert_num]
std::vector<ggml_bf16_t*> m_local_input_ptr_; // [expert_num]
std::vector<ggml_bf16_t*> m_local_gate_output_ptr_; // [expert_num]
std::vector<ggml_bf16_t*> m_local_up_output_ptr_; // [expert_num]
std::vector<ggml_bf16_t*> m_local_down_output_ptr_; // [expert_num]
std::vector<std::shared_ptr<typename T::BufferA>> gate_up_ba_;
std::vector<std::shared_ptr<typename T::BufferB>> gate_bb_;
std::vector<std::shared_ptr<typename T::BufferC>> gate_bc_;
std::vector<std::shared_ptr<typename T::BufferB>> up_bb_;
std::vector<std::shared_ptr<typename T::BufferC>> up_bc_;
std::vector<std::shared_ptr<typename T::BufferA>> down_ba_;
std::vector<std::shared_ptr<typename T::BufferB>> down_bb_;
std::vector<std::shared_ptr<typename T::BufferC>> down_bc_;
#ifdef CHECK
char verify_bb[100000000];
char check_bb[100000000];
@@ -161,21 +137,15 @@ class AMX_MOE_TP {
#endif
public:
using input_t = ggml_bf16_t;
using output_t = float;
GeneralMOEConfig config_;
static constexpr double ELEMENT_SIZE = T::ELEMENT_SIZE;
AMX_MOE_TP() = default;
AMX_MOE_TP(GeneralMOEConfig config, int tp_part_idx) {
AMX_MOE_TP(GeneralMOEConfig config, int tp_part_idx = 0) : Base(config, tp_part_idx) {
printf("Creating AMX_MOE_TP %d at numa %d\n", tp_part_idx, numa_node_of_cpu(sched_getcpu()));
auto& load = config.load;
auto& save = config.save;
if (load && config.path == "") {
load = false;
}
auto& load = config_.load;
auto& save = config_.save;
prefix = config.path;
prefix = prefix / ("_layer_" + std::to_string(config.layer_idx)) / ("_numa_" + std::to_string(tp_part_idx));
prefix = config_.path;
prefix = prefix / ("_layer_" + std::to_string(config_.layer_idx)) / ("_numa_" + std::to_string(tp_part_idx));
if (save) {
std::cout << "Creating " << prefix << std::endl;
std::filesystem::create_directories(prefix);
@@ -188,78 +158,65 @@ class AMX_MOE_TP {
}
}
this->tp_part_idx = tp_part_idx;
config_ = config;
gate_proj_ = config_.gate_proj;
up_proj_ = config_.up_proj;
down_proj_ = config_.down_proj;
MemoryRequest mem_requests;
mem_requests.append_pointer(
&m_local_input_, sizeof(ggml_bf16_t) * config_.num_experts_per_tok * config_.max_len * config_.hidden_size);
mem_requests.append_pointer(&m_local_gate_output_, sizeof(ggml_bf16_t) * config_.num_experts_per_tok *
config_.max_len * config_.intermediate_size);
mem_requests.append_pointer(&m_local_up_output_, sizeof(ggml_bf16_t) * config_.num_experts_per_tok *
config_.max_len * config_.intermediate_size);
mem_requests.append_pointer(&m_local_down_output_, sizeof(ggml_bf16_t) * config_.num_experts_per_tok *
config_.max_len * config_.hidden_size);
m_local_pos_.resize(config_.max_len);
for (int i = 0; i < config_.max_len; i++) {
m_local_pos_[i].resize(config_.num_experts_per_tok);
}
m_expert_id_map_.resize(config_.expert_num);
m_local_num_.resize(config_.expert_num);
m_local_input_ptr_.resize(config_.expert_num);
m_local_gate_output_ptr_.resize(config_.expert_num);
m_local_up_output_ptr_.resize(config_.expert_num);
m_local_down_output_ptr_.resize(config_.expert_num);
// printf("tp part %d alloc layer %d, %f GB, on numa %d\n", tp_part_idx, config_.layer_idx,
// 1e-9 * config_.expert_num *
// (T::BufferB::required_size(config_.intermediate_size, config_.hidden_size) * 2 +
// T::BufferB::required_size(config_.hidden_size, config_.intermediate_size)),
// numa_node_of_cpu(sched_getcpu()));
for (size_t i = 0; i < config_.expert_num; i++) {
gate_up_ba_.push_back(std::make_shared<typename T::BufferA>(config_.max_len, config_.hidden_size, nullptr));
gate_bc_.push_back(std::make_shared<typename T::BufferC>(config_.max_len, config_.intermediate_size, nullptr));
up_bc_.push_back(std::make_shared<typename T::BufferC>(config_.max_len, config_.intermediate_size, nullptr));
down_ba_.push_back(std::make_shared<typename T::BufferA>(config_.max_len, config_.intermediate_size, nullptr));
down_bc_.push_back(std::make_shared<typename T::BufferC>(config_.max_len, config_.hidden_size, nullptr));
void* gate_bb_ptr =
std::aligned_alloc(64, T::BufferB::required_size(config_.intermediate_size, config_.hidden_size));
gate_bb_.push_back(
std::make_shared<typename T::BufferB>(config_.intermediate_size, config_.hidden_size, gate_bb_ptr));
void* up_bb_ptr =
std::aligned_alloc(64, T::BufferB::required_size(config_.intermediate_size, config_.hidden_size));
up_bb_.push_back(
std::make_shared<typename T::BufferB>(config_.intermediate_size, config_.hidden_size, up_bb_ptr));
void* down_bb_ptr =
std::aligned_alloc(64, T::BufferB::required_size(config_.hidden_size, config_.intermediate_size));
down_bb_.push_back(
std::make_shared<typename T::BufferB>(config_.hidden_size, config_.intermediate_size, down_bb_ptr));
}
for (int i = 0; i < config_.expert_num; i++) {
mem_requests.append_function([this, i](void* new_ptr) { gate_up_ba_[i]->set_data(new_ptr); },
T::BufferA::required_size(config_.max_len, config_.hidden_size));
mem_requests.append_function([this, i](void* new_ptr) { gate_bc_[i]->set_data(new_ptr); },
T::BufferC::required_size(config_.max_len, config_.intermediate_size));
mem_requests.append_function([this, i](void* new_ptr) { up_bc_[i]->set_data(new_ptr); },
T::BufferC::required_size(config_.max_len, config_.intermediate_size));
mem_requests.append_function([this, i](void* new_ptr) { down_ba_[i]->set_data(new_ptr); },
T::BufferA::required_size(config_.max_len, config_.intermediate_size));
mem_requests.append_function([this, i](void* new_ptr) { down_bc_[i]->set_data(new_ptr); },
T::BufferC::required_size(config_.max_len, config_.hidden_size));
}
shared_mem_buffer_numa.alloc(tp_part_idx, this, mem_requests);
}
~AMX_MOE_TP() {
// shared_mem_buffer_numa.dealloc(this);
~AMX_MOE_TP() = default;
// ============================================================================
// CRTP buffer creation - no group_size
// ============================================================================
size_t buffer_a_required_size_impl(size_t m, size_t k) const {
return T::BufferA::required_size(m, k);
}
size_t buffer_b_required_size_impl(size_t n, size_t k) const {
return T::BufferB::required_size(n, k);
}
size_t buffer_c_required_size_impl(size_t m, size_t n) const {
return T::BufferC::required_size(m, n);
}
std::shared_ptr<typename T::BufferA> make_buffer_a_impl(size_t m, size_t k, void* data) const {
return std::make_shared<typename T::BufferA>(m, k, data);
}
std::shared_ptr<typename T::BufferB> make_buffer_b_impl(size_t n, size_t k, void* data) const {
return std::make_shared<typename T::BufferB>(n, k, data);
}
std::shared_ptr<typename T::BufferC> make_buffer_c_impl(size_t m, size_t n, void* data) const {
return std::make_shared<typename T::BufferC>(m, n, data);
}
// ============================================================================
// CRTP virtual points - GEMM dispatch
// ============================================================================
void do_gate_up_gemm(bool do_up, int expert_idx, int ith, int nth, int qlen) {
int m = m_local_num_[expert_idx];
auto& ba = gate_up_ba_[expert_idx];
auto& bb = do_up ? up_bb_[expert_idx] : gate_bb_[expert_idx];
auto& bc = do_up ? up_bc_[expert_idx] : gate_bc_[expert_idx];
if (qlen > 4 * config_.expert_num / config_.num_experts_per_tok) {
amx::mat_mul(m, config_.intermediate_size, config_.hidden_size, ba, bb, bc, ith, nth);
} else {
amx::vec_mul(m, config_.intermediate_size, config_.hidden_size, ba, bb, bc, ith, nth);
}
}
void do_down_gemm(int expert_idx, int ith, int nth, int qlen) {
int m = m_local_num_[expert_idx];
auto& ba = down_ba_[expert_idx];
auto& bb = down_bb_[expert_idx];
auto& bc = down_bc_[expert_idx];
if (qlen > 4 * config_.expert_num / config_.num_experts_per_tok) {
amx::mat_mul(m, config_.hidden_size, config_.intermediate_size, ba, bb, bc, ith, nth);
} else {
amx::vec_mul(m, config_.hidden_size, config_.intermediate_size, ba, bb, bc, ith, nth);
}
}
void load_weights() {
auto pool = config_.pool->get_subpool(tp_part_idx);
@@ -401,434 +358,21 @@ class AMX_MOE_TP {
}
}
void warm_up() {
int qlen = config_.max_len;
std::vector<uint8_t> input(sizeof(ggml_bf16_t) * qlen * config_.hidden_size);
std::vector<uint8_t> output(sizeof(ggml_bf16_t) * qlen * config_.hidden_size);
std::vector<int64_t> expert_ids(qlen * config_.num_experts_per_tok);
std::vector<float> weights(qlen * config_.num_experts_per_tok);
for (int i = 0; i < qlen * config_.num_experts_per_tok; i++) {
expert_ids[i] = i % config_.expert_num;
weights[i] = 0.01;
}
forward(qlen, config_.num_experts_per_tok, expert_ids.data(), weights.data(), input.data(), output.data());
}
void forward(int qlen, int k, const int64_t* expert_ids, const float* weights, const void* input, void* output) {
if (qlen > 1) {
forward_prefill(qlen, k, expert_ids, weights, input, output);
} else {
forward_decode(k, expert_ids, weights, input, output);
}
}
#define DIRECT_OR_POOL_BY_QLEN(var, fn) \
do { \
if (qlen < 10) { \
for (int i = 0; i < (var); i++) { \
(fn)(i); \
} \
} else { \
pool->do_work_stealing_job((var), nullptr, (fn), nullptr); \
} \
} while (0)
#define MATMUL_OR_VECMUL_BY_QLEN(...) \
do { \
if (qlen > 4 * config_.expert_num / config_.num_experts_per_tok) { \
amx::mat_mul(__VA_ARGS__); \
} else { \
amx::vec_mul(__VA_ARGS__); \
} \
} while (0)
void forward_prefill(int qlen, int k, const int64_t* expert_ids, const float* weights, const void* input,
void* output) {
auto pool = config_.pool->get_subpool(tp_part_idx);
#ifdef FORWARD_TIME_PROFILE
auto start_time = std::chrono::high_resolution_clock::now();
auto last = start_time;
// 用于保存各阶段耗时(单位:微秒)
long prepare_time = 0, cpy_input_time = 0, q_input_time = 0, up_gate_time = 0;
long act_time = 0, q_down_time = 0, down_time = 0, weight_time = 0;
int max_local_num = 0; // 记录最大的 local num
#endif
int activated_expert = 0;
for (int i = 0; i < config_.expert_num; i++) {
m_local_num_[i] = 0;
}
for (int i = 0; i < qlen; i++) {
for (int j = 0; j < k; j++) {
if (expert_ids[i * k + j] < config_.num_gpu_experts || expert_ids[i * k + j] >= config_.expert_num) {
continue;
}
m_local_pos_[i][j] = m_local_num_[expert_ids[i * k + j]]++;
}
}
for (int i = 0; i < config_.expert_num; i++) {
if (m_local_num_[i] > 0) {
#ifdef FORWARD_TIME_PROFILE
max_local_num = std::max(max_local_num, m_local_num_[i]);
#endif
m_expert_id_map_[activated_expert] = i;
activated_expert++;
}
}
// activated_expert 已经统计完成
size_t offset = 0;
for (int i = 0; i < config_.expert_num; i++) {
m_local_input_ptr_[i] = m_local_input_ + offset * config_.hidden_size;
m_local_gate_output_ptr_[i] = m_local_gate_output_ + offset * config_.intermediate_size;
m_local_up_output_ptr_[i] = m_local_up_output_ + offset * config_.intermediate_size;
m_local_down_output_ptr_[i] = m_local_down_output_ + offset * config_.hidden_size;
offset += m_local_num_[i];
}
#ifdef FORWARD_TIME_PROFILE
{
auto now_time = std::chrono::high_resolution_clock::now();
prepare_time = std::chrono::duration_cast<std::chrono::microseconds>(now_time - last).count();
last = now_time;
}
#endif
DIRECT_OR_POOL_BY_QLEN(qlen, [&](int i) {
for (int j = 0; j < k; j++) {
if (expert_ids[i * k + j] < config_.num_gpu_experts || expert_ids[i * k + j] >= config_.expert_num) {
continue;
}
memcpy(m_local_input_ptr_[expert_ids[i * k + j]] + m_local_pos_[i][j] * config_.hidden_size,
(ggml_bf16_t*)input + i * config_.hidden_size, sizeof(ggml_bf16_t) * config_.hidden_size);
}
});
#ifdef FORWARD_TIME_PROFILE
{
auto now_time = std::chrono::high_resolution_clock::now();
cpy_input_time = std::chrono::duration_cast<std::chrono::microseconds>(now_time - last).count();
last = now_time;
}
#endif
DIRECT_OR_POOL_BY_QLEN(activated_expert, [this](int task_id) {
int expert_idx = m_expert_id_map_[task_id];
gate_up_ba_[expert_idx]->from_mat(m_local_num_[expert_idx], m_local_input_ptr_[expert_idx], 0, 1);
});
#ifdef FORWARD_TIME_PROFILE
{
auto now_time = std::chrono::high_resolution_clock::now();
q_input_time = std::chrono::duration_cast<std::chrono::microseconds>(now_time - last).count();
last = now_time;
}
#endif
int nth = T::recommended_nth(config_.intermediate_size);
pool->do_work_stealing_job(
nth * activated_expert * 2, [](int _) { T::config(); },
[this, nth, qlen](int task_id2) {
int task_id = task_id2 / 2;
bool do_up = task_id2 % 2;
int expert_idx = m_expert_id_map_[task_id / nth];
int ith = task_id % nth;
if (do_up) {
MATMUL_OR_VECMUL_BY_QLEN(m_local_num_[expert_idx], config_.intermediate_size, config_.hidden_size,
gate_up_ba_[expert_idx], up_bb_[expert_idx], up_bc_[expert_idx], ith, nth);
up_bc_[expert_idx]->to_mat(m_local_num_[expert_idx], m_local_up_output_ptr_[expert_idx], ith, nth);
} else {
MATMUL_OR_VECMUL_BY_QLEN(m_local_num_[expert_idx], config_.intermediate_size, config_.hidden_size,
gate_up_ba_[expert_idx], gate_bb_[expert_idx], gate_bc_[expert_idx], ith, nth);
gate_bc_[expert_idx]->to_mat(m_local_num_[expert_idx], m_local_gate_output_ptr_[expert_idx], ith, nth);
}
},
nullptr);
#ifdef FORWARD_TIME_PROFILE
{
auto now_time = std::chrono::high_resolution_clock::now();
up_gate_time = std::chrono::duration_cast<std::chrono::microseconds>(now_time - last).count();
last = now_time;
}
#endif
auto up_gate_fn = [this, nth](int task_id) {
int expert_idx = m_expert_id_map_[task_id / nth];
int ith = task_id % nth;
auto [n_start, n_end] = T::split_range_n(config_.intermediate_size, ith, nth);
for (int i = 0; i < m_local_num_[expert_idx]; i++) {
ggml_bf16_t* gate_output_ptr = &m_local_gate_output_ptr_[expert_idx][i * config_.intermediate_size];
ggml_bf16_t* up_output_ptr = &m_local_up_output_ptr_[expert_idx][i * config_.intermediate_size];
for (int j = n_start; j < n_end; j += 32) {
__m512 gate_val0, gate_val1, up_val0, up_val1;
avx512_32xbf16_to_32xfp32((__m512i*)(gate_output_ptr + j), &gate_val0, &gate_val1);
avx512_32xbf16_to_32xfp32((__m512i*)(up_output_ptr + j), &up_val0, &up_val1);
__m512 result0 = amx::act_fn(gate_val0, up_val0);
__m512 result1 = amx::act_fn(gate_val1, up_val1);
avx512_32xfp32_to_32xbf16(&result0, &result1, (__m512i*)(gate_output_ptr + j));
}
}
};
DIRECT_OR_POOL_BY_QLEN(nth * activated_expert, up_gate_fn);
#ifdef FORWARD_TIME_PROFILE
{
auto now_time = std::chrono::high_resolution_clock::now();
act_time = std::chrono::duration_cast<std::chrono::microseconds>(now_time - last).count();
last = now_time;
}
#endif
pool->do_work_stealing_job(
activated_expert, nullptr,
[this](int task_id) {
int expert_idx = m_expert_id_map_[task_id];
down_ba_[expert_idx]->from_mat(m_local_num_[expert_idx], m_local_gate_output_ptr_[expert_idx], 0, 1);
},
nullptr);
#ifdef FORWARD_TIME_PROFILE
{
auto now_time = std::chrono::high_resolution_clock::now();
q_down_time = std::chrono::duration_cast<std::chrono::microseconds>(now_time - last).count();
last = now_time;
}
#endif
nth = T::recommended_nth(config_.hidden_size);
pool->do_work_stealing_job(
nth * activated_expert, [](int _) { T::config(); },
[this, nth, qlen](int task_id) {
int expert_idx = m_expert_id_map_[task_id / nth];
int ith = task_id % nth;
MATMUL_OR_VECMUL_BY_QLEN(m_local_num_[expert_idx], config_.hidden_size, config_.intermediate_size,
down_ba_[expert_idx], down_bb_[expert_idx], down_bc_[expert_idx], ith, nth);
down_bc_[expert_idx]->to_mat(m_local_num_[expert_idx], m_local_down_output_ptr_[expert_idx], ith, nth);
},
nullptr);
#ifdef FORWARD_TIME_PROFILE
{
auto now_time = std::chrono::high_resolution_clock::now();
down_time = std::chrono::duration_cast<std::chrono::microseconds>(now_time - last).count();
last = now_time;
}
#endif
pool->do_work_stealing_job(
qlen, nullptr,
[this, nth, output, k, expert_ids, weights](int i) {
for (int e = 0; e < config_.hidden_size; e += 32) {
__m512 x0 = _mm512_setzero_ps();
__m512 x1 = _mm512_setzero_ps();
for (int j = 0; j < k; j++) {
if (expert_ids[i * k + j] < config_.num_gpu_experts || expert_ids[i * k + j] >= config_.expert_num) {
continue;
}
__m512 weight = _mm512_set1_ps(weights[i * k + j]);
__m512 down_output0, down_output1;
avx512_32xbf16_to_32xfp32((__m512i*)(m_local_down_output_ptr_[expert_ids[i * k + j]] +
m_local_pos_[i][j] * config_.hidden_size + e),
&down_output0, &down_output1);
x0 = _mm512_fmadd_ps(down_output0, weight, x0);
x1 = _mm512_fmadd_ps(down_output1, weight, x1);
}
auto f32out = (__m512*)((float*)output + i * config_.hidden_size + e);
f32out[0] = x0;
f32out[1] = x1;
}
},
nullptr);
#ifdef FORWARD_TIME_PROFILE
{
auto now_time = std::chrono::high_resolution_clock::now();
weight_time = std::chrono::duration_cast<std::chrono::microseconds>(now_time - last).count();
last = now_time;
}
auto end_time = std::chrono::high_resolution_clock::now();
auto forward_total_time = std::chrono::duration_cast<std::chrono::microseconds>(end_time - start_time).count();
// 在函数末尾一次性打印所有阶段的耗时,并附带 max_local_num 和 qlen
printf(
"Profiling Results (numa[%d]): activated_expert: %d, prepare: %ld us, cpy_input: %ld us, q_input: %ld us, "
"up_gate: %ld us, act: %ld us, q_down: %ld us, down: %ld us, weight: %ld us, total: %ld us, max_local_num: "
"%d, qlen: %d\n",
tp_part_idx, activated_expert, prepare_time, cpy_input_time, q_input_time, up_gate_time, act_time, q_down_time,
down_time, weight_time, forward_total_time, max_local_num, qlen);
#endif
}
void forward_decode(int k, const int64_t* expert_ids, const float* weights, const void* input, void* output) {
int qlen = 1;
auto pool = config_.pool->get_subpool(tp_part_idx);
#ifdef FORWARD_TIME_PROFILE
auto start_time = std::chrono::high_resolution_clock::now();
auto last = start_time;
// 用于保存各阶段耗时(单位:微秒)
long prepare_time = 0, cpy_input_time = 0, q_input_time = 0, up_gate_time = 0;
long act_time = 0, q_down_time = 0, down_time = 0, weight_time = 0;
int max_local_num = 0; // 记录最大的 local num
#endif
int activated_expert = 0;
for (int i = 0; i < k; i++) {
if (expert_ids[i] < config_.num_gpu_experts || expert_ids[i] >= config_.expert_num) {
continue;
}
m_expert_id_map_[activated_expert] = expert_ids[i];
activated_expert++;
}
size_t offset = 0;
for (int i = 0; i < activated_expert; i++) {
auto expert_idx = m_expert_id_map_[i];
m_local_gate_output_ptr_[expert_idx] = m_local_gate_output_ + offset * config_.intermediate_size;
m_local_up_output_ptr_[expert_idx] = m_local_up_output_ + offset * config_.intermediate_size;
m_local_down_output_ptr_[expert_idx] = m_local_down_output_ + offset * config_.hidden_size;
offset += qlen;
}
gate_up_ba_[0]->from_mat(qlen, (ggml_bf16_t*)input, 0, 1);
#ifdef FORWARD_TIME_PROFILE
{
auto now_time = std::chrono::high_resolution_clock::now();
q_input_time = std::chrono::duration_cast<std::chrono::microseconds>(now_time - last).count();
last = now_time;
}
#endif
int nth = T::recommended_nth(config_.intermediate_size);
pool->do_work_stealing_job(
nth * activated_expert * 2, [](int _) { T::config(); },
[this, nth, qlen](int task_id2) {
int task_id = task_id2 / 2;
bool do_up = task_id2 % 2;
int expert_idx = m_expert_id_map_[task_id / nth];
int ith = task_id % nth;
if (do_up) {
amx::vec_mul(qlen, config_.intermediate_size, config_.hidden_size, gate_up_ba_[0], up_bb_[expert_idx],
up_bc_[expert_idx], ith, nth);
up_bc_[expert_idx]->to_mat(qlen, m_local_up_output_ptr_[expert_idx], ith, nth);
} else {
amx::vec_mul(qlen, config_.intermediate_size, config_.hidden_size, gate_up_ba_[0], gate_bb_[expert_idx],
gate_bc_[expert_idx], ith, nth);
gate_bc_[expert_idx]->to_mat(qlen, m_local_gate_output_ptr_[expert_idx], ith, nth);
}
},
nullptr);
#ifdef FORWARD_TIME_PROFILE
{
auto now_time = std::chrono::high_resolution_clock::now();
up_gate_time = std::chrono::duration_cast<std::chrono::microseconds>(now_time - last).count();
last = now_time;
}
#endif
for (int task_id = 0; task_id < nth * activated_expert; task_id++) {
int expert_idx = m_expert_id_map_[task_id / nth];
int ith = task_id % nth;
auto [n_start, n_end] = T::split_range_n(config_.intermediate_size, ith, nth);
for (int i = 0; i < qlen; i++) {
ggml_bf16_t* gate_output_ptr = &m_local_gate_output_ptr_[expert_idx][i * config_.intermediate_size];
ggml_bf16_t* up_output_ptr = &m_local_up_output_ptr_[expert_idx][i * config_.intermediate_size];
for (int j = n_start; j < n_end; j += 32) {
__m512 gate_val0, gate_val1, up_val0, up_val1;
avx512_32xbf16_to_32xfp32((__m512i*)(gate_output_ptr + j), &gate_val0, &gate_val1);
avx512_32xbf16_to_32xfp32((__m512i*)(up_output_ptr + j), &up_val0, &up_val1);
__m512 result0 = amx::act_fn(gate_val0, up_val0);
__m512 result1 = amx::act_fn(gate_val1, up_val1);
avx512_32xfp32_to_32xbf16(&result0, &result1, (__m512i*)(gate_output_ptr + j));
}
}
}
#ifdef FORWARD_TIME_PROFILE
{
auto now_time = std::chrono::high_resolution_clock::now();
act_time = std::chrono::duration_cast<std::chrono::microseconds>(now_time - last).count();
last = now_time;
}
#endif
pool->do_work_stealing_job(
activated_expert, nullptr,
[this, qlen](int task_id) {
int expert_idx = m_expert_id_map_[task_id];
down_ba_[expert_idx]->from_mat(qlen, m_local_gate_output_ptr_[expert_idx], 0, 1);
},
nullptr);
#ifdef FORWARD_TIME_PROFILE
{
auto now_time = std::chrono::high_resolution_clock::now();
q_down_time = std::chrono::duration_cast<std::chrono::microseconds>(now_time - last).count();
last = now_time;
}
#endif
nth = T::recommended_nth(config_.hidden_size);
pool->do_work_stealing_job(
nth * activated_expert, [](int _) { T::config(); },
[this, nth, qlen](int task_id) {
int expert_idx = m_expert_id_map_[task_id / nth];
int ith = task_id % nth;
amx::vec_mul(qlen, config_.hidden_size, config_.intermediate_size, down_ba_[expert_idx], down_bb_[expert_idx],
down_bc_[expert_idx], ith, nth);
down_bc_[expert_idx]->to_mat(qlen, m_local_down_output_ptr_[expert_idx], ith, nth);
},
nullptr);
#ifdef FORWARD_TIME_PROFILE
{
auto now_time = std::chrono::high_resolution_clock::now();
down_time = std::chrono::duration_cast<std::chrono::microseconds>(now_time - last).count();
last = now_time;
}
#endif
for (int i = 0; i < qlen; i++) {
for (int e = 0; e < config_.hidden_size; e += 32) {
__m512 x0 = _mm512_setzero_ps();
__m512 x1 = _mm512_setzero_ps();
for (int j = 0; j < k; j++) {
if (expert_ids[i * k + j] < config_.num_gpu_experts || expert_ids[i * k + j] >= config_.expert_num) {
continue;
}
__m512 weight = _mm512_set1_ps(weights[i * k + j]);
__m512 down_output0, down_output1;
avx512_32xbf16_to_32xfp32((__m512i*)(m_local_down_output_ptr_[expert_ids[i * k + j]] +
m_local_pos_[i][j] * config_.hidden_size + e),
&down_output0, &down_output1);
x0 = _mm512_fmadd_ps(down_output0, weight, x0);
x1 = _mm512_fmadd_ps(down_output1, weight, x1);
}
auto f32out = (__m512*)((float*)output + i * config_.hidden_size + e);
f32out[0] = x0;
f32out[1] = x1;
}
}
#ifdef FORWARD_TIME_PROFILE
{
auto now_time = std::chrono::high_resolution_clock::now();
weight_time = std::chrono::duration_cast<std::chrono::microseconds>(now_time - last).count();
last = now_time;
}
auto end_time = std::chrono::high_resolution_clock::now();
auto forward_total_time = std::chrono::duration_cast<std::chrono::microseconds>(end_time - start_time).count();
// 在函数末尾一次性打印所有阶段的耗时,并附带 max_local_num 和 qlen
printf(
"Profiling Results (numa[%d]) decode: activated_expert: %d, q_input: %ld us, "
"up_gate: %ld us, act: %ld us, q_down: %ld us, down: %ld us, weight: %ld us, total: %ld us\n",
tp_part_idx, activated_expert, q_input_time, up_gate_time, act_time, q_down_time, down_time, weight_time,
forward_total_time);
#endif
}
// forward, forward_prefill, forward_decode, warm_up are inherited from Base
};
// ============================================================================
// TP_MOE specialization for AMX_MOE_TP
// Inherits from TP_MOE<AMX_MOE_BASE<...>> to reuse merge_results implementation
// ============================================================================
template <typename K>
class TP_MOE<AMX_MOE_TP<K>> : public TP_MOE_Common<AMX_MOE_TP<K>> {
class TP_MOE<AMX_MOE_TP<K>> : public TP_MOE<AMX_MOE_BASE<K, AMX_MOE_TP<K>>> {
public:
using TP_MOE_Common<AMX_MOE_TP<K>>::TP_MOE_Common;
void load_weights() {
using Base = TP_MOE<AMX_MOE_BASE<K, AMX_MOE_TP<K>>>;
using Base::Base;
void load_weights() override {
auto& config = this->config;
auto& tps = this->tps;
auto& tp_count = this->tp_count;
@@ -836,7 +380,6 @@ class TP_MOE<AMX_MOE_TP<K>> : public TP_MOE_Common<AMX_MOE_TP<K>> {
const uint64_t* physical_to_logical_map = (const uint64_t*)config.physical_to_logical_map;
if (config.gate_projs.empty() == false) {
printf("TP Load from loader\n");
// pool->dispense_backend()->do_numa_job([this, pool](int numa_id) { this->tps[numa_id]->load_weights(); });
DO_TPS_LOAD_WEIGHTS(pool);
this->weights_loaded = true;
} else if (config.gate_proj != nullptr) {
@@ -872,7 +415,6 @@ class TP_MOE<AMX_MOE_TP<K>> : public TP_MOE_Common<AMX_MOE_TP<K>> {
}
}
// pool->dispense_backend()->do_numa_job([this, pool](int numa_id) { this->tps[numa_id]->load_weights(); });
DO_TPS_LOAD_WEIGHTS(pool);
for (auto i = 0; i < tp_count; i++) {
@@ -885,7 +427,6 @@ class TP_MOE<AMX_MOE_TP<K>> : public TP_MOE_Common<AMX_MOE_TP<K>> {
this->weights_loaded = true;
} else if (config.path != "") {
printf("TP Load from file\n");
// pool->dispense_backend()->do_numa_job([this, pool](int numa_id) { this->tps[numa_id]->load_weights(); });
DO_TPS_LOAD_WEIGHTS(pool);
this->weights_loaded = true;
} else {
@@ -893,37 +434,7 @@ class TP_MOE<AMX_MOE_TP<K>> : public TP_MOE_Common<AMX_MOE_TP<K>> {
}
}
void merge_results(int qlen, void* output, bool incremental) {
auto pool = this->config.pool;
auto merge_fn = [this, output, incremental](int token_nth) {
auto& local_output_numa = this->local_output_numa;
auto& tp_configs = this->tp_configs;
auto& tp_count = this->tp_count;
auto& config = this->config;
float* merge_to = local_output_numa[0] + token_nth * tp_configs[0].hidden_size;
if (incremental) {
for (int e = 0; e < config.hidden_size; e += 32) {
__m512 x0, x1;
avx512_32xbf16_to_32xfp32((__m512i*)((ggml_bf16_t*)output + token_nth * config.hidden_size + e), &x0, &x1);
*((__m512*)(merge_to + e)) = _mm512_add_ps(*((__m512*)(merge_to + e)), x0);
*((__m512*)(merge_to + e + 16)) = _mm512_add_ps(*((__m512*)(merge_to + e + 16)), x1);
}
}
for (int i = 1; i < tp_count; i++) {
float* merge_from = local_output_numa[i] + token_nth * tp_configs[i].hidden_size;
for (int e = 0; e < tp_configs[i].hidden_size; e += 16) {
*((__m512*)(merge_to + e)) = _mm512_add_ps(*((__m512*)(merge_to + e)), *((__m512*)(merge_from + e)));
}
}
for (int e = 0; e < config.hidden_size; e += 32) {
__m512 x0 = *(__m512*)(merge_to + e);
__m512 x1 = *(__m512*)(merge_to + e + 16);
avx512_32xfp32_to_32xbf16(&x0, &x1, (__m512i*)((ggml_bf16_t*)output + token_nth * config.hidden_size + e));
}
};
DIRECT_OR_POOL_BY_QLEN(qlen, merge_fn);
}
void merge_results(int qlen, void* output) { merge_results(qlen, output, false); }
// merge_results is inherited from TP_MOE<AMX_MOE_BASE<K, AMX_MOE_TP<K>>>
};
#endif

View File

@@ -0,0 +1,763 @@
/**
* @Description : Common AMX MoE base class extracted from K2 implementation.
* @Author : oql, Codex and Claude
* @Date : 2025-12-09
* @Version : 0.1.0
* @LastEditors : oql, Codex and Claude
* @LastEditTime : 2025-12-09
* @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
**/
#ifndef CPUINFER_OPERATOR_AMX_MOE_BASE_H
#define CPUINFER_OPERATOR_AMX_MOE_BASE_H
// #define FORWARD_TIME_PROFILE
#include <immintrin.h>
#include <algorithm>
#include <chrono>
#include <cmath>
#include <cstddef>
#include <cstdint>
#include <cstdio>
#include <cstring>
#include <filesystem>
#include <fstream>
#include <string>
#include <utility>
#include <vector>
#include "../../cpu_backend/shared_mem_buffer.h"
#include "../../cpu_backend/worker_pool.h"
#include "../common.hpp"
#include "../moe-tp.hpp"
#include "la/amx.hpp"
#include "llama.cpp/ggml.h"
template <class T, class Derived>
class AMX_MOE_BASE {
public:
int tp_part_idx = 0;
ggml_bf16_t* m_local_input_ = nullptr;
ggml_bf16_t* m_local_gate_output_ = nullptr;
ggml_bf16_t* m_local_up_output_ = nullptr;
ggml_bf16_t* m_local_down_output_ = nullptr;
std::vector<std::vector<int>> m_local_pos_;
std::vector<int> m_local_num_;
std::vector<int> m_expert_id_map_;
std::vector<ggml_bf16_t*> m_local_input_ptr_;
std::vector<ggml_bf16_t*> m_local_gate_output_ptr_;
std::vector<ggml_bf16_t*> m_local_up_output_ptr_;
std::vector<ggml_bf16_t*> m_local_down_output_ptr_;
std::vector<std::shared_ptr<typename T::BufferA>> gate_up_ba_;
std::vector<std::shared_ptr<typename T::BufferB>> gate_bb_;
std::vector<std::shared_ptr<typename T::BufferC>> gate_bc_;
std::vector<std::shared_ptr<typename T::BufferB>> up_bb_;
std::vector<std::shared_ptr<typename T::BufferC>> up_bc_;
std::vector<std::shared_ptr<typename T::BufferA>> down_ba_;
std::vector<std::shared_ptr<typename T::BufferB>> down_bb_;
std::vector<std::shared_ptr<typename T::BufferC>> down_bc_;
size_t pool_count_ = 0;
size_t gate_up_ba_pool_bytes_ = 0;
size_t gate_bc_pool_bytes_ = 0;
size_t up_bc_pool_bytes_ = 0;
size_t down_ba_pool_bytes_ = 0;
size_t down_bc_pool_bytes_ = 0;
void* gate_up_ba_pool_ = nullptr;
void* gate_bc_pool_ = nullptr;
void* up_bc_pool_ = nullptr;
void* down_ba_pool_ = nullptr;
void* down_bc_pool_ = nullptr;
GeneralMOEConfig config_;
using input_t = ggml_bf16_t;
using output_t = float;
static constexpr double ELEMENT_SIZE = T::ELEMENT_SIZE;
AMX_MOE_BASE(GeneralMOEConfig config, int tp_part_idx_) : tp_part_idx(tp_part_idx_), config_(config) { init(); }
void init() {
if (config_.load && config_.path == "") {
config_.load = false;
}
MemoryRequest mem_requests;
mem_requests.append_pointer(
&m_local_input_, sizeof(ggml_bf16_t) * config_.num_experts_per_tok * config_.max_len * config_.hidden_size);
mem_requests.append_pointer(&m_local_gate_output_, sizeof(ggml_bf16_t) * config_.num_experts_per_tok *
config_.max_len * config_.intermediate_size);
mem_requests.append_pointer(&m_local_up_output_, sizeof(ggml_bf16_t) * config_.num_experts_per_tok *
config_.max_len * config_.intermediate_size);
mem_requests.append_pointer(&m_local_down_output_, sizeof(ggml_bf16_t) * config_.num_experts_per_tok *
config_.max_len * config_.hidden_size);
m_local_pos_.resize(config_.max_len);
for (int i = 0; i < config_.max_len; i++) {
m_local_pos_[i].resize(config_.num_experts_per_tok);
}
m_expert_id_map_.resize(config_.expert_num);
m_local_num_.resize(config_.expert_num);
m_local_input_ptr_.resize(config_.expert_num);
m_local_gate_output_ptr_.resize(config_.expert_num);
m_local_up_output_ptr_.resize(config_.expert_num);
m_local_down_output_ptr_.resize(config_.expert_num);
for (size_t i = 0; i < config_.expert_num; i++) {
gate_up_ba_.push_back(make_buffer_a(config_.max_len, config_.hidden_size, nullptr));
gate_bc_.push_back(make_buffer_c(config_.max_len, config_.intermediate_size, nullptr));
up_bc_.push_back(make_buffer_c(config_.max_len, config_.intermediate_size, nullptr));
down_ba_.push_back(make_buffer_a(config_.max_len, config_.intermediate_size, nullptr));
down_bc_.push_back(make_buffer_c(config_.max_len, config_.hidden_size, nullptr));
void* gate_bb_ptr =
std::aligned_alloc(64, buffer_b_required_size(config_.intermediate_size, config_.hidden_size));
gate_bb_.push_back(make_buffer_b(config_.intermediate_size, config_.hidden_size, gate_bb_ptr));
void* up_bb_ptr = std::aligned_alloc(64, buffer_b_required_size(config_.intermediate_size, config_.hidden_size));
up_bb_.push_back(make_buffer_b(config_.intermediate_size, config_.hidden_size, up_bb_ptr));
void* down_bb_ptr =
std::aligned_alloc(64, buffer_b_required_size(config_.hidden_size, config_.intermediate_size));
down_bb_.push_back(make_buffer_b(config_.hidden_size, config_.intermediate_size, down_bb_ptr));
}
// TODO: need update to all *.hpp
// (config_.expert_num * T::M_STEP) in pool_count_ is to ensure padding for each experts.
pool_count_ = config_.max_len * config_.num_experts_per_tok + config_.expert_num * T::M_STEP;
gate_up_ba_pool_bytes_ = buffer_a_required_size(pool_count_, config_.hidden_size) + pool_count_ * 64;
gate_bc_pool_bytes_ = buffer_c_required_size(pool_count_, config_.intermediate_size) + pool_count_ * 64;
up_bc_pool_bytes_ = buffer_c_required_size(pool_count_, config_.intermediate_size) + pool_count_ * 64;
down_ba_pool_bytes_ = buffer_a_required_size(pool_count_, config_.intermediate_size) + pool_count_ * 64;
down_bc_pool_bytes_ = buffer_c_required_size(pool_count_, config_.hidden_size) + pool_count_ * 64;
mem_requests.append_pointer(&gate_up_ba_pool_, gate_up_ba_pool_bytes_);
mem_requests.append_pointer(&gate_bc_pool_, gate_bc_pool_bytes_);
mem_requests.append_pointer(&up_bc_pool_, up_bc_pool_bytes_);
mem_requests.append_pointer(&down_ba_pool_, down_ba_pool_bytes_);
mem_requests.append_pointer(&down_bc_pool_, down_bc_pool_bytes_);
shared_mem_buffer_numa.alloc(tp_part_idx, this, mem_requests);
}
~AMX_MOE_BASE() = default;
void warm_up() {
int qlen = config_.max_len;
std::vector<uint8_t> input(sizeof(ggml_bf16_t) * qlen * config_.hidden_size);
std::vector<uint8_t> output(sizeof(ggml_bf16_t) * qlen * config_.hidden_size);
std::vector<int64_t> expert_ids(qlen * config_.num_experts_per_tok);
std::vector<float> weights(qlen * config_.num_experts_per_tok);
for (int i = 0; i < qlen * config_.num_experts_per_tok; i++) {
expert_ids[i] = i % config_.expert_num;
weights[i] = 0.01;
}
forward(qlen, config_.num_experts_per_tok, expert_ids.data(), weights.data(), input.data(), output.data());
}
void forward(int qlen, int k, const int64_t* expert_ids, const float* weights, const void* input, void* output) {
if (qlen > 1) {
forward_prefill(qlen, k, expert_ids, weights, input, output);
} else {
forward_decode(k, expert_ids, weights, input, output);
}
}
template <typename... Args>
void load_weights(Args&&... args) {
derived()->load_weights(std::forward<Args>(args)...);
}
template <typename... Args>
void write_weights_to_buffer(Args&&... args) const {
derived_const()->write_weights_to_buffer(std::forward<Args>(args)...);
}
void forward_prefill(int qlen, int k, const int64_t* expert_ids, const float* weights, const void* input,
void* output) {
auto pool = config_.pool->get_subpool(tp_part_idx);
#ifdef FORWARD_TIME_PROFILE
auto start_time = std::chrono::high_resolution_clock::now();
auto last = start_time;
long prepare_time = 0, cpy_input_time = 0, q_input_time = 0, up_gate_time = 0;
long act_time = 0, q_down_time = 0, down_time = 0, weight_time = 0;
int max_local_num = 0;
#endif
int activated_expert = 0;
std::fill(m_local_num_.begin(), m_local_num_.end(), 0);
for (int i = 0; i < qlen; i++) {
for (int j = 0; j < k; j++) {
if (expert_ids[i * k + j] < config_.num_gpu_experts || expert_ids[i * k + j] >= config_.expert_num) {
continue;
}
m_local_pos_[i][j] = m_local_num_[expert_ids[i * k + j]]++;
}
}
for (int i = 0; i < config_.expert_num; i++) {
if (m_local_num_[i] > 0) {
#ifdef FORWARD_TIME_PROFILE
max_local_num = std::max(max_local_num, m_local_num_[i]);
#endif
m_expert_id_map_[activated_expert] = i;
activated_expert++;
}
}
size_t offset = 0;
void* gate_up_ba_pool_ptr = gate_up_ba_pool_;
void* gate_bc_pool_ptr = gate_bc_pool_;
void* up_bc_pool_ptr = up_bc_pool_;
void* down_ba_pool_ptr = down_ba_pool_;
void* down_bc_pool_ptr = down_bc_pool_;
constexpr size_t M_STEP = T::M_STEP;
auto align64 = [](size_t v) { return (v + 63) & (~(size_t)63); };
size_t used_pool_m = 0;
size_t used_pool_bytes_a = 0, used_pool_bytes_bc_gate = 0, used_pool_bytes_bc_up = 0, used_pool_bytes_ba_down = 0,
used_pool_bytes_bc_down = 0;
for (int i = 0; i < config_.expert_num; i++) {
m_local_input_ptr_[i] = m_local_input_ + offset * config_.hidden_size;
m_local_gate_output_ptr_[i] = m_local_gate_output_ + offset * config_.intermediate_size;
m_local_up_output_ptr_[i] = m_local_up_output_ + offset * config_.intermediate_size;
m_local_down_output_ptr_[i] = m_local_down_output_ + offset * config_.hidden_size;
offset += m_local_num_[i];
if (m_local_num_[i] == 0) {
continue;
}
size_t max_m = (m_local_num_[i] + M_STEP - 1) / M_STEP * M_STEP;
gate_up_ba_[i]->max_m = max_m;
gate_up_ba_[i]->set_data(gate_up_ba_pool_ptr);
size_t ba_size = align64(buffer_a_required_size(max_m, config_.hidden_size));
gate_up_ba_pool_ptr = (void*)((uintptr_t)gate_up_ba_pool_ptr + ba_size);
gate_bc_[i]->max_m = max_m;
gate_bc_[i]->set_data(gate_bc_pool_ptr);
size_t bc_gate_size = align64(buffer_c_required_size(max_m, config_.intermediate_size));
gate_bc_pool_ptr = (void*)((uintptr_t)gate_bc_pool_ptr + bc_gate_size);
up_bc_[i]->max_m = max_m;
up_bc_[i]->set_data(up_bc_pool_ptr);
size_t bc_up_size = align64(buffer_c_required_size(max_m, config_.intermediate_size));
up_bc_pool_ptr = (void*)((uintptr_t)up_bc_pool_ptr + bc_up_size);
down_ba_[i]->max_m = max_m;
down_ba_[i]->set_data(down_ba_pool_ptr);
size_t ba_down_size = align64(buffer_a_required_size(max_m, config_.intermediate_size));
down_ba_pool_ptr = (void*)((uintptr_t)down_ba_pool_ptr + ba_down_size);
down_bc_[i]->max_m = max_m;
down_bc_[i]->set_data(down_bc_pool_ptr);
size_t bc_down_size = align64(buffer_c_required_size(max_m, config_.hidden_size));
down_bc_pool_ptr = (void*)((uintptr_t)down_bc_pool_ptr + bc_down_size);
used_pool_m += max_m;
used_pool_bytes_a += ba_size;
used_pool_bytes_bc_gate += bc_gate_size;
used_pool_bytes_bc_up += bc_up_size;
used_pool_bytes_ba_down += ba_down_size;
used_pool_bytes_bc_down += bc_down_size;
}
assert(used_pool_m <= pool_count_);
assert(used_pool_bytes_a <= gate_up_ba_pool_bytes_);
assert(used_pool_bytes_bc_gate <= gate_bc_pool_bytes_);
assert(used_pool_bytes_bc_up <= up_bc_pool_bytes_);
assert(used_pool_bytes_ba_down <= down_ba_pool_bytes_);
assert(used_pool_bytes_bc_down <= down_bc_pool_bytes_);
#ifdef FORWARD_TIME_PROFILE
{
auto now_time = std::chrono::high_resolution_clock::now();
prepare_time = std::chrono::duration_cast<std::chrono::microseconds>(now_time - last).count();
last = now_time;
}
#endif
auto direct_or_pool = [&](int count, auto&& fn) {
if (qlen < 10) {
for (int i = 0; i < count; i++) {
fn(i);
}
} else {
pool->do_work_stealing_job(count, nullptr, fn, nullptr);
}
};
direct_or_pool(qlen, [&](int i) {
for (int j = 0; j < k; j++) {
if (expert_ids[i * k + j] < config_.num_gpu_experts || expert_ids[i * k + j] >= config_.expert_num) {
continue;
}
memcpy(m_local_input_ptr_[expert_ids[i * k + j]] + m_local_pos_[i][j] * config_.hidden_size,
(ggml_bf16_t*)input + i * config_.hidden_size, sizeof(ggml_bf16_t) * config_.hidden_size);
}
});
#ifdef FORWARD_TIME_PROFILE
{
auto now_time = std::chrono::high_resolution_clock::now();
cpy_input_time = std::chrono::duration_cast<std::chrono::microseconds>(now_time - last).count();
last = now_time;
}
#endif
direct_or_pool(activated_expert, [this](int task_id) {
int expert_idx = m_expert_id_map_[task_id];
gate_up_ba_[expert_idx]->from_mat(m_local_num_[expert_idx], m_local_input_ptr_[expert_idx], 0, 1);
});
#ifdef FORWARD_TIME_PROFILE
{
auto now_time = std::chrono::high_resolution_clock::now();
q_input_time = std::chrono::duration_cast<std::chrono::microseconds>(now_time - last).count();
last = now_time;
}
#endif
int nth = T::recommended_nth(config_.intermediate_size);
pool->do_work_stealing_job(
nth * activated_expert * 2, [](int _) { T::config(); },
[this, nth, qlen](int task_id2) {
int task_id = task_id2 / 2;
bool do_up = task_id2 % 2;
int expert_idx = m_expert_id_map_[task_id / nth];
int ith = task_id % nth;
derived()->do_gate_up_gemm(do_up, expert_idx, ith, nth, qlen);
if (do_up) {
up_bc_[expert_idx]->to_mat(m_local_num_[expert_idx], m_local_up_output_ptr_[expert_idx], ith, nth);
} else {
gate_bc_[expert_idx]->to_mat(m_local_num_[expert_idx], m_local_gate_output_ptr_[expert_idx], ith, nth);
}
},
nullptr);
#ifdef FORWARD_TIME_PROFILE
{
auto now_time = std::chrono::high_resolution_clock::now();
up_gate_time = std::chrono::duration_cast<std::chrono::microseconds>(now_time - last).count();
last = now_time;
}
#endif
apply_activation(activated_expert, nth, qlen);
#ifdef FORWARD_TIME_PROFILE
{
auto now_time = std::chrono::high_resolution_clock::now();
act_time = std::chrono::duration_cast<std::chrono::microseconds>(now_time - last).count();
last = now_time;
}
#endif
pool->do_work_stealing_job(
activated_expert, nullptr,
[this](int task_id) {
int expert_idx = m_expert_id_map_[task_id];
down_ba_[expert_idx]->from_mat(m_local_num_[expert_idx], m_local_gate_output_ptr_[expert_idx], 0, 1);
},
nullptr);
#ifdef FORWARD_TIME_PROFILE
{
auto now_time = std::chrono::high_resolution_clock::now();
q_down_time = std::chrono::duration_cast<std::chrono::microseconds>(now_time - last).count();
last = now_time;
}
#endif
nth = T::recommended_nth(config_.hidden_size);
pool->do_work_stealing_job(
nth * activated_expert, [](int _) { T::config(); },
[this, nth, qlen](int task_id) {
int expert_idx = m_expert_id_map_[task_id / nth];
int ith = task_id % nth;
derived()->do_down_gemm(expert_idx, ith, nth, qlen);
down_bc_[expert_idx]->to_mat(m_local_num_[expert_idx], m_local_down_output_ptr_[expert_idx], ith, nth);
},
nullptr);
#ifdef FORWARD_TIME_PROFILE
{
auto now_time = std::chrono::high_resolution_clock::now();
down_time = std::chrono::duration_cast<std::chrono::microseconds>(now_time - last).count();
last = now_time;
}
#endif
pool->do_work_stealing_job(
qlen, nullptr,
[this, output, k, expert_ids, weights](int i) {
for (int e = 0; e < config_.hidden_size; e += 32) {
__m512 x0 = _mm512_setzero_ps();
__m512 x1 = _mm512_setzero_ps();
for (int j = 0; j < k; j++) {
if (expert_ids[i * k + j] < config_.num_gpu_experts || expert_ids[i * k + j] >= config_.expert_num) {
continue;
}
__m512 weight = _mm512_set1_ps(weights[i * k + j]);
__m512 down_output0, down_output1;
avx512_32xbf16_to_32xfp32((__m512i*)(m_local_down_output_ptr_[expert_ids[i * k + j]] +
m_local_pos_[i][j] * config_.hidden_size + e),
&down_output0, &down_output1);
x0 = _mm512_fmadd_ps(down_output0, weight, x0);
x1 = _mm512_fmadd_ps(down_output1, weight, x1);
}
auto f32out = (__m512*)((float*)output + i * config_.hidden_size + e);
f32out[0] = x0;
f32out[1] = x1;
}
},
nullptr);
#ifdef FORWARD_TIME_PROFILE
{
auto now_time = std::chrono::high_resolution_clock::now();
weight_time = std::chrono::duration_cast<std::chrono::microseconds>(now_time - last).count();
last = now_time;
}
auto end_time = std::chrono::high_resolution_clock::now();
auto forward_total_time = std::chrono::duration_cast<std::chrono::microseconds>(end_time - start_time).count();
printf(
"Profiling Results (numa[%d]): activated_expert: %d, prepare: %ld us, cpy_input: %ld us, q_input: %ld us, "
"up_gate: %ld us, act: %ld us, q_down: %ld us, down: %ld us, weight: %ld us, total: %ld us, max_local_num: "
"%d, qlen: %d\n",
tp_part_idx, activated_expert, prepare_time, cpy_input_time, q_input_time, up_gate_time, act_time, q_down_time,
down_time, weight_time, forward_total_time, max_local_num, qlen);
#endif
}
void forward_decode(int k, const int64_t* expert_ids, const float* weights, const void* input, void* output) {
int qlen = 1;
auto pool = config_.pool->get_subpool(tp_part_idx);
#ifdef FORWARD_TIME_PROFILE
auto start_time = std::chrono::high_resolution_clock::now();
auto last = start_time;
long q_input_time = 0, up_gate_time = 0, act_time = 0, q_down_time = 0, down_time = 0, weight_time = 0;
#endif
int activated_expert = 0;
std::fill(m_local_num_.begin(), m_local_num_.end(), 0);
for (int i = 0; i < k; i++) {
if (expert_ids[i] < config_.num_gpu_experts || expert_ids[i] >= config_.expert_num) {
continue;
}
m_expert_id_map_[activated_expert] = expert_ids[i];
m_local_pos_[0][i] = 0;
m_local_num_[expert_ids[i]] = qlen;
activated_expert++;
}
size_t offset = 0;
for (int i = 0; i < activated_expert; i++) {
auto expert_idx = m_expert_id_map_[i];
m_local_gate_output_ptr_[expert_idx] = m_local_gate_output_ + offset * config_.intermediate_size;
m_local_up_output_ptr_[expert_idx] = m_local_up_output_ + offset * config_.intermediate_size;
m_local_down_output_ptr_[expert_idx] = m_local_down_output_ + offset * config_.hidden_size;
offset += qlen;
}
void* gate_bc_pool_ptr = gate_bc_pool_;
void* up_bc_pool_ptr = up_bc_pool_;
void* down_ba_pool_ptr = down_ba_pool_;
void* down_bc_pool_ptr = down_bc_pool_;
constexpr size_t M_STEP = T::M_STEP;
auto align64 = [](size_t v) { return (v + 63) & (~(size_t)63); };
size_t used_pool_m = 0;
size_t used_pool_bytes_bc_gate = 0, used_pool_bytes_bc_up = 0, used_pool_bytes_ba_down = 0,
used_pool_bytes_bc_down = 0;
for (int i = 0; i < activated_expert; i++) {
auto expert_idx = m_expert_id_map_[i];
size_t max_m = (qlen + M_STEP - 1) / M_STEP * M_STEP;
gate_bc_[expert_idx]->max_m = max_m;
gate_bc_[expert_idx]->set_data(gate_bc_pool_ptr);
size_t bc_gate_size = align64(buffer_c_required_size(max_m, config_.intermediate_size));
gate_bc_pool_ptr = (void*)((uintptr_t)gate_bc_pool_ptr + bc_gate_size);
up_bc_[expert_idx]->max_m = max_m;
up_bc_[expert_idx]->set_data(up_bc_pool_ptr);
size_t bc_up_size = align64(buffer_c_required_size(max_m, config_.intermediate_size));
up_bc_pool_ptr = (void*)((uintptr_t)up_bc_pool_ptr + bc_up_size);
down_ba_[expert_idx]->max_m = max_m;
down_ba_[expert_idx]->set_data(down_ba_pool_ptr);
size_t ba_down_size = align64(buffer_a_required_size(max_m, config_.intermediate_size));
down_ba_pool_ptr = (void*)((uintptr_t)down_ba_pool_ptr + ba_down_size);
down_bc_[expert_idx]->max_m = max_m;
down_bc_[expert_idx]->set_data(down_bc_pool_ptr);
size_t bc_down_size = align64(buffer_c_required_size(max_m, config_.hidden_size));
down_bc_pool_ptr = (void*)((uintptr_t)down_bc_pool_ptr + bc_down_size);
used_pool_m += max_m;
used_pool_bytes_bc_gate += bc_gate_size;
used_pool_bytes_bc_up += bc_up_size;
used_pool_bytes_ba_down += ba_down_size;
used_pool_bytes_bc_down += bc_down_size;
}
assert(used_pool_m <= pool_count_);
assert(used_pool_bytes_bc_gate <= gate_bc_pool_bytes_);
assert(used_pool_bytes_bc_up <= up_bc_pool_bytes_);
assert(used_pool_bytes_ba_down <= down_ba_pool_bytes_);
assert(used_pool_bytes_bc_down <= down_bc_pool_bytes_);
void* gate_up_ba_pool_ptr = gate_up_ba_pool_;
for (int i = 0; i < activated_expert; i++) {
auto expert_idx = m_expert_id_map_[i];
size_t max_m = (qlen + M_STEP - 1) / M_STEP * M_STEP;
gate_up_ba_[expert_idx]->max_m = max_m;
gate_up_ba_[expert_idx]->set_data(gate_up_ba_pool_ptr);
size_t ba_size = align64(buffer_a_required_size(max_m, config_.hidden_size));
gate_up_ba_pool_ptr = (void*)((uintptr_t)gate_up_ba_pool_ptr + ba_size);
gate_up_ba_[expert_idx]->from_mat(qlen, (ggml_bf16_t*)input, 0, 1);
}
#ifdef FORWARD_TIME_PROFILE
{
auto now_time = std::chrono::high_resolution_clock::now();
q_input_time = std::chrono::duration_cast<std::chrono::microseconds>(now_time - last).count();
last = now_time;
}
#endif
int nth = T::recommended_nth(config_.intermediate_size);
pool->do_work_stealing_job(
nth * activated_expert * 2, [](int _) { T::config(); },
[this, nth, qlen](int task_id2) {
int task_id = task_id2 / 2;
bool do_up = task_id2 % 2;
int expert_idx = m_expert_id_map_[task_id / nth];
int ith = task_id % nth;
derived()->do_gate_up_gemm(do_up, expert_idx, ith, nth, qlen);
if (do_up) {
up_bc_[expert_idx]->to_mat(qlen, m_local_up_output_ptr_[expert_idx], ith, nth);
} else {
gate_bc_[expert_idx]->to_mat(qlen, m_local_gate_output_ptr_[expert_idx], ith, nth);
}
},
nullptr);
#ifdef FORWARD_TIME_PROFILE
{
auto now_time = std::chrono::high_resolution_clock::now();
up_gate_time = std::chrono::duration_cast<std::chrono::microseconds>(now_time - last).count();
last = now_time;
}
#endif
apply_activation(activated_expert, nth, qlen);
#ifdef FORWARD_TIME_PROFILE
{
auto now_time = std::chrono::high_resolution_clock::now();
act_time = std::chrono::duration_cast<std::chrono::microseconds>(now_time - last).count();
last = now_time;
}
#endif
pool->do_work_stealing_job(
activated_expert, nullptr,
[this, qlen](int task_id) {
int expert_idx = m_expert_id_map_[task_id];
down_ba_[expert_idx]->from_mat(qlen, m_local_gate_output_ptr_[expert_idx], 0, 1);
},
nullptr);
#ifdef FORWARD_TIME_PROFILE
{
auto now_time = std::chrono::high_resolution_clock::now();
q_down_time = std::chrono::duration_cast<std::chrono::microseconds>(now_time - last).count();
last = now_time;
}
#endif
nth = T::recommended_nth(config_.hidden_size);
pool->do_work_stealing_job(
nth * activated_expert, [](int _) { T::config(); },
[this, nth, qlen](int task_id) {
int expert_idx = m_expert_id_map_[task_id / nth];
int ith = task_id % nth;
derived()->do_down_gemm(expert_idx, ith, nth, qlen);
down_bc_[expert_idx]->to_mat(qlen, m_local_down_output_ptr_[expert_idx], ith, nth);
},
nullptr);
#ifdef FORWARD_TIME_PROFILE
{
auto now_time = std::chrono::high_resolution_clock::now();
down_time = std::chrono::duration_cast<std::chrono::microseconds>(now_time - last).count();
last = now_time;
}
#endif
for (int e = 0; e < config_.hidden_size; e += 32) {
__m512 x0 = _mm512_setzero_ps();
__m512 x1 = _mm512_setzero_ps();
for (int j = 0; j < k; j++) {
if (expert_ids[j] < config_.num_gpu_experts || expert_ids[j] >= config_.expert_num) {
continue;
}
__m512 weight = _mm512_set1_ps(weights[j]);
__m512 down_output0, down_output1;
avx512_32xbf16_to_32xfp32(
(__m512i*)(m_local_down_output_ptr_[expert_ids[j]] + m_local_pos_[0][j] * config_.hidden_size + e),
&down_output0, &down_output1);
x0 = _mm512_fmadd_ps(down_output0, weight, x0);
x1 = _mm512_fmadd_ps(down_output1, weight, x1);
}
auto f32out = (__m512*)((float*)output + e);
f32out[0] = x0;
f32out[1] = x1;
}
#ifdef FORWARD_TIME_PROFILE
{
auto now_time = std::chrono::high_resolution_clock::now();
weight_time = std::chrono::duration_cast<std::chrono::microseconds>(now_time - last).count();
last = now_time;
}
auto end_time = std::chrono::high_resolution_clock::now();
auto forward_total_time = std::chrono::duration_cast<std::chrono::microseconds>(end_time - start_time).count();
printf(
"Profiling Results (numa[%d]): activated_expert: %d, q_input: %ld us, "
"up_gate: %ld us, act: %ld us, q_down: %ld us, down: %ld us, weight: %ld us, total: %ld us\n",
tp_part_idx, activated_expert, q_input_time, up_gate_time, act_time, q_down_time, down_time, weight_time,
forward_total_time);
#endif
}
protected:
Derived* derived() { return static_cast<Derived*>(this); }
const Derived* derived_const() const { return static_cast<const Derived*>(this); }
// ============================================================================
// Virtual points for buffer creation and size calculation
// Default implementations use group_size (for KGroup quantization like K2)
// Derived classes (like moe.hpp) can override to not use group_size
// ============================================================================
size_t buffer_a_required_size(size_t m, size_t k) const { return derived_const()->buffer_a_required_size_impl(m, k); }
size_t buffer_b_required_size(size_t n, size_t k) const { return derived_const()->buffer_b_required_size_impl(n, k); }
size_t buffer_c_required_size(size_t m, size_t n) const { return derived_const()->buffer_c_required_size_impl(m, n); }
std::shared_ptr<typename T::BufferA> make_buffer_a(size_t m, size_t k, void* data) const {
return derived_const()->make_buffer_a_impl(m, k, data);
}
std::shared_ptr<typename T::BufferB> make_buffer_b(size_t n, size_t k, void* data) const {
return derived_const()->make_buffer_b_impl(n, k, data);
}
std::shared_ptr<typename T::BufferC> make_buffer_c(size_t m, size_t n, void* data) const {
return derived_const()->make_buffer_c_impl(m, n, data);
}
void apply_activation(int activated_expert, int nth, int qlen) {
auto pool = config_.pool->get_subpool(tp_part_idx);
auto fn = [this, nth](int task_id) {
int expert_idx = m_expert_id_map_[task_id / nth];
int ith = task_id % nth;
auto [n_start, n_end] = T::split_range_n(config_.intermediate_size, ith, nth);
for (int i = 0; i < m_local_num_[expert_idx]; i++) {
ggml_bf16_t* gate_output_ptr = &m_local_gate_output_ptr_[expert_idx][i * config_.intermediate_size];
ggml_bf16_t* up_output_ptr = &m_local_up_output_ptr_[expert_idx][i * config_.intermediate_size];
for (int j = n_start; j < n_end; j += 32) {
__m512 gate_val0, gate_val1, up_val0, up_val1;
avx512_32xbf16_to_32xfp32((__m512i*)(gate_output_ptr + j), &gate_val0, &gate_val1);
avx512_32xbf16_to_32xfp32((__m512i*)(up_output_ptr + j), &up_val0, &up_val1);
__m512 result0 = amx::act_fn(gate_val0, up_val0);
__m512 result1 = amx::act_fn(gate_val1, up_val1);
avx512_32xfp32_to_32xbf16(&result0, &result1, (__m512i*)(gate_output_ptr + j));
}
}
};
if (activated_expert == 0) {
return;
}
if (qlen < 10) {
for (int task_id = 0; task_id < nth * activated_expert; task_id++) {
fn(task_id);
}
} else {
pool->do_work_stealing_job(nth * activated_expert, nullptr, fn, nullptr);
}
}
};
// ============================================================================
// TP_MOE specialization for AMX_MOE_BASE derived classes
// ============================================================================
template <class T, class Derived>
class TP_MOE<AMX_MOE_BASE<T, Derived>> : public TP_MOE_Common<AMX_MOE_BASE<T, Derived>> {
public:
using TP_MOE_Common<AMX_MOE_BASE<T, Derived>>::TP_MOE_Common;
// Default load_weights implementation - can be overridden by derived TP_MOE classes
void load_weights() override { throw std::runtime_error("Not Implemented"); }
void write_weight_scale_to_buffer(int gpu_tp_count, int gpu_experts_num,
const std::vector<uintptr_t>& w13_weight_ptrs,
const std::vector<uintptr_t>& w13_scale_ptrs,
const std::vector<uintptr_t>& w2_weight_ptrs,
const std::vector<uintptr_t>& w2_scale_ptrs) {
throw std::runtime_error("Not Implemented");
}
void merge_results(int qlen, void* output, bool incremental) override {
auto& config = this->config;
auto& tp_count = this->tp_count;
auto& local_output_numa = this->local_output_numa;
auto& tp_configs = this->tp_configs;
auto merge_fn = [this, output, incremental, &config, &tp_count, &local_output_numa, &tp_configs](int token_nth) {
float* merge_to = local_output_numa[0] + token_nth * tp_configs[0].hidden_size;
if (incremental) {
for (int e = 0; e < config.hidden_size; e += 32) {
__m512 x0, x1;
avx512_32xbf16_to_32xfp32((__m512i*)((ggml_bf16_t*)output + token_nth * config.hidden_size + e), &x0, &x1);
*((__m512*)(merge_to + e)) = _mm512_add_ps(*((__m512*)(merge_to + e)), x0);
*((__m512*)(merge_to + e + 16)) = _mm512_add_ps(*((__m512*)(merge_to + e + 16)), x1);
}
}
for (int i = 1; i < tp_count; i++) {
float* merge_from = local_output_numa[i] + token_nth * tp_configs[i].hidden_size;
for (int e = 0; e < tp_configs[i].hidden_size; e += 16) {
*((__m512*)(merge_to + e)) = _mm512_add_ps(*((__m512*)(merge_to + e)), *((__m512*)(merge_from + e)));
}
}
for (int e = 0; e < config.hidden_size; e += 32) {
__m512 x0 = *(__m512*)(merge_to + e);
__m512 x1 = *(__m512*)(merge_to + e + 16);
avx512_32xfp32_to_32xbf16(&x0, &x1, (__m512i*)((ggml_bf16_t*)output + token_nth * config.hidden_size + e));
}
};
auto pool = config.pool;
auto direct_or_pool = [&](int count, auto&& fn) {
if (qlen < 10) {
for (int i = 0; i < count; i++) {
fn(i);
}
} else {
pool->do_work_stealing_job(count, nullptr, fn, nullptr);
}
};
direct_or_pool(qlen, merge_fn);
}
void merge_results(int qlen, void* output) override { merge_results(qlen, output, false); }
};
#endif // CPUINFER_OPERATOR_AMX_MOE_BASE_H

View File

@@ -27,6 +27,12 @@ dependencies = [
"numpy>=1.24.0",
"triton>=2.0.0",
"gguf>=0.17.0",
# CLI dependencies
"typer[all]>=0.9.0",
"rich>=13.0.0",
"pyyaml>=6.0",
"httpx>=0.25.0",
"packaging>=23.0",
# Development dependencies
"black>=25.9.0",
]
@@ -37,19 +43,35 @@ test = [
"psutil>=5.9.0",
]
[project.scripts]
kt = "kt_kernel.cli.main:main"
[project.urls]
Homepage = "https://github.com/kvcache-ai"
[tool.setuptools]
packages = ["kt_kernel", "kt_kernel.utils"]
packages = [
"kt_kernel",
"kt_kernel.utils",
"kt_kernel.cli",
"kt_kernel.cli.commands",
"kt_kernel.cli.config",
"kt_kernel.cli.utils",
"kt_kernel.cli.completions",
]
include-package-data = true
[tool.setuptools.package-dir]
kt_kernel = "python"
"kt_kernel.utils" = "python/utils"
"kt_kernel.cli" = "python/cli"
"kt_kernel.cli.commands" = "python/cli/commands"
"kt_kernel.cli.config" = "python/cli/config"
"kt_kernel.cli.utils" = "python/cli/utils"
"kt_kernel.cli.completions" = "python/cli/completions"
[tool.setuptools.package-data]
# (empty) placeholder if you later add resources
"kt_kernel.cli.completions" = ["*.bash", "*.fish", "_kt"]
[tool.setuptools.exclude-package-data]
# (empty)

View File

@@ -37,11 +37,13 @@ from __future__ import annotations
# Detect CPU and load optimal extension variant
from ._cpu_detect import initialize as _initialize_cpu
_kt_kernel_ext, __cpu_variant__ = _initialize_cpu()
# Make the extension module available to other modules in this package
import sys
sys.modules['kt_kernel_ext'] = _kt_kernel_ext
sys.modules["kt_kernel_ext"] = _kt_kernel_ext
# Also expose kt_kernel_ext as an attribute for backward compatibility
kt_kernel_ext = _kt_kernel_ext
@@ -53,25 +55,28 @@ from .experts import KTMoEWrapper
try:
# Try to get version from installed package metadata (works in installed environment)
from importlib.metadata import version, PackageNotFoundError
try:
__version__ = version('kt-kernel')
__version__ = version("kt-kernel")
except PackageNotFoundError:
# Package not installed, try to read from source tree version.py
import os
_root_version_file = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(__file__))), 'version.py')
_root_version_file = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(__file__))), "version.py")
if os.path.exists(_root_version_file):
_version_ns = {}
with open(_root_version_file, 'r', encoding='utf-8') as f:
with open(_root_version_file, "r", encoding="utf-8") as f:
exec(f.read(), _version_ns)
__version__ = _version_ns.get('__version__', '0.4.3')
__version__ = _version_ns.get("__version__", "0.4.3")
else:
__version__ = "0.4.3"
except ImportError:
# Python < 3.8, fallback to pkg_resources or hardcoded version
try:
from pkg_resources import get_distribution, DistributionNotFound
try:
__version__ = get_distribution('kt-kernel').version
__version__ = get_distribution("kt-kernel").version
except DistributionNotFound:
__version__ = "0.4.3"
except ImportError:

View File

@@ -17,6 +17,7 @@ Example:
>>> os.environ['KT_KERNEL_CPU_VARIANT'] = 'avx2'
>>> import kt_kernel # Will use AVX2 variant
"""
import os
import sys
from pathlib import Path
@@ -35,82 +36,82 @@ def detect_cpu_features():
str: 'amx', 'avx512', or 'avx2'
"""
# Check environment override
variant = os.environ.get('KT_KERNEL_CPU_VARIANT', '').lower()
if variant in ['amx', 'avx512', 'avx2']:
if os.environ.get('KT_KERNEL_DEBUG') == '1':
variant = os.environ.get("KT_KERNEL_CPU_VARIANT", "").lower()
if variant in ["amx", "avx512", "avx2"]:
if os.environ.get("KT_KERNEL_DEBUG") == "1":
print(f"[kt-kernel] Using environment override: {variant}")
return variant
# Try to read /proc/cpuinfo on Linux
try:
with open('/proc/cpuinfo', 'r') as f:
with open("/proc/cpuinfo", "r") as f:
cpuinfo = f.read().lower()
# Check for AMX support (Intel Sapphire Rapids+)
# AMX requires amx_tile, amx_int8, and amx_bf16
amx_flags = ['amx_tile', 'amx_int8', 'amx_bf16']
amx_flags = ["amx_tile", "amx_int8", "amx_bf16"]
has_amx = all(flag in cpuinfo for flag in amx_flags)
if has_amx:
if os.environ.get('KT_KERNEL_DEBUG') == '1':
if os.environ.get("KT_KERNEL_DEBUG") == "1":
print("[kt-kernel] Detected AMX support via /proc/cpuinfo")
return 'amx'
return "amx"
# Check for AVX512 support
# AVX512F is the foundation for all AVX512 variants
if 'avx512f' in cpuinfo:
if os.environ.get('KT_KERNEL_DEBUG') == '1':
if "avx512f" in cpuinfo:
if os.environ.get("KT_KERNEL_DEBUG") == "1":
print("[kt-kernel] Detected AVX512 support via /proc/cpuinfo")
return 'avx512'
return "avx512"
# Check for AVX2 support
if 'avx2' in cpuinfo:
if os.environ.get('KT_KERNEL_DEBUG') == '1':
if "avx2" in cpuinfo:
if os.environ.get("KT_KERNEL_DEBUG") == "1":
print("[kt-kernel] Detected AVX2 support via /proc/cpuinfo")
return 'avx2'
return "avx2"
# Fallback to AVX2 (should be rare on modern CPUs)
if os.environ.get('KT_KERNEL_DEBUG') == '1':
if os.environ.get("KT_KERNEL_DEBUG") == "1":
print("[kt-kernel] No AVX2/AVX512/AMX detected, using AVX2 fallback")
return 'avx2'
return "avx2"
except FileNotFoundError:
# /proc/cpuinfo doesn't exist (not Linux or in container)
# Try cpufeature package as fallback
if os.environ.get('KT_KERNEL_DEBUG') == '1':
if os.environ.get("KT_KERNEL_DEBUG") == "1":
print("[kt-kernel] /proc/cpuinfo not found, trying cpufeature package")
try:
import cpufeature
# Check for AMX
if cpufeature.CPUFeature.get('AMX_TILE', False):
if os.environ.get('KT_KERNEL_DEBUG') == '1':
if cpufeature.CPUFeature.get("AMX_TILE", False):
if os.environ.get("KT_KERNEL_DEBUG") == "1":
print("[kt-kernel] Detected AMX support via cpufeature")
return 'amx'
return "amx"
# Check for AVX512
if cpufeature.CPUFeature.get('AVX512F', False):
if os.environ.get('KT_KERNEL_DEBUG') == '1':
if cpufeature.CPUFeature.get("AVX512F", False):
if os.environ.get("KT_KERNEL_DEBUG") == "1":
print("[kt-kernel] Detected AVX512 support via cpufeature")
return 'avx512'
return "avx512"
# Fallback to AVX2
if os.environ.get('KT_KERNEL_DEBUG') == '1':
if os.environ.get("KT_KERNEL_DEBUG") == "1":
print("[kt-kernel] Using AVX2 fallback via cpufeature")
return 'avx2'
return "avx2"
except ImportError:
# cpufeature not available - ultimate fallback
if os.environ.get('KT_KERNEL_DEBUG') == '1':
if os.environ.get("KT_KERNEL_DEBUG") == "1":
print("[kt-kernel] cpufeature not available, using AVX2 fallback")
return 'avx2'
return "avx2"
except Exception as e:
# Any other error - safe fallback
if os.environ.get('KT_KERNEL_DEBUG') == '1':
if os.environ.get("KT_KERNEL_DEBUG") == "1":
print(f"[kt-kernel] Error during CPU detection: {e}, using AVX2 fallback")
return 'avx2'
return "avx2"
def load_extension(variant):
@@ -148,51 +149,53 @@ def load_extension(variant):
kt_kernel_dir = os.path.dirname(os.path.abspath(__file__))
# Try multi-variant naming first
pattern = os.path.join(kt_kernel_dir, f'_kt_kernel_ext_{variant}.*.so')
pattern = os.path.join(kt_kernel_dir, f"_kt_kernel_ext_{variant}.*.so")
so_files = glob.glob(pattern)
if not so_files:
# Try single-variant naming (fallback for builds without CPUINFER_BUILD_ALL_VARIANTS)
pattern = os.path.join(kt_kernel_dir, 'kt_kernel_ext.*.so')
pattern = os.path.join(kt_kernel_dir, "kt_kernel_ext.*.so")
so_files = glob.glob(pattern)
if so_files:
if os.environ.get('KT_KERNEL_DEBUG') == '1':
if os.environ.get("KT_KERNEL_DEBUG") == "1":
print(f"[kt-kernel] Multi-variant {variant} not found, using single-variant build")
else:
raise ImportError(f"No .so file found for variant {variant} (tried patterns: {kt_kernel_dir}/_kt_kernel_ext_{variant}.*.so and {kt_kernel_dir}/kt_kernel_ext.*.so)")
raise ImportError(
f"No .so file found for variant {variant} (tried patterns: {kt_kernel_dir}/_kt_kernel_ext_{variant}.*.so and {kt_kernel_dir}/kt_kernel_ext.*.so)"
)
so_file = so_files[0]
if os.environ.get('KT_KERNEL_DEBUG') == '1':
if os.environ.get("KT_KERNEL_DEBUG") == "1":
print(f"[kt-kernel] Loading {variant} from: {so_file}")
# Load the module manually
# The module exports PyInit_kt_kernel_ext, so we use that as the module name
spec = importlib.util.spec_from_file_location('kt_kernel_ext', so_file)
spec = importlib.util.spec_from_file_location("kt_kernel_ext", so_file)
if spec is None or spec.loader is None:
raise ImportError(f"Failed to create spec for {so_file}")
ext = importlib.util.module_from_spec(spec)
spec.loader.exec_module(ext)
if os.environ.get('KT_KERNEL_DEBUG') == '1':
if os.environ.get("KT_KERNEL_DEBUG") == "1":
print(f"[kt-kernel] Successfully loaded {variant.upper()} variant")
return ext
except (ImportError, ModuleNotFoundError, FileNotFoundError) as e:
if os.environ.get('KT_KERNEL_DEBUG') == '1':
if os.environ.get("KT_KERNEL_DEBUG") == "1":
print(f"[kt-kernel] Failed to load {variant} variant: {e}")
# Automatic fallback to next best variant
if variant == 'amx':
if os.environ.get('KT_KERNEL_DEBUG') == '1':
if variant == "amx":
if os.environ.get("KT_KERNEL_DEBUG") == "1":
print("[kt-kernel] Falling back from AMX to AVX512")
return load_extension('avx512')
elif variant == 'avx512':
if os.environ.get('KT_KERNEL_DEBUG') == '1':
return load_extension("avx512")
elif variant == "avx512":
if os.environ.get("KT_KERNEL_DEBUG") == "1":
print("[kt-kernel] Falling back from AVX512 to AVX2")
return load_extension('avx2')
return load_extension("avx2")
else:
# AVX2 is the last fallback - if this fails, we can't continue
raise ImportError(
@@ -221,13 +224,13 @@ def initialize():
# Detect CPU features
variant = detect_cpu_features()
if os.environ.get('KT_KERNEL_DEBUG') == '1':
if os.environ.get("KT_KERNEL_DEBUG") == "1":
print(f"[kt-kernel] Selected CPU variant: {variant}")
# Load the appropriate extension
ext = load_extension(variant)
if os.environ.get('KT_KERNEL_DEBUG') == '1':
if os.environ.get("KT_KERNEL_DEBUG") == "1":
print(f"[kt-kernel] Extension module loaded: {ext.__name__}")
return ext, variant

View File

@@ -0,0 +1,8 @@
"""
KTransformers CLI - A unified command-line interface for KTransformers.
This CLI provides a user-friendly interface to all KTransformers functionality,
including model inference, fine-tuning, benchmarking, and more.
"""
__version__ = "0.1.0"

View File

@@ -0,0 +1,3 @@
"""
Command modules for kt-cli.
"""

View File

@@ -0,0 +1,274 @@
"""
Bench commands for kt-cli.
Runs benchmarks for performance testing.
"""
import subprocess
import sys
from enum import Enum
from pathlib import Path
from typing import Optional
import typer
from kt_kernel.cli.i18n import t
from kt_kernel.cli.utils.console import (
console,
print_error,
print_info,
print_step,
print_success,
)
class BenchType(str, Enum):
"""Benchmark type."""
INFERENCE = "inference"
MLA = "mla"
MOE = "moe"
LINEAR = "linear"
ATTENTION = "attention"
ALL = "all"
def bench(
type: BenchType = typer.Option(
BenchType.ALL,
"--type",
"-t",
help="Benchmark type",
),
model: Optional[str] = typer.Option(
None,
"--model",
"-m",
help="Model to benchmark",
),
output: Optional[Path] = typer.Option(
None,
"--output",
"-o",
help="Output file for results (JSON)",
),
iterations: int = typer.Option(
10,
"--iterations",
"-n",
help="Number of iterations",
),
) -> None:
"""Run full benchmark suite."""
console.print()
print_step(t("bench_starting"))
print_info(t("bench_type", type=type.value))
console.print()
if type == BenchType.ALL:
_run_all_benchmarks(model, output, iterations)
elif type == BenchType.INFERENCE:
_run_inference_benchmark(model, output, iterations)
elif type == BenchType.MLA:
_run_component_benchmark("mla", output, iterations)
elif type == BenchType.MOE:
_run_component_benchmark("moe", output, iterations)
elif type == BenchType.LINEAR:
_run_component_benchmark("linear", output, iterations)
elif type == BenchType.ATTENTION:
_run_component_benchmark("attention", output, iterations)
console.print()
print_success(t("bench_complete"))
if output:
console.print(f" Results saved to: {output}")
console.print()
def microbench(
component: str = typer.Argument(
"moe",
help="Component to benchmark (moe, mla, linear, attention)",
),
batch_size: int = typer.Option(
1,
"--batch-size",
"-b",
help="Batch size",
),
seq_len: int = typer.Option(
1,
"--seq-len",
"-s",
help="Sequence length",
),
iterations: int = typer.Option(
100,
"--iterations",
"-n",
help="Number of iterations",
),
warmup: int = typer.Option(
10,
"--warmup",
"-w",
help="Warmup iterations",
),
output: Optional[Path] = typer.Option(
None,
"--output",
"-o",
help="Output file for results (JSON)",
),
) -> None:
"""Run micro-benchmark for specific components."""
console.print()
console.print(f"[yellow]{t('feature_coming_soon')}[/yellow]")
console.print()
raise typer.Exit(0)
# Try to find the benchmark script
kt_kernel_path = _find_kt_kernel_path()
if kt_kernel_path is None:
print_error("kt-kernel not found. Install with: kt install inference")
raise typer.Exit(1)
bench_dir = kt_kernel_path / "bench"
# Map component to script
component_scripts = {
"moe": "bench_moe.py",
"mla": "bench_mla.py",
"linear": "bench_linear.py",
"attention": "bench_attention.py",
"mlp": "bench_mlp.py",
}
script_name = component_scripts.get(component.lower())
if script_name is None:
print_error(f"Unknown component: {component}")
console.print(f"Available: {', '.join(component_scripts.keys())}")
raise typer.Exit(1)
script_path = bench_dir / script_name
if not script_path.exists():
print_error(f"Benchmark script not found: {script_path}")
raise typer.Exit(1)
# Run benchmark
cmd = [
sys.executable,
str(script_path),
"--batch-size",
str(batch_size),
"--seq-len",
str(seq_len),
"--iterations",
str(iterations),
"--warmup",
str(warmup),
]
if output:
cmd.extend(["--output", str(output)])
console.print(f"[dim]$ {' '.join(cmd)}[/dim]")
console.print()
try:
process = subprocess.run(cmd)
if process.returncode == 0:
console.print()
print_success(t("bench_complete"))
if output:
console.print(f" Results saved to: {output}")
else:
print_error(f"Benchmark failed with exit code {process.returncode}")
raise typer.Exit(process.returncode)
except FileNotFoundError as e:
print_error(f"Failed to run benchmark: {e}")
raise typer.Exit(1)
def _find_kt_kernel_path() -> Optional[Path]:
"""Find the kt-kernel installation path."""
try:
import kt_kernel
return Path(kt_kernel.__file__).parent.parent
except ImportError:
pass
# Check common locations
possible_paths = [
Path.home() / "Projects" / "ktransformers" / "kt-kernel",
Path("/opt/ktransformers/kt-kernel"),
Path.cwd() / "kt-kernel",
]
for path in possible_paths:
if path.exists() and (path / "bench").exists():
return path
return None
def _run_all_benchmarks(model: Optional[str], output: Optional[Path], iterations: int) -> None:
"""Run all benchmarks."""
components = ["moe", "mla", "linear", "attention"]
for component in components:
console.print(f"\n[bold]Running {component} benchmark...[/bold]")
_run_component_benchmark(component, None, iterations)
def _run_inference_benchmark(model: Optional[str], output: Optional[Path], iterations: int) -> None:
"""Run inference benchmark."""
if model is None:
print_error("Model required for inference benchmark. Use --model flag.")
raise typer.Exit(1)
print_info(f"Running inference benchmark on {model}...")
console.print()
console.print("[dim]This will start the server and run test requests.[/dim]")
console.print()
# TODO: Implement actual inference benchmarking
print_error("Inference benchmarking not yet implemented.")
def _run_component_benchmark(component: str, output: Optional[Path], iterations: int) -> None:
"""Run a component benchmark."""
kt_kernel_path = _find_kt_kernel_path()
if kt_kernel_path is None:
print_error("kt-kernel not found.")
return
bench_dir = kt_kernel_path / "bench"
script_map = {
"moe": "bench_moe.py",
"mla": "bench_mla.py",
"linear": "bench_linear.py",
"attention": "bench_attention.py",
}
script_name = script_map.get(component)
if script_name is None:
print_error(f"Unknown component: {component}")
return
script_path = bench_dir / script_name
if not script_path.exists():
print_error(f"Script not found: {script_path}")
return
cmd = [sys.executable, str(script_path), "--iterations", str(iterations)]
try:
subprocess.run(cmd)
except Exception as e:
print_error(f"Benchmark failed: {e}")

View File

@@ -0,0 +1,437 @@
"""
Chat command for kt-cli.
Provides interactive chat interface with running model server.
"""
import json
import os
import sys
from datetime import datetime
from pathlib import Path
from typing import Optional
import typer
from rich.console import Console
from rich.markdown import Markdown
from rich.panel import Panel
from rich.prompt import Prompt, Confirm
from kt_kernel.cli.config.settings import get_settings
from kt_kernel.cli.i18n import t
from kt_kernel.cli.utils.console import (
console,
print_error,
print_info,
print_success,
print_warning,
)
# Try to import OpenAI SDK
try:
from openai import OpenAI
HAS_OPENAI = True
except ImportError:
HAS_OPENAI = False
def chat(
host: Optional[str] = typer.Option(
None,
"--host",
"-H",
help="Server host address",
),
port: Optional[int] = typer.Option(
None,
"--port",
"-p",
help="Server port",
),
model: Optional[str] = typer.Option(
None,
"--model",
"-m",
help="Model name (if server hosts multiple models)",
),
temperature: float = typer.Option(
0.7,
"--temperature",
"-t",
help="Sampling temperature (0.0 to 2.0)",
),
max_tokens: int = typer.Option(
2048,
"--max-tokens",
help="Maximum tokens to generate",
),
system_prompt: Optional[str] = typer.Option(
None,
"--system",
"-s",
help="System prompt",
),
save_history: bool = typer.Option(
True,
"--save-history/--no-save-history",
help="Save conversation history",
),
history_file: Optional[Path] = typer.Option(
None,
"--history-file",
help="Path to save conversation history",
),
stream: bool = typer.Option(
True,
"--stream/--no-stream",
help="Enable streaming output",
),
) -> None:
"""Start interactive chat with a running model server.
Examples:
kt chat # Connect to default server
kt chat --host 127.0.0.1 -p 8080 # Connect to specific server
kt chat -t 0.9 --max-tokens 4096 # Adjust generation parameters
"""
if not HAS_OPENAI:
print_error("OpenAI Python SDK is required for chat functionality.")
console.print()
console.print("Install it with:")
console.print(" pip install openai")
raise typer.Exit(1)
settings = get_settings()
# Resolve server connection
final_host = host or settings.get("server.host", "127.0.0.1")
final_port = port or settings.get("server.port", 30000)
# Construct base URL for OpenAI-compatible API
base_url = f"http://{final_host}:{final_port}/v1"
console.print()
console.print(
Panel.fit(
f"[bold cyan]KTransformers Chat[/bold cyan]\n\n"
f"Server: [yellow]{final_host}:{final_port}[/yellow]\n"
f"Temperature: [cyan]{temperature}[/cyan] | Max tokens: [cyan]{max_tokens}[/cyan]\n\n"
f"[dim]Type '/help' for commands, '/quit' to exit[/dim]",
border_style="cyan",
)
)
console.print()
# Check for proxy environment variables
proxy_vars = ["HTTP_PROXY", "HTTPS_PROXY", "http_proxy", "https_proxy", "ALL_PROXY", "all_proxy"]
detected_proxies = {var: os.environ.get(var) for var in proxy_vars if os.environ.get(var)}
if detected_proxies:
proxy_info = ", ".join(f"{k}={v}" for k, v in detected_proxies.items())
console.print()
print_warning(t("chat_proxy_detected"))
console.print(f" [dim]{proxy_info}[/dim]")
console.print()
use_proxy = Confirm.ask(t("chat_proxy_confirm"), default=False)
if not use_proxy:
# Temporarily disable proxy for this connection
for var in proxy_vars:
if var in os.environ:
del os.environ[var]
print_info(t("chat_proxy_disabled"))
console.print()
# Initialize OpenAI client
try:
client = OpenAI(
base_url=base_url,
api_key="EMPTY", # SGLang doesn't require API key
)
# Test connection
print_info("Connecting to server...")
models = client.models.list()
available_models = [m.id for m in models.data]
if not available_models:
print_error("No models available on server")
raise typer.Exit(1)
# Select model
if model:
if model not in available_models:
print_warning(f"Model '{model}' not found. Available models: {', '.join(available_models)}")
selected_model = available_models[0]
else:
selected_model = model
else:
selected_model = available_models[0]
print_success(f"Connected to model: {selected_model}")
console.print()
except Exception as e:
print_error(f"Failed to connect to server: {e}")
console.print()
console.print("Make sure the model server is running:")
console.print(" kt run <model>")
raise typer.Exit(1)
# Initialize conversation history
messages = []
# Add system prompt if provided
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
# Setup history file
if save_history:
if history_file is None:
history_dir = settings.config_dir / "chat_history"
history_dir.mkdir(parents=True, exist_ok=True)
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
history_file = history_dir / f"chat_{timestamp}.json"
else:
history_file = Path(history_file)
history_file.parent.mkdir(parents=True, exist_ok=True)
# Main chat loop
try:
while True:
# Get user input
try:
user_input = Prompt.ask("[bold green]You[/bold green]")
except (EOFError, KeyboardInterrupt):
console.print()
print_info("Goodbye!")
break
if not user_input.strip():
continue
# Handle special commands
if user_input.startswith("/"):
if _handle_command(user_input, messages, temperature, max_tokens):
continue
else:
break # Exit command
# Add user message to history
messages.append({"role": "user", "content": user_input})
# Generate response
console.print()
console.print("[bold cyan]Assistant[/bold cyan]")
try:
if stream:
# Streaming response
response_content = _stream_response(client, selected_model, messages, temperature, max_tokens)
else:
# Non-streaming response
response_content = _generate_response(client, selected_model, messages, temperature, max_tokens)
# Add assistant response to history
messages.append({"role": "assistant", "content": response_content})
console.print()
except Exception as e:
print_error(f"Error generating response: {e}")
# Remove the user message that caused the error
messages.pop()
continue
# Save history if enabled
if save_history:
_save_history(history_file, messages, selected_model)
except KeyboardInterrupt:
console.print()
console.print()
print_info("Chat interrupted. Goodbye!")
# Final history save
if save_history and messages:
_save_history(history_file, messages, selected_model)
console.print(f"[dim]History saved to: {history_file}[/dim]")
console.print()
def _stream_response(
client: "OpenAI",
model: str,
messages: list,
temperature: float,
max_tokens: int,
) -> str:
"""Generate streaming response and display in real-time."""
response_content = ""
try:
stream = client.chat.completions.create(
model=model,
messages=messages,
temperature=temperature,
max_tokens=max_tokens,
stream=True,
)
for chunk in stream:
if chunk.choices[0].delta.content:
content = chunk.choices[0].delta.content
response_content += content
console.print(content, end="")
console.print() # Newline after streaming
except Exception as e:
raise Exception(f"Streaming error: {e}")
return response_content
def _generate_response(
client: "OpenAI",
model: str,
messages: list,
temperature: float,
max_tokens: int,
) -> str:
"""Generate non-streaming response."""
try:
response = client.chat.completions.create(
model=model,
messages=messages,
temperature=temperature,
max_tokens=max_tokens,
stream=False,
)
content = response.choices[0].message.content
# Display as markdown
md = Markdown(content)
console.print(md)
return content
except Exception as e:
raise Exception(f"Generation error: {e}")
def _handle_command(command: str, messages: list, temperature: float, max_tokens: int) -> bool:
"""Handle special commands. Returns True to continue chat, False to exit."""
cmd = command.lower().strip()
if cmd in ["/quit", "/exit", "/q"]:
console.print()
print_info("Goodbye!")
return False
elif cmd in ["/help", "/h"]:
console.print()
console.print(
Panel(
"[bold]Available Commands:[/bold]\n\n"
"/help, /h - Show this help message\n"
"/quit, /exit, /q - Exit chat\n"
"/clear, /c - Clear conversation history\n"
"/history, /hist - Show conversation history\n"
"/info, /i - Show current settings\n"
"/retry, /r - Regenerate last response",
title="Help",
border_style="cyan",
)
)
console.print()
return True
elif cmd in ["/clear", "/c"]:
messages.clear()
console.print()
print_success("Conversation history cleared")
console.print()
return True
elif cmd in ["/history", "/hist"]:
console.print()
if not messages:
print_info("No conversation history")
else:
console.print(
Panel(
_format_history(messages),
title=f"History ({len(messages)} messages)",
border_style="cyan",
)
)
console.print()
return True
elif cmd in ["/info", "/i"]:
console.print()
console.print(
Panel(
f"[bold]Current Settings:[/bold]\n\n"
f"Temperature: [cyan]{temperature}[/cyan]\n"
f"Max tokens: [cyan]{max_tokens}[/cyan]\n"
f"Messages: [cyan]{len(messages)}[/cyan]",
title="Info",
border_style="cyan",
)
)
console.print()
return True
elif cmd in ["/retry", "/r"]:
if len(messages) >= 2 and messages[-1]["role"] == "assistant":
# Remove last assistant response
messages.pop()
print_info("Retrying last response...")
console.print()
else:
print_warning("No previous response to retry")
console.print()
return True
else:
print_warning(f"Unknown command: {command}")
console.print("[dim]Type /help for available commands[/dim]")
console.print()
return True
def _format_history(messages: list) -> str:
"""Format conversation history for display."""
lines = []
for i, msg in enumerate(messages, 1):
role = msg["role"].capitalize()
content = msg["content"]
# Truncate long messages
if len(content) > 200:
content = content[:200] + "..."
lines.append(f"[bold]{i}. {role}:[/bold] {content}")
return "\n\n".join(lines)
def _save_history(file_path: Path, messages: list, model: str) -> None:
"""Save conversation history to file."""
try:
history_data = {
"model": model,
"timestamp": datetime.now().isoformat(),
"messages": messages,
}
with open(file_path, "w", encoding="utf-8") as f:
json.dump(history_data, f, indent=2, ensure_ascii=False)
except Exception as e:
print_warning(f"Failed to save history: {e}")

View File

@@ -0,0 +1,167 @@
"""
Config command for kt-cli.
Manages kt-cli configuration.
"""
from typing import Optional
import typer
import yaml
from rich.syntax import Syntax
from kt_kernel.cli.config.settings import get_settings
from kt_kernel.cli.i18n import t
from kt_kernel.cli.utils.console import confirm, console, print_error, print_success
app = typer.Typer(help="Manage kt-cli configuration")
@app.command(name="init")
def init() -> None:
"""Initialize or re-run the first-time setup wizard."""
from kt_kernel.cli.main import _show_first_run_setup
from kt_kernel.cli.config.settings import get_settings
settings = get_settings()
_show_first_run_setup(settings)
@app.command(name="show")
def show(
key: Optional[str] = typer.Argument(None, help="Configuration key to show (e.g., server.port)"),
) -> None:
"""Show current configuration."""
settings = get_settings()
if key:
value = settings.get(key)
if value is not None:
if isinstance(value, (dict, list)):
console.print(yaml.dump({key: value}, default_flow_style=False, allow_unicode=True))
else:
console.print(t("config_get_value", key=key, value=value))
else:
print_error(t("config_get_not_found", key=key))
raise typer.Exit(1)
else:
console.print(f"\n[bold]{t('config_show_title')}[/bold]\n")
console.print(f"[dim]{t('config_file_location', path=str(settings.config_path))}[/dim]\n")
config_yaml = yaml.dump(settings.get_all(), default_flow_style=False, allow_unicode=True)
syntax = Syntax(config_yaml, "yaml", theme="monokai", line_numbers=False)
console.print(syntax)
@app.command(name="set")
def set_config(
key: str = typer.Argument(..., help="Configuration key (e.g., server.port)"),
value: str = typer.Argument(..., help="Value to set"),
) -> None:
"""Set a configuration value."""
settings = get_settings()
# Try to parse value as JSON/YAML for complex types
parsed_value = _parse_value(value)
settings.set(key, parsed_value)
print_success(t("config_set_success", key=key, value=parsed_value))
@app.command(name="get")
def get_config(
key: str = typer.Argument(..., help="Configuration key (e.g., server.port)"),
) -> None:
"""Get a configuration value."""
settings = get_settings()
value = settings.get(key)
if value is not None:
if isinstance(value, (dict, list)):
console.print(yaml.dump(value, default_flow_style=False, allow_unicode=True))
else:
console.print(str(value))
else:
print_error(t("config_get_not_found", key=key))
raise typer.Exit(1)
@app.command(name="reset")
def reset(
yes: bool = typer.Option(False, "--yes", "-y", help="Skip confirmation"),
) -> None:
"""Reset configuration to defaults."""
if not yes:
if not confirm(t("config_reset_confirm"), default=False):
raise typer.Abort()
settings = get_settings()
settings.reset()
print_success(t("config_reset_success"))
@app.command(name="path")
def path() -> None:
"""Show configuration file path."""
settings = get_settings()
console.print(str(settings.config_path))
@app.command(name="model-path-list", deprecated=True, hidden=True)
def model_path_list() -> None:
"""[Deprecated] Use 'kt model path-list' instead."""
console.print("[yellow]⚠ This command is deprecated. Use 'kt model path-list' instead.[/yellow]\n")
import subprocess
subprocess.run(["kt", "model", "path-list"])
@app.command(name="model-path-add", deprecated=True, hidden=True)
def model_path_add(
path: str = typer.Argument(..., help="Path to add"),
) -> None:
"""[Deprecated] Use 'kt model path-add' instead."""
console.print("[yellow]⚠ This command is deprecated. Use 'kt model path-add' instead.[/yellow]\n")
import subprocess
subprocess.run(["kt", "model", "path-add", path])
@app.command(name="model-path-remove", deprecated=True, hidden=True)
def model_path_remove(
path: str = typer.Argument(..., help="Path to remove"),
) -> None:
"""[Deprecated] Use 'kt model path-remove' instead."""
console.print("[yellow]⚠ This command is deprecated. Use 'kt model path-remove' instead.[/yellow]\n")
import subprocess
subprocess.run(["kt", "model", "path-remove", path])
def _parse_value(value: str):
"""Parse a string value into appropriate Python type."""
# Try boolean
if value.lower() in ("true", "yes", "on", "1"):
return True
if value.lower() in ("false", "no", "off", "0"):
return False
# Try integer
try:
return int(value)
except ValueError:
pass
# Try float
try:
return float(value)
except ValueError:
pass
# Try YAML/JSON parsing for lists/dicts
try:
parsed = yaml.safe_load(value)
if isinstance(parsed, (dict, list)):
return parsed
except yaml.YAMLError:
pass
# Return as string
return value

View File

@@ -0,0 +1,394 @@
"""
Doctor command for kt-cli.
Diagnoses environment issues and provides recommendations.
"""
import platform
import shutil
from pathlib import Path
from typing import Optional
import typer
from rich.table import Table
from kt_kernel.cli.config.settings import get_settings
from kt_kernel.cli.i18n import t
from kt_kernel.cli.utils.console import console, print_error, print_info, print_success, print_warning
from kt_kernel.cli.utils.environment import (
check_docker,
detect_available_ram_gb,
detect_cpu_info,
detect_cuda_version,
detect_disk_space_gb,
detect_env_managers,
detect_gpus,
detect_memory_info,
detect_ram_gb,
get_installed_package_version,
)
def doctor(
verbose: bool = typer.Option(False, "--verbose", "-v", help="Show detailed diagnostics"),
) -> None:
"""Diagnose environment issues."""
console.print(f"\n[bold]{t('doctor_title')}[/bold]\n")
issues_found = False
checks = []
# 1. Python version
python_version = platform.python_version()
python_ok = _check_python_version(python_version)
checks.append(
{
"name": t("doctor_check_python"),
"status": "ok" if python_ok else "error",
"value": python_version,
"hint": "Python 3.10+ required" if not python_ok else None,
}
)
if not python_ok:
issues_found = True
# 2. CUDA availability
cuda_version = detect_cuda_version()
checks.append(
{
"name": t("doctor_check_cuda"),
"status": "ok" if cuda_version else "warning",
"value": cuda_version or t("version_cuda_not_found"),
"hint": "CUDA is optional but recommended for GPU acceleration" if not cuda_version else None,
}
)
# 3. GPU detection
gpus = detect_gpus()
if gpus:
gpu_names = ", ".join(g.name for g in gpus)
total_vram = sum(g.vram_gb for g in gpus)
checks.append(
{
"name": t("doctor_check_gpu"),
"status": "ok",
"value": t("doctor_gpu_found", count=len(gpus), names=gpu_names),
"hint": f"Total VRAM: {total_vram}GB",
}
)
else:
checks.append(
{
"name": t("doctor_check_gpu"),
"status": "warning",
"value": t("doctor_gpu_not_found"),
"hint": "GPU recommended for best performance",
}
)
# 4. CPU information
cpu_info = detect_cpu_info()
checks.append(
{
"name": t("doctor_check_cpu"),
"status": "ok",
"value": t("doctor_cpu_info", name=cpu_info.name, cores=cpu_info.cores, threads=cpu_info.threads),
"hint": None,
}
)
# 5. CPU instruction sets (critical for kt-kernel)
isa_list = cpu_info.instruction_sets
# Check for recommended instruction sets
recommended_isa = {"AVX2", "AVX512F", "AMX-INT8"}
has_recommended = bool(set(isa_list) & recommended_isa)
has_avx2 = "AVX2" in isa_list
has_avx512 = any(isa.startswith("AVX512") for isa in isa_list)
has_amx = any(isa.startswith("AMX") for isa in isa_list)
# Determine status and build display string
if has_amx:
isa_status = "ok"
isa_hint = "AMX available - best performance for INT4/INT8"
elif has_avx512:
isa_status = "ok"
isa_hint = "AVX512 available - good performance"
elif has_avx2:
isa_status = "warning"
isa_hint = "AVX2 only - consider upgrading CPU for better performance"
else:
isa_status = "error"
isa_hint = "AVX2 required for kt-kernel"
# Show top instruction sets (prioritize important ones)
display_isa = isa_list[:8] if len(isa_list) > 8 else isa_list
isa_display = ", ".join(display_isa)
if len(isa_list) > 8:
isa_display += f" (+{len(isa_list) - 8} more)"
checks.append(
{
"name": t("doctor_check_cpu_isa"),
"status": isa_status,
"value": isa_display if isa_display else "None detected",
"hint": isa_hint,
}
)
# 6. NUMA topology
numa_detail = []
for node, cpus in sorted(cpu_info.numa_info.items()):
if len(cpus) > 6:
cpu_str = f"{cpus[0]}-{cpus[-1]}"
else:
cpu_str = ",".join(str(c) for c in cpus)
numa_detail.append(f"{node}: {cpu_str}")
numa_value = t("doctor_numa_info", nodes=cpu_info.numa_nodes)
if verbose and numa_detail:
numa_value += " (" + "; ".join(numa_detail) + ")"
checks.append(
{
"name": t("doctor_check_numa"),
"status": "ok",
"value": numa_value,
"hint": f"{cpu_info.threads // cpu_info.numa_nodes} threads per node" if cpu_info.numa_nodes > 1 else None,
}
)
# 7. System memory (with frequency if available)
mem_info = detect_memory_info()
if mem_info.frequency_mhz and mem_info.type:
mem_value = t(
"doctor_memory_freq",
available=f"{mem_info.available_gb}GB",
total=f"{mem_info.total_gb}GB",
freq=mem_info.frequency_mhz,
type=mem_info.type,
)
else:
mem_value = t("doctor_memory_info", available=f"{mem_info.available_gb}GB", total=f"{mem_info.total_gb}GB")
ram_ok = mem_info.total_gb >= 32
checks.append(
{
"name": t("doctor_check_memory"),
"status": "ok" if ram_ok else "warning",
"value": mem_value,
"hint": "32GB+ RAM recommended for large models" if not ram_ok else None,
}
)
# 8. Disk space - check all model paths
settings = get_settings()
model_paths = settings.get_model_paths()
# Check all configured model paths
for i, disk_path in enumerate(model_paths):
available_disk, total_disk = detect_disk_space_gb(str(disk_path))
disk_ok = available_disk >= 100
# For multiple paths, add index to name
path_label = f"Model Path {i+1}" if len(model_paths) > 1 else t("doctor_check_disk")
checks.append(
{
"name": path_label,
"status": "ok" if disk_ok else "warning",
"value": t("doctor_disk_info", available=f"{available_disk}GB", path=str(disk_path)),
"hint": "100GB+ free space recommended for model storage" if not disk_ok else None,
}
)
# 6. Required packages
packages = [
("kt-kernel", ">=0.4.0", False), # name, version_req, required
("ktransformers", ">=0.4.0", False),
("sglang", ">=0.4.0", False),
("torch", ">=2.4.0", True),
("transformers", ">=4.45.0", True),
]
package_issues = []
for pkg_name, version_req, required in packages:
version = get_installed_package_version(pkg_name)
if version:
package_issues.append((pkg_name, version, "ok"))
elif required:
package_issues.append((pkg_name, t("version_not_installed"), "error"))
issues_found = True
else:
package_issues.append((pkg_name, t("version_not_installed"), "warning"))
if verbose:
checks.append(
{
"name": t("doctor_check_packages"),
"status": "ok" if not any(p[2] == "error" for p in package_issues) else "error",
"value": f"{sum(1 for p in package_issues if p[2] == 'ok')}/{len(package_issues)} installed",
"packages": package_issues,
}
)
# 7. SGLang installation source check
from kt_kernel.cli.utils.sglang_checker import check_sglang_installation, check_sglang_kt_kernel_support
sglang_info = check_sglang_installation()
if sglang_info["installed"]:
if sglang_info["from_source"]:
if sglang_info["git_info"]:
git_remote = sglang_info["git_info"].get("remote", "unknown")
git_branch = sglang_info["git_info"].get("branch", "unknown")
sglang_source_value = f"Source (GitHub: {git_remote}, branch: {git_branch})"
sglang_source_status = "ok"
sglang_source_hint = None
else:
sglang_source_value = "Source (editable)"
sglang_source_status = "ok"
sglang_source_hint = None
else:
sglang_source_value = "PyPI (not recommended)"
sglang_source_status = "warning"
sglang_source_hint = t("sglang_pypi_hint")
else:
sglang_source_value = "Not installed"
sglang_source_status = "warning"
sglang_source_hint = t("sglang_install_hint")
checks.append(
{
"name": "SGLang Source",
"status": sglang_source_status,
"value": sglang_source_value,
"hint": sglang_source_hint,
}
)
# 7b. SGLang kt-kernel support check (only if SGLang is installed)
kt_kernel_support = {"supported": True} # Default to True if not checked
if sglang_info["installed"]:
# Use cache=False to force re-check in doctor, but silent=True since we show in table
kt_kernel_support = check_sglang_kt_kernel_support(use_cache=False, silent=True)
if kt_kernel_support["supported"]:
kt_kernel_value = t("sglang_kt_kernel_supported")
kt_kernel_status = "ok"
kt_kernel_hint = None
else:
kt_kernel_value = t("sglang_kt_kernel_not_supported")
kt_kernel_status = "error"
kt_kernel_hint = 'Reinstall SGLang from: git clone https://github.com/kvcache-ai/sglang && cd sglang && pip install -e "python[all]"'
issues_found = True
checks.append(
{
"name": "SGLang kt-kernel",
"status": kt_kernel_status,
"value": kt_kernel_value,
"hint": kt_kernel_hint,
}
)
# 8. Environment managers
env_managers = detect_env_managers()
docker = check_docker()
env_list = [f"{m.name} {m.version}" for m in env_managers]
if docker:
env_list.append(f"docker {docker.version}")
checks.append(
{
"name": "Environment Managers",
"status": "ok" if env_list else "warning",
"value": ", ".join(env_list) if env_list else "None found",
"hint": "conda or docker recommended for installation" if not env_list else None,
}
)
# Display results
_display_results(checks, verbose)
# Show SGLang installation instructions if not installed
if not sglang_info["installed"]:
from kt_kernel.cli.utils.sglang_checker import print_sglang_install_instructions
console.print()
print_sglang_install_instructions()
# Show kt-kernel installation instructions if SGLang is installed but doesn't support kt-kernel
elif sglang_info["installed"] and not kt_kernel_support.get("supported", True):
from kt_kernel.cli.utils.sglang_checker import print_sglang_kt_kernel_instructions
console.print()
print_sglang_kt_kernel_instructions()
# Summary
console.print()
if issues_found:
print_warning(t("doctor_has_issues"))
else:
print_success(t("doctor_all_ok"))
console.print()
def _check_python_version(version: str) -> bool:
"""Check if Python version meets requirements."""
parts = version.split(".")
try:
major, minor = int(parts[0]), int(parts[1])
return major >= 3 and minor >= 10
except (IndexError, ValueError):
return False
def _display_results(checks: list[dict], verbose: bool) -> None:
"""Display diagnostic results."""
table = Table(show_header=True, header_style="bold")
table.add_column("Check", style="bold")
table.add_column("Status", width=8)
table.add_column("Value")
if verbose:
table.add_column("Notes", style="dim")
for check in checks:
status = check["status"]
if status == "ok":
status_str = f"[green]{t('doctor_status_ok')}[/green]"
elif status == "warning":
status_str = f"[yellow]{t('doctor_status_warning')}[/yellow]"
else:
status_str = f"[red]{t('doctor_status_error')}[/red]"
if verbose:
table.add_row(
check["name"],
status_str,
check["value"],
check.get("hint", ""),
)
else:
table.add_row(
check["name"],
status_str,
check["value"],
)
# Show package details if verbose
if verbose and "packages" in check:
for pkg_name, pkg_version, pkg_status in check["packages"]:
if pkg_status == "ok":
pkg_status_str = "[green]✓[/green]"
elif pkg_status == "warning":
pkg_status_str = "[yellow]○[/yellow]"
else:
pkg_status_str = "[red]✗[/red]"
table.add_row(
f" └─ {pkg_name}",
pkg_status_str,
pkg_version,
"",
)
console.print(table)

View File

@@ -0,0 +1,409 @@
"""
Model command for kt-cli.
Manages models: download, list, and storage paths.
"""
import os
from pathlib import Path
from typing import Optional
import typer
from kt_kernel.cli.config.settings import get_settings
from kt_kernel.cli.i18n import t
from kt_kernel.cli.utils.console import (
confirm,
console,
print_error,
print_info,
print_success,
print_warning,
prompt_choice,
)
app = typer.Typer(
help="Manage models and storage paths",
invoke_without_command=True,
no_args_is_help=False,
)
@app.callback()
def callback(ctx: typer.Context) -> None:
"""
Model management commands.
Run without arguments to see available models.
"""
# If no subcommand is provided, show the model list
if ctx.invoked_subcommand is None:
show_model_list()
def show_model_list() -> None:
"""Display available models with their status and paths."""
from rich.table import Table
from kt_kernel.cli.utils.model_registry import get_registry
from kt_kernel.cli.i18n import get_lang
registry = get_registry()
settings = get_settings()
console.print()
console.print(f"[bold cyan]{t('model_supported_title')}[/bold cyan]\n")
# Get local models mapping
local_models = {m.name: p for m, p in registry.find_local_models()}
# Create table
table = Table(show_header=True, header_style="bold")
table.add_column(t("model_column_model"), style="cyan", no_wrap=True)
table.add_column(t("model_column_status"), justify="center")
all_models = registry.list_all()
for model in all_models:
if model.name in local_models:
status = f"[green]✓ {t('model_status_local')}[/green]"
else:
status = "[dim]-[/dim]"
table.add_row(model.name, status)
console.print(table)
console.print()
# Usage instructions
console.print(f"[bold]{t('model_usage_title')}:[/bold]")
console.print(f"{t('model_usage_download')} [cyan]kt model download <model-name>[/cyan]")
console.print(f"{t('model_usage_list_local')} [cyan]kt model list --local[/cyan]")
console.print(f"{t('model_usage_search')} [cyan]kt model search <query>[/cyan]")
console.print()
# Show model storage paths
model_paths = settings.get_model_paths()
console.print(f"[bold]{t('model_storage_paths_title')}:[/bold]")
for path in model_paths:
marker = "[green]✓[/green]" if path.exists() else "[dim]✗[/dim]"
console.print(f" {marker} {path}")
console.print()
@app.command(name="download")
def download(
model: Optional[str] = typer.Argument(
None,
help="Model name or HuggingFace repo (e.g., deepseek-v3, Qwen/Qwen3-30B)",
),
path: Optional[Path] = typer.Option(
None,
"--path",
"-p",
help="Custom download path",
),
list_models: bool = typer.Option(
False,
"--list",
"-l",
help="List available models",
),
resume: bool = typer.Option(
True,
"--resume/--no-resume",
help="Resume incomplete downloads",
),
yes: bool = typer.Option(
False,
"--yes",
"-y",
help="Skip confirmation prompts",
),
) -> None:
"""Download model weights from HuggingFace."""
import subprocess
from kt_kernel.cli.i18n import get_lang
from kt_kernel.cli.utils.console import print_model_table, print_step
from kt_kernel.cli.utils.model_registry import get_registry
settings = get_settings()
registry = get_registry()
console.print()
# List mode
if list_models or model is None:
print_step(t("download_list_title"))
console.print()
models = registry.list_all()
model_dicts = []
for m in models:
lang = get_lang()
desc = m.description_zh if lang == "zh" and m.description_zh else m.description
model_dicts.append(
{
"name": m.name,
"hf_repo": m.hf_repo,
"type": m.type,
"gpu_vram_gb": m.gpu_vram_gb,
"cpu_ram_gb": m.cpu_ram_gb,
}
)
print_model_table(model_dicts)
console.print()
if model is None:
console.print(f"[dim]{t('model_download_usage_hint')}[/dim]")
console.print()
return
# Search for model
print_step(t("download_searching", name=model))
# Check if it's a direct HuggingFace repo path
if "/" in model:
hf_repo = model
model_info = None
model_name = model.split("/")[-1]
else:
matches = registry.search(model)
if not matches:
print_error(t("run_model_not_found", name=model))
console.print()
console.print(t("model_download_list_hint"))
console.print(t("model_download_hf_hint"))
raise typer.Exit(1)
if len(matches) == 1:
model_info = matches[0]
else:
console.print()
print_info(t("download_multiple_found"))
choices = [f"{m.name} ({m.hf_repo})" for m in matches]
selected = prompt_choice(t("download_select"), choices)
idx = choices.index(selected)
model_info = matches[idx]
hf_repo = model_info.hf_repo
model_name = model_info.name
print_success(t("download_found", name=hf_repo))
# Determine download path
if path is None:
download_path = settings.models_dir / model_name.replace(" ", "-")
else:
download_path = path
console.print()
print_info(t("download_destination", path=str(download_path)))
# Check if already exists
if download_path.exists() and (download_path / "config.json").exists():
print_warning(t("download_already_exists", path=str(download_path)))
if not yes:
if not confirm(t("download_overwrite_prompt"), default=False):
raise typer.Abort()
# Confirm download
if not yes:
console.print()
if not confirm(t("prompt_continue")):
raise typer.Abort()
# Download using huggingface-cli
console.print()
print_step(t("download_starting"))
cmd = [
"huggingface-cli",
"download",
hf_repo,
"--local-dir",
str(download_path),
]
if resume:
cmd.append("--resume-download")
# Add mirror if configured
mirror = settings.get("download.mirror", "")
if mirror:
cmd.extend(["--endpoint", mirror])
try:
process = subprocess.run(cmd, check=True)
console.print()
print_success(t("download_complete"))
console.print()
console.print(f" {t('model_saved_to', path=download_path)}")
console.print()
console.print(f" {t('model_start_with', name=model_name)}")
console.print()
except subprocess.CalledProcessError as e:
print_error(t("model_download_failed", error=str(e)))
raise typer.Exit(1)
except FileNotFoundError:
print_error(t("model_hf_cli_not_found"))
raise typer.Exit(1)
@app.command(name="list")
def list_models(
local_only: bool = typer.Option(False, "--local", help="Show only locally downloaded models"),
verbose: bool = typer.Option(False, "--verbose", "-v", help="Show detailed info including paths"),
) -> None:
"""List available models."""
from rich.table import Table
from kt_kernel.cli.utils.model_registry import get_registry
registry = get_registry()
console.print()
if local_only:
# Show only local models
local_models = registry.find_local_models()
if not local_models:
print_warning(t("model_no_local_models"))
console.print()
console.print(f" {t('model_download_hint')} [cyan]kt model download <model-name>[/cyan]")
console.print()
return
table = Table(title=t("model_local_models_title"), show_header=True, header_style="bold")
table.add_column(t("model_column_model"), style="cyan", no_wrap=True)
if verbose:
table.add_column(t("model_column_local_path"), style="dim")
for model_info, model_path in local_models:
if verbose:
table.add_row(model_info.name, str(model_path))
else:
table.add_row(model_info.name)
console.print(table)
else:
# Show all registered models
all_models = registry.list_all()
local_models_dict = {m.name: p for m, p in registry.find_local_models()}
table = Table(title=t("model_available_models_title"), show_header=True, header_style="bold")
table.add_column(t("model_column_model"), style="cyan", no_wrap=True)
table.add_column(t("model_column_status"), justify="center")
if verbose:
table.add_column(t("model_column_local_path"), style="dim")
for model in all_models:
if model.name in local_models_dict:
status = f"[green]✓ {t('model_status_local')}[/green]"
local_path = str(local_models_dict[model.name])
else:
status = "[dim]-[/dim]"
local_path = f"[dim]{t('model_status_not_downloaded')}[/dim]"
if verbose:
table.add_row(model.name, status, local_path)
else:
table.add_row(model.name, status)
console.print(table)
console.print()
@app.command(name="path-list")
def path_list() -> None:
"""List all configured model storage paths."""
settings = get_settings()
model_paths = settings.get_model_paths()
console.print()
console.print(f"[bold]{t('model_storage_paths_title')}:[/bold]\n")
for i, path in enumerate(model_paths, 1):
marker = "[green]✓[/green]" if path.exists() else "[red]✗[/red]"
console.print(f" {marker} [{i}] {path}")
console.print()
@app.command(name="path-add")
def path_add(
path: str = typer.Argument(..., help="Path to add"),
) -> None:
"""Add a new model storage path."""
# Expand user home directory
path = os.path.expanduser(path)
# Check if path exists or can be created
path_obj = Path(path)
if not path_obj.exists():
console.print(f"[yellow]{t('model_path_not_exist', path=path)}[/yellow]")
if confirm(t("model_create_directory", path=path), default=True):
try:
path_obj.mkdir(parents=True, exist_ok=True)
console.print(f"[green]✓[/green] {t('model_created_directory', path=path)}")
except (OSError, PermissionError) as e:
print_error(t("model_create_dir_failed", error=str(e)))
raise typer.Exit(1)
else:
raise typer.Abort()
# Add to configuration
settings = get_settings()
settings.add_model_path(path)
print_success(t("model_path_added", path=path))
@app.command(name="path-remove")
def path_remove(
path: str = typer.Argument(..., help="Path to remove"),
) -> None:
"""Remove a model storage path from configuration."""
# Expand user home directory
path = os.path.expanduser(path)
settings = get_settings()
if settings.remove_model_path(path):
print_success(t("model_path_removed", path=path))
else:
print_error(t("model_path_not_found", path=path))
raise typer.Exit(1)
@app.command(name="search")
def search(
query: str = typer.Argument(..., help="Search query (model name or keyword)"),
) -> None:
"""Search for models in the registry."""
from rich.table import Table
from kt_kernel.cli.utils.model_registry import get_registry
registry = get_registry()
matches = registry.search(query)
console.print()
if not matches:
print_warning(t("model_search_no_results", query=query))
console.print()
return
table = Table(title=t("model_search_results_title", query=query), show_header=True)
table.add_column(t("model_column_name"), style="cyan")
table.add_column(t("model_column_hf_repo"), style="dim")
table.add_column(t("model_column_aliases"), style="yellow")
for model in matches:
aliases = ", ".join(model.aliases[:3])
if len(model.aliases) > 3:
aliases += f" +{len(model.aliases) - 3} more"
table.add_row(model.name, model.hf_repo, aliases)
console.print(table)
console.print()

View File

@@ -0,0 +1,239 @@
"""
Quant command for kt-cli.
Quantizes model weights for CPU inference.
"""
import subprocess
import sys
from enum import Enum
from pathlib import Path
from typing import Optional
import typer
from kt_kernel.cli.config.settings import get_settings
from kt_kernel.cli.i18n import t
from kt_kernel.cli.utils.console import (
confirm,
console,
create_progress,
print_error,
print_info,
print_step,
print_success,
print_warning,
)
from kt_kernel.cli.utils.environment import detect_cpu_info
class QuantMethod(str, Enum):
"""Quantization method."""
INT4 = "int4"
INT8 = "int8"
def quant(
model: str = typer.Argument(
...,
help="Model name or path to quantize",
),
method: QuantMethod = typer.Option(
QuantMethod.INT4,
"--method",
"-m",
help="Quantization method",
),
output: Optional[Path] = typer.Option(
None,
"--output",
"-o",
help="Output path for quantized weights",
),
input_type: str = typer.Option(
"fp8",
"--input-type",
"-i",
help="Input weight type (fp8, fp16, bf16)",
),
cpu_threads: Optional[int] = typer.Option(
None,
"--cpu-threads",
help="Number of CPU threads for quantization",
),
numa_nodes: Optional[int] = typer.Option(
None,
"--numa-nodes",
help="Number of NUMA nodes",
),
no_merge: bool = typer.Option(
False,
"--no-merge",
help="Don't merge safetensor files",
),
yes: bool = typer.Option(
False,
"--yes",
"-y",
help="Skip confirmation prompts",
),
) -> None:
"""Quantize model weights for CPU inference."""
settings = get_settings()
console.print()
# Resolve input path
input_path = _resolve_input_path(model, settings)
if input_path is None:
print_error(t("quant_input_not_found", path=model))
raise typer.Exit(1)
print_info(t("quant_input_path", path=str(input_path)))
# Resolve output path
if output is None:
output = input_path.parent / f"{input_path.name}-{method.value.upper()}"
print_info(t("quant_output_path", path=str(output)))
print_info(t("quant_method", method=method.value.upper()))
# Detect CPU configuration
cpu = detect_cpu_info()
final_cpu_threads = cpu_threads or cpu.cores
final_numa_nodes = numa_nodes or cpu.numa_nodes
print_info(f"CPU threads: {final_cpu_threads}")
print_info(f"NUMA nodes: {final_numa_nodes}")
# Check if output exists
if output.exists():
print_warning(f"Output path already exists: {output}")
if not yes:
if not confirm("Overwrite?", default=False):
raise typer.Abort()
# Confirm
if not yes:
console.print()
console.print("[bold]Quantization Settings:[/bold]")
console.print(f" Input: {input_path}")
console.print(f" Output: {output}")
console.print(f" Method: {method.value.upper()}")
console.print(f" Input type: {input_type}")
console.print()
print_warning("Quantization may take 30-60 minutes depending on model size.")
console.print()
if not confirm(t("prompt_continue")):
raise typer.Abort()
# Find conversion script
kt_kernel_path = _find_kt_kernel_path()
if kt_kernel_path is None:
print_error("kt-kernel not found. Install with: kt install inference")
raise typer.Exit(1)
script_path = kt_kernel_path / "scripts" / "convert_cpu_weights.py"
if not script_path.exists():
print_error(f"Conversion script not found: {script_path}")
raise typer.Exit(1)
# Build command
cmd = [
sys.executable, str(script_path),
"--input-path", str(input_path),
"--input-type", input_type,
"--output", str(output),
"--quant-method", method.value,
"--cpuinfer-threads", str(final_cpu_threads),
"--threadpool-count", str(final_numa_nodes),
]
if no_merge:
cmd.append("--no-merge-safetensor")
# Run quantization
console.print()
print_step(t("quant_starting"))
console.print()
console.print(f"[dim]$ {' '.join(cmd)}[/dim]")
console.print()
try:
process = subprocess.run(cmd)
if process.returncode == 0:
console.print()
print_success(t("quant_complete"))
console.print()
console.print(f" Quantized weights saved to: {output}")
console.print()
console.print(" Use with:")
console.print(f" kt run {model} --weights-path {output}")
console.print()
else:
print_error(f"Quantization failed with exit code {process.returncode}")
raise typer.Exit(process.returncode)
except FileNotFoundError as e:
print_error(f"Failed to run quantization: {e}")
raise typer.Exit(1)
except KeyboardInterrupt:
console.print()
print_warning("Quantization interrupted.")
raise typer.Exit(130)
def _resolve_input_path(model: str, settings) -> Optional[Path]:
"""Resolve the input model path."""
# Check if it's already a path
path = Path(model)
if path.exists() and (path / "config.json").exists():
return path
# Search in models directory
from kt_kernel.cli.utils.model_registry import get_registry
registry = get_registry()
matches = registry.search(model)
if matches:
model_info = matches[0]
# Try to find in all configured model directories
model_paths = settings.get_model_paths()
for models_dir in model_paths:
possible_paths = [
models_dir / model_info.name,
models_dir / model_info.name.lower(),
models_dir / model_info.hf_repo.split("/")[-1],
]
for p in possible_paths:
if p.exists() and (p / "config.json").exists():
return p
return None
def _find_kt_kernel_path() -> Optional[Path]:
"""Find the kt-kernel installation path."""
try:
import kt_kernel
return Path(kt_kernel.__file__).parent.parent
except ImportError:
pass
# Check common locations
possible_paths = [
Path.home() / "Projects" / "ktransformers" / "kt-kernel",
Path.cwd().parent / "kt-kernel",
Path.cwd() / "kt-kernel",
]
for path in possible_paths:
if path.exists() and (path / "scripts").exists():
return path
return None

View File

@@ -0,0 +1,831 @@
"""
Run command for kt-cli.
Starts the model inference server using SGLang + kt-kernel.
"""
import os
import subprocess
import sys
from pathlib import Path
from typing import Optional
import typer
from kt_kernel.cli.config.settings import get_settings
from kt_kernel.cli.i18n import t
from kt_kernel.cli.utils.console import (
confirm,
console,
print_api_info,
print_error,
print_info,
print_server_info,
print_step,
print_success,
print_warning,
prompt_choice,
)
from kt_kernel.cli.utils.environment import detect_cpu_info, detect_gpus, detect_ram_gb
from kt_kernel.cli.utils.model_registry import MODEL_COMPUTE_FUNCTIONS, ModelInfo, get_registry
def run(
model: Optional[str] = typer.Argument(
None,
help="Model name or path (e.g., deepseek-v3, qwen3-30b). If not specified, shows interactive selection.",
),
host: str = typer.Option(
None,
"--host",
"-H",
help="Server host address",
),
port: int = typer.Option(
None,
"--port",
"-p",
help="Server port",
),
# CPU/GPU configuration
gpu_experts: Optional[int] = typer.Option(
None,
"--gpu-experts",
help="Number of GPU experts per layer",
),
cpu_threads: Optional[int] = typer.Option(
None,
"--cpu-threads",
help="Number of CPU inference threads (kt-cpuinfer, defaults to 80% of CPU cores)",
),
numa_nodes: Optional[int] = typer.Option(
None,
"--numa-nodes",
help="Number of NUMA nodes",
),
tensor_parallel_size: Optional[int] = typer.Option(
None,
"--tensor-parallel-size",
"--tp",
help="Tensor parallel size (number of GPUs)",
),
# Model paths
model_path: Optional[Path] = typer.Option(
None,
"--model-path",
help="Custom model path",
),
weights_path: Optional[Path] = typer.Option(
None,
"--weights-path",
help="Custom quantized weights path",
),
# KT-kernel options
kt_method: Optional[str] = typer.Option(
None,
"--kt-method",
help="KT quantization method (AMXINT4, RAWFP8, etc.)",
),
kt_gpu_prefill_token_threshold: Optional[int] = typer.Option(
None,
"--kt-gpu-prefill-threshold",
help="GPU prefill token threshold for kt-kernel",
),
# SGLang options
attention_backend: Optional[str] = typer.Option(
None,
"--attention-backend",
help="Attention backend (triton, flashinfer)",
),
max_total_tokens: Optional[int] = typer.Option(
None,
"--max-total-tokens",
help="Maximum total tokens",
),
max_running_requests: Optional[int] = typer.Option(
None,
"--max-running-requests",
help="Maximum running requests",
),
chunked_prefill_size: Optional[int] = typer.Option(
None,
"--chunked-prefill-size",
help="Chunked prefill size",
),
mem_fraction_static: Optional[float] = typer.Option(
None,
"--mem-fraction-static",
help="Memory fraction for static allocation",
),
watchdog_timeout: Optional[int] = typer.Option(
None,
"--watchdog-timeout",
help="Watchdog timeout in seconds",
),
served_model_name: Optional[str] = typer.Option(
None,
"--served-model-name",
help="Custom model name for API responses",
),
# Performance flags
disable_shared_experts_fusion: Optional[bool] = typer.Option(
None,
"--disable-shared-experts-fusion/--enable-shared-experts-fusion",
help="Disable/enable shared experts fusion",
),
# Other options
quantize: bool = typer.Option(
False,
"--quantize",
"-q",
help="Quantize model if weights not found",
),
advanced: bool = typer.Option(
False,
"--advanced",
help="Show advanced options",
),
dry_run: bool = typer.Option(
False,
"--dry-run",
help="Show command without executing",
),
) -> None:
"""Start model inference server."""
# Check if SGLang is installed before proceeding
from kt_kernel.cli.utils.sglang_checker import (
check_sglang_installation,
check_sglang_kt_kernel_support,
print_sglang_install_instructions,
print_sglang_kt_kernel_instructions,
)
sglang_info = check_sglang_installation()
if not sglang_info["installed"]:
console.print()
print_error(t("sglang_not_found"))
console.print()
print_sglang_install_instructions()
raise typer.Exit(1)
# Check if SGLang supports kt-kernel (has --kt-gpu-prefill-token-threshold parameter)
kt_kernel_support = check_sglang_kt_kernel_support()
if not kt_kernel_support["supported"]:
console.print()
print_error(t("sglang_kt_kernel_not_supported"))
console.print()
print_sglang_kt_kernel_instructions()
raise typer.Exit(1)
settings = get_settings()
registry = get_registry()
console.print()
# If no model specified, show interactive selection
if model is None:
model = _interactive_model_selection(registry, settings)
if model is None:
raise typer.Exit(0)
# Step 1: Detect hardware
print_step(t("run_detecting_hardware"))
gpus = detect_gpus()
cpu = detect_cpu_info()
ram = detect_ram_gb()
if gpus:
gpu_info = f"{gpus[0].name} ({gpus[0].vram_gb}GB VRAM)"
if len(gpus) > 1:
gpu_info += f" + {len(gpus) - 1} more"
print_info(t("run_gpu_info", name=gpus[0].name, vram=gpus[0].vram_gb))
else:
print_warning(t("doctor_gpu_not_found"))
gpu_info = "None"
print_info(t("run_cpu_info", name=cpu.name, cores=cpu.cores, numa=cpu.numa_nodes))
print_info(t("run_ram_info", total=int(ram)))
# Step 2: Resolve model
console.print()
print_step(t("run_checking_model"))
model_info = None
resolved_model_path = model_path
# Check if model is a path
if Path(model).exists():
resolved_model_path = Path(model)
print_info(t("run_model_path", path=str(resolved_model_path)))
# Try to infer model type from path to use default configurations
# Check directory name against known models
dir_name = resolved_model_path.name.lower()
for registered_model in registry.list_all():
# Check if directory name matches model name or aliases
if dir_name == registered_model.name.lower():
model_info = registered_model
print_info(f"Detected model type: {registered_model.name}")
break
for alias in registered_model.aliases:
if dir_name == alias.lower() or alias.lower() in dir_name:
model_info = registered_model
print_info(f"Detected model type: {registered_model.name}")
break
if model_info:
break
# Also check HuggingFace repo format (org--model)
if not model_info:
for registered_model in registry.list_all():
repo_slug = registered_model.hf_repo.replace("/", "--").lower()
if repo_slug in dir_name or dir_name in repo_slug:
model_info = registered_model
print_info(f"Detected model type: {registered_model.name}")
break
if not model_info:
print_warning("Could not detect model type from path. Using default parameters.")
console.print(" [dim]Tip: Use model name (e.g., 'kt run m2') to apply optimized configurations[/dim]")
else:
# Search in registry
matches = registry.search(model)
if not matches:
print_error(t("run_model_not_found", name=model))
console.print()
console.print("Available models:")
for m in registry.list_all()[:5]:
console.print(f" - {m.name} ({', '.join(m.aliases[:2])})")
raise typer.Exit(1)
if len(matches) == 1:
model_info = matches[0]
else:
# Multiple matches - prompt user
console.print()
print_info(t("run_multiple_matches"))
choices = [f"{m.name} ({m.hf_repo})" for m in matches]
selected = prompt_choice(t("run_select_model"), choices)
idx = choices.index(selected)
model_info = matches[idx]
# Find model path
if model_path is None:
resolved_model_path = _find_model_path(model_info, settings)
if resolved_model_path is None:
print_error(t("run_model_not_found", name=model_info.name))
console.print()
console.print(
f" Download with: kt download {model_info.aliases[0] if model_info.aliases else model_info.name}"
)
raise typer.Exit(1)
print_info(t("run_model_path", path=str(resolved_model_path)))
# Step 3: Check quantized weights (only if explicitly requested)
resolved_weights_path = None
# Only use quantized weights if explicitly specified by user
if weights_path is not None:
# User explicitly specified weights path
resolved_weights_path = weights_path
if not resolved_weights_path.exists():
print_error(t("run_weights_not_found"))
console.print(f" Path: {resolved_weights_path}")
raise typer.Exit(1)
print_info(f"Using quantized weights: {resolved_weights_path}")
elif quantize:
# User requested quantization
console.print()
print_step(t("run_quantizing"))
# TODO: Implement quantization
print_warning("Quantization not yet implemented. Please run 'kt quant' manually.")
raise typer.Exit(1)
else:
# Default: use original precision model without quantization
console.print()
print_info("Using original precision model (no quantization)")
# Step 4: Build command
# Resolve all parameters (CLI > model defaults > config > auto-detect)
final_host = host or settings.get("server.host", "0.0.0.0")
final_port = port or settings.get("server.port", 30000)
# Get defaults from model info if available
model_defaults = model_info.default_params if model_info else {}
# Determine tensor parallel size first (needed for GPU expert calculation)
# Priority: CLI > model defaults > config > auto-detect (with model constraints)
# Check if explicitly specified by user or configuration
explicitly_specified = (
tensor_parallel_size # CLI argument (highest priority)
or model_defaults.get("tensor-parallel-size") # Model defaults
or settings.get("inference.tensor_parallel_size") # Config file
)
if explicitly_specified:
# Use explicitly specified value
requested_tensor_parallel_size = explicitly_specified
else:
# Auto-detect from GPUs, considering model's max constraint
detected_gpu_count = len(gpus) if gpus else 1
if model_info and model_info.max_tensor_parallel_size is not None:
# Automatically limit to model's maximum to use as many GPUs as possible
requested_tensor_parallel_size = min(detected_gpu_count, model_info.max_tensor_parallel_size)
else:
requested_tensor_parallel_size = detected_gpu_count
# Apply model's max_tensor_parallel_size constraint if explicitly specified value exceeds it
final_tensor_parallel_size = requested_tensor_parallel_size
if model_info and model_info.max_tensor_parallel_size is not None:
if requested_tensor_parallel_size > model_info.max_tensor_parallel_size:
console.print()
print_warning(
f"Model {model_info.name} only supports up to {model_info.max_tensor_parallel_size}-way "
f"tensor parallelism, but {requested_tensor_parallel_size} was requested. "
f"Reducing to {model_info.max_tensor_parallel_size}."
)
final_tensor_parallel_size = model_info.max_tensor_parallel_size
# CPU/GPU configuration with smart defaults
# kt-cpuinfer: default to 80% of total CPU threads (cores * NUMA nodes)
total_threads = cpu.cores * cpu.numa_nodes
final_cpu_threads = (
cpu_threads
or model_defaults.get("kt-cpuinfer")
or settings.get("inference.cpu_threads")
or int(total_threads * 0.8)
)
# kt-threadpool-count: default to NUMA node count
final_numa_nodes = (
numa_nodes
or model_defaults.get("kt-threadpool-count")
or settings.get("inference.numa_nodes")
or cpu.numa_nodes
)
# kt-num-gpu-experts: use model-specific computation if available and not explicitly set
if gpu_experts is not None:
# User explicitly set it
final_gpu_experts = gpu_experts
elif model_info and model_info.name in MODEL_COMPUTE_FUNCTIONS and gpus:
# Use model-specific computation function (only if GPUs detected)
vram_per_gpu = gpus[0].vram_gb
compute_func = MODEL_COMPUTE_FUNCTIONS[model_info.name]
final_gpu_experts = compute_func(final_tensor_parallel_size, vram_per_gpu)
console.print()
print_info(
f"Auto-computed kt-num-gpu-experts: {final_gpu_experts} (TP={final_tensor_parallel_size}, VRAM={vram_per_gpu}GB per GPU)"
)
else:
# Fall back to defaults
final_gpu_experts = model_defaults.get("kt-num-gpu-experts") or settings.get("inference.gpu_experts", 1)
# KT-kernel options
final_kt_method = kt_method or model_defaults.get("kt-method") or settings.get("inference.kt_method", "AMXINT4")
final_kt_gpu_prefill_threshold = (
kt_gpu_prefill_token_threshold
or model_defaults.get("kt-gpu-prefill-token-threshold")
or settings.get("inference.kt_gpu_prefill_token_threshold", 4096)
)
# SGLang options
final_attention_backend = (
attention_backend
or model_defaults.get("attention-backend")
or settings.get("inference.attention_backend", "triton")
)
final_max_total_tokens = (
max_total_tokens or model_defaults.get("max-total-tokens") or settings.get("inference.max_total_tokens", 40000)
)
final_max_running_requests = (
max_running_requests
or model_defaults.get("max-running-requests")
or settings.get("inference.max_running_requests", 32)
)
final_chunked_prefill_size = (
chunked_prefill_size
or model_defaults.get("chunked-prefill-size")
or settings.get("inference.chunked_prefill_size", 4096)
)
final_mem_fraction_static = (
mem_fraction_static
or model_defaults.get("mem-fraction-static")
or settings.get("inference.mem_fraction_static", 0.98)
)
final_watchdog_timeout = (
watchdog_timeout or model_defaults.get("watchdog-timeout") or settings.get("inference.watchdog_timeout", 3000)
)
final_served_model_name = (
served_model_name or model_defaults.get("served-model-name") or settings.get("inference.served_model_name", "")
)
# Performance flags
if disable_shared_experts_fusion is not None:
final_disable_shared_experts_fusion = disable_shared_experts_fusion
elif "disable-shared-experts-fusion" in model_defaults:
final_disable_shared_experts_fusion = model_defaults["disable-shared-experts-fusion"]
else:
final_disable_shared_experts_fusion = settings.get("inference.disable_shared_experts_fusion", False)
# Pass all model default params to handle any extra parameters
extra_params = model_defaults if model_info else {}
cmd = _build_sglang_command(
model_path=resolved_model_path,
weights_path=resolved_weights_path,
model_info=model_info,
host=final_host,
port=final_port,
gpu_experts=final_gpu_experts,
cpu_threads=final_cpu_threads,
numa_nodes=final_numa_nodes,
tensor_parallel_size=final_tensor_parallel_size,
kt_method=final_kt_method,
kt_gpu_prefill_threshold=final_kt_gpu_prefill_threshold,
attention_backend=final_attention_backend,
max_total_tokens=final_max_total_tokens,
max_running_requests=final_max_running_requests,
chunked_prefill_size=final_chunked_prefill_size,
mem_fraction_static=final_mem_fraction_static,
watchdog_timeout=final_watchdog_timeout,
served_model_name=final_served_model_name,
disable_shared_experts_fusion=final_disable_shared_experts_fusion,
settings=settings,
extra_model_params=extra_params,
)
# Prepare environment variables
env = os.environ.copy()
# Add environment variables from advanced.env
env.update(settings.get_env_vars())
# Add environment variables from inference.env
inference_env = settings.get("inference.env", {})
if isinstance(inference_env, dict):
env.update({k: str(v) for k, v in inference_env.items()})
# Step 5: Show configuration summary
console.print()
print_step("Configuration")
# Model info
if model_info:
console.print(f" Model: [bold]{model_info.name}[/bold]")
else:
console.print(f" Model: [bold]{resolved_model_path.name}[/bold]")
console.print(f" Path: [dim]{resolved_model_path}[/dim]")
# Key parameters
console.print()
console.print(f" GPU Experts: [cyan]{final_gpu_experts}[/cyan] per layer")
console.print(f" CPU Threads (kt-cpuinfer): [cyan]{final_cpu_threads}[/cyan]")
console.print(f" NUMA Nodes (kt-threadpool-count): [cyan]{final_numa_nodes}[/cyan]")
console.print(f" Tensor Parallel: [cyan]{final_tensor_parallel_size}[/cyan]")
console.print(f" Method: [cyan]{final_kt_method}[/cyan]")
console.print(f" Attention: [cyan]{final_attention_backend}[/cyan]")
# Weights info
if resolved_weights_path:
console.print()
console.print(f" Quantized weights: [yellow]{resolved_weights_path}[/yellow]")
console.print()
console.print(f" Server: [green]http://{final_host}:{final_port}[/green]")
console.print()
# Step 6: Show or execute
if dry_run:
console.print()
console.print("[bold]Command:[/bold]")
console.print()
console.print(f" [dim]{' '.join(cmd)}[/dim]")
console.print()
return
# Execute with prepared environment variables
# Don't print "Server started" or API info here - let sglang's logs speak for themselves
# The actual startup takes time and these messages are misleading
# Print the command being executed
console.print()
console.print("[bold]Launching server with command:[/bold]")
console.print()
console.print(f" [dim]{' '.join(cmd)}[/dim]")
console.print()
try:
# Execute directly without intercepting output or signals
# This allows direct output to terminal and Ctrl+C to work naturally
process = subprocess.run(cmd, env=env)
sys.exit(process.returncode)
except FileNotFoundError:
from kt_kernel.cli.utils.sglang_checker import print_sglang_install_instructions
print_error(t("sglang_not_found"))
console.print()
print_sglang_install_instructions()
raise typer.Exit(1)
except Exception as e:
print_error(f"Failed to start server: {e}")
raise typer.Exit(1)
def _find_model_path(model_info: ModelInfo, settings) -> Optional[Path]:
"""Find the model path on disk by searching all configured model paths."""
model_paths = settings.get_model_paths()
# Search in all configured model directories
for models_dir in model_paths:
# Check common path patterns
possible_paths = [
models_dir / model_info.name,
models_dir / model_info.name.lower(),
models_dir / model_info.name.replace(" ", "-"),
models_dir / model_info.hf_repo.split("/")[-1],
models_dir / model_info.hf_repo.replace("/", "--"),
]
# Add alias-based paths
for alias in model_info.aliases:
possible_paths.append(models_dir / alias)
possible_paths.append(models_dir / alias.lower())
for path in possible_paths:
if path.exists() and (path / "config.json").exists():
return path
return None
def _find_weights_path(model_info: ModelInfo, settings) -> Optional[Path]:
"""Find the quantized weights path on disk by searching all configured paths."""
model_paths = settings.get_model_paths()
weights_dir = settings.weights_dir
# Check common patterns
base_names = [
model_info.name,
model_info.name.lower(),
model_info.hf_repo.split("/")[-1],
]
suffixes = ["-INT4", "-int4", "_INT4", "_int4", "-quant", "-quantized"]
# Prepare search directories
search_dirs = [weights_dir] if weights_dir else []
search_dirs.extend(model_paths)
for base in base_names:
for suffix in suffixes:
for dir_path in search_dirs:
if dir_path:
path = dir_path / f"{base}{suffix}"
if path.exists():
return path
return None
def _build_sglang_command(
model_path: Path,
weights_path: Optional[Path],
model_info: Optional[ModelInfo],
host: str,
port: int,
gpu_experts: int,
cpu_threads: int,
numa_nodes: int,
tensor_parallel_size: int,
kt_method: str,
kt_gpu_prefill_threshold: int,
attention_backend: str,
max_total_tokens: int,
max_running_requests: int,
chunked_prefill_size: int,
mem_fraction_static: float,
watchdog_timeout: int,
served_model_name: str,
disable_shared_experts_fusion: bool,
settings,
extra_model_params: Optional[dict] = None, # New parameter for additional params
) -> list[str]:
"""Build the SGLang launch command."""
cmd = [
sys.executable,
"-m",
"sglang.launch_server",
"--host",
host,
"--port",
str(port),
"--model",
str(model_path),
]
# Add kt-kernel options
# kt-kernel is needed for:
# 1. Quantized models (when weights_path is provided)
# 2. MoE models with CPU offloading (when kt-cpuinfer > 0 or kt-num-gpu-experts is configured)
use_kt_kernel = False
# Check if we should use kt-kernel
if weights_path:
# Quantized model - always use kt-kernel
use_kt_kernel = True
elif cpu_threads > 0 or gpu_experts > 1:
# CPU offloading configured - use kt-kernel
use_kt_kernel = True
elif model_info and model_info.type == "moe":
# MoE model - likely needs kt-kernel for expert offloading
use_kt_kernel = True
if use_kt_kernel:
# Add kt-weight-path: use quantized weights if available, otherwise use model path
weight_path_to_use = weights_path if weights_path else model_path
# Add kt-kernel configuration
cmd.extend(
[
"--kt-weight-path",
str(weight_path_to_use),
"--kt-cpuinfer",
str(cpu_threads),
"--kt-threadpool-count",
str(numa_nodes),
"--kt-num-gpu-experts",
str(gpu_experts),
"--kt-method",
kt_method,
"--kt-gpu-prefill-token-threshold",
str(kt_gpu_prefill_threshold),
]
)
# Add SGLang options
cmd.extend(
[
"--attention-backend",
attention_backend,
"--trust-remote-code",
"--mem-fraction-static",
str(mem_fraction_static),
"--chunked-prefill-size",
str(chunked_prefill_size),
"--max-running-requests",
str(max_running_requests),
"--max-total-tokens",
str(max_total_tokens),
"--watchdog-timeout",
str(watchdog_timeout),
"--enable-mixed-chunk",
"--tensor-parallel-size",
str(tensor_parallel_size),
"--enable-p2p-check",
]
)
# Add served model name if specified
if served_model_name:
cmd.extend(["--served-model-name", served_model_name])
# Add performance flags
if disable_shared_experts_fusion:
cmd.append("--disable-shared-experts-fusion")
# Add any extra parameters from model defaults that weren't explicitly handled
if extra_model_params:
# List of parameters already handled above
handled_params = {
"kt-num-gpu-experts",
"kt-cpuinfer",
"kt-threadpool-count",
"kt-method",
"kt-gpu-prefill-token-threshold",
"attention-backend",
"tensor-parallel-size",
"max-total-tokens",
"max-running-requests",
"chunked-prefill-size",
"mem-fraction-static",
"watchdog-timeout",
"served-model-name",
"disable-shared-experts-fusion",
}
for key, value in extra_model_params.items():
if key not in handled_params:
# Add unhandled parameters dynamically
cmd.append(f"--{key}")
if isinstance(value, bool):
# Boolean flags don't need a value
if not value:
# For False boolean, skip the flag entirely
cmd.pop() # Remove the flag we just added
else:
cmd.append(str(value))
# Add extra args from settings
extra_args = settings.get("advanced.sglang_args", [])
if extra_args:
cmd.extend(extra_args)
return cmd
def _interactive_model_selection(registry, settings) -> Optional[str]:
"""Show interactive model selection interface.
Returns:
Selected model name or None if cancelled.
"""
from rich.panel import Panel
from rich.table import Table
from rich.prompt import Prompt
from kt_kernel.cli.i18n import get_lang
lang = get_lang()
# Find local models first
local_models = registry.find_local_models()
# Get all registered models
all_models = registry.list_all()
console.print()
console.print(
Panel.fit(
t("run_select_model_title"),
border_style="cyan",
)
)
console.print()
# Build choices list
choices = []
choice_map = {} # index -> model name
# Section 1: Local models (downloaded)
if local_models:
console.print(f"[bold green]{t('run_local_models')}[/bold green]")
console.print()
for i, (model_info, path) in enumerate(local_models, 1):
desc = model_info.description_zh if lang == "zh" else model_info.description
short_desc = desc[:50] + "..." if len(desc) > 50 else desc
console.print(f" [cyan][{i}][/cyan] [bold]{model_info.name}[/bold]")
console.print(f" [dim]{short_desc}[/dim]")
console.print(f" [dim]{path}[/dim]")
choices.append(str(i))
choice_map[str(i)] = model_info.name
console.print()
# Section 2: All registered models (for reference)
start_idx = len(local_models) + 1
console.print(f"[bold yellow]{t('run_registered_models')}[/bold yellow]")
console.print()
# Filter out already shown local models
local_model_names = {m.name for m, _ in local_models}
for i, model_info in enumerate(all_models, start_idx):
if model_info.name in local_model_names:
continue
desc = model_info.description_zh if lang == "zh" else model_info.description
short_desc = desc[:50] + "..." if len(desc) > 50 else desc
console.print(f" [cyan][{i}][/cyan] [bold]{model_info.name}[/bold]")
console.print(f" [dim]{short_desc}[/dim]")
console.print(f" [dim]{model_info.hf_repo}[/dim]")
choices.append(str(i))
choice_map[str(i)] = model_info.name
console.print()
# Add cancel option
cancel_idx = str(len(choices) + 1)
console.print(f" [cyan][{cancel_idx}][/cyan] [dim]{t('cancel')}[/dim]")
choices.append(cancel_idx)
console.print()
# Prompt for selection
try:
selection = Prompt.ask(
t("run_select_model_prompt"),
choices=choices,
default="1" if choices else cancel_idx,
)
except KeyboardInterrupt:
console.print()
return None
if selection == cancel_idx:
return None
return choice_map.get(selection)

View File

@@ -0,0 +1,52 @@
"""
SFT command for kt-cli.
Fine-tuning with LlamaFactory integration.
"""
import typer
from kt_kernel.cli.i18n import t
from kt_kernel.cli.utils.console import console
app = typer.Typer(help="Fine-tuning with LlamaFactory (coming soon)")
@app.callback(invoke_without_command=True)
def callback(ctx: typer.Context) -> None:
"""Fine-tuning commands (coming soon)."""
if ctx.invoked_subcommand is None:
console.print()
console.print(f"[yellow]{t('feature_coming_soon')}[/yellow]")
console.print()
console.print("[dim]kt sft train - Train a model[/dim]")
console.print("[dim]kt sft chat - Chat with a trained model[/dim]")
console.print("[dim]kt sft export - Export a trained model[/dim]")
console.print()
@app.command(name="train")
def train() -> None:
"""Train a model using LlamaFactory (coming soon)."""
console.print()
console.print(f"[yellow]{t('feature_coming_soon')}[/yellow]")
console.print()
raise typer.Exit(0)
@app.command(name="chat")
def chat() -> None:
"""Chat with a trained model using LlamaFactory (coming soon)."""
console.print()
console.print(f"[yellow]{t('feature_coming_soon')}[/yellow]")
console.print()
raise typer.Exit(0)
@app.command(name="export")
def export() -> None:
"""Export a trained model using LlamaFactory (coming soon)."""
console.print()
console.print(f"[yellow]{t('feature_coming_soon')}[/yellow]")
console.print()
raise typer.Exit(0)

View File

@@ -0,0 +1,118 @@
"""
Version command for kt-cli.
Displays version information for kt-cli and related packages.
"""
import platform
from typing import Optional
import typer
from kt_kernel.cli import __version__
from kt_kernel.cli.i18n import t
from kt_kernel.cli.utils.console import console, print_version_table
from kt_kernel.cli.utils.environment import detect_cuda_version, get_installed_package_version
def _get_sglang_info() -> str:
"""Get sglang version and installation source information."""
try:
import sglang
version = getattr(sglang, "__version__", None)
if not version:
version = get_installed_package_version("sglang")
if not version:
return t("version_not_installed")
# Try to detect installation source
from pathlib import Path
import subprocess
if hasattr(sglang, "__file__") and sglang.__file__:
location = Path(sglang.__file__).parent.parent
git_dir = location / ".git"
if git_dir.exists():
# Installed from git (editable install)
try:
# Get remote URL
result = subprocess.run(
["git", "remote", "get-url", "origin"],
cwd=location,
capture_output=True,
text=True,
timeout=2,
)
if result.returncode == 0:
remote_url = result.stdout.strip()
# Simplify GitHub URLs
if "github.com" in remote_url:
repo_name = remote_url.split("/")[-1].replace(".git", "")
owner = remote_url.split("/")[-2]
return f"{version} [dim](GitHub: {owner}/{repo_name})[/dim]"
return f"{version} [dim](Git: {remote_url})[/dim]"
except (subprocess.TimeoutExpired, FileNotFoundError, OSError):
pass
# Default: installed from PyPI
return f"{version} [dim](PyPI)[/dim]"
except ImportError:
return t("version_not_installed")
def version(
verbose: bool = typer.Option(False, "--verbose", "-v", help="Show detailed version info"),
) -> None:
"""Show version information."""
console.print(f"\n[bold]{t('version_info')}[/bold] v{__version__}\n")
# Basic info
versions = {
t("version_python"): platform.python_version(),
t("version_platform"): f"{platform.system()} {platform.release()}",
}
# CUDA version
cuda_version = detect_cuda_version()
versions[t("version_cuda")] = cuda_version or t("version_cuda_not_found")
print_version_table(versions)
# Always show key packages with installation source
console.print("\n[bold]Packages:[/bold]\n")
sglang_info = _get_sglang_info()
key_packages = {
t("version_kt_kernel"): get_installed_package_version("kt-kernel") or t("version_not_installed"),
t("version_sglang"): sglang_info,
}
print_version_table(key_packages)
# Show SGLang installation hint if not installed
if sglang_info == t("version_not_installed"):
from kt_kernel.cli.utils.sglang_checker import print_sglang_install_instructions
console.print()
print_sglang_install_instructions()
if verbose:
console.print("\n[bold]Additional Packages:[/bold]\n")
package_versions = {
t("version_ktransformers"): get_installed_package_version("ktransformers") or t("version_not_installed"),
t("version_llamafactory"): get_installed_package_version("llamafactory") or t("version_not_installed"),
"typer": get_installed_package_version("typer") or t("version_not_installed"),
"rich": get_installed_package_version("rich") or t("version_not_installed"),
"torch": get_installed_package_version("torch") or t("version_not_installed"),
"transformers": get_installed_package_version("transformers") or t("version_not_installed"),
}
print_version_table(package_versions)
console.print()

View File

@@ -0,0 +1 @@
"""Shell completion scripts for kt-cli."""

View File

@@ -0,0 +1,153 @@
#compdef kt
# Zsh completion for kt command
# This is a static completion script that doesn't require Python startup
_kt() {
local -a commands
commands=(
'version:Show version information'
'run:Start model inference server'
'chat:Interactive chat with running model'
'quant:Quantize model weights'
'bench:Run full benchmark'
'microbench:Run micro-benchmark'
'doctor:Diagnose environment issues'
'model:Manage models and storage paths'
'config:Manage configuration'
'sft:Fine-tuning with LlamaFactory'
)
local -a run_opts
run_opts=(
'--host[Server host]:host:'
'--port[Server port]:port:'
'--gpu-experts[Number of GPU experts]:count:'
'--cpu-threads[Number of CPU threads]:count:'
'--tensor-parallel-size[Tensor parallel size]:size:'
'--kt-method[KT method]:method:(AMXINT4 FP8 RAWINT4)'
'--attention-backend[Attention backend]:backend:(triton flashinfer)'
'--max-total-tokens[Maximum total tokens]:tokens:'
'--dry-run[Show command without executing]'
'--help[Show help message]'
)
local -a chat_opts
chat_opts=(
'--host[Server host]:host:'
'--port[Server port]:port:'
'--model[Model name]:model:'
'--temperature[Sampling temperature]:temp:'
'--max-tokens[Maximum tokens]:tokens:'
'--system[System prompt]:prompt:'
'--save-history[Save conversation history]'
'--no-save-history[Do not save history]'
'--history-file[History file path]:path:_files'
'--stream[Enable streaming output]'
'--no-stream[Disable streaming output]'
'--help[Show help message]'
)
local -a model_cmds
model_cmds=(
'download:Download a model from HuggingFace'
'list:List available models'
'path-list:List all model storage paths'
'path-add:Add a new model storage path'
'path-remove:Remove a model storage path'
'search:Search for models in the registry'
)
local -a config_cmds
config_cmds=(
'show:Show all configuration'
'get:Get configuration value'
'set:Set configuration value'
'reset:Reset to defaults'
'path:Show configuration file path'
'init:Re-run first-time setup wizard'
)
local -a sft_cmds
sft_cmds=(
'train:Train model'
'chat:Chat with model'
'export:Export model'
)
_arguments -C \
'1: :->command' \
'*::arg:->args'
case $state in
command)
_describe 'kt commands' commands
_arguments \
'--help[Show help message]' \
'--version[Show version]'
;;
args)
case $words[1] in
run)
_arguments $run_opts \
'1:model:'
;;
chat)
_arguments $chat_opts
;;
quant)
_arguments \
'--method[Quantization method]:method:' \
'--output[Output directory]:path:_files -/' \
'--help[Show help message]' \
'1:model:_files -/'
;;
bench|microbench)
_arguments \
'--model[Model name or path]:model:' \
'--config[Config file path]:path:_files' \
'--help[Show help message]'
;;
doctor)
_arguments \
'--verbose[Verbose output]' \
'--help[Show help message]'
;;
model)
_arguments \
'1: :->model_cmd' \
'*::arg:->model_args'
case $state in
model_cmd)
_describe 'model commands' model_cmds
;;
esac
;;
config)
_arguments \
'1: :->config_cmd' \
'*::arg:->config_args'
case $state in
config_cmd)
_describe 'config commands' config_cmds
;;
esac
;;
sft)
_arguments \
'1: :->sft_cmd' \
'*::arg:->sft_args'
case $state in
sft_cmd)
_describe 'sft commands' sft_cmds
;;
esac
;;
esac
;;
esac
}
_kt "$@"

View File

@@ -0,0 +1,73 @@
#!/bin/bash
# Bash completion for kt command
# This is a static completion script that doesn't require Python startup
_kt_completion() {
local cur prev opts
COMPREPLY=()
cur="${COMP_WORDS[COMP_CWORD]}"
prev="${COMP_WORDS[COMP_CWORD-1]}"
# Main commands
local commands="version run chat quant bench microbench doctor model config sft"
# Global options
local global_opts="--help --version"
# Handle subcommands
case "${COMP_CWORD}" in
1)
# First argument: suggest commands and global options
COMPREPLY=( $(compgen -W "${commands} ${global_opts}" -- ${cur}) )
return 0
;;
*)
# Handle specific command options
case "${COMP_WORDS[1]}" in
run)
local run_opts="--host --port --gpu-experts --cpu-threads --tensor-parallel-size --kt-method --attention-backend --max-total-tokens --dry-run --help"
COMPREPLY=( $(compgen -W "${run_opts}" -- ${cur}) )
;;
chat)
local chat_opts="--host --port --model --temperature --max-tokens --system --save-history --no-save-history --history-file --stream --no-stream --help"
COMPREPLY=( $(compgen -W "${chat_opts}" -- ${cur}) )
;;
quant)
local quant_opts="--method --output --help"
COMPREPLY=( $(compgen -W "${quant_opts}" -- ${cur}) )
;;
bench|microbench)
local bench_opts="--model --config --help"
COMPREPLY=( $(compgen -W "${bench_opts}" -- ${cur}) )
;;
doctor)
local doctor_opts="--verbose --help"
COMPREPLY=( $(compgen -W "${doctor_opts}" -- ${cur}) )
;;
model)
local model_cmds="download list path-list path-add path-remove search"
local model_opts="--help"
COMPREPLY=( $(compgen -W "${model_cmds} ${model_opts}" -- ${cur}) )
;;
config)
local config_cmds="show get set reset path init model-path-list model-path-add model-path-remove"
local config_opts="--help"
COMPREPLY=( $(compgen -W "${config_cmds} ${config_opts}" -- ${cur}) )
;;
sft)
local sft_cmds="train chat export"
local sft_opts="--help"
COMPREPLY=( $(compgen -W "${sft_cmds} ${sft_opts}" -- ${cur}) )
;;
version)
COMPREPLY=( $(compgen -W "--help" -- ${cur}) )
;;
*)
COMPREPLY=()
;;
esac
;;
esac
}
complete -F _kt_completion kt

View File

@@ -0,0 +1,74 @@
# Fish completion for kt command
# This is a static completion script that doesn't require Python startup
# Main commands
complete -c kt -f -n "__fish_use_subcommand" -a "version" -d "Show version information"
complete -c kt -f -n "__fish_use_subcommand" -a "run" -d "Start model inference server"
complete -c kt -f -n "__fish_use_subcommand" -a "chat" -d "Interactive chat with running model"
complete -c kt -f -n "__fish_use_subcommand" -a "quant" -d "Quantize model weights"
complete -c kt -f -n "__fish_use_subcommand" -a "bench" -d "Run full benchmark"
complete -c kt -f -n "__fish_use_subcommand" -a "microbench" -d "Run micro-benchmark"
complete -c kt -f -n "__fish_use_subcommand" -a "doctor" -d "Diagnose environment issues"
complete -c kt -f -n "__fish_use_subcommand" -a "model" -d "Manage models and storage paths"
complete -c kt -f -n "__fish_use_subcommand" -a "config" -d "Manage configuration"
complete -c kt -f -n "__fish_use_subcommand" -a "sft" -d "Fine-tuning with LlamaFactory"
# Global options
complete -c kt -l help -d "Show help message"
complete -c kt -l version -d "Show version"
# Run command options
complete -c kt -f -n "__fish_seen_subcommand_from run" -l host -d "Server host"
complete -c kt -f -n "__fish_seen_subcommand_from run" -l port -d "Server port"
complete -c kt -f -n "__fish_seen_subcommand_from run" -l gpu-experts -d "Number of GPU experts"
complete -c kt -f -n "__fish_seen_subcommand_from run" -l cpu-threads -d "Number of CPU threads"
complete -c kt -f -n "__fish_seen_subcommand_from run" -l tensor-parallel-size -d "Tensor parallel size"
complete -c kt -f -n "__fish_seen_subcommand_from run" -l kt-method -d "KT method"
complete -c kt -f -n "__fish_seen_subcommand_from run" -l attention-backend -d "Attention backend"
complete -c kt -f -n "__fish_seen_subcommand_from run" -l max-total-tokens -d "Maximum total tokens"
complete -c kt -f -n "__fish_seen_subcommand_from run" -l dry-run -d "Show command without executing"
# Chat command options
complete -c kt -f -n "__fish_seen_subcommand_from chat" -l host -d "Server host"
complete -c kt -f -n "__fish_seen_subcommand_from chat" -l port -d "Server port"
complete -c kt -f -n "__fish_seen_subcommand_from chat" -l model -d "Model name"
complete -c kt -f -n "__fish_seen_subcommand_from chat" -l temperature -d "Sampling temperature"
complete -c kt -f -n "__fish_seen_subcommand_from chat" -l max-tokens -d "Maximum tokens"
complete -c kt -f -n "__fish_seen_subcommand_from chat" -l system -d "System prompt"
complete -c kt -f -n "__fish_seen_subcommand_from chat" -l save-history -d "Save conversation history"
complete -c kt -f -n "__fish_seen_subcommand_from chat" -l no-save-history -d "Do not save history"
complete -c kt -f -n "__fish_seen_subcommand_from chat" -l history-file -d "History file path"
complete -c kt -f -n "__fish_seen_subcommand_from chat" -l stream -d "Enable streaming output"
complete -c kt -f -n "__fish_seen_subcommand_from chat" -l no-stream -d "Disable streaming output"
# Quant command options
complete -c kt -f -n "__fish_seen_subcommand_from quant" -l method -d "Quantization method"
complete -c kt -f -n "__fish_seen_subcommand_from quant" -l output -d "Output directory"
# Bench command options
complete -c kt -f -n "__fish_seen_subcommand_from bench microbench" -l model -d "Model name or path"
complete -c kt -f -n "__fish_seen_subcommand_from bench microbench" -l config -d "Config file path"
# Doctor command options
complete -c kt -f -n "__fish_seen_subcommand_from doctor" -l verbose -d "Verbose output"
# Model subcommands
complete -c kt -f -n "__fish_seen_subcommand_from model; and not __fish_seen_subcommand_from download list path-list path-add path-remove search" -a "download" -d "Download a model from HuggingFace"
complete -c kt -f -n "__fish_seen_subcommand_from model; and not __fish_seen_subcommand_from download list path-list path-add path-remove search" -a "list" -d "List available models"
complete -c kt -f -n "__fish_seen_subcommand_from model; and not __fish_seen_subcommand_from download list path-list path-add path-remove search" -a "path-list" -d "List all model storage paths"
complete -c kt -f -n "__fish_seen_subcommand_from model; and not __fish_seen_subcommand_from download list path-list path-add path-remove search" -a "path-add" -d "Add a new model storage path"
complete -c kt -f -n "__fish_seen_subcommand_from model; and not __fish_seen_subcommand_from download list path-list path-add path-remove search" -a "path-remove" -d "Remove a model storage path"
complete -c kt -f -n "__fish_seen_subcommand_from model; and not __fish_seen_subcommand_from download list path-list path-add path-remove search" -a "search" -d "Search for models in the registry"
# Config subcommands
complete -c kt -f -n "__fish_seen_subcommand_from config; and not __fish_seen_subcommand_from show get set reset path init" -a "show" -d "Show all configuration"
complete -c kt -f -n "__fish_seen_subcommand_from config; and not __fish_seen_subcommand_from show get set reset path init" -a "get" -d "Get configuration value"
complete -c kt -f -n "__fish_seen_subcommand_from config; and not __fish_seen_subcommand_from show get set reset path init" -a "set" -d "Set configuration value"
complete -c kt -f -n "__fish_seen_subcommand_from config; and not __fish_seen_subcommand_from show get set reset path init" -a "reset" -d "Reset to defaults"
complete -c kt -f -n "__fish_seen_subcommand_from config; and not __fish_seen_subcommand_from show get set reset path init" -a "path" -d "Show configuration file path"
complete -c kt -f -n "__fish_seen_subcommand_from config; and not __fish_seen_subcommand_from show get set reset path init" -a "init" -d "Re-run first-time setup wizard"
# SFT subcommands
complete -c kt -f -n "__fish_seen_subcommand_from sft; and not __fish_seen_subcommand_from train chat export" -a "train" -d "Train model"
complete -c kt -f -n "__fish_seen_subcommand_from sft; and not __fish_seen_subcommand_from train chat export" -a "chat" -d "Chat with model"
complete -c kt -f -n "__fish_seen_subcommand_from sft; and not __fish_seen_subcommand_from train chat export" -a "export" -d "Export model"

View File

@@ -0,0 +1,7 @@
"""
Configuration management for kt-cli.
"""
from kt_kernel.cli.config.settings import Settings, get_settings
__all__ = ["Settings", "get_settings"]

View File

@@ -0,0 +1,311 @@
"""
Configuration management for kt-cli.
Handles reading and writing configuration from ~/.ktransformers/config.yaml
"""
import os
from pathlib import Path
from typing import Any, Optional
import yaml
# Default configuration directory
DEFAULT_CONFIG_DIR = Path.home() / ".ktransformers"
DEFAULT_CONFIG_FILE = DEFAULT_CONFIG_DIR / "config.yaml"
DEFAULT_MODELS_DIR = DEFAULT_CONFIG_DIR / "models"
DEFAULT_CACHE_DIR = DEFAULT_CONFIG_DIR / "cache"
# Default configuration values
DEFAULT_CONFIG = {
"general": {
"language": "auto", # auto, en, zh
"color": True,
"verbose": False,
},
"paths": {
"models": str(DEFAULT_MODELS_DIR),
"cache": str(DEFAULT_CACHE_DIR),
"weights": "", # Custom quantized weights path
},
"server": {
"host": "0.0.0.0",
"port": 30000,
},
"inference": {
# Inference parameters are model-specific and should not have defaults
# They will be auto-detected or use model-specific optimizations
# Environment variables (general optimizations)
"env": {
"PYTORCH_ALLOC_CONF": "expandable_segments:True",
"SGLANG_ENABLE_JIT_DEEPGEMM": "0",
},
},
"download": {
"mirror": "", # HuggingFace mirror URL
"resume": True,
"verify": True,
},
"advanced": {
# Environment variables to set when running
"env": {},
# Extra arguments to pass to sglang
"sglang_args": [],
# Extra arguments to pass to llamafactory
"llamafactory_args": [],
},
"dependencies": {
# SGLang installation source configuration
"sglang": {
"source": "github", # "pypi" or "github"
"repo": "https://github.com/kvcache-ai/sglang",
"branch": "main",
},
},
}
class Settings:
"""Configuration manager for kt-cli."""
def __init__(self, config_path: Optional[Path] = None):
"""Initialize settings manager.
Args:
config_path: Path to config file. Defaults to ~/.ktransformers/config.yaml
"""
self.config_path = config_path or DEFAULT_CONFIG_FILE
self.config_dir = self.config_path.parent
self._config: dict[str, Any] = {}
self._load()
def _ensure_dirs(self) -> None:
"""Ensure configuration directories exist."""
self.config_dir.mkdir(parents=True, exist_ok=True)
# Ensure all model paths exist
model_paths = self.get_model_paths()
for path in model_paths:
path.mkdir(parents=True, exist_ok=True)
Path(self.get("paths.cache", DEFAULT_CACHE_DIR)).mkdir(parents=True, exist_ok=True)
def _load(self) -> None:
"""Load configuration from file."""
self._config = self._deep_copy(DEFAULT_CONFIG)
if self.config_path.exists():
try:
with open(self.config_path, "r", encoding="utf-8") as f:
user_config = yaml.safe_load(f) or {}
self._deep_merge(self._config, user_config)
except (yaml.YAMLError, OSError) as e:
# Log warning but continue with defaults
print(f"Warning: Failed to load config: {e}")
self._ensure_dirs()
def _save(self) -> None:
"""Save configuration to file."""
self._ensure_dirs()
try:
with open(self.config_path, "w", encoding="utf-8") as f:
yaml.dump(self._config, f, default_flow_style=False, allow_unicode=True)
except OSError as e:
raise RuntimeError(f"Failed to save config: {e}")
def _deep_copy(self, obj: Any) -> Any:
"""Create a deep copy of a nested dict."""
if isinstance(obj, dict):
return {k: self._deep_copy(v) for k, v in obj.items()}
if isinstance(obj, list):
return [self._deep_copy(item) for item in obj]
return obj
def _deep_merge(self, base: dict, override: dict) -> None:
"""Deep merge override into base."""
for key, value in override.items():
if key in base and isinstance(base[key], dict) and isinstance(value, dict):
self._deep_merge(base[key], value)
else:
base[key] = value
def get(self, key: str, default: Any = None) -> Any:
"""Get a configuration value by dot-separated key.
Args:
key: Dot-separated key path (e.g., "server.port")
default: Default value if key not found
Returns:
Configuration value or default
"""
parts = key.split(".")
value = self._config
for part in parts:
if isinstance(value, dict) and part in value:
value = value[part]
else:
return default
return value
def set(self, key: str, value: Any) -> None:
"""Set a configuration value by dot-separated key.
Args:
key: Dot-separated key path (e.g., "server.port")
value: Value to set
"""
parts = key.split(".")
config = self._config
# Navigate to parent
for part in parts[:-1]:
if part not in config:
config[part] = {}
config = config[part]
# Set value
config[parts[-1]] = value
self._save()
def delete(self, key: str) -> bool:
"""Delete a configuration value.
Args:
key: Dot-separated key path
Returns:
True if key was deleted, False if not found
"""
parts = key.split(".")
config = self._config
# Navigate to parent
for part in parts[:-1]:
if part not in config:
return False
config = config[part]
# Delete key
if parts[-1] in config:
del config[parts[-1]]
self._save()
return True
return False
def reset(self) -> None:
"""Reset configuration to defaults."""
self._config = self._deep_copy(DEFAULT_CONFIG)
self._save()
def get_all(self) -> dict[str, Any]:
"""Get all configuration values."""
return self._deep_copy(self._config)
def get_env_vars(self) -> dict[str, str]:
"""Get environment variables to set."""
env_vars = {}
# Get from advanced.env
advanced_env = self.get("advanced.env", {})
if isinstance(advanced_env, dict):
env_vars.update({k: str(v) for k, v in advanced_env.items()})
return env_vars
@property
def models_dir(self) -> Path:
"""Get the primary models directory path (for backward compatibility)."""
paths = self.get_model_paths()
return paths[0] if paths else Path(DEFAULT_MODELS_DIR)
def get_model_paths(self) -> list[Path]:
"""Get all model directory paths.
Returns a list of Path objects. Supports both:
- Single path: paths.models = "/path/to/models"
- Multiple paths: paths.models = ["/path/1", "/path/2"]
"""
models_config = self.get("paths.models", DEFAULT_MODELS_DIR)
# Handle both string and list
if isinstance(models_config, str):
return [Path(models_config)]
elif isinstance(models_config, list):
return [Path(p) for p in models_config]
else:
return [Path(DEFAULT_MODELS_DIR)]
def add_model_path(self, path: str) -> None:
"""Add a new model path to the configuration."""
models_config = self.get("paths.models", DEFAULT_MODELS_DIR)
# Convert to list if it's a string
if isinstance(models_config, str):
paths = [models_config]
elif isinstance(models_config, list):
paths = list(models_config)
else:
paths = []
# Add new path if not already present
if path not in paths:
paths.append(path)
self.set("paths.models", paths)
def remove_model_path(self, path: str) -> bool:
"""Remove a model path from the configuration.
Returns True if path was removed, False if not found.
"""
models_config = self.get("paths.models", DEFAULT_MODELS_DIR)
if isinstance(models_config, str):
# Can't remove if it's a single string
if models_config == path:
# Don't remove the last path
return False
return False
elif isinstance(models_config, list):
if path in models_config:
paths = list(models_config)
paths.remove(path)
# Don't allow removing all paths
if not paths:
return False
self.set("paths.models", paths if len(paths) > 1 else paths[0])
return True
return False
@property
def cache_dir(self) -> Path:
"""Get the cache directory path."""
return Path(self.get("paths.cache", DEFAULT_CACHE_DIR))
@property
def weights_dir(self) -> Optional[Path]:
"""Get the custom weights directory path."""
weights = self.get("paths.weights", "")
return Path(weights) if weights else None
# Global settings instance
_settings: Optional[Settings] = None
def get_settings() -> Settings:
"""Get the global settings instance."""
global _settings
if _settings is None:
_settings = Settings()
return _settings
def reset_settings() -> None:
"""Reset the global settings instance."""
global _settings
_settings = None

View File

@@ -0,0 +1,655 @@
"""
Internationalization (i18n) module for kt-cli.
Supports English and Chinese languages, with automatic detection based on
system locale or KT_LANG environment variable.
"""
import os
from typing import Any
# Message definitions for all supported languages
MESSAGES: dict[str, dict[str, str]] = {
"en": {
# General
"welcome": "Welcome to KTransformers!",
"goodbye": "Goodbye!",
"error": "Error",
"warning": "Warning",
"success": "Success",
"info": "Info",
"yes": "Yes",
"no": "No",
"cancel": "Cancel",
"confirm": "Confirm",
"done": "Done",
"failed": "Failed",
"skip": "Skip",
"back": "Back",
"next": "Next",
"retry": "Retry",
"abort": "Abort",
# Version command
"version_info": "KTransformers CLI",
"version_python": "Python",
"version_platform": "Platform",
"version_cuda": "CUDA",
"version_cuda_not_found": "Not found",
"version_kt_kernel": "kt-kernel",
"version_ktransformers": "ktransformers",
"version_sglang": "sglang",
"version_llamafactory": "llamafactory",
"version_not_installed": "Not installed",
# Install command
"install_detecting_env": "Detecting environment managers...",
"install_found": "Found {name} (version {version})",
"install_not_found": "Not found: {name}",
"install_checking_env": "Checking existing environments...",
"install_env_exists": "Found existing 'kt' environment",
"install_env_not_exists": "No 'kt' environment found",
"install_no_env_manager": "No virtual environment manager detected",
"install_select_method": "Please select installation method:",
"install_method_conda": "Create new conda environment 'kt' (Recommended)",
"install_method_venv": "Create new venv environment",
"install_method_uv": "Create new uv environment (Fast)",
"install_method_docker": "Use Docker container",
"install_method_system": "Install to system Python (Not recommended)",
"install_select_mode": "Please select installation mode:",
"install_mode_inference": "Inference - Install kt-kernel + SGLang",
"install_mode_sft": "Training - Install kt-sft + LlamaFactory",
"install_mode_full": "Full - Install all components",
"install_creating_env": "Creating {type} environment '{name}'...",
"install_env_created": "Environment created successfully",
"install_installing_deps": "Installing dependencies...",
"install_checking_deps": "Checking dependency versions...",
"install_dep_ok": "OK",
"install_dep_outdated": "Needs update",
"install_dep_missing": "Missing",
"install_installing_pytorch": "Installing PyTorch...",
"install_installing_from_requirements": "Installing from requirements file...",
"install_deps_outdated": "Found {count} package(s) that need updating. Continue?",
"install_updating": "Updating packages...",
"install_complete": "Installation complete!",
"install_activate_hint": "Activate environment: {command}",
"install_start_hint": "Get started: kt run --help",
"install_docker_pulling": "Pulling Docker image...",
"install_docker_complete": "Docker image ready!",
"install_docker_run_hint": "Run with: docker run --gpus all -p 30000:30000 {image} kt run {model}",
"install_in_venv": "Running in virtual environment: {name}",
"install_continue_without_venv": "Continue installing to system Python?",
"install_already_installed": "All dependencies are already installed!",
"install_confirm": "Install {count} package(s)?",
# Install - System dependencies
"install_checking_system_deps": "Checking system dependencies...",
"install_dep_name": "Dependency",
"install_dep_status": "Status",
"install_deps_all_installed": "All system dependencies are installed",
"install_deps_install_prompt": "Install missing dependencies?",
"install_installing_system_deps": "Installing system dependencies...",
"install_installing_dep": "Installing {name}",
"install_dep_no_install_cmd": "No install command available for {name} on {os}",
"install_dep_install_failed": "Failed to install {name}",
"install_deps_skipped": "Skipping dependency installation",
"install_deps_failed": "Failed to install system dependencies",
# Install - CPU detection
"install_auto_detect_cpu": "Auto-detecting CPU capabilities...",
"install_cpu_features": "Detected CPU features: {features}",
"install_cpu_no_features": "No advanced CPU features detected",
# Install - Build configuration
"install_build_config": "Build Configuration:",
"install_native_warning": "Note: Binary optimized for THIS CPU only (not portable)",
"install_building_from_source": "Building kt-kernel from source...",
"install_build_failed": "Build failed",
"install_build_success": "Build completed successfully",
# Install - Verification
"install_verifying": "Verifying installation...",
"install_verify_success": "kt-kernel {version} ({variant} variant) installed successfully",
"install_verify_failed": "Verification failed: {error}",
# Install - Docker
"install_docker_guide_title": "Docker Installation",
"install_docker_guide_desc": "For Docker installation, please refer to the official guide:",
# Config command
"config_show_title": "Current Configuration",
"config_set_success": "Configuration updated: {key} = {value}",
"config_get_value": "{key} = {value}",
"config_get_not_found": "Configuration key '{key}' not found",
"config_reset_confirm": "This will reset all configurations to default. Continue?",
"config_reset_success": "Configuration reset to default",
"config_file_location": "Configuration file: {path}",
# Doctor command
"doctor_title": "KTransformers Environment Diagnostics",
"doctor_checking": "Running diagnostics...",
"doctor_check_python": "Python version",
"doctor_check_cuda": "CUDA availability",
"doctor_check_gpu": "GPU detection",
"doctor_check_cpu": "CPU",
"doctor_check_cpu_isa": "CPU Instructions",
"doctor_check_numa": "NUMA Topology",
"doctor_check_memory": "System memory",
"doctor_check_disk": "Disk space",
"doctor_check_packages": "Required packages",
"doctor_check_env": "Environment variables",
"doctor_status_ok": "OK",
"doctor_status_warning": "Warning",
"doctor_status_error": "Error",
"doctor_gpu_found": "Found {count} GPU(s): {names}",
"doctor_gpu_not_found": "No GPU detected",
"doctor_cpu_info": "{name} ({cores} cores / {threads} threads)",
"doctor_cpu_isa_info": "{isa_list}",
"doctor_cpu_isa_missing": "Missing recommended: {missing}",
"doctor_numa_info": "{nodes} node(s)",
"doctor_numa_detail": "{node}: CPUs {cpus}",
"doctor_memory_info": "{available} available / {total} total",
"doctor_memory_freq": "{available} available / {total} total ({freq}MHz {type})",
"doctor_disk_info": "{available} available at {path}",
"doctor_all_ok": "All checks passed! Your environment is ready.",
"doctor_has_issues": "Some issues were found. Please review the warnings/errors above.",
# Run command
"run_detecting_hardware": "Detecting hardware configuration...",
"run_gpu_info": "GPU: {name} ({vram}GB VRAM)",
"run_cpu_info": "CPU: {name} ({cores} cores, {numa} NUMA nodes)",
"run_ram_info": "RAM: {total}GB",
"run_checking_model": "Checking model status...",
"run_model_path": "Model path: {path}",
"run_weights_not_found": "Quantized weights not found",
"run_quant_prompt": "Quantize model now? (This may take a while)",
"run_quantizing": "Quantizing model...",
"run_starting_server": "Starting server...",
"run_server_mode": "Mode: SGLang + kt-kernel",
"run_server_port": "Port: {port}",
"run_gpu_experts": "GPU experts: {count}/layer",
"run_cpu_threads": "CPU threads: {count}",
"run_server_started": "Server started!",
"run_api_url": "API URL: http://{host}:{port}",
"run_docs_url": "Docs URL: http://{host}:{port}/docs",
"run_stop_hint": "Press Ctrl+C to stop the server",
"run_model_not_found": "Model '{name}' not found. Run 'kt download' first.",
"run_multiple_matches": "Multiple models found. Please select:",
"run_select_model": "Select model",
"run_select_model_title": "Select a model to run",
"run_select_model_prompt": "Enter number",
"run_local_models": "Local Models (Downloaded)",
"run_registered_models": "Registered Models",
# Download command
"download_list_title": "Available Models",
"download_searching": "Searching for model '{name}'...",
"download_found": "Found: {name}",
"download_multiple_found": "Multiple matches found:",
"download_select": "Select model to download:",
"download_destination": "Destination: {path}",
"download_starting": "Starting download...",
"download_progress": "Downloading {name}...",
"download_complete": "Download complete!",
"download_already_exists": "Model already exists at {path}",
"download_overwrite_prompt": "Overwrite existing files?",
# Quant command
"quant_input_path": "Input path: {path}",
"quant_output_path": "Output path: {path}",
"quant_method": "Quantization method: {method}",
"quant_starting": "Starting quantization...",
"quant_progress": "Quantizing...",
"quant_complete": "Quantization complete!",
"quant_input_not_found": "Input model not found at {path}",
# SFT command
"sft_mode_train": "Training mode",
"sft_mode_chat": "Chat mode",
"sft_mode_export": "Export mode",
"sft_config_path": "Config file: {path}",
"sft_starting": "Starting {mode}...",
"sft_complete": "{mode} complete!",
"sft_config_not_found": "Config file not found: {path}",
# Bench command
"bench_starting": "Starting benchmark...",
"bench_type": "Benchmark type: {type}",
"bench_complete": "Benchmark complete!",
"bench_results_title": "Benchmark Results",
# Common prompts
"prompt_continue": "Continue?",
"prompt_select": "Please select:",
"prompt_enter_value": "Enter value:",
"prompt_confirm_action": "Confirm this action?",
# First-run setup - Model path selection
"setup_model_path_title": "Model Storage Location",
"setup_model_path_desc": "LLM models are large (50-200GB+). Please select a storage location with sufficient space:",
"setup_scanning_disks": "Scanning available storage locations...",
"setup_disk_option": "{path} ({available} available / {total} total)",
"setup_disk_option_recommended": "{path} ({available} available / {total} total) [Recommended]",
"setup_custom_path": "Enter custom path",
"setup_enter_custom_path": "Enter the path for model storage",
"setup_path_not_exist": "Path does not exist. Create it?",
"setup_path_no_write": "No write permission for this path. Please choose another.",
"setup_path_low_space": "Warning: Less than 100GB available. Large models may not fit.",
"setup_model_path_set": "Model storage path set to: {path}",
"setup_no_large_disk": "No large storage locations found. Using default path.",
"setup_scanning_models": "Scanning for existing models...",
"setup_found_models": "Found {count} model(s):",
"setup_model_info": "{name} ({size}, {type})",
"setup_no_models_found": "No existing models found in this location.",
"setup_location_has_models": "{count} model(s) found",
"setup_installing_completion": "Installing shell completion for {shell}...",
"setup_completion_installed": "Shell completion installed! Restart terminal to enable.",
"setup_completion_failed": "Failed to install shell completion. Run 'kt --install-completion' manually.",
# Auto completion
"completion_installed_title": "Tab Completion",
"completion_installed_for": "Shell completion installed for {shell}",
"completion_activate_now": "To enable completion in this terminal session, run:",
"completion_next_session": "Completion will be automatically enabled in new terminal sessions.",
# SGLang
"sglang_not_found": "SGLang not found",
"sglang_pypi_warning": "SGLang from PyPI may not be compatible with kt-kernel",
"sglang_pypi_hint": 'SGLang from PyPI may not be compatible. Install from source: git clone https://github.com/kvcache-ai/sglang && cd sglang && pip install -e "python[all]"',
"sglang_install_hint": 'Install SGLang: git clone https://github.com/kvcache-ai/sglang && cd sglang && pip install -e "python[all]"',
"sglang_recommend_source": 'Recommend reinstalling from source: git clone https://github.com/kvcache-ai/sglang && cd sglang && pip install -e "python[all]"',
"sglang_kt_kernel_not_supported": "SGLang does not support kt-kernel (missing --kt-gpu-prefill-token-threshold parameter)",
"sglang_checking_kt_kernel_support": "Checking SGLang kt-kernel support...",
"sglang_kt_kernel_supported": "SGLang kt-kernel support verified",
# Chat
"chat_proxy_detected": "Proxy detected in environment",
"chat_proxy_confirm": "Use proxy for connection?",
"chat_proxy_disabled": "Proxy disabled for this session",
# Model command
"model_supported_title": "KTransformers Supported Models",
"model_column_model": "Model",
"model_column_status": "Status",
"model_column_local_path": "Local Path",
"model_status_local": "Local",
"model_status_not_downloaded": "Not downloaded",
"model_usage_title": "Usage",
"model_usage_download": "Download a model:",
"model_usage_list_local": "List local models:",
"model_usage_search": "Search models:",
"model_storage_paths_title": "Model Storage Paths",
"model_local_models_title": "Locally Downloaded Models",
"model_available_models_title": "Available Models",
"model_no_local_models": "No locally downloaded models found",
"model_download_hint": "Download a model with:",
"model_download_usage_hint": "Usage: kt model download <model-name>",
"model_download_list_hint": "Use 'kt model download --list' to see available models.",
"model_download_hf_hint": "Or specify a HuggingFace repo directly: kt model download org/model-name",
"model_saved_to": "Model saved to: {path}",
"model_start_with": "Start with: kt run {name}",
"model_download_failed": "Download failed: {error}",
"model_hf_cli_not_found": "huggingface-cli not found. Install with: pip install huggingface-hub",
"model_path_not_exist": "Path does not exist: {path}",
"model_create_directory": "Create directory {path}?",
"model_created_directory": "Created directory: {path}",
"model_create_dir_failed": "Failed to create directory: {error}",
"model_path_added": "Added model path: {path}",
"model_path_removed": "Removed model path: {path}",
"model_path_not_found": "Path not found in configuration or cannot remove last path: {path}",
"model_search_no_results": "No models found matching '{query}'",
"model_search_results_title": "Search Results for '{query}'",
"model_column_name": "Name",
"model_column_hf_repo": "HuggingFace Repo",
"model_column_aliases": "Aliases",
# Coming soon
"feature_coming_soon": "This feature is coming soon...",
},
"zh": {
# General
"welcome": "欢迎使用 KTransformers",
"goodbye": "再见!",
"error": "错误",
"warning": "警告",
"success": "成功",
"info": "信息",
"yes": "",
"no": "",
"cancel": "取消",
"confirm": "确认",
"done": "完成",
"failed": "失败",
"skip": "跳过",
"back": "返回",
"next": "下一步",
"retry": "重试",
"abort": "中止",
# Version command
"version_info": "KTransformers CLI",
"version_python": "Python",
"version_platform": "平台",
"version_cuda": "CUDA",
"version_cuda_not_found": "未找到",
"version_kt_kernel": "kt-kernel",
"version_ktransformers": "ktransformers",
"version_sglang": "sglang",
"version_llamafactory": "llamafactory",
"version_not_installed": "未安装",
# Install command
"install_detecting_env": "检测环境管理工具...",
"install_found": "发现 {name} (版本 {version})",
"install_not_found": "未找到: {name}",
"install_checking_env": "检查现有环境...",
"install_env_exists": "发现现有 'kt' 环境",
"install_env_not_exists": "未发现 'kt' 环境",
"install_no_env_manager": "未检测到虚拟环境管理工具",
"install_select_method": "请选择安装方式:",
"install_method_conda": "创建新的 conda 环境 'kt' (推荐)",
"install_method_venv": "创建新的 venv 环境",
"install_method_uv": "创建新的 uv 环境 (快速)",
"install_method_docker": "使用 Docker 容器",
"install_method_system": "安装到系统 Python (不推荐)",
"install_select_mode": "请选择安装模式:",
"install_mode_inference": "推理模式 - 安装 kt-kernel + SGLang",
"install_mode_sft": "训练模式 - 安装 kt-sft + LlamaFactory",
"install_mode_full": "完整安装 - 安装所有组件",
"install_creating_env": "正在创建 {type} 环境 '{name}'...",
"install_env_created": "环境创建成功",
"install_installing_deps": "正在安装依赖...",
"install_checking_deps": "检查依赖版本...",
"install_dep_ok": "正常",
"install_dep_outdated": "需更新",
"install_dep_missing": "缺失",
"install_installing_pytorch": "正在安装 PyTorch...",
"install_installing_from_requirements": "从依赖文件安装...",
"install_deps_outdated": "发现 {count} 个包需要更新,是否继续?",
"install_updating": "正在更新包...",
"install_complete": "安装完成!",
"install_activate_hint": "激活环境: {command}",
"install_start_hint": "开始使用: kt run --help",
"install_docker_pulling": "正在拉取 Docker 镜像...",
"install_docker_complete": "Docker 镜像已就绪!",
"install_docker_run_hint": "运行: docker run --gpus all -p 30000:30000 {image} kt run {model}",
"install_in_venv": "当前在虚拟环境中: {name}",
"install_continue_without_venv": "继续安装到系统 Python",
"install_already_installed": "所有依赖已安装!",
"install_confirm": "安装 {count} 个包?",
# Install - System dependencies
"install_checking_system_deps": "检查系统依赖...",
"install_dep_name": "依赖项",
"install_dep_status": "状态",
"install_deps_all_installed": "所有系统依赖已安装",
"install_deps_install_prompt": "是否安装缺失的依赖?",
"install_installing_system_deps": "正在安装系统依赖...",
"install_installing_dep": "正在安装 {name}",
"install_dep_no_install_cmd": "{os} 系统上没有 {name} 的安装命令",
"install_dep_install_failed": "安装 {name} 失败",
"install_deps_skipped": "跳过依赖安装",
"install_deps_failed": "系统依赖安装失败",
# Install - CPU detection
"install_auto_detect_cpu": "正在自动检测 CPU 能力...",
"install_cpu_features": "检测到的 CPU 特性: {features}",
"install_cpu_no_features": "未检测到高级 CPU 特性",
# Install - Build configuration
"install_build_config": "构建配置:",
"install_native_warning": "注意: 二进制文件仅针对当前 CPU 优化(不可移植)",
"install_building_from_source": "正在从源码构建 kt-kernel...",
"install_build_failed": "构建失败",
"install_build_success": "构建成功",
# Install - Verification
"install_verifying": "正在验证安装...",
"install_verify_success": "kt-kernel {version} ({variant} 变体) 安装成功",
"install_verify_failed": "验证失败: {error}",
# Install - Docker
"install_docker_guide_title": "Docker 安装",
"install_docker_guide_desc": "有关 Docker 安装,请参阅官方指南:",
# Config command
"config_show_title": "当前配置",
"config_set_success": "配置已更新: {key} = {value}",
"config_get_value": "{key} = {value}",
"config_get_not_found": "未找到配置项 '{key}'",
"config_reset_confirm": "这将重置所有配置为默认值。是否继续?",
"config_reset_success": "配置已重置为默认值",
"config_file_location": "配置文件: {path}",
# Doctor command
"doctor_title": "KTransformers 环境诊断",
"doctor_checking": "正在运行诊断...",
"doctor_check_python": "Python 版本",
"doctor_check_cuda": "CUDA 可用性",
"doctor_check_gpu": "GPU 检测",
"doctor_check_cpu": "CPU",
"doctor_check_cpu_isa": "CPU 指令集",
"doctor_check_numa": "NUMA 拓扑",
"doctor_check_memory": "系统内存",
"doctor_check_disk": "磁盘空间",
"doctor_check_packages": "必需的包",
"doctor_check_env": "环境变量",
"doctor_status_ok": "正常",
"doctor_status_warning": "警告",
"doctor_status_error": "错误",
"doctor_gpu_found": "发现 {count} 个 GPU: {names}",
"doctor_gpu_not_found": "未检测到 GPU",
"doctor_cpu_info": "{name} ({cores} 核心 / {threads} 线程)",
"doctor_cpu_isa_info": "{isa_list}",
"doctor_cpu_isa_missing": "缺少推荐指令集: {missing}",
"doctor_numa_info": "{nodes} 个节点",
"doctor_numa_detail": "{node}: CPU {cpus}",
"doctor_memory_info": "{available} 可用 / {total} 总计",
"doctor_memory_freq": "{available} 可用 / {total} 总计 ({freq}MHz {type})",
"doctor_disk_info": "{path}{available} 可用空间",
"doctor_all_ok": "所有检查通过!您的环境已就绪。",
"doctor_has_issues": "发现一些问题,请查看上方的警告/错误信息。",
# Run command
"run_detecting_hardware": "检测硬件配置...",
"run_gpu_info": "GPU: {name} ({vram}GB 显存)",
"run_cpu_info": "CPU: {name} ({cores} 核心, {numa} NUMA 节点)",
"run_ram_info": "内存: {total}GB",
"run_checking_model": "检查模型状态...",
"run_model_path": "模型路径: {path}",
"run_weights_not_found": "未找到量化权重",
"run_quant_prompt": "是否现在量化模型?(这可能需要一些时间)",
"run_quantizing": "正在量化模型...",
"run_starting_server": "正在启动服务器...",
"run_server_mode": "模式: SGLang + kt-kernel",
"run_server_port": "端口: {port}",
"run_gpu_experts": "GPU 专家: {count}/层",
"run_cpu_threads": "CPU 线程: {count}",
"run_server_started": "服务器已启动!",
"run_api_url": "API 地址: http://{host}:{port}",
"run_docs_url": "文档地址: http://{host}:{port}/docs",
"run_stop_hint": "按 Ctrl+C 停止服务器",
"run_model_not_found": "未找到模型 '{name}'。请先运行 'kt download'",
"run_multiple_matches": "找到多个匹配的模型,请选择:",
"run_select_model": "选择模型",
"run_select_model_title": "选择要运行的模型",
"run_select_model_prompt": "输入编号",
"run_local_models": "本地模型 (已下载)",
"run_registered_models": "注册模型",
# Download command
"download_list_title": "可用模型",
"download_searching": "正在搜索模型 '{name}'...",
"download_found": "找到: {name}",
"download_multiple_found": "找到多个匹配:",
"download_select": "选择要下载的模型:",
"download_destination": "目标路径: {path}",
"download_starting": "开始下载...",
"download_progress": "正在下载 {name}...",
"download_complete": "下载完成!",
"download_already_exists": "模型已存在于 {path}",
"download_overwrite_prompt": "是否覆盖现有文件?",
# Quant command
"quant_input_path": "输入路径: {path}",
"quant_output_path": "输出路径: {path}",
"quant_method": "量化方法: {method}",
"quant_starting": "开始量化...",
"quant_progress": "正在量化...",
"quant_complete": "量化完成!",
"quant_input_not_found": "未找到输入模型: {path}",
# SFT command
"sft_mode_train": "训练模式",
"sft_mode_chat": "聊天模式",
"sft_mode_export": "导出模式",
"sft_config_path": "配置文件: {path}",
"sft_starting": "正在启动 {mode}...",
"sft_complete": "{mode} 完成!",
"sft_config_not_found": "未找到配置文件: {path}",
# Bench command
"bench_starting": "开始基准测试...",
"bench_type": "测试类型: {type}",
"bench_complete": "基准测试完成!",
"bench_results_title": "基准测试结果",
# Common prompts
"prompt_continue": "是否继续?",
"prompt_select": "请选择:",
"prompt_enter_value": "请输入:",
"prompt_confirm_action": "确认此操作?",
# First-run setup - Model path selection
"setup_model_path_title": "模型存储位置",
"setup_model_path_desc": "大语言模型体积较大50-200GB+)。请选择一个有足够空间的存储位置:",
"setup_scanning_disks": "正在扫描可用存储位置...",
"setup_disk_option": "{path} (可用 {available} / 总共 {total})",
"setup_disk_option_recommended": "{path} (可用 {available} / 总共 {total}) [推荐]",
"setup_custom_path": "输入自定义路径",
"setup_enter_custom_path": "请输入模型存储路径",
"setup_path_not_exist": "路径不存在,是否创建?",
"setup_path_no_write": "没有该路径的写入权限,请选择其他路径。",
"setup_path_low_space": "警告:可用空间不足 100GB可能无法存储大型模型。",
"setup_model_path_set": "模型存储路径已设置为: {path}",
"setup_no_large_disk": "未发现大容量存储位置,使用默认路径。",
"setup_scanning_models": "正在扫描已有模型...",
"setup_found_models": "发现 {count} 个模型:",
"setup_model_info": "{name} ({size}, {type})",
"setup_no_models_found": "该位置未发现已有模型。",
"setup_location_has_models": "发现 {count} 个模型",
"setup_installing_completion": "正在为 {shell} 安装命令补全...",
"setup_completion_installed": "命令补全已安装!重启终端后生效。",
"setup_completion_failed": "命令补全安装失败。请手动运行 'kt --install-completion'",
# Auto completion
"completion_installed_title": "命令补全",
"completion_installed_for": "已为 {shell} 安装命令补全",
"completion_activate_now": "在当前终端会话中启用补全,请运行:",
"completion_next_session": "新的终端会话将自动启用补全。",
# SGLang
"sglang_not_found": "未找到 SGLang",
"sglang_pypi_warning": "PyPI 版本的 SGLang 可能与 kt-kernel 不兼容",
"sglang_pypi_hint": 'PyPI 版本可能不兼容。从源码安装: git clone https://github.com/kvcache-ai/sglang && cd sglang && pip install -e "python[all]"',
"sglang_install_hint": '安装 SGLang: git clone https://github.com/kvcache-ai/sglang && cd sglang && pip install -e "python[all]"',
"sglang_recommend_source": '建议从源码重新安装: git clone https://github.com/kvcache-ai/sglang && cd sglang && pip install -e "python[all]"',
"sglang_kt_kernel_not_supported": "SGLang 不支持 kt-kernel (缺少 --kt-gpu-prefill-token-threshold 参数)",
"sglang_checking_kt_kernel_support": "正在检查 SGLang kt-kernel 支持...",
"sglang_kt_kernel_supported": "SGLang kt-kernel 支持已验证",
# Chat
"chat_proxy_detected": "检测到环境中存在代理设置",
"chat_proxy_confirm": "是否使用代理连接?",
"chat_proxy_disabled": "已在本次会话中禁用代理",
# Model command
"model_supported_title": "KTransformers 支持的模型",
"model_column_model": "模型",
"model_column_status": "状态",
"model_column_local_path": "本地路径",
"model_status_local": "本地",
"model_status_not_downloaded": "未下载",
"model_usage_title": "使用方法",
"model_usage_download": "下载模型:",
"model_usage_list_local": "列出本地模型:",
"model_usage_search": "搜索模型:",
"model_storage_paths_title": "模型存储路径",
"model_local_models_title": "本地已下载的模型",
"model_available_models_title": "可用模型",
"model_no_local_models": "未找到本地已下载的模型",
"model_download_hint": "下载模型:",
"model_download_usage_hint": "用法: kt model download <模型名称>",
"model_download_list_hint": "使用 'kt model download --list' 查看可用模型。",
"model_download_hf_hint": "或直接指定 HuggingFace 仓库: kt model download org/model-name",
"model_saved_to": "模型已保存到: {path}",
"model_start_with": "启动命令: kt run {name}",
"model_download_failed": "下载失败: {error}",
"model_hf_cli_not_found": "未找到 huggingface-cli。请安装: pip install huggingface-hub",
"model_path_not_exist": "路径不存在: {path}",
"model_create_directory": "创建目录 {path}",
"model_created_directory": "已创建目录: {path}",
"model_create_dir_failed": "创建目录失败: {error}",
"model_path_added": "已添加模型路径: {path}",
"model_path_removed": "已移除模型路径: {path}",
"model_path_not_found": "路径未找到或无法移除最后一个路径: {path}",
"model_search_no_results": "未找到匹配 '{query}' 的模型",
"model_search_results_title": "'{query}' 的搜索结果",
"model_column_name": "名称",
"model_column_hf_repo": "HuggingFace 仓库",
"model_column_aliases": "别名",
# Coming soon
"feature_coming_soon": "此功能即将推出...",
},
}
# Cache for language detection to avoid repeated I/O
_lang_cache: str | None = None
def get_lang() -> str:
"""
Detect the current language setting.
Priority:
1. KT_LANG environment variable
2. Config file general.language setting
3. LANG environment variable (if config is "auto")
4. Default to English
Returns:
Language code: "zh" for Chinese, "en" for English
"""
global _lang_cache
# 1. Check KT_LANG environment variable (highest priority)
kt_lang = os.environ.get("KT_LANG", "").lower()
if kt_lang:
return "zh" if kt_lang.startswith("zh") else "en"
# 2. Return cached value if available (avoids I/O on every call)
if _lang_cache is not None:
return _lang_cache
# 3. Check config file setting (with caching)
# Import here to avoid circular imports
from kt_kernel.cli.config.settings import get_settings
try:
settings = get_settings()
config_lang = settings.get("general.language", "auto")
if config_lang and config_lang != "auto":
lang = "zh" if config_lang.lower().startswith("zh") else "en"
_lang_cache = lang
return lang
except Exception:
# If settings fail to load, continue with system detection
pass
# 4. Check system LANG environment variable
system_lang = os.environ.get("LANG", "").lower()
lang = "zh" if system_lang.startswith("zh") else "en"
_lang_cache = lang
return lang
def t(msg_key: str, **kwargs: Any) -> str:
"""
Translate a message key to the current language.
Args:
msg_key: Message key to translate
**kwargs: Format arguments for the message
Returns:
Translated and formatted message string
Example:
>>> t("welcome")
"Welcome to KTransformers!" # or "欢迎使用 KTransformers" in Chinese
>>> t("install_found", name="conda", version="24.1.0")
"Found conda (version 24.1.0)"
"""
lang = get_lang()
messages = MESSAGES.get(lang, MESSAGES["en"])
message = messages.get(msg_key, MESSAGES["en"].get(msg_key, msg_key))
if kwargs:
try:
return message.format(**kwargs)
except KeyError:
return message
return message
def set_lang(lang: str) -> None:
"""
Set the language for the current session.
Args:
lang: Language code ("en" or "zh")
"""
global _lang_cache
os.environ["KT_LANG"] = lang
_lang_cache = lang # Update cache when language is explicitly set

View File

@@ -0,0 +1,436 @@
"""
Main entry point for kt-cli.
KTransformers CLI - A unified command-line interface for KTransformers.
"""
import sys
import typer
from kt_kernel.cli import __version__
from kt_kernel.cli.commands import bench, chat, config, doctor, model, quant, run, sft, version
from kt_kernel.cli.i18n import t, set_lang, get_lang
def _get_app_help() -> str:
"""Get app help text based on current language."""
lang = get_lang()
if lang == "zh":
return "KTransformers CLI - KTransformers 统一命令行界面"
return "KTransformers CLI - A unified command-line interface for KTransformers."
def _get_help(key: str) -> str:
"""Get help text based on current language."""
help_texts = {
"version": {"en": "Show version information", "zh": "显示版本信息"},
"run": {"en": "Start model inference server", "zh": "启动模型推理服务器"},
"chat": {"en": "Interactive chat with running model", "zh": "与运行中的模型进行交互式聊天"},
"quant": {"en": "Quantize model weights", "zh": "量化模型权重"},
"bench": {"en": "Run full benchmark", "zh": "运行完整基准测试"},
"microbench": {"en": "Run micro-benchmark", "zh": "运行微基准测试"},
"doctor": {"en": "Diagnose environment issues", "zh": "诊断环境问题"},
"model": {"en": "Manage models and storage paths", "zh": "管理模型和存储路径"},
"config": {"en": "Manage configuration", "zh": "管理配置"},
"sft": {"en": "Fine-tuning with LlamaFactory", "zh": "使用 LlamaFactory 进行微调"},
}
lang = get_lang()
return help_texts.get(key, {}).get(lang, help_texts.get(key, {}).get("en", key))
# Create main app with dynamic help
app = typer.Typer(
name="kt",
help="KTransformers CLI - A unified command-line interface for KTransformers.",
no_args_is_help=True,
add_completion=False, # Use static completion scripts instead of dynamic completion
rich_markup_mode="rich",
)
def _update_help_texts() -> None:
"""Update all help texts based on current language setting."""
# Update main app help
app.info.help = _get_app_help()
# Update command help texts
for cmd_info in app.registered_commands:
# cmd_info is a CommandInfo object
if hasattr(cmd_info, "name") and cmd_info.name:
cmd_info.help = _get_help(cmd_info.name)
# Update sub-app help texts
for group_info in app.registered_groups:
if hasattr(group_info, "name") and group_info.name:
group_info.help = _get_help(group_info.name)
# Register commands
app.command(name="version", help="Show version information")(version.version)
app.command(name="run", help="Start model inference server")(run.run)
app.command(name="chat", help="Interactive chat with running model")(chat.chat)
app.command(name="quant", help="Quantize model weights")(quant.quant)
app.command(name="bench", help="Run full benchmark")(bench.bench)
app.command(name="microbench", help="Run micro-benchmark")(bench.microbench)
app.command(name="doctor", help="Diagnose environment issues")(doctor.doctor)
# Register sub-apps
app.add_typer(model.app, name="model", help="Manage models and storage paths")
app.add_typer(config.app, name="config", help="Manage configuration")
app.add_typer(sft.app, name="sft", help="Fine-tuning with LlamaFactory")
def check_first_run() -> None:
"""Check if this is the first run and prompt for language setup."""
import os
# Skip if not running in interactive terminal
if not sys.stdin.isatty():
return
from kt_kernel.cli.config.settings import DEFAULT_CONFIG_FILE
# Only check if config file exists - don't create it yet
if not DEFAULT_CONFIG_FILE.exists():
# First run - show welcome and language selection
from kt_kernel.cli.config.settings import get_settings
settings = get_settings()
_show_first_run_setup(settings)
else:
# Config exists - check if initialized
from kt_kernel.cli.config.settings import get_settings
settings = get_settings()
if not settings.get("general._initialized"):
_show_first_run_setup(settings)
def _show_first_run_setup(settings) -> None:
"""Show first-run setup wizard."""
from rich.console import Console
from rich.panel import Panel
from rich.prompt import Prompt, Confirm
from rich.spinner import Spinner
from rich.live import Live
from kt_kernel.cli.utils.environment import scan_storage_locations, format_size_gb, scan_models_in_location
console = Console()
# Welcome message
console.print()
console.print(
Panel.fit(
"[bold cyan]Welcome to KTransformers CLI! / 欢迎使用 KTransformers CLI![/bold cyan]\n\n"
"Let's set up your preferences.\n"
"让我们设置您的偏好。",
title="kt-cli",
border_style="cyan",
)
)
console.print()
# Language selection
console.print("[bold]Select your preferred language / 选择您的首选语言:[/bold]")
console.print()
console.print(" [cyan][1][/cyan] English")
console.print(" [cyan][2][/cyan] 中文 (Chinese)")
console.print()
while True:
choice = Prompt.ask("Enter choice / 输入选择", choices=["1", "2"], default="1")
if choice == "1":
lang = "en"
break
elif choice == "2":
lang = "zh"
break
# Save language setting
settings.set("general.language", lang)
set_lang(lang)
# Confirmation message
console.print()
if lang == "zh":
console.print("[green]✓[/green] 语言已设置为中文")
else:
console.print("[green]✓[/green] Language set to English")
# Model storage path selection
console.print()
console.print(f"[bold]{t('setup_model_path_title')}[/bold]")
console.print()
console.print(f"[dim]{t('setup_model_path_desc')}[/dim]")
console.print()
# Scan for storage locations
console.print(f"[dim]{t('setup_scanning_disks')}[/dim]")
locations = scan_storage_locations(min_size_gb=50.0)
console.print()
if locations:
# Scan for models in each location
console.print(f"[dim]{t('setup_scanning_models')}[/dim]")
location_models: dict[str, list] = {}
for loc in locations[:5]:
models = scan_models_in_location(loc, max_depth=2)
if models:
location_models[loc.path] = models
console.print()
# Show options
for i, loc in enumerate(locations[:5], 1): # Show top 5 options
available = format_size_gb(loc.available_gb)
total = format_size_gb(loc.total_gb)
# Build the option string
if i == 1:
option_str = t("setup_disk_option_recommended", path=loc.path, available=available, total=total)
else:
option_str = t("setup_disk_option", path=loc.path, available=available, total=total)
# Add model count if any
if loc.path in location_models:
model_count = len(location_models[loc.path])
option_str += f" [green]✓ {t('setup_location_has_models', count=model_count)}[/green]"
console.print(f" [cyan][{i}][/cyan] {option_str}")
# Show first few models found in this location
if loc.path in location_models:
for model in location_models[loc.path][:3]: # Show up to 3 models
size_str = format_size_gb(model.size_gb)
console.print(f" [dim]• {model.name} ({size_str})[/dim]")
if len(location_models[loc.path]) > 3:
remaining = len(location_models[loc.path]) - 3
console.print(f" [dim] ... +{remaining} more[/dim]")
# Custom path option
custom_idx = min(len(locations), 5) + 1
console.print(f" [cyan][{custom_idx}][/cyan] {t('setup_custom_path')}")
console.print()
valid_choices = [str(i) for i in range(1, custom_idx + 1)]
path_choice = Prompt.ask(t("prompt_select"), choices=valid_choices, default="1")
if path_choice == str(custom_idx):
# Custom path
selected_path = _prompt_custom_path(console, settings)
else:
selected_path = locations[int(path_choice) - 1].path
else:
# No large storage found, ask for custom path
console.print(f"[yellow]{t('setup_no_large_disk')}[/yellow]")
console.print()
selected_path = _prompt_custom_path(console, settings)
# Ensure the path exists
import os
from pathlib import Path
if not os.path.exists(selected_path):
if Confirm.ask(t("setup_path_not_exist"), default=True):
try:
Path(selected_path).mkdir(parents=True, exist_ok=True)
except (OSError, PermissionError) as e:
console.print(f"[red]{t('error')}: {e}[/red]")
# Fall back to default
selected_path = str(Path.home() / ".ktransformers" / "models")
Path(selected_path).mkdir(parents=True, exist_ok=True)
# Check available space and warn if low
from kt_kernel.cli.utils.environment import detect_disk_space_gb
available_gb, _ = detect_disk_space_gb(
selected_path if os.path.exists(selected_path) else str(Path(selected_path).parent)
)
if available_gb < 100:
console.print(f"[yellow]{t('setup_path_low_space')}[/yellow]")
# Save the path
settings.set("paths.models", selected_path)
settings.set("general._initialized", True)
console.print()
console.print(f"[green]✓[/green] {t('setup_model_path_set', path=selected_path)}")
console.print()
# Tips
if lang == "zh":
console.print("[dim]提示: 运行 'kt config show' 查看所有配置[/dim]")
else:
console.print("[dim]Tip: Run 'kt config show' to view all settings[/dim]")
console.print()
def _prompt_custom_path(console, settings) -> str:
"""Prompt user to enter a custom path."""
from rich.prompt import Prompt
from pathlib import Path
import os
default_path = str(Path.home() / ".ktransformers" / "models")
while True:
custom_path = Prompt.ask(t("setup_enter_custom_path"), default=default_path)
# Expand user home
custom_path = os.path.expanduser(custom_path)
# Check if path exists or parent is writable
if os.path.exists(custom_path):
if os.access(custom_path, os.W_OK):
return custom_path
else:
console.print(f"[red]{t('setup_path_no_write')}[/red]")
else:
# Check if we can create it (parent writable)
parent = str(Path(custom_path).parent)
while not os.path.exists(parent) and parent != "/":
parent = str(Path(parent).parent)
if os.access(parent, os.W_OK):
return custom_path
else:
console.print(f"[red]{t('setup_path_no_write')}[/red]")
def _install_shell_completion() -> None:
"""Install shell completion scripts to user directories.
Uses standard locations that are auto-loaded by shell completion systems:
- Bash: ~/.local/share/bash-completion/completions/kt (auto-loaded by bash-completion 2.0+)
- Zsh: ~/.zfunc/_kt (requires fpath setup, but commonly used)
- Fish: ~/.config/fish/completions/kt.fish (auto-loaded)
"""
import os
import shutil
from pathlib import Path
from kt_kernel.cli.config.settings import get_settings
settings = get_settings()
# Check if already installed
if settings.get("general._completion_installed", False):
return
# Detect current shell
shell = os.environ.get("SHELL", "")
if "zsh" in shell:
shell_name = "zsh"
elif "fish" in shell:
shell_name = "fish"
else:
shell_name = "bash"
try:
cli_dir = Path(__file__).parent
completions_dir = cli_dir / "completions"
home = Path.home()
installed = False
if shell_name == "bash":
# Use XDG standard location for bash-completion (auto-loaded)
src_file = completions_dir / "kt-completion.bash"
dest_dir = home / ".local" / "share" / "bash-completion" / "completions"
dest_file = dest_dir / "kt"
if src_file.exists():
dest_dir.mkdir(parents=True, exist_ok=True)
shutil.copy2(src_file, dest_file)
installed = True
elif shell_name == "zsh":
src_file = completions_dir / "_kt"
dest_dir = home / ".zfunc"
dest_file = dest_dir / "_kt"
if src_file.exists():
dest_dir.mkdir(parents=True, exist_ok=True)
shutil.copy2(src_file, dest_file)
installed = True
elif shell_name == "fish":
# Fish auto-loads from this directory
src_file = completions_dir / "kt.fish"
dest_dir = home / ".config" / "fish" / "completions"
dest_file = dest_dir / "kt.fish"
if src_file.exists():
dest_dir.mkdir(parents=True, exist_ok=True)
shutil.copy2(src_file, dest_file)
installed = True
# Mark as installed
settings.set("general._completion_installed", True)
# For bash/zsh, completion will work in new terminals automatically
# (bash-completion 2.0+ auto-loads from ~/.local/share/bash-completion/completions/)
except (OSError, IOError):
# Silently ignore errors - completion is not critical
pass
def _apply_saved_language() -> None:
"""Apply the saved language setting.
Priority:
1. KT_LANG environment variable (if already set, don't override)
2. Config file setting
3. System locale (auto)
"""
import os
# Don't override if KT_LANG is already set by user
if os.environ.get("KT_LANG"):
return
from kt_kernel.cli.config.settings import get_settings
settings = get_settings()
lang = settings.get("general.language", "auto")
if lang != "auto":
set_lang(lang)
def main():
"""Main entry point."""
# Apply saved language setting first (before anything else for correct help display)
_apply_saved_language()
# Update help texts based on language
_update_help_texts()
# Check for first run (but not for certain commands)
# Skip first-run check for: --help, config commands, version
args = sys.argv[1:] if len(sys.argv) > 1 else []
skip_commands = ["--help", "-h", "config", "version", "--version"]
should_check_first_run = True
for arg in args:
if arg in skip_commands:
should_check_first_run = False
break
# Auto-install shell completion on first run
if should_check_first_run:
_install_shell_completion()
# Check first run before running commands
if should_check_first_run and args:
check_first_run()
app()
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,6 @@
# Inference dependencies for KTransformers
# NOTE: sglang is installed separately from source (see install.py)
transformers>=4.45.0
safetensors>=0.4.0
huggingface-hub>=0.20.0

View File

@@ -0,0 +1,7 @@
# SFT (Supervised Fine-Tuning) dependencies for KTransformers
llamafactory>=0.9.0
peft>=0.12.0
transformers>=4.45.0
datasets>=2.14.0
accelerate>=0.30.0

View File

@@ -0,0 +1,3 @@
"""
Utility modules for kt-cli.
"""

View File

@@ -0,0 +1,249 @@
"""
Console utilities for kt-cli.
Provides Rich-based console output helpers for consistent formatting.
"""
from typing import Optional
from rich.console import Console
from rich.panel import Panel
from rich.progress import (
BarColumn,
DownloadColumn,
Progress,
SpinnerColumn,
TaskProgressColumn,
TextColumn,
TimeElapsedColumn,
TimeRemainingColumn,
TransferSpeedColumn,
)
from rich.prompt import Confirm, Prompt
from rich.table import Table
from rich.theme import Theme
from kt_kernel.cli.i18n import t
# Custom theme for kt-cli
KT_THEME = Theme(
{
"info": "cyan",
"warning": "yellow",
"error": "bold red",
"success": "bold green",
"highlight": "bold magenta",
"muted": "dim",
}
)
# Global console instance
console = Console(theme=KT_THEME)
def print_info(message: str, **kwargs) -> None:
"""Print an info message."""
console.print(f"[info][/info] {message}", **kwargs)
def print_success(message: str, **kwargs) -> None:
"""Print a success message."""
console.print(f"[success]✓[/success] {message}", **kwargs)
def print_warning(message: str, **kwargs) -> None:
"""Print a warning message."""
console.print(f"[warning]⚠[/warning] {message}", **kwargs)
def print_error(message: str, **kwargs) -> None:
"""Print an error message."""
console.print(f"[error]✗[/error] {message}", **kwargs)
def print_step(message: str, **kwargs) -> None:
"""Print a step indicator."""
console.print(f"[highlight]→[/highlight] {message}", **kwargs)
def print_header(title: str, subtitle: Optional[str] = None) -> None:
"""Print a header panel."""
content = f"[bold]{title}[/bold]"
if subtitle:
content += f"\n[muted]{subtitle}[/muted]"
console.print(Panel(content, expand=False))
def print_version_table(versions: dict[str, Optional[str]]) -> None:
"""Print a version information table."""
table = Table(show_header=False, box=None, padding=(0, 2))
table.add_column("Component", style="bold")
table.add_column("Version")
for name, version in versions.items():
if version:
table.add_row(name, f"[success]{version}[/success]")
else:
table.add_row(name, f"[muted]{t('version_not_installed')}[/muted]")
console.print(table)
def print_dependency_table(deps: list[dict]) -> None:
"""Print a dependency status table."""
table = Table(title=t("install_checking_deps"))
table.add_column(t("version_info"), style="bold")
table.add_column("Current")
table.add_column("Required")
table.add_column("Status")
for dep in deps:
status = dep.get("status", "ok")
if status == "ok":
status_str = f"[success]{t('install_dep_ok')}[/success]"
elif status == "outdated":
status_str = f"[warning]{t('install_dep_outdated')}[/warning]"
else:
status_str = f"[error]{t('install_dep_missing')}[/error]"
table.add_row(
dep["name"],
dep.get("installed", "-"),
dep.get("required", "-"),
status_str,
)
console.print(table)
def confirm(message: str, default: bool = True) -> bool:
"""Ask for confirmation."""
return Confirm.ask(message, default=default, console=console)
def prompt_choice(message: str, choices: list[str], default: Optional[str] = None) -> str:
"""Prompt for a choice from a list."""
# Display numbered choices
console.print(f"\n[bold]{message}[/bold]")
for i, choice in enumerate(choices, 1):
console.print(f" [highlight][{i}][/highlight] {choice}")
while True:
response = Prompt.ask(
"\n" + t("prompt_select"),
console=console,
default=str(choices.index(default) + 1) if default else None,
)
try:
idx = int(response) - 1
if 0 <= idx < len(choices):
return choices[idx]
except ValueError:
# Check if response matches a choice directly
if response in choices:
return response
print_error(f"Please enter a number between 1 and {len(choices)}")
def prompt_text(message: str, default: Optional[str] = None) -> str:
"""Prompt for text input."""
return Prompt.ask(message, console=console, default=default)
def create_progress() -> Progress:
"""Create a progress bar for general tasks."""
return Progress(
SpinnerColumn(),
TextColumn("[progress.description]{task.description}"),
BarColumn(),
TaskProgressColumn(),
TimeElapsedColumn(),
console=console,
)
def create_download_progress() -> Progress:
"""Create a progress bar for downloads."""
return Progress(
SpinnerColumn(),
TextColumn("[progress.description]{task.description}"),
BarColumn(),
DownloadColumn(),
TransferSpeedColumn(),
TimeRemainingColumn(),
console=console,
)
def print_model_table(models: list[dict]) -> None:
"""Print a table of models."""
table = Table(title=t("download_list_title"))
table.add_column("Name", style="bold")
table.add_column("Repository")
table.add_column("Type")
table.add_column("Requirements")
for model in models:
reqs = []
if model.get("gpu_vram_gb"):
reqs.append(f"GPU: {model['gpu_vram_gb']}GB")
if model.get("cpu_ram_gb"):
reqs.append(f"RAM: {model['cpu_ram_gb']}GB")
table.add_row(
model.get("name", ""),
model.get("hf_repo", ""),
model.get("type", ""),
", ".join(reqs) if reqs else "-",
)
console.print(table)
def print_hardware_info(gpu_info: str, cpu_info: str, ram_info: str) -> None:
"""Print hardware information."""
table = Table(show_header=False, box=None)
table.add_column("Icon", width=3)
table.add_column("Info")
table.add_row("🖥️", gpu_info)
table.add_row("💻", cpu_info)
table.add_row("🧠", ram_info)
console.print(Panel(table, title="Hardware", expand=False))
def print_server_info(
mode: str, host: str, port: int, gpu_experts: int, cpu_threads: int
) -> None:
"""Print server startup information."""
table = Table(show_header=False, box=None)
table.add_column("Key", style="bold")
table.add_column("Value")
table.add_row(t("run_server_mode").split(":")[0], mode)
table.add_row("Host", host)
table.add_row("Port", str(port))
table.add_row(t("run_gpu_experts").split(":")[0], f"{gpu_experts}/layer")
table.add_row(t("run_cpu_threads").split(":")[0], str(cpu_threads))
console.print(Panel(table, title=t("run_server_started"), expand=False, border_style="green"))
def print_api_info(host: str, port: int) -> None:
"""Print API endpoint information."""
api_url = f"http://{host}:{port}"
docs_url = f"http://{host}:{port}/docs"
console.print()
console.print(f" {t('run_api_url', host=host, port=port)}")
console.print(f" {t('run_docs_url', host=host, port=port)}")
console.print()
console.print(f" [muted]Test command:[/muted]")
console.print(
f" [dim]curl {api_url}/v1/chat/completions -H 'Content-Type: application/json' "
f"-d '{{\"model\": \"default\", \"messages\": [{{\"role\": \"user\", \"content\": \"Hello\"}}]}}'[/dim]"
)
console.print()
console.print(f" [muted]{t('run_stop_hint')}[/muted]")

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,374 @@
"""
Model registry for kt-cli.
Provides a registry of supported models with fuzzy matching capabilities.
"""
import re
from dataclasses import dataclass, field
from pathlib import Path
from typing import Callable, Optional
import yaml
from kt_kernel.cli.config.settings import get_settings
@dataclass
class ModelInfo:
"""Information about a supported model."""
name: str
hf_repo: str
aliases: list[str] = field(default_factory=list)
type: str = "moe" # moe, dense
gpu_vram_gb: float = 0
cpu_ram_gb: float = 0
default_params: dict = field(default_factory=dict)
description: str = ""
description_zh: str = ""
max_tensor_parallel_size: Optional[int] = None # Maximum tensor parallel size for this model
# Built-in model registry
BUILTIN_MODELS: list[ModelInfo] = [
ModelInfo(
name="DeepSeek-V3-0324",
hf_repo="deepseek-ai/DeepSeek-V3-0324",
aliases=["deepseek-v3-0324", "deepseek-v3", "dsv3", "deepseek3", "v3-0324"],
type="moe",
default_params={
"kt-num-gpu-experts": 1,
"attention-backend": "triton",
"disable-shared-experts-fusion": True,
"kt-method": "AMXINT4",
},
description="DeepSeek V3-0324 685B MoE model (March 2025, improved benchmarks)",
description_zh="DeepSeek V3-0324 685B MoE 模型2025年3月改进的基准测试",
),
ModelInfo(
name="DeepSeek-V3.2",
hf_repo="deepseek-ai/DeepSeek-V3.2",
aliases=["deepseek-v3.2", "dsv3.2", "deepseek3.2", "v3.2"],
type="moe",
default_params={
"kt-method": "FP8",
"kt-gpu-prefill-token-threshold": 4096,
"attention-backend": "flashinfer",
"fp8-gemm-backend": "triton",
"max-total-tokens": 100000,
"max-running-requests": 16,
"chunked-prefill-size": 32768,
"mem-fraction-static": 0.80,
"watchdog-timeout": 3000,
"served-model-name": "DeepSeek-V3.2",
"disable-shared-experts-fusion": True,
},
description="DeepSeek V3.2 671B MoE model (latest)",
description_zh="DeepSeek V3.2 671B MoE 模型(最新)",
),
ModelInfo(
name="DeepSeek-R1-0528",
hf_repo="deepseek-ai/DeepSeek-R1-0528",
aliases=["deepseek-r1-0528", "deepseek-r1", "dsr1", "r1", "r1-0528"],
type="moe",
default_params={
"kt-num-gpu-experts": 1,
"attention-backend": "triton",
"disable-shared-experts-fusion": True,
"kt-method": "AMXINT4",
},
description="DeepSeek R1-0528 reasoning model (May 2025, improved reasoning depth)",
description_zh="DeepSeek R1-0528 推理模型2025年5月改进的推理深度",
),
ModelInfo(
name="Kimi-K2-Thinking",
hf_repo="moonshotai/Kimi-K2-Thinking",
aliases=["kimi-k2-thinking", "kimi-thinking", "k2-thinking", "kimi", "k2"],
type="moe",
default_params={
"kt-method": "RAWINT4",
"kt-gpu-prefill-token-threshold": 400,
"attention-backend": "flashinfer",
"max-total-tokens": 100000,
"max-running-requests": 16,
"chunked-prefill-size": 32768,
"mem-fraction-static": 0.80,
"watchdog-timeout": 3000,
"served-model-name": "Kimi-K2-Thinking",
"disable-shared-experts-fusion": True,
},
description="Moonshot Kimi K2 Thinking MoE model",
description_zh="月之暗面 Kimi K2 Thinking MoE 模型",
),
ModelInfo(
name="MiniMax-M2",
hf_repo="MiniMaxAI/MiniMax-M2",
aliases=["minimax-m2", "m2"],
type="moe",
default_params={
"kt-method": "FP8",
"kt-gpu-prefill-token-threshold": 4096,
"attention-backend": "flashinfer",
"fp8-gemm-backend": "triton",
"max-total-tokens": 100000,
"max-running-requests": 16,
"chunked-prefill-size": 32768,
"mem-fraction-static": 0.80,
"watchdog-timeout": 3000,
"served-model-name": "MiniMax-M2",
"disable-shared-experts-fusion": True,
"tool-call-parser": "minimax-m2",
"reasoning-parser": "minimax-append-think",
},
description="MiniMax M2 MoE model",
description_zh="MiniMax M2 MoE 模型",
max_tensor_parallel_size=4, # M2 only supports up to 4-way tensor parallelism
),
ModelInfo(
name="MiniMax-M2.1",
hf_repo="MiniMaxAI/MiniMax-M2.1",
aliases=["minimax-m2.1", "m2.1"],
type="moe",
default_params={
"kt-method": "FP8",
"kt-gpu-prefill-token-threshold": 4096,
"attention-backend": "flashinfer",
"fp8-gemm-backend": "triton",
"max-total-tokens": 100000,
"max-running-requests": 16,
"chunked-prefill-size": 32768,
"mem-fraction-static": 0.80,
"watchdog-timeout": 3000,
"served-model-name": "MiniMax-M2.1",
"disable-shared-experts-fusion": True,
"tool-call-parser": "minimax-m2",
"reasoning-parser": "minimax-append-think",
},
description="MiniMax M2.1 MoE model (enhanced multi-language programming)",
description_zh="MiniMax M2.1 MoE 模型(增强多语言编程能力)",
max_tensor_parallel_size=4, # M2.1 only supports up to 4-way tensor parallelism
),
]
class ModelRegistry:
"""Registry of supported models with fuzzy matching."""
def __init__(self):
"""Initialize the model registry."""
self._models: dict[str, ModelInfo] = {}
self._aliases: dict[str, str] = {}
self._load_builtin_models()
self._load_user_models()
def _load_builtin_models(self) -> None:
"""Load built-in models."""
for model in BUILTIN_MODELS:
self._register(model)
def _load_user_models(self) -> None:
"""Load user-defined models from config."""
settings = get_settings()
registry_file = settings.config_dir / "registry.yaml"
if registry_file.exists():
try:
with open(registry_file, "r", encoding="utf-8") as f:
data = yaml.safe_load(f) or {}
for name, info in data.get("models", {}).items():
model = ModelInfo(
name=name,
hf_repo=info.get("hf_repo", ""),
aliases=info.get("aliases", []),
type=info.get("type", "moe"),
gpu_vram_gb=info.get("gpu_vram_gb", 0),
cpu_ram_gb=info.get("cpu_ram_gb", 0),
default_params=info.get("default_params", {}),
description=info.get("description", ""),
description_zh=info.get("description_zh", ""),
max_tensor_parallel_size=info.get("max_tensor_parallel_size"),
)
self._register(model)
except (yaml.YAMLError, OSError):
pass
def _register(self, model: ModelInfo) -> None:
"""Register a model."""
self._models[model.name.lower()] = model
# Register aliases
for alias in model.aliases:
self._aliases[alias.lower()] = model.name.lower()
def get(self, name: str) -> Optional[ModelInfo]:
"""Get a model by exact name or alias."""
name_lower = name.lower()
# Check direct match
if name_lower in self._models:
return self._models[name_lower]
# Check aliases
if name_lower in self._aliases:
return self._models[self._aliases[name_lower]]
return None
def search(self, query: str, limit: int = 10) -> list[ModelInfo]:
"""Search for models using fuzzy matching.
Args:
query: Search query
limit: Maximum number of results
Returns:
List of matching models, sorted by relevance
"""
query_lower = query.lower()
results: list[tuple[float, ModelInfo]] = []
for model in self._models.values():
score = self._match_score(query_lower, model)
if score > 0:
results.append((score, model))
# Sort by score descending
results.sort(key=lambda x: x[0], reverse=True)
return [model for _, model in results[:limit]]
def _match_score(self, query: str, model: ModelInfo) -> float:
"""Calculate match score for a model.
Returns a score between 0 and 1, where 1 is an exact match.
"""
# Check exact match
if query == model.name.lower():
return 1.0
# Check alias exact match
for alias in model.aliases:
if query == alias.lower():
return 0.95
# Check if query is contained in name
if query in model.name.lower():
return 0.8
# Check if query is contained in aliases
for alias in model.aliases:
if query in alias.lower():
return 0.7
# Check if query is contained in hf_repo
if query in model.hf_repo.lower():
return 0.6
# Fuzzy matching - check if all query parts are present
query_parts = re.split(r"[-_.\s]", query)
name_lower = model.name.lower()
matches = sum(1 for part in query_parts if part and part in name_lower)
if matches > 0:
return 0.5 * (matches / len(query_parts))
return 0.0
def list_all(self) -> list[ModelInfo]:
"""List all registered models."""
return list(self._models.values())
def find_local_models(self) -> list[tuple[ModelInfo, Path]]:
"""Find models that are downloaded locally in any configured model path.
Returns:
List of (ModelInfo, path) tuples for local models
"""
settings = get_settings()
model_paths = settings.get_model_paths()
results = []
for model in self._models.values():
found = False
# Search in all configured model directories
for models_dir in model_paths:
if not models_dir.exists():
continue
# Check common path patterns
possible_paths = [
models_dir / model.name,
models_dir / model.name.lower(),
models_dir / model.hf_repo.split("/")[-1],
models_dir / model.hf_repo.replace("/", "--"),
]
for path in possible_paths:
if path.exists() and (path / "config.json").exists():
results.append((model, path))
found = True
break
if found:
break
return results
# Global registry instance
_registry: Optional[ModelRegistry] = None
def get_registry() -> ModelRegistry:
"""Get the global model registry instance."""
global _registry
if _registry is None:
_registry = ModelRegistry()
return _registry
# ============================================================================
# Model-specific parameter computation functions
# ============================================================================
def compute_deepseek_v3_gpu_experts(tensor_parallel_size: int, vram_per_gpu_gb: float) -> int:
per_gpu_gb = 16
if vram_per_gpu_gb < per_gpu_gb:
return int(0)
total_vram = int(tensor_parallel_size * (vram_per_gpu_gb - per_gpu_gb))
return total_vram // 3
def compute_kimi_k2_thinking_gpu_experts(tensor_parallel_size: int, vram_per_gpu_gb: float) -> int:
"""Compute kt-num-gpu-experts for Kimi K2 Thinking."""
per_gpu_gb = 16
if vram_per_gpu_gb < per_gpu_gb:
return int(0)
total_vram = int(tensor_parallel_size * (vram_per_gpu_gb - per_gpu_gb))
return total_vram * 2 // 3
def compute_minimax_m2_gpu_experts(tensor_parallel_size: int, vram_per_gpu_gb: float) -> int:
"""Compute kt-num-gpu-experts for MiniMax M2/M2.1."""
per_gpu_gb = 16
if vram_per_gpu_gb < per_gpu_gb:
return int(0)
total_vram = int(tensor_parallel_size * (vram_per_gpu_gb - per_gpu_gb))
return total_vram // 1
# Model name to computation function mapping
MODEL_COMPUTE_FUNCTIONS: dict[str, Callable[[int, float], int]] = {
"DeepSeek-V3-0324": compute_deepseek_v3_gpu_experts,
"DeepSeek-V3.2": compute_deepseek_v3_gpu_experts, # Same as V3-0324
"DeepSeek-R1-0528": compute_deepseek_v3_gpu_experts, # Same as V3-0324
"Kimi-K2-Thinking": compute_kimi_k2_thinking_gpu_experts,
"MiniMax-M2": compute_minimax_m2_gpu_experts,
"MiniMax-M2.1": compute_minimax_m2_gpu_experts, # Same as M2
}

View File

@@ -0,0 +1,407 @@
"""
SGLang installation checker and installation instructions provider.
This module provides utilities to:
- Check if SGLang is installed and get its metadata
- Provide installation instructions when SGLang is not found
"""
import subprocess
import sys
from pathlib import Path
from typing import Optional
from kt_kernel.cli.i18n import t
from kt_kernel.cli.utils.console import console
def check_sglang_installation() -> dict:
"""Check if SGLang is installed and get its metadata.
Returns:
dict with keys:
- installed: bool
- version: str or None
- location: str or None (installation path)
- editable: bool (whether installed in editable mode)
- git_info: dict or None (git remote and branch if available)
- from_source: bool (whether installed from source repository)
"""
try:
# Try to import sglang
import sglang
version = getattr(sglang, "__version__", None)
# Use pip show to get detailed package information
location = None
editable = False
git_info = None
from_source = False
try:
# Get pip show output
result = subprocess.run(
[sys.executable, "-m", "pip", "show", "sglang"],
capture_output=True,
text=True,
timeout=10,
)
if result.returncode == 0:
pip_info = {}
for line in result.stdout.split("\n"):
if ":" in line:
key, value = line.split(":", 1)
pip_info[key.strip()] = value.strip()
location = pip_info.get("Location")
editable_location = pip_info.get("Editable project location")
if editable_location:
editable = True
location = editable_location
except (subprocess.TimeoutExpired, FileNotFoundError, OSError):
# Fallback to module location
if hasattr(sglang, "__file__") and sglang.__file__:
location = str(Path(sglang.__file__).parent.parent)
# Check if it's installed from source (has .git directory)
if location:
git_root = None
check_path = Path(location)
# Check current directory and up to 2 parent directories
for _ in range(3):
git_dir = check_path / ".git"
if git_dir.exists():
git_root = check_path
from_source = True
break
if check_path.parent == check_path: # Reached root
break
check_path = check_path.parent
if from_source and git_root:
# Try to get git remote and branch info
try:
# Get remote URL
result = subprocess.run(
["git", "remote", "get-url", "origin"],
cwd=git_root,
capture_output=True,
text=True,
timeout=5,
)
remote_url = result.stdout.strip() if result.returncode == 0 else None
# Extract org/repo from URL
remote_short = None
if remote_url:
# Handle both https and git@ URLs
if "github.com" in remote_url:
parts = remote_url.rstrip("/").replace(".git", "").split("github.com")[-1]
remote_short = parts.lstrip("/").lstrip(":")
# Get current branch
result = subprocess.run(
["git", "branch", "--show-current"],
cwd=git_root,
capture_output=True,
text=True,
timeout=5,
)
branch = result.stdout.strip() if result.returncode == 0 else None
if remote_url or branch:
git_info = {
"remote": remote_short or remote_url,
"branch": branch,
}
except (subprocess.TimeoutExpired, FileNotFoundError, OSError):
pass
return {
"installed": True,
"version": version,
"location": location,
"editable": editable,
"git_info": git_info,
"from_source": from_source,
}
except ImportError:
return {
"installed": False,
"version": None,
"location": None,
"editable": False,
"git_info": None,
"from_source": False,
}
def get_sglang_install_instructions(lang: Optional[str] = None) -> str:
"""Get SGLang installation instructions.
Args:
lang: Language code ('en' or 'zh'). If None, uses current language setting.
Returns:
Formatted installation instructions string.
"""
from kt_kernel.cli.i18n import get_lang
if lang is None:
lang = get_lang()
if lang == "zh":
return """
[bold yellow]SGLang \u672a\u5b89\u88c5[/bold yellow]
\u8bf7\u6309\u7167\u4ee5\u4e0b\u6b65\u9aa4\u5b89\u88c5 SGLang:
[bold]1. \u514b\u9686\u4ed3\u5e93:[/bold]
git clone https://github.com/kvcache-ai/sglang.git
cd sglang
[bold]2. \u5b89\u88c5 (\u4e8c\u9009\u4e00):[/bold]
[cyan]\u65b9\u5f0f A - pip \u5b89\u88c5 (\u63a8\u8350):[/cyan]
pip install -e "python[all]"
[cyan]\u65b9\u5f0f B - uv \u5b89\u88c5 (\u66f4\u5feb):[/cyan]
pip install uv
uv pip install -e "python[all]"
[dim]\u6ce8\u610f: \u8bf7\u786e\u4fdd\u5728\u6b63\u786e\u7684 Python \u73af\u5883\u4e2d\u6267\u884c\u4ee5\u4e0a\u547d\u4ee4[/dim]
"""
else:
return """
[bold yellow]SGLang is not installed[/bold yellow]
Please follow these steps to install SGLang:
[bold]1. Clone the repository:[/bold]
git clone https://github.com/kvcache-ai/sglang.git
cd sglang
[bold]2. Install (choose one):[/bold]
[cyan]Option A - pip install (recommended):[/cyan]
pip install -e "python[all]"
[cyan]Option B - uv install (faster):[/cyan]
pip install uv
uv pip install -e "python[all]"
[dim]Note: Make sure to run these commands in the correct Python environment[/dim]
"""
def print_sglang_install_instructions() -> None:
"""Print SGLang installation instructions to console."""
instructions = get_sglang_install_instructions()
console.print(instructions)
def check_sglang_and_warn() -> bool:
"""Check if SGLang is installed, print warning if not.
Returns:
True if SGLang is installed, False otherwise.
"""
info = check_sglang_installation()
if not info["installed"]:
print_sglang_install_instructions()
return False
# Check if installed from PyPI (not recommended)
if info["installed"] and not info["from_source"]:
from kt_kernel.cli.utils.console import print_warning
print_warning(t("sglang_pypi_warning"))
console.print()
console.print("[dim]" + t("sglang_recommend_source") + "[/dim]")
console.print()
return True
def _get_sglang_kt_kernel_cache_path() -> Path:
"""Get the path to the sglang kt-kernel support cache file."""
cache_dir = Path.home() / ".ktransformers" / "cache"
cache_dir.mkdir(parents=True, exist_ok=True)
return cache_dir / "sglang_kt_kernel_supported"
def _is_sglang_kt_kernel_cache_valid() -> bool:
"""Check if the sglang kt-kernel support cache is valid.
The cache is considered valid if:
1. The cache file exists
2. The cache file contains 'true' (indicating previous check passed)
Returns:
True if cache is valid and indicates support, False otherwise.
"""
cache_path = _get_sglang_kt_kernel_cache_path()
if cache_path.exists():
try:
content = cache_path.read_text().strip().lower()
return content == "true"
except (OSError, IOError):
pass
return False
def _save_sglang_kt_kernel_cache(supported: bool) -> None:
"""Save the sglang kt-kernel support check result to cache."""
cache_path = _get_sglang_kt_kernel_cache_path()
try:
cache_path.write_text("true" if supported else "false")
except (OSError, IOError):
pass # Ignore cache write errors
def clear_sglang_kt_kernel_cache() -> None:
"""Clear the sglang kt-kernel support cache, forcing a re-check on next run."""
cache_path = _get_sglang_kt_kernel_cache_path()
try:
if cache_path.exists():
cache_path.unlink()
except (OSError, IOError):
pass
def check_sglang_kt_kernel_support(use_cache: bool = True, silent: bool = False) -> dict:
"""Check if SGLang supports kt-kernel parameters (--kt-gpu-prefill-token-threshold).
This function runs `python -m sglang.launch_server --help` and checks if the
output contains the `--kt-gpu-prefill-token-threshold` parameter. This parameter
is only available in the kvcache-ai/sglang fork, not in the official sglang.
The result is cached after the first successful check to avoid repeated checks.
Args:
use_cache: If True, use cached result if available. Default is True.
silent: If True, don't print checking message. Default is False.
Returns:
dict with keys:
- supported: bool - True if kt-kernel parameters are supported
- help_output: str or None - The help output from sglang.launch_server
- error: str or None - Error message if check failed
- from_cache: bool - True if result was from cache
"""
from kt_kernel.cli.utils.console import print_step
# Check cache first
if use_cache and _is_sglang_kt_kernel_cache_valid():
return {
"supported": True,
"help_output": None,
"error": None,
"from_cache": True,
}
# Print checking message
if not silent:
print_step(t("sglang_checking_kt_kernel_support"))
try:
result = subprocess.run(
[sys.executable, "-m", "sglang.launch_server", "--help"],
capture_output=True,
text=True,
timeout=30,
)
help_output = result.stdout + result.stderr
# Check if --kt-gpu-prefill-token-threshold is in the help output
supported = "--kt-gpu-prefill-token-threshold" in help_output
# Save to cache if supported
if supported:
_save_sglang_kt_kernel_cache(True)
return {
"supported": supported,
"help_output": help_output,
"error": None,
"from_cache": False,
}
except subprocess.TimeoutExpired:
return {
"supported": False,
"help_output": None,
"error": "Timeout while checking sglang.launch_server --help",
"from_cache": False,
}
except FileNotFoundError:
return {
"supported": False,
"help_output": None,
"error": "Python interpreter not found",
"from_cache": False,
}
except Exception as e:
return {
"supported": False,
"help_output": None,
"error": str(e),
"from_cache": False,
}
def print_sglang_kt_kernel_instructions() -> None:
"""Print instructions for installing the kvcache-ai fork of SGLang with kt-kernel support."""
from kt_kernel.cli.i18n import get_lang
lang = get_lang()
if lang == "zh":
instructions = """
[bold red]SGLang 不支持 kt-kernel[/bold red]
您当前安装的 SGLang 不包含 kt-kernel 支持。
kt-kernel 需要使用 kvcache-ai 维护的 SGLang 分支。
[bold]请按以下步骤重新安装 SGLang:[/bold]
[cyan]1. 卸载当前的 SGLang:[/cyan]
pip uninstall sglang -y
[cyan]2. 克隆 kvcache-ai 的 SGLang 仓库:[/cyan]
git clone https://github.com/kvcache-ai/sglang.git
cd sglang
[cyan]3. 安装 SGLang:[/cyan]
pip install -e "python[all]"
[dim]注意: 请确保在正确的 Python 环境中执行以上命令[/dim]
"""
else:
instructions = """
[bold red]SGLang does not support kt-kernel[/bold red]
Your current SGLang installation does not include kt-kernel support.
kt-kernel requires the kvcache-ai maintained fork of SGLang.
[bold]Please reinstall SGLang with the following steps:[/bold]
[cyan]1. Uninstall current SGLang:[/cyan]
pip uninstall sglang -y
[cyan]2. Clone the kvcache-ai SGLang repository:[/cyan]
git clone https://github.com/kvcache-ai/sglang.git
cd sglang
[cyan]3. Install SGLang:[/cyan]
pip install -e "python[all]"
[dim]Note: Make sure to run these commands in the correct Python environment[/dim]
"""
console.print(instructions)

View File

@@ -17,7 +17,7 @@ from typing import List, Optional
from .experts_base import BaseMoEWrapper, KExpertsCPUBuffer
# Import backend implementations
from .utils.amx import AMXMoEWrapper, RAWAMXMoEWrapper
from .utils.amx import AMXMoEWrapper, NativeMoEWrapper
from .utils.llamafile import LlamafileMoEWrapper
from .utils.moe_kernel import GeneralMoEWrapper
@@ -77,7 +77,7 @@ class KTMoEWrapper:
chunked_prefill_size: Maximum prefill chunk size
cpu_save: Whether to save weights to CPU memory
max_deferred_experts_per_token: Number of experts per token to defer. Defaults to 0.
method: Backend method ("AMXINT4", "AMXINT8", "RAWINT4", "LLAMAFILE", "MOE_INT4", "MOE_INT8")
method: Backend method ("AMXINT4", "AMXINT8", "RAWINT4", "FP8", "LLAMAFILE", "MOE_INT4", "MOE_INT8")
Returns:
An instance of the appropriate backend implementation (e.g., AMXMoEWrapper)
@@ -85,8 +85,8 @@ class KTMoEWrapper:
# Select backend based on method
if method in ["AMXINT4", "AMXINT8"]:
backend_cls = AMXMoEWrapper
elif method == "RAWINT4":
backend_cls = RAWAMXMoEWrapper
elif method in ["RAWINT4", "FP8"]:
backend_cls = NativeMoEWrapper
elif method == "LLAMAFILE":
backend_cls = LlamafileMoEWrapper
elif method in ["MOE_INT4", "MOE_INT8"]:

View File

@@ -4,13 +4,13 @@
Utilities for kt_kernel package.
"""
from .amx import AMXMoEWrapper, RAWAMXMoEWrapper
from .amx import AMXMoEWrapper, NativeMoEWrapper
from .llamafile import LlamafileMoEWrapper
from .loader import SafeTensorLoader, GGUFLoader, CompressedSafeTensorLoader
__all__ = [
"AMXMoEWrapper",
"RAWAMXMoEWrapper",
"NativeMoEWrapper",
"LlamafileMoEWrapper",
"SafeTensorLoader",
"CompressedSafeTensorLoader",

View File

@@ -4,16 +4,16 @@ import ctypes
# Use relative imports for package structure
from ..experts_base import BaseMoEWrapper
from .loader import SafeTensorLoader, CompressedSafeTensorLoader
from .loader import SafeTensorLoader, CompressedSafeTensorLoader, FP8SafeTensorLoader
from kt_kernel_ext.moe import MOEConfig
try:
from kt_kernel_ext.moe import AMXInt4_MOE, AMXInt8_MOE, AMXInt4_KGroup_MOE
from kt_kernel_ext.moe import AMXInt4_MOE, AMXInt8_MOE, AMXInt4_KGroup_MOE, AMXFP8_MOE
_HAS_AMX_SUPPORT = True
except (ImportError, AttributeError):
_HAS_AMX_SUPPORT = False
AMXInt4_MOE, AMXInt8_MOE, AMXInt4_KGroup_MOE = None, None, None
AMXInt4_MOE, AMXInt8_MOE, AMXInt4_KGroup_MOE, AMXFP8_MOE = None, None, None, None
from typing import Optional
@@ -303,10 +303,10 @@ class AMXMoEWrapper(BaseMoEWrapper):
del self.down_scales
class RAWAMXMoEWrapper(BaseMoEWrapper):
"""Wrapper for RAWINT4 experts stored in compressed SafeTensor format."""
class NativeMoEWrapper(BaseMoEWrapper):
"""Wrapper for RAWINT4/FP8 experts stored in compressed SafeTensor format."""
_compressed_loader_instance = None
_native_loader_instance = None
def __init__(
self,
@@ -324,8 +324,12 @@ class RAWAMXMoEWrapper(BaseMoEWrapper):
max_deferred_experts_per_token: Optional[int] = None,
method: str = "RAWINT4",
):
if not _HAS_AMX_SUPPORT or AMXInt4_KGroup_MOE is None:
if not _HAS_AMX_SUPPORT:
raise RuntimeError("AMX backend is not available.")
if method == "RAWINT4" and AMXInt4_KGroup_MOE is None:
raise RuntimeError("AMX backend with RAWINT4 support is not available.")
if method == "FP8" and AMXFP8_MOE is None:
raise RuntimeError("AMX backend with FP8 support is not available.")
super().__init__(
layer_idx=layer_idx,
@@ -343,9 +347,14 @@ class RAWAMXMoEWrapper(BaseMoEWrapper):
method=method,
)
if RAWAMXMoEWrapper._compressed_loader_instance is None:
RAWAMXMoEWrapper._compressed_loader_instance = CompressedSafeTensorLoader(weight_path)
self.loader = RAWAMXMoEWrapper._compressed_loader_instance
if NativeMoEWrapper._native_loader_instance is None:
if method == "RAWINT4":
NativeMoEWrapper._native_loader_instance = CompressedSafeTensorLoader(weight_path)
elif method == "FP8":
NativeMoEWrapper._native_loader_instance = FP8SafeTensorLoader(weight_path)
else:
raise NotImplementedError(f"Unsupported method for NativeMoEWrapper: {method}")
self.loader = NativeMoEWrapper._native_loader_instance
self.gate_weights = None
self.up_weights = None
@@ -378,9 +387,17 @@ class RAWAMXMoEWrapper(BaseMoEWrapper):
self.down_weights = weights["down"]
# Convert scales to bf16 individually
self.gate_scales = [t.to(torch.bfloat16).contiguous() for t in weights["gate_scale"]]
self.up_scales = [t.to(torch.bfloat16).contiguous() for t in weights["up_scale"]]
self.down_scales = [t.to(torch.bfloat16).contiguous() for t in weights["down_scale"]]
# self.gate_scales = [t.to(torch.bfloat16).contiguous() for t in weights["gate_scale"]]
# self.up_scales = [t.to(torch.bfloat16).contiguous() for t in weights["up_scale"]]
# self.down_scales = [t.to(torch.bfloat16).contiguous() for t in weights["down_scale"]]
self.gate_scales = weights["gate_scale"]
self.up_scales = weights["up_scale"]
self.down_scales = weights["down_scale"]
if self.method == "RAWINT4":
assert self.gate_scales[0].dtype == torch.bfloat16, "Expected bf16 scales for RAWINT4"
elif self.method == "FP8":
assert self.gate_scales[0].dtype == torch.float32, "Expected float32 scales for FP8"
t2 = time.time()
# Build pointer lists: [numa_id][expert_id] -> pointer
@@ -404,18 +421,6 @@ class RAWAMXMoEWrapper(BaseMoEWrapper):
moe_config.pool = self.cpu_infer.backend_
moe_config.max_len = self.chunked_prefill_size
# Infer group_size from scale shape (column-major layout)
# For gate/up projection: in_features = hidden_size
# So: group_size = hidden_size / scale.shape[1]
scale_shape = self.gate_scales[0].shape
group_size = self.hidden_size // scale_shape[1]
print(f"[RAWAMXMoEWrapper Layer {self.layer_idx}] Inferred group_size: {group_size}")
moe_config.quant_config.bits = 4
moe_config.quant_config.group_size = group_size
moe_config.quant_config.zero_point = False
# Use gate_projs instead of gate_proj for per-expert pointers
moe_config.gate_projs = gate_ptrs
moe_config.up_projs = up_ptrs
@@ -424,7 +429,21 @@ class RAWAMXMoEWrapper(BaseMoEWrapper):
moe_config.up_scales = up_scale_ptrs
moe_config.down_scales = down_scale_ptrs
self.moe = AMXInt4_KGroup_MOE(moe_config)
# Infer group_size from scale shape (column-major layout)
# For gate/up projection: in_features = hidden_size
# So: group_size = hidden_size / scale.shape[1]
if self.method == "RAWINT4":
group_size = self.hidden_size // self.gate_scales[0].shape[1]
moe_config.quant_config.bits = 4
moe_config.quant_config.group_size = group_size
moe_config.quant_config.zero_point = False
self.moe = AMXInt4_KGroup_MOE(moe_config)
elif self.method == "FP8":
moe_config.quant_config.bits = 8
moe_config.quant_config.group_size = 128
moe_config.quant_config.zero_point = False
self.moe = AMXFP8_MOE(moe_config)
t4 = time.time()
self.cpu_infer.submit(self.moe.load_weights_task(physical_to_logical_map_cpu.data_ptr()))
@@ -440,7 +459,7 @@ class RAWAMXMoEWrapper(BaseMoEWrapper):
t6 = time.time()
print(
f"[RAWAMXMoEWrapper Layer {self.layer_idx}] "
f"[NativeMoEWrapper Layer {self.layer_idx}] "
f"load_experts: {(t1-t0)*1000:.1f}ms, "
f"prepare_tensors: {(t2-t1)*1000:.1f}ms, "
f"build_ptrs: {(t3-t2)*1000:.1f}ms, "
@@ -453,7 +472,7 @@ class RAWAMXMoEWrapper(BaseMoEWrapper):
def submit_write_weight_scale_to_buffer(
self,
gpu_tp_count: int,
gpu_experts_num: int,
expert_id: int,
w13_weight_ptrs,
w13_scale_ptrs,
w2_weight_ptrs,
@@ -477,7 +496,7 @@ class RAWAMXMoEWrapper(BaseMoEWrapper):
self.cpu_infer.submit(
self.moe.write_weight_scale_to_buffer_task(
gpu_tp_count,
gpu_experts_num,
expert_id,
w13_weight_ptrs,
w13_scale_ptrs,
w2_weight_ptrs,

View File

@@ -219,4 +219,4 @@ class LlamafileMoEWrapper(BaseMoEWrapper):
self.cpu_infer.sync()
# Drop original weights after loading
self.weights_to_keep = None
self.weights_to_keep = None

View File

@@ -237,6 +237,117 @@ class SafeTensorLoader:
return name in self.tensor_file_map
class FP8SafeTensorLoader(SafeTensorLoader):
"""Loader for FP8 expert weights with auto-detection of naming formats.
Supported formats:
- DeepSeek style: {base}.mlp.experts.{id}.{gate,up,down}_proj.weight
- Mixtral/MiniMax style: {base}.block_sparse_moe.experts.{id}.{w1,w3,w2}.weight
The format is auto-detected during initialization.
"""
# Known MoE naming formats: (experts_path_template, gate_name, up_name, down_name)
MOE_FORMATS = {
"deepseek": ("{base}.mlp.experts", "gate_proj", "up_proj", "down_proj"),
"mixtral": ("{base}.block_sparse_moe.experts", "w1", "w3", "w2"),
}
def __init__(self, file_path: str):
super().__init__(file_path)
self._detected_format = None
self._detect_format()
def _detect_format(self):
"""Auto-detect the MoE naming format by checking tensor keys."""
# Sample some tensor names to detect format
sample_keys = list(self.tensor_file_map.keys())[:1000]
for fmt_name, (path_tpl, gate, up, down) in self.MOE_FORMATS.items():
# Check if any key matches this format pattern
# Look for pattern like: model.layers.0.{experts_path}.0.{gate_name}.weight
for key in sample_keys:
if ".experts." in key and f".{gate}.weight" in key:
# Verify the path template matches
if "block_sparse_moe.experts" in key and fmt_name == "mixtral":
self._detected_format = fmt_name
print(f"[FP8SafeTensorLoader] Detected format: {fmt_name}")
return
elif "mlp.experts" in key and "block_sparse_moe" not in key and fmt_name == "deepseek":
self._detected_format = fmt_name
print(f"[FP8SafeTensorLoader] Detected format: {fmt_name}")
return
# Default to deepseek if no format detected
self._detected_format = "deepseek"
print("[FP8SafeTensorLoader] No MoE format detected, defaulting to: deepseek")
def _get_experts_prefix(self, base_key: str) -> str:
"""Get the experts prefix based on detected format."""
path_tpl, _, _, _ = self.MOE_FORMATS[self._detected_format]
return path_tpl.format(base=base_key)
def _get_proj_names(self):
"""Get projection names (gate, up, down) based on detected format."""
_, gate, up, down = self.MOE_FORMATS[self._detected_format]
return gate, up, down
def load_tensor(self, key: str, device: str = "cpu"):
if key not in self.tensor_file_map:
raise KeyError(f"Key {key} not found in Safetensor files")
file = self.tensor_file_map[key]
f = self.file_handle_map.get(file)
if f is None:
raise FileNotFoundError(f"File {file} not found in Safetensor files")
tensor = f.get_tensor(key)
if device == "cpu":
return tensor
return tensor.to(device)
def load_experts(self, base_key: str, device: str = "cpu"):
"""Load FP8 expert weights and their block-wise scale_inv tensors."""
experts_prefix = self._get_experts_prefix(base_key)
gate_name, up_name, down_name = self._get_proj_names()
expert_count = 0
while self.has_tensor(f"{experts_prefix}.{expert_count}.{gate_name}.weight"):
expert_count += 1
if expert_count == 0:
raise ValueError(f"No experts found for key {experts_prefix}")
gate_weights = [None] * expert_count
up_weights = [None] * expert_count
down_weights = [None] * expert_count
gate_scales = [None] * expert_count
up_scales = [None] * expert_count
down_scales = [None] * expert_count
for exp_id in range(expert_count):
gate_w_key = f"{experts_prefix}.{exp_id}.{gate_name}.weight"
up_w_key = f"{experts_prefix}.{exp_id}.{up_name}.weight"
down_w_key = f"{experts_prefix}.{exp_id}.{down_name}.weight"
gate_s_key = f"{experts_prefix}.{exp_id}.{gate_name}.weight_scale_inv"
up_s_key = f"{experts_prefix}.{exp_id}.{up_name}.weight_scale_inv"
down_s_key = f"{experts_prefix}.{exp_id}.{down_name}.weight_scale_inv"
gate_weights[exp_id] = self.load_tensor(gate_w_key, device).contiguous()
up_weights[exp_id] = self.load_tensor(up_w_key, device).contiguous()
down_weights[exp_id] = self.load_tensor(down_w_key, device).contiguous()
gate_scales[exp_id] = self.load_tensor(gate_s_key, device).contiguous()
up_scales[exp_id] = self.load_tensor(up_s_key, device).contiguous()
down_scales[exp_id] = self.load_tensor(down_s_key, device).contiguous()
return {
"gate": gate_weights,
"up": up_weights,
"down": down_weights,
"gate_scale": gate_scales,
"up_scale": up_scales,
"down_scale": down_scales,
}
class CompressedSafeTensorLoader(SafeTensorLoader):
"""Loader for compressed SafeTensor layouts (RAWINT4 weights)."""

View File

@@ -285,9 +285,9 @@ class CMakeBuild(build_ext):
# Variant configurations: (name, CPUINFER_CPU_INSTRUCT, CPUINFER_ENABLE_AMX)
variants = [
("amx", "AVX512", "ON"), # AVX512 + AMX
("amx", "AVX512", "ON"), # AVX512 + AMX
("avx512", "AVX512", "OFF"), # AVX512 only
("avx2", "AVX2", "OFF"), # AVX2 only
("avx2", "AVX2", "OFF"), # AVX2 only
]
for variant_name, cpu_instruct, enable_amx in variants:
@@ -384,6 +384,7 @@ class CMakeBuild(build_ext):
build_temp: Temporary build directory for CMake
cfg: Build type (Release/Debug/etc.)
"""
# Auto-detect CUDA toolkit if user did not explicitly set CPUINFER_USE_CUDA
def detect_cuda_toolkit() -> bool:
# Respect CUDA_HOME
@@ -614,10 +615,26 @@ setup(
author="kvcache-ai",
license="Apache-2.0",
python_requires=">=3.8",
packages=["kt_kernel", "kt_kernel.utils"],
packages=[
"kt_kernel",
"kt_kernel.utils",
"kt_kernel.cli",
"kt_kernel.cli.commands",
"kt_kernel.cli.config",
"kt_kernel.cli.utils",
],
package_dir={
"kt_kernel": "python",
"kt_kernel.utils": "python/utils",
"kt_kernel.cli": "python/cli",
"kt_kernel.cli.commands": "python/cli/commands",
"kt_kernel.cli.config": "python/cli/config",
"kt_kernel.cli.utils": "python/cli/utils",
},
entry_points={
"console_scripts": [
"kt=kt_kernel.cli.main:main",
],
},
ext_modules=[CMakeExtension("kt_kernel.kt_kernel_ext", str(REPO_ROOT))],
cmdclass={"build_ext": CMakeBuild},

View File

@@ -17,6 +17,7 @@ register_cpu_ci(est_time=30, suite="default")
# Check if kt_kernel_ext is available
try:
import kt_kernel # Import kt_kernel first to register kt_kernel_ext
kt_kernel_ext = kt_kernel.kt_kernel_ext # Access the extension module
HAS_KT_KERNEL = True
except ImportError:
@@ -51,7 +52,7 @@ def test_basic_module_attributes():
pytest.skip("kt_kernel_ext not built or available")
# Check for key attributes/functions
assert hasattr(kt_kernel_ext, 'CPUInfer'), "kt_kernel_ext should have CPUInfer class"
assert hasattr(kt_kernel_ext, "CPUInfer"), "kt_kernel_ext should have CPUInfer class"
def run_all_tests():

View File

@@ -20,6 +20,7 @@ register_cpu_ci(est_time=120, suite="default")
try:
import torch
import kt_kernel # Import kt_kernel first to register kt_kernel_ext
kt_kernel_ext = kt_kernel.kt_kernel_ext # Access the extension module
HAS_DEPS = True
except ImportError as e:
@@ -68,9 +69,7 @@ def moe_torch(input, expert_ids, weights, gate_proj, up_proj, down_proj):
if num_tokens == 0:
continue
tokens_for_this_expert = sorted_tokens[start_idx:end_idx]
expert_out = mlp_torch(
tokens_for_this_expert, gate_proj[i], up_proj[i], down_proj[i]
)
expert_out = mlp_torch(tokens_for_this_expert, gate_proj[i], up_proj[i], down_proj[i])
outputs.append(expert_out)
start_idx = end_idx
@@ -96,9 +95,7 @@ def test_moe_amx_int4_accuracy():
pytest.skip(f"Dependencies not available: {import_error}")
global physical_to_logical_map
physical_to_logical_map = torch.tensor(
data=range(expert_num), device="cpu", dtype=torch.int64
).contiguous()
physical_to_logical_map = torch.tensor(data=range(expert_num), device="cpu", dtype=torch.int64).contiguous()
CPUInfer = kt_kernel_ext.CPUInfer(60)
@@ -133,9 +130,7 @@ def test_moe_amx_int4_accuracy():
)
# Create MOE config
config = kt_kernel_ext.moe.MOEConfig(
expert_num, num_experts_per_tok, hidden_size, intermediate_size, 0
)
config = kt_kernel_ext.moe.MOEConfig(expert_num, num_experts_per_tok, hidden_size, intermediate_size, 0)
config.max_len = max_len
config.gate_proj = gate_proj.data_ptr()
config.up_proj = up_proj.data_ptr()
@@ -176,14 +171,10 @@ def test_moe_amx_int4_accuracy():
CPUInfer.sync()
# Run torch reference
t_output = moe_torch(
input_data, expert_ids, weights, gate_proj, up_proj, down_proj
)
t_output = moe_torch(input_data, expert_ids, weights, gate_proj, up_proj, down_proj)
# Calculate relative difference
diff = torch.mean(torch.abs(output - t_output)) / torch.mean(
torch.abs(t_output)
)
diff = torch.mean(torch.abs(output - t_output)) / torch.mean(torch.abs(t_output))
print(f"Iteration {i}, diff = {diff:.6f}")
# INT4 should have diff < 0.35
@@ -205,6 +196,7 @@ def run_all_tests():
except Exception as e:
print(f"\n✗ Test failed: {e}")
import traceback
traceback.print_exc()
sys.exit(1)

View File

@@ -20,6 +20,7 @@ register_cpu_ci(est_time=120, suite="default")
try:
import torch
import kt_kernel # Import kt_kernel first to register kt_kernel_ext
kt_kernel_ext = kt_kernel.kt_kernel_ext # Access the extension module
HAS_DEPS = True
except ImportError as e:
@@ -68,9 +69,7 @@ def moe_torch(input, expert_ids, weights, gate_proj, up_proj, down_proj):
if num_tokens == 0:
continue
tokens_for_this_expert = sorted_tokens[start_idx:end_idx]
expert_out = mlp_torch(
tokens_for_this_expert, gate_proj[i], up_proj[i], down_proj[i]
)
expert_out = mlp_torch(tokens_for_this_expert, gate_proj[i], up_proj[i], down_proj[i])
outputs.append(expert_out)
start_idx = end_idx
@@ -96,9 +95,7 @@ def test_moe_amx_int4_1_accuracy():
pytest.skip(f"Dependencies not available: {import_error}")
global physical_to_logical_map
physical_to_logical_map = torch.tensor(
data=range(expert_num), device="cpu", dtype=torch.int64
).contiguous()
physical_to_logical_map = torch.tensor(data=range(expert_num), device="cpu", dtype=torch.int64).contiguous()
CPUInfer = kt_kernel_ext.CPUInfer(60)
@@ -133,9 +130,7 @@ def test_moe_amx_int4_1_accuracy():
)
# Create MOE config
config = kt_kernel_ext.moe.MOEConfig(
expert_num, num_experts_per_tok, hidden_size, intermediate_size, 0
)
config = kt_kernel_ext.moe.MOEConfig(expert_num, num_experts_per_tok, hidden_size, intermediate_size, 0)
config.max_len = max_len
config.gate_proj = gate_proj.data_ptr()
config.up_proj = up_proj.data_ptr()
@@ -176,14 +171,10 @@ def test_moe_amx_int4_1_accuracy():
CPUInfer.sync()
# Run torch reference
t_output = moe_torch(
input_data, expert_ids, weights, gate_proj, up_proj, down_proj
)
t_output = moe_torch(input_data, expert_ids, weights, gate_proj, up_proj, down_proj)
# Calculate relative difference
diff = torch.mean(torch.abs(output - t_output)) / torch.mean(
torch.abs(t_output)
)
diff = torch.mean(torch.abs(output - t_output)) / torch.mean(torch.abs(t_output))
print(f"Iteration {i}, diff = {diff:.6f}")
# INT4_1 should have diff < 0.35
@@ -205,6 +196,7 @@ def run_all_tests():
except Exception as e:
print(f"\n✗ Test failed: {e}")
import traceback
traceback.print_exc()
sys.exit(1)

View File

@@ -20,6 +20,7 @@ register_cpu_ci(est_time=120, suite="default")
try:
import torch
import kt_kernel # Import kt_kernel first to register kt_kernel_ext
kt_kernel_ext = kt_kernel.kt_kernel_ext # Access the extension module
HAS_DEPS = True
except ImportError as e:
@@ -69,9 +70,7 @@ def moe_torch(input, expert_ids, weights, gate_proj, up_proj, down_proj):
if num_tokens == 0:
continue
tokens_for_this_expert = sorted_tokens[start_idx:end_idx]
expert_out = mlp_torch(
tokens_for_this_expert, gate_proj[i], up_proj[i], down_proj[i]
)
expert_out = mlp_torch(tokens_for_this_expert, gate_proj[i], up_proj[i], down_proj[i])
outputs.append(expert_out)
start_idx = end_idx
@@ -97,9 +96,7 @@ def test_moe_amx_int4_1k_accuracy():
pytest.skip(f"Dependencies not available: {import_error}")
global physical_to_logical_map
physical_to_logical_map = torch.tensor(
data=range(expert_num), device="cpu", dtype=torch.int64
).contiguous()
physical_to_logical_map = torch.tensor(data=range(expert_num), device="cpu", dtype=torch.int64).contiguous()
CPUInfer = kt_kernel_ext.CPUInfer(60)
@@ -134,9 +131,7 @@ def test_moe_amx_int4_1k_accuracy():
)
# Create MOE config
config = kt_kernel_ext.moe.MOEConfig(
expert_num, num_experts_per_tok, hidden_size, intermediate_size, 0
)
config = kt_kernel_ext.moe.MOEConfig(expert_num, num_experts_per_tok, hidden_size, intermediate_size, 0)
config.max_len = max_len
config.gate_proj = gate_proj.data_ptr()
config.up_proj = up_proj.data_ptr()
@@ -180,14 +175,10 @@ def test_moe_amx_int4_1k_accuracy():
CPUInfer.sync()
# Run torch reference
t_output = moe_torch(
input_data, expert_ids, weights, gate_proj, up_proj, down_proj
)
t_output = moe_torch(input_data, expert_ids, weights, gate_proj, up_proj, down_proj)
# Calculate relative difference
diff = torch.mean(torch.abs(output - t_output)) / torch.mean(
torch.abs(t_output)
)
diff = torch.mean(torch.abs(output - t_output)) / torch.mean(torch.abs(t_output))
print(f"Iteration {i}, diff = {diff:.6f}")
# INT4_1K should have diff < 0.35
@@ -209,6 +200,7 @@ def run_all_tests():
except Exception as e:
print(f"\n✗ Test failed: {e}")
import traceback
traceback.print_exc()
sys.exit(1)

View File

@@ -20,6 +20,7 @@ register_cpu_ci(est_time=120, suite="default")
try:
import torch
import kt_kernel # Import kt_kernel first to register kt_kernel_ext
kt_kernel_ext = kt_kernel.kt_kernel_ext # Access the extension module
HAS_DEPS = True
except ImportError as e:
@@ -68,9 +69,7 @@ def moe_torch(input, expert_ids, weights, gate_proj, up_proj, down_proj):
if num_tokens == 0:
continue
tokens_for_this_expert = sorted_tokens[start_idx:end_idx]
expert_out = mlp_torch(
tokens_for_this_expert, gate_proj[i], up_proj[i], down_proj[i]
)
expert_out = mlp_torch(tokens_for_this_expert, gate_proj[i], up_proj[i], down_proj[i])
outputs.append(expert_out)
start_idx = end_idx
@@ -96,9 +95,7 @@ def test_moe_amx_int8_accuracy():
pytest.skip(f"Dependencies not available: {import_error}")
global physical_to_logical_map
physical_to_logical_map = torch.tensor(
data=range(expert_num), device="cpu", dtype=torch.int64
).contiguous()
physical_to_logical_map = torch.tensor(data=range(expert_num), device="cpu", dtype=torch.int64).contiguous()
CPUInfer = kt_kernel_ext.CPUInfer(60)
@@ -133,9 +130,7 @@ def test_moe_amx_int8_accuracy():
)
# Create MOE config
config = kt_kernel_ext.moe.MOEConfig(
expert_num, num_experts_per_tok, hidden_size, intermediate_size, 0
)
config = kt_kernel_ext.moe.MOEConfig(expert_num, num_experts_per_tok, hidden_size, intermediate_size, 0)
config.max_len = max_len
config.gate_proj = gate_proj.data_ptr()
config.up_proj = up_proj.data_ptr()
@@ -174,14 +169,10 @@ def test_moe_amx_int8_accuracy():
CPUInfer.sync()
# Run torch reference
t_output = moe_torch(
input_data, expert_ids, weights, gate_proj, up_proj, down_proj
)
t_output = moe_torch(input_data, expert_ids, weights, gate_proj, up_proj, down_proj)
# Calculate relative difference
diff = torch.mean(torch.abs(output - t_output)) / torch.mean(
torch.abs(t_output)
)
diff = torch.mean(torch.abs(output - t_output)) / torch.mean(torch.abs(t_output))
print(f"Iteration {i}, diff = {diff:.6f}")
# INT8 should have diff < 0.05
@@ -203,6 +194,7 @@ def run_all_tests():
except Exception as e:
print(f"\n✗ Test failed: {e}")
import traceback
traceback.print_exc()
sys.exit(1)

View File

@@ -24,8 +24,10 @@ register_cpu_ci(est_time=300, suite="default")
try:
import torch
import kt_kernel # Import kt_kernel first to register kt_kernel_ext
kt_kernel_ext = kt_kernel.kt_kernel_ext # Access the extension module
from tqdm import tqdm
HAS_DEPS = True
except ImportError as e:
HAS_DEPS = False
@@ -306,6 +308,7 @@ def run_all_tests():
except Exception as e:
print(f"\n✗ Test failed: {e}")
import traceback
traceback.print_exc()
sys.exit(1)

View File

@@ -24,6 +24,7 @@ register_cpu_ci(est_time=300, suite="default")
try:
import torch
import kt_kernel # Import kt_kernel first to register kt_kernel_ext
kt_kernel_ext = kt_kernel.kt_kernel_ext # Access the extension module
from tqdm import tqdm

View File

@@ -25,8 +25,10 @@ register_cpu_ci(est_time=300, suite="default")
try:
import torch
import kt_kernel # Import kt_kernel first to register kt_kernel_ext
kt_kernel_ext = kt_kernel.kt_kernel_ext # Access the extension module
from tqdm import tqdm
HAS_DEPS = True
except ImportError as e:
HAS_DEPS = False
@@ -156,11 +158,7 @@ def test_moe_amx_int4_1k_benchmark():
CPUInfer = kt_kernel_ext.CPUInfer(worker_config)
# Physical to logical map for weight loading
physical_to_logical_map = torch.tensor(
data=range(expert_num),
device="cpu",
dtype=torch.int64
).contiguous()
physical_to_logical_map = torch.tensor(data=range(expert_num), device="cpu", dtype=torch.int64).contiguous()
# Initialize MOE layers
moes = []
@@ -322,6 +320,7 @@ def run_all_tests():
except Exception as e:
print(f"\nTest failed: {e}")
import traceback
traceback.print_exc()
sys.exit(1)

View File

@@ -24,8 +24,10 @@ register_cpu_ci(est_time=300, suite="default")
try:
import torch
import kt_kernel # Import kt_kernel first to register kt_kernel_ext
kt_kernel_ext = kt_kernel.kt_kernel_ext # Access the extension module
from tqdm import tqdm
HAS_DEPS = True
except ImportError as e:
HAS_DEPS = False
@@ -51,7 +53,6 @@ worker_config_dict = {
CPUINFER_PARAM = 60
def get_git_commit():
"""Get current git commit information."""
result = {}
@@ -307,6 +308,7 @@ def run_all_tests():
except Exception as e:
print(f"\n✗ Test failed: {e}")
import traceback
traceback.print_exc()
sys.exit(1)

View File

@@ -3,4 +3,4 @@ KTransformers version information.
Shared across kt-kernel and kt-sft modules.
"""
__version__ = "0.4.4"
__version__ = "0.4.5"