Files
ktransformers/kt-kernel/scripts/compare_weights.py
2025-11-03 15:19:52 +08:00

530 lines
17 KiB
Python

#!/usr/bin/env python3
"""
Compare two sets of quantized weights generated by convert_cpu_weights.py
This script supports comparing:
- Two safetensor format weights (merged)
- Two .kt format weights (layer folder structure)
- One safetensor and one .kt format (cross-format comparison)
Usage:
python compare_weights.py --path1 /path/to/weights1 --path2 /path/to/weights2
python compare_weights.py --path1 /path/to/weights1 --path2 /path/to/weights2 --tolerance 1e-5
"""
import argparse
import os
import glob
import numpy as np
import torch
from safetensors import safe_open
from typing import Dict, Tuple
from collections import defaultdict
def unpack_awq_int32_to_int8(packed: np.ndarray, bits: int = 4) -> np.ndarray:
"""Unpack AWQ int32 packed format to int8
AWQ uses INT4 quantization: 8 x 4-bit values packed into 1 x 32-bit integer
Args:
packed: Packed int32 array
bits: Number of bits per element (default: 4)
Returns:
Unpacked int8 array
"""
if packed.dtype != np.int32:
# Try to reinterpret as int32
packed = packed.view(np.int32)
pack_num = 32 // bits # 8 for INT4
unpacked_size = packed.size * pack_num
unpacked = np.empty(unpacked_size, dtype=np.int8)
for i in range(pack_num):
shift = i * bits
mask = (1 << bits) - 1 # 0x0F for 4-bit
unpacked[i::pack_num] = ((packed >> shift) & mask).astype(np.int8)
return unpacked
def normalize_tensor_dtype(tensor: np.ndarray, tensor_name: str, is_awq: bool = False) -> np.ndarray:
"""Normalize tensor to consistent dtype based on tensor type
Args:
tensor: Input tensor
tensor_name: Name of the tensor (used to determine type)
is_awq: Whether this is AWQ format (requires unpacking)
Returns:
Normalized tensor with consistent dtype
"""
# Determine tensor type from name
is_scale = "scale" in tensor_name
is_weight = "weight" in tensor_name
is_qzeros = "qzeros" in tensor_name
if is_scale:
# Scale should be float32
if tensor.dtype != np.float32:
# Try to reinterpret bytes as float32
tensor = tensor.view(np.float32)
return tensor
elif is_weight or is_qzeros:
# Weight/qzeros should be int8
if is_awq and tensor.dtype == np.int32:
# AWQ format: unpack int32 to int8
tensor = unpack_awq_int32_to_int8(tensor)
elif tensor.dtype == np.float32:
# Two cases for float32:
# Case 1: Values look like int8 values (e.g., [37., 73., -70.])
# -> use astype to convert values
# Case 2: Values are large scientific notation (e.g., [2.6e34, ...])
# -> use view to reinterpret bytes
# Check if values are in int8 range (-128 to 127)
if len(tensor) > 0:
sample_size = min(100, len(tensor))
sample_values = tensor.flat[:sample_size]
# If most values are in int8 range and have no decimal parts
in_int8_range = np.all((sample_values >= -128) & (sample_values <= 127))
is_integer_valued = np.all(sample_values == np.round(sample_values))
if in_int8_range and is_integer_valued:
# Case 1: Direct value conversion
tensor = tensor.astype(np.int8)
else:
# Case 2: Byte reinterpretation (4 bytes -> 4 int8s)
tensor = tensor.view(np.int8)
else:
tensor = tensor.astype(np.int8)
elif tensor.dtype == np.int32:
# Reinterpret int32 as int8 (4x more elements)
tensor = tensor.view(np.int8)
elif tensor.dtype != np.int8:
# Other types: try to convert
tensor = tensor.astype(np.int8)
return tensor
else:
# Unknown type, return as-is
return tensor
def load_kt_binary(file_path: str) -> np.ndarray:
"""Load .kt format binary tensor file
Args:
file_path: Path to .kt binary file
Returns:
numpy array with the loaded tensor
"""
if not os.path.exists(file_path):
raise FileNotFoundError(f"File not found: {file_path}")
with open(file_path, "rb") as f:
binary_data = f.read()
# Determine dtype based on file name
if "scale" in file_path:
dtype = np.float32
else:
dtype = np.int8
return np.frombuffer(binary_data, dtype=dtype)
def detect_weight_format(path: str) -> str:
"""Detect if weights are in safetensor or .kt format
Args:
path: Path to weight directory
Returns:
'safetensor' or 'kt' or 'unknown'
"""
if not os.path.exists(path):
raise FileNotFoundError(f"Path not found: {path}")
# Check for safetensor files
safetensor_files = glob.glob(os.path.join(path, "*.safetensors"))
if safetensor_files:
return "safetensor"
# Check for layer folder structure
layer_folders = glob.glob(os.path.join(path, "_layer_*"))
if layer_folders:
return "kt"
return "unknown"
def detect_awq_format(weights_sample: Dict[str, np.ndarray]) -> bool:
"""Detect if weights are in AWQ format
AWQ format characteristics:
- Has 'qzeros' tensors
- Weight tensors are int32 dtype (packed)
Args:
weights_sample: Sample of loaded weights
Returns:
True if AWQ format detected
"""
has_qzeros = any("qzeros" in key for key in weights_sample.keys())
if not has_qzeros:
return False
# Check if weight tensors are int32
for key, tensor in weights_sample.items():
if "weight" in key and tensor.dtype == np.int32:
return True
return False
def load_safetensor_weights(path: str) -> Dict[str, np.ndarray]:
"""Load all weights from safetensor format
Args:
path: Path to directory containing safetensor files
Returns:
Dictionary mapping tensor names to numpy arrays (dtype normalized)
"""
weights = {}
safetensor_files = sorted(glob.glob(os.path.join(path, "*.safetensors")))
if not safetensor_files:
raise FileNotFoundError(f"No safetensor files found in {path}")
print(f"Loading safetensor files from {path}")
# First pass: load all tensors
for file in safetensor_files:
with safe_open(file, framework="pt") as f:
for key in f.keys():
# Only load MoE expert weights for comparison
if ".ffn_" in key and "_exps." in key:
tensor = f.get_tensor(key)
weights[key] = tensor.cpu().numpy()
# Detect AWQ format
is_awq = detect_awq_format(weights)
print(f" Format detected: {'AWQ' if is_awq else 'INT4/INT8'}")
# Second pass: normalize dtypes
print(f" Normalizing dtypes...")
for key in list(weights.keys()):
original_dtype = weights[key].dtype
original_shape = weights[key].shape
weights[key] = normalize_tensor_dtype(weights[key], key, is_awq=is_awq)
if weights[key].shape != original_shape or weights[key].dtype != original_dtype:
print(f" {key}: {original_dtype}{original_shape} -> {weights[key].dtype}{weights[key].shape}")
print(f" Loaded {len(weights)} tensors from safetensor format")
return weights
def load_kt_weights(path: str) -> Dict[str, np.ndarray]:
"""Load all weights from .kt format (layer folder structure)
Args:
path: Path to directory containing _layer_* folders
Returns:
Dictionary mapping tensor names to numpy arrays
"""
weights = {}
layer_folders = sorted(glob.glob(os.path.join(path, "_layer_*")))
if not layer_folders:
raise FileNotFoundError(f"No _layer_* folders found in {path}")
print(f"Loading .kt files from {path}")
for layer_folder in layer_folders:
# Extract layer index from folder name
layer_idx = int(os.path.basename(layer_folder).split("_")[-1])
# Find all NUMA folders
numa_folders = sorted(glob.glob(os.path.join(layer_folder, "_numa_*")))
for numa_folder in numa_folders:
# Extract NUMA index
numa_idx = int(os.path.basename(numa_folder).split("_")[-1])
# Find all .kt files
kt_files = glob.glob(os.path.join(numa_folder, "*.kt"))
for kt_file in kt_files:
filename = os.path.basename(kt_file)
# Parse filename to extract metadata
# Format: {METHOD}_{proj}_{expert}_{size}Byte_{type}_.kt
parts = filename.replace(".kt", "").split("_")
if len(parts) >= 5:
method = parts[0] # INT4, INT8, etc.
proj = parts[1] # down, gate, up
expert = parts[2] # expert ID
tensor_type = parts[4] # quant or scale
# Map proj names
proj_map = {"down": "ffn_down_exps", "gate": "ffn_gate_exps", "up": "ffn_up_exps"}
proj_key = proj_map.get(proj, proj)
# Build key matching safetensor format
if tensor_type == "quant":
key = f"blk.{layer_idx}.{proj_key}.{expert}.numa.{numa_idx}.weight"
else: # scale
key = f"blk.{layer_idx}.{proj_key}.{expert}.numa.{numa_idx}.scale"
# Load tensor
weights[key] = load_kt_binary(kt_file)
# Normalize dtypes (.kt format is never AWQ)
print(f" Normalizing dtypes...")
for key in list(weights.keys()):
original_dtype = weights[key].dtype
original_shape = weights[key].shape
weights[key] = normalize_tensor_dtype(weights[key], key, is_awq=False)
if weights[key].shape != original_shape or weights[key].dtype != original_dtype:
print(f" {key}: {original_dtype}{original_shape} -> {weights[key].dtype}{weights[key].shape}")
print(f" Loaded {len(weights)} tensors from .kt format")
return weights
def normalize_key(key: str) -> Tuple[int, str, int, str]:
"""Normalize tensor key to extract layer, projection, expert, and type
Args:
key: Tensor key like "blk.0.ffn_up_exps.5.weight" or "blk.0.ffn_up_exps.5.numa.0.weight"
Returns:
Tuple of (layer_idx, proj_name, expert_idx, tensor_type)
"""
parts = key.split(".")
layer_idx = int(parts[1])
proj_name = parts[2]
expert_idx = int(parts[3])
# Handle both formats: with and without numa
if "numa" in key:
tensor_type = parts[6] # weight or scale
else:
tensor_type = parts[4] # weight, scale, or qzeros
return (layer_idx, proj_name, expert_idx, tensor_type)
def compare_weights(
weights1: Dict[str, np.ndarray], weights2: Dict[str, np.ndarray], tolerance: float = 1e-6
) -> Tuple[bool, Dict[str, Dict]]:
"""Compare two sets of weights
Args:
weights1: First set of weights
weights2: Second set of weights
tolerance: Numerical tolerance for comparison
Returns:
Tuple of (all_match, differences_dict)
"""
print("\n" + "=" * 80)
print("WEIGHT COMPARISON")
print("=" * 80)
# Group keys by normalized form (ignoring numa index)
def group_by_base_key(weights):
groups = defaultdict(list)
for key in weights.keys():
try:
layer, proj, expert, ttype = normalize_key(key)
base_key = f"blk.{layer}.{proj}.{expert}.{ttype}"
groups[base_key].append(key)
except:
# Skip keys that don't match expected format
pass
return groups
groups1 = group_by_base_key(weights1)
groups2 = group_by_base_key(weights2)
all_base_keys = sorted(set(groups1.keys()) | set(groups2.keys()))
all_match = True
differences = {}
total_comparisons = 0
matching_comparisons = 0
for base_key in all_base_keys:
keys1 = groups1.get(base_key, [])
keys2 = groups2.get(base_key, [])
if not keys1:
print(f"❌ Missing in weights1: {base_key}")
differences[base_key] = {"status": "missing_in_weights1"}
all_match = False
continue
if not keys2:
print(f"❌ Missing in weights2: {base_key}")
differences[base_key] = {"status": "missing_in_weights2"}
all_match = False
continue
# For kt format, we may have multiple keys (one per NUMA node)
# We need to concatenate them for comparison
if len(keys1) > 1 or len(keys2) > 1:
# Concatenate tensors from all NUMA nodes
tensor1 = np.concatenate([weights1[k] for k in sorted(keys1)])
tensor2 = np.concatenate([weights2[k] for k in sorted(keys2)])
else:
tensor1 = weights1[keys1[0]]
tensor2 = weights2[keys2[0]]
total_comparisons += 1
# Debug: print dtype and shape info
if tensor1.dtype != tensor2.dtype:
print(f"⚠️ Dtype mismatch for {base_key}: {tensor1.dtype} vs {tensor2.dtype}")
print(f" This should have been normalized. Shape: {tensor1.shape} vs {tensor2.shape}")
# Compare shapes
if tensor1.shape != tensor2.shape:
print(f"❌ Shape mismatch for {base_key}:")
print(f" Shape1: {tensor1.shape} (dtype: {tensor1.dtype})")
print(f" Shape2: {tensor2.shape} (dtype: {tensor2.dtype})")
differences[base_key] = {
"status": "shape_mismatch",
"shape1": tensor1.shape,
"shape2": tensor2.shape,
"dtype1": str(tensor1.dtype),
"dtype2": str(tensor2.dtype),
}
all_match = False
continue
# Compare dtypes (should be consistent after normalization)
if tensor1.dtype != tensor2.dtype:
print(f"❌ Dtype mismatch for {base_key} after normalization:")
print(f" Dtype1: {tensor1.dtype}")
print(f" Dtype2: {tensor2.dtype}")
differences[base_key] = {
"status": "dtype_mismatch",
"dtype1": str(tensor1.dtype),
"dtype2": str(tensor2.dtype),
}
all_match = False
continue
# Compare values
if np.allclose(tensor1, tensor2, atol=tolerance, rtol=tolerance):
matching_comparisons += 1
else:
max_diff = np.max(np.abs(tensor1 - tensor2))
mean_diff = np.mean(np.abs(tensor1 - tensor2))
print(f"❌ Value mismatch for {base_key}:")
print(f" Max difference: {max_diff:.2e}")
print(f" Mean difference: {mean_diff:.2e}")
print(f" Tolerance: {tolerance:.2e}")
differences[base_key] = {
"status": "value_mismatch",
"max_diff": float(max_diff),
"mean_diff": float(mean_diff),
"tolerance": tolerance,
}
all_match = False
print("\n" + "=" * 80)
print("SUMMARY")
print("=" * 80)
print(f"Total comparisons: {total_comparisons}")
print(f"Matching: {matching_comparisons}")
print(f"Mismatching: {total_comparisons - matching_comparisons}")
print(f"Missing tensors: {len(differences) - (total_comparisons - matching_comparisons)}")
if all_match:
print("\n✅ All weights match!")
else:
print(f"\n❌ Found {len(differences)} differences")
return all_match, differences
def main():
parser = argparse.ArgumentParser(description="Compare two sets of quantized weights")
parser.add_argument("--path1", type=str, required=True, help="Path to first weight directory")
parser.add_argument("--path2", type=str, required=True, help="Path to second weight directory")
parser.add_argument(
"--tolerance", type=float, default=1e-6, help="Numerical tolerance for comparison (default: 1e-6)"
)
args = parser.parse_args()
# Validate paths
if not os.path.exists(args.path1):
print(f"Error: Path1 does not exist: {args.path1}")
return 1
if not os.path.exists(args.path2):
print(f"Error: Path2 does not exist: {args.path2}")
return 1
# Detect formats
print("Detecting weight formats...")
format1 = detect_weight_format(args.path1)
format2 = detect_weight_format(args.path2)
print(f"Path1 format: {format1}")
print(f"Path2 format: {format2}")
if format1 == "unknown":
print(f"Error: Unable to detect weight format in {args.path1}")
return 1
if format2 == "unknown":
print(f"Error: Unable to detect weight format in {args.path2}")
return 1
# Load weights based on format
print("\nLoading weights...")
if format1 == "safetensor":
weights1 = load_safetensor_weights(args.path1)
else:
weights1 = load_kt_weights(args.path1)
if format2 == "safetensor":
weights2 = load_safetensor_weights(args.path2)
else:
weights2 = load_kt_weights(args.path2)
# Compare weights
all_match, differences = compare_weights(weights1, weights2, args.tolerance)
return 0 if all_match else 1
if __name__ == "__main__":
exit(main())