Files
exllamav2/tests/test_alloc.py
2023-09-10 06:15:33 +02:00

164 lines
4.4 KiB
Python

import sys, os, math
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from exllamav2 import(
ExLlamaV2,
ExLlamaV2Config,
ExLlamaV2Cache,
ExLlamaV2Tokenizer,
)
from exllamav2.generator import (
ExLlamaV2BaseGenerator,
ExLlamaV2Sampler
)
import time
import torch
# The allocation to test
# model_directory = "/mnt/str/models/_exl2/openllama-3b-3.0bpw-h6-exl2/"
# model_directory = "/mnt/str/models/_exl2/llama-7b-3.0bpw-h6-exl2/"
# model_directory = "/mnt/str/models/_exl2/llama2-70b-chat-2.5bpw-h6-exl2/"
model_directory = "/mnt/str/models/_exl2/codellama-34b-instruct-4.0bpw-h6-exl2/"
allocation = [18, 24]
# Prime CUDA and initialize mem measurement
torch_devices = [f"cuda:{i}" for i in range(torch.cuda.device_count())]
torch.cuda.init()
temp = [torch.randn((1024, 1024), dtype = torch.float, device = x) for x in torch_devices]
temp2 = [x * 2 for x in temp]
temp = []
temp2 = []
torch.cuda.empty_cache()
mem_base = {}
for dev in torch_devices:
torch.cuda.reset_peak_memory_stats(dev)
mem_base[dev] = torch.cuda.max_memory_allocated(dev)
# Initialize and load model
config = ExLlamaV2Config()
config.model_dir = model_directory
config.prepare()
config.max_seq_len = 8192
model = ExLlamaV2(config)
print("Loading model: " + model_directory)
_, stats = model.load(allocation, stats = True)
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats(dev)
# Load tokenizer
tokenizer = ExLlamaV2Tokenizer(config)
# Initialize and measure cache
cache = ExLlamaV2Cache(model)
cache_fp = cache.footprint()
expected = [(ab - rb) for (ab, rb) in zip(allocation, stats)]
expected_with_cache = [e for e in expected]
for idx, c in enumerate(cache_fp): expected_with_cache[idx] += c / 1024**3
# Initialize generator
generator = ExLlamaV2BaseGenerator(model, cache, tokenizer)
# Generate some text
settings = ExLlamaV2Sampler.Settings()
settings.temperature = 0.85
settings.top_k = 50
settings.top_p = 0.8
settings.token_repetition_penalty = 1.15
settings.disallow_tokens(tokenizer, [tokenizer.eos_token_id])
prompt = "Our story begins in the Scottish town of Auchtermuchty, where once"
max_new_tokens = 150
generator.warmup()
time_begin = time.time()
output = generator.generate_simple(prompt, settings, max_new_tokens, seed = 1234)
time_end = time.time()
time_total = time_end - time_begin
print(output)
print()
print(f"Response generated in {time_total:.2f} seconds, {max_new_tokens} tokens, {max_new_tokens / time_total:.2f} tokens/second")
print()
print(f"Prompt processing, {model.config.max_seq_len - 1} tokens...")
cache.current_seq_len = 0
time_begin = time.time()
input_ids = torch.randint(0, model.config.vocab_size - 1, (1, model.config.max_seq_len - 1))
model.forward(input_ids, cache, preprocess_only = True)
torch.cuda.synchronize()
time_end = time.time()
time_total = time_end - time_begin
print(f"Prompt processed in {time_total:.2f} seconds, {(model.config.max_seq_len - 1) / time_total:.2f} tokens/second")
print()
# Report
res1 = f" ** VRAM reported by Torch : "
res2 = f" ** VRAM expected : "
res3 = f" ** VRAM expected (with cache) : "
res4 = f" ** VRAM allocated (max) : "
res5 = f" ** Cache size : "
first = True
mem_total = 0
mem_exp = 0
for idx, device in enumerate(torch_devices):
mem_this = torch.cuda.max_memory_allocated(device) - mem_base[device]
mem_total += mem_this
mem_exp += expected_with_cache[idx] * 1024 ** 3
if not first: res1 += " - "
if not first: res2 += " - "
if not first: res3 += " - "
if not first: res4 += " - "
if not first: res5 += " - "
first = False
res1 += f"[{device}] {mem_this / (1024 ** 2):,.2f} MB"
res2 += f"[{device}] {expected[idx] * 1024:,.2f} MB"
res3 += f"[{device}] {expected_with_cache[idx] * 1024:,.2f} MB"
res4 += f"[{device}] {allocation[idx] * 1024:,.2f} MB"
res5 += f"[{device}] {cache_fp[idx] / (1024 ** 2) if idx < len(cache_fp) else 0:,.2f} MB"
print(res4)
print(res2)
print(res5)
print(res3)
print(res1)
print()
print(f"Max sequence length: {config.max_seq_len}")
print(f"Hidden size: {config.hidden_size}")
print(f"Attention heads: {config.num_attention_heads}")
print(f"Key/value heads: {config.num_key_value_heads}")
print(f"Max attention size: {math.sqrt(config.max_attention_size)} ** 2")
print(f"Max input len: {config.max_input_len}")
# print(f"Correction amount: {mem_total - mem_exp:,.2f} B")