Loader: Add tensor override script

This commit is contained in:
turboderp
2025-07-08 17:13:39 +02:00
parent 86753399f5
commit 6341b119ef
7 changed files with 204 additions and 55 deletions

View File

@@ -7,10 +7,12 @@ from exllamav3.util.progress import ProgressBar
from exllamav3.util.memory import free_mem
from exllamav3.util.measures import cosine_error, sqnr
from exllamav3 import Config, Model, Tokenizer
from exllamav3.loader import SafetensorsCollection, VariantSafetensorsCollection
from datasets import load_dataset
import torch
import torch.nn.functional as F
import math
import yaml
@disk_lru_cache("get_dataset_text")
@@ -52,6 +54,26 @@ def main(args):
config_b.override_dynamic_seq_len(2048)
model_b = Model.from_config(config_b)
# Override tensors
if args.override:
with open(args.override, "r") as f:
comp = yaml.safe_load(f)
sources = {s["id"]: s["model_dir"] for s in comp["sources"]}
overrides = {o["key"]: sources[o["source"]] for o in comp["overrides"]}
collections = {}
for o_key, o_dir in overrides.items():
if o_dir not in collections:
collections[o_dir] = []
collections[o_dir].append(o_key)
if len(collections):
vstc = VariantSafetensorsCollection(config_a.stc)
for o_dir, o_keys in collections.items():
print(f" -- Overriding from: {o_dir}:")
for o_key in o_keys:
print(f" {o_key}")
vstc.add_stc(o_keys, SafetensorsCollection(o_dir))
config_a.stc = vstc
# Dataset
eval_ids = get_test_tokens(tokenizer, args.rows)
state_a = eval_ids
@@ -215,6 +237,7 @@ if __name__ == "__main__":
parser.add_argument("-kb", "--keep_b", type = int, help = "Maintain B state for number of modules", default = 0)
parser.add_argument("-tkm", "--topk_max", type = int, default = 5, help = "Max top-K interval to test")
parser.add_argument("-d", "--device", type = int, help = "CUDA device index", default = 0)
parser.add_argument("-or", "--override", type = str, help = "Model A tensor override spec (YAML)", default = None)
_args = parser.parse_args()
main(_args)

34
examples/overrides.yaml Normal file
View File

@@ -0,0 +1,34 @@
sources:
- id: A
model_dir: /mnt/str/models/ernie-4.5-300b-a47b-base-pt/exl3/3.0bpw/
overrides:
- key: "*.self_attn.*"
source: A
- key: "model.layers.0.*"
source: A
- key: "model.layers.1.*"
source: A
- key: "model.layers.2.*"
source: A
- key: "model.layers.7.*"
source: A
- key: "model.layers.42.*"
source: A
- key: "model.layers.38.*"
source: A
- key: "model.layers.41.*"
source: A
- key: "model.layers.37.*"
source: A
- key: "model.layers.36.*"
source: A
- key: "model.layers.34.*"
source: A
- key: "model.layers.39.*"
source: A
- key: "model.layers.33.*"
source: A
- key: "model.layers.40.*"
source: A
- key: "model.layers.35.*"
source: A

View File

@@ -10,6 +10,7 @@ import mmap
from ..util import Timer, cuda_sync_active
from ..ext import exllamav3_ext as ext
from functools import lru_cache
from fnmatch import fnmatch
MAX_DEFERRED_LOAD_CHUNK = 2*1024**2
@@ -480,82 +481,151 @@ class SafetensorsCollection:
self.deferred_loads = []
class VariantSafetensorsCollection:
# noinspection PyMissingConstructor
class VariantSafetensorsCollection(SafetensorsCollection):
def __init__(
self,
tensor_map: dict[str, str],
main: SafetensorsCollection,
**kwargs
):
self.tensor_map = None
self.tensor_map_sort = None
self.all_dirs = None
self.stcs = {}
self.kwargs = kwargs
self.update_map(tensor_map)
self.main = main
self.stcs = []
def update_map(
self,
tensor_map: dict[str, str]
):
self.tensor_map = tensor_map
self.tensor_map_sort = sorted(tensor_map.items(), key = lambda kv: len(kv[0]), reverse = True)
all_dirs = list(set(tensor_map.values()))
def add_stc(self, filters, stc):
self.stcs = [(filters, stc)] + self.stcs
for d in all_dirs:
if d not in self.stcs:
self.stcs[d] = SafetensorsCollection(directory = d, **self.kwargs)
def find_stc(self, key):
for filters, stc in self.stcs:
for f in filters:
if fnmatch(key, f):
return stc
return self.main
def has_tensor(
self,
key: str,
):
return any(key in stc.tensor_file_map for stc in self.stcs.values())
stc = self.find_stc(key)
return stc.has_tensor(key)
def has_tensor_group(
self,
key: str,
subkeys: list[str],
subkeys: list,
):
return all(
any(f"{key}.{subkey}" in stc.tensor_file_map for stc in self.stcs.values())
for subkey in subkeys
)
for subkey in subkeys:
sk_exists = False
for sk in [subkey] if isinstance(subkey, str) else subkey:
k = f"{key}.{sk}"
stc = self.find_stc(k)
if k in stc.tensor_file_map:
sk_exists = True
break
if not sk_exists:
return False
return True
def get_tensor_sizes(
self,
prefix: str,
):
keys = [
key for key in self.main.tensor_file_map.keys()
if key == prefix or key.startswith(prefix + ".")
]
sizes = [self.get_tensor_size(key) for key in keys]
return sizes
def get_tensor_size(
self,
key: str,
optional: bool = False
):
stc = self.find_stc(key)
return stc.get_tensor_size(key, optional)
def list_tensors(
self,
prefix: str,
only_serializable: bool = False
) -> dict:
keys = [
key for key in self.main.tensor_file_map.keys()
if key == prefix or key.startswith(prefix + ".")
]
results = {}
for key in keys:
stc = self.find_stc(key)
filename = stc.tensor_file_map[key]
header = stc.file_headers[filename]
h = header[key]
dtype, np_dtype, esize = convert_dtype(h["dtype"])
beg, end = h["data_offsets"]
results[key] = {
"shape": h["shape"],
"n_bytes": end - beg,
"dtype": str(dtype)
}
if not only_serializable:
results[key]["torch_dtype"] = dtype
return results
def get_tensors(
self,
prefix: str,
device: torch.device | None = None,
allow_bf16: bool = False,
) -> dict:
keys = [
key for key in self.main.tensor_file_map.keys()
if key == prefix or key.startswith(prefix + ".")
]
result = {key: self.find_stc(key).get_tensor(key, device, allow_bf16 = allow_bf16) for key in keys}
return result
def get_tensor(
self,
key: str,
device: torch.device | None = None,
optional: bool = False,
allow_bf16: bool = False
*args,
**kwargs,
) -> torch.Tensor | None:
file = None
for k, v in self.tensor_map_sort:
if key.startswith(k):
file = v
break
if file is None:
if not optional:
raise ValueError(f"No prefix found in variants map with the matching key: {key}")
else:
return None
return self.stcs[file].get_tensor(key, device, optional, allow_bf16)
stc = self.find_stc(key)
return stc.get_tensor(key, *args, **kwargs)
def close(self):
for stc in self.stcs.values():
for stc in [s for _, s in self.stcs] + [self.main]:
stc.close()
def get_metrics(self):
res = [stc.get_metrics() for stc in self.stcs.values()]
bytes_loaded = sum(r[0] for r in res)
time_elapsed = sum(r[1] for r in res)
bandwidth = bytes_loaded / (1024**3) / time_elapsed
return bytes_loaded, time_elapsed, bandwidth
def max_key_len(self):
return self.main.max_key_len()
def set_new_tensors(self, new_tensors):
raise NotImplementedError()
def begin_deferred_load(self):
for stc in [s for _, s in self.stcs] + [self.main]:
stc.begin_deferred_load()
def end_deferred_load(self):
for stc in [s for _, s in self.stcs] + [self.main]:
stc.end_deferred_load()
def abort_deferred_load(self):
for stc in [s for _, s in self.stcs] + [self.main]:
stc.abort_deferred_load()

View File

@@ -1,7 +1,9 @@
from . import Model, Config, Cache, Tokenizer
from .loader import SafetensorsCollection, VariantSafetensorsCollection
from .cache import CacheLayer_fp16, CacheLayer_quant
from argparse import ArgumentParser
import torch
import yaml
def add_args(
parser: ArgumentParser,
@@ -23,6 +25,7 @@ def add_args(
parser.add_argument("-m", "--model_dir", type = str, help = "Path to model directory", required = True)
parser.add_argument("-gs", "--gpu_split", type = str, help = "Maximum amount of VRAM to use per device, in GB.")
parser.add_argument("-lm", "--load_metrics", action = "store_true", help = "Show metrics from loader")
parser.add_argument("-or", "--override", type = str, help = "Tensor override spec (YAML)", default = None)
if cache:
parser.add_argument("-cs", "--cache_size", type = int, help = f"Total cache size in tokens, default: {default_cache_size}", default = default_cache_size)
@@ -73,6 +76,28 @@ def init(
# Config
config = Config.from_directory(args.model_dir)
if override_dynamic_seq_len: config.override_dynamic_seq_len(override_dynamic_seq_len)
# Override tensors
if args.override:
with open(args.override, "r") as f:
comp = yaml.safe_load(f)
sources = {s["id"]: s["model_dir"] for s in comp["sources"]}
overrides = {o["key"]: sources[o["source"]] for o in comp["overrides"]}
collections = {}
for o_key, o_dir in overrides.items():
if o_dir not in collections:
collections[o_dir] = []
collections[o_dir].append(o_key)
if len(collections):
vstc = VariantSafetensorsCollection(config.stc)
for o_dir, o_keys in collections.items():
printp(not quiet, f" -- Overriding from: {o_dir}:")
for o_key in o_keys:
printp(not quiet, f" {o_key}")
vstc.add_stc(o_keys, SafetensorsCollection(o_dir))
config.stc = vstc
# Model instance
model = Model.from_config(config)
# Cache

View File

@@ -5,7 +5,7 @@ import torch.nn.functional as F
from torch import nn
import os, json
from ..util.rope import RopeSettings, RopeStyle
from ..loader import SafetensorsCollection, VariantSafetensorsCollection
from ..loader import SafetensorsCollection
from ..util.file import read_dict, no_value, no_default
import uuid
@@ -46,13 +46,8 @@ class Config(ABC):
f"Unexpected architecture {arch} in {self.config_filename}, should be {self.arch_string}."
self.architecture = arch
# Special mode to load tensors from across multiple variants of the same model
if kwargs.get("st_variants"):
self.stc = VariantSafetensorsCollection(kwargs.get("st_variants"))
# Collect all .safetensors files in directory
else:
self.stc = SafetensorsCollection(directory, load_method = kwargs.get("load_method"))
self.stc = SafetensorsCollection(directory, load_method = kwargs.get("load_method"))
# Standard params, vocab
self.bos_token_id = self.read_cfg(int, "bos_token_id", None)

View File

@@ -7,3 +7,4 @@ typing_extensions
safetensors>=0.3.2
ninja
pillow
pyyaml

View File

@@ -93,7 +93,8 @@ setup(
"rich",
"typing_extensions",
"ninja",
"safetensors>=0.3.2"
"safetensors>=0.3.2",
"pyyaml"
],
include_package_data=True,
package_data = {