mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2026-04-20 14:29:22 +00:00
530 lines
17 KiB
Python
530 lines
17 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Compare two sets of quantized weights generated by convert_cpu_weights.py
|
|
|
|
This script supports comparing:
|
|
- Two safetensor format weights (merged)
|
|
- Two .kt format weights (layer folder structure)
|
|
- One safetensor and one .kt format (cross-format comparison)
|
|
|
|
Usage:
|
|
python compare_weights.py --path1 /path/to/weights1 --path2 /path/to/weights2
|
|
python compare_weights.py --path1 /path/to/weights1 --path2 /path/to/weights2 --tolerance 1e-5
|
|
"""
|
|
|
|
import argparse
|
|
import os
|
|
import glob
|
|
import numpy as np
|
|
import torch
|
|
from safetensors import safe_open
|
|
from typing import Dict, Tuple
|
|
from collections import defaultdict
|
|
|
|
|
|
def unpack_awq_int32_to_int8(packed: np.ndarray, bits: int = 4) -> np.ndarray:
|
|
"""Unpack AWQ int32 packed format to int8
|
|
|
|
AWQ uses INT4 quantization: 8 x 4-bit values packed into 1 x 32-bit integer
|
|
|
|
Args:
|
|
packed: Packed int32 array
|
|
bits: Number of bits per element (default: 4)
|
|
|
|
Returns:
|
|
Unpacked int8 array
|
|
"""
|
|
if packed.dtype != np.int32:
|
|
# Try to reinterpret as int32
|
|
packed = packed.view(np.int32)
|
|
|
|
pack_num = 32 // bits # 8 for INT4
|
|
unpacked_size = packed.size * pack_num
|
|
|
|
unpacked = np.empty(unpacked_size, dtype=np.int8)
|
|
|
|
for i in range(pack_num):
|
|
shift = i * bits
|
|
mask = (1 << bits) - 1 # 0x0F for 4-bit
|
|
unpacked[i::pack_num] = ((packed >> shift) & mask).astype(np.int8)
|
|
|
|
return unpacked
|
|
|
|
|
|
def normalize_tensor_dtype(tensor: np.ndarray, tensor_name: str, is_awq: bool = False) -> np.ndarray:
|
|
"""Normalize tensor to consistent dtype based on tensor type
|
|
|
|
Args:
|
|
tensor: Input tensor
|
|
tensor_name: Name of the tensor (used to determine type)
|
|
is_awq: Whether this is AWQ format (requires unpacking)
|
|
|
|
Returns:
|
|
Normalized tensor with consistent dtype
|
|
"""
|
|
# Determine tensor type from name
|
|
is_scale = "scale" in tensor_name
|
|
is_weight = "weight" in tensor_name
|
|
is_qzeros = "qzeros" in tensor_name
|
|
|
|
if is_scale:
|
|
# Scale should be float32
|
|
if tensor.dtype != np.float32:
|
|
# Try to reinterpret bytes as float32
|
|
tensor = tensor.view(np.float32)
|
|
return tensor
|
|
|
|
elif is_weight or is_qzeros:
|
|
# Weight/qzeros should be int8
|
|
if is_awq and tensor.dtype == np.int32:
|
|
# AWQ format: unpack int32 to int8
|
|
tensor = unpack_awq_int32_to_int8(tensor)
|
|
elif tensor.dtype == np.float32:
|
|
# Two cases for float32:
|
|
# Case 1: Values look like int8 values (e.g., [37., 73., -70.])
|
|
# -> use astype to convert values
|
|
# Case 2: Values are large scientific notation (e.g., [2.6e34, ...])
|
|
# -> use view to reinterpret bytes
|
|
|
|
# Check if values are in int8 range (-128 to 127)
|
|
if len(tensor) > 0:
|
|
sample_size = min(100, len(tensor))
|
|
sample_values = tensor.flat[:sample_size]
|
|
|
|
# If most values are in int8 range and have no decimal parts
|
|
in_int8_range = np.all((sample_values >= -128) & (sample_values <= 127))
|
|
is_integer_valued = np.all(sample_values == np.round(sample_values))
|
|
|
|
if in_int8_range and is_integer_valued:
|
|
# Case 1: Direct value conversion
|
|
tensor = tensor.astype(np.int8)
|
|
else:
|
|
# Case 2: Byte reinterpretation (4 bytes -> 4 int8s)
|
|
tensor = tensor.view(np.int8)
|
|
else:
|
|
tensor = tensor.astype(np.int8)
|
|
|
|
elif tensor.dtype == np.int32:
|
|
# Reinterpret int32 as int8 (4x more elements)
|
|
tensor = tensor.view(np.int8)
|
|
elif tensor.dtype != np.int8:
|
|
# Other types: try to convert
|
|
tensor = tensor.astype(np.int8)
|
|
|
|
return tensor
|
|
|
|
else:
|
|
# Unknown type, return as-is
|
|
return tensor
|
|
|
|
|
|
def load_kt_binary(file_path: str) -> np.ndarray:
|
|
"""Load .kt format binary tensor file
|
|
|
|
Args:
|
|
file_path: Path to .kt binary file
|
|
|
|
Returns:
|
|
numpy array with the loaded tensor
|
|
"""
|
|
if not os.path.exists(file_path):
|
|
raise FileNotFoundError(f"File not found: {file_path}")
|
|
|
|
with open(file_path, "rb") as f:
|
|
binary_data = f.read()
|
|
|
|
# Determine dtype based on file name
|
|
if "scale" in file_path:
|
|
dtype = np.float32
|
|
else:
|
|
dtype = np.int8
|
|
|
|
return np.frombuffer(binary_data, dtype=dtype)
|
|
|
|
|
|
def detect_weight_format(path: str) -> str:
|
|
"""Detect if weights are in safetensor or .kt format
|
|
|
|
Args:
|
|
path: Path to weight directory
|
|
|
|
Returns:
|
|
'safetensor' or 'kt' or 'unknown'
|
|
"""
|
|
if not os.path.exists(path):
|
|
raise FileNotFoundError(f"Path not found: {path}")
|
|
|
|
# Check for safetensor files
|
|
safetensor_files = glob.glob(os.path.join(path, "*.safetensors"))
|
|
if safetensor_files:
|
|
return "safetensor"
|
|
|
|
# Check for layer folder structure
|
|
layer_folders = glob.glob(os.path.join(path, "_layer_*"))
|
|
if layer_folders:
|
|
return "kt"
|
|
|
|
return "unknown"
|
|
|
|
|
|
def detect_awq_format(weights_sample: Dict[str, np.ndarray]) -> bool:
|
|
"""Detect if weights are in AWQ format
|
|
|
|
AWQ format characteristics:
|
|
- Has 'qzeros' tensors
|
|
- Weight tensors are int32 dtype (packed)
|
|
|
|
Args:
|
|
weights_sample: Sample of loaded weights
|
|
|
|
Returns:
|
|
True if AWQ format detected
|
|
"""
|
|
has_qzeros = any("qzeros" in key for key in weights_sample.keys())
|
|
|
|
if not has_qzeros:
|
|
return False
|
|
|
|
# Check if weight tensors are int32
|
|
for key, tensor in weights_sample.items():
|
|
if "weight" in key and tensor.dtype == np.int32:
|
|
return True
|
|
|
|
return False
|
|
|
|
|
|
def load_safetensor_weights(path: str) -> Dict[str, np.ndarray]:
|
|
"""Load all weights from safetensor format
|
|
|
|
Args:
|
|
path: Path to directory containing safetensor files
|
|
|
|
Returns:
|
|
Dictionary mapping tensor names to numpy arrays (dtype normalized)
|
|
"""
|
|
weights = {}
|
|
|
|
safetensor_files = sorted(glob.glob(os.path.join(path, "*.safetensors")))
|
|
if not safetensor_files:
|
|
raise FileNotFoundError(f"No safetensor files found in {path}")
|
|
|
|
print(f"Loading safetensor files from {path}")
|
|
|
|
# First pass: load all tensors
|
|
for file in safetensor_files:
|
|
with safe_open(file, framework="pt") as f:
|
|
for key in f.keys():
|
|
# Only load MoE expert weights for comparison
|
|
if ".ffn_" in key and "_exps." in key:
|
|
tensor = f.get_tensor(key)
|
|
weights[key] = tensor.cpu().numpy()
|
|
|
|
# Detect AWQ format
|
|
is_awq = detect_awq_format(weights)
|
|
print(f" Format detected: {'AWQ' if is_awq else 'INT4/INT8'}")
|
|
|
|
# Second pass: normalize dtypes
|
|
print(f" Normalizing dtypes...")
|
|
for key in list(weights.keys()):
|
|
original_dtype = weights[key].dtype
|
|
original_shape = weights[key].shape
|
|
weights[key] = normalize_tensor_dtype(weights[key], key, is_awq=is_awq)
|
|
|
|
if weights[key].shape != original_shape or weights[key].dtype != original_dtype:
|
|
print(f" {key}: {original_dtype}{original_shape} -> {weights[key].dtype}{weights[key].shape}")
|
|
|
|
print(f" Loaded {len(weights)} tensors from safetensor format")
|
|
return weights
|
|
|
|
|
|
def load_kt_weights(path: str) -> Dict[str, np.ndarray]:
|
|
"""Load all weights from .kt format (layer folder structure)
|
|
|
|
Args:
|
|
path: Path to directory containing _layer_* folders
|
|
|
|
Returns:
|
|
Dictionary mapping tensor names to numpy arrays
|
|
"""
|
|
weights = {}
|
|
|
|
layer_folders = sorted(glob.glob(os.path.join(path, "_layer_*")))
|
|
if not layer_folders:
|
|
raise FileNotFoundError(f"No _layer_* folders found in {path}")
|
|
|
|
print(f"Loading .kt files from {path}")
|
|
|
|
for layer_folder in layer_folders:
|
|
# Extract layer index from folder name
|
|
layer_idx = int(os.path.basename(layer_folder).split("_")[-1])
|
|
|
|
# Find all NUMA folders
|
|
numa_folders = sorted(glob.glob(os.path.join(layer_folder, "_numa_*")))
|
|
|
|
for numa_folder in numa_folders:
|
|
# Extract NUMA index
|
|
numa_idx = int(os.path.basename(numa_folder).split("_")[-1])
|
|
|
|
# Find all .kt files
|
|
kt_files = glob.glob(os.path.join(numa_folder, "*.kt"))
|
|
|
|
for kt_file in kt_files:
|
|
filename = os.path.basename(kt_file)
|
|
|
|
# Parse filename to extract metadata
|
|
# Format: {METHOD}_{proj}_{expert}_{size}Byte_{type}_.kt
|
|
parts = filename.replace(".kt", "").split("_")
|
|
|
|
if len(parts) >= 5:
|
|
method = parts[0] # INT4, INT8, etc.
|
|
proj = parts[1] # down, gate, up
|
|
expert = parts[2] # expert ID
|
|
tensor_type = parts[4] # quant or scale
|
|
|
|
# Map proj names
|
|
proj_map = {"down": "ffn_down_exps", "gate": "ffn_gate_exps", "up": "ffn_up_exps"}
|
|
|
|
proj_key = proj_map.get(proj, proj)
|
|
|
|
# Build key matching safetensor format
|
|
if tensor_type == "quant":
|
|
key = f"blk.{layer_idx}.{proj_key}.{expert}.numa.{numa_idx}.weight"
|
|
else: # scale
|
|
key = f"blk.{layer_idx}.{proj_key}.{expert}.numa.{numa_idx}.scale"
|
|
|
|
# Load tensor
|
|
weights[key] = load_kt_binary(kt_file)
|
|
|
|
# Normalize dtypes (.kt format is never AWQ)
|
|
print(f" Normalizing dtypes...")
|
|
for key in list(weights.keys()):
|
|
original_dtype = weights[key].dtype
|
|
original_shape = weights[key].shape
|
|
weights[key] = normalize_tensor_dtype(weights[key], key, is_awq=False)
|
|
|
|
if weights[key].shape != original_shape or weights[key].dtype != original_dtype:
|
|
print(f" {key}: {original_dtype}{original_shape} -> {weights[key].dtype}{weights[key].shape}")
|
|
|
|
print(f" Loaded {len(weights)} tensors from .kt format")
|
|
return weights
|
|
|
|
|
|
def normalize_key(key: str) -> Tuple[int, str, int, str]:
|
|
"""Normalize tensor key to extract layer, projection, expert, and type
|
|
|
|
Args:
|
|
key: Tensor key like "blk.0.ffn_up_exps.5.weight" or "blk.0.ffn_up_exps.5.numa.0.weight"
|
|
|
|
Returns:
|
|
Tuple of (layer_idx, proj_name, expert_idx, tensor_type)
|
|
"""
|
|
parts = key.split(".")
|
|
|
|
layer_idx = int(parts[1])
|
|
proj_name = parts[2]
|
|
expert_idx = int(parts[3])
|
|
|
|
# Handle both formats: with and without numa
|
|
if "numa" in key:
|
|
tensor_type = parts[6] # weight or scale
|
|
else:
|
|
tensor_type = parts[4] # weight, scale, or qzeros
|
|
|
|
return (layer_idx, proj_name, expert_idx, tensor_type)
|
|
|
|
|
|
def compare_weights(
|
|
weights1: Dict[str, np.ndarray], weights2: Dict[str, np.ndarray], tolerance: float = 1e-6
|
|
) -> Tuple[bool, Dict[str, Dict]]:
|
|
"""Compare two sets of weights
|
|
|
|
Args:
|
|
weights1: First set of weights
|
|
weights2: Second set of weights
|
|
tolerance: Numerical tolerance for comparison
|
|
|
|
Returns:
|
|
Tuple of (all_match, differences_dict)
|
|
"""
|
|
print("\n" + "=" * 80)
|
|
print("WEIGHT COMPARISON")
|
|
print("=" * 80)
|
|
|
|
# Group keys by normalized form (ignoring numa index)
|
|
def group_by_base_key(weights):
|
|
groups = defaultdict(list)
|
|
for key in weights.keys():
|
|
try:
|
|
layer, proj, expert, ttype = normalize_key(key)
|
|
base_key = f"blk.{layer}.{proj}.{expert}.{ttype}"
|
|
groups[base_key].append(key)
|
|
except:
|
|
# Skip keys that don't match expected format
|
|
pass
|
|
return groups
|
|
|
|
groups1 = group_by_base_key(weights1)
|
|
groups2 = group_by_base_key(weights2)
|
|
|
|
all_base_keys = sorted(set(groups1.keys()) | set(groups2.keys()))
|
|
|
|
all_match = True
|
|
differences = {}
|
|
|
|
total_comparisons = 0
|
|
matching_comparisons = 0
|
|
|
|
for base_key in all_base_keys:
|
|
keys1 = groups1.get(base_key, [])
|
|
keys2 = groups2.get(base_key, [])
|
|
|
|
if not keys1:
|
|
print(f"❌ Missing in weights1: {base_key}")
|
|
differences[base_key] = {"status": "missing_in_weights1"}
|
|
all_match = False
|
|
continue
|
|
|
|
if not keys2:
|
|
print(f"❌ Missing in weights2: {base_key}")
|
|
differences[base_key] = {"status": "missing_in_weights2"}
|
|
all_match = False
|
|
continue
|
|
|
|
# For kt format, we may have multiple keys (one per NUMA node)
|
|
# We need to concatenate them for comparison
|
|
if len(keys1) > 1 or len(keys2) > 1:
|
|
# Concatenate tensors from all NUMA nodes
|
|
tensor1 = np.concatenate([weights1[k] for k in sorted(keys1)])
|
|
tensor2 = np.concatenate([weights2[k] for k in sorted(keys2)])
|
|
else:
|
|
tensor1 = weights1[keys1[0]]
|
|
tensor2 = weights2[keys2[0]]
|
|
|
|
total_comparisons += 1
|
|
|
|
# Debug: print dtype and shape info
|
|
if tensor1.dtype != tensor2.dtype:
|
|
print(f"⚠️ Dtype mismatch for {base_key}: {tensor1.dtype} vs {tensor2.dtype}")
|
|
print(f" This should have been normalized. Shape: {tensor1.shape} vs {tensor2.shape}")
|
|
|
|
# Compare shapes
|
|
if tensor1.shape != tensor2.shape:
|
|
print(f"❌ Shape mismatch for {base_key}:")
|
|
print(f" Shape1: {tensor1.shape} (dtype: {tensor1.dtype})")
|
|
print(f" Shape2: {tensor2.shape} (dtype: {tensor2.dtype})")
|
|
differences[base_key] = {
|
|
"status": "shape_mismatch",
|
|
"shape1": tensor1.shape,
|
|
"shape2": tensor2.shape,
|
|
"dtype1": str(tensor1.dtype),
|
|
"dtype2": str(tensor2.dtype),
|
|
}
|
|
all_match = False
|
|
continue
|
|
|
|
# Compare dtypes (should be consistent after normalization)
|
|
if tensor1.dtype != tensor2.dtype:
|
|
print(f"❌ Dtype mismatch for {base_key} after normalization:")
|
|
print(f" Dtype1: {tensor1.dtype}")
|
|
print(f" Dtype2: {tensor2.dtype}")
|
|
differences[base_key] = {
|
|
"status": "dtype_mismatch",
|
|
"dtype1": str(tensor1.dtype),
|
|
"dtype2": str(tensor2.dtype),
|
|
}
|
|
all_match = False
|
|
continue
|
|
|
|
# Compare values
|
|
if np.allclose(tensor1, tensor2, atol=tolerance, rtol=tolerance):
|
|
matching_comparisons += 1
|
|
else:
|
|
max_diff = np.max(np.abs(tensor1 - tensor2))
|
|
mean_diff = np.mean(np.abs(tensor1 - tensor2))
|
|
|
|
print(f"❌ Value mismatch for {base_key}:")
|
|
print(f" Max difference: {max_diff:.2e}")
|
|
print(f" Mean difference: {mean_diff:.2e}")
|
|
print(f" Tolerance: {tolerance:.2e}")
|
|
|
|
differences[base_key] = {
|
|
"status": "value_mismatch",
|
|
"max_diff": float(max_diff),
|
|
"mean_diff": float(mean_diff),
|
|
"tolerance": tolerance,
|
|
}
|
|
all_match = False
|
|
|
|
print("\n" + "=" * 80)
|
|
print("SUMMARY")
|
|
print("=" * 80)
|
|
print(f"Total comparisons: {total_comparisons}")
|
|
print(f"Matching: {matching_comparisons}")
|
|
print(f"Mismatching: {total_comparisons - matching_comparisons}")
|
|
print(f"Missing tensors: {len(differences) - (total_comparisons - matching_comparisons)}")
|
|
|
|
if all_match:
|
|
print("\n✅ All weights match!")
|
|
else:
|
|
print(f"\n❌ Found {len(differences)} differences")
|
|
|
|
return all_match, differences
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description="Compare two sets of quantized weights")
|
|
parser.add_argument("--path1", type=str, required=True, help="Path to first weight directory")
|
|
parser.add_argument("--path2", type=str, required=True, help="Path to second weight directory")
|
|
parser.add_argument(
|
|
"--tolerance", type=float, default=1e-6, help="Numerical tolerance for comparison (default: 1e-6)"
|
|
)
|
|
|
|
args = parser.parse_args()
|
|
|
|
# Validate paths
|
|
if not os.path.exists(args.path1):
|
|
print(f"Error: Path1 does not exist: {args.path1}")
|
|
return 1
|
|
|
|
if not os.path.exists(args.path2):
|
|
print(f"Error: Path2 does not exist: {args.path2}")
|
|
return 1
|
|
|
|
# Detect formats
|
|
print("Detecting weight formats...")
|
|
format1 = detect_weight_format(args.path1)
|
|
format2 = detect_weight_format(args.path2)
|
|
|
|
print(f"Path1 format: {format1}")
|
|
print(f"Path2 format: {format2}")
|
|
|
|
if format1 == "unknown":
|
|
print(f"Error: Unable to detect weight format in {args.path1}")
|
|
return 1
|
|
|
|
if format2 == "unknown":
|
|
print(f"Error: Unable to detect weight format in {args.path2}")
|
|
return 1
|
|
|
|
# Load weights based on format
|
|
print("\nLoading weights...")
|
|
|
|
if format1 == "safetensor":
|
|
weights1 = load_safetensor_weights(args.path1)
|
|
else:
|
|
weights1 = load_kt_weights(args.path1)
|
|
|
|
if format2 == "safetensor":
|
|
weights2 = load_safetensor_weights(args.path2)
|
|
else:
|
|
weights2 = load_kt_weights(args.path2)
|
|
|
|
# Compare weights
|
|
all_match, differences = compare_weights(weights1, weights2, args.tolerance)
|
|
|
|
return 0 if all_match else 1
|
|
|
|
|
|
if __name__ == "__main__":
|
|
exit(main())
|