Files
exllamav2/conversion/compile.py

192 lines
6.5 KiB
Python

from exllamav2.model import \
(
ExLlamaV2Embedding,
ExLlamaV2Attention,
ExLlamaV2MLP,
ExLlamaV2MoEMLP,
ExLlamaV2Linear,
ExLlamaV2RMSNorm,
ExLlamaV2LayerNorm
)
import os, glob, shutil
from safetensors import safe_open
from safetensors.torch import save_file
def _tsize(t):
return t.nelement() * t.element_size()
def _dsize(d):
size = 0
for _, v in d.items(): size += _tsize(v)
return size
def get_f_module(job, module):
mod_dict = {}
module.load()
w = module.get_weight()
if isinstance(w, tuple):
mod_dict[module.key + ".weight"] = w[0]
mod_dict[module.key + ".bias"] = w[1]
else:
mod_dict[module.key + ".weight"] = w
return mod_dict
def get_q_module(job, module):
mod_dict = {}
filename = os.path.join(job["out_dir"], "out_tensor/" + module.key + ".safetensors")
with safe_open(filename, framework = "pt", device = "cpu") as f:
for k in f.keys():
mod_dict[k] = f.get_tensor(k)
return mod_dict
def compile_model(job, save_fn, model):
out_dict = {}
current_size = 0
file_index = 1
index = 0
shard_bytes = job["shard_size"] * 1024 ** 2
while index < len(model.modules):
module = model.modules[index]
if isinstance(module, ExLlamaV2Embedding):
d = get_f_module(job, module); out_dict.update(d); current_size += _dsize(d)
if isinstance(module, ExLlamaV2Attention):
d = get_f_module(job, module.input_layernorm); out_dict.update(d); current_size += _dsize(d)
d = get_q_module(job, module.q_proj); out_dict.update(d); current_size += _dsize(d)
d = get_q_module(job, module.k_proj); out_dict.update(d); current_size += _dsize(d)
d = get_q_module(job, module.v_proj); out_dict.update(d); current_size += _dsize(d)
d = get_q_module(job, module.o_proj); out_dict.update(d); current_size += _dsize(d)
if isinstance(module, ExLlamaV2MLP):
d = get_f_module(job, module.post_attention_layernorm); out_dict.update(d); current_size += _dsize(d)
d = get_q_module(job, module.gate_proj); out_dict.update(d); current_size += _dsize(d)
d = get_q_module(job, module.up_proj); out_dict.update(d); current_size += _dsize(d)
d = get_q_module(job, module.down_proj); out_dict.update(d); current_size += _dsize(d)
if isinstance(module, ExLlamaV2MoEMLP):
d = get_f_module(job, module.post_attention_layernorm); out_dict.update(d); current_size += _dsize(d)
d = get_f_module(job, module.gate); out_dict.update(d); current_size += _dsize(d)
for i in range(model.config.num_experts):
d = get_q_module(job, module.w1[i]); out_dict.update(d); current_size += _dsize(d)
d = get_q_module(job, module.w3[i]); out_dict.update(d); current_size += _dsize(d)
d = get_q_module(job, module.w2[i]); out_dict.update(d); current_size += _dsize(d)
if isinstance(module, ExLlamaV2RMSNorm) or isinstance(module, ExLlamaV2LayerNorm):
d = get_f_module(job, module); out_dict.update(d); current_size += _dsize(d)
if isinstance(module, ExLlamaV2Linear):
assert module.key == "lm_head"
d = get_q_module(job, module); out_dict.update(d); current_size += _dsize(d)
index += 1
# Save shard
if current_size > shard_bytes or index == len(model.modules):
save_dict = {}
dont_save_dict = {}
this_shard_size = 0
for k, v in out_dict.items():
tsize = _tsize(v)
if this_shard_size + tsize <= shard_bytes:
this_shard_size += tsize
current_size -= tsize
save_dict[k] = v.contiguous()
else:
dont_save_dict[k] = v
if len(save_dict) == 0:
print(f" ## Error: Unable to fit output tensor in single shard.")
os._exit(0)
while True:
print(f" -- Writing shard {file_index}...")
out_dir = job["out_dir"]
if job["compile_full"] is not None: out_dir = job["compile_full"]
if not os.path.exists(out_dir):
print(f" -- Creating directory {out_dir}")
os.makedirs(out_dir)
out_filename = os.path.join(out_dir, f"output_temp_{file_index}.safetensors")
save_file(save_dict, out_filename)
file_index += 1
out_dict = dont_save_dict
if index == len(model.modules) and len(out_dict) > 0:
save_dict = dont_save_dict
dont_save_dict = {}
continue
break
num_files = file_index - 1
if num_files == 1:
final_filename = os.path.join(out_dir, "output.safetensors")
os.rename(out_filename, final_filename)
filesize = os.path.getsize(final_filename) // (1024 ** 2)
print(f" -- {final_filename} ({filesize:,} MB)")
else:
print(f" -- Saved model weights:")
for i in range(num_files):
temp_filename = os.path.join(out_dir, f"output_temp_{i + 1}.safetensors")
final_filename = os.path.join(out_dir, f"output-{i + 1:05}-of-{num_files:05}.safetensors")
os.rename(temp_filename, final_filename)
filesize = os.path.getsize(final_filename) // (1024 ** 2)
print(f" -- {final_filename} ({filesize:,} MB)")
# Copy all non-tensor files from the model's directory if compiling a full model
if job["compile_full"] is not None:
print(f" -- Copying non-tensor files to output directory {out_dir}")
input_dir = model.config.model_dir
all_files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f))]
tensor_files = glob.glob(os.path.join(input_dir, "*.safetensors"))
tensor_files_set = set(tensor_files)
bin_files = glob.glob(os.path.join(input_dir, "*.bin"))
if len(bin_files) > 0:
print(f" !! Ignoring *.bin files in source dir")
tensor_files_set.update(bin_files)
non_tensor_files = [f for f in all_files if os.path.join(input_dir, f) not in tensor_files_set]
for f in non_tensor_files:
print(f" -- {f}")
source_file_path = os.path.join(input_dir, f)
target_file_path = os.path.join(out_dir, f)
shutil.copy(source_file_path, target_file_path)