mirror of
https://github.com/turboderp-org/exllamav2.git
synced 2026-04-20 14:29:28 +00:00
CFG support in streaming gen
This commit is contained in:
96
examples/cfg.py
Normal file
96
examples/cfg.py
Normal file
@@ -0,0 +1,96 @@
|
||||
|
||||
import sys, os
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from exllamav2 import *
|
||||
from exllamav2.generator import *
|
||||
|
||||
# Initialize model and cache
|
||||
|
||||
model_directory = "/mnt/str/models/llama2-70b-chat-exl2/4.0bpw"
|
||||
|
||||
config = ExLlamaV2Config()
|
||||
config.model_dir = model_directory
|
||||
config.max_batch_size = 2
|
||||
config.no_flash_attn = True
|
||||
config.prepare()
|
||||
|
||||
model = ExLlamaV2(config)
|
||||
print("Loading model: " + model_directory)
|
||||
|
||||
cache = ExLlamaV2Cache(model, lazy = True, batch_size = 2)
|
||||
model.load_autosplit(cache)
|
||||
|
||||
tokenizer = ExLlamaV2Tokenizer(config)
|
||||
|
||||
# Initialize generator
|
||||
|
||||
generator = ExLlamaV2StreamingGenerator(model, cache, tokenizer)
|
||||
|
||||
# Settings
|
||||
|
||||
settings = ExLlamaV2Sampler.Settings()
|
||||
settings.temperature = 0.85
|
||||
settings.top_k = 50
|
||||
settings.top_p = 0.8
|
||||
settings.top_a = 0.0
|
||||
settings.token_repetition_penalty = 1.05
|
||||
|
||||
max_new_tokens = 250
|
||||
|
||||
# Prompt
|
||||
|
||||
positive = \
|
||||
"""[INST] <<SYS>>
|
||||
You are a cheerful, bubbly and respectful assistant.
|
||||
<</SYS>>
|
||||
{prompt} [/INST]"""
|
||||
|
||||
negative = \
|
||||
"""[INST] <<SYS>>
|
||||
You are a rude and obnoxious assistant.
|
||||
<</SYS>>
|
||||
{prompt} [/INST]"""
|
||||
|
||||
q = """Tell me about Homer Simpson."""
|
||||
|
||||
prompt_a = positive.replace("{prompt}", q)
|
||||
prompt_b = negative.replace("{prompt}", q)
|
||||
|
||||
print("-------------------------------------------")
|
||||
print("Prompt a:\n" + prompt_a + "\n")
|
||||
|
||||
print("-------------------------------------------")
|
||||
print("Prompt b:\n" + prompt_b + "\n")
|
||||
|
||||
for x in range(11):
|
||||
|
||||
# cfg_scale is the weight of the first prompt in the batch, while the second prompt is weighted as (1 - cfg_scale).
|
||||
#
|
||||
# - at cfg_scale == 0, only the second prompt is effective
|
||||
# - at 0 < cfg_scale < 1, the sampled logits will be a weighted average of the normalized outputs of both prompts
|
||||
# - at cfg_scale == 1, only the first prompt is effective
|
||||
# - at cfg_scale > 1, the second prompt will have a negative weight, emphasizing the difference between the two
|
||||
|
||||
settings.cfg_scale = x / 5
|
||||
|
||||
# Start a batched generation. CFG requires a batch size of exactly 2. Offsets and padding mask are required
|
||||
|
||||
input_ids, offsets = tokenizer.encode([prompt_a, prompt_b], encode_special_tokens = True, return_offsets = True)
|
||||
mask = tokenizer.padding_mask(input_ids)
|
||||
generator.begin_stream(input_ids, settings, input_mask = mask, position_offsets = offsets)
|
||||
generator.set_stop_conditions([tokenizer.eos_token_id])
|
||||
|
||||
print(f"---------------------------------------------------------------------------------------")
|
||||
print(f"cfg_scale = {settings.cfg_scale:.1f}")
|
||||
print()
|
||||
|
||||
generated_tokens = 0
|
||||
max_new_tokens = 200
|
||||
while True:
|
||||
chunk, eos, _ = generator.stream()
|
||||
generated_tokens += 1
|
||||
print (chunk, end = "")
|
||||
sys.stdout.flush()
|
||||
if eos or generated_tokens == max_new_tokens: break
|
||||
print()
|
||||
@@ -1,8 +1,8 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from exllamav2 import ExLlamaV2Tokenizer
|
||||
from exllamav2.ext import exllamav2_ext as ext_c, none_tensor
|
||||
|
||||
|
||||
class ExLlamaV2Sampler:
|
||||
|
||||
class Settings:
|
||||
@@ -30,6 +30,7 @@ class ExLlamaV2Sampler:
|
||||
mirostat_mu = None # (re)initialized from mirostat_tau on first sample
|
||||
|
||||
token_bias = None
|
||||
cfg_scale = None
|
||||
|
||||
filters = []
|
||||
|
||||
@@ -105,9 +106,22 @@ class ExLlamaV2Sampler:
|
||||
|
||||
assert logits.shape[1] == 1, "Logits tensor is incorrect shape, must be (bsz, 1, vocab_size)"
|
||||
assert prefix_token is None or prefix_token.shape == (batch_size, 1), "Prefix token list doesn't match batch shape"
|
||||
assert batch_size == 1 or len(settings.filters) == 0, "Filters not implemented for batch size > 1"
|
||||
if settings.cfg_scale is not None: assert batch_size == 2, "CFG requires logits to be bsz 2"
|
||||
else: assert batch_size == 1 or len(settings.filters) == 0, "Filters not implemented for batch size > 1"
|
||||
|
||||
logits = logits.squeeze(1)
|
||||
|
||||
# CFG
|
||||
|
||||
if settings.cfg_scale is not None:
|
||||
|
||||
logits = F.log_softmax(logits, dim = -1)
|
||||
logits = settings.cfg_scale * logits[0] + (1 - settings.cfg_scale) * logits[1]
|
||||
logits = logits.unsqueeze(0)
|
||||
batch_size = 1
|
||||
|
||||
# Prepare filter
|
||||
|
||||
logit_filter = torch.empty((batch_size, vocab_size), dtype = torch.bool)
|
||||
ext_c.fast_fill_cpu_ones_bool(logit_filter)
|
||||
|
||||
@@ -117,7 +131,7 @@ class ExLlamaV2Sampler:
|
||||
settings.token_frequency_penalty != 0.0 or \
|
||||
settings.token_presence_penalty != 0.0:
|
||||
|
||||
ext_c.apply_rep_penalty(sequence_ids,
|
||||
ext_c.apply_rep_penalty(sequence_ids[:1, :],
|
||||
settings.token_repetition_penalty,
|
||||
settings.token_repetition_range,
|
||||
settings.token_repetition_decay,
|
||||
|
||||
@@ -42,6 +42,9 @@ class ExLlamaV2StreamingGenerator(ExLlamaV2BaseGenerator):
|
||||
accepted_draft_tokens: int = 0
|
||||
|
||||
active_loras = []
|
||||
position_offsets = None
|
||||
input_mask = None
|
||||
|
||||
|
||||
def __init__(self, model, cache, tokenizer, draft_model = None, draft_cache = None, num_speculative_tokens = 5):
|
||||
super().__init__(model, cache, tokenizer)
|
||||
@@ -74,7 +77,14 @@ class ExLlamaV2StreamingGenerator(ExLlamaV2BaseGenerator):
|
||||
else: raise ValueError("Unsupported type in stop_conditions")
|
||||
|
||||
|
||||
def begin_stream(self, input_ids: torch.Tensor, gen_settings: ExLlamaV2Sampler.Settings, token_healing = False, loras = None):
|
||||
def begin_stream(self, input_ids: torch.Tensor, gen_settings: ExLlamaV2Sampler.Settings, token_healing = False, loras = None, input_mask = None, position_offsets = None):
|
||||
|
||||
assert input_ids.shape[0] <= 2, "Streaming generator does not support batch size > 1"
|
||||
if input_ids.shape[0] == 2:
|
||||
assert gen_settings.cfg_scale is not None, "No CFG scale set"
|
||||
|
||||
self.position_offsets = position_offsets
|
||||
self.input_mask = input_mask
|
||||
|
||||
# Accept LoRA or list of LoRAs
|
||||
if loras is not None and isinstance(loras, ExLlamaV2Lora): loras = [loras]
|
||||
@@ -98,7 +108,7 @@ class ExLlamaV2StreamingGenerator(ExLlamaV2BaseGenerator):
|
||||
|
||||
if self.heal_next_token:
|
||||
|
||||
# Pop the last toke
|
||||
# Pop the last token
|
||||
|
||||
old_tail = self.tokenizer.decode(self.sequence_ids[:, -self.tail_decode_tokens:])[0]
|
||||
last_token = self.sequence_ids[:, -1:]
|
||||
@@ -136,7 +146,7 @@ class ExLlamaV2StreamingGenerator(ExLlamaV2BaseGenerator):
|
||||
|
||||
# Decode the current tail end of the sequence
|
||||
|
||||
old_tail = self.tokenizer.decode(self.sequence_ids[:, -self.tail_decode_tokens:])[0]
|
||||
old_tail = self.tokenizer.decode(self.sequence_ids[:1, -self.tail_decode_tokens:])[0]
|
||||
|
||||
# Generate a single token and append to the sequence
|
||||
|
||||
@@ -149,7 +159,7 @@ class ExLlamaV2StreamingGenerator(ExLlamaV2BaseGenerator):
|
||||
|
||||
# Decode the tail end of the sequence with the added token to get (actual) characters added
|
||||
|
||||
new_tail = self.tokenizer.decode(self.sequence_ids[:, -(self.tail_decode_tokens + 1):])[0]
|
||||
new_tail = self.tokenizer.decode(self.sequence_ids[:1, -(self.tail_decode_tokens + 1):])[0]
|
||||
new_text = new_tail[len(old_tail):]
|
||||
|
||||
next_token, new_text = self._catch_utf8(next_token, new_text)
|
||||
@@ -250,11 +260,11 @@ class ExLlamaV2StreamingGenerator(ExLlamaV2BaseGenerator):
|
||||
|
||||
self.sequence_ids = in_tokens.clone()
|
||||
self.cache.current_seq_len = 0
|
||||
self.model.forward(self.sequence_ids[:, :-1], self.cache, preprocess_only = True, loras = self.active_loras)
|
||||
self.model.forward(self.sequence_ids[:, :-1], self.cache, preprocess_only = True, loras = self.active_loras, input_mask = self.input_mask, position_offsets = self.position_offsets)
|
||||
|
||||
if self.draft_model is not None:
|
||||
self.draft_cache.current_seq_len = 0
|
||||
self.draft_model.forward(self.sequence_ids[:, :-1], self.draft_cache, preprocess_only = True)
|
||||
self.draft_model.forward(self.sequence_ids[:1, :-1], self.draft_cache, preprocess_only = True)
|
||||
self.future_logits = None
|
||||
self.future_tokens = None
|
||||
|
||||
@@ -296,7 +306,7 @@ class ExLlamaV2StreamingGenerator(ExLlamaV2BaseGenerator):
|
||||
start = self.cache.current_seq_len
|
||||
self.sequence_ids = torch.cat((self.sequence_ids, in_tokens), dim = 1)
|
||||
|
||||
self.model.forward(self.sequence_ids[:, start : -1], self.cache, preprocess_only = True, loras = self.active_loras)
|
||||
self.model.forward(self.sequence_ids[:, start : -1], self.cache, preprocess_only = True, loras = self.active_loras, input_mask = self.input_mask, position_offsets = self.position_offsets)
|
||||
|
||||
if self.draft_model is not None:
|
||||
self.draft_model.forward(self.sequence_ids[:, start: -1], self.draft_cache, preprocess_only = True)
|
||||
@@ -308,14 +318,18 @@ class ExLlamaV2StreamingGenerator(ExLlamaV2BaseGenerator):
|
||||
|
||||
if self.draft_model is None:
|
||||
|
||||
logits = self.model.forward(self.sequence_ids[:, -1:], self.cache, loras = self.active_loras).float().cpu()
|
||||
logits = self.model.forward(self.sequence_ids[:, -1:], self.cache, loras = self.active_loras, input_mask = self.input_mask, position_offsets = self.position_offsets).float().cpu()
|
||||
token, _, eos = ExLlamaV2Sampler.sample(logits, gen_settings, self.sequence_ids, random.random(), self.tokenizer, prefix_token)
|
||||
|
||||
else:
|
||||
|
||||
token, eos = self._gen_single_token_speculative(gen_settings, prefix_token)
|
||||
|
||||
self.sequence_ids = torch.cat([self.sequence_ids, token], dim = 1)
|
||||
if self.sequence_ids.shape[0] > 1 and token.shape[0] == 1:
|
||||
self.sequence_ids = torch.cat([self.sequence_ids, token.repeat(self.sequence_ids.shape[0], 1)], dim = 1)
|
||||
else:
|
||||
self.sequence_ids = torch.cat([self.sequence_ids, token], dim = 1)
|
||||
|
||||
gen_settings.feed_filters(token)
|
||||
return token, eos
|
||||
|
||||
@@ -327,7 +341,7 @@ class ExLlamaV2StreamingGenerator(ExLlamaV2BaseGenerator):
|
||||
# Generate draft
|
||||
|
||||
draft_gen_settings = gen_settings.greedy_clone()
|
||||
draft_sequence_ids = self.sequence_ids.clone()
|
||||
draft_sequence_ids = self.sequence_ids[:1, :]
|
||||
num_drafted_tokens = 0
|
||||
|
||||
for k in range(self.num_speculative_tokens):
|
||||
@@ -350,8 +364,11 @@ class ExLlamaV2StreamingGenerator(ExLlamaV2BaseGenerator):
|
||||
|
||||
# Forward last sampled token plus draft through model
|
||||
|
||||
self.future_tokens = draft_sequence_ids[:, -1 - num_drafted_tokens:]
|
||||
self.future_logits = self.model.forward(self.future_tokens, self.cache, loras = self.active_loras).float().cpu()
|
||||
if self.sequence_ids.shape[0] > 1:
|
||||
self.future_tokens = draft_sequence_ids[:, -1 - num_drafted_tokens:].repeat(self.sequence_ids.shape[0], 1)
|
||||
else:
|
||||
self.future_tokens = draft_sequence_ids[:, -1 - num_drafted_tokens:]
|
||||
self.future_logits = self.model.forward(self.future_tokens, self.cache, loras = self.active_loras, input_mask = self.input_mask, position_offsets = self.position_offsets).float().cpu()
|
||||
|
||||
# Rewind model cache
|
||||
|
||||
|
||||
Reference in New Issue
Block a user