mirror of
https://github.com/turboderp-org/exllamav3.git
synced 2026-04-20 06:19:10 +00:00
Loader: Add tensor override script
This commit is contained in:
@@ -7,10 +7,12 @@ from exllamav3.util.progress import ProgressBar
|
||||
from exllamav3.util.memory import free_mem
|
||||
from exllamav3.util.measures import cosine_error, sqnr
|
||||
from exllamav3 import Config, Model, Tokenizer
|
||||
from exllamav3.loader import SafetensorsCollection, VariantSafetensorsCollection
|
||||
from datasets import load_dataset
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import math
|
||||
import yaml
|
||||
|
||||
|
||||
@disk_lru_cache("get_dataset_text")
|
||||
@@ -52,6 +54,26 @@ def main(args):
|
||||
config_b.override_dynamic_seq_len(2048)
|
||||
model_b = Model.from_config(config_b)
|
||||
|
||||
# Override tensors
|
||||
if args.override:
|
||||
with open(args.override, "r") as f:
|
||||
comp = yaml.safe_load(f)
|
||||
sources = {s["id"]: s["model_dir"] for s in comp["sources"]}
|
||||
overrides = {o["key"]: sources[o["source"]] for o in comp["overrides"]}
|
||||
collections = {}
|
||||
for o_key, o_dir in overrides.items():
|
||||
if o_dir not in collections:
|
||||
collections[o_dir] = []
|
||||
collections[o_dir].append(o_key)
|
||||
if len(collections):
|
||||
vstc = VariantSafetensorsCollection(config_a.stc)
|
||||
for o_dir, o_keys in collections.items():
|
||||
print(f" -- Overriding from: {o_dir}:")
|
||||
for o_key in o_keys:
|
||||
print(f" {o_key}")
|
||||
vstc.add_stc(o_keys, SafetensorsCollection(o_dir))
|
||||
config_a.stc = vstc
|
||||
|
||||
# Dataset
|
||||
eval_ids = get_test_tokens(tokenizer, args.rows)
|
||||
state_a = eval_ids
|
||||
@@ -215,6 +237,7 @@ if __name__ == "__main__":
|
||||
parser.add_argument("-kb", "--keep_b", type = int, help = "Maintain B state for number of modules", default = 0)
|
||||
parser.add_argument("-tkm", "--topk_max", type = int, default = 5, help = "Max top-K interval to test")
|
||||
parser.add_argument("-d", "--device", type = int, help = "CUDA device index", default = 0)
|
||||
parser.add_argument("-or", "--override", type = str, help = "Model A tensor override spec (YAML)", default = None)
|
||||
|
||||
_args = parser.parse_args()
|
||||
main(_args)
|
||||
|
||||
34
examples/overrides.yaml
Normal file
34
examples/overrides.yaml
Normal file
@@ -0,0 +1,34 @@
|
||||
sources:
|
||||
- id: A
|
||||
model_dir: /mnt/str/models/ernie-4.5-300b-a47b-base-pt/exl3/3.0bpw/
|
||||
overrides:
|
||||
- key: "*.self_attn.*"
|
||||
source: A
|
||||
- key: "model.layers.0.*"
|
||||
source: A
|
||||
- key: "model.layers.1.*"
|
||||
source: A
|
||||
- key: "model.layers.2.*"
|
||||
source: A
|
||||
- key: "model.layers.7.*"
|
||||
source: A
|
||||
- key: "model.layers.42.*"
|
||||
source: A
|
||||
- key: "model.layers.38.*"
|
||||
source: A
|
||||
- key: "model.layers.41.*"
|
||||
source: A
|
||||
- key: "model.layers.37.*"
|
||||
source: A
|
||||
- key: "model.layers.36.*"
|
||||
source: A
|
||||
- key: "model.layers.34.*"
|
||||
source: A
|
||||
- key: "model.layers.39.*"
|
||||
source: A
|
||||
- key: "model.layers.33.*"
|
||||
source: A
|
||||
- key: "model.layers.40.*"
|
||||
source: A
|
||||
- key: "model.layers.35.*"
|
||||
source: A
|
||||
@@ -10,6 +10,7 @@ import mmap
|
||||
from ..util import Timer, cuda_sync_active
|
||||
from ..ext import exllamav3_ext as ext
|
||||
from functools import lru_cache
|
||||
from fnmatch import fnmatch
|
||||
|
||||
MAX_DEFERRED_LOAD_CHUNK = 2*1024**2
|
||||
|
||||
@@ -480,82 +481,151 @@ class SafetensorsCollection:
|
||||
self.deferred_loads = []
|
||||
|
||||
|
||||
class VariantSafetensorsCollection:
|
||||
# noinspection PyMissingConstructor
|
||||
class VariantSafetensorsCollection(SafetensorsCollection):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tensor_map: dict[str, str],
|
||||
main: SafetensorsCollection,
|
||||
**kwargs
|
||||
):
|
||||
self.tensor_map = None
|
||||
self.tensor_map_sort = None
|
||||
self.all_dirs = None
|
||||
self.stcs = {}
|
||||
self.kwargs = kwargs
|
||||
self.update_map(tensor_map)
|
||||
self.main = main
|
||||
self.stcs = []
|
||||
|
||||
|
||||
def update_map(
|
||||
self,
|
||||
tensor_map: dict[str, str]
|
||||
):
|
||||
self.tensor_map = tensor_map
|
||||
self.tensor_map_sort = sorted(tensor_map.items(), key = lambda kv: len(kv[0]), reverse = True)
|
||||
all_dirs = list(set(tensor_map.values()))
|
||||
def add_stc(self, filters, stc):
|
||||
self.stcs = [(filters, stc)] + self.stcs
|
||||
|
||||
for d in all_dirs:
|
||||
if d not in self.stcs:
|
||||
self.stcs[d] = SafetensorsCollection(directory = d, **self.kwargs)
|
||||
|
||||
def find_stc(self, key):
|
||||
for filters, stc in self.stcs:
|
||||
for f in filters:
|
||||
if fnmatch(key, f):
|
||||
return stc
|
||||
return self.main
|
||||
|
||||
|
||||
def has_tensor(
|
||||
self,
|
||||
key: str,
|
||||
):
|
||||
return any(key in stc.tensor_file_map for stc in self.stcs.values())
|
||||
stc = self.find_stc(key)
|
||||
return stc.has_tensor(key)
|
||||
|
||||
|
||||
def has_tensor_group(
|
||||
self,
|
||||
key: str,
|
||||
subkeys: list[str],
|
||||
subkeys: list,
|
||||
):
|
||||
return all(
|
||||
any(f"{key}.{subkey}" in stc.tensor_file_map for stc in self.stcs.values())
|
||||
for subkey in subkeys
|
||||
)
|
||||
for subkey in subkeys:
|
||||
sk_exists = False
|
||||
for sk in [subkey] if isinstance(subkey, str) else subkey:
|
||||
k = f"{key}.{sk}"
|
||||
stc = self.find_stc(k)
|
||||
if k in stc.tensor_file_map:
|
||||
sk_exists = True
|
||||
break
|
||||
if not sk_exists:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def get_tensor_sizes(
|
||||
self,
|
||||
prefix: str,
|
||||
):
|
||||
keys = [
|
||||
key for key in self.main.tensor_file_map.keys()
|
||||
if key == prefix or key.startswith(prefix + ".")
|
||||
]
|
||||
sizes = [self.get_tensor_size(key) for key in keys]
|
||||
return sizes
|
||||
|
||||
|
||||
def get_tensor_size(
|
||||
self,
|
||||
key: str,
|
||||
optional: bool = False
|
||||
):
|
||||
stc = self.find_stc(key)
|
||||
return stc.get_tensor_size(key, optional)
|
||||
|
||||
|
||||
def list_tensors(
|
||||
self,
|
||||
prefix: str,
|
||||
only_serializable: bool = False
|
||||
) -> dict:
|
||||
keys = [
|
||||
key for key in self.main.tensor_file_map.keys()
|
||||
if key == prefix or key.startswith(prefix + ".")
|
||||
]
|
||||
results = {}
|
||||
for key in keys:
|
||||
stc = self.find_stc(key)
|
||||
filename = stc.tensor_file_map[key]
|
||||
header = stc.file_headers[filename]
|
||||
h = header[key]
|
||||
dtype, np_dtype, esize = convert_dtype(h["dtype"])
|
||||
beg, end = h["data_offsets"]
|
||||
results[key] = {
|
||||
"shape": h["shape"],
|
||||
"n_bytes": end - beg,
|
||||
"dtype": str(dtype)
|
||||
}
|
||||
if not only_serializable:
|
||||
results[key]["torch_dtype"] = dtype
|
||||
return results
|
||||
|
||||
|
||||
def get_tensors(
|
||||
self,
|
||||
prefix: str,
|
||||
device: torch.device | None = None,
|
||||
allow_bf16: bool = False,
|
||||
) -> dict:
|
||||
keys = [
|
||||
key for key in self.main.tensor_file_map.keys()
|
||||
if key == prefix or key.startswith(prefix + ".")
|
||||
]
|
||||
result = {key: self.find_stc(key).get_tensor(key, device, allow_bf16 = allow_bf16) for key in keys}
|
||||
return result
|
||||
|
||||
|
||||
def get_tensor(
|
||||
self,
|
||||
key: str,
|
||||
device: torch.device | None = None,
|
||||
optional: bool = False,
|
||||
allow_bf16: bool = False
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> torch.Tensor | None:
|
||||
|
||||
file = None
|
||||
for k, v in self.tensor_map_sort:
|
||||
if key.startswith(k):
|
||||
file = v
|
||||
break
|
||||
if file is None:
|
||||
if not optional:
|
||||
raise ValueError(f"No prefix found in variants map with the matching key: {key}")
|
||||
else:
|
||||
return None
|
||||
|
||||
return self.stcs[file].get_tensor(key, device, optional, allow_bf16)
|
||||
stc = self.find_stc(key)
|
||||
return stc.get_tensor(key, *args, **kwargs)
|
||||
|
||||
|
||||
def close(self):
|
||||
for stc in self.stcs.values():
|
||||
for stc in [s for _, s in self.stcs] + [self.main]:
|
||||
stc.close()
|
||||
|
||||
|
||||
def get_metrics(self):
|
||||
res = [stc.get_metrics() for stc in self.stcs.values()]
|
||||
bytes_loaded = sum(r[0] for r in res)
|
||||
time_elapsed = sum(r[1] for r in res)
|
||||
bandwidth = bytes_loaded / (1024**3) / time_elapsed
|
||||
return bytes_loaded, time_elapsed, bandwidth
|
||||
def max_key_len(self):
|
||||
return self.main.max_key_len()
|
||||
|
||||
|
||||
def set_new_tensors(self, new_tensors):
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
def begin_deferred_load(self):
|
||||
for stc in [s for _, s in self.stcs] + [self.main]:
|
||||
stc.begin_deferred_load()
|
||||
|
||||
|
||||
def end_deferred_load(self):
|
||||
for stc in [s for _, s in self.stcs] + [self.main]:
|
||||
stc.end_deferred_load()
|
||||
|
||||
|
||||
def abort_deferred_load(self):
|
||||
for stc in [s for _, s in self.stcs] + [self.main]:
|
||||
stc.abort_deferred_load()
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
from . import Model, Config, Cache, Tokenizer
|
||||
from .loader import SafetensorsCollection, VariantSafetensorsCollection
|
||||
from .cache import CacheLayer_fp16, CacheLayer_quant
|
||||
from argparse import ArgumentParser
|
||||
import torch
|
||||
import yaml
|
||||
|
||||
def add_args(
|
||||
parser: ArgumentParser,
|
||||
@@ -23,6 +25,7 @@ def add_args(
|
||||
parser.add_argument("-m", "--model_dir", type = str, help = "Path to model directory", required = True)
|
||||
parser.add_argument("-gs", "--gpu_split", type = str, help = "Maximum amount of VRAM to use per device, in GB.")
|
||||
parser.add_argument("-lm", "--load_metrics", action = "store_true", help = "Show metrics from loader")
|
||||
parser.add_argument("-or", "--override", type = str, help = "Tensor override spec (YAML)", default = None)
|
||||
|
||||
if cache:
|
||||
parser.add_argument("-cs", "--cache_size", type = int, help = f"Total cache size in tokens, default: {default_cache_size}", default = default_cache_size)
|
||||
@@ -73,6 +76,28 @@ def init(
|
||||
# Config
|
||||
config = Config.from_directory(args.model_dir)
|
||||
if override_dynamic_seq_len: config.override_dynamic_seq_len(override_dynamic_seq_len)
|
||||
|
||||
# Override tensors
|
||||
if args.override:
|
||||
with open(args.override, "r") as f:
|
||||
comp = yaml.safe_load(f)
|
||||
sources = {s["id"]: s["model_dir"] for s in comp["sources"]}
|
||||
overrides = {o["key"]: sources[o["source"]] for o in comp["overrides"]}
|
||||
collections = {}
|
||||
for o_key, o_dir in overrides.items():
|
||||
if o_dir not in collections:
|
||||
collections[o_dir] = []
|
||||
collections[o_dir].append(o_key)
|
||||
if len(collections):
|
||||
vstc = VariantSafetensorsCollection(config.stc)
|
||||
for o_dir, o_keys in collections.items():
|
||||
printp(not quiet, f" -- Overriding from: {o_dir}:")
|
||||
for o_key in o_keys:
|
||||
printp(not quiet, f" {o_key}")
|
||||
vstc.add_stc(o_keys, SafetensorsCollection(o_dir))
|
||||
config.stc = vstc
|
||||
|
||||
# Model instance
|
||||
model = Model.from_config(config)
|
||||
|
||||
# Cache
|
||||
|
||||
@@ -5,7 +5,7 @@ import torch.nn.functional as F
|
||||
from torch import nn
|
||||
import os, json
|
||||
from ..util.rope import RopeSettings, RopeStyle
|
||||
from ..loader import SafetensorsCollection, VariantSafetensorsCollection
|
||||
from ..loader import SafetensorsCollection
|
||||
from ..util.file import read_dict, no_value, no_default
|
||||
import uuid
|
||||
|
||||
@@ -46,13 +46,8 @@ class Config(ABC):
|
||||
f"Unexpected architecture {arch} in {self.config_filename}, should be {self.arch_string}."
|
||||
self.architecture = arch
|
||||
|
||||
# Special mode to load tensors from across multiple variants of the same model
|
||||
if kwargs.get("st_variants"):
|
||||
self.stc = VariantSafetensorsCollection(kwargs.get("st_variants"))
|
||||
|
||||
# Collect all .safetensors files in directory
|
||||
else:
|
||||
self.stc = SafetensorsCollection(directory, load_method = kwargs.get("load_method"))
|
||||
self.stc = SafetensorsCollection(directory, load_method = kwargs.get("load_method"))
|
||||
|
||||
# Standard params, vocab
|
||||
self.bos_token_id = self.read_cfg(int, "bos_token_id", None)
|
||||
|
||||
@@ -7,3 +7,4 @@ typing_extensions
|
||||
safetensors>=0.3.2
|
||||
ninja
|
||||
pillow
|
||||
pyyaml
|
||||
Reference in New Issue
Block a user