mirror of
https://github.com/turboderp-org/exllamav2.git
synced 2026-03-15 00:07:26 +00:00
262 lines
7.9 KiB
Python
262 lines
7.9 KiB
Python
|
|
from exllamav2 import(
|
|
ExLlamaV2,
|
|
ExLlamaV2Config,
|
|
ExLlamaV2Cache,
|
|
ExLlamaV2Cache_8bit,
|
|
ExLlamaV2Tokenizer,
|
|
model_init,
|
|
)
|
|
|
|
from exllamav2.attn import ExLlamaV2Attention
|
|
|
|
import argparse, os, math, time
|
|
import pandas, fastparquet
|
|
import torch
|
|
import torch.nn.functional as F
|
|
from exllamav2.conversion.tokenize import get_tokens
|
|
from exllamav2.util import list_live_tensors
|
|
import gc
|
|
|
|
import sys
|
|
import json
|
|
|
|
torch.cuda._lazy_init()
|
|
torch.set_printoptions(precision = 10)
|
|
|
|
parser = argparse.ArgumentParser(description = "Test layer-by-layer hidden state difference between two models")
|
|
parser.add_argument("-ed", "--eval_dataset", type = str, help = "Perplexity evaluation dataset (.parquet file)")
|
|
parser.add_argument("-er", "--eval_rows", type = int, default = 20, help = "Number of rows to apply from dataset")
|
|
parser.add_argument("-el", "--eval_length", type = int, default = 2048, help = "Max no. tokens per sample")
|
|
parser.add_argument("-ma", "--model_a", type = str, help = "Path to model A")
|
|
parser.add_argument("-mb", "--model_b", type = str, help = "Path to model B")
|
|
parser.add_argument("-k", "--keep_layers", type = int, default = 0, help = "Maintain state from model A for this many layers")
|
|
parser.add_argument("-tkm", "--topk_max", type = int, default = 5, help = "Max top-K interval to test")
|
|
|
|
args = parser.parse_args()
|
|
|
|
# Initialize both models
|
|
|
|
print(f" -- Model A: {args.model_a}")
|
|
print(f" -- Model B: {args.model_b}")
|
|
|
|
config = (ExLlamaV2Config(), ExLlamaV2Config())
|
|
config[0].model_dir = args.model_a
|
|
config[1].model_dir = args.model_b
|
|
config[0].prepare()
|
|
config[1].prepare()
|
|
config[0].max_batch_size = 1
|
|
config[1].max_batch_size = 1
|
|
config[0].arch_compat_overrides()
|
|
config[1].arch_compat_overrides()
|
|
|
|
model = (ExLlamaV2(config[0]), ExLlamaV2(config[1]))
|
|
model[0].load(lazy = True)
|
|
model[1].load(lazy = True)
|
|
|
|
num_modules = len(model[0].modules)
|
|
assert len(model[1].modules) == num_modules
|
|
|
|
# Tokenizer
|
|
|
|
print(f" -- Loading tokenizer")
|
|
tokenizer = ExLlamaV2Tokenizer(config[0])
|
|
|
|
with torch.no_grad():
|
|
|
|
# Input
|
|
|
|
print(f" -- Tokenizing eval data")
|
|
eval_tokens = get_tokens(args.eval_rows, args.eval_length, args.eval_dataset, tokenizer)
|
|
num_rows, seq_len = eval_tokens.shape
|
|
|
|
eval_tokens = [eval_tokens[i:i+1, :] for i in range(eval_tokens.shape[0])]
|
|
attn_params = ExLlamaV2Attention.Params(1, seq_len, 0, None, None)
|
|
|
|
# Get embeddings
|
|
|
|
print(f" -- Embeddings")
|
|
hidden_state = [[], []]
|
|
for i in [0, 1]:
|
|
module = model[i].modules[0]
|
|
module.load()
|
|
for j in range(num_rows):
|
|
hidden_state[i].append(module.forward(eval_tokens[j]))
|
|
module.unload()
|
|
|
|
# Forward
|
|
|
|
rfn_error = []
|
|
|
|
for idx in range(1, num_modules):
|
|
|
|
for i in [0, 1]:
|
|
|
|
module = model[i].modules[idx]
|
|
if i == 0:
|
|
print(f" -- {module.key + ' (' + module.name + ')':40}", end = "")
|
|
|
|
module.load()
|
|
|
|
for j in range(num_rows):
|
|
if i == 1 and idx <= args.keep_layers:
|
|
hidden_state[1][j] = hidden_state[0][j].clone()
|
|
else:
|
|
x = hidden_state[i][j].to("cuda:0")
|
|
x = module.forward(x, cache = None, attn_params = attn_params, past_len = 0, loras = None)
|
|
hidden_state[i][j] = x.to("cpu")
|
|
x = None
|
|
|
|
module.unload()
|
|
module = None
|
|
|
|
max_error_ = 0
|
|
rfn_error_sum = 0
|
|
mse_sum = 0
|
|
|
|
for j in range(num_rows):
|
|
|
|
x = hidden_state[0][j].to("cuda:0").float()
|
|
y = hidden_state[1][j].to("cuda:0").float()
|
|
rfn_error_sum += torch.linalg.norm(y[0] - x[0], 'fro') / torch.linalg.norm(x[0], 'fro').item()
|
|
x = None
|
|
y = None
|
|
|
|
rfn_error_ = rfn_error_sum / num_rows
|
|
print(f" rfn_error: {rfn_error_:8.6f}")
|
|
rfn_error.append(rfn_error_)
|
|
|
|
|
|
# Test outputs
|
|
|
|
def ppl(input_ids_, logits_):
|
|
|
|
logprob_sum_ = 0.0
|
|
logprob_count_ = 0
|
|
|
|
chunksize = logits_.shape[1] * 16000 // logits_.shape[2]
|
|
b_ = 0
|
|
while b_ < logits_.shape[1]:
|
|
a_ = b_
|
|
b_ = min(b_ + chunksize, logits_.shape[1])
|
|
|
|
logits_f = logits_[:, a_:b_, :].float() + 1e-10
|
|
target_ids = input_ids_[:, a_ + 1:b_ + 1].to(logits_.device)
|
|
|
|
log_probs = F.log_softmax(logits_f, dim=-1)
|
|
token_log_probs = log_probs.gather(-1, target_ids.unsqueeze(-1)).squeeze(-1)
|
|
logprob_sum_ += token_log_probs.sum().item()
|
|
logprob_count_ += target_ids.numel()
|
|
|
|
return logprob_sum_, logprob_count_
|
|
|
|
topk_max = args.topk_max
|
|
logprob_sum = [0, 0]
|
|
logprob_count = [0, 0]
|
|
kl_div_sum = 0
|
|
kl_div_count = 0
|
|
mse_sum = 0
|
|
mse_count = 0
|
|
topk_hits_sum = [[0] * topk_max, [0] * topk_max]
|
|
topk_hits_count = [[0] * topk_max, [0] * topk_max]
|
|
topk_agreement_sum = [0] * topk_max
|
|
topk_agreement_count = [0] * topk_max
|
|
|
|
print(f" -- Testing outputs")
|
|
|
|
b = 0
|
|
for j in range(num_rows):
|
|
|
|
# Perplexity
|
|
|
|
x = (hidden_state[0][j].to("cuda:0"), hidden_state[1][j].to("cuda:0"))
|
|
input_ids = eval_tokens[j]
|
|
|
|
top_indices = []
|
|
|
|
for i in [0, 1]:
|
|
logits = x[i][:, :-1, :]
|
|
logprob_sum__, logprob_count__ = ppl(input_ids, logits)
|
|
logprob_sum[i] += logprob_sum__
|
|
logprob_count[i] += logprob_count__
|
|
|
|
_, top_index = torch.topk(logits, topk_max, dim = -1)
|
|
top_index = top_index.cpu().view(-1, topk_max)
|
|
top_indices.append(top_index)
|
|
targets = input_ids[:, 1:].view(-1, 1)
|
|
|
|
for t in range(topk_max):
|
|
top_slice = top_index[:, :t + 1]
|
|
hits = torch.eq(targets, top_slice)
|
|
row_hits = hits.any(dim = 1)
|
|
topk_hits_sum[i][t] += row_hits.sum().item()
|
|
topk_hits_count[i][t] += top_slice.shape[0]
|
|
|
|
for t in range(topk_max):
|
|
top_slice_a = top_indices[0][:, :t + 1]
|
|
top_slice_b = top_indices[1][:, :t + 1]
|
|
hits = torch.eq(top_slice_a, top_slice_b)
|
|
row_hits = hits.all(dim = 1)
|
|
topk_agreement_sum[t] += row_hits.sum().item()
|
|
topk_agreement_count[t] += top_slice_a.shape[0]
|
|
|
|
epsilon = 1e-10
|
|
probs_a = torch.softmax(x[0].float(), dim = -1)
|
|
probs_b = torch.softmax(x[1].float(), dim = -1)
|
|
kl_div = F.kl_div(torch.log(probs_a + epsilon), probs_b, reduction = 'none')
|
|
kl_div_sum += kl_div.sum(dim = -1).mean().item()
|
|
|
|
mse_sum += F.mse_loss(probs_a, probs_b)
|
|
mse_count += 1
|
|
|
|
perplexity = (math.exp(-logprob_sum[0] / logprob_count[0]), math.exp(-logprob_sum[1] / logprob_count[1]))
|
|
mse = mse_sum / mse_count
|
|
kl_div = kl_div_sum / num_rows
|
|
|
|
a_acc = []
|
|
b_acc = []
|
|
a_acc_str = ""
|
|
b_acc_str = ""
|
|
agree_str = ""
|
|
topk_agree = []
|
|
for t in range(topk_max):
|
|
a_acc_ = topk_hits_sum[0][t] / topk_hits_count[0][t]
|
|
b_acc_ = topk_hits_sum[1][t] / topk_hits_count[1][t]
|
|
topk_agree_ = topk_agreement_sum[t] / topk_agreement_count[t]
|
|
a_acc.append(a_acc_)
|
|
b_acc.append(b_acc_)
|
|
topk_agree.append(topk_agree_)
|
|
a_acc_str += f"{a_acc_:6.4f} "
|
|
b_acc_str += f"{b_acc_:6.4f} "
|
|
agree_str += f"{topk_agree_:6.4f} "
|
|
|
|
# CSV output
|
|
|
|
print()
|
|
print("-----------------")
|
|
print()
|
|
print(";".join([f"{p:.8f}" for p in perplexity]))
|
|
print()
|
|
print(f"{kl_div:.8f}")
|
|
print(f"{mse:.8f}")
|
|
print()
|
|
for i in range(topk_max):
|
|
print(f"{i+1};{a_acc[i]:.8f};{b_acc[i]:.8f};{topk_agree[i]:.8f}")
|
|
print()
|
|
for idx, err in enumerate(rfn_error):
|
|
print(f"{idx};{err:.8f}")
|
|
print()
|
|
print("-----------------")
|
|
print()
|
|
|
|
# Results
|
|
|
|
print(f" -- A, ppl: {perplexity[0]:11.8f} acc: {a_acc_str}")
|
|
print(f" -- B, ppl: {perplexity[1]:11.8f} acc: {b_acc_str}")
|
|
print(f" -- Top-K agreement: {agree_str}")
|
|
print(f" -- KL divergence: {kl_div:11.8f}")
|
|
print(f" -- MSE: {mse:11.8f}")
|
|
|
|
|
|
|