mirror of
https://github.com/turboderp-org/exllamav2.git
synced 2026-03-15 00:07:26 +00:00
180 lines
6.3 KiB
Python
180 lines
6.3 KiB
Python
import torch
|
|
from tqdm import tqdm
|
|
from datasets import load_dataset
|
|
import torch.nn as nn
|
|
from awq import AutoAWQForCausalLM
|
|
from auto_gptq import AutoGPTQForCausalLM
|
|
from transformers import AutoTokenizer
|
|
import os, json, gc
|
|
from exllamav2 import ExLlamaV2, ExLlamaV2Config, ExLlamaV2Cache, ExLlamaV2Tokenizer
|
|
from py_markdown_table.markdown_table import markdown_table
|
|
|
|
seqlen = 2048
|
|
|
|
def evaluate_perplexity(dataset, model:any, tokenizer, get_tokens, get_logits):
|
|
global seqlen
|
|
|
|
def _perplexity(nlls, n_samples, seqlen):
|
|
return torch.exp(torch.stack(nlls).sum() / (n_samples * seqlen))
|
|
|
|
data = get_tokens(tokenizer, model, dataset)
|
|
n_samples = data.numel() // seqlen
|
|
|
|
nlls = []
|
|
|
|
with tqdm(range(n_samples), desc = "Perplexity -") as progress_bar:
|
|
|
|
for i in progress_bar:
|
|
|
|
start_index = i * seqlen
|
|
end_index = (i + 1) * seqlen
|
|
batch = data[:, start_index:end_index]
|
|
with torch.no_grad():
|
|
logits = get_logits(model, batch)
|
|
|
|
shift_logits = logits[:, :-1, :].contiguous().float()
|
|
shift_labels = data[:, start_index:end_index][:, 1:].to(shift_logits.device)
|
|
loss_fct = nn.CrossEntropyLoss()
|
|
loss = loss_fct(
|
|
shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)
|
|
)
|
|
neg_log_likelihood = loss.float() * seqlen
|
|
nlls.append(neg_log_likelihood)
|
|
|
|
curr_ppl = _perplexity(nlls, i + 1, seqlen)
|
|
progress_bar.set_description(f"Perplexity {curr_ppl:.3f}")
|
|
|
|
ppl = _perplexity(nlls, n_samples, seqlen)
|
|
|
|
return ppl.item()
|
|
|
|
def get_logits_hf(model, batch):
|
|
return model(batch).logits
|
|
|
|
def get_tokens_hf(tokenizer, model, text):
|
|
data = tokenizer("\n\n".join(text), return_tensors="pt")
|
|
data = data.input_ids.to(model.device)
|
|
return data
|
|
|
|
def get_logits_exl2(model, batch):
|
|
return model.forward(batch)
|
|
|
|
def get_tokens_exl2(tokenizer, model, text):
|
|
data = tokenizer.encode("\n\n".join(text), add_bos = True)
|
|
return data
|
|
|
|
def get_dataset(ds_path, ds_name, ds_split, cache_file, max_rows):
|
|
module_dir = os.path.dirname(os.path.abspath(__file__))
|
|
filename = os.path.join(module_dir, cache_file)
|
|
if os.path.exists(filename):
|
|
print("Loading cached dataset...")
|
|
with open(filename, "r") as f:
|
|
return json.load(f)
|
|
else:
|
|
print(f"Loading dataset: {ds_path}...")
|
|
c_dataset = load_dataset(ds_path, ds_name, split = ds_split, streaming = True)
|
|
c_rows = []
|
|
for row in c_dataset:
|
|
c_rows.append(row["text"])
|
|
max_rows -= 1
|
|
if not max_rows: break
|
|
with open(filename, "w") as f:
|
|
f.write(json.dumps(c_rows, indent = 4))
|
|
return c_rows
|
|
|
|
def model_instance_awq(model_dir):
|
|
model = AutoAWQForCausalLM.from_pretrained(model_dir, device_map="auto").model
|
|
tokenizer = AutoTokenizer.from_pretrained(model_dir)
|
|
return model.eval(), tokenizer
|
|
|
|
def model_instance_gptq(model_dir):
|
|
model = AutoGPTQForCausalLM.from_quantized(model_dir, device_map="auto").model
|
|
tokenizer = AutoTokenizer.from_pretrained(model_dir)
|
|
return model.eval(), tokenizer
|
|
|
|
def model_instance_exl2(model_dir):
|
|
config = ExLlamaV2Config(model_dir)
|
|
config.max_input_len = seqlen
|
|
config.max_output_len = seqlen
|
|
config.max_seq_len = seqlen
|
|
model = ExLlamaV2(config)
|
|
model.load()
|
|
# model.load(gpu_split = [0,24,0,0])
|
|
# cache = ExLlamaV2Cache(model, lazy = True)
|
|
# model.load_autosplit(cache)
|
|
tokenizer = ExLlamaV2Tokenizer(config, lazy_init = True)
|
|
return model, tokenizer
|
|
|
|
def flush():
|
|
torch.cuda.synchronize()
|
|
gc.collect()
|
|
torch.cuda.empty_cache()
|
|
|
|
def run_tests(test_models):
|
|
|
|
wikitext = get_dataset("wikitext", "wikitext-2-raw-v1", "test", "wikitext_cached.json", 4358)
|
|
c4 = get_dataset("allenai/c4", "en", "validation", "c4_cached.json", 750)
|
|
fineweb = get_dataset("HuggingFaceFW/fineweb", "default", "train", "fineweb_cached.json", 500)
|
|
|
|
datasets = [wikitext, c4, fineweb]
|
|
# datasets = [d[:100] for d in datasets]
|
|
results = []
|
|
|
|
num_devices = torch.cuda.device_count()
|
|
|
|
for (fw, model_dir, text) in test_models:
|
|
|
|
flush()
|
|
for d in range(num_devices): torch.cuda.reset_peak_memory_stats(d)
|
|
|
|
print(f"Loading {fw}: {model_dir}")
|
|
|
|
if fw == "awq":
|
|
model, tokenizer = model_instance_awq(model_dir)
|
|
ppl = [evaluate_perplexity(ds, model, tokenizer, get_tokens_hf, get_logits_hf) for ds in datasets]
|
|
model, tokenizer = None, None
|
|
|
|
if fw in ["exl2", "exl2_fp16"]:
|
|
model, tokenizer = model_instance_exl2(model_dir)
|
|
ppl = [evaluate_perplexity(ds, model, tokenizer, get_tokens_exl2, get_logits_exl2) for ds in datasets]
|
|
model.unload()
|
|
model, tokenizer = None, None
|
|
|
|
if fw == "gptq":
|
|
model, tokenizer = model_instance_gptq(model_dir)
|
|
ppl = [evaluate_perplexity(ds, model, tokenizer, get_tokens_hf, get_logits_hf) for ds in datasets]
|
|
model, tokenizer = None, None
|
|
|
|
total_memory = sum(torch.cuda.max_memory_allocated(d) for d in range(num_devices))
|
|
max_mem_gb = total_memory / (1024 ** 3)
|
|
print(f"Max CUDA memory: {max_mem_gb:.2f} GB")
|
|
|
|
results.append({
|
|
"": text,
|
|
"Wikitext": f"{ppl[0]:.3f}",
|
|
"C4": f"{ppl[1]:.3f}",
|
|
"FineWeb": f"{ppl[2]:.3f}",
|
|
"Max VRAM": f"{max_mem_gb:.2f} GB"
|
|
})
|
|
|
|
return results
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
all_models = [
|
|
("exl2_fp16", "/mnt/str/models/llama3-8b-instruct", "FP16"),
|
|
("awq", "/mnt/str/models/_awq/llama3-8b-instruct-awq/", "AWQ"),
|
|
("gptq", "/mnt/str/models/_gptq/mistral-8b-instruct-v0.2-gptq/4bit-act-128g/", "GPTQ 4b-128g-act"),
|
|
("exl2", "/mnt/str/models/llama3-8b-instruct-exl2/4.0bpw", "EXL2 4.00 bpw"),
|
|
("exl2", "/mnt/str/models/llama3-8b-instruct-exl2/4.15bpw", "EXL2 4.15 bpw"),
|
|
("exl2", "/mnt/str/models/llama3-8b-instruct-exl2/5.0bpw", "EXL2 5.00 bpw"),
|
|
("exl2", "/mnt/str/models/llama3-8b-instruct-exl2/5.3bpw", "EXL2 5.30 bpw"),
|
|
]
|
|
|
|
results = run_tests(all_models)
|
|
markdown = markdown_table(results).set_params(row_sep = 'markdown', quote = False).get_markdown()
|
|
print(markdown)
|
|
|
|
|