mirror of
https://github.com/turboderp-org/exllamav3.git
synced 2026-03-15 00:07:24 +00:00
98 lines
3.5 KiB
Python
98 lines
3.5 KiB
Python
import torch
|
|
try:
|
|
from any_precision.modules.AnyPrecisionForCausalLM import AnyPrecisionForCausalLM
|
|
from any_precision.modules.AnyPrecisionLinear import AnyPrecisionLinear
|
|
except ModuleNotFoundError:
|
|
pass
|
|
except ImportError:
|
|
pass
|
|
|
|
def get_tensors_size(tensors):
|
|
return 8 * sum(t.element_size() * t.numel() for t in tensors.values() if t is not None)
|
|
|
|
def get_tensor_size(tensor):
|
|
return 8 * tensor.element_size() * tensor.numel()
|
|
|
|
def scan_gpu_tensors(obj, seen = None):
|
|
if seen is None:
|
|
seen = set()
|
|
obj_id = id(obj)
|
|
if obj_id in seen:
|
|
return 0
|
|
seen.add(obj_id)
|
|
total_size = 0
|
|
# If it's a GPU tensor, add its memory usage.
|
|
if isinstance(obj, torch.Tensor) and obj.is_cuda:
|
|
total_size += obj.element_size() * obj.nelement()
|
|
else:
|
|
if isinstance(obj, dict):
|
|
for key, value in obj.items():
|
|
total_size += scan_gpu_tensors(key, seen)
|
|
total_size += scan_gpu_tensors(value, seen)
|
|
return total_size
|
|
if isinstance(obj, (list, tuple, set)):
|
|
for item in obj:
|
|
total_size += scan_gpu_tensors(item, seen)
|
|
return total_size
|
|
if hasattr(obj, '__dict__'):
|
|
total_size += scan_gpu_tensors(vars(obj), seen)
|
|
if hasattr(obj, '__slots__'):
|
|
for slot in obj.__slots__:
|
|
try:
|
|
attr = getattr(obj, slot)
|
|
total_size += scan_gpu_tensors(attr, seen)
|
|
except AttributeError:
|
|
continue
|
|
return total_size
|
|
|
|
def get_storage_info(model):
|
|
sum_bits = 0
|
|
sum_numel = 0
|
|
head_bpw = 0
|
|
head_numel = 0
|
|
for name, module in model.named_modules():
|
|
if isinstance(module, AnyPrecisionLinear):
|
|
mods = {"qweight": module.qweight}
|
|
match module.precision:
|
|
case 1:
|
|
mods.update({"g_idx": module.lut1})
|
|
case 2:
|
|
mods.update({"g_idx": module.lut2})
|
|
case 3:
|
|
mods.update({"g_idx": module.lut3})
|
|
case 4:
|
|
mods.update({"g_idx": module.lut4})
|
|
case 5:
|
|
mods.update({"g_idx": module.lut5})
|
|
case 6:
|
|
mods.update({"g_idx": module.lut6})
|
|
case 7:
|
|
mods.update({"g_idx": module.lut7})
|
|
case 8:
|
|
mods.update({"g_idx": module.lut8})
|
|
sum_bits += get_tensors_size(mods)
|
|
sum_numel += module.in_features * module.out_features
|
|
elif any(isinstance(module, x) for x in [torch.nn.Linear]):
|
|
if module.out_features >= model.config.vocab_size * 0.9:
|
|
head_bpw = module.weight.element_size() * 8
|
|
head_numel = module.weight.numel()
|
|
else:
|
|
sum_bits += get_tensor_size(module.weight)
|
|
sum_numel += module.weight.numel()
|
|
vram_bits = head_numel * head_bpw + sum_bits
|
|
return sum_bits / sum_numel, head_bpw, vram_bits
|
|
|
|
@torch.inference_mode
|
|
def load_anyprecision(model_dir: str, auto = False, bf16 = False):
|
|
model = AnyPrecisionForCausalLM.from_quantized(
|
|
model_dir,
|
|
)
|
|
bpw_layer, bpw_head, vram_bits = get_storage_info(model)
|
|
return model, bpw_layer, bpw_head, vram_bits
|
|
|
|
@torch.inference_mode
|
|
def fwd_anyprecision(model_instance, input_ids: torch.Tensor):
|
|
input_ids = input_ids.to("cuda:0")
|
|
output = model_instance(input_ids)
|
|
return output.logits
|