CFG support in streaming gen

This commit is contained in:
turboderp
2024-01-01 23:48:24 +01:00
parent 13fe676ac2
commit 66d19b6aa9
3 changed files with 142 additions and 15 deletions

96
examples/cfg.py Normal file
View File

@@ -0,0 +1,96 @@
import sys, os
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from exllamav2 import *
from exllamav2.generator import *
# Initialize model and cache
model_directory = "/mnt/str/models/llama2-70b-chat-exl2/4.0bpw"
config = ExLlamaV2Config()
config.model_dir = model_directory
config.max_batch_size = 2
config.no_flash_attn = True
config.prepare()
model = ExLlamaV2(config)
print("Loading model: " + model_directory)
cache = ExLlamaV2Cache(model, lazy = True, batch_size = 2)
model.load_autosplit(cache)
tokenizer = ExLlamaV2Tokenizer(config)
# Initialize generator
generator = ExLlamaV2StreamingGenerator(model, cache, tokenizer)
# Settings
settings = ExLlamaV2Sampler.Settings()
settings.temperature = 0.85
settings.top_k = 50
settings.top_p = 0.8
settings.top_a = 0.0
settings.token_repetition_penalty = 1.05
max_new_tokens = 250
# Prompt
positive = \
"""[INST] <<SYS>>
You are a cheerful, bubbly and respectful assistant.
<</SYS>>
{prompt} [/INST]"""
negative = \
"""[INST] <<SYS>>
You are a rude and obnoxious assistant.
<</SYS>>
{prompt} [/INST]"""
q = """Tell me about Homer Simpson."""
prompt_a = positive.replace("{prompt}", q)
prompt_b = negative.replace("{prompt}", q)
print("-------------------------------------------")
print("Prompt a:\n" + prompt_a + "\n")
print("-------------------------------------------")
print("Prompt b:\n" + prompt_b + "\n")
for x in range(11):
# cfg_scale is the weight of the first prompt in the batch, while the second prompt is weighted as (1 - cfg_scale).
#
# - at cfg_scale == 0, only the second prompt is effective
# - at 0 < cfg_scale < 1, the sampled logits will be a weighted average of the normalized outputs of both prompts
# - at cfg_scale == 1, only the first prompt is effective
# - at cfg_scale > 1, the second prompt will have a negative weight, emphasizing the difference between the two
settings.cfg_scale = x / 5
# Start a batched generation. CFG requires a batch size of exactly 2. Offsets and padding mask are required
input_ids, offsets = tokenizer.encode([prompt_a, prompt_b], encode_special_tokens = True, return_offsets = True)
mask = tokenizer.padding_mask(input_ids)
generator.begin_stream(input_ids, settings, input_mask = mask, position_offsets = offsets)
generator.set_stop_conditions([tokenizer.eos_token_id])
print(f"---------------------------------------------------------------------------------------")
print(f"cfg_scale = {settings.cfg_scale:.1f}")
print()
generated_tokens = 0
max_new_tokens = 200
while True:
chunk, eos, _ = generator.stream()
generated_tokens += 1
print (chunk, end = "")
sys.stdout.flush()
if eos or generated_tokens == max_new_tokens: break
print()

View File

@@ -1,8 +1,8 @@
import torch
import torch.nn.functional as F
from exllamav2 import ExLlamaV2Tokenizer
from exllamav2.ext import exllamav2_ext as ext_c, none_tensor
class ExLlamaV2Sampler:
class Settings:
@@ -30,6 +30,7 @@ class ExLlamaV2Sampler:
mirostat_mu = None # (re)initialized from mirostat_tau on first sample
token_bias = None
cfg_scale = None
filters = []
@@ -105,9 +106,22 @@ class ExLlamaV2Sampler:
assert logits.shape[1] == 1, "Logits tensor is incorrect shape, must be (bsz, 1, vocab_size)"
assert prefix_token is None or prefix_token.shape == (batch_size, 1), "Prefix token list doesn't match batch shape"
assert batch_size == 1 or len(settings.filters) == 0, "Filters not implemented for batch size > 1"
if settings.cfg_scale is not None: assert batch_size == 2, "CFG requires logits to be bsz 2"
else: assert batch_size == 1 or len(settings.filters) == 0, "Filters not implemented for batch size > 1"
logits = logits.squeeze(1)
# CFG
if settings.cfg_scale is not None:
logits = F.log_softmax(logits, dim = -1)
logits = settings.cfg_scale * logits[0] + (1 - settings.cfg_scale) * logits[1]
logits = logits.unsqueeze(0)
batch_size = 1
# Prepare filter
logit_filter = torch.empty((batch_size, vocab_size), dtype = torch.bool)
ext_c.fast_fill_cpu_ones_bool(logit_filter)
@@ -117,7 +131,7 @@ class ExLlamaV2Sampler:
settings.token_frequency_penalty != 0.0 or \
settings.token_presence_penalty != 0.0:
ext_c.apply_rep_penalty(sequence_ids,
ext_c.apply_rep_penalty(sequence_ids[:1, :],
settings.token_repetition_penalty,
settings.token_repetition_range,
settings.token_repetition_decay,

View File

@@ -42,6 +42,9 @@ class ExLlamaV2StreamingGenerator(ExLlamaV2BaseGenerator):
accepted_draft_tokens: int = 0
active_loras = []
position_offsets = None
input_mask = None
def __init__(self, model, cache, tokenizer, draft_model = None, draft_cache = None, num_speculative_tokens = 5):
super().__init__(model, cache, tokenizer)
@@ -74,7 +77,14 @@ class ExLlamaV2StreamingGenerator(ExLlamaV2BaseGenerator):
else: raise ValueError("Unsupported type in stop_conditions")
def begin_stream(self, input_ids: torch.Tensor, gen_settings: ExLlamaV2Sampler.Settings, token_healing = False, loras = None):
def begin_stream(self, input_ids: torch.Tensor, gen_settings: ExLlamaV2Sampler.Settings, token_healing = False, loras = None, input_mask = None, position_offsets = None):
assert input_ids.shape[0] <= 2, "Streaming generator does not support batch size > 1"
if input_ids.shape[0] == 2:
assert gen_settings.cfg_scale is not None, "No CFG scale set"
self.position_offsets = position_offsets
self.input_mask = input_mask
# Accept LoRA or list of LoRAs
if loras is not None and isinstance(loras, ExLlamaV2Lora): loras = [loras]
@@ -98,7 +108,7 @@ class ExLlamaV2StreamingGenerator(ExLlamaV2BaseGenerator):
if self.heal_next_token:
# Pop the last toke
# Pop the last token
old_tail = self.tokenizer.decode(self.sequence_ids[:, -self.tail_decode_tokens:])[0]
last_token = self.sequence_ids[:, -1:]
@@ -136,7 +146,7 @@ class ExLlamaV2StreamingGenerator(ExLlamaV2BaseGenerator):
# Decode the current tail end of the sequence
old_tail = self.tokenizer.decode(self.sequence_ids[:, -self.tail_decode_tokens:])[0]
old_tail = self.tokenizer.decode(self.sequence_ids[:1, -self.tail_decode_tokens:])[0]
# Generate a single token and append to the sequence
@@ -149,7 +159,7 @@ class ExLlamaV2StreamingGenerator(ExLlamaV2BaseGenerator):
# Decode the tail end of the sequence with the added token to get (actual) characters added
new_tail = self.tokenizer.decode(self.sequence_ids[:, -(self.tail_decode_tokens + 1):])[0]
new_tail = self.tokenizer.decode(self.sequence_ids[:1, -(self.tail_decode_tokens + 1):])[0]
new_text = new_tail[len(old_tail):]
next_token, new_text = self._catch_utf8(next_token, new_text)
@@ -250,11 +260,11 @@ class ExLlamaV2StreamingGenerator(ExLlamaV2BaseGenerator):
self.sequence_ids = in_tokens.clone()
self.cache.current_seq_len = 0
self.model.forward(self.sequence_ids[:, :-1], self.cache, preprocess_only = True, loras = self.active_loras)
self.model.forward(self.sequence_ids[:, :-1], self.cache, preprocess_only = True, loras = self.active_loras, input_mask = self.input_mask, position_offsets = self.position_offsets)
if self.draft_model is not None:
self.draft_cache.current_seq_len = 0
self.draft_model.forward(self.sequence_ids[:, :-1], self.draft_cache, preprocess_only = True)
self.draft_model.forward(self.sequence_ids[:1, :-1], self.draft_cache, preprocess_only = True)
self.future_logits = None
self.future_tokens = None
@@ -296,7 +306,7 @@ class ExLlamaV2StreamingGenerator(ExLlamaV2BaseGenerator):
start = self.cache.current_seq_len
self.sequence_ids = torch.cat((self.sequence_ids, in_tokens), dim = 1)
self.model.forward(self.sequence_ids[:, start : -1], self.cache, preprocess_only = True, loras = self.active_loras)
self.model.forward(self.sequence_ids[:, start : -1], self.cache, preprocess_only = True, loras = self.active_loras, input_mask = self.input_mask, position_offsets = self.position_offsets)
if self.draft_model is not None:
self.draft_model.forward(self.sequence_ids[:, start: -1], self.draft_cache, preprocess_only = True)
@@ -308,14 +318,18 @@ class ExLlamaV2StreamingGenerator(ExLlamaV2BaseGenerator):
if self.draft_model is None:
logits = self.model.forward(self.sequence_ids[:, -1:], self.cache, loras = self.active_loras).float().cpu()
logits = self.model.forward(self.sequence_ids[:, -1:], self.cache, loras = self.active_loras, input_mask = self.input_mask, position_offsets = self.position_offsets).float().cpu()
token, _, eos = ExLlamaV2Sampler.sample(logits, gen_settings, self.sequence_ids, random.random(), self.tokenizer, prefix_token)
else:
token, eos = self._gen_single_token_speculative(gen_settings, prefix_token)
self.sequence_ids = torch.cat([self.sequence_ids, token], dim = 1)
if self.sequence_ids.shape[0] > 1 and token.shape[0] == 1:
self.sequence_ids = torch.cat([self.sequence_ids, token.repeat(self.sequence_ids.shape[0], 1)], dim = 1)
else:
self.sequence_ids = torch.cat([self.sequence_ids, token], dim = 1)
gen_settings.feed_filters(token)
return token, eos
@@ -327,7 +341,7 @@ class ExLlamaV2StreamingGenerator(ExLlamaV2BaseGenerator):
# Generate draft
draft_gen_settings = gen_settings.greedy_clone()
draft_sequence_ids = self.sequence_ids.clone()
draft_sequence_ids = self.sequence_ids[:1, :]
num_drafted_tokens = 0
for k in range(self.num_speculative_tokens):
@@ -350,8 +364,11 @@ class ExLlamaV2StreamingGenerator(ExLlamaV2BaseGenerator):
# Forward last sampled token plus draft through model
self.future_tokens = draft_sequence_ids[:, -1 - num_drafted_tokens:]
self.future_logits = self.model.forward(self.future_tokens, self.cache, loras = self.active_loras).float().cpu()
if self.sequence_ids.shape[0] > 1:
self.future_tokens = draft_sequence_ids[:, -1 - num_drafted_tokens:].repeat(self.sequence_ids.shape[0], 1)
else:
self.future_tokens = draft_sequence_ids[:, -1 - num_drafted_tokens:]
self.future_logits = self.model.forward(self.future_tokens, self.cache, loras = self.active_loras, input_mask = self.input_mask, position_offsets = self.position_offsets).float().cpu()
# Rewind model cache