Conversion: Add fallback quant method for layers with all-zero H, and tolerate matrices with rows/columns of zeros

This commit is contained in:
turboderp
2026-03-05 00:29:59 +01:00
parent c21108341a
commit 144d826dda
4 changed files with 221 additions and 58 deletions

View File

@@ -467,13 +467,19 @@ def main(args, job_state):
assert isinstance(linear.inner, LinearEXL3)
linear.inner.swap_cpu()
flags = "o" if quant_args_local["apply_out_scales"] else "."
proxy_err_str = f"{proxy_err:8.6f}" if proxy_err >= 0.0 else "(OoM) "
flags_local = "o" if quant_args_local["apply_out_scales"] else "."
flags_local += "f" if quant_args_local["q_fallback"] else "."
proxy_err_str_local = (
"(zero) " if quant_args_local["zeros"] else
"(big) " if proxy_err >= 9.9 else
f"{proxy_err:8.6f}" if proxy_err >= 0.0 else
"(OoM) "
)
print(
f" -- Quantized: {linear.key:{config.stc.max_key_len() + 8}}"
f" bpw: {quant_args_local['K']:5.2f}"
f" proxy_err: {proxy_err_str}"
f" {flags}"
f" proxy_err: {proxy_err_str_local}"
f" {flags_local}"
f" g_sc: {quant_args_local['g_scale']:.6f}"
)
with progress_lock:
@@ -539,7 +545,13 @@ def main(args, job_state):
assert isinstance(linear.inner, LinearEXL3)
linear.inner.swap_cpu()
flags = "o" if quant_args["apply_out_scales"] else "."
proxy_err_str = f"{proxy_err:8.6f}" if proxy_err >= 0.0 else "(OoM) "
flags += "f" if quant_args["q_fallback"] else "."
proxy_err_str = (
"(zero) " if quant_args["zeros"] else
"(big) " if proxy_err >= 9.9 else
f"{proxy_err:8.6f}" if proxy_err >= 0.0 else
"(OoM) "
)
print(
f" -- Quantized: {linear.key:{config.stc.max_key_len() + 8}}"
f" bpw: {quant_args['K']:5.2f}"

View File

@@ -37,7 +37,7 @@ class Linear(Module):
out_dtype: torch.dtype | None = None,
allow_input_padding: bool = False,
post_scale: float = 1.0,
transposed_load: bool = True
transposed_load: bool = True,
):
super().__init__(config, key, qmap)
@@ -349,6 +349,9 @@ class Linear(Module):
key = self.key
)
if quant_args["q_fallback"]:
proxy_err = 0.0
if return_weight_q:
return proxy_err, weight_q
else:

View File

@@ -443,6 +443,111 @@ def ldlq(
return weight_q, encoded
def fallback_quant(
weight: torch.Tensor,
q_device: torch.Tensor,
quant_args: dict,
pb: ProgressBar | None = None
):
"""
Perform the same quantization as ldlq() but without an LDL decomposition
:param weight:
Input weights, shape (k, n). If device is "cpu", result is collected on CPU as well, saving a bunch of
VRAM but adding a little PCIe overhead and many sync points
:param q_device:
Target device
:param quant_args:
dict:
- K: bitrate
- buf_size_k: buffer size for faux-LDLQ, along k
:param pb:
Optional ProgressPar to update, k // 16 steps
:return:
tuple:
- quantized weight, shape (k, n)
- indices (unpacked), shape (k // 16, n // 16, 256), uint16_t
"""
devices = quant_args["devices"]
for device in devices:
torch.cuda.synchronize(device)
main_stream = get_quant_stream(devices[0])
with torch.cuda.stream(main_stream):
devices = quant_args["devices"]
device = weight.device
assert device == torch.device(devices[0])
buffer_device = weight.device
size_k, size_n = weight.shape # Row-major
assert size_k % 16 == 0
assert size_n % 128 == 0
tiles_k = size_k // 16
tiles_n = size_n // 16
buf_size_k = max(quant_args.get("buf_size_k", 128), 16)
assert buf_size_k % 16 == 0
assert size_n % buf_size_k == 0
p_row = 0
# Work buffers
weight_q = torch.zeros((size_k, size_n), dtype = torch.float, device = buffer_device)
encoded = torch.zeros((tiles_k, tiles_n, 256), dtype = torch.short, device = buffer_device)
for j in range(size_k, 0, -buf_size_k):
i = j - buf_size_k
# Current span is rows i:j
b_weight = weight[i:j].to(device)
b_weight_q = weight_q[i:j] if device == buffer_device else \
torch.zeros_like(weight_q[i:j], device = device)
b_encoded = encoded[i // 16 : j // 16] if device == buffer_device else \
torch.zeros_like(encoded[i // 16 : j // 16], device = device)
# Iterate over rows of blocks in current span
for bj in range(buf_size_k, 0, -16):
bi = bj - 16
# Input tiles for quantization
rows = b_weight[bi:bj]
tiles = rows.reshape(16, tiles_n, 16).permute(1, 0, 2).reshape(tiles_n, 256)
# Pre-permute to tensor core layout
tiles = tiles[:, tensor_core_perm(device)]
# Quantize
quant_w, quant_i = quantize_tiles_multigpu(tiles, quant_args)
# Undo permutation on reconstructed tiles, but keep indices in tensor core layout
quant_w = quant_w[:, tensor_core_perm_i(device)]
# Store result
quant_w = quant_w.reshape(tiles_n, 16, 16).permute(1, 0, 2).reshape(16, size_n)
b_weight_q[bi:bj] = quant_w
b_encoded[bi // 16 : bj // 16] = quant_i.unsqueeze(0)
# Update progress
if pb:
p_row += 1
pb.update(p_row)
# Collect output
if device != buffer_device:
weight_q[i:j] = b_weight_q.to(buffer_device)
encoded[i // 16 : j // 16] = b_encoded.to(buffer_device)
for device in devices:
torch.cuda.synchronize(device)
return weight_q, encoded
finalize_capture_H_mutex = threading.Lock()
def finalize_capture_H(H_data: dict, quant_args: dict, verbose: bool):
@@ -455,13 +560,19 @@ def finalize_capture_H(H_data: dict, quant_args: dict, verbose: bool):
H = H_data["H"]
if H_data["finalized"]:
return H, H_data["L"], H_data["su"], H_data["diag"]
return H_data["q_fallback"], H, H_data["L"], H_data["su"], H_data["diag"]
# Mean of samples summed up during forward pass
H /= H_data["count"]
# Switch to uncalibrated fallback if no input activations or diagonal is too small (few activations)
count = H_data["count"]
if count == 0:
q_fallback = True
else:
H /= count
diag_mean = torch.diag(H).mean()
q_fallback = diag_mean.item() < 1e-20
# Regularize diagonal
diag_mean = torch.diag(H).mean()
idx = torch.arange(H.shape[0])
H[idx, idx] += quant_args.get("sigma_reg", 0.025) * diag_mean
@@ -485,9 +596,13 @@ def finalize_capture_H(H_data: dict, quant_args: dict, verbose: bool):
blockwise_preapply_had_l_(H, had_k)
# Get block LDL decomposition of H, zero diagonal
L, H = block_ldl(H, 16, verbose)
dr = torch.arange(k)
L[dr, dr] = 0
if q_fallback:
L = None
else:
L, H = block_ldl(H, 16, verbose)
dr = torch.arange(k)
L[dr, dr] = 0
H_data["L"] = L
# H is no longer needed except to compute proxy error so move to CPU
@@ -496,7 +611,8 @@ def finalize_capture_H(H_data: dict, quant_args: dict, verbose: bool):
H_data["finalized"] = True
H_data["diag"] = diag
return H, L, su, diag
H_data["q_fallback"] = q_fallback
return q_fallback, H, L, su, diag
def pack_trellis(encoded: torch.Tensor, quant_args: dict) -> torch.Tensor:
@@ -652,7 +768,8 @@ def regularize(
verbose: bool,
H_diag: torch.Tensor | None,
pb: ProgressBar | None,
skip_g_scale: bool = False
skip_g_scale: bool = False,
q_fallback: bool = False
):
force_out_scales = quant_args["apply_out_scales"]
@@ -669,7 +786,7 @@ def regularize(
# the input to the linear layer is very irregular. After some testing, set the cutoff at 15% of the RMS sum
# on 2% of the channels
# TODO: More science
if H_diag is not None:
if not q_fallback and H_diag is not None:
diag = H_diag.sqrt()
diag, _ = torch.sort(diag, descending = True)
cutoff = diag.shape[0] // 50
@@ -685,10 +802,23 @@ def regularize(
else:
apply_out_scales = True if force_out_scales is None else force_out_scales
if q_fallback:
apply_out_scales = force_out_scales
# Apply output scales
out_channel_scales = block_rms(weight, dim = 0, keepdim = True)
mean = out_channel_scales.mean().item()
if mean > 1e-30:
out_channel_scales /= mean
quant_args["zeros"] = False
else:
quant_args["zeros"] = True
if force_out_scales is not None:
apply_out_scales = True
zero_out_scales = out_channel_scales.abs() < 1e-30
if apply_out_scales:
out_channel_scales = block_rms(weight, dim = 0, keepdim = True)
out_channel_scales /= out_channel_scales.mean()
out_channel_scales[zero_out_scales] = 0.1
sv = (sv * out_channel_scales + 1e-10).float()
if verbose:
out_channel_std = out_channel_scales.std().item()
@@ -698,11 +828,15 @@ def regularize(
# Output sign flips (and scales)
weight /= sv
# Force zero output channels to zero
sv[zero_out_scales] = 0.0
# Output hadamard transform
blockwise_preapply_had_r_(weight, had_n)
# Input sign flips and scales
in_channel_scales = block_rms(weight, dim = 1, keepdim = True)
in_channel_scales[in_channel_scales.abs() < 1e-30] = 0.1
su = (su * in_channel_scales / (-codebook_scale) + 1e-10).float() # mustn't be inplace
weight /= su
blockwise_preapply_had_l_(weight, had_k)
@@ -790,16 +924,22 @@ def quantize_exl3(
k, n = weight.shape
# Get H, LDL decomp. and input/output sign flips
H, L, su, H_diag = finalize_capture_H(H_data, quant_args, verbose)
if H.is_cuda: H = H.to(device)
if L.is_cuda: L = L.to(device)
if su.is_cuda: su = su.to(device)
if H_diag.is_cuda: H_diag = H_diag.to(device)
q_fallback, H, L, su, H_diag = finalize_capture_H(H_data, quant_args, verbose)
if H.is_cuda:
H = H.to(device)
if L is not None and L.is_cuda:
L = L.to(device)
if su.is_cuda:
su = su.to(device)
if H_diag.is_cuda:
H_diag = H_diag.to(device)
sv = (torch.randn(n, device = device).sign() + 1e-5).sign().to(torch.float).unsqueeze(0)
# Move stored L to CPU (if not already), move working L to device
H_data["L"] = H_data["L"].cpu()
L = L.to(device)
if H_data["L"] is not None:
H_data["L"] = H_data["L"].cpu()
if L is not None:
L = L.to(device)
if swap_to_device is not None:
weight = weight.to(swap_to_device)
@@ -820,7 +960,8 @@ def quantize_exl3(
quant_args,
verbose,
H_diag,
pb
pb,
q_fallback = q_fallback
)
if save_reg:
@@ -840,40 +981,46 @@ def quantize_exl3(
weight_r = weight_r.cpu()
# Quantize
weight_q, encoded_q = ldlq(weight_r, L, quant_args, pb)
del L
if not q_fallback:
weight_q, encoded_q = ldlq(weight_r, L, quant_args, pb) #zxc
del L
else:
weight_q, encoded_q = fallback_quant(weight_r, device, quant_args, pb) # zxc
pb.update(tiles_k)
# Metrics
try:
def block_trace(A, B, block_size = 1024):
total = 0.0
for j_start in range(0, B.shape[1], block_size):
j_end = min(j_start + block_size, B.shape[1])
B_block = B[:, j_start:j_end]
A_j_block = A[j_start:j_end, :]
partial = torch.einsum("ik,ij,jk->", A, B_block, A_j_block)
total += partial.item()
return total
E = weight_r - weight_q # may run on CPU
W = weight_r
Hd = H.to(device)
weight_r = None
E = E.to(device)
num = block_trace(E, Hd)
E = None
W = W.to(device)
den = block_trace(W, Hd)
W = None
Hd = None
proxy_err = num / max(den, 1e-8)
except torch.OutOfMemoryError:
weight_r = None
E = None
W = None
Hd = None
proxy_err = -1.0
if not q_fallback:
try:
def block_trace(A, B, block_size = 1024):
total = 0.0
for j_start in range(0, B.shape[1], block_size):
j_end = min(j_start + block_size, B.shape[1])
B_block = B[:, j_start:j_end]
A_j_block = A[j_start:j_end, :]
partial = torch.einsum("ik,ij,jk->", A, B_block, A_j_block)
total += partial.item()
return total
E = weight_r - weight_q # may run on CPU
W = weight_r
Hd = H.to(device)
weight_r = None
E = E.to(device)
num = block_trace(E, Hd)
E = None
W = W.to(device)
den = block_trace(W, Hd)
W = None
Hd = None
proxy_err = num / max(den, 1e-8)
except torch.OutOfMemoryError:
weight_r = None
E = None
W = None
Hd = None
proxy_err = -1.0
else:
proxy_err = 0.0
# free_mem()
@@ -918,6 +1065,7 @@ def quantize_exl3(
quant_args.update({
"apply_out_scales": apply_out_scales,
"g_scale": g_scale,
"q_fallback": q_fallback,
})
return weight_q, proxy_err, out_tensors

View File

@@ -170,13 +170,13 @@ def save_tensor_image(
t = t.detach().to("cpu", copy = True).float()
k = 3
mu, sigma = t.mean(), t.std()
lo, hi = mu - k * sigma, mu + k * sigma
_, sigma = t.mean(), t.std()
lo, hi = -k * sigma, k * sigma
t.clamp_(lo, hi)
t -= lo
t /= (hi - lo + 1e-8)
rgba = cm.get_cmap("gnuplot2")(t.numpy())
rgba = cm.get_cmap("berlin")(t.numpy())
rgb8 = (rgba[..., :3] * 255).astype("uint8")
im = Image.fromarray(rgb8)
im.save(path)