mirror of
https://github.com/turboderp-org/exllamav2.git
synced 2026-04-20 06:19:00 +00:00
Add script to compare quantized and unquantized model
This commit is contained in:
257
model_diff.py
Normal file
257
model_diff.py
Normal file
@@ -0,0 +1,257 @@
|
||||
|
||||
from exllamav2 import(
|
||||
ExLlamaV2,
|
||||
ExLlamaV2Config,
|
||||
ExLlamaV2Cache,
|
||||
ExLlamaV2Cache_8bit,
|
||||
ExLlamaV2Tokenizer,
|
||||
model_init,
|
||||
)
|
||||
|
||||
import argparse, os, math, time
|
||||
import pandas, fastparquet
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from conversion.tokenize import get_tokens
|
||||
from exllamav2.util import list_live_tensors
|
||||
import gc
|
||||
|
||||
import sys
|
||||
import json
|
||||
|
||||
torch.cuda._lazy_init()
|
||||
torch.set_printoptions(precision = 10)
|
||||
|
||||
parser = argparse.ArgumentParser(description = "Test layer-by-layer hidden state difference between two models")
|
||||
parser.add_argument("-ed", "--eval_dataset", type = str, help = "Perplexity evaluation dataset (.parquet file)")
|
||||
parser.add_argument("-er", "--eval_rows", type = int, default = 20, help = "Number of rows to apply from dataset")
|
||||
parser.add_argument("-el", "--eval_length", type = int, default = 2048, help = "Max no. tokens per sample")
|
||||
parser.add_argument("-ma", "--model_a", type = str, help = "Path to model A")
|
||||
parser.add_argument("-mb", "--model_b", type = str, help = "Path to model B")
|
||||
parser.add_argument("-k", "--keep_layers", type = int, default = 0, help = "Maintain state from model A for this many layers")
|
||||
parser.add_argument("-tkm", "--topk_max", type = int, default = 5, help = "Max top-K interval to test")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Initialize both models
|
||||
|
||||
print(f" -- Model A: {args.model_a}")
|
||||
print(f" -- Model B: {args.model_b}")
|
||||
|
||||
config = (ExLlamaV2Config(), ExLlamaV2Config())
|
||||
config[0].model_dir = args.model_a
|
||||
config[1].model_dir = args.model_b
|
||||
config[0].prepare()
|
||||
config[1].prepare()
|
||||
config[0].max_batch_size = 1
|
||||
config[1].max_batch_size = 1
|
||||
|
||||
model = (ExLlamaV2(config[0]), ExLlamaV2(config[1]))
|
||||
model[0].load(lazy = True)
|
||||
model[1].load(lazy = True)
|
||||
|
||||
num_modules = len(model[0].modules)
|
||||
assert len(model[1].modules) == num_modules
|
||||
|
||||
# Tokenizer
|
||||
|
||||
print(f" -- Loading tokenizer")
|
||||
tokenizer = ExLlamaV2Tokenizer(config[0])
|
||||
|
||||
with torch.no_grad():
|
||||
|
||||
# Input
|
||||
|
||||
print(f" -- Tokenizing eval data")
|
||||
eval_tokens = get_tokens(args.eval_rows, args.eval_length, args.eval_dataset, tokenizer)
|
||||
num_rows, seq_len = eval_tokens.shape
|
||||
|
||||
eval_tokens = [eval_tokens[i:i+1, :] for i in range(eval_tokens.shape[0])]
|
||||
attn_mask = model[0].build_attn_mask(1, seq_len, 0, None, "cuda:0")
|
||||
|
||||
# Get embeddings
|
||||
|
||||
print(f" -- Embeddings")
|
||||
hidden_state = [[], []]
|
||||
for i in [0, 1]:
|
||||
module = model[i].modules[0]
|
||||
module.load()
|
||||
for j in range(num_rows):
|
||||
hidden_state[i].append(module.forward(eval_tokens[j]))
|
||||
module.unload()
|
||||
|
||||
# Forward
|
||||
|
||||
rfn_error = []
|
||||
|
||||
for idx in range(1, num_modules):
|
||||
|
||||
for i in [0, 1]:
|
||||
|
||||
module = model[i].modules[idx]
|
||||
if i == 0:
|
||||
print(f" -- {module.key + ' (' + module.name + ')':40}", end = "")
|
||||
|
||||
module.load()
|
||||
|
||||
for j in range(num_rows):
|
||||
if i == 1 and idx <= args.keep_layers:
|
||||
hidden_state[1][j] = hidden_state[0][j].clone()
|
||||
else:
|
||||
x = hidden_state[i][j].to("cuda:0")
|
||||
x = module.forward(x, cache = None, attn_mask = attn_mask, past_len = 0, loras = None, position_offsets = None)
|
||||
hidden_state[i][j] = x.to("cpu")
|
||||
x = None
|
||||
|
||||
module.unload()
|
||||
module = None
|
||||
|
||||
max_error_ = 0
|
||||
rfn_error_sum = 0
|
||||
mse_sum = 0
|
||||
|
||||
for j in range(num_rows):
|
||||
|
||||
x = hidden_state[0][j].to("cuda:0").float()
|
||||
y = hidden_state[1][j].to("cuda:0").float()
|
||||
rfn_error_sum += torch.linalg.norm(y[0] - x[0], 'fro') / torch.linalg.norm(x[0], 'fro').item()
|
||||
x = None
|
||||
y = None
|
||||
|
||||
rfn_error_ = rfn_error_sum / num_rows
|
||||
print(f" rfn_error: {rfn_error_:8.6f}")
|
||||
rfn_error.append(rfn_error_)
|
||||
|
||||
|
||||
# Test outputs
|
||||
|
||||
def ppl(input_ids_, logits_):
|
||||
|
||||
logprob_sum_ = 0.0
|
||||
logprob_count_ = 0
|
||||
|
||||
chunksize = logits_.shape[1] * 16000 // logits_.shape[2]
|
||||
b_ = 0
|
||||
while b_ < logits_.shape[1]:
|
||||
a_ = b_
|
||||
b_ = min(b_ + chunksize, logits_.shape[1])
|
||||
|
||||
logits_f = logits_[:, a_:b_, :].float() + 1e-10
|
||||
target_ids = input_ids_[:, a_ + 1:b_ + 1].to(logits_.device)
|
||||
|
||||
log_probs = F.log_softmax(logits_f, dim=-1)
|
||||
token_log_probs = log_probs.gather(-1, target_ids.unsqueeze(-1)).squeeze(-1)
|
||||
logprob_sum_ += token_log_probs.sum().item()
|
||||
logprob_count_ += target_ids.numel()
|
||||
|
||||
return logprob_sum_, logprob_count_
|
||||
|
||||
topk_max = args.topk_max
|
||||
logprob_sum = [0, 0]
|
||||
logprob_count = [0, 0]
|
||||
kl_div_sum = 0
|
||||
kl_div_count = 0
|
||||
mse_sum = 0
|
||||
mse_count = 0
|
||||
topk_hits_sum = [[0] * topk_max, [0] * topk_max]
|
||||
topk_hits_count = [[0] * topk_max, [0] * topk_max]
|
||||
topk_agreement_sum = [0] * topk_max
|
||||
topk_agreement_count = [0] * topk_max
|
||||
|
||||
print(f" -- Testing outputs")
|
||||
|
||||
b = 0
|
||||
for j in range(num_rows):
|
||||
|
||||
# Perplexity
|
||||
|
||||
x = (hidden_state[0][j].to("cuda:0"), hidden_state[1][j].to("cuda:0"))
|
||||
input_ids = eval_tokens[j]
|
||||
|
||||
top_indices = []
|
||||
|
||||
for i in [0, 1]:
|
||||
logits = x[i][:, :-1, :]
|
||||
logprob_sum__, logprob_count__ = ppl(input_ids, logits)
|
||||
logprob_sum[i] += logprob_sum__
|
||||
logprob_count[i] += logprob_count__
|
||||
|
||||
_, top_index = torch.topk(logits, topk_max, dim = -1)
|
||||
top_index = top_index.cpu().view(-1, topk_max)
|
||||
top_indices.append(top_index)
|
||||
targets = input_ids[:, 1:].view(-1, 1)
|
||||
|
||||
for t in range(topk_max):
|
||||
top_slice = top_index[:, :t + 1]
|
||||
hits = torch.eq(targets, top_slice)
|
||||
row_hits = hits.any(dim = 1)
|
||||
topk_hits_sum[i][t] += row_hits.sum().item()
|
||||
topk_hits_count[i][t] += top_slice.shape[0]
|
||||
|
||||
for t in range(topk_max):
|
||||
top_slice_a = top_indices[0][:, :t + 1]
|
||||
top_slice_b = top_indices[1][:, :t + 1]
|
||||
hits = torch.eq(top_slice_a, top_slice_b)
|
||||
row_hits = hits.all(dim = 1)
|
||||
topk_agreement_sum[t] += row_hits.sum().item()
|
||||
topk_agreement_count[t] += top_slice_a.shape[0]
|
||||
|
||||
epsilon = 1e-10
|
||||
probs_a = torch.softmax(x[0].float(), dim = -1)
|
||||
probs_b = torch.softmax(x[1].float(), dim = -1)
|
||||
kl_div = F.kl_div(torch.log(probs_a + epsilon), probs_b, reduction = 'none')
|
||||
kl_div_sum += kl_div.sum(dim = -1).mean().item()
|
||||
|
||||
mse_sum += F.mse_loss(probs_a, probs_b)
|
||||
mse_count += 1
|
||||
|
||||
perplexity = (math.exp(-logprob_sum[0] / logprob_count[0]), math.exp(-logprob_sum[1] / logprob_count[1]))
|
||||
mse = mse_sum / mse_count
|
||||
kl_div = kl_div_sum / num_rows
|
||||
|
||||
a_acc = []
|
||||
b_acc = []
|
||||
a_acc_str = ""
|
||||
b_acc_str = ""
|
||||
agree_str = ""
|
||||
topk_agree = []
|
||||
for t in range(topk_max):
|
||||
a_acc_ = topk_hits_sum[0][t] / topk_hits_count[0][t]
|
||||
b_acc_ = topk_hits_sum[1][t] / topk_hits_count[1][t]
|
||||
topk_agree_ = topk_agreement_sum[t] / topk_agreement_count[t]
|
||||
a_acc.append(a_acc_)
|
||||
b_acc.append(b_acc_)
|
||||
topk_agree.append(topk_agree_)
|
||||
a_acc_str += f"{a_acc_:6.4f} "
|
||||
b_acc_str += f"{b_acc_:6.4f} "
|
||||
agree_str += f"{topk_agree_:6.4f} "
|
||||
|
||||
# CSV output
|
||||
|
||||
print()
|
||||
print("-----------------")
|
||||
print()
|
||||
print(";".join([f"{p:.8f}" for p in perplexity]))
|
||||
print()
|
||||
print(f"{kl_div:.8f}")
|
||||
print(f"{mse:.8f}")
|
||||
print()
|
||||
for i in range(topk_max):
|
||||
print(f"{i+1};{a_acc[i]:.8f};{b_acc[i]:.8f};{topk_agree[i]:.8f}")
|
||||
print()
|
||||
for idx, err in enumerate(rfn_error):
|
||||
print(f"{idx};{err:.8f}")
|
||||
print()
|
||||
print("-----------------")
|
||||
print()
|
||||
|
||||
# Results
|
||||
|
||||
print(f" -- A, ppl: {perplexity[0]:11.8f} acc: {a_acc_str}")
|
||||
print(f" -- B, ppl: {perplexity[1]:11.8f} acc: {b_acc_str}")
|
||||
print(f" -- Top-K agreement: {agree_str}")
|
||||
print(f" -- KL divergence: {kl_div:11.8f}")
|
||||
print(f" -- MSE: {mse:11.8f}")
|
||||
|
||||
|
||||
|
||||
@@ -21,24 +21,25 @@ import torch
|
||||
|
||||
# model_base = "/mnt/str/models/"
|
||||
# variants = ["mistral-7b-instruct"]
|
||||
# model_base = "/mnt/str/models/mistral-7b-instruct-exl2"
|
||||
# variants = ["2.5bpw"]
|
||||
|
||||
# model_base = "/mnt/str/models/mistral-7b-instruct-exl3"
|
||||
# variants = ["8.0bpw"]
|
||||
model_base = "/mnt/str/models/mixtral-8x7b-instruct-exl2/"
|
||||
variants = ["4.0bpw"]
|
||||
|
||||
# model_base = "/mnt/str/models/tiefighter-13b-exl4/"
|
||||
|
||||
# variants = [v for v in os.listdir(model_base) if os.path.isdir(os.path.join(model_base, v))]
|
||||
|
||||
variants = \
|
||||
[
|
||||
"2.4bpw",
|
||||
"2.5bpw",
|
||||
"2.7bpw",
|
||||
"3.0bpw",
|
||||
"4.0bpw",
|
||||
"6.0bpw",
|
||||
"8.0bpw",
|
||||
]
|
||||
# variants = \
|
||||
# [
|
||||
# "2.4bpw",
|
||||
# "2.5bpw",
|
||||
# "2.7bpw",
|
||||
# "3.0bpw",
|
||||
# "4.0bpw",
|
||||
# "6.0bpw",
|
||||
# "8.0bpw",
|
||||
# ]
|
||||
|
||||
gpu_split = (16, 16, 24)
|
||||
|
||||
@@ -81,8 +82,8 @@ for variant in variants:
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
batch_size = 20
|
||||
num_samples_per_task = 5
|
||||
batch_size = 10
|
||||
num_samples_per_task = 1
|
||||
samples = []
|
||||
|
||||
model, cache, tokenizer = get_model(model_base, variant, gpu_split, batch_size)
|
||||
|
||||
@@ -14,21 +14,22 @@ import torch
|
||||
|
||||
# Models to test
|
||||
|
||||
#model_base = "/mnt/str/models/_exl2"
|
||||
# model_base = "/mnt/str/models/_exl2"
|
||||
# model_base = "/mnt/str/models/mixtral-8x7b-instruct-exl2/"
|
||||
model_base = "/mnt/str/models/tiefighter-13b-exl4/"
|
||||
|
||||
model_base = "/mnt/str/models/_gptq/TheBloke_Yi-34B-GPTQ/"
|
||||
variants = [v for v in os.listdir(model_base) if os.path.isdir(os.path.join(model_base, v))]
|
||||
|
||||
# variants = [v for v in os.listdir(model_base) if os.path.isdir(os.path.join(model_base, v))]
|
||||
# variants = \
|
||||
# [
|
||||
# "2.4bpw",
|
||||
# "2.5bpw",
|
||||
# "3.0bpw",
|
||||
# "4.0bpw",
|
||||
# "6.0bpw",
|
||||
# ]
|
||||
|
||||
variants = \
|
||||
[
|
||||
"",
|
||||
# "goliath-120b-exl2/3.0bpw",
|
||||
# "llama2-70b-exl2/3.0bpw",
|
||||
# "llama2-70b-exl2/4.65bpw",
|
||||
]
|
||||
|
||||
gpu_split = (21.2, 24)
|
||||
gpu_split = (20, 21.3, 24)
|
||||
|
||||
qa_set = "cais/mmlu"
|
||||
qa_split = "test"
|
||||
@@ -36,12 +37,11 @@ qa_split = "test"
|
||||
categories = \
|
||||
[
|
||||
"anatomy",
|
||||
# "computer_security",
|
||||
# "formal_logic",
|
||||
# "logical_fallacies",
|
||||
# "computer_security",
|
||||
# "philosophy",
|
||||
# "nutrition",
|
||||
"computer_security",
|
||||
"formal_logic",
|
||||
"logical_fallacies",
|
||||
"philosophy",
|
||||
"nutrition",
|
||||
]
|
||||
|
||||
examples_per_category = 3
|
||||
@@ -56,6 +56,7 @@ def get_model(base, variant_, gpu_split_, batch_size_):
|
||||
config = ExLlamaV2Config()
|
||||
config.model_dir = model_dir
|
||||
config.prepare()
|
||||
config.max_seq_len = 2048
|
||||
config.max_batch_size = batch_size_
|
||||
|
||||
model_ = ExLlamaV2(config)
|
||||
|
||||
Reference in New Issue
Block a user