mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2026-03-14 18:37:23 +00:00
fix
This commit is contained in:
@@ -84,6 +84,63 @@ pip install .
|
||||
```bash
|
||||
python -c "from kt_kernel import AMXMoEWrapper; print('✓ kt-kernel installed successfully')"
|
||||
```
|
||||
|
||||
## Weight Quantization
|
||||
|
||||
KT-Kernel provides a weight conversion tool to quantize model weights from FP8/FP16/BF16 to INT4/INT8 format optimized for AMX inference.
|
||||
|
||||
### Quantization Methods
|
||||
|
||||
- **INT4**: 4-bit quantization for maximum memory efficiency
|
||||
- **INT8**: 8-bit quantization for better accuracy
|
||||
|
||||
### Supported Input Formats
|
||||
|
||||
- **FP8**: 8-bit floating point with automatic dequantization
|
||||
- **FP16**: 16-bit floating point
|
||||
- **BF16**: BFloat16 format
|
||||
|
||||
### Basic Usage
|
||||
|
||||
```bash
|
||||
# Quantize BF16 model to INT4
|
||||
python scripts/convert_weights.py \
|
||||
--input-path /path/to/bf16/model \
|
||||
--input-type bf16 \
|
||||
--output /path/to/output \
|
||||
--quant-method int4
|
||||
|
||||
# Quantize FP16 model to INT8
|
||||
python scripts/convert_weights.py \
|
||||
--input-path /path/to/fp16/model \
|
||||
--input-type fp16 \
|
||||
--output /path/to/output \
|
||||
--quant-method int8
|
||||
|
||||
# Quantize FP8 model to INT4
|
||||
python scripts/convert_weights.py \
|
||||
--input-path /path/to/fp8/model \
|
||||
--input-type fp8 \
|
||||
--output /path/to/output \
|
||||
--quant-method int4
|
||||
```
|
||||
|
||||
### Output Format
|
||||
|
||||
The converted weights are saved in SafeTensors format with NUMA-aware layout:
|
||||
```
|
||||
output_dir/
|
||||
├── model-00001-of-00050.safetensors
|
||||
├── model-00002-of-00050.safetensors
|
||||
├── ...
|
||||
├── config.json
|
||||
└── tokenizer files...
|
||||
```
|
||||
|
||||
Each expert's weights are split across NUMA nodes for optimal memory access:
|
||||
- `blk.{layer}.ffn_{proj}_exps.{expert}.numa.{numa_idx}.weight`: Quantized weights
|
||||
- `blk.{layer}.ffn_{proj}_exps.{expert}.numa.{numa_idx}.scale`: Quantization scales
|
||||
|
||||
## Before Commit!
|
||||
your msg should match: Conventional Commits (https://www.conventionalcommits.org/) <br>and format your code before commit:
|
||||
```shell
|
||||
|
||||
@@ -191,7 +191,7 @@ class AMXMoEWrapper:
|
||||
moe_intermediate_size: int,
|
||||
num_gpu_experts: int,
|
||||
cpuinfer_threads: int,
|
||||
subpool_count: int,
|
||||
threadpool_count: int,
|
||||
amx_weight_path: str,
|
||||
chunked_prefill_size: int,
|
||||
cpu_save: bool = False,
|
||||
@@ -207,7 +207,7 @@ class AMXMoEWrapper:
|
||||
moe_intermediate_size: MoE intermediate size
|
||||
num_gpu_experts: Number of experts to run on GPU
|
||||
cpuinfer_threads: Number of CPU inference threads
|
||||
subpool_count: Number of NUMA subpools
|
||||
threadpool_count: Number of NUMA subpools
|
||||
amx_weight_path: Path to AMX weights
|
||||
chunked_prefill_size: Maximum prefill chunk size
|
||||
cpu_save: Whether to save weights to CPU memory
|
||||
@@ -227,13 +227,13 @@ class AMXMoEWrapper:
|
||||
if AMXMoEWrapper._cpu_infer_instance is None:
|
||||
worker_config = cpuinfer_ext.WorkerPoolConfig()
|
||||
|
||||
subpool_numa_map = list(range(subpool_count))
|
||||
subpool_numa_map = list(range(threadpool_count))
|
||||
subpool_thread_count = [
|
||||
cpuinfer_threads // subpool_count + (1 if i < cpuinfer_threads % subpool_count else 0)
|
||||
for i in range(subpool_count)
|
||||
cpuinfer_threads // threadpool_count + (1 if i < cpuinfer_threads % threadpool_count else 0)
|
||||
for i in range(threadpool_count)
|
||||
]
|
||||
|
||||
worker_config.subpool_count = subpool_count
|
||||
worker_config.subpool_count = threadpool_count
|
||||
worker_config.subpool_numa_map = subpool_numa_map
|
||||
worker_config.subpool_thread_count = subpool_thread_count
|
||||
AMXMoEWrapper._cpu_infer_instance = cpuinfer_ext.CPUInfer(worker_config)
|
||||
@@ -261,6 +261,64 @@ class AMXMoEWrapper:
|
||||
self.up_scales = None
|
||||
self.down_scales = None
|
||||
|
||||
def load_weights_from_tensors(
|
||||
self,
|
||||
gate_proj: torch.Tensor,
|
||||
up_proj: torch.Tensor,
|
||||
down_proj: torch.Tensor,
|
||||
physical_to_logical_map_cpu: torch.Tensor,
|
||||
):
|
||||
"""
|
||||
Load and quantize weights from BF16/FP16 tensors (online quantization).
|
||||
|
||||
Args:
|
||||
gate_proj: Gate projection weights [num_experts, intermediate_size, hidden_size]
|
||||
up_proj: Up projection weights [num_experts, intermediate_size, hidden_size]
|
||||
down_proj: Down projection weights [num_experts, hidden_size, intermediate_size]
|
||||
physical_to_logical_map_cpu: Mapping from physical to logical expert IDs
|
||||
"""
|
||||
# Store tensors as instance variables to keep them alive
|
||||
self.gate_proj = gate_proj.contiguous()
|
||||
self.up_proj = up_proj.contiguous()
|
||||
self.down_proj = down_proj.contiguous()
|
||||
|
||||
# Configure MoE with online quantization (cpu_save mode)
|
||||
moe_config = MOEConfig(
|
||||
self.num_experts,
|
||||
self.num_experts_per_tok,
|
||||
self.hidden_size,
|
||||
self.moe_intermediate_size,
|
||||
self.num_gpu_experts,
|
||||
)
|
||||
moe_config.layer_idx = self.layer_idx
|
||||
moe_config.pool = self.cpu_infer.backend_
|
||||
moe_config.max_len = self.chunked_prefill_size
|
||||
|
||||
# Enable save mode for online quantization
|
||||
moe_config.save = True
|
||||
moe_config.load = False
|
||||
|
||||
# Set weight pointers
|
||||
moe_config.gate_proj = self.gate_proj.data_ptr()
|
||||
moe_config.up_proj = self.up_proj.data_ptr()
|
||||
moe_config.down_proj = self.down_proj.data_ptr()
|
||||
|
||||
# Set output path for quantized weights
|
||||
moe_config.path = self.amx_weight_path
|
||||
|
||||
# Create MoE module based on AMX method
|
||||
amx_method = os.environ.get("AMX_METHOD", "AMXINT4")
|
||||
if amx_method == "AMXINT4":
|
||||
self.moe = AMXInt4_MOE(moe_config)
|
||||
elif amx_method == "AMXINT8":
|
||||
self.moe = AMXInt8_MOE(moe_config)
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported AMX method: {amx_method}")
|
||||
|
||||
# Submit quantization and save task
|
||||
self.cpu_infer.submit(self.moe.load_weights_task(physical_to_logical_map_cpu.data_ptr()))
|
||||
self.cpu_infer.sync()
|
||||
|
||||
def load_weights(self, physical_to_logical_map_cpu: torch.Tensor):
|
||||
"""
|
||||
Load weights for this layer and initialize the MoE module.
|
||||
|
||||
@@ -19,6 +19,15 @@ from safetensors.torch import save_file
|
||||
from compressed_tensors.compressors import pack_to_int32, unpack_from_int32
|
||||
import gc
|
||||
import time
|
||||
import json
|
||||
import sys
|
||||
import glob
|
||||
import numpy as np
|
||||
|
||||
# Add parent directory to path to import kt_kernel
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
|
||||
from kt_kernel import AMXMoEWrapper
|
||||
|
||||
import cpuinfer_ext
|
||||
|
||||
|
||||
@@ -31,6 +40,66 @@ REVERSE_AWQ_PACK_ORDER = [0, 4, 1, 5, 2, 6, 3, 7]
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
|
||||
def load_model_config(input_path: str, input_type: str = None) -> Dict:
|
||||
"""Load model configuration from config.json
|
||||
|
||||
Args:
|
||||
input_path: Path to directory containing config.json
|
||||
input_type: Input weight type (fp8/fp16/bf16/awq), used to validate FP8 config
|
||||
|
||||
Returns:
|
||||
Dictionary with model configuration
|
||||
"""
|
||||
config_path = os.path.join(input_path, "config.json")
|
||||
if not os.path.exists(config_path):
|
||||
raise FileNotFoundError(f"config.json not found in {input_path}")
|
||||
|
||||
with open(config_path, "r") as f:
|
||||
config = json.load(f)
|
||||
|
||||
# Extract required fields with fallbacks
|
||||
model_config = {
|
||||
"num_experts": config.get("n_routed_experts", config.get("num_experts")),
|
||||
"num_experts_per_tok": config.get("num_experts_per_tok", 2),
|
||||
"hidden_size": config.get("hidden_size"),
|
||||
"moe_intermediate_size": config.get("moe_intermediate_size", config.get("intermediate_size")),
|
||||
}
|
||||
|
||||
# Validate required fields
|
||||
missing_fields = [k for k, v in model_config.items() if v is None]
|
||||
if missing_fields:
|
||||
raise ValueError(f"Missing required config fields: {missing_fields}")
|
||||
|
||||
# For FP8 input, extract and validate quantization_config
|
||||
if input_type == "fp8":
|
||||
quant_config = config.get("quantization_config")
|
||||
if quant_config is None:
|
||||
raise ValueError(
|
||||
"FP8 input type specified but 'quantization_config' not found in config.json. "
|
||||
"Expected quantization_config with weight_block_size field."
|
||||
)
|
||||
|
||||
weight_block_size = quant_config.get("weight_block_size")
|
||||
if weight_block_size is None:
|
||||
raise ValueError(
|
||||
"FP8 quantization_config found but 'weight_block_size' field is missing. "
|
||||
"Expected format: 'weight_block_size': [128, 128]"
|
||||
)
|
||||
|
||||
if not isinstance(weight_block_size, list) or len(weight_block_size) != 2:
|
||||
raise ValueError(
|
||||
f"Invalid weight_block_size format: {weight_block_size}. "
|
||||
"Expected a list of two integers, e.g., [128, 128]"
|
||||
)
|
||||
|
||||
model_config["fp8_weight_block_size"] = weight_block_size
|
||||
print(f"FP8 quantization config detected:")
|
||||
print(f" format: {quant_config.get('fmt', 'unknown')}")
|
||||
print(f" weight_block_size: {weight_block_size}")
|
||||
|
||||
return model_config
|
||||
|
||||
|
||||
def pack(imatrix: torch.Tensor):
|
||||
"""
|
||||
Packs a 4-bit integer matrix into a packed 32-bit integer matrix.
|
||||
@@ -150,44 +219,32 @@ class ConverterBase:
|
||||
tensor transformation for a given quantization method (e.g., awq, int4, int8).
|
||||
"""
|
||||
|
||||
def __init__(self, input_path: str, output_path: str, bf16_path: str = None):
|
||||
def __init__(
|
||||
self,
|
||||
input_path: str,
|
||||
output_path: str,
|
||||
model_config: Dict,
|
||||
cpuinfer_threads: int = 60,
|
||||
subpool_count: int = 2,
|
||||
input_type: str = None,
|
||||
):
|
||||
self.input_path = input_path
|
||||
self.output_path = output_path
|
||||
self.bf16_path = bf16_path
|
||||
self.tensor_file_map: Dict[str, str] = {}
|
||||
self.file_handle_map: Dict[str, any] = {}
|
||||
self.model_config = model_config
|
||||
self.cpuinfer_threads = cpuinfer_threads
|
||||
self.subpool_count = subpool_count
|
||||
self.input_type = input_type
|
||||
self.tensor_file_map: Dict[str, str] = {} # key -> filename
|
||||
self.file_handle_map: Dict[str, any] = {} # filename -> file
|
||||
|
||||
# Extract commonly used config values for convenience
|
||||
self.num_experts = model_config["num_experts"]
|
||||
self.num_experts_per_tok = model_config["num_experts_per_tok"]
|
||||
self.hidden_size = model_config["hidden_size"]
|
||||
self.moe_intermediate_size = model_config["moe_intermediate_size"]
|
||||
|
||||
# Load input safetensors files
|
||||
self._load_input_files()
|
||||
if bf16_path:
|
||||
self.tensor_file_map_bf16: Dict[str, str] = {}
|
||||
self.file_handle_map_bf16: Dict[str, any] = {}
|
||||
self._load_bf16_files()
|
||||
|
||||
def _load_bf16_files(self):
|
||||
"""Load all bf16 safetensors files from bf16 directory"""
|
||||
print(f"Loading bf16 safetensors files from {self.bf16_path}")
|
||||
|
||||
found_safetensor = False
|
||||
for root, _, files in os.walk(self.bf16_path):
|
||||
files = sorted(files)
|
||||
for file in files:
|
||||
if file.endswith(".safetensors"):
|
||||
found_safetensor = True
|
||||
file_path = os.path.join(root, file)
|
||||
try:
|
||||
handle = safe_open(file_path, framework="pt")
|
||||
self.file_handle_map_bf16[file] = handle
|
||||
for key in handle.keys():
|
||||
self.tensor_file_map_bf16[key] = file
|
||||
print(f" Loaded: {file} ({len(list(handle.keys()))} tensors)")
|
||||
except Exception as e:
|
||||
print(f" Error loading {file}: {e}")
|
||||
|
||||
if not found_safetensor:
|
||||
raise FileNotFoundError(f"No safetensors files found in {self.bf16_path}")
|
||||
|
||||
print(f"Total tensors loaded: {len(self.tensor_file_map)}")
|
||||
|
||||
def _load_input_files(self):
|
||||
"""Load all safetensors files from input directory"""
|
||||
@@ -223,15 +280,8 @@ class ConverterBase:
|
||||
handle = self.file_handle_map[file]
|
||||
return handle.get_tensor(key)
|
||||
|
||||
def _load_tensor_bf16(self, key: str) -> torch.Tensor:
|
||||
"""Load tensor by key"""
|
||||
if key not in self.tensor_file_map_bf16:
|
||||
raise KeyError(f"Key {key} not found")
|
||||
|
||||
file = self.tensor_file_map_bf16[key]
|
||||
handle = self.file_handle_map_bf16[file]
|
||||
return handle.get_tensor(key)
|
||||
|
||||
# layers_id -> list[experts_id]
|
||||
def _find_expert_layers(self) -> Dict[int, List[int]]:
|
||||
"""Find all layers and experts in the model"""
|
||||
layers = defaultdict(set)
|
||||
@@ -316,7 +366,7 @@ class ConverterBase:
|
||||
print(f"Saving {len(all_tensors)} tensors...")
|
||||
|
||||
# Split into multiple files if too large
|
||||
max_tensors_per_file = 2000 # Adjust based on memory constraints
|
||||
max_tensors_per_file = 3000 # Adjust based on memory constraints
|
||||
tensor_items = list(all_tensors.items())
|
||||
|
||||
if len(tensor_items) <= max_tensors_per_file:
|
||||
@@ -430,72 +480,328 @@ class AWQToColumnMajorConverter(ConverterBase):
|
||||
return output_tensors
|
||||
|
||||
|
||||
class Int4ToColumnMajorConverter(ConverterBase):
|
||||
"""Convert raw INT4 safetensors to NUMA-sliced column-major format.
|
||||
class OnlineQuantConverter(ConverterBase):
|
||||
"""Convert FP8/FP16/BF16 weights to quantized format using AMXMoEWrapper.
|
||||
|
||||
NOTE: Implement `_convert_layer_experts` with the correct INT4 packing rules.
|
||||
Performs online quantization (FP8/FP16/BF16 -> INT4/INT8) using AMXMoEWrapper
|
||||
with NUMA-aware memory management and automatic weight saving.
|
||||
"""
|
||||
|
||||
def _convert_layer_experts(self, layer_idx: int, expert_ids: List[int]) -> Dict[str, torch.Tensor]:
|
||||
"""Convert all experts in a layer to our numa int4 format"""
|
||||
output_tensors = {}
|
||||
def __init__(
|
||||
self,
|
||||
input_path: str,
|
||||
output_path: str,
|
||||
model_config: Dict,
|
||||
cpuinfer_threads: int = 60,
|
||||
subpool_count: int = 2,
|
||||
input_type: str = None,
|
||||
quant_method: str = "int4",
|
||||
):
|
||||
super().__init__(input_path, output_path, model_config, cpuinfer_threads, subpool_count, input_type)
|
||||
self.quant_method = quant_method
|
||||
|
||||
# For FP8, get block size from model_config
|
||||
if input_type == "fp8":
|
||||
self.fp8_block_size = model_config.get("fp8_weight_block_size", [128, 128])
|
||||
else:
|
||||
self.fp8_block_size = None
|
||||
|
||||
def _dequantize_fp8_blockwise(self, fp8_weight: torch.Tensor, scale_inv: torch.Tensor) -> torch.Tensor:
|
||||
"""Dequantize FP8 weight with block-wise scaling.
|
||||
|
||||
Args:
|
||||
fp8_weight: FP8 weight tensor of shape [H, W]
|
||||
scale_inv: Scale inverse tensor of shape [H//block_size, W//block_size]
|
||||
|
||||
Returns:
|
||||
Dequantized BF16 weight tensor of shape [H, W]
|
||||
"""
|
||||
H, W = fp8_weight.shape
|
||||
num_blocks_h, num_blocks_w = scale_inv.shape
|
||||
|
||||
# Infer block size from shapes
|
||||
block_h = H // num_blocks_h
|
||||
block_w = W // num_blocks_w
|
||||
|
||||
# Reshape fp8_weight to [num_blocks_h, block_h, num_blocks_w, block_w]
|
||||
fp8_reshaped = fp8_weight.view(num_blocks_h, block_h, num_blocks_w, block_w)
|
||||
|
||||
# Reshape scale_inv to [num_blocks_h, 1, num_blocks_w, 1] for broadcasting
|
||||
scale_inv_reshaped = scale_inv.view(num_blocks_h, 1, num_blocks_w, 1)
|
||||
|
||||
# Dequantize: convert to bf16 and multiply by scale_inv
|
||||
dequantized = fp8_reshaped.to(torch.bfloat16) * scale_inv_reshaped
|
||||
|
||||
# Reshape back to [H, W]
|
||||
dequantized = dequantized.view(H, W).contiguous()
|
||||
|
||||
return dequantized
|
||||
|
||||
def _load_binary_tensor(self, file_path: str) -> torch.Tensor:
|
||||
"""Load .kt format binary tensor file
|
||||
|
||||
Args:
|
||||
file_path: Path to .kt binary file
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Loaded tensor
|
||||
"""
|
||||
if not os.path.exists(file_path):
|
||||
raise FileNotFoundError(f"File not found: {file_path}")
|
||||
|
||||
with open(file_path, 'rb') as f:
|
||||
binary_data = f.read()
|
||||
|
||||
# Determine dtype based on file name
|
||||
if 'scale' in file_path:
|
||||
# Scale tensors are typically float32
|
||||
np_array = np.frombuffer(binary_data, dtype=np.float32)
|
||||
else:
|
||||
# Quant tensors are typically int8
|
||||
np_array = np.frombuffer(binary_data, dtype=np.int8)
|
||||
|
||||
tensor = torch.from_numpy(np_array.copy())
|
||||
return tensor
|
||||
|
||||
def _load_layer_tensors_from_disk(self, layer_idx: int) -> Dict[str, torch.Tensor]:
|
||||
"""Load all quantized tensors from _layer_{layer_idx} folder
|
||||
|
||||
Args:
|
||||
layer_idx: Layer index
|
||||
|
||||
Returns:
|
||||
Dict[str, torch.Tensor]: Dictionary with keys in format:
|
||||
'blk.{layer}.ffn_{proj}_exps.{expert}.numa.{numa_idx}.{weight|scale}'
|
||||
"""
|
||||
layer_path = os.path.join(self.output_path, f"_layer_{layer_idx}")
|
||||
if not os.path.exists(layer_path):
|
||||
raise FileNotFoundError(f"Layer folder not found: {layer_path}")
|
||||
|
||||
tensors = {}
|
||||
|
||||
# Get AMX method from quant_method parameter (INT4/INT8)
|
||||
# Map quant_method to AMX_METHOD format
|
||||
quant_to_amx_map = {
|
||||
"int4": "INT4",
|
||||
"int8": "INT8",
|
||||
}
|
||||
amx_method = quant_to_amx_map.get(self.quant_method, "INT4")
|
||||
|
||||
# Iterate through all NUMA folders
|
||||
for numa_idx in range(self.subpool_count):
|
||||
numa_folder = os.path.join(layer_path, f"_numa_{numa_idx}")
|
||||
if not os.path.exists(numa_folder):
|
||||
continue
|
||||
|
||||
# Iterate through all experts
|
||||
for expert_id in range(self.num_experts):
|
||||
# For each projection (down, gate, up)
|
||||
proj_mappings = [
|
||||
('down', 'ffn_down_exps'),
|
||||
('gate', 'ffn_gate_exps'),
|
||||
('up', 'ffn_up_exps')
|
||||
]
|
||||
|
||||
for proj_name, proj_key in proj_mappings:
|
||||
# Build file patterns
|
||||
quant_pattern = os.path.join(
|
||||
numa_folder,
|
||||
f'{amx_method}_{proj_name}_{expert_id}_*Byte_quant_.kt'
|
||||
)
|
||||
scale_pattern = os.path.join(
|
||||
numa_folder,
|
||||
f'{amx_method}_{proj_name}_{expert_id}_*Byte_scale_.kt'
|
||||
)
|
||||
|
||||
# Find files using glob
|
||||
quant_files = glob.glob(quant_pattern)
|
||||
scale_files = glob.glob(scale_pattern)
|
||||
|
||||
# Build keys (following merge_small_tensor.py format)
|
||||
weight_key = f"blk.{layer_idx}.{proj_key}.{expert_id}.numa.{numa_idx}.weight"
|
||||
scale_key = f"blk.{layer_idx}.{proj_key}.{expert_id}.numa.{numa_idx}.scale"
|
||||
|
||||
# Load quant tensor
|
||||
if quant_files:
|
||||
if len(quant_files) > 1:
|
||||
raise ValueError(f"Multiple quant files found: {quant_files}")
|
||||
tensors[weight_key] = self._load_binary_tensor(quant_files[0])
|
||||
|
||||
# Load scale tensor
|
||||
if scale_files:
|
||||
if len(scale_files) > 1:
|
||||
raise ValueError(f"Multiple scale files found: {scale_files}")
|
||||
tensors[scale_key] = self._load_binary_tensor(scale_files[0])
|
||||
|
||||
return tensors
|
||||
|
||||
def _remove_layer_folder(self, layer_idx: int):
|
||||
"""Remove _layer_{layer_idx} folder and all its contents
|
||||
|
||||
Args:
|
||||
layer_idx: Layer index
|
||||
"""
|
||||
import shutil
|
||||
|
||||
layer_path = os.path.join(self.output_path, f"_layer_{layer_idx}")
|
||||
if os.path.exists(layer_path):
|
||||
shutil.rmtree(layer_path)
|
||||
print(f" Removed temporary folder: {layer_path}")
|
||||
|
||||
def _convert_layer_experts(self, layer_idx: int, expert_ids: List[int]) -> Dict[str, torch.Tensor]:
|
||||
"""Convert all experts in a layer using online quantization via AMXMoEWrapper"""
|
||||
start_time = time.time()
|
||||
print(f"Converting layer {layer_idx} with {len(expert_ids)} experts...")
|
||||
print(f"Converting layer {layer_idx} with {len(expert_ids)} experts via online quantization...")
|
||||
|
||||
# Load all expert weights for this layer
|
||||
gate_weights = []
|
||||
up_weights = []
|
||||
down_weights = []
|
||||
|
||||
# Pre-compute projection name mappings
|
||||
proj_mappings = {"up_proj": "ffn_up_exps", "gate_proj": "ffn_gate_exps", "down_proj": "ffn_down_exps"}
|
||||
for expert_id in expert_ids:
|
||||
# Load expert's all tensors for this projection at once
|
||||
# up_expert_weights_out = torch.tensor()
|
||||
# gate_expert_weights_out = torch.tensor()
|
||||
# down_expert_weights_out = torch.tensor()
|
||||
for proj_name, out_proj in proj_mappings.items():
|
||||
weight_key = f"model.layers.{layer_idx}.mlp.experts.{expert_id}.{proj_name}.weight"
|
||||
gate_key = f"model.layers.{layer_idx}.mlp.experts.{expert_id}.gate_proj.weight"
|
||||
up_key = f"model.layers.{layer_idx}.mlp.experts.{expert_id}.up_proj.weight"
|
||||
down_key = f"model.layers.{layer_idx}.mlp.experts.{expert_id}.down_proj.weight"
|
||||
|
||||
if weight_key in self.tensor_file_map_bf16:
|
||||
weight = self._load_tensor_bf16(weight_key)
|
||||
if proj_name == "up_proj":
|
||||
up_expert_weights = weight
|
||||
up_output_tensor = torch.empty(weight.numel(), dtype=torch.uint8).continuous()
|
||||
output_tensors[f"blk.{layer_idx}.{out_proj}.{expert_id}.weight"] = up_output_tensor
|
||||
elif proj_name == "gate_proj":
|
||||
gate_expert_weights = weight
|
||||
gate_output_tensor = torch.empty(weight.numel(), dtype=torch.uint8).continuous()
|
||||
output_tensors[f"blk.{layer_idx}.{out_proj}.{expert_id}.weight"] = gate_output_tensor
|
||||
elif proj_name == "down_proj":
|
||||
down_expert_weights = weight
|
||||
down_output_tensor = torch.empty(weight.numel(), dtype=torch.uint8).continuous()
|
||||
output_tensors[f"blk.{layer_idx}.{out_proj}.{expert_id}.weight"] = down_output_tensor
|
||||
if gate_key not in self.tensor_file_map:
|
||||
raise KeyError(f"Missing gate weight for layer {layer_idx}, expert {expert_id}")
|
||||
if up_key not in self.tensor_file_map:
|
||||
raise KeyError(f"Missing up weight for layer {layer_idx}, expert {expert_id}")
|
||||
if down_key not in self.tensor_file_map:
|
||||
raise KeyError(f"Missing down weight for layer {layer_idx}, expert {expert_id}")
|
||||
|
||||
# call c++ api to get qweights and scales
|
||||
# Load weights based on input type
|
||||
if self.input_type == "fp8":
|
||||
# Load FP8 weights and their scale_inv tensors
|
||||
gate_scale_key = f"model.layers.{layer_idx}.mlp.experts.{expert_id}.gate_proj.weight_scale_inv"
|
||||
up_scale_key = f"model.layers.{layer_idx}.mlp.experts.{expert_id}.up_proj.weight_scale_inv"
|
||||
down_scale_key = f"model.layers.{layer_idx}.mlp.experts.{expert_id}.down_proj.weight_scale_inv"
|
||||
|
||||
if gate_scale_key not in self.tensor_file_map:
|
||||
raise KeyError(f"Missing gate weight_scale_inv for layer {layer_idx}, expert {expert_id}")
|
||||
if up_scale_key not in self.tensor_file_map:
|
||||
raise KeyError(f"Missing up weight_scale_inv for layer {layer_idx}, expert {expert_id}")
|
||||
if down_scale_key not in self.tensor_file_map:
|
||||
raise KeyError(f"Missing down weight_scale_inv for layer {layer_idx}, expert {expert_id}")
|
||||
|
||||
# Load FP8 weights and scales
|
||||
gate_fp8 = self._load_tensor(gate_key)
|
||||
up_fp8 = self._load_tensor(up_key)
|
||||
down_fp8 = self._load_tensor(down_key)
|
||||
|
||||
gate_scale_inv = self._load_tensor(gate_scale_key)
|
||||
up_scale_inv = self._load_tensor(up_scale_key)
|
||||
down_scale_inv = self._load_tensor(down_scale_key)
|
||||
|
||||
# Dequantize FP8 to BF16 using block-wise scaling
|
||||
gate_weight = self._dequantize_fp8_blockwise(gate_fp8, gate_scale_inv)
|
||||
up_weight = self._dequantize_fp8_blockwise(up_fp8, up_scale_inv)
|
||||
down_weight = self._dequantize_fp8_blockwise(down_fp8, down_scale_inv)
|
||||
|
||||
elif self.input_type == "fp16":
|
||||
# Load FP16 and convert to BF16
|
||||
gate_weight = self._load_tensor(gate_key).to(torch.bfloat16)
|
||||
up_weight = self._load_tensor(up_key).to(torch.bfloat16)
|
||||
down_weight = self._load_tensor(down_key).to(torch.bfloat16)
|
||||
|
||||
elif self.input_type == "bf16":
|
||||
# Load BF16 directly
|
||||
gate_weight = self._load_tensor(gate_key)
|
||||
up_weight = self._load_tensor(up_key)
|
||||
down_weight = self._load_tensor(down_key)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unsupported input_type for INT4 conversion: {self.input_type}")
|
||||
|
||||
gate_weights.append(gate_weight)
|
||||
up_weights.append(up_weight)
|
||||
down_weights.append(down_weight)
|
||||
|
||||
# Stack weights into single tensors: [num_experts, ...]
|
||||
gate_proj = torch.stack(gate_weights, dim=0).contiguous()
|
||||
up_proj = torch.stack(up_weights, dim=0).contiguous()
|
||||
down_proj = torch.stack(down_weights, dim=0).contiguous()
|
||||
|
||||
print(f" Loaded weights shapes:")
|
||||
print(f" gate_proj: {gate_proj.shape}")
|
||||
print(f" up_proj: {up_proj.shape}")
|
||||
print(f" down_proj: {down_proj.shape}")
|
||||
|
||||
# Create physical_to_logical_map: identity mapping where position i maps to expert i
|
||||
physical_to_logical_map = torch.arange(self.num_experts, dtype=torch.int64)
|
||||
|
||||
# Create AMXMoEWrapper instance for this layer
|
||||
# num_gpu_experts=0 since we're converting all experts to CPU format
|
||||
wrapper = AMXMoEWrapper(
|
||||
layer_idx=layer_idx,
|
||||
num_experts=self.num_experts,
|
||||
num_experts_per_tok=self.num_experts_per_tok,
|
||||
hidden_size=self.hidden_size,
|
||||
moe_intermediate_size=self.moe_intermediate_size,
|
||||
num_gpu_experts=0, # All experts on CPU for conversion
|
||||
cpuinfer_threads=self.cpuinfer_threads,
|
||||
subpool_count=self.subpool_count,
|
||||
amx_weight_path=self.output_path, # Output path for quantized weights
|
||||
chunked_prefill_size=512, # Arbitrary value, not critical for conversion
|
||||
cpu_save=True, # Enable saving quantized weights to output
|
||||
)
|
||||
|
||||
# Load and quantize weights from tensors
|
||||
# This triggers the quantization process and saves to disk
|
||||
wrapper.load_weights_from_tensors(gate_proj, up_proj, down_proj, physical_to_logical_map)
|
||||
|
||||
# Clean up to free memory
|
||||
del gate_weights, up_weights, down_weights
|
||||
del gate_proj, up_proj, down_proj
|
||||
gc.collect()
|
||||
|
||||
# Load quantized tensors from disk
|
||||
print(f" Loading quantized tensors from disk...")
|
||||
layer_tensors = self._load_layer_tensors_from_disk(layer_idx)
|
||||
print(f" Loaded {len(layer_tensors)} tensors")
|
||||
|
||||
# Remove temporary layer folder
|
||||
self._remove_layer_folder(layer_idx)
|
||||
|
||||
gc.collect()
|
||||
elapsed = time.time() - start_time
|
||||
print(f" Generated {len(output_tensors)} tensors in {elapsed:.2f}s")
|
||||
return output_tensors
|
||||
print(f" Layer {layer_idx} quantized and saved in {elapsed:.2f}s")
|
||||
|
||||
# Return loaded tensors
|
||||
return layer_tensors
|
||||
|
||||
class Int8ToColumnMajorConverter(ConverterBase):
|
||||
"""Convert raw INT8 safetensors to NUMA-sliced column-major format.
|
||||
|
||||
NOTE: Implement `_convert_layer_experts` with the correct INT8 transformation rules.
|
||||
"""
|
||||
|
||||
def _convert_layer_experts(self, layer_idx: int, expert_ids: List[int]) -> Dict[str, torch.Tensor]:
|
||||
raise NotImplementedError("INT8 converter not implemented yet. Please implement transformation logic.")
|
||||
|
||||
"""
|
||||
Example usage(test passed):
|
||||
python convert_weights.py --input-path /mnt/data3/models/DeepSeek-V3.1 --input-type fp8 --output /mnt/data3/models/DeepSeek-V3.1-INT4-test --quant-method int4 --cpuinfer-threads 60 --subpool-count 2
|
||||
python convert_weights.py --input-path /mnt/data2/models/Qwen3-Next-80B-A3B-Instruct --input-type bf16 --output /mnt/data2/models/Qwen3-Next-80B-A3B-Instruct-INT4-test --quant-method int4 --cpuinfer-threads 60 --subpool-count 2
|
||||
"""
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Convert AWQ SafeTensors to column major 1D format")
|
||||
parser.add_argument("--input", "-i", required=True, help="Input directory with raw AWQ safetensors")
|
||||
parser.add_argument("--bf16_path", help="Path to bf16 weights if needed for mixed precision(amx for int4&int8)")
|
||||
parser.add_argument("--output", "-o", required=True, help="Output directory for hybrid safetensors")
|
||||
parser = argparse.ArgumentParser(description="Convert SafeTensors to column major 1D format")
|
||||
parser.add_argument("--input-path", "-i", required=True, help="Input directory with safetensors")
|
||||
parser.add_argument(
|
||||
"--quant_method",
|
||||
"--input-type",
|
||||
choices=["awq", "fp8", "fp16", "bf16"],
|
||||
required=True,
|
||||
help="Input weight type (awq/fp8/fp16/bf16)",
|
||||
)
|
||||
parser.add_argument("--output", "-o", required=True, help="Output directory for converted safetensors")
|
||||
parser.add_argument(
|
||||
"--quant-method",
|
||||
choices=["int4", "int8", "awq"],
|
||||
default="int4",
|
||||
help="Quantization method used in input (default: int4)",
|
||||
help="Quantization method for output (default: int4)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cpuinfer-threads",
|
||||
type=int,
|
||||
default=60,
|
||||
help="Number of CPU inference threads (default: 60)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--subpool-count",
|
||||
type=int,
|
||||
default=2,
|
||||
help="Number of NUMA subpools for thread distribution (default: 2)",
|
||||
)
|
||||
parser.add_argument("--gpu", action="store_true", help="Use GPU for conversion if available")
|
||||
|
||||
@@ -503,21 +809,38 @@ def main():
|
||||
device = torch.device("cuda" if args.gpu and torch.cuda.is_available() else "cpu")
|
||||
|
||||
# Validate inputs
|
||||
if not os.path.exists(args.input):
|
||||
print(f"Error: Input path does not exist: {args.input}")
|
||||
if not os.path.exists(args.input_path):
|
||||
print(f"Error: Input path does not exist: {args.input_path}")
|
||||
return 1
|
||||
try:
|
||||
# Load model configuration from config.json
|
||||
print("Loading model configuration...")
|
||||
model_config = load_model_config(args.input_path, args.input_type)
|
||||
print(f"Model config: {model_config}")
|
||||
print(f" num_experts: {model_config['num_experts']}")
|
||||
print(f" num_experts_per_tok: {model_config['num_experts_per_tok']}")
|
||||
print(f" hidden_size: {model_config['hidden_size']}")
|
||||
print(f" moe_intermediate_size: {model_config['moe_intermediate_size']}")
|
||||
print(f"CPU inference config:")
|
||||
print(f" cpuinfer_threads: {args.cpuinfer_threads}")
|
||||
print(f" subpool_count: {args.subpool_count}")
|
||||
print()
|
||||
|
||||
# Create converter by quantization method
|
||||
quant_method = args.quant_method.lower()
|
||||
if quant_method == "awq":
|
||||
converter = AWQToColumnMajorConverter(args.input, args.output)
|
||||
elif quant_method == "int4" and args.bf16_path:
|
||||
converter = Int4ToColumnMajorConverter(args.input, args.output, args.bf16_path)
|
||||
elif quant_method == "int8" and args.bf16_path:
|
||||
converter = Int8ToColumnMajorConverter(args.input, args.output, args.bf16_path)
|
||||
converter = AWQToColumnMajorConverter(
|
||||
args.input_path, args.output, model_config, args.cpuinfer_threads, args.subpool_count
|
||||
)
|
||||
elif quant_method in ["int4", "int8"] and args.input_type in ["fp8", "fp16", "bf16"]:
|
||||
# Use OnlineQuantConverter for both INT4 and INT8 quantization
|
||||
converter = OnlineQuantConverter(
|
||||
args.input_path, args.output, model_config, args.cpuinfer_threads, args.subpool_count, args.input_type, quant_method
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported quant_method: {args.quant_method} or missing bf16_path for int4/int8")
|
||||
raise ValueError(
|
||||
f"Unsupported quant_method: {args.quant_method} or incompatible input_type: {args.input_type}"
|
||||
)
|
||||
|
||||
# Run conversion
|
||||
converter.convert()
|
||||
|
||||
Reference in New Issue
Block a user