Files
exllamav2/model_diff.py
2024-07-06 11:55:37 +02:00

262 lines
7.9 KiB
Python

from exllamav2 import(
ExLlamaV2,
ExLlamaV2Config,
ExLlamaV2Cache,
ExLlamaV2Cache_8bit,
ExLlamaV2Tokenizer,
model_init,
)
from exllamav2.attn import ExLlamaV2Attention
import argparse, os, math, time
import pandas, fastparquet
import torch
import torch.nn.functional as F
from exllamav2.conversion.tokenize import get_tokens
from exllamav2.util import list_live_tensors
import gc
import sys
import json
torch.cuda._lazy_init()
torch.set_printoptions(precision = 10)
parser = argparse.ArgumentParser(description = "Test layer-by-layer hidden state difference between two models")
parser.add_argument("-ed", "--eval_dataset", type = str, help = "Perplexity evaluation dataset (.parquet file)")
parser.add_argument("-er", "--eval_rows", type = int, default = 20, help = "Number of rows to apply from dataset")
parser.add_argument("-el", "--eval_length", type = int, default = 2048, help = "Max no. tokens per sample")
parser.add_argument("-ma", "--model_a", type = str, help = "Path to model A")
parser.add_argument("-mb", "--model_b", type = str, help = "Path to model B")
parser.add_argument("-k", "--keep_layers", type = int, default = 0, help = "Maintain state from model A for this many layers")
parser.add_argument("-tkm", "--topk_max", type = int, default = 5, help = "Max top-K interval to test")
args = parser.parse_args()
# Initialize both models
print(f" -- Model A: {args.model_a}")
print(f" -- Model B: {args.model_b}")
config = (ExLlamaV2Config(), ExLlamaV2Config())
config[0].model_dir = args.model_a
config[1].model_dir = args.model_b
config[0].prepare()
config[1].prepare()
config[0].max_batch_size = 1
config[1].max_batch_size = 1
config[0].arch_compat_overrides()
config[1].arch_compat_overrides()
model = (ExLlamaV2(config[0]), ExLlamaV2(config[1]))
model[0].load(lazy = True)
model[1].load(lazy = True)
num_modules = len(model[0].modules)
assert len(model[1].modules) == num_modules
# Tokenizer
print(f" -- Loading tokenizer")
tokenizer = ExLlamaV2Tokenizer(config[0])
with torch.no_grad():
# Input
print(f" -- Tokenizing eval data")
eval_tokens = get_tokens(args.eval_rows, args.eval_length, args.eval_dataset, tokenizer)
num_rows, seq_len = eval_tokens.shape
eval_tokens = [eval_tokens[i:i+1, :] for i in range(eval_tokens.shape[0])]
attn_params = ExLlamaV2Attention.Params(1, seq_len, 0, None, None)
# Get embeddings
print(f" -- Embeddings")
hidden_state = [[], []]
for i in [0, 1]:
module = model[i].modules[0]
module.load()
for j in range(num_rows):
hidden_state[i].append(module.forward(eval_tokens[j]))
module.unload()
# Forward
rfn_error = []
for idx in range(1, num_modules):
for i in [0, 1]:
module = model[i].modules[idx]
if i == 0:
print(f" -- {module.key + ' (' + module.name + ')':40}", end = "")
module.load()
for j in range(num_rows):
if i == 1 and idx <= args.keep_layers:
hidden_state[1][j] = hidden_state[0][j].clone()
else:
x = hidden_state[i][j].to("cuda:0")
x = module.forward(x, cache = None, attn_params = attn_params, past_len = 0, loras = None)
hidden_state[i][j] = x.to("cpu")
x = None
module.unload()
module = None
max_error_ = 0
rfn_error_sum = 0
mse_sum = 0
for j in range(num_rows):
x = hidden_state[0][j].to("cuda:0").float()
y = hidden_state[1][j].to("cuda:0").float()
rfn_error_sum += torch.linalg.norm(y[0] - x[0], 'fro') / torch.linalg.norm(x[0], 'fro').item()
x = None
y = None
rfn_error_ = rfn_error_sum / num_rows
print(f" rfn_error: {rfn_error_:8.6f}")
rfn_error.append(rfn_error_)
# Test outputs
def ppl(input_ids_, logits_):
logprob_sum_ = 0.0
logprob_count_ = 0
chunksize = logits_.shape[1] * 16000 // logits_.shape[2]
b_ = 0
while b_ < logits_.shape[1]:
a_ = b_
b_ = min(b_ + chunksize, logits_.shape[1])
logits_f = logits_[:, a_:b_, :].float() + 1e-10
target_ids = input_ids_[:, a_ + 1:b_ + 1].to(logits_.device)
log_probs = F.log_softmax(logits_f, dim=-1)
token_log_probs = log_probs.gather(-1, target_ids.unsqueeze(-1)).squeeze(-1)
logprob_sum_ += token_log_probs.sum().item()
logprob_count_ += target_ids.numel()
return logprob_sum_, logprob_count_
topk_max = args.topk_max
logprob_sum = [0, 0]
logprob_count = [0, 0]
kl_div_sum = 0
kl_div_count = 0
mse_sum = 0
mse_count = 0
topk_hits_sum = [[0] * topk_max, [0] * topk_max]
topk_hits_count = [[0] * topk_max, [0] * topk_max]
topk_agreement_sum = [0] * topk_max
topk_agreement_count = [0] * topk_max
print(f" -- Testing outputs")
b = 0
for j in range(num_rows):
# Perplexity
x = (hidden_state[0][j].to("cuda:0"), hidden_state[1][j].to("cuda:0"))
input_ids = eval_tokens[j]
top_indices = []
for i in [0, 1]:
logits = x[i][:, :-1, :]
logprob_sum__, logprob_count__ = ppl(input_ids, logits)
logprob_sum[i] += logprob_sum__
logprob_count[i] += logprob_count__
_, top_index = torch.topk(logits, topk_max, dim = -1)
top_index = top_index.cpu().view(-1, topk_max)
top_indices.append(top_index)
targets = input_ids[:, 1:].view(-1, 1)
for t in range(topk_max):
top_slice = top_index[:, :t + 1]
hits = torch.eq(targets, top_slice)
row_hits = hits.any(dim = 1)
topk_hits_sum[i][t] += row_hits.sum().item()
topk_hits_count[i][t] += top_slice.shape[0]
for t in range(topk_max):
top_slice_a = top_indices[0][:, :t + 1]
top_slice_b = top_indices[1][:, :t + 1]
hits = torch.eq(top_slice_a, top_slice_b)
row_hits = hits.all(dim = 1)
topk_agreement_sum[t] += row_hits.sum().item()
topk_agreement_count[t] += top_slice_a.shape[0]
epsilon = 1e-10
probs_a = torch.softmax(x[0].float(), dim = -1)
probs_b = torch.softmax(x[1].float(), dim = -1)
kl_div = F.kl_div(torch.log(probs_a + epsilon), probs_b, reduction = 'none')
kl_div_sum += kl_div.sum(dim = -1).mean().item()
mse_sum += F.mse_loss(probs_a, probs_b)
mse_count += 1
perplexity = (math.exp(-logprob_sum[0] / logprob_count[0]), math.exp(-logprob_sum[1] / logprob_count[1]))
mse = mse_sum / mse_count
kl_div = kl_div_sum / num_rows
a_acc = []
b_acc = []
a_acc_str = ""
b_acc_str = ""
agree_str = ""
topk_agree = []
for t in range(topk_max):
a_acc_ = topk_hits_sum[0][t] / topk_hits_count[0][t]
b_acc_ = topk_hits_sum[1][t] / topk_hits_count[1][t]
topk_agree_ = topk_agreement_sum[t] / topk_agreement_count[t]
a_acc.append(a_acc_)
b_acc.append(b_acc_)
topk_agree.append(topk_agree_)
a_acc_str += f"{a_acc_:6.4f} "
b_acc_str += f"{b_acc_:6.4f} "
agree_str += f"{topk_agree_:6.4f} "
# CSV output
print()
print("-----------------")
print()
print(";".join([f"{p:.8f}" for p in perplexity]))
print()
print(f"{kl_div:.8f}")
print(f"{mse:.8f}")
print()
for i in range(topk_max):
print(f"{i+1};{a_acc[i]:.8f};{b_acc[i]:.8f};{topk_agree[i]:.8f}")
print()
for idx, err in enumerate(rfn_error):
print(f"{idx};{err:.8f}")
print()
print("-----------------")
print()
# Results
print(f" -- A, ppl: {perplexity[0]:11.8f} acc: {a_acc_str}")
print(f" -- B, ppl: {perplexity[1]:11.8f} acc: {b_acc_str}")
print(f" -- Top-K agreement: {agree_str}")
print(f" -- KL divergence: {kl_div:11.8f}")
print(f" -- MSE: {mse:11.8f}")