mirror of
https://github.com/turboderp-org/exllamav2.git
synced 2026-03-15 00:07:26 +00:00
256 lines
7.2 KiB
Python
256 lines
7.2 KiB
Python
|
|
import sys, os, gc, time, random
|
|
import torch
|
|
|
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|
|
|
from exllamav2 import(
|
|
ExLlamaV2,
|
|
ExLlamaV2Config,
|
|
ExLlamaV2CacheBase,
|
|
ExLlamaV2Cache,
|
|
ExLlamaV2Cache_8bit,
|
|
ExLlamaV2Tokenizer,
|
|
)
|
|
|
|
from exllamav2.generator import (
|
|
ExLlamaV2BaseGenerator,
|
|
ExLlamaV2StreamingGenerator,
|
|
ExLlamaV2Sampler
|
|
)
|
|
|
|
import time
|
|
|
|
model: ExLlamaV2
|
|
config: ExLlamaV2Config
|
|
tokenizer: ExLlamaV2Tokenizer
|
|
cache: ExLlamaV2CacheBase
|
|
|
|
def unload():
|
|
global model, config, tokenizer, cache
|
|
|
|
model.unload()
|
|
model = None
|
|
config = None
|
|
cache = None
|
|
tokenizer = None
|
|
|
|
gc.collect()
|
|
torch.cuda.empty_cache()
|
|
|
|
|
|
def load_model(model_dir, split = None, cache_8bit = False):
|
|
global model, config, tokenizer, cache
|
|
|
|
config = ExLlamaV2Config()
|
|
config.model_dir = model_dir
|
|
config.prepare()
|
|
|
|
model = ExLlamaV2(config)
|
|
print(" -- Loading model: " + model_dir)
|
|
|
|
model.load(split)
|
|
|
|
tokenizer = ExLlamaV2Tokenizer(config)
|
|
|
|
if cache_8bit:
|
|
print(" -- Creating 8-bit cache")
|
|
cache = ExLlamaV2Cache_8bit(model, batch_size = 4)
|
|
else:
|
|
print(" -- Creating 16-bit cache")
|
|
cache = ExLlamaV2Cache(model, batch_size = 4)
|
|
|
|
|
|
def test_gen_normal(prompt, max_new_tokens):
|
|
global model, config, tokenizer, cache
|
|
|
|
print("--------------------------------")
|
|
print("Generating, normal")
|
|
print()
|
|
|
|
generator = ExLlamaV2BaseGenerator(model, cache, tokenizer)
|
|
|
|
settings = ExLlamaV2Sampler.Settings()
|
|
settings.temperature = 0.85
|
|
settings.top_k = 50
|
|
settings.top_p = 0.8
|
|
settings.top_a = 0.0
|
|
settings.token_repetition_penalty = 1.15
|
|
settings.disallow_tokens(tokenizer, [tokenizer.eos_token_id])
|
|
|
|
generator.warmup()
|
|
time_begin = time.time()
|
|
|
|
output = generator.generate_simple(prompt, settings, max_new_tokens, seed=1234)
|
|
|
|
time_end = time.time()
|
|
time_total = time_end - time_begin
|
|
|
|
print(output)
|
|
print()
|
|
print(f"Response generated in {time_total:.2f} seconds, {max_new_tokens} tokens, {max_new_tokens / time_total:.2f} tokens/second")
|
|
|
|
|
|
def test_gen_streaming(prompt, max_new_tokens):
|
|
global model, config, tokenizer, cache
|
|
|
|
print("--------------------------------")
|
|
print("Generating, streaming")
|
|
print()
|
|
|
|
generator = ExLlamaV2StreamingGenerator(model, cache, tokenizer)
|
|
|
|
settings = ExLlamaV2Sampler.Settings()
|
|
settings.temperature = 0.85
|
|
settings.top_k = 50
|
|
settings.top_p = 0.8
|
|
settings.top_a = 0.0
|
|
settings.token_repetition_penalty = 1.15
|
|
settings.disallow_tokens(tokenizer, [tokenizer.eos_token_id])
|
|
|
|
input_ids = tokenizer.encode(prompt)
|
|
prompt_tokens = input_ids.shape[-1]
|
|
|
|
print(prompt, end = "")
|
|
sys.stdout.flush()
|
|
|
|
time_begin_prompt = time.time()
|
|
|
|
generator.set_stop_conditions([])
|
|
generator.begin_stream(input_ids, settings)
|
|
|
|
time_begin_stream = time.time()
|
|
generated_tokens = 0
|
|
|
|
while True:
|
|
chunk, eos, _ = generator.stream()
|
|
generated_tokens += 1
|
|
print(chunk, end = "")
|
|
sys.stdout.flush()
|
|
if eos or generated_tokens == max_new_tokens: break
|
|
|
|
time_end = time.time()
|
|
|
|
time_prompt = time_begin_stream - time_begin_prompt
|
|
time_tokens = time_end - time_begin_stream
|
|
|
|
print()
|
|
print()
|
|
print(f"Prompt processed in {time_prompt:.2f} seconds, {prompt_tokens} tokens, {prompt_tokens / time_prompt:.2f} tokens/second")
|
|
print(f"Response generated in {time_tokens:.2f} seconds, {generated_tokens} tokens, {generated_tokens / time_tokens:.2f} tokens/second")
|
|
|
|
|
|
def test_gen_batch(max_new_tokens):
|
|
|
|
print("--------------------------------")
|
|
print("Generating, batched")
|
|
print()
|
|
|
|
generator = ExLlamaV2BaseGenerator(model, cache, tokenizer)
|
|
|
|
settings = ExLlamaV2Sampler.Settings()
|
|
settings.temperature = 0.85
|
|
settings.top_k = 50
|
|
settings.top_p = 0.8
|
|
settings.top_a = 0.0
|
|
settings.token_repetition_penalty = 1.15
|
|
settings.disallow_tokens(tokenizer, [tokenizer.eos_token_id])
|
|
|
|
generator.warmup()
|
|
time_begin = time.time()
|
|
|
|
prompts = ["Here's how to create a powerful love potio",
|
|
"For once,",
|
|
"The events of the American Civil W",
|
|
"A bird in the hand is worth"]
|
|
|
|
output = generator.generate_simple(prompts, settings, max_new_tokens, seed = 1234, token_healing = True)
|
|
|
|
time_end = time.time()
|
|
time_total = time_end - time_begin
|
|
|
|
for o in output:
|
|
print(o)
|
|
print("---")
|
|
print()
|
|
print(f"Response generated in {time_total:.2f} seconds, {max_new_tokens} tokens, throughput {4 * max_new_tokens / time_total:.2f} tokens/second")
|
|
|
|
|
|
def test_multicache(max_new_tokens):
|
|
|
|
print("--------------------------------")
|
|
print("Generating, batched multi cache")
|
|
print()
|
|
|
|
settings = ExLlamaV2Sampler.Settings()
|
|
settings.temperature = 0.85
|
|
settings.top_k = 50
|
|
settings.top_p = 0.8
|
|
settings.top_a = 0.0
|
|
settings.token_repetition_penalty = 1.15
|
|
settings.disallow_tokens(tokenizer, [tokenizer.eos_token_id])
|
|
|
|
prompts = ["Here's how to create a powerful love potion",
|
|
"For once,",
|
|
"The events of the American Civil War",
|
|
"A bird in the hand is worth"]
|
|
|
|
caches = [ExLlamaV2Cache(model, max_seq_len = 256) for _ in range(len(prompts))]
|
|
input_ids = []
|
|
|
|
for i in range(len(prompts)):
|
|
|
|
input_ids.append(tokenizer.encode(prompts[i]))
|
|
model.forward(input_ids[i][:, :-1], caches[i], input_mask = None, preprocess_only = True)
|
|
|
|
time_begin = time.time()
|
|
|
|
for i in range(max_new_tokens):
|
|
|
|
inputs = torch.cat([x[:, -1:] for x in input_ids], dim = 0)
|
|
logits = model.forward(inputs, caches, input_mask = None).float().cpu()
|
|
|
|
r = random.random()
|
|
for j in range(len(input_ids)):
|
|
token, _, _ = ExLlamaV2Sampler.sample(logits[j:j + 1, :, :], settings, input_ids[j], r, tokenizer)
|
|
input_ids[j] = torch.cat([input_ids[j], token], dim = 1)
|
|
|
|
output = [tokenizer.decode(ids)[0] for ids in input_ids]
|
|
|
|
time_end = time.time()
|
|
time_total = time_end - time_begin
|
|
|
|
for o in output:
|
|
print(o)
|
|
print("---")
|
|
print()
|
|
print(f"Response generated in {time_total:.2f} seconds, {max_new_tokens} tokens, throughput {4 * max_new_tokens / time_total:.2f} tokens/second")
|
|
|
|
|
|
def tests(model_dir, cache_8bit, use_split):
|
|
|
|
if use_split: split = [1, 24]
|
|
else: split = None
|
|
print("--------------------------------")
|
|
print(f" -- Split: {split}")
|
|
load_model(model_dir, split = split, cache_8bit = cache_8bit)
|
|
|
|
test_gen_normal("Our story begins in the Scottish town of Auchtermuchty, where once", 150)
|
|
test_gen_streaming("Our story begins in the Scottish town of Auchtermuchty, where once", 150)
|
|
test_gen_batch(40)
|
|
if model.is_quant(): test_multicache(40)
|
|
|
|
unload()
|
|
|
|
|
|
q_model_directory = "/mnt/str/models/mistral-7b-instruct-exl2/4.0bpw/"
|
|
f_model_directory = "/mnt/str/models/tinyllama-1b-ckpt503/"
|
|
|
|
tests(q_model_directory, False, False)
|
|
tests(q_model_directory, False, True)
|
|
tests(q_model_directory, True, False)
|
|
tests(q_model_directory, True, True)
|
|
tests(f_model_directory, False, False)
|
|
tests(f_model_directory, False, True)
|
|
tests(f_model_directory, True, False)
|
|
tests(f_model_directory, True, True) |