mirror of
https://github.com/turboderp-org/exllamav3.git
synced 2026-03-15 00:07:24 +00:00
42 lines
1.5 KiB
Python
42 lines
1.5 KiB
Python
import torch
|
|
try:
|
|
from exllamav3 import Config, Model, Tokenizer, Cache
|
|
from exllamav3.modules import Linear
|
|
except ModuleNotFoundError:
|
|
pass
|
|
|
|
def get_tensor_size(tensors):
|
|
return 8 * sum(t.element_size() * t.numel() for t in tensors.values())
|
|
|
|
def get_storage_info(model):
|
|
sum_bits = 0
|
|
sum_numel = 0
|
|
head_bpw = 0
|
|
head_numel = 0
|
|
for module in model:
|
|
if module.key.endswith("lm_head"):
|
|
head_bpw = get_tensor_size(module.get_tensors()) / module.weights_numel()
|
|
head_numel = module.weights_numel()
|
|
elif isinstance(module, Linear):
|
|
sum_bits += get_tensor_size(module.get_tensors())
|
|
sum_numel += module.weights_numel()
|
|
vram_bits = head_numel * head_bpw + sum_bits
|
|
return sum_bits / sum_numel, head_bpw, vram_bits
|
|
|
|
def load_exllamav3(model_dir: str | list):
|
|
if isinstance(model_dir, list):
|
|
model_dir, override_tensors = model_dir
|
|
config = Config.from_directory(model_dir)
|
|
config.stc.add_tensor_files(override_tensors)
|
|
else:
|
|
config = Config.from_directory(model_dir)
|
|
model = Model.from_config(config)
|
|
model.load(max_output_size = 2048, max_output_factor = 7)
|
|
bpw_layer, bpw_head, vram_bits = get_storage_info(model)
|
|
return model, bpw_layer, bpw_head, vram_bits
|
|
|
|
def fwd_exllamav3(model_instance, input_ids: torch.Tensor):
|
|
input_ids = input_ids.cpu()
|
|
output = model_instance.forward(input_ids, {"attn_mode": "flash_attn_nc"})
|
|
output[..., model_instance.config.vocab_size:] = float("-inf")
|
|
return output |