Add recompile utility script

This commit is contained in:
turboderp
2025-07-08 19:51:01 +02:00
parent 6341b119ef
commit 6f77a33415
4 changed files with 113 additions and 26 deletions

View File

@@ -20,9 +20,12 @@ def compile_model(args, model, config, tokenizer):
in_dir = args["in_dir"]
out_dir = args["out_dir"]
work_dir = args["work_dir"]
qtensors_dir = os.path.join(work_dir, "qtensors")
qtensors_stc = SafetensorsCollection(qtensors_dir)
if args.get("model_stc"):
qtensors_stc = config.stc
else:
work_dir = args["work_dir"]
qtensors_dir = os.path.join(work_dir, "qtensors")
qtensors_stc = SafetensorsCollection(qtensors_dir)
# Prepare output directory
if not os.path.exists(out_dir):
@@ -140,24 +143,29 @@ def compile_model(args, model, config, tokenizer):
print(f" -- Writing config.json")
with open(os.path.join(in_dir, "config.json"), "r") as f:
config_dict = json.load(f)
qcfg = {
"quant_method": "exl3",
"version": __version__,
"bits": args["bits"],
"head_bits": args["head_bits"],
"calibration": {
"rows": args["cal_rows"],
"cols": args["cal_cols"],
},
"out_scales": {True: "always", False: "never", None: "auto"}[args["apply_out_scales"]],
}
if any(args.get(x) for x in ["mcg_multiplier", "mul1_multiplier"]):
exp_qcfg = {}
if args.get("mcg_multiplier"):
exp_qcfg["mcg_multiplier"] = args.get("mcg_multiplier")
if args.get("mul1_multiplier"):
exp_qcfg["mul1_multiplier"] = args.get("mul1_multiplier")
qcfg["experimental_options"] = exp_qcfg
if "quantization_config" in config_dict:
qcfg = config_dict["quantization_config"]
qcfg["bits"] = args["bits"]
qcfg["head_bits"] = args["head_bits"]
else:
qcfg = {
"quant_method": "exl3",
"version": __version__,
"bits": args["bits"],
"head_bits": args["head_bits"],
"calibration": {
"rows": args["cal_rows"],
"cols": args["cal_cols"],
},
"out_scales": {True: "always", False: "never", None: "auto"}[args["apply_out_scales"]],
}
if any(args.get(x) for x in ["mcg_multiplier", "mul1_multiplier"]):
exp_qcfg = {}
if args.get("mcg_multiplier"):
exp_qcfg["mcg_multiplier"] = args.get("mcg_multiplier")
if args.get("mul1_multiplier"):
exp_qcfg["mul1_multiplier"] = args.get("mul1_multiplier")
qcfg["experimental_options"] = exp_qcfg
update_config(config_dict)
config_dict["quantization_config"] = qcfg
@@ -169,5 +177,3 @@ def compile_model(args, model, config, tokenizer):
create_quantization_config_json(out_dir)
print(f" -- Finished compiling model to {out_dir}")

View File

@@ -435,10 +435,16 @@ class Model:
head_numel = 0
for module in self:
if module.key.endswith("lm_head"):
head_bpw = get_tensor_size(module.get_tensors()) / module.weights_numel()
if module.device is not None:
head_bpw = get_tensor_size(module.get_tensors()) / module.weights_numel()
else:
head_bpw = sum(self.config.stc.get_tensor_sizes(module.key)) / module.weights_numel() * 8
head_numel = module.weights_numel()
elif isinstance(module, Linear):
sum_bits += get_tensor_size(module.get_tensors())
if module.device is not None:
sum_bits += get_tensor_size(module.get_tensors())
else:
sum_bits += sum(self.config.stc.get_tensor_sizes(module.key)) * 8
sum_numel += module.weights_numel()
vram_bits = head_numel * head_bpw + sum_bits
return sum_bits / sum_numel, head_bpw, vram_bits

View File

@@ -284,7 +284,7 @@ class Linear(Module):
if params["capture"][self.qmap]["first_key"] == self.key:
rows = np.prod(x.shape[:-1])
dim = x.shape[-1]
x = x.view((rows, dim)).to(torch.float, copy = True)
x = x.view((rows, dim)).to(torch.float, copy = True) # TODO: Why copy here?
params["capture"][self.qmap]["H"].addmm_(x.T, x)
params["capture"][self.qmap]["count"] += rows

75
util/recompile.py Normal file
View File

@@ -0,0 +1,75 @@
import sys, os
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from exllamav3 import Config, Model
from exllamav3.conversion.compile import compile_model
import argparse
from exllamav3.loader.safetensors import SafetensorsCollection, VariantSafetensorsCollection
import yaml
def tsize(t):
return t.nelement() * t.element_size()
def dsize(d):
size = 0
for _, v in d.items(): size += tsize(v)
return size
def main(args):
# Config/model
config = Config.from_directory(args.in_dir)
model = Model.from_config(config)
# Tensor collection
stc = SafetensorsCollection(args.in_dir)
# Override tensors
if args.override:
with open(args.override, "r") as f:
comp = yaml.safe_load(f)
sources = {s["id"]: s["model_dir"] for s in comp["sources"]}
overrides = {o["key"]: sources[o["source"]] for o in comp["overrides"]}
collections = {}
for o_key, o_dir in overrides.items():
if o_dir not in collections:
collections[o_dir] = []
collections[o_dir].append(o_key)
if len(collections):
vstc = VariantSafetensorsCollection(config.stc)
for o_dir, o_keys in collections.items():
print(f" -- Overriding from: {o_dir}:")
for o_key in o_keys:
print(f" {o_key}")
vstc.add_stc(o_keys, SafetensorsCollection(o_dir))
config.stc = vstc
# New bpw etc.
bpw_layer, bpw_head, vram_bits = model.get_storage_info()
bpw_layer = round(bpw_layer, 2)
bpw_head = round(bpw_head)
print(f" -- New estimated model bitrate: {bpw_layer:.2f} bpw / {bpw_head:.2f} bpw (head)")
# Recompile model
compile_args = {
"bits": bpw_layer,
"head_bits": bpw_head,
"in_dir": args.in_dir,
"out_dir": args.out_dir,
"shard_size": args.shard_size,
"model_stc": True
}
compile_model(compile_args, model, config, None)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("-i", "--in_dir", type = str, default = None, help = "Input model directory")
parser.add_argument("-o", "--out_dir", type = str, default = None, help = "Output directory")
parser.add_argument("-ss", "--shard_size", type = int, help = "Max shard size in MB, default: 8192", default = 8192)
parser.add_argument("-or", "--override", type = str, help = "Tensor override spec (YAML)", default = None)
_args = parser.parse_args()
main(_args)