Files
exllamav2/util/estimate_kld.py
2023-12-08 20:19:57 +01:00

166 lines
4.8 KiB
Python

from conversion.qparams import QParams, qparams_options
from exllamav2 import ExLlamaV2, ExLlamaV2Config, ExLlamaV2Tokenizer
from conversion.tokenize import get_tokens, get_standard_calibration
from conversion.qparams_stats import qparams_stats
import torch
import torch.nn.functional as F
import sys, math, json, os
from safetensors import safe_open
model_dir = "/mnt/str/models/_exl2/llama2-7b/"
tensor_dir = "/mnt/str/models/_exl2/__giga/out_tensor/"
measurement_file = "/mnt/str/models/_exl2/__giga/measurement_1.json"
log_file = "/mnt/str/models/_exl2/__giga/log.csv"
config = ExLlamaV2Config()
config.model_dir = model_dir
config.prepare()
model = ExLlamaV2(config)
model.load()
tokenizer = ExLlamaV2Tokenizer(config)
# eval_tokens = get_standard_calibration(True, tokenizer)
eval_dataset = "/mnt/str/datasets/c4_sample.parquet"
eval_rows = 32
eval_length = 2048
eval_tokens = get_tokens(eval_rows, eval_length, eval_dataset, tokenizer)
ref_probs = []
def ppl_test(reference = False):
global ref_probs
cache = None
logprob_sum = 0.0
logprob_count = 0
for i in range(eval_tokens.shape[0]):
# if i % 10 == 0: print(".", end="")
# sys.stdout.flush()
input_ids = eval_tokens[i:i + 1, :]
input_ids = input_ids[:, :]
if cache is not None: cache.current_seq_len = 0
logits = model.forward(input_ids, cache)
logits = logits[:, :-1, :]
logits = logits.float() + 1e-10
target_ids = input_ids[:, 1:].to(logits.device)
lg = logits[0].to("cuda:1")
probs = F.softmax(lg, dim = -1)
if reference:
ref_probs.append(probs)
avg_kl_div = 0
else:
rprobs = torch.log(ref_probs[i] + 1e-10)
kl_div = F.kl_div(rprobs, probs, reduction = 'none')
avg_kl_div = kl_div.sum(dim = 1).mean().item()
log_probs = F.log_softmax(logits, dim=-1)
token_log_probs = log_probs.gather(-1, target_ids.unsqueeze(-1)).squeeze(-1)
logprob_sum += token_log_probs.sum().item()
logprob_count += target_ids.numel()
mean_log_prob = logprob_sum / logprob_count
perplexity = math.exp(-mean_log_prob)
return perplexity, avg_kl_div
base_ppl, _ = ppl_test(reference = True)
print("Base perplexity:", base_ppl)
def replace_layer(key, qp):
# Remove original weight
original_file = config.tensor_file_map[key + ".weight"]
if key + ".weight" in config.tensor_file_map:
del config.tensor_file_map[key + ".weight"]
module = model.modules_dict[key]
module.unload()
# Get new quantized tensor to test
fdesc = qp.get_desc(True)
tensor_file = os.path.join(tensor_dir, fdesc + "____" + key + ".safetensors")
assert os.path.exists(tensor_file)
# Insert q tensor components into map
keys_to_unset = []
with safe_open(tensor_file, framework="pt", device="cpu") as f:
for k in f.keys():
config.tensor_file_map[k] = tensor_file
keys_to_unset.append(k)
# Load the quantized layer
module.load()
for k in keys_to_unset: del config.tensor_file_map[k]
config.tensor_file_map[key + ".weight"] = original_file
def unreplace_layer(key):
module = model.modules_dict[key]
module.unload()
module.load()
layers = [0, 1, 2, 16, 24, 31]
results = []
print("qparams_stats = \\")
print("[")
for qps in qparams_stats:
print(" [")
for x in qps:
if x is None:
print(" None,")
elif isinstance(x, QParams):
print(" " + str(x) + ",")
else:
print(f" {x:1.10f},")
if len(qps) == 7:
for i in layers:
lkey = "model.layers." + str(i)
# Get new strat
s_q, s_k, s_v, s_o, s_g, s_u, s_d = qps
# Replace
if s_q: replace_layer(lkey + ".self_attn.q_proj", s_q)
if s_k: replace_layer(lkey + ".self_attn.k_proj", s_k)
if s_v: replace_layer(lkey + ".self_attn.v_proj", s_v)
if s_o: replace_layer(lkey + ".self_attn.o_proj", s_o)
if s_g: replace_layer(lkey + ".mlp.gate_proj", s_g)
if s_u: replace_layer(lkey + ".mlp.up_proj", s_u)
if s_d: replace_layer(lkey + ".mlp.down_proj", s_d)
# Test
new_ppl, kldiv = ppl_test()
print(f" {kldiv:1.10f},")
if s_q: unreplace_layer(lkey + ".self_attn.q_proj")
if s_k: unreplace_layer(lkey + ".self_attn.k_proj")
if s_v: unreplace_layer(lkey + ".self_attn.v_proj")
if s_o: unreplace_layer(lkey + ".self_attn.o_proj")
if s_g: unreplace_layer(lkey + ".mlp.gate_proj")
if s_u: unreplace_layer(lkey + ".mlp.up_proj")
if s_d: unreplace_layer(lkey + ".mlp.down_proj")
print(" ],")
print("]")