mirror of
https://github.com/turboderp-org/exllamav2.git
synced 2026-04-20 14:29:28 +00:00
Fix VRAM usage estimate for linear layer spanning multiple shards
This commit is contained in:
@@ -6,20 +6,6 @@ from exllamav2.ext import exllamav2_ext as ext_c, none_tensor
|
||||
from safetensors import safe_open
|
||||
|
||||
|
||||
def _tsize(st, key):
|
||||
|
||||
tslice = st.get_slice(key)
|
||||
shape = tslice.get_shape()
|
||||
numel = 1
|
||||
for x in shape: numel *= x
|
||||
dtype = tslice.get_dtype()
|
||||
if dtype == "I32": return numel * 4
|
||||
elif dtype == "I16": return numel * 2
|
||||
elif dtype == "F16": return numel * 2
|
||||
elif dtype == "F32": return numel * 4
|
||||
else: raise ValueError("Unexpected datatype: " + key)
|
||||
|
||||
|
||||
class ExLlamaV2Linear(ExLlamaV2Module):
|
||||
|
||||
in_features: int
|
||||
@@ -29,7 +15,6 @@ class ExLlamaV2Linear(ExLlamaV2Module):
|
||||
linear: nn.Linear or None = None
|
||||
q_handle: int or None = None
|
||||
q_tensors: dict or None = None
|
||||
footprint: int
|
||||
|
||||
name: str = "Linear"
|
||||
|
||||
@@ -74,48 +59,6 @@ class ExLlamaV2Linear(ExLlamaV2Module):
|
||||
return self.linear.weight.data
|
||||
|
||||
|
||||
def weight_footprint(self):
|
||||
|
||||
if self.footprint == -1:
|
||||
|
||||
# Torch linear layer
|
||||
|
||||
if self.key + ".weight" in self.model.config.tensor_file_map:
|
||||
filename = self.model.config.tensor_file_map[self.key + ".weight"]
|
||||
with safe_open(filename, framework="pt", device="cpu") as st:
|
||||
self.footprint = 0
|
||||
self.footprint += _tsize(st, self.key + ".weight")
|
||||
|
||||
# EXL2
|
||||
|
||||
elif self.key + ".q_weight" in self.model.config.tensor_file_map:
|
||||
filename = self.model.config.tensor_file_map[self.key + ".q_weight"]
|
||||
with safe_open(filename, framework="pt", device="cpu") as st:
|
||||
self.footprint = 0
|
||||
self.footprint += _tsize(st, self.key + ".q_weight") + 128
|
||||
self.footprint += _tsize(st, self.key + ".q_invperm") + 128
|
||||
self.footprint += _tsize(st, self.key + ".q_scale") + 128
|
||||
self.footprint += _tsize(st, self.key + ".q_scale_max") + 128
|
||||
self.footprint += _tsize(st, self.key + ".q_groups") + 128
|
||||
self.footprint += _tsize(st, self.key + ".q_invperm") + 128
|
||||
|
||||
# GPTQ
|
||||
|
||||
elif self.key + ".qweight" in self.model.config.tensor_file_map:
|
||||
filename = self.model.config.tensor_file_map[self.key + ".qweight"]
|
||||
with safe_open(filename, framework="pt", device="cpu") as st:
|
||||
self.footprint += _tsize(st, self.key + ".qweight") + 128
|
||||
self.footprint += _tsize(st, self.key + ".qzeros") + 128
|
||||
self.footprint += _tsize(st, self.key + ".scales") + 128
|
||||
if self.key + ".g_idx" in self.model.config.tensor_file_map:
|
||||
self.footprint += _tsize(st, self.key + ".g_idx") + 128
|
||||
|
||||
else:
|
||||
raise ValueError("Can't find tensors in model files.")
|
||||
|
||||
return self.footprint
|
||||
|
||||
|
||||
def scratch_space_fixed(self):
|
||||
|
||||
return self.temp_dq_size() + \
|
||||
|
||||
@@ -3,21 +3,39 @@ import torch.nn as nn
|
||||
from exllamav2.config import ExLlamaV2Config
|
||||
from safetensors import safe_open
|
||||
|
||||
|
||||
def _torch_device(idx):
|
||||
if idx == -1: return "cpu"
|
||||
return f"cuda:{idx}"
|
||||
|
||||
|
||||
def _tsize(st, key):
|
||||
|
||||
tslice = st.get_slice(key)
|
||||
shape = tslice.get_shape()
|
||||
numel = 1
|
||||
for x in shape: numel *= x
|
||||
dtype = tslice.get_dtype()
|
||||
if dtype == "I32": return numel * 4
|
||||
elif dtype == "I16": return numel * 2
|
||||
elif dtype == "F16": return numel * 2
|
||||
elif dtype == "F32": return numel * 4
|
||||
else: raise ValueError("Unexpected datatype: " + key)
|
||||
|
||||
|
||||
class ExLlamaV2Module:
|
||||
|
||||
model = None
|
||||
config: ExLlamaV2Config
|
||||
key: str
|
||||
device_idx: int
|
||||
footprint: int
|
||||
|
||||
def __init__(self, model, key):
|
||||
|
||||
self.model = model
|
||||
self.key = key
|
||||
self.footprint = -1
|
||||
|
||||
|
||||
def device(self):
|
||||
@@ -25,11 +43,12 @@ class ExLlamaV2Module:
|
||||
return _torch_device(self.device_idx)
|
||||
|
||||
|
||||
def load_multi(self, keys):
|
||||
def load_multi(self, keys, measure = False):
|
||||
|
||||
tensors = {}
|
||||
submap = {}
|
||||
submap_i = {}
|
||||
size = 0
|
||||
|
||||
for k in keys:
|
||||
ck = self.key + "." + k
|
||||
@@ -44,9 +63,12 @@ class ExLlamaV2Module:
|
||||
for v, ks in submap_i.items():
|
||||
with safe_open(v, framework="pt", device="cpu") as st:
|
||||
for k in ks:
|
||||
tensors[k] = st.get_tensor(self.key + "." + k).to(self.device())
|
||||
if measure:
|
||||
size += _tsize(st, self.key + "." + k)
|
||||
else:
|
||||
tensors[k] = st.get_tensor(self.key + "." + k).to(self.device())
|
||||
|
||||
return tensors
|
||||
return size if measure else tensors
|
||||
|
||||
|
||||
def load_weight(self):
|
||||
@@ -72,6 +94,32 @@ class ExLlamaV2Module:
|
||||
return nn.Parameter(tensor)
|
||||
|
||||
|
||||
def weight_footprint(self):
|
||||
|
||||
if self.footprint == -1:
|
||||
|
||||
# EXL2
|
||||
|
||||
if self.key + ".q_weight" in self.model.config.tensor_file_map:
|
||||
self.footprint = self.load_multi(["q_weight", "q_invperm", "q_scale", "q_scale_max", "q_groups", "q_perm", "q_perm"], measure = True)
|
||||
|
||||
# GPTQ
|
||||
|
||||
elif self.key + ".qweight" in self.model.config.tensor_file_map:
|
||||
self.footprint = self.load_multi(["qweight", "qzeros", "scales", "g_idx"], measure = True)
|
||||
|
||||
# Torch
|
||||
|
||||
elif self.key + ".weight" in self.model.config.tensor_file_map:
|
||||
self.footprint = self.load_multi(["weight"], measure = True)
|
||||
|
||||
# Error
|
||||
|
||||
else: raise ValueError("Unknown tensor type: " + self.key)
|
||||
|
||||
return self.footprint
|
||||
|
||||
|
||||
def set_device_idx(self, idx):
|
||||
|
||||
self.device_idx = idx
|
||||
|
||||
Reference in New Issue
Block a user