Files
exllamav2/tests/compare_quants.py
2024-04-30 20:38:26 +02:00

180 lines
6.3 KiB
Python

import torch
from tqdm import tqdm
from datasets import load_dataset
import torch.nn as nn
from awq import AutoAWQForCausalLM
from auto_gptq import AutoGPTQForCausalLM
from transformers import AutoTokenizer
import os, json, gc
from exllamav2 import ExLlamaV2, ExLlamaV2Config, ExLlamaV2Cache, ExLlamaV2Tokenizer
from py_markdown_table.markdown_table import markdown_table
seqlen = 2048
def evaluate_perplexity(dataset, model:any, tokenizer, get_tokens, get_logits):
global seqlen
def _perplexity(nlls, n_samples, seqlen):
return torch.exp(torch.stack(nlls).sum() / (n_samples * seqlen))
data = get_tokens(tokenizer, model, dataset)
n_samples = data.numel() // seqlen
nlls = []
with tqdm(range(n_samples), desc = "Perplexity -") as progress_bar:
for i in progress_bar:
start_index = i * seqlen
end_index = (i + 1) * seqlen
batch = data[:, start_index:end_index]
with torch.no_grad():
logits = get_logits(model, batch)
shift_logits = logits[:, :-1, :].contiguous().float()
shift_labels = data[:, start_index:end_index][:, 1:].to(shift_logits.device)
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(
shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)
)
neg_log_likelihood = loss.float() * seqlen
nlls.append(neg_log_likelihood)
curr_ppl = _perplexity(nlls, i + 1, seqlen)
progress_bar.set_description(f"Perplexity {curr_ppl:.3f}")
ppl = _perplexity(nlls, n_samples, seqlen)
return ppl.item()
def get_logits_hf(model, batch):
return model(batch).logits
def get_tokens_hf(tokenizer, model, text):
data = tokenizer("\n\n".join(text), return_tensors="pt")
data = data.input_ids.to(model.device)
return data
def get_logits_exl2(model, batch):
return model.forward(batch)
def get_tokens_exl2(tokenizer, model, text):
data = tokenizer.encode("\n\n".join(text), add_bos = True)
return data
def get_dataset(ds_path, ds_name, ds_split, cache_file, max_rows):
module_dir = os.path.dirname(os.path.abspath(__file__))
filename = os.path.join(module_dir, cache_file)
if os.path.exists(filename):
print("Loading cached dataset...")
with open(filename, "r") as f:
return json.load(f)
else:
print(f"Loading dataset: {ds_path}...")
c_dataset = load_dataset(ds_path, ds_name, split = ds_split, streaming = True)
c_rows = []
for row in c_dataset:
c_rows.append(row["text"])
max_rows -= 1
if not max_rows: break
with open(filename, "w") as f:
f.write(json.dumps(c_rows, indent = 4))
return c_rows
def model_instance_awq(model_dir):
model = AutoAWQForCausalLM.from_pretrained(model_dir, device_map="auto").model
tokenizer = AutoTokenizer.from_pretrained(model_dir)
return model.eval(), tokenizer
def model_instance_gptq(model_dir):
model = AutoGPTQForCausalLM.from_quantized(model_dir, device_map="auto").model
tokenizer = AutoTokenizer.from_pretrained(model_dir)
return model.eval(), tokenizer
def model_instance_exl2(model_dir):
config = ExLlamaV2Config(model_dir)
config.max_input_len = seqlen
config.max_output_len = seqlen
config.max_seq_len = seqlen
model = ExLlamaV2(config)
model.load()
# model.load(gpu_split = [0,24,0,0])
# cache = ExLlamaV2Cache(model, lazy = True)
# model.load_autosplit(cache)
tokenizer = ExLlamaV2Tokenizer(config, lazy_init = True)
return model, tokenizer
def flush():
torch.cuda.synchronize()
gc.collect()
torch.cuda.empty_cache()
def run_tests(test_models):
wikitext = get_dataset("wikitext", "wikitext-2-raw-v1", "test", "wikitext_cached.json", 4358)
c4 = get_dataset("allenai/c4", "en", "validation", "c4_cached.json", 750)
fineweb = get_dataset("HuggingFaceFW/fineweb", "default", "train", "fineweb_cached.json", 500)
datasets = [wikitext, c4, fineweb]
# datasets = [d[:100] for d in datasets]
results = []
num_devices = torch.cuda.device_count()
for (fw, model_dir, text) in test_models:
flush()
for d in range(num_devices): torch.cuda.reset_peak_memory_stats(d)
print(f"Loading {fw}: {model_dir}")
if fw == "awq":
model, tokenizer = model_instance_awq(model_dir)
ppl = [evaluate_perplexity(ds, model, tokenizer, get_tokens_hf, get_logits_hf) for ds in datasets]
model, tokenizer = None, None
if fw in ["exl2", "exl2_fp16"]:
model, tokenizer = model_instance_exl2(model_dir)
ppl = [evaluate_perplexity(ds, model, tokenizer, get_tokens_exl2, get_logits_exl2) for ds in datasets]
model.unload()
model, tokenizer = None, None
if fw == "gptq":
model, tokenizer = model_instance_gptq(model_dir)
ppl = [evaluate_perplexity(ds, model, tokenizer, get_tokens_hf, get_logits_hf) for ds in datasets]
model, tokenizer = None, None
total_memory = sum(torch.cuda.max_memory_allocated(d) for d in range(num_devices))
max_mem_gb = total_memory / (1024 ** 3)
print(f"Max CUDA memory: {max_mem_gb:.2f} GB")
results.append({
"": text,
"Wikitext": f"{ppl[0]:.3f}",
"C4": f"{ppl[1]:.3f}",
"FineWeb": f"{ppl[2]:.3f}",
"Max VRAM": f"{max_mem_gb:.2f} GB"
})
return results
if __name__ == "__main__":
all_models = [
("exl2_fp16", "/mnt/str/models/llama3-8b-instruct", "FP16"),
("awq", "/mnt/str/models/_awq/llama3-8b-instruct-awq/", "AWQ"),
("gptq", "/mnt/str/models/_gptq/mistral-8b-instruct-v0.2-gptq/4bit-act-128g/", "GPTQ 4b-128g-act"),
("exl2", "/mnt/str/models/llama3-8b-instruct-exl2/4.0bpw", "EXL2 4.00 bpw"),
("exl2", "/mnt/str/models/llama3-8b-instruct-exl2/4.15bpw", "EXL2 4.15 bpw"),
("exl2", "/mnt/str/models/llama3-8b-instruct-exl2/5.0bpw", "EXL2 5.00 bpw"),
("exl2", "/mnt/str/models/llama3-8b-instruct-exl2/5.3bpw", "EXL2 5.30 bpw"),
]
results = run_tests(all_models)
markdown = markdown_table(results).set_params(row_sep = 'markdown', quote = False).get_markdown()
print(markdown)