diff --git a/doc/en/Qwen3.5.md b/doc/en/Qwen3.5.md new file mode 100644 index 0000000..335c8c1 --- /dev/null +++ b/doc/en/Qwen3.5.md @@ -0,0 +1,155 @@ +# Running Qwen3.5 with SGLang and KT-Kernel + +This tutorial demonstrates how to run Qwen3.5 (MoE-400B) model inference using SGLang integrated with KT-Kernel for CPU-GPU heterogeneous inference. This setup enables efficient deployment of large MoE models by offloading experts to CPU. + +## Table of Contents + +- [Running Qwen3.5 with SGLang and KT-Kernel](#running-qwen35-with-sglang-and-kt-kernel) + - [Table of Contents](#table-of-contents) + - [Hardware Requirements](#hardware-requirements) + - [Prerequisites](#prerequisites) + - [Step 1: Download Model Weights](#step-1-download-model-weights) + - [Step 2: Launch SGLang Server](#step-2-launch-sglang-server) + - [Launch Command (4x RTX 4090 Example)](#launch-command-4x-rtx-4090-example) + - [Step 3: Send Inference Requests](#step-3-send-inference-requests) + - [Basic Chat Completion Request](#basic-chat-completion-request) + - [Example Response](#example-response) + +## Hardware Requirements + +**Minimum Configuration:** +- **GPU**: NVIDIA 4x RTX 4090 (or equivalent with at least 96GB total VRAM available) +- **CPU**: x86 CPU with AVX512F support (e.g., Intel Sapphire Rapids) +- **RAM**: At least 800GB system memory +- **Storage**: ~800GB for model weights (BF16) + +## Prerequisites + +Before starting, ensure you have: + +1. **KT-Kernel installed**: + +```bash +git clone https://github.com/kvcache-ai/ktransformers.git +git checkout qwen3.5 +git submodule update --init --recursive +cd kt-kernel && ./install.sh +``` + +2. **SGLang installed** - Follow [SGLang integration steps](./kt-kernel_intro.md#integration-with-sglang) + +Note: Currently, please clone our custom SGLang repository: + +```bash +git clone https://github.com/kvcache-ai/sglang.git +git checkout qwen3.5 +cd sglang && pip install -e "python[all]" +# Maybe need to reinstall cudnn according to the issue when launching SGLang +pip install nvidia-cudnn-cu12==9.16.0.29 +``` + +3. **CUDA toolkit** - Compatible with your GPU (CUDA 12.8+ recommended) +4. **Hugging Face CLI** - For downloading models: + + ```bash + pip install huggingface-hub + ``` + +## Step 1: Download Model Weights + +```bash +# Create a directory for models +mkdir -p /path/to/models +cd /path/to/models + +# Download Qwen3.5 (BF16) +huggingface-cli download Qwen/Qwen3.5 \ + --local-dir /path/to/qwen3.5 +``` + +**Note:** Replace `/path/to/models` with your actual storage path throughout this tutorial. + +## Step 2: Launch SGLang Server + +Start the SGLang server with KT-Kernel integration for CPU-GPU heterogeneous inference. + +### Launch Command (4x RTX 4090 Example) + +```bash +python -m sglang.launch_server \ + --host 0.0.0.0 \ + --port 30005 \ + --model /path/to/qwen3.5 \ + --kt-weight-path /path/to/qwen3.5 \ + --kt-cpuinfer 60 \ + --kt-threadpool-count 2 \ + --kt-num-gpu-experts 1 \ + --kt-method BF16 \ + --attention-backend triton \ + --trust-remote-code \ + --mem-fraction-static 0.98 \ + --chunked-prefill-size 4096 \ + --max-running-requests 32 \ + --max-total-tokens 32000 \ + --served-model-name qwen3.5 \ + --enable-mixed-chunk \ + --tensor-parallel-size 4 \ + --enable-p2p-check \ + --disable-shared-experts-fusion \ + --disable-custom-all-reduce +``` + +See [KT-Kernel Parameters](https://github.com/kvcache-ai/ktransformers/tree/main/kt-kernel#kt-kernel-parameters) for detailed parameter tuning guidelines. + +## Step 3: Send Inference Requests + +Once the server is running, you can send inference requests using the OpenAI-compatible API. + +### Basic Chat Completion Request + +```bash +curl -s http://localhost:30005/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d '{ + "model": "qwen3.5", + "stream": false, + "messages": [ + {"role": "user", "content": "hi, who are you?"} + ] + }' +``` + +### Example Response + +```json +{ + "id": "c79f6d63e04f4874acb8853d218e1bf1", + "object": "chat.completion", + "created": 1770880035, + "model": "qwen3.5", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": "Hello! I'm **Qwen**, a large language model developed by **Alibaba Cloud**. I'm designed to provide helpful, accurate, and safe information across a wide range of topics—whether you have questions, need help with writing, coding, analysis, or just want to explore ideas together.\n\nHow can I assist *you* today?", + "reasoning_content": null, + "tool_calls": null + }, + "logprobs": null, + "finish_reason": "stop", + "matched_stop": 248046 + } + ], + "usage": { + "prompt_tokens": 16, + "total_tokens": 527, + "completion_tokens": 511, + "prompt_tokens_details": null, + "reasoning_tokens": 0 + }, + "metadata": { + "weight_version": "default" + } +} +``` diff --git a/kt-kernel/python/utils/loader.py b/kt-kernel/python/utils/loader.py index 50b9b75..68e53cc 100644 --- a/kt-kernel/python/utils/loader.py +++ b/kt-kernel/python/utils/loader.py @@ -440,6 +440,13 @@ class BF16SafeTensorLoader(SafeTensorLoader): """Auto-detect the MoE naming format by checking tensor keys.""" sample_keys = list(self.tensor_file_map.keys())[:1000] + # Check for packed format first (Qwen3.5 MoE style: all experts in one 3D tensor) + for key in sample_keys: + if key.endswith(".mlp.experts.gate_up_proj"): + self._detected_format = "packed" + print("[BF16SafeTensorLoader] Detected format: packed (Qwen3.5 MoE style)") + return + for fmt_name, (path_tpl, gate, up, down) in self.MOE_FORMATS.items(): for key in sample_keys: if ".experts." in key and f".{gate}.weight" in key: @@ -479,6 +486,9 @@ class BF16SafeTensorLoader(SafeTensorLoader): def load_experts(self, base_key: str, device: str = "cpu"): """Load BF16 expert weights (no scales needed).""" + if self._detected_format == "packed": + return self._load_experts_packed(base_key, device) + experts_prefix = self._get_experts_prefix(base_key) gate_name, up_name, down_name = self._get_proj_names() @@ -533,6 +543,13 @@ class BF16SafeTensorLoader(SafeTensorLoader): """Auto-detect the MoE naming format by checking tensor keys.""" sample_keys = list(self.tensor_file_map.keys())[:1000] + # Check for packed format first (Qwen3.5 MoE style: all experts in one 3D tensor) + for key in sample_keys: + if key.endswith(".mlp.experts.gate_up_proj"): + self._detected_format = "packed" + print("[BF16SafeTensorLoader] Detected format: packed (Qwen3.5 MoE style)") + return + for fmt_name, (path_tpl, gate, up, down) in self.MOE_FORMATS.items(): for key in sample_keys: if ".experts." in key and f".{gate}.weight" in key: @@ -572,6 +589,9 @@ class BF16SafeTensorLoader(SafeTensorLoader): def load_experts(self, base_key: str, device: str = "cpu"): """Load BF16 expert weights (no scales needed).""" + if self._detected_format == "packed": + return self._load_experts_packed(base_key, device) + experts_prefix = self._get_experts_prefix(base_key) gate_name, up_name, down_name = self._get_proj_names() @@ -601,6 +621,49 @@ class BF16SafeTensorLoader(SafeTensorLoader): "down": down_weights, } + def _resolve_packed_experts_prefix(self, base_key: str) -> str: + """Resolve the experts prefix for packed format, trying fallbacks.""" + # Direct: model.layers.{N}.mlp.experts + experts_prefix = f"{base_key}.mlp.experts" + if self.has_tensor(f"{experts_prefix}.gate_up_proj"): + return experts_prefix + + # VL models: model.layers.{N} -> model.language_model.layers.{N} + parts = base_key.split(".", 1) + if len(parts) == 2: + alt_base = f"{parts[0]}.language_model.{parts[1]}" + experts_prefix = f"{alt_base}.mlp.experts" + if self.has_tensor(f"{experts_prefix}.gate_up_proj"): + return experts_prefix + + raise ValueError(f"No packed experts found for base_key '{base_key}'.") + + def _load_experts_packed(self, base_key: str, device: str = "cpu"): + """Load packed expert weights (Qwen3.5 MoE style). + + Packed format stores all experts in stacked 3D tensors: + - gate_up_proj: [num_experts, 2 * intermediate_size, hidden_size] + - down_proj: [num_experts, hidden_size, intermediate_size] + """ + experts_prefix = self._resolve_packed_experts_prefix(base_key) + + gate_up_key = f"{experts_prefix}.gate_up_proj" + down_key = f"{experts_prefix}.down_proj" + + gate_up = self.load_tensor(gate_up_key, device) # [E, 2*I, H] + down = self.load_tensor(down_key, device) # [E, H, I] + + mid = gate_up.shape[1] // 2 + gate_list = [gate_up[i, :mid, :].contiguous() for i in range(gate_up.shape[0])] + up_list = [gate_up[i, mid:, :].contiguous() for i in range(gate_up.shape[0])] + down_list = [down[i].contiguous() for i in range(down.shape[0])] + + return { + "gate": gate_list, + "up": up_list, + "down": down_list, + } + class CompressedSafeTensorLoader(SafeTensorLoader): """Loader for compressed SafeTensor layouts (RAWINT4 weights)."""