VRAM optimizations during quant

This commit is contained in:
turboderp
2024-02-15 20:03:47 +01:00
parent 75f969a6d3
commit 702dd9740a
2 changed files with 26 additions and 12 deletions

View File

@@ -325,7 +325,7 @@ class AdaptiveGPTQ:
self.weights = self.weights[self.perm, :]
def quantize(self, keep_qweight = False, apply = False, drop = False):
def quantize(self, keep_qweight = False, apply = False):
with torch.inference_mode():
@@ -401,19 +401,19 @@ class AdaptiveGPTQ:
self.qscale_max = qscale_max.to(torch.float16)
self.qgroups = torch.tensor(qgroups, dtype = torch.short)
# I love Python
weights = None
error = None
scale = None
qscale = None
qscale_max = None
qgroups = None
group_idx_list = None
# Apply
if apply:
if drop:
weights = None
error = None
scale = None
qscale = None
qscale_max = None
qgroups = None
group_idx_list = None
gc.collect()
torch.cuda.empty_cache()
self.apply_quant()
@@ -431,10 +431,16 @@ class AdaptiveGPTQ:
def apply_quant(self):
self.hessian = None
qc = self.quant.cpu()
invperm = self.invperm.cpu()
q = qc[invperm, :].T
q = q.reshape(self.quant.T.shape)
gc.collect()
torch.cuda.empty_cache()
q = q.to(self.quant.device)
self.layer.weight.data = q