mirror of
https://github.com/turboderp-org/exllamav3.git
synced 2026-03-15 00:07:24 +00:00
356 lines
13 KiB
Python
356 lines
13 KiB
Python
import sys, os
|
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|
|
|
import argparse
|
|
from exllamav3.util.file import disk_lru_cache
|
|
from exllamav3.util.progress import ProgressBar
|
|
from exllamav3.util.memory import free_mem
|
|
from exllamav3 import Config, Model, Tokenizer
|
|
from exllamav3.ext import exllamav3_ext as ext
|
|
from datasets import load_dataset
|
|
from exllamav3.modules import Linear
|
|
from exllamav3.modules.quant.exl3_lib.quantize import regularize
|
|
import torch
|
|
import torch.nn.functional as F
|
|
import math
|
|
|
|
# ANSI codes
|
|
ESC = "\u001b"
|
|
col_default = "\u001b[0m"
|
|
col_yellow = "\u001b[33;1m"
|
|
col_blue = "\u001b[34;1m"
|
|
col_green = "\u001b[32;1m"
|
|
col_red = "\u001b[31;1m"
|
|
col_purple = "\u001b[35;1m"
|
|
col_cyan = "\u001b[36;1m"
|
|
col_white = "\u001b[37;1m"
|
|
|
|
block_chars = [" ", "▁", "▂", "▃", "▄", "▅", "▆", "▇", "█"]
|
|
|
|
@disk_lru_cache("get_dataset_text")
|
|
def get_dataset_text(spec: dict):
|
|
assert spec["dataset"] == "wiki2", "Only wiki2 implemented atm"
|
|
dataset_text = "\n\n".join(
|
|
load_dataset("wikitext", "wikitext-2-raw-v1", split = "test")
|
|
["text"]
|
|
)
|
|
return dataset_text
|
|
|
|
|
|
def get_test_tokens(tokenizer, rows, bos, eval_len = 2048, eval_stride = 512):
|
|
with ProgressBar("Tokenizing", rows) as pb:
|
|
dataset_spec = { "dataset": "wiki2" }
|
|
eval_tokens = tokenizer.encode(get_dataset_text(dataset_spec))
|
|
num_tokens = eval_tokens.shape[-1]
|
|
seqs = []
|
|
for a in range(0, num_tokens - eval_len, eval_stride):
|
|
b = a + eval_len
|
|
if bos is not None:
|
|
r = torch.cat((bos, eval_tokens[:, a:b-1]), dim = -1)
|
|
else:
|
|
r = eval_tokens[:, a:b]
|
|
seqs.append(r)
|
|
pb.update(len(seqs))
|
|
if len(seqs) >= rows:
|
|
break
|
|
return torch.cat(seqs, dim = 0)[:, :]
|
|
|
|
|
|
|
|
# @torch.compile(fullgraph = True, mode = "reduce-overhead")
|
|
def count_threshold(x: torch.Tensor, abs_threshold: float) -> torch.Tensor:
|
|
return (x.abs() > abs_threshold).sum(dtype = torch.int64)
|
|
|
|
|
|
def bchart(bins, min_value, max_value, cc, height = 10):
|
|
maxcount = bins.max()
|
|
lines = []
|
|
for r in range(height):
|
|
line = cc
|
|
for c in range(len(bins)):
|
|
if maxcount == 0:
|
|
b = block_chars[0]
|
|
else:
|
|
y = (bins[c] / maxcount) * height - r
|
|
y = max(min(y, 1), 0)
|
|
b = block_chars[int(y * 8)]
|
|
line += b
|
|
line += col_default
|
|
lines.append(line)
|
|
lines.reverse()
|
|
return lines
|
|
|
|
|
|
def histogram(args, tensor):
|
|
|
|
nbins = args.histogram_bins
|
|
stddev = tensor.std()
|
|
min_value = tensor.amin().item()
|
|
max_value = tensor.amax().item()
|
|
|
|
# Middle histogram
|
|
m_min_value = -stddev * 3
|
|
m_max_value = stddev * 3
|
|
|
|
m_bins = torch.empty((nbins // 2,), dtype = torch.long, device = tensor.device)
|
|
ext.histogram(tensor, m_bins, m_min_value, m_max_value, True)
|
|
m_count = m_bins.sum().item()
|
|
m_bins = m_bins.cpu().numpy()
|
|
|
|
# Low histogram
|
|
l_min_value = min_value
|
|
l_max_value = m_min_value
|
|
l_bins = torch.empty((nbins // 4,), dtype = torch.long, device = tensor.device)
|
|
ext.histogram(tensor, l_bins, l_min_value, l_max_value, True)
|
|
l_count = l_bins.sum().item()
|
|
l_bins = l_bins.cpu().numpy()
|
|
|
|
# High histogram
|
|
h_min_value = m_max_value + 0.001
|
|
h_max_value = max_value
|
|
h_bins = torch.empty((nbins // 4,), dtype = torch.long, device = tensor.device)
|
|
ext.histogram(tensor, h_bins, h_min_value, h_max_value, True)
|
|
h_count = h_bins.sum().item()
|
|
h_bins = h_bins.cpu().numpy()
|
|
|
|
total = l_count + m_count + h_count
|
|
l_pct = l_count / total * 100.0
|
|
m_pct = m_count / total * 100.0
|
|
h_pct = h_count / total * 100.0
|
|
|
|
bc_l = bchart(l_bins, l_min_value, l_max_value, col_yellow)
|
|
bc_m = bchart(m_bins, m_min_value, m_max_value, col_cyan)
|
|
bc_h = bchart(h_bins, h_min_value, h_max_value, col_yellow)
|
|
bc = [f"{l} {m} {h}" for l, m, h in zip(bc_l, bc_m, bc_h)]
|
|
|
|
bc.append(col_default + ("─" * (nbins // 4)) + "┬" + ("─" * (nbins // 2)) + "┬" + ("─" * (nbins // 4 )) + col_default)
|
|
|
|
for vl, vml, vmh, vh in zip(
|
|
[f"{l_min_value:.2e}", f"{l_count:,} elem", f"{l_pct:.5f} %"],
|
|
[f"{m_min_value:.2e}", f"{m_count:,} elem", f"{m_pct:.5f} %"],
|
|
[f"{m_max_value:.2e}", "", ""],
|
|
[f"{h_max_value:.2e}", f"{h_count:,} elem", f"{h_pct:.5f} %"],
|
|
):
|
|
ax = ""
|
|
ax += col_yellow + vl + (" " * ((nbins // 4) - len(vl)))
|
|
ax += " " + col_cyan + vml + (" " * ((nbins // 2) - len(vml) - len(vmh))) + vmh + " "
|
|
ax += (" " * ((nbins // 4) - len(vh))) + col_yellow + vh + col_default
|
|
bc.append(ax)
|
|
|
|
print("\n".join(bc))
|
|
print()
|
|
|
|
|
|
def stats(args, tensor):
|
|
|
|
# inf/NaN values
|
|
inf_nan = torch.zeros(2, dtype = torch.long, device = tensor.device)
|
|
ext.count_inf_nan(tensor, inf_nan)
|
|
inf = inf_nan[0].item()
|
|
nan = inf_nan[1].item()
|
|
if inf or nan:
|
|
print(f"{col_red}inf values : {col_white}{inf:,}{col_default}")
|
|
print(f"{col_red}NaN values : {col_white}{nan:,}{col_default}")
|
|
print(f"{col_red}Total numel : {col_white}{tensor.numel():,}{col_default}")
|
|
print()
|
|
|
|
min_value = tensor.amin().item()
|
|
max_value = tensor.amax().item()
|
|
print(f"Min value : {col_white}{min_value:18.8f}{col_default}")
|
|
print(f"Max value : {col_white}{max_value:18.8f}{col_default}")
|
|
|
|
sigma = tensor.std(unbiased = False)
|
|
print(f"Std. deviation : {col_white}{sigma:18.8f}{col_default}")
|
|
|
|
# mu = tensor.mean()
|
|
# std2 = tensor.var(unbiased = False)
|
|
# std4 = std2 ** 2
|
|
# kurt = ((tensor - mu) ** 4).mean() / (std4 + 1e-10)
|
|
# print(f"Kurtosis : {col_white}{kurt:18.8f}{col_default}")
|
|
|
|
n6 = count_threshold(tensor, 6 * sigma)
|
|
n6p = n6 / tensor.numel()
|
|
if n6 > 0:
|
|
print(f"Six-sigma exceedance : {col_white}{n6p:18.8f}{col_default} ({col_white}{n6:,} elem{col_default})")
|
|
else:
|
|
print(f"Six-sigma exceedance : {col_white}None{col_default}")
|
|
|
|
print()
|
|
|
|
|
|
def inspect_state(args, state):
|
|
state = state[:, args.skip_tokens:, :].to(torch.device(args.device))
|
|
print(f"{col_blue}Hidden states{col_default}")
|
|
print("─────────────")
|
|
stats(args, state)
|
|
histogram(args, state)
|
|
|
|
|
|
def inspect_module(args, module):
|
|
linears = [m for m in module if isinstance(m, Linear)]
|
|
for linear in linears:
|
|
print(f"{col_blue}{linear.key}{col_default}")
|
|
print("─" * len(linear.key))
|
|
w = linear.inner.get_weight_tensor()
|
|
stats(args, w)
|
|
|
|
nbins = args.histogram_bins
|
|
# bitrate = args.regularize_bits
|
|
stddev = w.std(unbiased = False)
|
|
|
|
min_value = -stddev * 4
|
|
max_value = stddev * 4
|
|
bins = torch.empty((nbins // 2,), dtype = torch.long, device = w.device)
|
|
ext.histogram(w, bins, min_value, max_value, False)
|
|
|
|
k, n = w.shape
|
|
su = (torch.randn(k, device = w.device).sign() + 1e-5).sign().to(torch.float).unsqueeze(1)
|
|
sv = (torch.randn(n, device = w.device).sign() + 1e-5).sign().to(torch.float).unsqueeze(0)
|
|
quant_args = {
|
|
# "K": bitrate,
|
|
"apply_out_scales": None,
|
|
"devices": [args.device],
|
|
}
|
|
apply_out_scales, w, g_scale, su, sv = regularize(
|
|
w.float(),
|
|
su,
|
|
sv,
|
|
quant_args,
|
|
False,
|
|
None,
|
|
None,
|
|
skip_g_scale = True
|
|
)
|
|
|
|
r_stddev = w.std(unbiased = False)
|
|
r_min_value = -r_stddev * 4
|
|
r_max_value = r_stddev * 4
|
|
r_bins = torch.empty((nbins // 2,), dtype = torch.long, device = w.device)
|
|
ext.histogram(w, r_bins, r_min_value, r_max_value, False)
|
|
|
|
print(f"Reg. std. deviation : {col_white}{r_stddev:18.8f}{col_default}")
|
|
w4 = count_threshold(w, 4)
|
|
w4p = w4 / w.numel()
|
|
if w4 > 0:
|
|
print(f"Reg. outliers >4 : {col_yellow}{w4p:18.8f}{col_default} ({col_yellow}{w4:,} elem{col_default})")
|
|
else:
|
|
print(f"Reg. outliers >4 : {col_white}None{col_default}")
|
|
w8 = count_threshold(w, 8)
|
|
w8p = w8 / w.numel()
|
|
if w8 > 0:
|
|
print(f"Reg. outliers >8 : {col_yellow}{w8p:18.8f}{col_default} ({col_yellow}{w8:,} elem{col_default})")
|
|
else:
|
|
print(f"Reg. outliers >8 : {col_white}None{col_default}")
|
|
print()
|
|
|
|
bc_pre = bchart(bins, min_value, max_value, col_default)
|
|
bc_reg = bchart(r_bins, r_min_value, r_max_value, col_purple)
|
|
bc = [f"{p} {r}" for p, r in zip(bc_pre, bc_reg)]
|
|
|
|
bc.append(col_default + ("─" * (nbins // 2)) + " " + ("─" * (nbins // 2)) + col_default)
|
|
|
|
for a, b, c, d in zip(
|
|
[f"{min_value:.2e}", "Input layer"],
|
|
[f"{max_value:.2e}", ""],
|
|
[f"{r_min_value:.2e}", "Regularized layer"],
|
|
[f"{r_max_value:.2e}", ""],
|
|
):
|
|
ax = ""
|
|
ax += col_default + a + (" " * ((nbins // 2) - len(a) - len(b))) + b + col_default
|
|
ax += " "
|
|
ax += col_purple + c + (" " * ((nbins // 2) - len(c) - len(d))) + d + col_default
|
|
bc.append(ax)
|
|
|
|
print("\n".join(bc))
|
|
print()
|
|
|
|
|
|
@torch.inference_mode()
|
|
def main(args):
|
|
|
|
# Create model config
|
|
config = Config.from_directory(args.model_dir)
|
|
config.override_dynamic_seq_len(2048)
|
|
tokenizer = Tokenizer.from_config(config)
|
|
model = Model.from_config(config)
|
|
|
|
# Input state
|
|
bos = None if not args.bos else torch.tensor([[config.bos_token_id]], dtype = torch.long)
|
|
eval_ids = get_test_tokens(tokenizer, args.rows, bos)
|
|
state = eval_ids
|
|
|
|
# Streaming forward pass
|
|
for idx, module in enumerate(model.modules):
|
|
|
|
# Load next module
|
|
print(f" -- Loading module: {col_green}{module.key}{col_default}")
|
|
print()
|
|
config.stc.begin_deferred_load()
|
|
module.load(torch.device(args.device) if not module.caps.get("prefer_cpu") else "cpu")
|
|
config.stc.end_deferred_load()
|
|
if (
|
|
(args.from_layer is None or idx >= args.from_layer) and
|
|
(args.to_layer is None or idx < args.to_layer) and
|
|
not args.no_inspect_modules
|
|
):
|
|
inspect_module(args, module)
|
|
|
|
# Forward pass
|
|
print(f" -- Forward pass")
|
|
print()
|
|
params = {}
|
|
state = module.prepare_for_device(state, params)
|
|
state = module.forward(state, params)
|
|
if (args.from_layer is None or idx >= args.from_layer) and (args.to_layer is None or idx < args.to_layer):
|
|
inspect_state(args, state)
|
|
|
|
# Unload current module
|
|
module.unload()
|
|
config.stc.close()
|
|
free_mem()
|
|
|
|
# Test perplexity
|
|
vocab_size = tokenizer.actual_vocab_size
|
|
logprob_sum = 0.0
|
|
logprob_count = 0
|
|
with ProgressBar("Evaluating", args.rows) as pb:
|
|
for row in range(state.shape[0]):
|
|
pb.update(row)
|
|
input_ids = eval_ids[row:row + 1, :]
|
|
logits = state[row:row+1, ...]
|
|
logits = logits[:, :-1, :vocab_size].float()
|
|
log_probs = F.log_softmax(logits, dim = -1)
|
|
del logits
|
|
target_ids = input_ids[:, 1:].to(log_probs.device)
|
|
del input_ids
|
|
target_log_probs = log_probs.gather(-1, target_ids.unsqueeze(-1)).squeeze(-1)
|
|
logprob_sum += target_log_probs.sum().item()
|
|
logprob_count += target_ids.numel()
|
|
del log_probs
|
|
del target_log_probs
|
|
del target_ids
|
|
torch.cuda.empty_cache()
|
|
pb.update(args.rows)
|
|
mean_log_prob = logprob_sum / logprob_count
|
|
perplexity = math.exp(-mean_log_prob)
|
|
|
|
print(f"{col_blue}Outputs{col_default}")
|
|
print("───────")
|
|
print(f"Perplexity : {col_white}{perplexity:.6f}{col_default}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("-m", "--model_dir", type = str, help = "Path to model directory", required = True)
|
|
parser.add_argument("-d", "--device", type = int, help = "CUDA device index", default = 0)
|
|
parser.add_argument("-r", "--rows", type = int, help = "Number of rows", default = 10)
|
|
parser.add_argument("-hb", "--histogram_bins", type = int, help = "Histogram bins", default = 160)
|
|
parser.add_argument("-bos", "--bos", action = "store_true", help = "Add BOS token on each row")
|
|
parser.add_argument("-skip", "--skip_tokens", type = int, help = "Skip tokens at start of context", default = 0)
|
|
parser.add_argument("-fl", "--from_layer", type = int, help = "From layer", default = None)
|
|
parser.add_argument("-tl", "--to_layer", type = int, help = "To layer", default = None)
|
|
parser.add_argument("-nim", "--no_inspect_modules", action = "store_true", help = "Skip module inspection")
|
|
# parser.add_argument("-rb", "--regularize_bits", type = int, help = "Target bitrate for regularization test", default = 4)
|
|
_args = parser.parse_args()
|
|
main(_args)
|