diff --git a/kt-kernel/README.md b/kt-kernel/README.md index 2fdc0bd..664714b 100644 --- a/kt-kernel/README.md +++ b/kt-kernel/README.md @@ -84,6 +84,63 @@ pip install . ```bash python -c "from kt_kernel import AMXMoEWrapper; print('✓ kt-kernel installed successfully')" ``` + +## Weight Quantization + +KT-Kernel provides a weight conversion tool to quantize model weights from FP8/FP16/BF16 to INT4/INT8 format optimized for AMX inference. + +### Quantization Methods + +- **INT4**: 4-bit quantization for maximum memory efficiency +- **INT8**: 8-bit quantization for better accuracy + +### Supported Input Formats + +- **FP8**: 8-bit floating point with automatic dequantization +- **FP16**: 16-bit floating point +- **BF16**: BFloat16 format + +### Basic Usage + +```bash +# Quantize BF16 model to INT4 +python scripts/convert_weights.py \ + --input-path /path/to/bf16/model \ + --input-type bf16 \ + --output /path/to/output \ + --quant-method int4 + +# Quantize FP16 model to INT8 +python scripts/convert_weights.py \ + --input-path /path/to/fp16/model \ + --input-type fp16 \ + --output /path/to/output \ + --quant-method int8 + +# Quantize FP8 model to INT4 +python scripts/convert_weights.py \ + --input-path /path/to/fp8/model \ + --input-type fp8 \ + --output /path/to/output \ + --quant-method int4 +``` + +### Output Format + +The converted weights are saved in SafeTensors format with NUMA-aware layout: +``` +output_dir/ +├── model-00001-of-00050.safetensors +├── model-00002-of-00050.safetensors +├── ... +├── config.json +└── tokenizer files... +``` + +Each expert's weights are split across NUMA nodes for optimal memory access: +- `blk.{layer}.ffn_{proj}_exps.{expert}.numa.{numa_idx}.weight`: Quantized weights +- `blk.{layer}.ffn_{proj}_exps.{expert}.numa.{numa_idx}.scale`: Quantization scales + ## Before Commit! your msg should match: Conventional Commits (https://www.conventionalcommits.org/)
and format your code before commit: ```shell diff --git a/kt-kernel/python/experts.py b/kt-kernel/python/experts.py index 7251fd8..896f216 100644 --- a/kt-kernel/python/experts.py +++ b/kt-kernel/python/experts.py @@ -191,7 +191,7 @@ class AMXMoEWrapper: moe_intermediate_size: int, num_gpu_experts: int, cpuinfer_threads: int, - subpool_count: int, + threadpool_count: int, amx_weight_path: str, chunked_prefill_size: int, cpu_save: bool = False, @@ -207,7 +207,7 @@ class AMXMoEWrapper: moe_intermediate_size: MoE intermediate size num_gpu_experts: Number of experts to run on GPU cpuinfer_threads: Number of CPU inference threads - subpool_count: Number of NUMA subpools + threadpool_count: Number of NUMA subpools amx_weight_path: Path to AMX weights chunked_prefill_size: Maximum prefill chunk size cpu_save: Whether to save weights to CPU memory @@ -227,13 +227,13 @@ class AMXMoEWrapper: if AMXMoEWrapper._cpu_infer_instance is None: worker_config = cpuinfer_ext.WorkerPoolConfig() - subpool_numa_map = list(range(subpool_count)) + subpool_numa_map = list(range(threadpool_count)) subpool_thread_count = [ - cpuinfer_threads // subpool_count + (1 if i < cpuinfer_threads % subpool_count else 0) - for i in range(subpool_count) + cpuinfer_threads // threadpool_count + (1 if i < cpuinfer_threads % threadpool_count else 0) + for i in range(threadpool_count) ] - worker_config.subpool_count = subpool_count + worker_config.subpool_count = threadpool_count worker_config.subpool_numa_map = subpool_numa_map worker_config.subpool_thread_count = subpool_thread_count AMXMoEWrapper._cpu_infer_instance = cpuinfer_ext.CPUInfer(worker_config) @@ -261,6 +261,64 @@ class AMXMoEWrapper: self.up_scales = None self.down_scales = None + def load_weights_from_tensors( + self, + gate_proj: torch.Tensor, + up_proj: torch.Tensor, + down_proj: torch.Tensor, + physical_to_logical_map_cpu: torch.Tensor, + ): + """ + Load and quantize weights from BF16/FP16 tensors (online quantization). + + Args: + gate_proj: Gate projection weights [num_experts, intermediate_size, hidden_size] + up_proj: Up projection weights [num_experts, intermediate_size, hidden_size] + down_proj: Down projection weights [num_experts, hidden_size, intermediate_size] + physical_to_logical_map_cpu: Mapping from physical to logical expert IDs + """ + # Store tensors as instance variables to keep them alive + self.gate_proj = gate_proj.contiguous() + self.up_proj = up_proj.contiguous() + self.down_proj = down_proj.contiguous() + + # Configure MoE with online quantization (cpu_save mode) + moe_config = MOEConfig( + self.num_experts, + self.num_experts_per_tok, + self.hidden_size, + self.moe_intermediate_size, + self.num_gpu_experts, + ) + moe_config.layer_idx = self.layer_idx + moe_config.pool = self.cpu_infer.backend_ + moe_config.max_len = self.chunked_prefill_size + + # Enable save mode for online quantization + moe_config.save = True + moe_config.load = False + + # Set weight pointers + moe_config.gate_proj = self.gate_proj.data_ptr() + moe_config.up_proj = self.up_proj.data_ptr() + moe_config.down_proj = self.down_proj.data_ptr() + + # Set output path for quantized weights + moe_config.path = self.amx_weight_path + + # Create MoE module based on AMX method + amx_method = os.environ.get("AMX_METHOD", "AMXINT4") + if amx_method == "AMXINT4": + self.moe = AMXInt4_MOE(moe_config) + elif amx_method == "AMXINT8": + self.moe = AMXInt8_MOE(moe_config) + else: + raise NotImplementedError(f"Unsupported AMX method: {amx_method}") + + # Submit quantization and save task + self.cpu_infer.submit(self.moe.load_weights_task(physical_to_logical_map_cpu.data_ptr())) + self.cpu_infer.sync() + def load_weights(self, physical_to_logical_map_cpu: torch.Tensor): """ Load weights for this layer and initialize the MoE module. diff --git a/kt-kernel/scripts/convert_weights.py b/kt-kernel/scripts/convert_weights.py index d21687c..6a51355 100644 --- a/kt-kernel/scripts/convert_weights.py +++ b/kt-kernel/scripts/convert_weights.py @@ -19,6 +19,15 @@ from safetensors.torch import save_file from compressed_tensors.compressors import pack_to_int32, unpack_from_int32 import gc import time +import json +import sys +import glob +import numpy as np + +# Add parent directory to path to import kt_kernel +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) +from kt_kernel import AMXMoEWrapper + import cpuinfer_ext @@ -31,6 +40,66 @@ REVERSE_AWQ_PACK_ORDER = [0, 4, 1, 5, 2, 6, 3, 7] device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +def load_model_config(input_path: str, input_type: str = None) -> Dict: + """Load model configuration from config.json + + Args: + input_path: Path to directory containing config.json + input_type: Input weight type (fp8/fp16/bf16/awq), used to validate FP8 config + + Returns: + Dictionary with model configuration + """ + config_path = os.path.join(input_path, "config.json") + if not os.path.exists(config_path): + raise FileNotFoundError(f"config.json not found in {input_path}") + + with open(config_path, "r") as f: + config = json.load(f) + + # Extract required fields with fallbacks + model_config = { + "num_experts": config.get("n_routed_experts", config.get("num_experts")), + "num_experts_per_tok": config.get("num_experts_per_tok", 2), + "hidden_size": config.get("hidden_size"), + "moe_intermediate_size": config.get("moe_intermediate_size", config.get("intermediate_size")), + } + + # Validate required fields + missing_fields = [k for k, v in model_config.items() if v is None] + if missing_fields: + raise ValueError(f"Missing required config fields: {missing_fields}") + + # For FP8 input, extract and validate quantization_config + if input_type == "fp8": + quant_config = config.get("quantization_config") + if quant_config is None: + raise ValueError( + "FP8 input type specified but 'quantization_config' not found in config.json. " + "Expected quantization_config with weight_block_size field." + ) + + weight_block_size = quant_config.get("weight_block_size") + if weight_block_size is None: + raise ValueError( + "FP8 quantization_config found but 'weight_block_size' field is missing. " + "Expected format: 'weight_block_size': [128, 128]" + ) + + if not isinstance(weight_block_size, list) or len(weight_block_size) != 2: + raise ValueError( + f"Invalid weight_block_size format: {weight_block_size}. " + "Expected a list of two integers, e.g., [128, 128]" + ) + + model_config["fp8_weight_block_size"] = weight_block_size + print(f"FP8 quantization config detected:") + print(f" format: {quant_config.get('fmt', 'unknown')}") + print(f" weight_block_size: {weight_block_size}") + + return model_config + + def pack(imatrix: torch.Tensor): """ Packs a 4-bit integer matrix into a packed 32-bit integer matrix. @@ -150,44 +219,32 @@ class ConverterBase: tensor transformation for a given quantization method (e.g., awq, int4, int8). """ - def __init__(self, input_path: str, output_path: str, bf16_path: str = None): + def __init__( + self, + input_path: str, + output_path: str, + model_config: Dict, + cpuinfer_threads: int = 60, + subpool_count: int = 2, + input_type: str = None, + ): self.input_path = input_path self.output_path = output_path - self.bf16_path = bf16_path - self.tensor_file_map: Dict[str, str] = {} - self.file_handle_map: Dict[str, any] = {} + self.model_config = model_config + self.cpuinfer_threads = cpuinfer_threads + self.subpool_count = subpool_count + self.input_type = input_type + self.tensor_file_map: Dict[str, str] = {} # key -> filename + self.file_handle_map: Dict[str, any] = {} # filename -> file + + # Extract commonly used config values for convenience + self.num_experts = model_config["num_experts"] + self.num_experts_per_tok = model_config["num_experts_per_tok"] + self.hidden_size = model_config["hidden_size"] + self.moe_intermediate_size = model_config["moe_intermediate_size"] # Load input safetensors files self._load_input_files() - if bf16_path: - self.tensor_file_map_bf16: Dict[str, str] = {} - self.file_handle_map_bf16: Dict[str, any] = {} - self._load_bf16_files() - - def _load_bf16_files(self): - """Load all bf16 safetensors files from bf16 directory""" - print(f"Loading bf16 safetensors files from {self.bf16_path}") - - found_safetensor = False - for root, _, files in os.walk(self.bf16_path): - files = sorted(files) - for file in files: - if file.endswith(".safetensors"): - found_safetensor = True - file_path = os.path.join(root, file) - try: - handle = safe_open(file_path, framework="pt") - self.file_handle_map_bf16[file] = handle - for key in handle.keys(): - self.tensor_file_map_bf16[key] = file - print(f" Loaded: {file} ({len(list(handle.keys()))} tensors)") - except Exception as e: - print(f" Error loading {file}: {e}") - - if not found_safetensor: - raise FileNotFoundError(f"No safetensors files found in {self.bf16_path}") - - print(f"Total tensors loaded: {len(self.tensor_file_map)}") def _load_input_files(self): """Load all safetensors files from input directory""" @@ -223,15 +280,8 @@ class ConverterBase: handle = self.file_handle_map[file] return handle.get_tensor(key) - def _load_tensor_bf16(self, key: str) -> torch.Tensor: - """Load tensor by key""" - if key not in self.tensor_file_map_bf16: - raise KeyError(f"Key {key} not found") - - file = self.tensor_file_map_bf16[key] - handle = self.file_handle_map_bf16[file] - return handle.get_tensor(key) + # layers_id -> list[experts_id] def _find_expert_layers(self) -> Dict[int, List[int]]: """Find all layers and experts in the model""" layers = defaultdict(set) @@ -316,7 +366,7 @@ class ConverterBase: print(f"Saving {len(all_tensors)} tensors...") # Split into multiple files if too large - max_tensors_per_file = 2000 # Adjust based on memory constraints + max_tensors_per_file = 3000 # Adjust based on memory constraints tensor_items = list(all_tensors.items()) if len(tensor_items) <= max_tensors_per_file: @@ -430,72 +480,328 @@ class AWQToColumnMajorConverter(ConverterBase): return output_tensors -class Int4ToColumnMajorConverter(ConverterBase): - """Convert raw INT4 safetensors to NUMA-sliced column-major format. +class OnlineQuantConverter(ConverterBase): + """Convert FP8/FP16/BF16 weights to quantized format using AMXMoEWrapper. - NOTE: Implement `_convert_layer_experts` with the correct INT4 packing rules. + Performs online quantization (FP8/FP16/BF16 -> INT4/INT8) using AMXMoEWrapper + with NUMA-aware memory management and automatic weight saving. """ - def _convert_layer_experts(self, layer_idx: int, expert_ids: List[int]) -> Dict[str, torch.Tensor]: - """Convert all experts in a layer to our numa int4 format""" - output_tensors = {} + def __init__( + self, + input_path: str, + output_path: str, + model_config: Dict, + cpuinfer_threads: int = 60, + subpool_count: int = 2, + input_type: str = None, + quant_method: str = "int4", + ): + super().__init__(input_path, output_path, model_config, cpuinfer_threads, subpool_count, input_type) + self.quant_method = quant_method + # For FP8, get block size from model_config + if input_type == "fp8": + self.fp8_block_size = model_config.get("fp8_weight_block_size", [128, 128]) + else: + self.fp8_block_size = None + + def _dequantize_fp8_blockwise(self, fp8_weight: torch.Tensor, scale_inv: torch.Tensor) -> torch.Tensor: + """Dequantize FP8 weight with block-wise scaling. + + Args: + fp8_weight: FP8 weight tensor of shape [H, W] + scale_inv: Scale inverse tensor of shape [H//block_size, W//block_size] + + Returns: + Dequantized BF16 weight tensor of shape [H, W] + """ + H, W = fp8_weight.shape + num_blocks_h, num_blocks_w = scale_inv.shape + + # Infer block size from shapes + block_h = H // num_blocks_h + block_w = W // num_blocks_w + + # Reshape fp8_weight to [num_blocks_h, block_h, num_blocks_w, block_w] + fp8_reshaped = fp8_weight.view(num_blocks_h, block_h, num_blocks_w, block_w) + + # Reshape scale_inv to [num_blocks_h, 1, num_blocks_w, 1] for broadcasting + scale_inv_reshaped = scale_inv.view(num_blocks_h, 1, num_blocks_w, 1) + + # Dequantize: convert to bf16 and multiply by scale_inv + dequantized = fp8_reshaped.to(torch.bfloat16) * scale_inv_reshaped + + # Reshape back to [H, W] + dequantized = dequantized.view(H, W).contiguous() + + return dequantized + + def _load_binary_tensor(self, file_path: str) -> torch.Tensor: + """Load .kt format binary tensor file + + Args: + file_path: Path to .kt binary file + + Returns: + torch.Tensor: Loaded tensor + """ + if not os.path.exists(file_path): + raise FileNotFoundError(f"File not found: {file_path}") + + with open(file_path, 'rb') as f: + binary_data = f.read() + + # Determine dtype based on file name + if 'scale' in file_path: + # Scale tensors are typically float32 + np_array = np.frombuffer(binary_data, dtype=np.float32) + else: + # Quant tensors are typically int8 + np_array = np.frombuffer(binary_data, dtype=np.int8) + + tensor = torch.from_numpy(np_array.copy()) + return tensor + + def _load_layer_tensors_from_disk(self, layer_idx: int) -> Dict[str, torch.Tensor]: + """Load all quantized tensors from _layer_{layer_idx} folder + + Args: + layer_idx: Layer index + + Returns: + Dict[str, torch.Tensor]: Dictionary with keys in format: + 'blk.{layer}.ffn_{proj}_exps.{expert}.numa.{numa_idx}.{weight|scale}' + """ + layer_path = os.path.join(self.output_path, f"_layer_{layer_idx}") + if not os.path.exists(layer_path): + raise FileNotFoundError(f"Layer folder not found: {layer_path}") + + tensors = {} + + # Get AMX method from quant_method parameter (INT4/INT8) + # Map quant_method to AMX_METHOD format + quant_to_amx_map = { + "int4": "INT4", + "int8": "INT8", + } + amx_method = quant_to_amx_map.get(self.quant_method, "INT4") + + # Iterate through all NUMA folders + for numa_idx in range(self.subpool_count): + numa_folder = os.path.join(layer_path, f"_numa_{numa_idx}") + if not os.path.exists(numa_folder): + continue + + # Iterate through all experts + for expert_id in range(self.num_experts): + # For each projection (down, gate, up) + proj_mappings = [ + ('down', 'ffn_down_exps'), + ('gate', 'ffn_gate_exps'), + ('up', 'ffn_up_exps') + ] + + for proj_name, proj_key in proj_mappings: + # Build file patterns + quant_pattern = os.path.join( + numa_folder, + f'{amx_method}_{proj_name}_{expert_id}_*Byte_quant_.kt' + ) + scale_pattern = os.path.join( + numa_folder, + f'{amx_method}_{proj_name}_{expert_id}_*Byte_scale_.kt' + ) + + # Find files using glob + quant_files = glob.glob(quant_pattern) + scale_files = glob.glob(scale_pattern) + + # Build keys (following merge_small_tensor.py format) + weight_key = f"blk.{layer_idx}.{proj_key}.{expert_id}.numa.{numa_idx}.weight" + scale_key = f"blk.{layer_idx}.{proj_key}.{expert_id}.numa.{numa_idx}.scale" + + # Load quant tensor + if quant_files: + if len(quant_files) > 1: + raise ValueError(f"Multiple quant files found: {quant_files}") + tensors[weight_key] = self._load_binary_tensor(quant_files[0]) + + # Load scale tensor + if scale_files: + if len(scale_files) > 1: + raise ValueError(f"Multiple scale files found: {scale_files}") + tensors[scale_key] = self._load_binary_tensor(scale_files[0]) + + return tensors + + def _remove_layer_folder(self, layer_idx: int): + """Remove _layer_{layer_idx} folder and all its contents + + Args: + layer_idx: Layer index + """ + import shutil + + layer_path = os.path.join(self.output_path, f"_layer_{layer_idx}") + if os.path.exists(layer_path): + shutil.rmtree(layer_path) + print(f" Removed temporary folder: {layer_path}") + + def _convert_layer_experts(self, layer_idx: int, expert_ids: List[int]) -> Dict[str, torch.Tensor]: + """Convert all experts in a layer using online quantization via AMXMoEWrapper""" start_time = time.time() - print(f"Converting layer {layer_idx} with {len(expert_ids)} experts...") + print(f"Converting layer {layer_idx} with {len(expert_ids)} experts via online quantization...") + + # Load all expert weights for this layer + gate_weights = [] + up_weights = [] + down_weights = [] - # Pre-compute projection name mappings - proj_mappings = {"up_proj": "ffn_up_exps", "gate_proj": "ffn_gate_exps", "down_proj": "ffn_down_exps"} for expert_id in expert_ids: - # Load expert's all tensors for this projection at once - # up_expert_weights_out = torch.tensor() - # gate_expert_weights_out = torch.tensor() - # down_expert_weights_out = torch.tensor() - for proj_name, out_proj in proj_mappings.items(): - weight_key = f"model.layers.{layer_idx}.mlp.experts.{expert_id}.{proj_name}.weight" + gate_key = f"model.layers.{layer_idx}.mlp.experts.{expert_id}.gate_proj.weight" + up_key = f"model.layers.{layer_idx}.mlp.experts.{expert_id}.up_proj.weight" + down_key = f"model.layers.{layer_idx}.mlp.experts.{expert_id}.down_proj.weight" - if weight_key in self.tensor_file_map_bf16: - weight = self._load_tensor_bf16(weight_key) - if proj_name == "up_proj": - up_expert_weights = weight - up_output_tensor = torch.empty(weight.numel(), dtype=torch.uint8).continuous() - output_tensors[f"blk.{layer_idx}.{out_proj}.{expert_id}.weight"] = up_output_tensor - elif proj_name == "gate_proj": - gate_expert_weights = weight - gate_output_tensor = torch.empty(weight.numel(), dtype=torch.uint8).continuous() - output_tensors[f"blk.{layer_idx}.{out_proj}.{expert_id}.weight"] = gate_output_tensor - elif proj_name == "down_proj": - down_expert_weights = weight - down_output_tensor = torch.empty(weight.numel(), dtype=torch.uint8).continuous() - output_tensors[f"blk.{layer_idx}.{out_proj}.{expert_id}.weight"] = down_output_tensor + if gate_key not in self.tensor_file_map: + raise KeyError(f"Missing gate weight for layer {layer_idx}, expert {expert_id}") + if up_key not in self.tensor_file_map: + raise KeyError(f"Missing up weight for layer {layer_idx}, expert {expert_id}") + if down_key not in self.tensor_file_map: + raise KeyError(f"Missing down weight for layer {layer_idx}, expert {expert_id}") - # call c++ api to get qweights and scales + # Load weights based on input type + if self.input_type == "fp8": + # Load FP8 weights and their scale_inv tensors + gate_scale_key = f"model.layers.{layer_idx}.mlp.experts.{expert_id}.gate_proj.weight_scale_inv" + up_scale_key = f"model.layers.{layer_idx}.mlp.experts.{expert_id}.up_proj.weight_scale_inv" + down_scale_key = f"model.layers.{layer_idx}.mlp.experts.{expert_id}.down_proj.weight_scale_inv" + + if gate_scale_key not in self.tensor_file_map: + raise KeyError(f"Missing gate weight_scale_inv for layer {layer_idx}, expert {expert_id}") + if up_scale_key not in self.tensor_file_map: + raise KeyError(f"Missing up weight_scale_inv for layer {layer_idx}, expert {expert_id}") + if down_scale_key not in self.tensor_file_map: + raise KeyError(f"Missing down weight_scale_inv for layer {layer_idx}, expert {expert_id}") + + # Load FP8 weights and scales + gate_fp8 = self._load_tensor(gate_key) + up_fp8 = self._load_tensor(up_key) + down_fp8 = self._load_tensor(down_key) + + gate_scale_inv = self._load_tensor(gate_scale_key) + up_scale_inv = self._load_tensor(up_scale_key) + down_scale_inv = self._load_tensor(down_scale_key) + + # Dequantize FP8 to BF16 using block-wise scaling + gate_weight = self._dequantize_fp8_blockwise(gate_fp8, gate_scale_inv) + up_weight = self._dequantize_fp8_blockwise(up_fp8, up_scale_inv) + down_weight = self._dequantize_fp8_blockwise(down_fp8, down_scale_inv) + + elif self.input_type == "fp16": + # Load FP16 and convert to BF16 + gate_weight = self._load_tensor(gate_key).to(torch.bfloat16) + up_weight = self._load_tensor(up_key).to(torch.bfloat16) + down_weight = self._load_tensor(down_key).to(torch.bfloat16) + + elif self.input_type == "bf16": + # Load BF16 directly + gate_weight = self._load_tensor(gate_key) + up_weight = self._load_tensor(up_key) + down_weight = self._load_tensor(down_key) + + else: + raise ValueError(f"Unsupported input_type for INT4 conversion: {self.input_type}") + + gate_weights.append(gate_weight) + up_weights.append(up_weight) + down_weights.append(down_weight) + + # Stack weights into single tensors: [num_experts, ...] + gate_proj = torch.stack(gate_weights, dim=0).contiguous() + up_proj = torch.stack(up_weights, dim=0).contiguous() + down_proj = torch.stack(down_weights, dim=0).contiguous() + + print(f" Loaded weights shapes:") + print(f" gate_proj: {gate_proj.shape}") + print(f" up_proj: {up_proj.shape}") + print(f" down_proj: {down_proj.shape}") + + # Create physical_to_logical_map: identity mapping where position i maps to expert i + physical_to_logical_map = torch.arange(self.num_experts, dtype=torch.int64) + + # Create AMXMoEWrapper instance for this layer + # num_gpu_experts=0 since we're converting all experts to CPU format + wrapper = AMXMoEWrapper( + layer_idx=layer_idx, + num_experts=self.num_experts, + num_experts_per_tok=self.num_experts_per_tok, + hidden_size=self.hidden_size, + moe_intermediate_size=self.moe_intermediate_size, + num_gpu_experts=0, # All experts on CPU for conversion + cpuinfer_threads=self.cpuinfer_threads, + subpool_count=self.subpool_count, + amx_weight_path=self.output_path, # Output path for quantized weights + chunked_prefill_size=512, # Arbitrary value, not critical for conversion + cpu_save=True, # Enable saving quantized weights to output + ) + + # Load and quantize weights from tensors + # This triggers the quantization process and saves to disk + wrapper.load_weights_from_tensors(gate_proj, up_proj, down_proj, physical_to_logical_map) + + # Clean up to free memory + del gate_weights, up_weights, down_weights + del gate_proj, up_proj, down_proj + gc.collect() + + # Load quantized tensors from disk + print(f" Loading quantized tensors from disk...") + layer_tensors = self._load_layer_tensors_from_disk(layer_idx) + print(f" Loaded {len(layer_tensors)} tensors") + + # Remove temporary layer folder + self._remove_layer_folder(layer_idx) - gc.collect() elapsed = time.time() - start_time - print(f" Generated {len(output_tensors)} tensors in {elapsed:.2f}s") - return output_tensors + print(f" Layer {layer_idx} quantized and saved in {elapsed:.2f}s") + # Return loaded tensors + return layer_tensors -class Int8ToColumnMajorConverter(ConverterBase): - """Convert raw INT8 safetensors to NUMA-sliced column-major format. - - NOTE: Implement `_convert_layer_experts` with the correct INT8 transformation rules. - """ - - def _convert_layer_experts(self, layer_idx: int, expert_ids: List[int]) -> Dict[str, torch.Tensor]: - raise NotImplementedError("INT8 converter not implemented yet. Please implement transformation logic.") - +""" +Example usage(test passed): +python convert_weights.py --input-path /mnt/data3/models/DeepSeek-V3.1 --input-type fp8 --output /mnt/data3/models/DeepSeek-V3.1-INT4-test --quant-method int4 --cpuinfer-threads 60 --subpool-count 2 +python convert_weights.py --input-path /mnt/data2/models/Qwen3-Next-80B-A3B-Instruct --input-type bf16 --output /mnt/data2/models/Qwen3-Next-80B-A3B-Instruct-INT4-test --quant-method int4 --cpuinfer-threads 60 --subpool-count 2 +""" def main(): - parser = argparse.ArgumentParser(description="Convert AWQ SafeTensors to column major 1D format") - parser.add_argument("--input", "-i", required=True, help="Input directory with raw AWQ safetensors") - parser.add_argument("--bf16_path", help="Path to bf16 weights if needed for mixed precision(amx for int4&int8)") - parser.add_argument("--output", "-o", required=True, help="Output directory for hybrid safetensors") + parser = argparse.ArgumentParser(description="Convert SafeTensors to column major 1D format") + parser.add_argument("--input-path", "-i", required=True, help="Input directory with safetensors") parser.add_argument( - "--quant_method", + "--input-type", + choices=["awq", "fp8", "fp16", "bf16"], + required=True, + help="Input weight type (awq/fp8/fp16/bf16)", + ) + parser.add_argument("--output", "-o", required=True, help="Output directory for converted safetensors") + parser.add_argument( + "--quant-method", choices=["int4", "int8", "awq"], default="int4", - help="Quantization method used in input (default: int4)", + help="Quantization method for output (default: int4)", + ) + parser.add_argument( + "--cpuinfer-threads", + type=int, + default=60, + help="Number of CPU inference threads (default: 60)", + ) + parser.add_argument( + "--subpool-count", + type=int, + default=2, + help="Number of NUMA subpools for thread distribution (default: 2)", ) parser.add_argument("--gpu", action="store_true", help="Use GPU for conversion if available") @@ -503,21 +809,38 @@ def main(): device = torch.device("cuda" if args.gpu and torch.cuda.is_available() else "cpu") # Validate inputs - if not os.path.exists(args.input): - print(f"Error: Input path does not exist: {args.input}") + if not os.path.exists(args.input_path): + print(f"Error: Input path does not exist: {args.input_path}") return 1 try: + # Load model configuration from config.json + print("Loading model configuration...") + model_config = load_model_config(args.input_path, args.input_type) + print(f"Model config: {model_config}") + print(f" num_experts: {model_config['num_experts']}") + print(f" num_experts_per_tok: {model_config['num_experts_per_tok']}") + print(f" hidden_size: {model_config['hidden_size']}") + print(f" moe_intermediate_size: {model_config['moe_intermediate_size']}") + print(f"CPU inference config:") + print(f" cpuinfer_threads: {args.cpuinfer_threads}") + print(f" subpool_count: {args.subpool_count}") + print() # Create converter by quantization method quant_method = args.quant_method.lower() if quant_method == "awq": - converter = AWQToColumnMajorConverter(args.input, args.output) - elif quant_method == "int4" and args.bf16_path: - converter = Int4ToColumnMajorConverter(args.input, args.output, args.bf16_path) - elif quant_method == "int8" and args.bf16_path: - converter = Int8ToColumnMajorConverter(args.input, args.output, args.bf16_path) + converter = AWQToColumnMajorConverter( + args.input_path, args.output, model_config, args.cpuinfer_threads, args.subpool_count + ) + elif quant_method in ["int4", "int8"] and args.input_type in ["fp8", "fp16", "bf16"]: + # Use OnlineQuantConverter for both INT4 and INT8 quantization + converter = OnlineQuantConverter( + args.input_path, args.output, model_config, args.cpuinfer_threads, args.subpool_count, args.input_type, quant_method + ) else: - raise ValueError(f"Unsupported quant_method: {args.quant_method} or missing bf16_path for int4/int8") + raise ValueError( + f"Unsupported quant_method: {args.quant_method} or incompatible input_type: {args.input_type}" + ) # Run conversion converter.convert()