Files
exllamav2/examples/multiple_caches.py

140 lines
3.9 KiB
Python

import sys, os
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from exllamav2 import(
ExLlamaV2,
ExLlamaV2Config,
ExLlamaV2Cache,
ExLlamaV2Cache_8bit,
ExLlamaV2Tokenizer,
)
from exllamav2.generator import (
ExLlamaV2BaseGenerator,
ExLlamaV2Sampler
)
import torch
import time
import random
# Initialize model
model_directory = "/mnt/str/models/llama2-7b-exl2/4.0bpw/"
config = ExLlamaV2Config()
config.model_dir = model_directory
config.prepare()
model = ExLlamaV2(config)
print("Loading model: " + model_directory)
model.load()
tokenizer = ExLlamaV2Tokenizer(config)
# Cache mode
cache_8bit = False
# Create some sampling settings
settings_proto = ExLlamaV2Sampler.Settings()
settings_proto.temperature = 0.8
settings_proto.top_p = 0.75
# settings_proto.mirostat = True
# settings_proto.mirostat_tau = 5
# settings_proto.top_k = 1000
# Define some prompts to inference in parallel
prompts = ["Once you eliminate all the",
"C++ is",
"Once upon a time, I had the pleasure of meeting Toni Morrison. I was attending an event at the University of North Carolina Chapel Hill when she came to speak. She",
"A bird in the hand is worth two in the bush, but",
"Too many cooks spoil the",
"A lynx is a type of",
"Standing before the gates of"]
max_parallel_seqs = 3
# Active sequences and corresponding caches and settings
input_ids = []
caches = []
settings = []
# Stats
total_gen_tokens = 0
total_prompt_tokens = 0
prompt_time = 0
token_time = 0
# Continue generating as long as there is work to do
while len(prompts) or len(input_ids):
# If doing less than max_parallel_seqs, start some more. Prompt processing isn't batched in this example, but
# would benefit much less from batching anyway
while len(input_ids) < max_parallel_seqs and len(prompts):
time_begin = time.time()
prompt = prompts.pop()
ids = tokenizer.encode(prompt)
if cache_8bit:
cache = ExLlamaV2Cache_8bit(model, max_seq_len = 256) # (max_seq_len could be different for each cache)
else:
cache = ExLlamaV2Cache(model, max_seq_len = 256) # (max_seq_len could be different for each cache)
model.forward(ids[:, :-1], cache, preprocess_only = True)
input_ids.append(ids)
caches.append(cache)
settings.append(settings_proto.clone()) # Need individual settings per prompt to support Mirostat
total_prompt_tokens += ids.shape[-1] -1
prompt_time += time.time() - time_begin
# Create a batch tensor of the last token in each active sequence, forward through the model using the list of
# active caches rather than a single, batched cache. Then sample for each token indidividually with some
# arbitrary stop condition
time_begin = time.time()
inputs = torch.cat([x[:, -1:] for x in input_ids], dim = 0)
logits = model.forward(inputs, caches, input_mask = None).float().cpu()
eos = []
r = random.random()
for i in range(len(input_ids)):
token, _, _ = ExLlamaV2Sampler.sample(logits[i:i+1, :, :], settings[i], input_ids[i], r, tokenizer)
input_ids[i] = torch.cat([input_ids[i], token], dim = 1)
total_gen_tokens += 1
if token.item() == tokenizer.newline_token_id or caches[i].current_seq_len == caches[i].max_seq_len:
eos.insert(0, i)
token_time += time.time() - time_begin
# Output and drop any sequences completed in this step
for i in eos:
output = tokenizer.decode(input_ids[i])[0]
print("-----")
print(output.strip())
input_ids.pop(i)
caches.pop(i)
settings.pop(i)
# Stats
print("-----")
print(f"Prompts: {total_prompt_tokens} tokens, {total_prompt_tokens / prompt_time:.2f} tokens/second")
print(f"Tokens: {total_gen_tokens} tokens, {total_gen_tokens / token_time:.2f} tokens/second")