mirror of
https://github.com/turboderp-org/exllamav2.git
synced 2026-04-20 14:29:28 +00:00
Fix multiple caches not working with 8-bit cache mode
This commit is contained in:
@@ -5,6 +5,7 @@ from exllamav2 import(
|
||||
ExLlamaV2,
|
||||
ExLlamaV2Config,
|
||||
ExLlamaV2Cache,
|
||||
ExLlamaV2Cache_8bit,
|
||||
ExLlamaV2Tokenizer,
|
||||
)
|
||||
|
||||
@@ -19,7 +20,7 @@ import random
|
||||
|
||||
# Initialize model
|
||||
|
||||
model_directory = "/mnt/str/models/_exl2/llama2-7b-exl2/4.0bpw/"
|
||||
model_directory = "/mnt/str/models/llama2-7b-exl2/4.0bpw/"
|
||||
|
||||
config = ExLlamaV2Config()
|
||||
config.model_dir = model_directory
|
||||
@@ -32,6 +33,10 @@ model.load()
|
||||
|
||||
tokenizer = ExLlamaV2Tokenizer(config)
|
||||
|
||||
# Cache mode
|
||||
|
||||
cache_8bit = False
|
||||
|
||||
# Create some sampling settings
|
||||
|
||||
settings_proto = ExLlamaV2Sampler.Settings()
|
||||
@@ -79,7 +84,10 @@ while len(prompts) or len(input_ids):
|
||||
|
||||
prompt = prompts.pop()
|
||||
ids = tokenizer.encode(prompt)
|
||||
cache = ExLlamaV2Cache(model, max_seq_len = 256) # (max_seq_len could be different for each cache)
|
||||
if cache_8bit:
|
||||
cache = ExLlamaV2Cache_8bit(model, max_seq_len = 256) # (max_seq_len could be different for each cache)
|
||||
else:
|
||||
cache = ExLlamaV2Cache(model, max_seq_len = 256) # (max_seq_len could be different for each cache)
|
||||
|
||||
model.forward(ids[:, :-1], cache, preprocess_only = True)
|
||||
input_ids.append(ids)
|
||||
|
||||
@@ -443,12 +443,16 @@ class ExLlamaV2Attention(ExLlamaV2Module):
|
||||
|
||||
# Add keys and values to cache
|
||||
|
||||
batch_keys, batch_values = cache[i].get_kv_state(self.layer_idx, batch_size, 0, past_len)
|
||||
batch_keys, batch_values = cache[i].get_kv_state(self.layer_idx, 1, 0, past_len[1][i].item())
|
||||
new_keys = batch_keys.narrow(1, past_len[1][i], q_len)
|
||||
new_values = batch_values.narrow(1, past_len[1][i], q_len)
|
||||
new_keys.copy_(k_states.narrow(0, i, 1))
|
||||
new_values.copy_(v_states.narrow(0, i, 1))
|
||||
|
||||
# Store updated cache values
|
||||
|
||||
cache[i].store_kv_state(self.layer_idx, 1, past_len[1][i].item(), q_len)
|
||||
|
||||
# Key/value tensors with past
|
||||
|
||||
k_states_b = batch_keys.narrow(1, 0, past_len[1][i] + q_len)
|
||||
|
||||
Reference in New Issue
Block a user