fix(scripts): resolve OOM when converting gpu weights and update README (#1640)

This commit is contained in:
Jianwei Dong
2025-12-01 14:15:14 +08:00
committed by GitHub
parent e637fedc65
commit fd78fe520a
2 changed files with 266 additions and 73 deletions

View File

@@ -3,7 +3,7 @@
KT-Kernel provides weight conversion tools for CPU-GPU hybrid inference (e.g., integrating KTransformers with SGLang). Both tools work together to enable heterogeneous expert placement:
- **CPU Weights (`convert_cpu_weights.py`)**: Quantize weights to INT4/INT8 with AMX optimization for CPU-resident "cold" experts
- **GPU Weights (`convert_gpu_weights.py`)**: Apply GPTQ quantization (W4A16/W8A16) for GPU-resident "hot" experts
- **GPU Weights (`convert_gpu_weights.py`)**: Apply GPTQ/RTN quantization (W4A16/W8A16) for GPU-resident "hot" experts
---
@@ -165,43 +165,118 @@ pip install accelerate transformers llmcompressor datasets
**Required packages:**
- `accelerate`: For distributed model loading and device mapping
- `transformers`: For model and tokenizer loading
- `llmcompressor`: For GPTQ quantization
- `datasets`: For calibration data loading
- `llmcompressor`: For quantization (supports GPTQ and RTN methods)
- `datasets`: For calibration data loading (GPTQ only)
**Documentation:** This tool is based on llmcompressor. For more details, see [llmcompressor quantization guide](https://docs.vllm.ai/projects/llm-compressor/en/latest/getting-started/compress/#select-a-quantization-method-and-scheme).
### Overview
Apply GPTQ quantization to model weights for GPU-resident "hot" experts (frequently accessed) in CPU-GPU hybrid inference. This tool works together with `convert_cpu_weights.py` to enable heterogeneous expert placement:
Apply weight quantization to model weights for GPU-resident "hot" experts (frequently accessed) in CPU-GPU hybrid inference. This tool works together with `convert_cpu_weights.py` to enable heterogeneous expert placement:
- **GPU-resident experts** ("hot" experts) use GPTQ quantization (this tool) for efficient GPU memory usage
- **GPU-resident experts** ("hot" experts) use GPTQ/RTN quantization (this tool) for efficient GPU memory usage
- **CPU-resident experts** ("cold" experts) use AMX-optimized INT4/INT8 quantization (convert_cpu_weights.py)
- **Attention layers, gates, and shared experts** remain in higher precision
This approach maximizes throughput and resource utilization by intelligently distributing experts across CPUs and GPUs.
### Quantization Methods
#### 1. GPTQ (Calibration-based, Default)
**Pros:**
- Higher accuracy through calibration-based quantization
- Recommended for production deployments
**Cons:**
- Requires calibration dataset
- Slower quantization process
- Higher memory requirements (needs Hessian matrix)
#### 2. RTN (Round-To-Nearest)
**Pros:**
- Fast quantization (no calibration needed)
- Lower memory requirements
- Good for quick testing and prototyping
**Cons:**
- Slightly lower accuracy compared to GPTQ
- No calibration optimization
### Quantization Types
- **W4A16**: 4-bit weights, 16-bit activations (GPTQ4)
- **W8A16**: 8-bit weights, 16-bit activations (GPTQ8)
- **W4A16**: 4-bit weights, 16-bit activations (INT4)
- **W8A16**: 8-bit weights, 16-bit activations (INT8)
### Basic Usage
#### GPTQ Quantization (Recommended for Production)
```bash
python scripts/convert_gpu_weights.py \
--model_id /path/to/model \
--output_dir /path/to/output \
--quant_method GPTQ \
--quant_type W4A16
```
#### RTN Quantization (Fast, for Testing)
```bash
python scripts/convert_gpu_weights.py \
--model_id /path/to/model \
--output_dir /path/to/output \
--quant_method RTN \
--quant_type W4A16
```
### Memory Requirements
Understanding memory requirements is crucial for successful quantization. The requirements differ significantly between RTN and GPTQ methods.
#### RTN Memory Requirements
RTN only requires memory for quantization parameters (scales/zero-points):
| Component | Requirement |
|-----------|-------------|
| **DRAM (CPU Memory)** | ≥ Total model parameters |
| **VRAM (GPU Memory)** | ≥ Single layer parameters |
**Example: DeepSeek-R1-0528-BF16 (684B parameters)**
- DRAM: ~1368 GB (684B params × 2 bytes)
- VRAM: ~22.4 GB (1 layer)
#### GPTQ Memory Requirements
GPTQ requires additional memory for Hessian matrices during calibration:
| Component | Requirement |
|-----------|-------------|
| **DRAM (CPU Memory)** | ≥ Total model parameters |
| **VRAM (GPU Memory)** | ≥ Single layer parameters × 2 |
The Hessian matrix is approximately the same size as the layer weights and is used to increase accuracy recovery.
**Example: DeepSeek-R1-0528-BF16 (684B parameters)**
- DRAM: ~1368 GB (684B params × 2 bytes)
- VRAM: ~44.8 GB (1 layer × 2 for Hessian matrix)
#### Method Comparison
| Method | Speed | VRAM | Accuracy | Use Case |
|--------|-------|------|----------|----------|
| **RTN** | Fast | Low (~22GB) | Good | Testing, prototyping |
| **GPTQ** | Slow | High (~45GB) | Better | Production deployment |
### Advanced Options
#### Calibration Configuration
#### Calibration Configuration (GPTQ Only)
Control the calibration process for better quantization quality:
For GPTQ quantization, control the calibration process for better quantization quality:
```bash
python scripts/convert_gpu_weights.py \
--model_id /path/to/model \
--output_dir /path/to/output \
--quant_method GPTQ \
--quant_type W4A16 \
--num_calibration_samples 512 \
--max_sequence_length 2048 \
@@ -209,53 +284,91 @@ python scripts/convert_gpu_weights.py \
--dataset_split train_sft
```
**Options:**
**Options (GPTQ only):**
- `--num_calibration_samples`: Number of samples for calibration (default: 512)
- `--max_sequence_length`: Maximum sequence length (default: 2048)
- `--dataset`: HuggingFace dataset for calibration
- `--dataset_split`: Dataset split to use
- `--dampening_frac`: Dampening fraction to reduce quantization noise (default: 0.1)
#### Memory Management (Avoiding OOM)
#### Memory Management
GPTQ quantization requires additional GPU memory for Hessian matrix computation beyond model weights. Use `--max_gpu_memory` to limit GPU memory usage and offload remaining layers to CPU:
Use `--max_gpu_memory` to limit GPU memory usage and offload remaining layers to CPU:
```bash
python scripts/convert_gpu_weights.py \
--model_id /path/to/model \
--output_dir /path/to/output \
--quant_method GPTQ \
--quant_type W4A16 \
--max_gpu_memory "40GiB"
```
**Recommended settings:**
**Recommended settings for GPTQ:**
| GPU VRAM | Suggested `--max_gpu_memory` |
|----------|------------------------------|
| 24 GiB | 14-16 GiB |
| 48 GiB | 30-35 GiB |
| 80 GiB | 50-60 GiB |
| GPU VRAM | Suggested `--max_gpu_memory` | Notes |
|----------|------------------------------|-------|
| 24 GiB | 10-12 GiB | Reserve ~50% for Hessian |
| 48 GiB | 24-30 GiB | Reserve ~40% for Hessian |
| 80 GiB | 40-50 GiB | Reserve ~40% for Hessian |
Reserve 40-50% of GPU memory for GPTQ's Hessian matrix computation.
**Recommended settings for RTN:**
| GPU VRAM | Suggested `--max_gpu_memory` | Notes |
|----------|------------------------------|-------|
| 24 GiB | 18-20 GiB | No Hessian needed |
| 48 GiB | 40-45 GiB | No Hessian needed |
| 80 GiB | 70-75 GiB | No Hessian needed |
**Options:**
- `--max_gpu_memory`: Maximum GPU memory for model weights per device (e.g., '40GiB')
- `--max_cpu_memory`: Maximum CPU memory (default: 1000GiB when `--max_gpu_memory` is set)
**Important:** llmcompressor does not support disk offloading. Ensure your machine has enough GPU + CPU memory to load the entire model. If you still encounter OOM:
1. Reduce `--num_calibration_samples` (e.g., 256)
2. Reduce `--max_sequence_length` (e.g., 1024)
3. Use `--force_cpu` to run entirely on CPU (slower but avoids GPU OOM)
1. Use RTN instead of GPTQ (requires less memory)
2. Reduce `--num_calibration_samples` (GPTQ only, e.g., 256)
3. Reduce `--max_sequence_length` (GPTQ only, e.g., 1024)
4. Use `--force_cpu` to run entirely on CPU (slower but avoids GPU OOM)
### Examples
#### Example 1: Quantize Qwen3-Next-80B for Hybrid Inference (W4A16)
#### Example 1: GPTQ Quantization for Production (Qwen3-Next-80B, W4A16)
```bash
python scripts/convert_gpu_weights.py \
--model_id /mnt/data/models/Qwen3-Next-80B-A3B-Thinking \
--output_dir /mnt/data/models/Qwen3-Next-80B-A3B-Thinking-GPTQ4 \
--model_id /mnt/data/models/Qwen3-Next-80B-A3B-Instruct \
--output_dir /mnt/data/models/Qwen3-Next-80B-A3B-Instruct-GPTQ-W4A16 \
--quant_method GPTQ \
--quant_type W4A16 \
--num_calibration_samples 512 \
--max_sequence_length 2048 \
--max_gpu_memory "40GiB" \
--trust_remote_code
```
#### Example 2: RTN Quantization for Fast Testing (DeepSeek-R1, W4A16)
```bash
python scripts/convert_gpu_weights.py \
--model_id /mnt/data/models/DeepSeek-R1-0528-BF16 \
--output_dir /mnt/data/models/DeepSeek-R1-0528-RTN-W4A16 \
--quant_method RTN \
--quant_type W4A16 \
--max_gpu_memory "70GiB" \
--trust_remote_code
```
#### Example 3: GPTQ with Custom Calibration Dataset (GLM-4.5-Air, W8A16)
```bash
python scripts/convert_gpu_weights.py \
--model_id /mnt/data/models/GLM-4.5-Air \
--output_dir /mnt/data/models/GLM-4.5-Air-GPTQ-W8A16 \
--quant_method GPTQ \
--quant_type W8A16 \
--dataset "tatsu-lab/alpaca" \
--dataset_split "train" \
--num_calibration_samples 256 \
--max_gpu_memory "40GiB" \
--trust_remote_code
```

View File

@@ -3,32 +3,49 @@
GPU Weight Quantization Tool for KTransformers
This script quantizes model weights for CPU-GPU hybrid inference when integrating
KTransformers with SGLang. It applies selective quantization (GPTQ) to GPU-resident
layers while preserving certain components (e.g., attention, gates, shared experts)
in higher precision.
KTransformers with SGLang. It supports multiple quantization methods (GPTQ, RTN) and
applies selective quantization to GPU-resident layers while preserving certain
components (e.g., attention, gates, shared experts) in higher precision.
Usage:
python convert_gpu_weights.py --model_id /path/to/model --output_dir /path/to/output --quant_type W4A16
python convert_gpu_weights.py --model_id /path/to/model --output_dir /path/to/output --quant_method GPTQ --quant_type W4A16
Example:
Example (GPTQ with calibration for best accuracy):
python convert_gpu_weights.py \
--model_id /mnt/data2/models/Qwen3-Next-80B-A3B-Instruct \
--output_dir /mnt/data2/models/Qwen3-Next-80B-A3B-Instruct-GPU-weight \
--quant_method GPTQ \
--quant_type W4A16
Example (RTN for fast quantization without calibration):
python convert_gpu_weights.py \
--model_id /mnt/data/models/GLM-4.5-Air \
--output_dir /mnt/data/models/GLM-4.5-Air-GPU-weights-test \
--output_dir /mnt/data/models/GLM-4.5-Air-GPU-weights-rtn \
--quant_method RTN \
--quant_type W4A16
"""
import os
import sys
import warnings
import argparse
# IMPORTANT: Parse force_cpu argument BEFORE importing torch
# CUDA_VISIBLE_DEVICES must be set before torch initializes CUDA
if __name__ == "__main__":
# Quick check for --force_cpu flag before full argument parsing
if "--force_cpu" in sys.argv:
os.environ["CUDA_VISIBLE_DEVICES"] = ""
warnings.filterwarnings("ignore", message="Can't initialize NVML")
print("🔧 Forced CPU-only mode (CUDA_VISIBLE_DEVICES set before torch import)")
# Now it's safe to import torch and other GPU-dependent libraries
import torch
from accelerate import init_empty_weights, infer_auto_device_map
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
from llmcompressor import oneshot
from llmcompressor.modifiers.quantization.gptq import GPTQModifier
from llmcompressor.modifiers.quantization import QuantizationModifier
from datasets import load_dataset
@@ -40,33 +57,46 @@ def parse_args():
parser.add_argument("--output_dir", type=str, required=True, help="Path to save the quantized model")
# Optional arguments
parser.add_argument(
"--quant_method",
type=str,
choices=["GPTQ", "RTN"],
default="GPTQ",
help="Quantization method: GPTQ (calibration-based) or RTN (round-to-nearest, no calibration). Default: GPTQ",
)
parser.add_argument(
"--quant_type",
type=str,
choices=["W4A16", "W8A16"],
default="W8A16",
help="Quantization type: W4A16 (GPTQ4) or W8A16 (GPTQ8). Default: W8A16",
help="Quantization type: W4A16 (INT4) or W8A16 (INT8). Default: W8A16",
)
parser.add_argument(
"--num_calibration_samples", type=int, default=512, help="Number of calibration samples. Default: 512"
"--num_calibration_samples",
type=int,
default=512,
help="Number of calibration samples (GPTQ only). Default: 512",
)
parser.add_argument(
"--max_sequence_length", type=int, default=2048, help="Maximum sequence length for calibration. Default: 2048"
"--max_sequence_length",
type=int,
default=2048,
help="Maximum sequence length for calibration (GPTQ only). Default: 2048",
)
parser.add_argument(
"--dampening_frac",
type=float,
default=0.1,
help="Dampening fraction to mitigate quantization noise. Default: 0.1",
help="Dampening fraction to mitigate quantization noise (GPTQ only). Default: 0.1",
)
parser.add_argument(
"--dataset",
type=str,
default="HuggingFaceH4/ultrachat_200k",
help="Dataset for calibration. Default: HuggingFaceH4/ultrachat_200k",
help="Dataset for calibration (GPTQ only). Default: HuggingFaceH4/ultrachat_200k",
)
parser.add_argument(
"--dataset_split", type=str, default="train_sft", help="Dataset split to use. Default: train_sft"
"--dataset_split", type=str, default="train_sft", help="Dataset split to use (GPTQ only). Default: train_sft"
)
parser.add_argument(
"--force_cpu", action="store_true", help="Force all computations to CPU (sets CUDA_VISIBLE_DEVICES='')"
@@ -118,15 +148,24 @@ def parse_args():
def setup_environment(force_cpu=False):
"""
Setup environment variables and warnings.
Verify environment setup (actual setup happens before torch import).
Args:
force_cpu: If True, forces all computations to CPU by hiding GPUs
force_cpu: If True, was requested to force CPU-only mode
Note:
CUDA_VISIBLE_DEVICES must be set BEFORE importing torch.
The actual environment setup is done at module import time.
"""
if force_cpu:
os.environ["CUDA_VISIBLE_DEVICES"] = ""
warnings.filterwarnings("ignore", message="Can't initialize NVML")
print("🔧 Forced CPU-only mode")
# Verify the environment variable was set correctly
cuda_visible = os.environ.get("CUDA_VISIBLE_DEVICES", None)
if cuda_visible != "":
print("⚠️ Warning: force_cpu was requested but CUDA_VISIBLE_DEVICES is not empty")
print(f" Current value: '{cuda_visible}'")
print(" This may happen if imported as a module. Recommend running as script.")
else:
print("✅ CPU-only mode verified (CUDA_VISIBLE_DEVICES is empty)")
def get_torch_dtype(dtype_str):
@@ -242,9 +281,21 @@ def main():
"""
Main function for GPU weight quantization.
This performs GPTQ quantization on model weights intended for GPU execution
in CPU-GPU hybrid inference scenarios. The quantization is selective:
- Expert MLP weights are quantized to INT4/INT8 (GPTQ)
This performs weight quantization on model weights intended for GPU execution
in CPU-GPU hybrid inference scenarios. Supports two quantization methods:
1. GPTQ (default): Calibration-based quantization for better accuracy
- Requires calibration dataset
- Higher accuracy but slower
- Recommended for production use
2. RTN (Round-To-Nearest): Fast quantization without calibration
- No calibration dataset needed
- Faster but may have lower accuracy
- Good for quick testing or prototyping
The quantization is selective:
- Expert MLP weights are quantized to INT4/INT8
- Attention layers, gates, and shared experts remain in original precision
- Dense layers (if present) are excluded from quantization
@@ -262,9 +313,13 @@ def main():
print(f"🚀 Starting quantization process")
print(f" Model: {args.model_id}")
print(f" Output: {args.output_dir}")
print(f" Quantization: {args.quant_type}")
print(f" Calibration samples: {args.num_calibration_samples}")
print(f" Max sequence length: {args.max_sequence_length}")
print(f" Quantization method: {args.quant_method}")
print(f" Quantization type: {args.quant_type}")
if args.quant_method == "GPTQ":
print(f" Calibration samples: {args.num_calibration_samples}")
print(f" Max sequence length: {args.max_sequence_length}")
else:
print(f" Calibration: Not required for {args.quant_method}")
# --------------------------------------------------------------------
# 0) Check for dense layers and update ignore patterns
@@ -361,24 +416,36 @@ def main():
# --------------------------------------------------------------------
# 3) Prepare calibration dataset
# GPTQ needs calibration data to compute optimal quantization parameters
ds = load_and_prepare_dataset(
args.dataset,
args.dataset_split,
args.num_calibration_samples,
args.max_sequence_length,
tokenizer,
args.random_seed,
)
if args.quant_method == "GPTQ":
ds = load_and_prepare_dataset(
args.dataset,
args.dataset_split,
args.num_calibration_samples,
args.max_sequence_length,
tokenizer,
args.random_seed,
)
# --------------------------------------------------------------------
# 4) Create quantization recipe with selective layer exclusion
print(f"⚙️ Setting up {args.quant_type} quantization recipe...")
recipe = GPTQModifier(
targets="Linear", # Target all Linear layers
scheme=args.quant_type, # W4A16 or W8A16
ignore=updated_ignore_patterns, # Exclude specific patterns
dampening_frac=args.dampening_frac,
)
print(f"⚙️ Setting up {args.quant_method} {args.quant_type} quantization recipe...")
if args.quant_method == "GPTQ":
# GPTQ: calibration-based quantization for better accuracy
recipe = GPTQModifier(
targets="Linear", # Target all Linear layers
scheme=args.quant_type, # W4A16 or W8A16
ignore=updated_ignore_patterns, # Exclude specific patterns
dampening_frac=args.dampening_frac,
)
elif args.quant_method == "RTN":
# RTN (Round-To-Nearest): fast quantization without calibration
recipe = QuantizationModifier(
targets="Linear", # Target all Linear layers
scheme=args.quant_type, # W4A16 or W8A16
ignore=updated_ignore_patterns, # Exclude specific patterns
)
else:
raise ValueError(f"Unsupported quantization method: {args.quant_method}")
print("🔧 Ignoring the following patterns from quantization:")
for i, pattern in enumerate(updated_ignore_patterns):
@@ -386,19 +453,32 @@ def main():
print(f" {marker} {pattern}")
# --------------------------------------------------------------------
# 5) Perform one-shot GPTQ quantization
# This applies GPTQ to quantize weights while minimizing accuracy loss
# 5) Perform one-shot quantization
# GPTQ: calibration-based quantization to minimize accuracy loss
# RTN: fast round-to-nearest quantization without calibration
print("🎯 Starting one-shot quantization...")
oneshot(
model=model,
dataset=ds,
recipe=recipe,
output_dir=args.output_dir,
max_seq_length=args.max_sequence_length,
num_calibration_samples=args.num_calibration_samples,
)
if args.quant_method == "GPTQ":
# GPTQ requires calibration dataset
oneshot(
model=model,
dataset=ds,
recipe=recipe,
output_dir=args.output_dir,
max_seq_length=args.max_sequence_length,
num_calibration_samples=args.num_calibration_samples,
)
elif args.quant_method == "RTN":
# RTN does not require calibration dataset
oneshot(
model=model,
recipe=recipe,
output_dir=args.output_dir,
)
else:
raise ValueError(f"Unsupported quantization method: {args.quant_method}")
print(f"\n✅ Quantized model written to: {args.output_dir}")
print(f" Quantization method: {args.quant_method}")
print(f" Quantization type: {args.quant_type}")
print(f" Ignored patterns remain in {args.torch_dtype}")
print("🎉 Quantization completed successfully!")