mirror of
https://github.com/turboderp-org/exllamav3.git
synced 2026-04-20 14:29:51 +00:00
Loader: Add tensor override script
This commit is contained in:
@@ -7,10 +7,12 @@ from exllamav3.util.progress import ProgressBar
|
||||
from exllamav3.util.memory import free_mem
|
||||
from exllamav3.util.measures import cosine_error, sqnr
|
||||
from exllamav3 import Config, Model, Tokenizer
|
||||
from exllamav3.loader import SafetensorsCollection, VariantSafetensorsCollection
|
||||
from datasets import load_dataset
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import math
|
||||
import yaml
|
||||
|
||||
|
||||
@disk_lru_cache("get_dataset_text")
|
||||
@@ -52,6 +54,26 @@ def main(args):
|
||||
config_b.override_dynamic_seq_len(2048)
|
||||
model_b = Model.from_config(config_b)
|
||||
|
||||
# Override tensors
|
||||
if args.override:
|
||||
with open(args.override, "r") as f:
|
||||
comp = yaml.safe_load(f)
|
||||
sources = {s["id"]: s["model_dir"] for s in comp["sources"]}
|
||||
overrides = {o["key"]: sources[o["source"]] for o in comp["overrides"]}
|
||||
collections = {}
|
||||
for o_key, o_dir in overrides.items():
|
||||
if o_dir not in collections:
|
||||
collections[o_dir] = []
|
||||
collections[o_dir].append(o_key)
|
||||
if len(collections):
|
||||
vstc = VariantSafetensorsCollection(config_a.stc)
|
||||
for o_dir, o_keys in collections.items():
|
||||
print(f" -- Overriding from: {o_dir}:")
|
||||
for o_key in o_keys:
|
||||
print(f" {o_key}")
|
||||
vstc.add_stc(o_keys, SafetensorsCollection(o_dir))
|
||||
config_a.stc = vstc
|
||||
|
||||
# Dataset
|
||||
eval_ids = get_test_tokens(tokenizer, args.rows)
|
||||
state_a = eval_ids
|
||||
@@ -215,6 +237,7 @@ if __name__ == "__main__":
|
||||
parser.add_argument("-kb", "--keep_b", type = int, help = "Maintain B state for number of modules", default = 0)
|
||||
parser.add_argument("-tkm", "--topk_max", type = int, default = 5, help = "Max top-K interval to test")
|
||||
parser.add_argument("-d", "--device", type = int, help = "CUDA device index", default = 0)
|
||||
parser.add_argument("-or", "--override", type = str, help = "Model A tensor override spec (YAML)", default = None)
|
||||
|
||||
_args = parser.parse_args()
|
||||
main(_args)
|
||||
|
||||
Reference in New Issue
Block a user