mirror of
https://github.com/turboderp-org/exllamav3.git
synced 2026-04-20 14:29:51 +00:00
Conversion: Add fallback quant method for layers with all-zero H, and tolerate matrices with rows/columns of zeros
This commit is contained in:
@@ -467,13 +467,19 @@ def main(args, job_state):
|
||||
assert isinstance(linear.inner, LinearEXL3)
|
||||
linear.inner.swap_cpu()
|
||||
|
||||
flags = "o" if quant_args_local["apply_out_scales"] else "."
|
||||
proxy_err_str = f"{proxy_err:8.6f}" if proxy_err >= 0.0 else "(OoM) "
|
||||
flags_local = "o" if quant_args_local["apply_out_scales"] else "."
|
||||
flags_local += "f" if quant_args_local["q_fallback"] else "."
|
||||
proxy_err_str_local = (
|
||||
"(zero) " if quant_args_local["zeros"] else
|
||||
"(big) " if proxy_err >= 9.9 else
|
||||
f"{proxy_err:8.6f}" if proxy_err >= 0.0 else
|
||||
"(OoM) "
|
||||
)
|
||||
print(
|
||||
f" -- Quantized: {linear.key:{config.stc.max_key_len() + 8}}"
|
||||
f" bpw: {quant_args_local['K']:5.2f}"
|
||||
f" proxy_err: {proxy_err_str}"
|
||||
f" {flags}"
|
||||
f" proxy_err: {proxy_err_str_local}"
|
||||
f" {flags_local}"
|
||||
f" g_sc: {quant_args_local['g_scale']:.6f}"
|
||||
)
|
||||
with progress_lock:
|
||||
@@ -539,7 +545,13 @@ def main(args, job_state):
|
||||
assert isinstance(linear.inner, LinearEXL3)
|
||||
linear.inner.swap_cpu()
|
||||
flags = "o" if quant_args["apply_out_scales"] else "."
|
||||
proxy_err_str = f"{proxy_err:8.6f}" if proxy_err >= 0.0 else "(OoM) "
|
||||
flags += "f" if quant_args["q_fallback"] else "."
|
||||
proxy_err_str = (
|
||||
"(zero) " if quant_args["zeros"] else
|
||||
"(big) " if proxy_err >= 9.9 else
|
||||
f"{proxy_err:8.6f}" if proxy_err >= 0.0 else
|
||||
"(OoM) "
|
||||
)
|
||||
print(
|
||||
f" -- Quantized: {linear.key:{config.stc.max_key_len() + 8}}"
|
||||
f" bpw: {quant_args['K']:5.2f}"
|
||||
|
||||
@@ -37,7 +37,7 @@ class Linear(Module):
|
||||
out_dtype: torch.dtype | None = None,
|
||||
allow_input_padding: bool = False,
|
||||
post_scale: float = 1.0,
|
||||
transposed_load: bool = True
|
||||
transposed_load: bool = True,
|
||||
):
|
||||
super().__init__(config, key, qmap)
|
||||
|
||||
@@ -349,6 +349,9 @@ class Linear(Module):
|
||||
key = self.key
|
||||
)
|
||||
|
||||
if quant_args["q_fallback"]:
|
||||
proxy_err = 0.0
|
||||
|
||||
if return_weight_q:
|
||||
return proxy_err, weight_q
|
||||
else:
|
||||
|
||||
@@ -443,6 +443,111 @@ def ldlq(
|
||||
return weight_q, encoded
|
||||
|
||||
|
||||
def fallback_quant(
|
||||
weight: torch.Tensor,
|
||||
q_device: torch.Tensor,
|
||||
quant_args: dict,
|
||||
pb: ProgressBar | None = None
|
||||
):
|
||||
"""
|
||||
Perform the same quantization as ldlq() but without an LDL decomposition
|
||||
|
||||
:param weight:
|
||||
Input weights, shape (k, n). If device is "cpu", result is collected on CPU as well, saving a bunch of
|
||||
VRAM but adding a little PCIe overhead and many sync points
|
||||
|
||||
:param q_device:
|
||||
Target device
|
||||
|
||||
:param quant_args:
|
||||
dict:
|
||||
- K: bitrate
|
||||
- buf_size_k: buffer size for faux-LDLQ, along k
|
||||
|
||||
:param pb:
|
||||
Optional ProgressPar to update, k // 16 steps
|
||||
|
||||
:return:
|
||||
tuple:
|
||||
- quantized weight, shape (k, n)
|
||||
- indices (unpacked), shape (k // 16, n // 16, 256), uint16_t
|
||||
"""
|
||||
|
||||
devices = quant_args["devices"]
|
||||
for device in devices:
|
||||
torch.cuda.synchronize(device)
|
||||
main_stream = get_quant_stream(devices[0])
|
||||
with torch.cuda.stream(main_stream):
|
||||
|
||||
devices = quant_args["devices"]
|
||||
device = weight.device
|
||||
assert device == torch.device(devices[0])
|
||||
|
||||
buffer_device = weight.device
|
||||
size_k, size_n = weight.shape # Row-major
|
||||
assert size_k % 16 == 0
|
||||
assert size_n % 128 == 0
|
||||
tiles_k = size_k // 16
|
||||
tiles_n = size_n // 16
|
||||
|
||||
buf_size_k = max(quant_args.get("buf_size_k", 128), 16)
|
||||
assert buf_size_k % 16 == 0
|
||||
assert size_n % buf_size_k == 0
|
||||
|
||||
p_row = 0
|
||||
|
||||
# Work buffers
|
||||
weight_q = torch.zeros((size_k, size_n), dtype = torch.float, device = buffer_device)
|
||||
encoded = torch.zeros((tiles_k, tiles_n, 256), dtype = torch.short, device = buffer_device)
|
||||
|
||||
for j in range(size_k, 0, -buf_size_k):
|
||||
i = j - buf_size_k
|
||||
|
||||
# Current span is rows i:j
|
||||
b_weight = weight[i:j].to(device)
|
||||
b_weight_q = weight_q[i:j] if device == buffer_device else \
|
||||
torch.zeros_like(weight_q[i:j], device = device)
|
||||
b_encoded = encoded[i // 16 : j // 16] if device == buffer_device else \
|
||||
torch.zeros_like(encoded[i // 16 : j // 16], device = device)
|
||||
|
||||
# Iterate over rows of blocks in current span
|
||||
for bj in range(buf_size_k, 0, -16):
|
||||
bi = bj - 16
|
||||
|
||||
# Input tiles for quantization
|
||||
rows = b_weight[bi:bj]
|
||||
tiles = rows.reshape(16, tiles_n, 16).permute(1, 0, 2).reshape(tiles_n, 256)
|
||||
|
||||
# Pre-permute to tensor core layout
|
||||
tiles = tiles[:, tensor_core_perm(device)]
|
||||
|
||||
# Quantize
|
||||
quant_w, quant_i = quantize_tiles_multigpu(tiles, quant_args)
|
||||
|
||||
# Undo permutation on reconstructed tiles, but keep indices in tensor core layout
|
||||
quant_w = quant_w[:, tensor_core_perm_i(device)]
|
||||
|
||||
# Store result
|
||||
quant_w = quant_w.reshape(tiles_n, 16, 16).permute(1, 0, 2).reshape(16, size_n)
|
||||
b_weight_q[bi:bj] = quant_w
|
||||
b_encoded[bi // 16 : bj // 16] = quant_i.unsqueeze(0)
|
||||
|
||||
# Update progress
|
||||
if pb:
|
||||
p_row += 1
|
||||
pb.update(p_row)
|
||||
|
||||
# Collect output
|
||||
if device != buffer_device:
|
||||
weight_q[i:j] = b_weight_q.to(buffer_device)
|
||||
encoded[i // 16 : j // 16] = b_encoded.to(buffer_device)
|
||||
|
||||
for device in devices:
|
||||
torch.cuda.synchronize(device)
|
||||
|
||||
return weight_q, encoded
|
||||
|
||||
|
||||
finalize_capture_H_mutex = threading.Lock()
|
||||
|
||||
def finalize_capture_H(H_data: dict, quant_args: dict, verbose: bool):
|
||||
@@ -455,13 +560,19 @@ def finalize_capture_H(H_data: dict, quant_args: dict, verbose: bool):
|
||||
|
||||
H = H_data["H"]
|
||||
if H_data["finalized"]:
|
||||
return H, H_data["L"], H_data["su"], H_data["diag"]
|
||||
return H_data["q_fallback"], H, H_data["L"], H_data["su"], H_data["diag"]
|
||||
|
||||
# Mean of samples summed up during forward pass
|
||||
H /= H_data["count"]
|
||||
# Switch to uncalibrated fallback if no input activations or diagonal is too small (few activations)
|
||||
count = H_data["count"]
|
||||
if count == 0:
|
||||
q_fallback = True
|
||||
else:
|
||||
H /= count
|
||||
diag_mean = torch.diag(H).mean()
|
||||
q_fallback = diag_mean.item() < 1e-20
|
||||
|
||||
# Regularize diagonal
|
||||
diag_mean = torch.diag(H).mean()
|
||||
idx = torch.arange(H.shape[0])
|
||||
H[idx, idx] += quant_args.get("sigma_reg", 0.025) * diag_mean
|
||||
|
||||
@@ -485,9 +596,13 @@ def finalize_capture_H(H_data: dict, quant_args: dict, verbose: bool):
|
||||
blockwise_preapply_had_l_(H, had_k)
|
||||
|
||||
# Get block LDL decomposition of H, zero diagonal
|
||||
L, H = block_ldl(H, 16, verbose)
|
||||
dr = torch.arange(k)
|
||||
L[dr, dr] = 0
|
||||
if q_fallback:
|
||||
L = None
|
||||
else:
|
||||
L, H = block_ldl(H, 16, verbose)
|
||||
dr = torch.arange(k)
|
||||
L[dr, dr] = 0
|
||||
|
||||
H_data["L"] = L
|
||||
|
||||
# H is no longer needed except to compute proxy error so move to CPU
|
||||
@@ -496,7 +611,8 @@ def finalize_capture_H(H_data: dict, quant_args: dict, verbose: bool):
|
||||
|
||||
H_data["finalized"] = True
|
||||
H_data["diag"] = diag
|
||||
return H, L, su, diag
|
||||
H_data["q_fallback"] = q_fallback
|
||||
return q_fallback, H, L, su, diag
|
||||
|
||||
|
||||
def pack_trellis(encoded: torch.Tensor, quant_args: dict) -> torch.Tensor:
|
||||
@@ -652,7 +768,8 @@ def regularize(
|
||||
verbose: bool,
|
||||
H_diag: torch.Tensor | None,
|
||||
pb: ProgressBar | None,
|
||||
skip_g_scale: bool = False
|
||||
skip_g_scale: bool = False,
|
||||
q_fallback: bool = False
|
||||
):
|
||||
force_out_scales = quant_args["apply_out_scales"]
|
||||
|
||||
@@ -669,7 +786,7 @@ def regularize(
|
||||
# the input to the linear layer is very irregular. After some testing, set the cutoff at 15% of the RMS sum
|
||||
# on 2% of the channels
|
||||
# TODO: More science
|
||||
if H_diag is not None:
|
||||
if not q_fallback and H_diag is not None:
|
||||
diag = H_diag.sqrt()
|
||||
diag, _ = torch.sort(diag, descending = True)
|
||||
cutoff = diag.shape[0] // 50
|
||||
@@ -685,10 +802,23 @@ def regularize(
|
||||
else:
|
||||
apply_out_scales = True if force_out_scales is None else force_out_scales
|
||||
|
||||
if q_fallback:
|
||||
apply_out_scales = force_out_scales
|
||||
|
||||
# Apply output scales
|
||||
out_channel_scales = block_rms(weight, dim = 0, keepdim = True)
|
||||
mean = out_channel_scales.mean().item()
|
||||
if mean > 1e-30:
|
||||
out_channel_scales /= mean
|
||||
quant_args["zeros"] = False
|
||||
else:
|
||||
quant_args["zeros"] = True
|
||||
if force_out_scales is not None:
|
||||
apply_out_scales = True
|
||||
zero_out_scales = out_channel_scales.abs() < 1e-30
|
||||
|
||||
if apply_out_scales:
|
||||
out_channel_scales = block_rms(weight, dim = 0, keepdim = True)
|
||||
out_channel_scales /= out_channel_scales.mean()
|
||||
out_channel_scales[zero_out_scales] = 0.1
|
||||
sv = (sv * out_channel_scales + 1e-10).float()
|
||||
if verbose:
|
||||
out_channel_std = out_channel_scales.std().item()
|
||||
@@ -698,11 +828,15 @@ def regularize(
|
||||
# Output sign flips (and scales)
|
||||
weight /= sv
|
||||
|
||||
# Force zero output channels to zero
|
||||
sv[zero_out_scales] = 0.0
|
||||
|
||||
# Output hadamard transform
|
||||
blockwise_preapply_had_r_(weight, had_n)
|
||||
|
||||
# Input sign flips and scales
|
||||
in_channel_scales = block_rms(weight, dim = 1, keepdim = True)
|
||||
in_channel_scales[in_channel_scales.abs() < 1e-30] = 0.1
|
||||
su = (su * in_channel_scales / (-codebook_scale) + 1e-10).float() # mustn't be inplace
|
||||
weight /= su
|
||||
blockwise_preapply_had_l_(weight, had_k)
|
||||
@@ -790,16 +924,22 @@ def quantize_exl3(
|
||||
k, n = weight.shape
|
||||
|
||||
# Get H, LDL decomp. and input/output sign flips
|
||||
H, L, su, H_diag = finalize_capture_H(H_data, quant_args, verbose)
|
||||
if H.is_cuda: H = H.to(device)
|
||||
if L.is_cuda: L = L.to(device)
|
||||
if su.is_cuda: su = su.to(device)
|
||||
if H_diag.is_cuda: H_diag = H_diag.to(device)
|
||||
q_fallback, H, L, su, H_diag = finalize_capture_H(H_data, quant_args, verbose)
|
||||
if H.is_cuda:
|
||||
H = H.to(device)
|
||||
if L is not None and L.is_cuda:
|
||||
L = L.to(device)
|
||||
if su.is_cuda:
|
||||
su = su.to(device)
|
||||
if H_diag.is_cuda:
|
||||
H_diag = H_diag.to(device)
|
||||
sv = (torch.randn(n, device = device).sign() + 1e-5).sign().to(torch.float).unsqueeze(0)
|
||||
|
||||
# Move stored L to CPU (if not already), move working L to device
|
||||
H_data["L"] = H_data["L"].cpu()
|
||||
L = L.to(device)
|
||||
if H_data["L"] is not None:
|
||||
H_data["L"] = H_data["L"].cpu()
|
||||
if L is not None:
|
||||
L = L.to(device)
|
||||
|
||||
if swap_to_device is not None:
|
||||
weight = weight.to(swap_to_device)
|
||||
@@ -820,7 +960,8 @@ def quantize_exl3(
|
||||
quant_args,
|
||||
verbose,
|
||||
H_diag,
|
||||
pb
|
||||
pb,
|
||||
q_fallback = q_fallback
|
||||
)
|
||||
|
||||
if save_reg:
|
||||
@@ -840,40 +981,46 @@ def quantize_exl3(
|
||||
weight_r = weight_r.cpu()
|
||||
|
||||
# Quantize
|
||||
weight_q, encoded_q = ldlq(weight_r, L, quant_args, pb)
|
||||
del L
|
||||
if not q_fallback:
|
||||
weight_q, encoded_q = ldlq(weight_r, L, quant_args, pb) #zxc
|
||||
del L
|
||||
else:
|
||||
weight_q, encoded_q = fallback_quant(weight_r, device, quant_args, pb) # zxc
|
||||
|
||||
pb.update(tiles_k)
|
||||
|
||||
# Metrics
|
||||
try:
|
||||
def block_trace(A, B, block_size = 1024):
|
||||
total = 0.0
|
||||
for j_start in range(0, B.shape[1], block_size):
|
||||
j_end = min(j_start + block_size, B.shape[1])
|
||||
B_block = B[:, j_start:j_end]
|
||||
A_j_block = A[j_start:j_end, :]
|
||||
partial = torch.einsum("ik,ij,jk->", A, B_block, A_j_block)
|
||||
total += partial.item()
|
||||
return total
|
||||
E = weight_r - weight_q # may run on CPU
|
||||
W = weight_r
|
||||
Hd = H.to(device)
|
||||
weight_r = None
|
||||
E = E.to(device)
|
||||
num = block_trace(E, Hd)
|
||||
E = None
|
||||
W = W.to(device)
|
||||
den = block_trace(W, Hd)
|
||||
W = None
|
||||
Hd = None
|
||||
proxy_err = num / max(den, 1e-8)
|
||||
except torch.OutOfMemoryError:
|
||||
weight_r = None
|
||||
E = None
|
||||
W = None
|
||||
Hd = None
|
||||
proxy_err = -1.0
|
||||
if not q_fallback:
|
||||
try:
|
||||
def block_trace(A, B, block_size = 1024):
|
||||
total = 0.0
|
||||
for j_start in range(0, B.shape[1], block_size):
|
||||
j_end = min(j_start + block_size, B.shape[1])
|
||||
B_block = B[:, j_start:j_end]
|
||||
A_j_block = A[j_start:j_end, :]
|
||||
partial = torch.einsum("ik,ij,jk->", A, B_block, A_j_block)
|
||||
total += partial.item()
|
||||
return total
|
||||
E = weight_r - weight_q # may run on CPU
|
||||
W = weight_r
|
||||
Hd = H.to(device)
|
||||
weight_r = None
|
||||
E = E.to(device)
|
||||
num = block_trace(E, Hd)
|
||||
E = None
|
||||
W = W.to(device)
|
||||
den = block_trace(W, Hd)
|
||||
W = None
|
||||
Hd = None
|
||||
proxy_err = num / max(den, 1e-8)
|
||||
except torch.OutOfMemoryError:
|
||||
weight_r = None
|
||||
E = None
|
||||
W = None
|
||||
Hd = None
|
||||
proxy_err = -1.0
|
||||
else:
|
||||
proxy_err = 0.0
|
||||
|
||||
# free_mem()
|
||||
|
||||
@@ -918,6 +1065,7 @@ def quantize_exl3(
|
||||
quant_args.update({
|
||||
"apply_out_scales": apply_out_scales,
|
||||
"g_scale": g_scale,
|
||||
"q_fallback": q_fallback,
|
||||
})
|
||||
|
||||
return weight_q, proxy_err, out_tensors
|
||||
@@ -170,13 +170,13 @@ def save_tensor_image(
|
||||
t = t.detach().to("cpu", copy = True).float()
|
||||
|
||||
k = 3
|
||||
mu, sigma = t.mean(), t.std()
|
||||
lo, hi = mu - k * sigma, mu + k * sigma
|
||||
_, sigma = t.mean(), t.std()
|
||||
lo, hi = -k * sigma, k * sigma
|
||||
t.clamp_(lo, hi)
|
||||
t -= lo
|
||||
t /= (hi - lo + 1e-8)
|
||||
|
||||
rgba = cm.get_cmap("gnuplot2")(t.numpy())
|
||||
rgba = cm.get_cmap("berlin")(t.numpy())
|
||||
rgb8 = (rgba[..., :3] * 255).astype("uint8")
|
||||
im = Image.fromarray(rgb8)
|
||||
im.save(path)
|
||||
|
||||
Reference in New Issue
Block a user