Files
exllamav2/conversion/adaptivegptq.py
2024-05-18 06:43:27 +02:00

675 lines
22 KiB
Python

from __future__ import annotations
import gc
import torch
from torch import nn
import torch.nn.functional as F
import math
from exllamav2.ext import exllamav2_ext as ext_c, none_tensor
from exllamav2.util import list_live_tensors
import sys
class AdaptiveQuantizer:
norm: float = 3.5
max_p: float = 1.2
min_p: float = 0.70
p_grid: int = 96
bits: int
scale_bits: int
scale_range: float = 1.0
scale: torch.tensor
qscale: torch.tensor
qscale_max: float
maxq: float
scale_maxq: float
qzero: float
def __init__(self,
bits: int = 4,
scale_bits: int = 4):
self.bits = bits
self.scale_bits = scale_bits
self.maxq = 2 ** bits - 1
self.qzero = (self.maxq + 1) / 2
self.scale_maxq = 2 ** scale_bits - 1
self.scale_maxq = (2 ** self.scale_bits) - 1
def find_params(self, x):
xmax, _ = torch.max(torch.abs(x), dim = 0)
xmax += 1e-12
base_scale = xmax / (self.maxq / 2)
qscale_max_t = torch.max(base_scale) * self.scale_range
scale_tp = base_scale / qscale_max_t
scale_tp = torch.sqrt(scale_tp)
scale_tp *= (self.scale_maxq + 1)
qscale_t = torch.clamp(torch.round(scale_tp), 1, self.scale_maxq + 1)
qscale_tw = qscale_t / (self.scale_maxq + 1)
qscale_tw = qscale_tw ** 2
qscale_tw *= qscale_max_t
q = torch.zeros((self.p_grid + 1, 128), dtype = torch.float, device = x.device)
ext_c.quantize_err(x, q, qscale_tw, self.qzero, self.maxq, self.norm, self.min_p, self.max_p, self.p_grid)
q = torch.sum(q, dim = 1)
best_pi = torch.argmin(q)
best_pif = best_pi / self.p_grid
best_p = self.max_p * best_pif + self.min_p * (1 - best_pif)
self.qscale = qscale_t.to(torch.short)
self.scale = qscale_tw * best_p
self.qscale_max = qscale_max_t * best_p
# Make sure scales are rounded correctly for sanity test
prescale = torch.tensor([1 / 256], dtype = torch.half, device = self.scale.device)
self.scale = ((self.qscale * self.qscale).to(torch.half) * (self.qscale_max.half() * prescale)).float()
class AdaptiveGPTQ:
percdamp: float = 0.12
layer: nn.Linear
device: torch.device
group_size: int | dict
bits: list
bits_groups: list
scale_bits: int
hot_bits: int
columns: int
rows: int
hessian: torch.tensor
total_groups: int
perm: torch.Tensor | None
perm_cpu: torch.Tensor | None
invperm: torch.Tensor | None
# g_idx: torch.tensor = None
scale: torch.Tensor | None
qscale: torch.Tensor | None
qscale_max: torch.Tensor | None
qweight: torch.Tensor | None
qgroups: torch.Tensor | None
quant: torch.Tensor | None
weights: torch.Tensor | None
hessian: torch.Tensor | None
hessian_inv: torch.Tensor | None
num_samples: int = 0
num_batches: int = 0
quant_device: int = 0
hessian_device: int = 0
def __init__(self,
layer: nn.Linear):
self.layer = layer
self.device = layer.weight.device
self.rows = self.layer.weight.data.shape[1]
self.columns = self.layer.weight.data.shape[0]
# self.weights = self.layer.weight.data.T.clone().float().contiguous()
self.weights = None
self.hessian = None
self.num_samples = 0
self.num_batches = 0
self.perm = None
self.perm_cpu = None
self.inv_perm = None
self.scale = None
self.qscale = None
self.qscale_max = None
self.qweight = None
self.qgroups = None
self.quant_device = 0
self.hessian_device = 0
def drop_buffers(self):
self.perm = None
self.perm_cpu = None
self.invperm = None
# self.g_idx = None
self.scale = None
self.qscale = None
self.qscale_max = None
self.qweight = None
self.qgroups = None
self.quant = None
self.weights = None
self.hessian = None
self.hessian_inv = None
gc.collect()
torch.cuda.empty_cache()
def configure(self,
group_size: dict,
bits = None,
bits_prop = None,
scale_bits: int = 4
):
self.group_size = group_size
self.scale_bits = scale_bits
self.bits = bits
assert isinstance(bits, list)
assert isinstance(bits_prop, list)
assert sum(bits_prop) == 1
groups = 0
remaining_rows = self.rows
self.bits_groups = []
for b, p in zip(self.bits, bits_prop):
assert p > 0
gsz = self.group_size[b]
g = math.ceil(min(self.rows * p, remaining_rows) / gsz)
groups += g
remaining_rows -= g * gsz
self.bits_groups.append(g)
assert remaining_rows <= 0
self.total_groups = groups
def add_batch(self, inputs):
with torch.inference_mode():
# dim = inputs.shape[-1]
# inputs = inputs.view((-1, dim)).float().T.to("cuda:0")
# ns = 1
#
# if self.hessian is None:
# self.hessian = torch.zeros((dim, dim), device = self.device, dtype = torch.float)
# else:
# self.hessian.mul_(self.num_samples / (self.num_samples + ns))
# self.num_samples += ns
# inputs.mul_(math.sqrt(2 / self.num_samples))
# self.hessian.addmm_(inputs, inputs.T)
# self.num_batches += 1
# num_samples = len(inputs)
# # inputs = torch.cat(inputs, dim = 0)
# inputs = inputs.view((-1, inputs.shape[-1])).float().T.to("cuda:0")
# inputs *= math.sqrt(2 / num_samples)
# self.hessian += inputs.matmul(inputs.T)
if self.hessian is None:
self.hessian = torch.zeros((self.rows, self.rows), device=self.device, dtype=torch.float)
self.num_batches += 1
inputs = inputs.view((-1, inputs.shape[-1])).float().to("cuda:0")
self.hessian.addmm_(inputs.T, inputs)
def prepare(self, no_h_inv = False):
with torch.inference_mode():
self.hessian /= self.num_batches
diagonal = torch.diag(self.hessian)
# Prepare weights
self.weights = self.layer.weight.data.cpu().T.clone().float().contiguous()
# Zero weights that have no impact. Disabling this since it feels a little drastic based on just the calibration
# data. It likely never triggers, anyway.
# dead = diagonal == 0.0
# self.hessian[dead, dead] = 1
# self.weights[dead, :] = 0
# Activation order
self.perm = torch.argsort(diagonal, descending = True)
self.perm_cpu = self.perm.cpu()
self.weights = self.weights[self.perm_cpu, :]
if self.hessian.numel() > 6e8:
hessian_cpu = self.hessian.cpu()
self.hessian = None
hessian = hessian_cpu[self.perm_cpu][:, self.perm_cpu]
hessian = hessian.to("cuda:0")
else:
hessian = self.hessian[self.perm][:, self.perm]
self.hessian = None
# In case numerical errors have caused some asymmetry in H, assume it's close to symmetrical and force it.
# (Doesn't seem to be needed)
# torch.cuda.empty_cache()
# hessian = (hessian + hessian.T) * 0.5
# torch.cuda.empty_cache()
# Damping
diagonal = torch.diag(hessian)
damp = torch.clamp(self.percdamp * torch.mean(diagonal), min = 1e-5)
# Inverse of H
attempts = 0
while not no_h_inv:
try:
d = torch.arange(self.rows, device = self.device)
hessian[d, d] += damp
current_device = self.hessian_device
max_devices = torch.cuda.device_count()
done = False
fail_device = False
while not done:
try:
hessian = hessian.to(torch.device(current_device))
hessian_inv = torch.linalg.cholesky(hessian)
hessian_inv = torch.cholesky_inverse(hessian_inv)
# The Cholesky inverse will sometimes fail to compute due to accumulated rounding errors when H
# is very large (e.g. 70B MLP down proj) and a lot of calibration data is used (e.g. 100 rows of
# 4096 tokens). This won't always throw an exception and sometimes just results in a NaN tensor.
if torch.any(torch.isnan(hessian_inv)): raise RuntimeError
# Test inversion
hessian_inv = torch.linalg.cholesky(hessian_inv, upper = True)
hessian_inv = hessian_inv.contiguous()
done = True
break
except torch.cuda.OutOfMemoryError as e:
current_device += 1
print(f" !! Out of memory (H), moving to device {current_device}")
if current_device == max_devices:
raise e
self.hessian_device = current_device
if done: break
except RuntimeError as runtime_error:
if "out of memory" in str(runtime_error):
raise runtime_error
# If inverting failed, assume there were non-positive eigenvalues, so apply more damping to shift
# the eigenvalues in a positive direction.
print(" !! Warning: Applied additional damping")
attempts += 1
if attempts == 10:
raise ValueError("Hessian is not invertible")
# Swap H to system RAM
self.hessian_inv = None if no_h_inv else hessian_inv.cpu()
self.hessian = None
def reuse_h(self, other):
with torch.inference_mode():
# Prepare weights
self.weights = self.layer.weight.data.cpu().T.clone().float().contiguous()
self.hessian_inv = other.hessian_inv
self.hessian = None
self.perm = other.perm
self.perm_cpu = other.perm_cpu
self.weights = self.weights[self.perm_cpu, :]
def quantize_rtn_inplace(self, keep_qweight = False, apply = False):
assert apply and keep_qweight
with torch.inference_mode():
current_device = self.quant_device
max_devices = torch.cuda.device_count()
weights_cpu = self.weights.cpu().contiguous()
done = False
while not done:
try:
weights = weights_cpu.to(torch.device(current_device))
self.qweight = torch.zeros_like(weights, dtype = torch.short, device = torch.device(current_device))
num_groups = 0
for bits_idx in range(len(self.bits)):
num_groups += self.bits_groups[bits_idx]
scale = []
qscale = []
qscale_max = torch.empty((num_groups,), dtype = torch.float, device = torch.device(current_device))
qgroups = []
group_idx = 0
group_idx_list = []
b = 0
for bits_idx, bits in enumerate(self.bits):
quantizer = AdaptiveQuantizer(bits = bits, scale_bits = self.scale_bits)
for group in range(self.bits_groups[bits_idx]):
a = b
b = min(a + self.group_size[bits], self.rows)
qgroups.append(bits)
qgroups.append(0)
quantizer.find_params(weights[a : b, :])
scale.append(quantizer.scale)
qscale.append(quantizer.qscale)
qscale_max[group_idx] = quantizer.qscale_max
ext_c.quantize_range_inplace(weights,
quantizer.scale,
self.qweight,
quantizer.qzero,
quantizer.maxq,
a,
b)
group_idx_list += [group_idx] * (b - a)
group_idx += 1
done = True
except torch.cuda.OutOfMemoryError as e:
current_device += 1
print(f" !! Out of memory (Q), moving to device {current_device}")
if current_device == max_devices:
raise e
self.quant_device = current_device
# Create g_idx to store inverse activation order
self.invperm = torch.argsort(self.perm).to("cuda:0")
# Store scales
self.scale = torch.stack(scale, dim = 0).to("cuda:0")
self.qscale = torch.stack(qscale, dim = 0).to("cuda:0")
self.qscale_max = qscale_max.to(torch.float16).to("cuda:0")
self.qgroups = torch.tensor(qgroups, dtype = torch.short).to("cuda:0")
# I love Python
scale = None
qscale = None
qscale_max = None
qgroups = None
group_idx_list = None
qc = weights.cpu()
qc = qc.to(torch.half)
invperm = self.invperm.cpu()
q = qc[invperm, :].T
q = q.reshape(weights.T.shape)
dev = weights.device
weights = None
torch.cuda.synchronize()
gc.collect()
torch.cuda.empty_cache()
q = q.to(dev)
self.layer.weight.data = q
self.weights = q.T
def quantize(self, keep_qweight = False, apply = False):
with torch.inference_mode():
current_device = self.quant_device
max_devices = torch.cuda.device_count()
weights_cpu = self.weights.cpu()
done = False
while not done:
try:
hessian_inv_cuda = self.hessian_inv.to(torch.device(current_device))
# if apply:
# weights = self.weights
# self.layer.weight.data = torch.zeros((1, 1), dtype = torch.float32, device = weights.device)
# else:
# weights = self.weights.clone()
weights = weights_cpu.to(torch.device(current_device))
self.quant = torch.zeros_like(weights_cpu, device = torch.device(current_device))
if keep_qweight:
self.qweight = torch.zeros_like(weights, dtype = torch.short)
# Quantize groups
num_groups = 0
for bits_idx in range(len(self.bits)):
num_groups += self.bits_groups[bits_idx]
scale = []
qscale = []
qscale_max = torch.empty((num_groups,), dtype = torch.float, device = torch.device(current_device))
qgroups = []
error = weights.clone()
group_idx = 0
group_idx_list = []
b = 0
for bits_idx, bits in enumerate(self.bits):
quantizer = AdaptiveQuantizer(bits = bits, scale_bits = self.scale_bits)
for group in range(self.bits_groups[bits_idx]):
a = b
b = min(a + self.group_size[bits], self.rows)
qgroups.append(bits)
qgroups.append(0)
quantizer.find_params(weights[a : b, :])
scale.append(quantizer.scale)
qscale.append(quantizer.qscale)
qscale_max[group_idx] = quantizer.qscale_max
ext_c.quantize_range(self.quant,
quantizer.scale,
self.qweight if keep_qweight else none_tensor,
quantizer.qzero,
quantizer.maxq,
hessian_inv_cuda,
weights,
error,
a,
b)
group_idx_list += [group_idx] * (b - a)
group_idx += 1
done = True
except torch.cuda.OutOfMemoryError as e:
current_device += 1
print(f" !! Out of memory (Q), moving to device {current_device}")
if current_device == max_devices:
raise e
self.quant_device = current_device
# Create g_idx to store inverse activation order
# self.g_idx = torch.tensor(group_idx_list, dtype = torch.int32, device = self.device)
# self.g_idx = torch.tensor(group_idx_list, dtype = torch.int32)
self.quant = self.quant.to("cuda:0")
self.invperm = torch.argsort(self.perm).to("cuda:0")
# self.g_idx = self.g_idx[self.invperm]
# Store scales
self.scale = torch.stack(scale, dim = 0).to("cuda:0")
self.qscale = torch.stack(qscale, dim = 0).to("cuda:0")
self.qscale_max = qscale_max.to(torch.float16).to("cuda:0")
self.qgroups = torch.tensor(qgroups, dtype = torch.short).to("cuda:0")
# I love Python
weights = None
error = None
scale = None
qscale = None
qscale_max = None
qgroups = None
group_idx_list = None
# Apply
if apply:
self.apply_quant()
def quant_error(self):
with torch.inference_mode():
q = self.quant[self.invperm, :]
diff = torch.abs(q - self.layer.weight.data.T)
mat_error_1 = (diff > 0.01).sum().item() / diff.numel()
mat_error_5 = (diff > 0.05).sum().item() / diff.numel()
mat_error_10 = (diff > 0.10).sum().item() / diff.numel()
return mat_error_1, mat_error_5, mat_error_10
def apply_quant(self):
self.hessian = None
qc = self.quant.cpu()
invperm = self.invperm.cpu()
q = qc[invperm, :].T
q = q.reshape(self.quant.T.shape)
gc.collect()
torch.cuda.empty_cache()
q = q.to(torch.device(self.quant_device))
self.layer.weight.data = q
def apply_temp(self):
q = self.quant[self.invperm, :].T
temp_layer = nn.Linear(self.layer.in_features, self.layer.out_features, False, device = "meta", dtype = torch.float16)
temp_layer.weight = nn.Parameter(q.reshape(self.layer.weight.shape).type_as(self.layer.weight.data))
return temp_layer
def pack(self, key, qparams):
self.qgroups = self.qgroups.to("cuda:0")
# self.qscale_max = self.qscale_max.to("cude:0")
assert qparams.scale_bits in [4]
# assert self.columns % 32 == 0
output = {}
output[key + ".q_invperm"] = self.invperm.to(torch.int)
output[key + ".q_scale_max"] = self.qscale_max
output[key + ".q_groups"] = self.qgroups
if self.layer.bias is not None:
output[key + ".bias"] = self.layer.bias.data
columns = self.columns
rem_rows = self.rows
padding = -columns % 32
if padding != 0:
print(f" !! Note: Padding quantized tensor {key}")
qst = F.pad(self.qscale, (0, padding)).contiguous()
qwt = F.pad(self.qweight, (0, padding)).contiguous()
else:
qst = self.qscale
qwt = self.qweight
qst_packed = torch.zeros((qst.shape[0], qst.shape[1] * qparams.scale_bits // 32), dtype = torch.int32, device = self.device)
if qparams.scale_bits == 4: ext_c.pack_rows_4(qst, qst_packed)
output[key + ".q_scale"] = qst_packed
qwt_packed = []
i = 0
row = 0
out_row = 0
while i < self.qscale.shape[0]:
bits = self.qgroups[i * 2].item()
self.qgroups[i * 2 + 1] = out_row
i += 1
rows = min(self.group_size[bits], rem_rows)
wpqr = 32 / bits
qrows = rows / wpqr
assert i == self.qgroups.shape[-1] or qrows == int(qrows)
qrows = math.ceil(qrows)
g_qwt = qwt[row:row+rows, :].contiguous()
g_qwt_packed = torch.zeros((qrows, columns + padding), dtype = torch.int32, device = g_qwt.device)
if padding > 0: g_qwt[:, -padding:] = 2 ** (bits - 1)
ext_c.pack_columns(g_qwt, g_qwt_packed, bits)
qwt_packed.append(g_qwt_packed)
# print(row, rows, bits)
row += rows
out_row += qrows
rem_rows -= rows
qwt_packed = torch.cat(qwt_packed, dim = 0)
output[key + ".q_weight"] = qwt_packed
return output