Files
exllamav3/util/size_estimation.py
2025-11-13 13:49:39 +01:00

62 lines
2.0 KiB
Python

import sys, os
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from exllamav3 import Config, Model
import argparse
from exllamav3.loader.safetensors import SafetensorsCollection, VariantSafetensorsCollection
import yaml
def tsize(t):
return t.nelement() * t.element_size()
def dsize(d):
size = 0
for _, v in d.items(): size += tsize(v)
return size
def main(args):
# Config/model
config = Config.from_directory(args.in_dir)
model = Model.from_config(config)
# Tensor collection
stc = SafetensorsCollection(args.in_dir)
# Override tensors
if args.override:
with open(args.override, "r") as f:
comp = yaml.safe_load(f)
sources = {s["id"]: s["model_dir"] for s in comp["sources"]}
overrides = {o["key"]: sources[o["source"]] for o in comp["overrides"]}
collections = {}
for o_key, o_dir in overrides.items():
if o_dir not in collections:
collections[o_dir] = []
collections[o_dir].append(o_key)
if len(collections):
vstc = VariantSafetensorsCollection(config.stc)
for o_dir, o_keys in collections.items():
print(f" -- Overriding from: {o_dir}:")
for o_key in o_keys:
print(f" {o_key}")
vstc.add_stc(o_keys, SafetensorsCollection(o_dir))
config.stc = vstc
# New bpw etc.
bpw_layer, bpw_head, vram_bits = model.get_storage_info()
bpw_layer = round(bpw_layer, 2)
bpw_head = round(bpw_head)
print(f" -- New estimated model bitrate: {bpw_layer:.2f} bpw / {bpw_head:.2f} bpw (head)")
print(f" -- VRAM: {vram_bits / 8 / 1024**3:.0f} GiB")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("-i", "--in_dir", type = str, default = None, help = "Input model directory")
parser.add_argument("-or", "--override", type = str, help = "Tensor override spec (YAML)", default = None)
_args = parser.parse_args()
main(_args)