Add quantization_config.json

This commit is contained in:
turboderp
2025-04-08 15:27:17 +02:00
parent 11822a5505
commit b73e159bc8
6 changed files with 119 additions and 5 deletions

View File

@@ -5,6 +5,7 @@ from ..loader.safetensors import SafetensorsCollection
from ..version import __version__
from safetensors.torch import save_file
from ..util.memory import free_mem
from .quant_config import update_config, create_quantization_config_json
def tsize(t):
return t.nelement() * t.element_size()
@@ -128,10 +129,15 @@ def compile_model(args, model, config, tokenizer):
"cols": args["cal_cols"],
}
}
update_config(config_dict)
config_dict["quantization_config"] = qcfg
with open(os.path.join(out_dir, "config.json"), "w") as f:
f.write(json.dumps(config_dict, indent = 4))
# Add extra metadata to quant_config
print(f" -- Creating quantization_config.json")
create_quantization_config_json(out_dir)
print(f" -- Finished compiling model to {out_dir}")

View File

@@ -0,0 +1,51 @@
from ..models import Config, Model
import os, json
def update_config(
config_dict: dict
):
"""
Make necessary updates to config.json
"""
if "tied_word_embeddings" in config_dict:
config_dict["tied_word_embeddings"] = True
def create_quantization_config_json(
model_dir: str
):
# Create model instance without loading
config = Config.from_directory(model_dir)
model = Model.from_config(config)
# Create tensor map
storage_dict = {}
for module in model:
# Only list leaf nodes
if len(module.modules) > 0:
continue
module_dict = {}
stored_tensors = config.stc.list_tensors(module.key)
module_dict["stored_tensors"] = stored_tensors
qformat = module.quant_format_id()
if qformat == "EXL3":
shape = stored_tensors[f"{module.key}.trellis"]["shape"]
module_dict["quant_format"] = "exl3"
module_dict["bits_per_weight"] = shape[-1] // 16
storage_dict[module.key] = module_dict
# Grab quantization_config from config.json
with open(os.path.join(model_dir, "config.json"), "r") as f:
config_dict = json.load(f)
assert "quantization_config" in config_dict, f"{model_dir} does not appear to be a quantized model"
quantization_config = config_dict["quantization_config"]
# Update config with storage data
quantization_config["tensor_storage"] = storage_dict
# Write
with open(os.path.join(model_dir, "quantization_config.json"), "w") as f:
f.write(json.dumps(quantization_config, indent = 4))

View File

@@ -136,6 +136,29 @@ class SafetensorsCollection:
return bytesize
def list_tensors(
self,
prefix: str,
) -> dict:
keys = [
key for key in self.tensor_file_map.keys()
if key == prefix or key.startswith(prefix + ".")
]
results = {}
for key in keys:
filename = self.tensor_file_map[key]
header = self.file_headers[filename]
h = header[key]
dtype, np_dtype, esize = convert_dtype(h["dtype"])
beg, end = h["data_offsets"]
results[key] = {
"shape": h["shape"],
"n_bytes": end - beg,
"dtype": str(dtype),
}
return results
def get_tensors(
self,
prefix: str,

View File

@@ -71,11 +71,15 @@ class Linear(Module):
return False
def load_exl3(self, key: str) -> bool:
if not self.config.stc.has_tensor_group(
def is_exl3_storage(self, key: str):
return self.config.stc.has_tensor_group(
key,
[["sv", "svh"], ["su", "suh"], "trellis"]
): return False
)
def load_exl3(self, key: str) -> bool:
if not self.is_exl3_storage(key):
return False
self.used_alt_key = key == self.alt_key
scale = self.config.stc.get_tensor(key + ".scale", self.device, optional = True)
su = self.config.stc.get_tensor(key + ".su", self.device, optional = True)
@@ -210,4 +214,12 @@ class Linear(Module):
quant_args[self.qbits_key],
surplus_bits,
self
)
)
def quant_format_id(self):
# alt_key is only used when loading unquantized model
if self.is_exl3_storage(self.key):
return "EXL3"
else:
return None

View File

@@ -84,4 +84,7 @@ class Module(ABC):
def register_submodule(self, module: Module | None):
if module is not None:
self.modules.append(module)
self.modules.append(module)
def quant_format_id(self):
return None

19
util/add_quant_config.py Normal file
View File

@@ -0,0 +1,19 @@
import sys, os
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from exllamav3.conversion.quant_config import create_quantization_config_json
import argparse
def main(args):
filename = os.path.join(args.model_dir, "quantization_config.json")
update = os.path.exists(filename)
create_quantization_config_json(args.model_dir)
if update:
print(f"Updated {filename}")
else:
print(f"Created {filename}")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("-m", "--model_dir", type = str, help = "Path to model directory", required = True)
_args = parser.parse_args()
main(_args)