mirror of
https://github.com/turboderp-org/exllamav3.git
synced 2026-04-20 14:29:51 +00:00
Add prequant_test script
This commit is contained in:
349
eval/prequant_test.py
Normal file
349
eval/prequant_test.py
Normal file
@@ -0,0 +1,349 @@
|
||||
import sys, os
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
import argparse
|
||||
from exllamav3.util.file import disk_lru_cache, disk_lru_cache_clear
|
||||
from exllamav3.util.progress import ProgressBar
|
||||
from exllamav3.util.memory import free_mem
|
||||
from exllamav3 import Config, Model, Cache, Tokenizer, model_init
|
||||
from exllamav3.ext import exllamav3_ext as ext
|
||||
from datasets import load_dataset
|
||||
from exllamav3.modules import Linear
|
||||
from exllamav3.modules.quant import LinearFP16, LinearEXL3
|
||||
from exllamav3.modules.quant.exl3_lib.quantize import regularize
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import math
|
||||
import numpy as np
|
||||
import termplotlib as tpl
|
||||
|
||||
# ANSI codes
|
||||
ESC = "\u001b"
|
||||
col_default = "\u001b[0m"
|
||||
col_yellow = "\u001b[33;1m"
|
||||
col_blue = "\u001b[34;1m"
|
||||
col_green = "\u001b[32;1m"
|
||||
col_red = "\u001b[31;1m"
|
||||
col_purple = "\u001b[35;1m"
|
||||
col_cyan = "\u001b[36;1m"
|
||||
col_white = "\u001b[37;1m"
|
||||
|
||||
block_chars = [" ", "▁", "▂", "▃", "▄", "▅", "▆", "▇", "█"]
|
||||
|
||||
@disk_lru_cache("get_dataset_text")
|
||||
def get_dataset_text(spec: dict):
|
||||
assert spec["dataset"] == "wiki2", "Only wiki2 implemented atm"
|
||||
dataset_text = "\n\n".join(
|
||||
load_dataset("wikitext", "wikitext-2-raw-v1", split = "test")
|
||||
["text"]
|
||||
)
|
||||
return dataset_text
|
||||
|
||||
|
||||
def get_test_tokens(tokenizer, rows, bos, eval_len = 2048, eval_stride = 512):
|
||||
with ProgressBar("Tokenizing", rows) as pb:
|
||||
dataset_spec = { "dataset": "wiki2" }
|
||||
eval_tokens = tokenizer.encode(get_dataset_text(dataset_spec))
|
||||
num_tokens = eval_tokens.shape[-1]
|
||||
seqs = []
|
||||
for a in range(0, num_tokens - eval_len, eval_stride):
|
||||
b = a + eval_len
|
||||
if bos is not None:
|
||||
r = torch.cat((bos, eval_tokens[:, a:b-1]), dim = -1)
|
||||
else:
|
||||
r = eval_tokens[:, a:b]
|
||||
seqs.append(r)
|
||||
pb.update(len(seqs))
|
||||
if len(seqs) >= rows:
|
||||
break
|
||||
return torch.cat(seqs, dim = 0)[:, :]
|
||||
|
||||
|
||||
|
||||
@torch.compile(fullgraph = True, mode = "reduce-overhead")
|
||||
def count_threshold(x: torch.Tensor, abs_threshold: float) -> torch.Tensor:
|
||||
return (x.abs() > abs_threshold).sum(dtype = torch.int64)
|
||||
|
||||
|
||||
def bchart(bins, min_value, max_value, cc, height = 10):
|
||||
maxcount = bins.max()
|
||||
lines = []
|
||||
for r in range(height):
|
||||
line = cc
|
||||
for c in range(len(bins)):
|
||||
if maxcount == 0:
|
||||
b = block_chars[0]
|
||||
else:
|
||||
y = (bins[c] / maxcount) * height - r
|
||||
y = max(min(y, 1), 0)
|
||||
b = block_chars[int(y * 8)]
|
||||
line += b
|
||||
line += col_default
|
||||
lines.append(line)
|
||||
lines.reverse()
|
||||
return lines
|
||||
|
||||
|
||||
def histogram(args, tensor):
|
||||
|
||||
nbins = args.histogram_bins
|
||||
stddev = tensor.std()
|
||||
min_value = tensor.amin().item()
|
||||
max_value = tensor.amax().item()
|
||||
|
||||
# Middle histogram
|
||||
m_min_value = -stddev * 3
|
||||
m_max_value = stddev * 3
|
||||
|
||||
m_bins = torch.empty((nbins // 2,), dtype = torch.long, device = tensor.device)
|
||||
ext.histogram(tensor, m_bins, m_min_value, m_max_value, True)
|
||||
m_count = m_bins.sum().item()
|
||||
m_bins = m_bins.cpu().numpy()
|
||||
|
||||
# Low histogram
|
||||
l_min_value = min_value
|
||||
l_max_value = m_min_value
|
||||
l_bins = torch.empty((nbins // 4,), dtype = torch.long, device = tensor.device)
|
||||
ext.histogram(tensor, l_bins, l_min_value, l_max_value, True)
|
||||
l_count = l_bins.sum().item()
|
||||
l_bins = l_bins.cpu().numpy()
|
||||
|
||||
# High histogram
|
||||
h_min_value = m_max_value + 0.001
|
||||
h_max_value = max_value
|
||||
h_bins = torch.empty((nbins // 4,), dtype = torch.long, device = tensor.device)
|
||||
ext.histogram(tensor, h_bins, h_min_value, h_max_value, True)
|
||||
h_count = h_bins.sum().item()
|
||||
h_bins = h_bins.cpu().numpy()
|
||||
|
||||
total = l_count + m_count + h_count
|
||||
l_pct = l_count / total * 100.0
|
||||
m_pct = m_count / total * 100.0
|
||||
h_pct = h_count / total * 100.0
|
||||
|
||||
bc_l = bchart(l_bins, l_min_value, l_max_value, col_yellow)
|
||||
bc_m = bchart(m_bins, m_min_value, m_max_value, col_cyan)
|
||||
bc_h = bchart(h_bins, h_min_value, h_max_value, col_yellow)
|
||||
bc = [f"{l} {m} {h}" for l, m, h in zip(bc_l, bc_m, bc_h)]
|
||||
|
||||
bc.append(col_default + ("─" * (nbins // 4)) + "┬" + ("─" * (nbins // 2)) + "┬" + ("─" * (nbins // 4 )) + col_default)
|
||||
|
||||
for vl, vml, vmh, vh in zip(
|
||||
[f"{l_min_value:.2e}", f"{l_count:,} elem", f"{l_pct:.5f} %"],
|
||||
[f"{m_min_value:.2e}", f"{m_count:,} elem", f"{m_pct:.5f} %"],
|
||||
[f"{m_max_value:.2e}", "", ""],
|
||||
[f"{h_max_value:.2e}", f"{h_count:,} elem", f"{h_pct:.5f} %"],
|
||||
):
|
||||
ax = ""
|
||||
ax += col_yellow + vl + (" " * ((nbins // 4) - len(vl)))
|
||||
ax += " " + col_cyan + vml + (" " * ((nbins // 2) - len(vml) - len(vmh))) + vmh + " "
|
||||
ax += (" " * ((nbins // 4) - len(vh))) + col_yellow + vh + col_default
|
||||
bc.append(ax)
|
||||
|
||||
print("\n".join(bc))
|
||||
print()
|
||||
|
||||
|
||||
def stats(args, tensor):
|
||||
|
||||
# inf/NaN values
|
||||
inf_nan = torch.zeros(2, dtype = torch.long, device = tensor.device)
|
||||
ext.count_inf_nan(tensor, inf_nan)
|
||||
inf = inf_nan[0].item()
|
||||
nan = inf_nan[1].item()
|
||||
if inf or nan:
|
||||
print(f"{col_red}inf values : {col_white}{inf:,}{col_default}")
|
||||
print(f"{col_red}NaN values : {col_white}{nan:,}{col_default}")
|
||||
print(f"{col_red}Total numel : {col_white}{tensor.numel():,}{col_default}")
|
||||
print()
|
||||
|
||||
min_value = tensor.amin().item()
|
||||
max_value = tensor.amax().item()
|
||||
print(f"Min value : {col_white}{min_value:18.8f}{col_default}")
|
||||
print(f"Max value : {col_white}{max_value:18.8f}{col_default}")
|
||||
|
||||
sigma = tensor.std(unbiased = False)
|
||||
print(f"Std. deviation : {col_white}{sigma:18.8f}{col_default}")
|
||||
|
||||
# mu = tensor.mean()
|
||||
# std2 = tensor.var(unbiased = False)
|
||||
# std4 = std2 ** 2
|
||||
# kurt = ((tensor - mu) ** 4).mean() / (std4 + 1e-10)
|
||||
# print(f"Kurtosis : {col_white}{kurt:18.8f}{col_default}")
|
||||
|
||||
n6 = count_threshold(tensor, 6 * sigma)
|
||||
n6p = n6 / tensor.numel()
|
||||
if n6 > 0:
|
||||
print(f"Six-sigma exceedance : {col_white}{n6p:18.8f}{col_default} ({col_white}{n6:,} elem{col_default})")
|
||||
else:
|
||||
print(f"Six-sigma exceedance : {col_white}None{col_default}")
|
||||
|
||||
print()
|
||||
|
||||
|
||||
def inspect_state(args, state):
|
||||
state = state[:, args.skip_tokens:, :].to(torch.device(args.device))
|
||||
print(f"{col_blue}Hidden states{col_default}")
|
||||
print("─────────────")
|
||||
stats(args, state)
|
||||
histogram(args, state)
|
||||
|
||||
|
||||
def inspect_module(args, module):
|
||||
linears = [m for m in module if isinstance(m, Linear)]
|
||||
for linear in linears:
|
||||
print(f"{col_blue}{linear.key}{col_default}")
|
||||
print("─" * len(linear.key))
|
||||
w = linear.inner.get_weight_tensor()
|
||||
stats(args, w)
|
||||
|
||||
nbins = args.histogram_bins
|
||||
# bitrate = args.regularize_bits
|
||||
stddev = w.std(unbiased = False)
|
||||
|
||||
min_value = -stddev * 3
|
||||
max_value = stddev * 3
|
||||
bins = torch.empty((nbins // 2,), dtype = torch.long, device = w.device)
|
||||
ext.histogram(w, bins, min_value, max_value, False)
|
||||
|
||||
k, n = w.shape
|
||||
su = (torch.randn(k, device = w.device).sign() + 1e-5).sign().to(torch.float).unsqueeze(1)
|
||||
sv = (torch.randn(n, device = w.device).sign() + 1e-5).sign().to(torch.float).unsqueeze(0)
|
||||
quant_args = {
|
||||
# "K": bitrate,
|
||||
"apply_out_scales": None,
|
||||
"devices": [args.device],
|
||||
}
|
||||
apply_out_scales, w, g_scale, su, sv = regularize(
|
||||
w.float(),
|
||||
su,
|
||||
sv,
|
||||
quant_args,
|
||||
False,
|
||||
None,
|
||||
None,
|
||||
skip_g_scale = True
|
||||
)
|
||||
|
||||
r_stddev = w.std(unbiased = False)
|
||||
r_min_value = -r_stddev * 3
|
||||
r_max_value = r_stddev * 3
|
||||
r_bins = torch.empty((nbins // 2,), dtype = torch.long, device = w.device)
|
||||
ext.histogram(w, r_bins, r_min_value, r_max_value, False)
|
||||
|
||||
print(f"Reg. std. deviation : {col_white}{r_stddev:18.8f}{col_default}")
|
||||
w4 = count_threshold(w, 4)
|
||||
w4p = w4 / w.numel()
|
||||
if w4 > 0:
|
||||
print(f"Reg. outliers >4 : {col_yellow}{w4p:18.8f}{col_default} ({col_yellow}{w4:,} elem{col_default})")
|
||||
else:
|
||||
print(f"Reg. outliers >4 : {col_white}None{col_default}")
|
||||
w8 = count_threshold(w, 8)
|
||||
w8p = w8 / w.numel()
|
||||
if w8 > 0:
|
||||
print(f"Reg. outliers >8 : {col_yellow}{w8p:18.8f}{col_default} ({col_yellow}{w8:,} elem{col_default})")
|
||||
else:
|
||||
print(f"Reg. outliers >8 : {col_white}None{col_default}")
|
||||
print()
|
||||
|
||||
bc_pre = bchart(bins, min_value, max_value, col_default)
|
||||
bc_reg = bchart(r_bins, r_min_value, r_max_value, col_purple)
|
||||
bc = [f"{p} {r}" for p, r in zip(bc_pre, bc_reg)]
|
||||
|
||||
bc.append(col_default + ("─" * (nbins // 2)) + " " + ("─" * (nbins // 2)) + col_default)
|
||||
|
||||
for a, b, c, d in zip(
|
||||
[f"{min_value:.2e}", "Input layer"],
|
||||
[f"{max_value:.2e}", ""],
|
||||
[f"{r_min_value:.2e}", "Regularized layer"],
|
||||
[f"{r_max_value:.2e}", ""],
|
||||
):
|
||||
ax = ""
|
||||
ax += col_default + a + (" " * ((nbins // 2) - len(a) - len(b))) + b + col_default
|
||||
ax += " "
|
||||
ax += col_purple + c + (" " * ((nbins // 2) - len(c) - len(d))) + d + col_default
|
||||
bc.append(ax)
|
||||
|
||||
print("\n".join(bc))
|
||||
print()
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def main(args):
|
||||
|
||||
# Create model config
|
||||
config = Config.from_directory(args.model_dir)
|
||||
config.override_dynamic_seq_len(2048)
|
||||
tokenizer = Tokenizer.from_config(config)
|
||||
model = Model.from_config(config)
|
||||
|
||||
# Input state
|
||||
bos = None if not args.bos else torch.tensor([[config.bos_token_id]], dtype = torch.long)
|
||||
eval_ids = get_test_tokens(tokenizer, args.rows, bos)
|
||||
state = eval_ids
|
||||
|
||||
# Streaming forward pass
|
||||
for idx, module in enumerate(model.modules):
|
||||
|
||||
# Load next module
|
||||
print(f" -- Loading module: {col_green}{module.key}{col_default}")
|
||||
print()
|
||||
config.stc.begin_deferred_load()
|
||||
module.load(torch.device(args.device) if not module.caps.get("prefer_cpu") else "cpu")
|
||||
config.stc.end_deferred_load()
|
||||
inspect_module(args, module)
|
||||
|
||||
# Forward pass
|
||||
print(f" -- Forward pass")
|
||||
print()
|
||||
params = {}
|
||||
state = module.prepare_for_device(state, params)
|
||||
state = module.forward(state, params)
|
||||
inspect_state(args, state)
|
||||
|
||||
# Unload current module
|
||||
module.unload()
|
||||
config.stc.close()
|
||||
free_mem()
|
||||
|
||||
# Test perplexity
|
||||
vocab_size = tokenizer.actual_vocab_size
|
||||
logprob_sum = 0.0
|
||||
logprob_count = 0
|
||||
with ProgressBar("Evaluating", args.rows) as pb:
|
||||
for row in range(state.shape[0]):
|
||||
pb.update(row)
|
||||
input_ids = eval_ids[row:row + 1, :]
|
||||
logits = state[row:row+1, ...]
|
||||
logits = logits[:, :-1, :vocab_size].float()
|
||||
log_probs = F.log_softmax(logits, dim = -1)
|
||||
del logits
|
||||
target_ids = input_ids[:, 1:].to(log_probs.device)
|
||||
del input_ids
|
||||
target_log_probs = log_probs.gather(-1, target_ids.unsqueeze(-1)).squeeze(-1)
|
||||
logprob_sum += target_log_probs.sum().item()
|
||||
logprob_count += target_ids.numel()
|
||||
del log_probs
|
||||
del target_log_probs
|
||||
del target_ids
|
||||
torch.cuda.empty_cache()
|
||||
pb.update(args.rows)
|
||||
mean_log_prob = logprob_sum / logprob_count
|
||||
perplexity = math.exp(-mean_log_prob)
|
||||
|
||||
print(f"{col_blue}Outputs{col_default}")
|
||||
print("───────")
|
||||
print(f"Perplexity : {col_white}{perplexity:.6f}{col_default}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("-m", "--model_dir", type = str, help = "Path to model directory", required = True)
|
||||
parser.add_argument("-d", "--device", type = int, help = "CUDA device index", default = 0)
|
||||
parser.add_argument("-r", "--rows", type = int, help = "Number of rows", default = 10)
|
||||
parser.add_argument("-hb", "--histogram_bins", type = int, help = "Histogram bins", default = 160)
|
||||
parser.add_argument("-bos", "--bos", action = "store_true", help = "Add BOS token on each row")
|
||||
parser.add_argument("-skip", "--skip_tokens", type = int, help = "Skip tokens at start of context", default = 0)
|
||||
# parser.add_argument("-rb", "--regularize_bits", type = int, help = "Target bitrate for regularization test", default = 4)
|
||||
_args = parser.parse_args()
|
||||
main(_args)
|
||||
@@ -642,8 +642,9 @@ def regularize(
|
||||
sv: torch.Tensor,
|
||||
quant_args: dict,
|
||||
verbose: bool,
|
||||
H_diag: torch.Tensor,
|
||||
pb: ProgressBar
|
||||
H_diag: torch.Tensor | None,
|
||||
pb: ProgressBar | None,
|
||||
skip_g_scale: bool = False
|
||||
):
|
||||
force_out_scales = quant_args["apply_out_scales"]
|
||||
|
||||
@@ -660,17 +661,21 @@ def regularize(
|
||||
# the input to the linear layer is very irregular. After some testing, set the cutoff at 15% of the RMS sum
|
||||
# on 2% of the channels
|
||||
# TODO: More science
|
||||
diag = H_diag.sqrt()
|
||||
diag, _ = torch.sort(diag, descending = True)
|
||||
cutoff = diag.shape[0] // 50
|
||||
skew_factor = diag[:cutoff].sum() / diag.sum()
|
||||
if verbose:
|
||||
print(f" - input state skew: {skew_factor.item():.6f}")
|
||||
if H_diag is not None:
|
||||
diag = H_diag.sqrt()
|
||||
diag, _ = torch.sort(diag, descending = True)
|
||||
cutoff = diag.shape[0] // 50
|
||||
skew_factor = diag[:cutoff].sum() / diag.sum()
|
||||
if verbose:
|
||||
print(f" - input state skew: {skew_factor.item():.6f}")
|
||||
|
||||
if force_out_scales is None:
|
||||
apply_out_scales = skew_factor.item() < 0.15
|
||||
else:
|
||||
apply_out_scales = force_out_scales
|
||||
|
||||
if force_out_scales is None:
|
||||
apply_out_scales = skew_factor.item() < 0.15
|
||||
else:
|
||||
apply_out_scales = force_out_scales
|
||||
apply_out_scales = True if force_out_scales is None else force_out_scales
|
||||
|
||||
# Apply output scales
|
||||
if apply_out_scales:
|
||||
@@ -695,7 +700,10 @@ def regularize(
|
||||
blockwise_preapply_had_l_(weight, had_k)
|
||||
|
||||
# Determine best scale for matrix by test quantizing a sample of tiles along a wrapped diagonal
|
||||
g_scale, mse_scale = g_scale_gss(weight, False, quant_args, pb = pb)
|
||||
if not skip_g_scale:
|
||||
g_scale, mse_scale = g_scale_gss(weight, False, quant_args, pb = pb)
|
||||
else:
|
||||
g_scale = 1.0
|
||||
weight *= g_scale
|
||||
su /= g_scale
|
||||
|
||||
|
||||
Reference in New Issue
Block a user