Fix VRAM usage estimate for linear layer spanning multiple shards

This commit is contained in:
turboderp
2023-09-11 19:20:24 +02:00
parent c5c90a8b4b
commit 7704a6877b
2 changed files with 51 additions and 60 deletions

View File

@@ -6,20 +6,6 @@ from exllamav2.ext import exllamav2_ext as ext_c, none_tensor
from safetensors import safe_open
def _tsize(st, key):
tslice = st.get_slice(key)
shape = tslice.get_shape()
numel = 1
for x in shape: numel *= x
dtype = tslice.get_dtype()
if dtype == "I32": return numel * 4
elif dtype == "I16": return numel * 2
elif dtype == "F16": return numel * 2
elif dtype == "F32": return numel * 4
else: raise ValueError("Unexpected datatype: " + key)
class ExLlamaV2Linear(ExLlamaV2Module):
in_features: int
@@ -29,7 +15,6 @@ class ExLlamaV2Linear(ExLlamaV2Module):
linear: nn.Linear or None = None
q_handle: int or None = None
q_tensors: dict or None = None
footprint: int
name: str = "Linear"
@@ -74,48 +59,6 @@ class ExLlamaV2Linear(ExLlamaV2Module):
return self.linear.weight.data
def weight_footprint(self):
if self.footprint == -1:
# Torch linear layer
if self.key + ".weight" in self.model.config.tensor_file_map:
filename = self.model.config.tensor_file_map[self.key + ".weight"]
with safe_open(filename, framework="pt", device="cpu") as st:
self.footprint = 0
self.footprint += _tsize(st, self.key + ".weight")
# EXL2
elif self.key + ".q_weight" in self.model.config.tensor_file_map:
filename = self.model.config.tensor_file_map[self.key + ".q_weight"]
with safe_open(filename, framework="pt", device="cpu") as st:
self.footprint = 0
self.footprint += _tsize(st, self.key + ".q_weight") + 128
self.footprint += _tsize(st, self.key + ".q_invperm") + 128
self.footprint += _tsize(st, self.key + ".q_scale") + 128
self.footprint += _tsize(st, self.key + ".q_scale_max") + 128
self.footprint += _tsize(st, self.key + ".q_groups") + 128
self.footprint += _tsize(st, self.key + ".q_invperm") + 128
# GPTQ
elif self.key + ".qweight" in self.model.config.tensor_file_map:
filename = self.model.config.tensor_file_map[self.key + ".qweight"]
with safe_open(filename, framework="pt", device="cpu") as st:
self.footprint += _tsize(st, self.key + ".qweight") + 128
self.footprint += _tsize(st, self.key + ".qzeros") + 128
self.footprint += _tsize(st, self.key + ".scales") + 128
if self.key + ".g_idx" in self.model.config.tensor_file_map:
self.footprint += _tsize(st, self.key + ".g_idx") + 128
else:
raise ValueError("Can't find tensors in model files.")
return self.footprint
def scratch_space_fixed(self):
return self.temp_dq_size() + \

View File

@@ -3,21 +3,39 @@ import torch.nn as nn
from exllamav2.config import ExLlamaV2Config
from safetensors import safe_open
def _torch_device(idx):
if idx == -1: return "cpu"
return f"cuda:{idx}"
def _tsize(st, key):
tslice = st.get_slice(key)
shape = tslice.get_shape()
numel = 1
for x in shape: numel *= x
dtype = tslice.get_dtype()
if dtype == "I32": return numel * 4
elif dtype == "I16": return numel * 2
elif dtype == "F16": return numel * 2
elif dtype == "F32": return numel * 4
else: raise ValueError("Unexpected datatype: " + key)
class ExLlamaV2Module:
model = None
config: ExLlamaV2Config
key: str
device_idx: int
footprint: int
def __init__(self, model, key):
self.model = model
self.key = key
self.footprint = -1
def device(self):
@@ -25,11 +43,12 @@ class ExLlamaV2Module:
return _torch_device(self.device_idx)
def load_multi(self, keys):
def load_multi(self, keys, measure = False):
tensors = {}
submap = {}
submap_i = {}
size = 0
for k in keys:
ck = self.key + "." + k
@@ -44,9 +63,12 @@ class ExLlamaV2Module:
for v, ks in submap_i.items():
with safe_open(v, framework="pt", device="cpu") as st:
for k in ks:
tensors[k] = st.get_tensor(self.key + "." + k).to(self.device())
if measure:
size += _tsize(st, self.key + "." + k)
else:
tensors[k] = st.get_tensor(self.key + "." + k).to(self.device())
return tensors
return size if measure else tensors
def load_weight(self):
@@ -72,6 +94,32 @@ class ExLlamaV2Module:
return nn.Parameter(tensor)
def weight_footprint(self):
if self.footprint == -1:
# EXL2
if self.key + ".q_weight" in self.model.config.tensor_file_map:
self.footprint = self.load_multi(["q_weight", "q_invperm", "q_scale", "q_scale_max", "q_groups", "q_perm", "q_perm"], measure = True)
# GPTQ
elif self.key + ".qweight" in self.model.config.tensor_file_map:
self.footprint = self.load_multi(["qweight", "qzeros", "scales", "g_idx"], measure = True)
# Torch
elif self.key + ".weight" in self.model.config.tensor_file_map:
self.footprint = self.load_multi(["weight"], measure = True)
# Error
else: raise ValueError("Unknown tensor type: " + self.key)
return self.footprint
def set_device_idx(self, idx):
self.device_idx = idx