mirror of
https://github.com/turboderp-org/exllamav2.git
synced 2026-05-05 05:31:13 +00:00
675 lines
22 KiB
Python
675 lines
22 KiB
Python
from __future__ import annotations
|
|
|
|
import gc
|
|
import torch
|
|
from torch import nn
|
|
import torch.nn.functional as F
|
|
import math
|
|
from exllamav2.ext import exllamav2_ext as ext_c, none_tensor
|
|
from exllamav2.util import list_live_tensors
|
|
import sys
|
|
|
|
class AdaptiveQuantizer:
|
|
|
|
norm: float = 3.5
|
|
max_p: float = 1.2
|
|
min_p: float = 0.70
|
|
p_grid: int = 96
|
|
|
|
bits: int
|
|
scale_bits: int
|
|
scale_range: float = 1.0
|
|
|
|
scale: torch.tensor
|
|
qscale: torch.tensor
|
|
qscale_max: float
|
|
|
|
maxq: float
|
|
scale_maxq: float
|
|
qzero: float
|
|
|
|
def __init__(self,
|
|
bits: int = 4,
|
|
scale_bits: int = 4):
|
|
|
|
self.bits = bits
|
|
self.scale_bits = scale_bits
|
|
self.maxq = 2 ** bits - 1
|
|
self.qzero = (self.maxq + 1) / 2
|
|
self.scale_maxq = 2 ** scale_bits - 1
|
|
|
|
self.scale_maxq = (2 ** self.scale_bits) - 1
|
|
|
|
|
|
def find_params(self, x):
|
|
|
|
xmax, _ = torch.max(torch.abs(x), dim = 0)
|
|
xmax += 1e-12
|
|
|
|
base_scale = xmax / (self.maxq / 2)
|
|
qscale_max_t = torch.max(base_scale) * self.scale_range
|
|
|
|
scale_tp = base_scale / qscale_max_t
|
|
scale_tp = torch.sqrt(scale_tp)
|
|
scale_tp *= (self.scale_maxq + 1)
|
|
qscale_t = torch.clamp(torch.round(scale_tp), 1, self.scale_maxq + 1)
|
|
qscale_tw = qscale_t / (self.scale_maxq + 1)
|
|
qscale_tw = qscale_tw ** 2
|
|
qscale_tw *= qscale_max_t
|
|
|
|
q = torch.zeros((self.p_grid + 1, 128), dtype = torch.float, device = x.device)
|
|
ext_c.quantize_err(x, q, qscale_tw, self.qzero, self.maxq, self.norm, self.min_p, self.max_p, self.p_grid)
|
|
|
|
q = torch.sum(q, dim = 1)
|
|
best_pi = torch.argmin(q)
|
|
best_pif = best_pi / self.p_grid
|
|
best_p = self.max_p * best_pif + self.min_p * (1 - best_pif)
|
|
|
|
self.qscale = qscale_t.to(torch.short)
|
|
self.scale = qscale_tw * best_p
|
|
self.qscale_max = qscale_max_t * best_p
|
|
|
|
# Make sure scales are rounded correctly for sanity test
|
|
prescale = torch.tensor([1 / 256], dtype = torch.half, device = self.scale.device)
|
|
self.scale = ((self.qscale * self.qscale).to(torch.half) * (self.qscale_max.half() * prescale)).float()
|
|
|
|
|
|
class AdaptiveGPTQ:
|
|
|
|
percdamp: float = 0.12
|
|
|
|
layer: nn.Linear
|
|
device: torch.device
|
|
|
|
group_size: int | dict
|
|
bits: list
|
|
bits_groups: list
|
|
scale_bits: int
|
|
hot_bits: int
|
|
|
|
columns: int
|
|
rows: int
|
|
hessian: torch.tensor
|
|
total_groups: int
|
|
|
|
perm: torch.Tensor | None
|
|
perm_cpu: torch.Tensor | None
|
|
invperm: torch.Tensor | None
|
|
|
|
# g_idx: torch.tensor = None
|
|
scale: torch.Tensor | None
|
|
qscale: torch.Tensor | None
|
|
qscale_max: torch.Tensor | None
|
|
qweight: torch.Tensor | None
|
|
qgroups: torch.Tensor | None
|
|
|
|
quant: torch.Tensor | None
|
|
weights: torch.Tensor | None
|
|
hessian: torch.Tensor | None
|
|
hessian_inv: torch.Tensor | None
|
|
num_samples: int = 0
|
|
num_batches: int = 0
|
|
|
|
quant_device: int = 0
|
|
hessian_device: int = 0
|
|
|
|
|
|
def __init__(self,
|
|
layer: nn.Linear):
|
|
|
|
self.layer = layer
|
|
self.device = layer.weight.device
|
|
|
|
self.rows = self.layer.weight.data.shape[1]
|
|
self.columns = self.layer.weight.data.shape[0]
|
|
|
|
# self.weights = self.layer.weight.data.T.clone().float().contiguous()
|
|
self.weights = None
|
|
self.hessian = None
|
|
self.num_samples = 0
|
|
self.num_batches = 0
|
|
|
|
self.perm = None
|
|
self.perm_cpu = None
|
|
self.inv_perm = None
|
|
|
|
self.scale = None
|
|
self.qscale = None
|
|
self.qscale_max = None
|
|
self.qweight = None
|
|
self.qgroups = None
|
|
|
|
self.quant_device = 0
|
|
self.hessian_device = 0
|
|
|
|
|
|
def drop_buffers(self):
|
|
|
|
self.perm = None
|
|
self.perm_cpu = None
|
|
self.invperm = None
|
|
# self.g_idx = None
|
|
self.scale = None
|
|
self.qscale = None
|
|
self.qscale_max = None
|
|
self.qweight = None
|
|
self.qgroups = None
|
|
self.quant = None
|
|
self.weights = None
|
|
self.hessian = None
|
|
self.hessian_inv = None
|
|
|
|
gc.collect()
|
|
torch.cuda.empty_cache()
|
|
|
|
|
|
def configure(self,
|
|
group_size: dict,
|
|
bits = None,
|
|
bits_prop = None,
|
|
scale_bits: int = 4
|
|
):
|
|
|
|
self.group_size = group_size
|
|
self.scale_bits = scale_bits
|
|
self.bits = bits
|
|
|
|
assert isinstance(bits, list)
|
|
assert isinstance(bits_prop, list)
|
|
assert sum(bits_prop) == 1
|
|
|
|
groups = 0
|
|
remaining_rows = self.rows
|
|
self.bits_groups = []
|
|
for b, p in zip(self.bits, bits_prop):
|
|
assert p > 0
|
|
gsz = self.group_size[b]
|
|
g = math.ceil(min(self.rows * p, remaining_rows) / gsz)
|
|
groups += g
|
|
remaining_rows -= g * gsz
|
|
self.bits_groups.append(g)
|
|
|
|
assert remaining_rows <= 0
|
|
|
|
self.total_groups = groups
|
|
|
|
|
|
def add_batch(self, inputs):
|
|
|
|
with torch.inference_mode():
|
|
|
|
# dim = inputs.shape[-1]
|
|
# inputs = inputs.view((-1, dim)).float().T.to("cuda:0")
|
|
# ns = 1
|
|
#
|
|
# if self.hessian is None:
|
|
# self.hessian = torch.zeros((dim, dim), device = self.device, dtype = torch.float)
|
|
# else:
|
|
# self.hessian.mul_(self.num_samples / (self.num_samples + ns))
|
|
# self.num_samples += ns
|
|
# inputs.mul_(math.sqrt(2 / self.num_samples))
|
|
# self.hessian.addmm_(inputs, inputs.T)
|
|
|
|
# self.num_batches += 1
|
|
# num_samples = len(inputs)
|
|
# # inputs = torch.cat(inputs, dim = 0)
|
|
# inputs = inputs.view((-1, inputs.shape[-1])).float().T.to("cuda:0")
|
|
# inputs *= math.sqrt(2 / num_samples)
|
|
# self.hessian += inputs.matmul(inputs.T)
|
|
|
|
if self.hessian is None:
|
|
self.hessian = torch.zeros((self.rows, self.rows), device=self.device, dtype=torch.float)
|
|
|
|
self.num_batches += 1
|
|
inputs = inputs.view((-1, inputs.shape[-1])).float().to("cuda:0")
|
|
self.hessian.addmm_(inputs.T, inputs)
|
|
|
|
|
|
def prepare(self, no_h_inv = False):
|
|
|
|
with torch.inference_mode():
|
|
|
|
self.hessian /= self.num_batches
|
|
diagonal = torch.diag(self.hessian)
|
|
|
|
# Prepare weights
|
|
|
|
self.weights = self.layer.weight.data.cpu().T.clone().float().contiguous()
|
|
|
|
# Zero weights that have no impact. Disabling this since it feels a little drastic based on just the calibration
|
|
# data. It likely never triggers, anyway.
|
|
|
|
# dead = diagonal == 0.0
|
|
# self.hessian[dead, dead] = 1
|
|
# self.weights[dead, :] = 0
|
|
|
|
# Activation order
|
|
|
|
self.perm = torch.argsort(diagonal, descending = True)
|
|
self.perm_cpu = self.perm.cpu()
|
|
self.weights = self.weights[self.perm_cpu, :]
|
|
|
|
if self.hessian.numel() > 6e8:
|
|
hessian_cpu = self.hessian.cpu()
|
|
self.hessian = None
|
|
hessian = hessian_cpu[self.perm_cpu][:, self.perm_cpu]
|
|
hessian = hessian.to("cuda:0")
|
|
else:
|
|
hessian = self.hessian[self.perm][:, self.perm]
|
|
self.hessian = None
|
|
|
|
# In case numerical errors have caused some asymmetry in H, assume it's close to symmetrical and force it.
|
|
# (Doesn't seem to be needed)
|
|
|
|
# torch.cuda.empty_cache()
|
|
# hessian = (hessian + hessian.T) * 0.5
|
|
# torch.cuda.empty_cache()
|
|
|
|
# Damping
|
|
|
|
diagonal = torch.diag(hessian)
|
|
damp = torch.clamp(self.percdamp * torch.mean(diagonal), min = 1e-5)
|
|
|
|
# Inverse of H
|
|
|
|
attempts = 0
|
|
while not no_h_inv:
|
|
|
|
try:
|
|
|
|
d = torch.arange(self.rows, device = self.device)
|
|
hessian[d, d] += damp
|
|
|
|
current_device = self.hessian_device
|
|
max_devices = torch.cuda.device_count()
|
|
|
|
done = False
|
|
fail_device = False
|
|
while not done:
|
|
try:
|
|
hessian = hessian.to(torch.device(current_device))
|
|
|
|
hessian_inv = torch.linalg.cholesky(hessian)
|
|
hessian_inv = torch.cholesky_inverse(hessian_inv)
|
|
|
|
# The Cholesky inverse will sometimes fail to compute due to accumulated rounding errors when H
|
|
# is very large (e.g. 70B MLP down proj) and a lot of calibration data is used (e.g. 100 rows of
|
|
# 4096 tokens). This won't always throw an exception and sometimes just results in a NaN tensor.
|
|
|
|
if torch.any(torch.isnan(hessian_inv)): raise RuntimeError
|
|
|
|
# Test inversion
|
|
|
|
hessian_inv = torch.linalg.cholesky(hessian_inv, upper = True)
|
|
hessian_inv = hessian_inv.contiguous()
|
|
|
|
done = True
|
|
break
|
|
|
|
except torch.cuda.OutOfMemoryError as e:
|
|
current_device += 1
|
|
print(f" !! Out of memory (H), moving to device {current_device}")
|
|
if current_device == max_devices:
|
|
raise e
|
|
self.hessian_device = current_device
|
|
|
|
if done: break
|
|
|
|
except RuntimeError as runtime_error:
|
|
|
|
if "out of memory" in str(runtime_error):
|
|
raise runtime_error
|
|
|
|
# If inverting failed, assume there were non-positive eigenvalues, so apply more damping to shift
|
|
# the eigenvalues in a positive direction.
|
|
|
|
print(" !! Warning: Applied additional damping")
|
|
|
|
attempts += 1
|
|
if attempts == 10:
|
|
raise ValueError("Hessian is not invertible")
|
|
|
|
# Swap H to system RAM
|
|
|
|
self.hessian_inv = None if no_h_inv else hessian_inv.cpu()
|
|
self.hessian = None
|
|
|
|
|
|
def reuse_h(self, other):
|
|
|
|
with torch.inference_mode():
|
|
|
|
# Prepare weights
|
|
|
|
self.weights = self.layer.weight.data.cpu().T.clone().float().contiguous()
|
|
|
|
self.hessian_inv = other.hessian_inv
|
|
self.hessian = None
|
|
self.perm = other.perm
|
|
self.perm_cpu = other.perm_cpu
|
|
self.weights = self.weights[self.perm_cpu, :]
|
|
|
|
|
|
def quantize_rtn_inplace(self, keep_qweight = False, apply = False):
|
|
assert apply and keep_qweight
|
|
|
|
with torch.inference_mode():
|
|
|
|
current_device = self.quant_device
|
|
max_devices = torch.cuda.device_count()
|
|
|
|
weights_cpu = self.weights.cpu().contiguous()
|
|
|
|
done = False
|
|
while not done:
|
|
|
|
try:
|
|
|
|
weights = weights_cpu.to(torch.device(current_device))
|
|
self.qweight = torch.zeros_like(weights, dtype = torch.short, device = torch.device(current_device))
|
|
|
|
num_groups = 0
|
|
for bits_idx in range(len(self.bits)):
|
|
num_groups += self.bits_groups[bits_idx]
|
|
|
|
scale = []
|
|
qscale = []
|
|
qscale_max = torch.empty((num_groups,), dtype = torch.float, device = torch.device(current_device))
|
|
qgroups = []
|
|
|
|
group_idx = 0
|
|
group_idx_list = []
|
|
|
|
b = 0
|
|
for bits_idx, bits in enumerate(self.bits):
|
|
quantizer = AdaptiveQuantizer(bits = bits, scale_bits = self.scale_bits)
|
|
|
|
for group in range(self.bits_groups[bits_idx]):
|
|
a = b
|
|
b = min(a + self.group_size[bits], self.rows)
|
|
|
|
qgroups.append(bits)
|
|
qgroups.append(0)
|
|
|
|
quantizer.find_params(weights[a : b, :])
|
|
scale.append(quantizer.scale)
|
|
qscale.append(quantizer.qscale)
|
|
qscale_max[group_idx] = quantizer.qscale_max
|
|
|
|
ext_c.quantize_range_inplace(weights,
|
|
quantizer.scale,
|
|
self.qweight,
|
|
quantizer.qzero,
|
|
quantizer.maxq,
|
|
a,
|
|
b)
|
|
|
|
group_idx_list += [group_idx] * (b - a)
|
|
group_idx += 1
|
|
|
|
done = True
|
|
|
|
except torch.cuda.OutOfMemoryError as e:
|
|
current_device += 1
|
|
print(f" !! Out of memory (Q), moving to device {current_device}")
|
|
if current_device == max_devices:
|
|
raise e
|
|
self.quant_device = current_device
|
|
|
|
# Create g_idx to store inverse activation order
|
|
|
|
self.invperm = torch.argsort(self.perm).to("cuda:0")
|
|
|
|
# Store scales
|
|
|
|
self.scale = torch.stack(scale, dim = 0).to("cuda:0")
|
|
self.qscale = torch.stack(qscale, dim = 0).to("cuda:0")
|
|
self.qscale_max = qscale_max.to(torch.float16).to("cuda:0")
|
|
self.qgroups = torch.tensor(qgroups, dtype = torch.short).to("cuda:0")
|
|
|
|
# I love Python
|
|
|
|
scale = None
|
|
qscale = None
|
|
qscale_max = None
|
|
qgroups = None
|
|
group_idx_list = None
|
|
|
|
qc = weights.cpu()
|
|
qc = qc.to(torch.half)
|
|
invperm = self.invperm.cpu()
|
|
q = qc[invperm, :].T
|
|
q = q.reshape(weights.T.shape)
|
|
|
|
dev = weights.device
|
|
weights = None
|
|
torch.cuda.synchronize()
|
|
gc.collect()
|
|
torch.cuda.empty_cache()
|
|
|
|
q = q.to(dev)
|
|
self.layer.weight.data = q
|
|
self.weights = q.T
|
|
|
|
|
|
def quantize(self, keep_qweight = False, apply = False):
|
|
|
|
with torch.inference_mode():
|
|
|
|
current_device = self.quant_device
|
|
max_devices = torch.cuda.device_count()
|
|
|
|
weights_cpu = self.weights.cpu()
|
|
|
|
done = False
|
|
while not done:
|
|
|
|
try:
|
|
|
|
hessian_inv_cuda = self.hessian_inv.to(torch.device(current_device))
|
|
|
|
# if apply:
|
|
# weights = self.weights
|
|
# self.layer.weight.data = torch.zeros((1, 1), dtype = torch.float32, device = weights.device)
|
|
# else:
|
|
# weights = self.weights.clone()
|
|
weights = weights_cpu.to(torch.device(current_device))
|
|
|
|
self.quant = torch.zeros_like(weights_cpu, device = torch.device(current_device))
|
|
|
|
if keep_qweight:
|
|
self.qweight = torch.zeros_like(weights, dtype = torch.short)
|
|
|
|
# Quantize groups
|
|
|
|
num_groups = 0
|
|
for bits_idx in range(len(self.bits)):
|
|
num_groups += self.bits_groups[bits_idx]
|
|
|
|
scale = []
|
|
qscale = []
|
|
qscale_max = torch.empty((num_groups,), dtype = torch.float, device = torch.device(current_device))
|
|
qgroups = []
|
|
|
|
error = weights.clone()
|
|
group_idx = 0
|
|
group_idx_list = []
|
|
|
|
b = 0
|
|
for bits_idx, bits in enumerate(self.bits):
|
|
quantizer = AdaptiveQuantizer(bits = bits, scale_bits = self.scale_bits)
|
|
|
|
for group in range(self.bits_groups[bits_idx]):
|
|
a = b
|
|
b = min(a + self.group_size[bits], self.rows)
|
|
|
|
qgroups.append(bits)
|
|
qgroups.append(0)
|
|
|
|
quantizer.find_params(weights[a : b, :])
|
|
scale.append(quantizer.scale)
|
|
qscale.append(quantizer.qscale)
|
|
qscale_max[group_idx] = quantizer.qscale_max
|
|
|
|
ext_c.quantize_range(self.quant,
|
|
quantizer.scale,
|
|
self.qweight if keep_qweight else none_tensor,
|
|
quantizer.qzero,
|
|
quantizer.maxq,
|
|
hessian_inv_cuda,
|
|
weights,
|
|
error,
|
|
a,
|
|
b)
|
|
|
|
group_idx_list += [group_idx] * (b - a)
|
|
group_idx += 1
|
|
|
|
done = True
|
|
|
|
except torch.cuda.OutOfMemoryError as e:
|
|
current_device += 1
|
|
print(f" !! Out of memory (Q), moving to device {current_device}")
|
|
if current_device == max_devices:
|
|
raise e
|
|
self.quant_device = current_device
|
|
|
|
# Create g_idx to store inverse activation order
|
|
|
|
# self.g_idx = torch.tensor(group_idx_list, dtype = torch.int32, device = self.device)
|
|
# self.g_idx = torch.tensor(group_idx_list, dtype = torch.int32)
|
|
|
|
self.quant = self.quant.to("cuda:0")
|
|
self.invperm = torch.argsort(self.perm).to("cuda:0")
|
|
# self.g_idx = self.g_idx[self.invperm]
|
|
|
|
# Store scales
|
|
|
|
self.scale = torch.stack(scale, dim = 0).to("cuda:0")
|
|
self.qscale = torch.stack(qscale, dim = 0).to("cuda:0")
|
|
self.qscale_max = qscale_max.to(torch.float16).to("cuda:0")
|
|
self.qgroups = torch.tensor(qgroups, dtype = torch.short).to("cuda:0")
|
|
|
|
# I love Python
|
|
|
|
weights = None
|
|
error = None
|
|
scale = None
|
|
qscale = None
|
|
qscale_max = None
|
|
qgroups = None
|
|
group_idx_list = None
|
|
|
|
# Apply
|
|
|
|
if apply:
|
|
self.apply_quant()
|
|
|
|
|
|
def quant_error(self):
|
|
|
|
with torch.inference_mode():
|
|
|
|
q = self.quant[self.invperm, :]
|
|
diff = torch.abs(q - self.layer.weight.data.T)
|
|
mat_error_1 = (diff > 0.01).sum().item() / diff.numel()
|
|
mat_error_5 = (diff > 0.05).sum().item() / diff.numel()
|
|
mat_error_10 = (diff > 0.10).sum().item() / diff.numel()
|
|
return mat_error_1, mat_error_5, mat_error_10
|
|
|
|
|
|
def apply_quant(self):
|
|
|
|
self.hessian = None
|
|
|
|
qc = self.quant.cpu()
|
|
invperm = self.invperm.cpu()
|
|
q = qc[invperm, :].T
|
|
q = q.reshape(self.quant.T.shape)
|
|
|
|
gc.collect()
|
|
torch.cuda.empty_cache()
|
|
|
|
q = q.to(torch.device(self.quant_device))
|
|
self.layer.weight.data = q
|
|
|
|
|
|
def apply_temp(self):
|
|
|
|
q = self.quant[self.invperm, :].T
|
|
temp_layer = nn.Linear(self.layer.in_features, self.layer.out_features, False, device = "meta", dtype = torch.float16)
|
|
temp_layer.weight = nn.Parameter(q.reshape(self.layer.weight.shape).type_as(self.layer.weight.data))
|
|
return temp_layer
|
|
|
|
|
|
def pack(self, key, qparams):
|
|
|
|
self.qgroups = self.qgroups.to("cuda:0")
|
|
# self.qscale_max = self.qscale_max.to("cude:0")
|
|
|
|
assert qparams.scale_bits in [4]
|
|
# assert self.columns % 32 == 0
|
|
|
|
output = {}
|
|
output[key + ".q_invperm"] = self.invperm.to(torch.int)
|
|
output[key + ".q_scale_max"] = self.qscale_max
|
|
output[key + ".q_groups"] = self.qgroups
|
|
if self.layer.bias is not None:
|
|
output[key + ".bias"] = self.layer.bias.data
|
|
|
|
columns = self.columns
|
|
rem_rows = self.rows
|
|
padding = -columns % 32
|
|
|
|
if padding != 0:
|
|
print(f" !! Note: Padding quantized tensor {key}")
|
|
qst = F.pad(self.qscale, (0, padding)).contiguous()
|
|
qwt = F.pad(self.qweight, (0, padding)).contiguous()
|
|
else:
|
|
qst = self.qscale
|
|
qwt = self.qweight
|
|
|
|
qst_packed = torch.zeros((qst.shape[0], qst.shape[1] * qparams.scale_bits // 32), dtype = torch.int32, device = self.device)
|
|
if qparams.scale_bits == 4: ext_c.pack_rows_4(qst, qst_packed)
|
|
output[key + ".q_scale"] = qst_packed
|
|
|
|
qwt_packed = []
|
|
|
|
i = 0
|
|
row = 0
|
|
out_row = 0
|
|
while i < self.qscale.shape[0]:
|
|
|
|
bits = self.qgroups[i * 2].item()
|
|
self.qgroups[i * 2 + 1] = out_row
|
|
i += 1
|
|
|
|
rows = min(self.group_size[bits], rem_rows)
|
|
wpqr = 32 / bits
|
|
qrows = rows / wpqr
|
|
assert i == self.qgroups.shape[-1] or qrows == int(qrows)
|
|
qrows = math.ceil(qrows)
|
|
|
|
g_qwt = qwt[row:row+rows, :].contiguous()
|
|
g_qwt_packed = torch.zeros((qrows, columns + padding), dtype = torch.int32, device = g_qwt.device)
|
|
|
|
if padding > 0: g_qwt[:, -padding:] = 2 ** (bits - 1)
|
|
|
|
ext_c.pack_columns(g_qwt, g_qwt_packed, bits)
|
|
qwt_packed.append(g_qwt_packed)
|
|
|
|
# print(row, rows, bits)
|
|
|
|
row += rows
|
|
out_row += qrows
|
|
rem_rows -= rows
|
|
|
|
|
|
qwt_packed = torch.cat(qwt_packed, dim = 0)
|
|
output[key + ".q_weight"] = qwt_packed
|
|
|
|
return output
|
|
|
|
|
|
|