Tensor renaming kludge (Gemma3 has one _weight tensor)

This commit is contained in:
turboderp
2025-03-14 23:37:57 +01:00
parent e2fa480595
commit 07afc90788
2 changed files with 13 additions and 1 deletions

View File

@@ -120,6 +120,7 @@ class ExLlamaV2ArchParams:
arch_recognized = False
self.keymap = None
self.compile_fix_keymap = None
@dataclass
class Params:

View File

@@ -18,6 +18,7 @@ import os, glob, shutil, json
from safetensors import safe_open
from safetensors.torch import save_file
from exllamav2.conversion.bot_status import print_stage
from safetensors import SafetensorError
def _tsize(t):
@@ -155,8 +156,18 @@ def compile_model(job, save_fn, model):
key = extra_tensors[0]
extra_tensors = extra_tensors[1:]
file = cfg.tensor_file_map[key]
lkey = key
if cfg.arch.compile_fix_keymap is not None:
km = cfg.arch.compile_fix_keymap
for (a, b) in km:
if key.endswith(b):
lkey = key[:-len(b)] + a
break
with safe_open(file, framework = "pt") as f:
tensor = f.get_tensor(key)
tensor = f.get_tensor(lkey)
out_dict.update({key: tensor})
extra_tensors_size += _tsize(tensor)