From 4f646657581dd4142b98a0252d74bde98942a4b0 Mon Sep 17 00:00:00 2001 From: Oql <1692110604@qq.com> Date: Wed, 4 Feb 2026 16:27:10 +0800 Subject: [PATCH] [docs]: add Qwen3 Coder Next Tutorial (#1833) --- doc/en/kt-kernel/Qwen3-Coder-Next-Tutorial.md | 222 ++++++++++++++++++ kt-kernel/README.md | 83 ++++++- 2 files changed, 292 insertions(+), 13 deletions(-) create mode 100644 doc/en/kt-kernel/Qwen3-Coder-Next-Tutorial.md diff --git a/doc/en/kt-kernel/Qwen3-Coder-Next-Tutorial.md b/doc/en/kt-kernel/Qwen3-Coder-Next-Tutorial.md new file mode 100644 index 0000000..5bf2c39 --- /dev/null +++ b/doc/en/kt-kernel/Qwen3-Coder-Next-Tutorial.md @@ -0,0 +1,222 @@ +# Running Qwen3-Coder-Next with SGLang and KT-Kernel + +This tutorial demonstrates how to run Qwen3-Coder-Next (80B-A3B) model inference using SGLang integrated with KT-Kernel for CPU-GPU heterogeneous inference. Qwen3-Coder-Next is a Mixture-of-Experts code generation model. KT-Kernel supports both BF16 and FP8 precision backends, allowing you to choose between maximum quality and reduced memory footprint. + +## Table of Contents + +- [Table of Contents](#table-of-contents) +- [Hardware Requirements](#hardware-requirements) +- [Prerequisites](#prerequisites) +- [Step 1: Download Model Weights](#step-1-download-model-weights) +- [Step 2: Launch SGLang Server](#step-2-launch-sglang-server) + - [Key Parameters](#key-parameters) +- [Step 3: Send Inference Requests](#step-3-send-inference-requests) + - [Option A: Interactive Chat with KT CLI](#option-a-interactive-chat-with-kt-cli) + - [Option B: OpenAI-Compatible API](#option-b-openai-compatible-api) +- [Performance](#performance) +- [Troubleshooting](#troubleshooting) + - [OOM (Out of Memory) Issues](#oom-out-of-memory-issues) +- [Additional Resources](#additional-resources) + +## Hardware Requirements + +**Recommended Configuration:** +- **GPU**: 1 x NVIDIA RTX 4090 24 GB +- **CPU**: x86 CPU with AVX512 support (e.g., Intel Sapphire Rapids, AMD EPYC) +- **RAM**: At least 100GB system memory for FP8 model weights +- **Storage**: >85 GB for FP8 model weights (80.4 GB) + +## Prerequisites + +Before starting, ensure you have: + +1. **SGLang installed** + + Note: Currently, please clone our custom SGLang repository: + + ```bash + git clone https://github.com/kvcache-ai/sglang.git + cd sglang + pip install -e "python[all]" + ``` + + You can follow [SGLang integration steps](https://docs.sglang.io/get_started/install.html) + +2. **KT-Kernel installed** + + Please follow [kt-kernel](https://github.com/kvcache-ai/ktransformers/blob/main/kt-kernel/README.md) + + After installation, verify the CLI is working: + + ```bash + kt version + ``` + +3. **CUDA toolkit** - CUDA 12.0+ recommended (12.8+ for best FP8 support) +4. **Hugging Face CLI** - For downloading models: + ```bash + pip install -U huggingface-hub + ``` + +## Step 1: Download Model Weights + +Download the Qwen3-Coder-Next weights from Hugging Face. + +```bash +# FP8 +hf download Qwen/Qwen3-Coder-Next-FP8 \ + --local-dir /path/to/Qwen3-Coder-Next-FP8 + +# BF16 +hf download Qwen/Qwen3-Coder-Next \ + --local-dir /path/to/Qwen3-Coder-Next +``` + +**Note:** Replace `/path/to/` with your actual storage path throughout this tutorial. + +## Step 2: Launch SGLang Server + +Start the SGLang server with KT-Kernel integration for CPU-GPU heterogeneous inference. + +```bash +# FP8 Precision +python -m sglang.launch_server \ + --host 0.0.0.0 \ + --port 30000 \ + --model /path/to/Qwen3-Coder-Next-FP8 \ + --kt-weight-path /path/to/Qwen3-Coder-Next-FP8 \ + --kt-cpuinfer 96 \ + --kt-threadpool-count 2 \ + --kt-num-gpu-experts 100 \ + --kt-method FP8 \ + --kt-gpu-prefill-token-threshold 2048 \ + --attention-backend triton \ + --trust-remote-code \ + --mem-fraction-static 0.80 \ + --chunked-prefill-size 16384 \ + --max-running-requests 4 \ + --max-total-tokens 256000 \ + --served-model-name Qwen3-Coder-Next \ + --enable-mixed-chunk \ + --tensor-parallel-size 1 \ + --enable-p2p-check \ + --disable-shared-experts-fusion \ + --fp8-gemm-backend cutlass \ + --tool-call-parser qwen3_coder \ + --kt-enable-dynamic-expert-update + +# BF16 Precision +python -m sglang.launch_server \ + --host 0.0.0.0 \ + --port 30000 \ + --model /path/to/Qwen3-Coder-Next \ + --kt-weight-path /path/to/Qwen3-Coder-Next \ + --kt-cpuinfer 96 \ + --kt-threadpool-count 2 \ + --kt-num-gpu-experts 60 \ + --kt-method BF16 \ + --kt-gpu-prefill-token-threshold 2048 \ + --attention-backend triton \ + --trust-remote-code \ + --mem-fraction-static 0.80 \ + --chunked-prefill-size 16384 \ + --max-running-requests 4 \ + --max-total-tokens 256000 \ + --served-model-name Qwen3-Coder-Next \ + --enable-mixed-chunk \ + --tensor-parallel-size 1 \ + --enable-p2p-check \ + --disable-shared-experts-fusion \ + --tool-call-parser qwen3_coder \ + --kt-enable-dynamic-expert-update +``` + +See [KT-Kernel Parameters](https://github.com/kvcache-ai/ktransformers/tree/main/kt-kernel#kt-kernel-parameters) for detailed parameter tuning guidelines. + +### Key Parameters + +| Parameter | Description | +|-----------|-------------| +| `--kt-method FP8 / BF16` | Inference precision mode. FP8 halves weight memory; BF16 uses full precision. | +| `--kt-cpuinfer` | Number of CPU inference threads. | +| `--kt-threadpool-count` | Number of thread pools. Set to NUMA node count. | +| `--kt-num-gpu-experts` | Number of experts kept on GPU for decoding. | +| `--kt-gpu-prefill-token-threshold` | Token threshold for layerwise prefill strategy. | +| `--kt-enable-dynamic-expert-update` | Enable dynamic expert placement on GPU based on routing statistics. | +| `--kt-expert-placement-strategy` | Expert placement strategy. Default: `uniform`. See [Expert Scheduling Tutorial](experts-sched-Tutorial.md) for other options. | +| `--chunked-prefill-size` | Maximum tokens per prefill batch. | +| `--max-total-tokens` | Maximum total tokens in KV cache. | +| `--tool-call-parser` | Tool call parser for function calling support (use `qwen3_coder`). | +| `--fp8-gemm-backend` | GEMM backend for FP8 computation. | + +## Step 3: Send Inference Requests + +Once the server is running (default: `http://localhost:30000`), you can interact with the model in several ways: + +### Option A: Interactive Chat with KT CLI + +The easiest way to chat with the model: + +```bash +kt chat +``` + +This opens an interactive terminal chat session. Type your messages and press Enter to send. Use `Ctrl+C` to exit. + +### Option B: OpenAI-Compatible API + +The server exposes an OpenAI-compatible API at `http://localhost:30000/v1`. + +**curl example (streaming):** + +```bash +curl http://localhost:30000/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d '{ + "model": "Qwen3-Coder-Next", + "messages": [{"role": "user", "content": "Write a Python function to compute the Fibonacci sequence."}], + "stream": true + }' +``` + +**curl example (non-streaming):** + +```bash +curl -s http://localhost:30000/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d '{ + "model": "Qwen3-Coder-Next", + "messages": [{"role": "user", "content": "Hello! What can you help me with?"}], + "stream": false + }' +``` + +## Performance + +The following benchmarks were measured with single concurrency (Prefill tps / Decode tps): + +| GPU | CPU | PCIe | Precision | 64 tokens | 2048 tokens | 8192 tokens | 32768 tokens | +|-----|-----|------|-----------|-------------|-------------|-------------|--------------| +| 1 x RTX 5090 (32 GB) | 2 x AMD EPYC 9355 | PCIe 5.0 | FP8 | 362 / 75.9 | 1746 / 75.6 | 2407 / 69.1 | 6233 / 51.7 | + +## Troubleshooting + +### OOM (Out of Memory) Issues + +Layerwise prefill requires extra VRAM. If you encounter OOM, adjust these parameters when launching the server: + +| Parameter | VRAM Impact | +|-----------|-------------| +| `--kt-num-gpu-experts` | Reduces expert weight VRAM usage | +| `--chunked-prefill-size` | Reduces prefill extra VRAM allocation | +| `--max-total-tokens` | Reduces KV cache VRAM usage | +| `--mem-fraction-static` | Lower values reserve more VRAM headroom (default: 0.80) | + +**Tip:** Test with an input of length `chunked-prefill-size` to verify your configuration won't OOM during prefill. + +## Additional Resources + +- [Qwen3-Coder-Next Model Card](https://huggingface.co/Qwen/Qwen3-Coder-Next) +- [KT-Kernel Documentation](../../../kt-kernel/README.md) +- [SGLang GitHub](https://github.com/sgl-project/sglang) +- [KT-Kernel Parameters Reference](../../../kt-kernel/README.md#kt-kernel-parameters) diff --git a/kt-kernel/README.md b/kt-kernel/README.md index 467cfe6..9db9960 100644 --- a/kt-kernel/README.md +++ b/kt-kernel/README.md @@ -23,14 +23,14 @@ High-performance kernel operations for KTransformers, featuring CPU-optimized Mo - [hwloc Not Found](#hwloc-not-found) - [Weight Quantization](#weight-quantization) - [Before Commit!](#before-commit) + ## Note **Current Support Status:** +- ✅ **Native Precision with AVX512/AMX**: Supported with AVX512 CPUs in `FP8`, `BF16` and `RAWINT4` format - [Guide](https://github.com/kvcache-ai/ktransformers/blob/main/doc/en/kt-kernel/Native-Precision-Tutorial.md) - ✅ **Intel CPUs with AMX**: Fully supported (using weights converted to INT4/INT8 format) - ✅ **Universal CPU (llamafile backend)**: Supported (using GGUF-format weights) - ✅ **AMD CPUs with BLIS**: Supported (for int8 prefill & decode) - [Guide](https://github.com/kvcache-ai/ktransformers/blob/main/doc/en/kt-kernel/amd_blis.md) -- ✅ **Kimi-K2 Native INT4 (RAWINT4)**: Supported on AVX512 CPUs (CPU-GPU shared INT4 weights) - [Guide](https://github.com/kvcache-ai/ktransformers/blob/main/doc/en/kt-kernel/Kimi-K2-Thinking-Native.md) -- ✅ **FP8 weights (e.g., MiniMax-M2.1)**: Supported on AVX512 CPUs (CPU-GPU shared FP8 weights) - [Guide](https://github.com/kvcache-ai/ktransformers/blob/main/doc/en/kt-kernel/MiniMax-M2.1-Tutorial.md) **KT-CLI** @@ -39,6 +39,7 @@ We are developing a simpler way to use KTransformers. Check out the [KT-CLI Guid ## Features - **CPU-Optimized MoE Kernels**: High-throughput MoE expert kernels optimized for instruction sets. +- **AVX512 Native Precision Backend**: FP8 / BF16 / INT4 native MoE backend for AVX512-capable servers. - **AMX INT4/INT8 Backend**: INT4 / INT8 quantized expert inference backend for AMX-capable servers. - **Llamafile CPU Backend**: AVX2/AVX512-based MoE backend built on Llamafile for universal CPU deployment. - **NUMA-Aware Execution**: Thread pool and memory layout designed for multi-socket / multi-NUMA machines. @@ -69,9 +70,6 @@ pip install kt-kernel - CPU with AVX2 support (Intel Haswell 2013+, AMD Zen+) - Optional: NVIDIA GPU with compute capability 8.0+ for CUDA features -<<<<<<< HEAD -**GPU Compatibility (Optional):** -======= #### CUDA Installation (GPU Acceleration) For NVIDIA GPU-accelerated inference: @@ -95,7 +93,6 @@ pip install kt-kernel-cuda - NVIDIA driver with CUDA 11.8+ or 12.x support (no CUDA toolkit needed) **GPU Compatibility Matrix:** ->>>>>>> main | GPU Architecture | Compute Capability | Supported | Example GPUs | |-----------------|-------------------|-----------|-------------| @@ -192,6 +189,8 @@ Simply run the install script - it will auto-detect your CPU and optimize for be | **LLAMAFILE** | AVX2 | Intel Haswell (2013+), AMD Zen+ | Universal compatibility | | **RAWINT4** | AVX512F + AVX512BW | Intel Skylake-X (2017+), Ice Lake, Cascade Lake | Software fallbacks for VNNI/BF16 | | **AMXINT4/INT8** | AMX | Intel Sapphire Rapids (2023+) | Best performance, requires AMX hardware | +| **FP8** | AVX512F + AVX512BW + AVX512_BF16 + AVX512_VBMI | Intel Cooper Lake (2020+), Sapphire Rapids (2023+); AMD Zen 4+ (e.g., EPYC 9355) | Native Precision (e.g., DeepSeek V3.2, MiniMax M2.1) | +| **BF16** | AVX512F + AVX512BW + AVX512_BF16 | Intel Cooper Lake (2020+), Sapphire Rapids (2023+); AMD Zen 4+ (e.g., EPYC 9355) | Native Precision (e.g., Qwen3-235B-A22B, GLM-4.7) | **Software Fallback Support (AVX512 backends):** - ✅ VNNI fallback: Uses AVX512BW instructions @@ -329,7 +328,7 @@ See [KT-Kernel Parameters](#kt-kernel-parameters) section below for detailed par ### Complete Example: Qwen3-30B-A3B -This example demonstrates the full workflow from downloading weights to launching the server, showing both **AMX backend** and **LLAMAFILE backend** options. +This example demonstrates the full workflow from downloading weights to launching the server, showing **Native backend**, **AMX backend** and **LLAMAFILE backend** options. **Hardware Configuration:** - **GPU**: NVIDIA RTX 4090 24GB @@ -353,10 +352,52 @@ NUMA node(s): 2 - `--kt-threadpool-count 2`: 2 NUMA nodes detected (dual-socket system) - `--kt-num-gpu-experts 32`: With 24GB GPU memory, we can fit ~32 experts on GPU for this model (varies by model architecture and actual memory usage) - `--kt-max-deferred-experts-per-token 2`: Enable pipelined execution; allows CPU to process next batch while GPU completes current batch +- `--kt-gpu-prefill-token-threshold 2048`: Use layerwise prefill strategy when token count exceeds 2048 (for native backends only) --- -#### Option A: AMX Backend (AMXINT8) +#### Option A: Native Backend (BF16) + +For AVX512 CPUs with BF16 support. + +**Step 1: Download model weights** + +```bash +# Install huggingface-cli if not already installed +pip install huggingface-hub +# Download model from Hugging Face +huggingface-cli download Qwen/Qwen3-30B-A3B --local-dir /mnt/data/models/Qwen3-30B-A3B +``` + +**Step 2: Launch SGLang server** + +```bash +python -m sglang.launch_server \ + --host 0.0.0.0 \ + --port 30000 \ + --model /mnt/data/models/Qwen3-30B-A3B \ + --kt-weight-path /mnt/data/models/Qwen3-30B-A3B \ + --kt-cpuinfer 64 \ + --kt-threadpool-count 2 \ + --kt-num-gpu-experts 32 \ + --kt-method BF16 \ + --attention-backend flashinfer \ + --trust-remote-code \ + --mem-fraction-static 0.80 \ + --chunked-prefill-size 16384 \ + --max-running-requests 4 \ + --served-model-name Qwen3 \ + --enable-mixed-chunk \ + --tensor-parallel-size 1 \ + --enable-p2p-check \ + --disable-shared-experts-fusion \ + --kt-gpu-prefill-token-threshold 4096 \ + --kt-enable-dynamic-expert-update +``` + +--- + +#### Option B: AMX Backend (AMXINT8) For Intel CPUs with AMX instruction set support. @@ -402,7 +443,7 @@ python -m sglang.launch_server \ --- -#### Option B: LLAMAFILE Backend (GGUF) +#### Option C: LLAMAFILE Backend (GGUF) For universal CPUs (no AMX required), using pre-quantized GGUF weights directly. @@ -445,21 +486,24 @@ python -m sglang.launch_server \ | Parameter | Description | Example Value | |-----------|-------------|---------------| -| `--kt-method` | CPU inference backend method | `AMXINT4`, `AMXINT8`, `RAWINT4`, `FP8` or `LLAMAFILE` | +| `--kt-method` | CPU inference backend method | `AMXINT4`, `AMXINT8`, `RAWINT4`, `FP8`, `FP8_PERCHANNEL`, `BF16` or `LLAMAFILE` | | `--kt-weight-path` | Path to quantized CPU weights | `/path/to/cpu-weights` | | `--kt-cpuinfer` | Number of CPU inference threads | `64` (adjust based on CPU cores) | | `--kt-threadpool-count` | Number of thread pools for parallel execution | `2` (typically 1-4) | | `--kt-num-gpu-experts` | Number of experts to keep on GPU | `32` (remaining experts go to CPU) | | `--kt-max-deferred-experts-per-token` | Number of experts per token to defer for pipelined execution | `2` (0 to disable, 1-4 recommended) | -| `--kt-gpu-prefill-token-threshold` | Token count threshold for prefill strategy (FP8 and RAWINT4 only) | ~`1024` | +| `--kt-gpu-prefill-token-threshold` | Token count threshold for prefill strategy (native backend only) | ~`1024-4096` | +| `--kt-enable-dynamic-expert-update` | Enable dynamic expert placement updates during prefill based on actual routing statistics | (flag, no value needed) | +| `--kt-expert-placement-strategy` | Strategy for initial GPU expert placement | `uniform`, `frequency`, `front-loading`, or `random` | **Parameter Guidelines:** - **`kt-method`**: Choose based on your CPU and weight format: - `AMXINT4`: Best performance on AMX CPUs with INT4 quantized weights (May cause huge accuracy drop for some models, e.g., Qwen3-30B-A3B) - `AMXINT8`: Higher accuracy with INT8 quantized weights on AMX CPUs - - `RAWINT4`: Native INT4 weights shared by CPU and GPU (AMX backend only, currently supports Kimi-K2-Thinking model). See [Kimi-K2-Thinking Native Tutorial](../doc/en/Kimi-K2-Thinking-Native.md) for details. - - `FP8`: FP8 weights shared by CPU and GPU + - `RAWINT4`: Native INT4 weights shared by CPU and GPU (currently supports Kimi-K2-Thinking model). See [Kimi-K2-Thinking Native Tutorial](../doc/en/Kimi-K2-Thinking-Native.md) for details. + - `FP8`, `FP8_PERCHANNEL`: FP8 weights shared by CPU and GPU + - `BF16`: BF16 weights shared by CPU and GPU - `LLAMAFILE`: GGUF-based backend - **`kt-cpuinfer`**: Set to the number of **physical CPU cores** (not hyperthreads). @@ -490,6 +534,19 @@ python -m sglang.launch_server \ - **> threshold**: Uses layerwise GPU prefill. Performance scales better with longer sequences, but requires one MoE layer extra VRAM (e.g., ~9GB+ for Kimi-K2-Thinking and ~3.6GB for MiniMax-M2.1). - Only applicable when `--kt-method RAWINT4` or `--kt-method FP8` is used. +- **`kt-enable-dynamic-expert-update`**: Enables dynamic expert placement updates during inference. + - During layerwise prefill, the system collects actual routing statistics and redistributes GPU experts accordingly. + - Requires `--kt-gpu-prefill-token-threshold` to be set, and prefill length must be ≥ the threshold value. + - Particularly effective at lower GPU expert ratios (10%-70%), where it can significantly outperform static strategies. + - See [Expert Scheduling Tutorial](../doc/en/kt-kernel/experts-sched-Tutorial.md) for benchmarks and details. + +- **`kt-expert-placement-strategy`**: Determines which experts are placed on GPU at server startup. + - `uniform`: Distributes GPU experts evenly across all MoE layers. Default option, no prior statistics needed. + - `frequency`: Places the most frequently activated experts on GPU. Best performance when activation statistics are available; requires `--init-expert-location` pointing to a `.pt` statistics file. + - `front-loading`: Fills GPU experts from the first MoE layer onwards. + - `random`: Randomly selects experts with a fixed seed (42). + - See [Expert Scheduling Tutorial](../doc/en/kt-kernel/experts-sched-Tutorial.md) for strategy comparison. + ## Direct Python API Usage For standalone usage without SGLang, you can use KT-Kernel directly via Python API: