New API for streaming generator

This commit is contained in:
turboderp
2024-02-11 20:31:58 +01:00
parent 944e523109
commit 1c67f97f3d
5 changed files with 138 additions and 73 deletions

View File

@@ -63,8 +63,9 @@ std::vector<float> sample_basic
float typical,
float random,
torch::Tensor output_tokens, // shape [bsz, 1]
torch::Tensor output_probs, // shape [bsz, 1] or [bsz, 1, num_probs]
torch::Tensor output_ptokens, // None or [bsz, 1, num_probs]
torch::Tensor output_probs, // shape [bsz, 1]
torch::Tensor output_kprobs, // None or [bsz, 1, num_probs]
torch::Tensor output_ktokens, // None or [bsz, 1, num_probs]
torch::Tensor logit_filter, // shape [bsz, vocab_size]
bool mirostat,
std::vector<float>& mirostat_mu,
@@ -93,9 +94,9 @@ std::vector<float> sample_basic
int* temp_indices = (int*) malloc(vocab_size * sizeof(int));
float* logits_ptr = (float*) logits.data_ptr();
int num_probs = 1;
if (output_probs.dim() == 3)
num_probs = output_probs.size(2);
int num_probs = 0;
if (!output_kprobs.device().is_meta())
num_probs = output_kprobs.size(2);
bool* logits_filter_ptr = (bool*) logit_filter.data_ptr();
@@ -177,12 +178,9 @@ std::vector<float> sample_basic
multinomial_cpu(num_candidates, temp_probs, temp_indices, random);
output_tokens[i][0] = temp_indices[0];
output_probs[i][0] = temp_probs[0];
if (num_probs == 1)
{
output_probs[i][0] = temp_probs[0];
}
else
if (num_probs)
{
num_candidates = pre_sort_descending(num_candidates, temp_probs, temp_indices);
sort_descending(num_candidates, temp_probs, temp_indices, num_probs);
@@ -193,8 +191,8 @@ std::vector<float> sample_basic
float tp = temp_probs[j];
if (tp == 0.0f) break;
output_ptokens[i][0][j] = temp_indices[j];
output_probs[i][0][j] = tp;
output_ktokens[i][0][j] = temp_indices[j];
output_kprobs[i][0][j] = tp;
}
// Candidate tokens are only valid up to num_candidates, so fake the ones with prob == 0
@@ -205,8 +203,8 @@ std::vector<float> sample_basic
for (int k = 0; k < num_candidates; ++k)
if (temp_indices[k] == fake_idx) { fake_idx++; k = 0; }
output_ptokens[i][0][j] = fake_idx;
output_probs[i][0][j] = 0.0f;
output_ktokens[i][0][j] = fake_idx;
output_kprobs[i][0][j] = 0.0f;
fake_idx++;
}
}

View File

@@ -22,8 +22,9 @@ std::vector<float> sample_basic
float typical,
float random,
torch::Tensor output_tokens, // shape [bsz, 1]
torch::Tensor output_probs, // shape [bsz, 1] or [bsz, 1, num_probs]
torch::Tensor output_ptokens, // None or [bsz, 1, num_probs]
torch::Tensor output_probs, // shape [bsz, 1]
torch::Tensor output_kprobs, // None or [bsz, 1, num_probs]
torch::Tensor output_ktokens, // None or [bsz, 1, num_probs]
torch::Tensor logit_filter, // shape [bsz, vocab_size]
bool mirostat,
std::vector<float>& mirostat_mu,

View File

@@ -106,7 +106,7 @@ class ExLlamaV2BaseGenerator:
for i in range(num_tokens):
logits = self.model.forward(self.sequence_ids[:, -1:], self.cache, input_mask = mask, loras = loras, position_offsets = position_offsets).float().cpu()
token, _, _, _ = ExLlamaV2Sampler.sample(logits, gen_settings, self.sequence_ids, random.random(), self.tokenizer, prefix_token = unhealed_token)
token, _, _, _, _ = ExLlamaV2Sampler.sample(logits, gen_settings, self.sequence_ids, random.random(), self.tokenizer, prefix_token = unhealed_token)
eos = False
if stop_token is not None:

View File

@@ -108,7 +108,13 @@ class ExLlamaV2Sampler:
@staticmethod
def sample(logits: torch.tensor, settings: Settings, sequence_ids: torch.tensor, random: float, tokenizer: ExLlamaV2Tokenizer, prefix_token = None, num_probs = 1):
def sample(logits: torch.tensor,
settings: Settings,
sequence_ids: torch.tensor,
random: float,
tokenizer: ExLlamaV2Tokenizer,
prefix_token = None,
return_top_tokens = 0):
batch_size, _, vocab_size = logits.shape
@@ -201,12 +207,13 @@ class ExLlamaV2Sampler:
batch_size = logits.shape[0]
output_tokens = torch.empty((batch_size, 1), device="cpu", dtype=torch.long)
if num_probs == 1:
output_probs = torch.empty((batch_size, 1), device = "cpu", dtype = torch.float)
output_ptokens = none_tensor
output_probs = torch.empty((batch_size, 1), device="cpu", dtype=torch.float)
if return_top_tokens == 0:
output_ktokens = none_tensor
output_kprobs = none_tensor
else:
output_probs = torch.empty((batch_size, 1, num_probs), device = "cpu", dtype = torch.float)
output_ptokens = torch.empty((batch_size, 1, num_probs), device = "cpu", dtype = torch.long)
output_ktokens = torch.empty((batch_size, 1, return_top_tokens), device = "cpu", dtype = torch.long)
output_kprobs = torch.empty((batch_size, 1, return_top_tokens), device = "cpu", dtype = torch.float)
m = ext_c.sample_basic(logits,
1.0 if settings.temperature_last else settings.temperature,
@@ -219,7 +226,8 @@ class ExLlamaV2Sampler:
random,
output_tokens,
output_probs,
output_ptokens,
output_kprobs,
output_ktokens,
logit_filter,
settings.mirostat,
settings.mirostat_mu if settings.mirostat else [],
@@ -238,4 +246,4 @@ class ExLlamaV2Sampler:
end_filter = False
if len(settings.filters) > 0 and output_tokens[0].item() in end_tokens: end_filter = True
return output_tokens, output_ptokens, output_probs, end_filter
return output_tokens, output_ktokens, output_kprobs, output_probs, end_filter

View File

@@ -18,29 +18,32 @@ class ExLlamaV2StreamingGenerator(ExLlamaV2BaseGenerator):
tail_decode_tokens: int = 2
remaining_tokens: int = 0
held_text: str = ""
held_utf8_tokens: torch.tensor = None
expect_utf8: int = 0
held_tokens: torch.Tensor or None = None
held_ptokens: torch.Tensor or None = None
held_probs: torch.Tensor or None = None
held_logits: torch.Tensor or None = None
settings: ExLlamaV2Sampler.Settings = None
stop_strings: set = set()
stop_tokens: set = set()
draft_model: ExLlamaV2 or None = None
draft_cache: ExLlamaV2Cache or None = None
remaining_tokens: int = 0
held_utf8_tokens: torch.tensor = None
expect_utf8: int = 0
no_tokens: torch.Tensor = None
no_ptokens: torch.Tensor = None
no_probs: torch.Tensor = None
no_pprobs: torch.Tensor = None
no_logits: torch.Tensor = None
held_text: str = ""
held_tokens: torch.Tensor or None = None
held_ptokens: torch.Tensor or None = None
held_probs: torch.Tensor or None = None
held_pprobs: torch.Tensor or None = None
held_logits: torch.Tensor or None = None
first_token = False
heal_next_token = False
draft_model: ExLlamaV2 or None = None
draft_cache: ExLlamaV2Cache or None = None
future_logits: torch.tensor or None = None
future_tokens: torch.tensor or None = None
num_speculative_tokens: int
@@ -49,9 +52,9 @@ class ExLlamaV2StreamingGenerator(ExLlamaV2BaseGenerator):
total_tokens: int = 0
accepted_draft_tokens: int = 0
return_probabilities: bool = False # Return final sampling probabilities, one per token
return_probabilities_k: int = 1 # Number of probabilities to return per token
return_logits: bool = False # Return raw logits prior to softmax, per token
return_probabilities: bool = False
return_top_tokens: int = 0
return_logits: bool = False
active_loras = []
position_offsets = None
@@ -87,16 +90,40 @@ class ExLlamaV2StreamingGenerator(ExLlamaV2BaseGenerator):
else: raise ValueError("Unsupported type in stop_conditions")
def begin_stream(self, input_ids: torch.Tensor, gen_settings: ExLlamaV2Sampler.Settings, token_healing = False, loras = None, input_mask = None, position_offsets = None):
# Begin stream
self.no_logits = torch.empty((0, ((self.model.config.vocab_size + 31) // 32) * 32), dtype = torch.float)
self.no_tokens = torch.empty((1, 0), dtype=torch.long)
if self.return_probabilities_k == 1:
self.no_ptokens = torch.empty((1, 0), dtype = torch.long)
self.no_probs = torch.empty((1, 0), dtype = torch.float)
else:
self.no_ptokens = torch.empty((1, 0, self.return_probabilities_k), dtype = torch.long)
self.no_probs = torch.empty((1, 0, self.return_probabilities_k), dtype = torch.float)
def begin_stream_ex(self,
input_ids: torch.Tensor,
gen_settings: ExLlamaV2Sampler.Settings,
token_healing = False,
loras = None,
input_mask = None,
position_offsets = None,
return_probabilities = False,
return_top_tokens = 0,
return_logits = False):
self.return_probabilities = return_probabilities
self.return_top_tokens = return_top_tokens
self.return_logits = return_logits
self.begin_stream(input_ids,
gen_settings,
token_healing,
loras,
input_mask,
position_offsets)
# Legacy function
def begin_stream(self,
input_ids: torch.Tensor,
gen_settings: ExLlamaV2Sampler.Settings,
token_healing = False,
loras = None,
input_mask = None,
position_offsets = None):
assert input_ids.shape[0] <= 2, "Streaming generator does not support batch size > 1"
if input_ids.shape[0] == 2:
@@ -109,12 +136,19 @@ class ExLlamaV2StreamingGenerator(ExLlamaV2BaseGenerator):
if loras is not None and isinstance(loras, ExLlamaV2Lora): loras = [loras]
self.active_loras = loras
self.no_logits = torch.empty((0, ((self.model.config.vocab_size + 31) // 32) * 32), dtype=torch.float)
self.no_tokens = torch.empty((1, 0), dtype=torch.long)
self.no_probs = torch.empty((1, 0), dtype=torch.float)
self.no_ptokens = torch.empty((1, 0, self.return_top_tokens), dtype=torch.long)
self.no_pprobs = torch.empty((1, 0, self.return_top_tokens), dtype=torch.float)
self.held_text = ""
self.held_utf8_tokens = self.no_tokens
self.expect_utf8 = 0
self.held_tokens = self.no_tokens
self.held_ptokens = self.no_ptokens
self.held_probs = self.no_probs
self.held_pprobs = self.no_pprobs
self.held_logits = self.no_logits
self.settings = gen_settings
self._gen_begin_reuse(input_ids, gen_settings)
@@ -122,27 +156,49 @@ class ExLlamaV2StreamingGenerator(ExLlamaV2BaseGenerator):
self.heal_next_token = (token_healing and self.sequence_ids.shape[-1] >= 2)
# Get the next chunk of text in the stream
def stream_ex(self):
chunk, eos, chunk_token_ids, probs, ptokens, pprobs, logits = self._stream()
ret = { "chunk": chunk, # output text
"eos": eos, # True if stop condition met
"chunk_token_ids": chunk_token_ids } # token(s) corresponding to output text: [1, n] (torch.long)
if self.return_probabilities:
ret["probs"] = probs # Probability of selected token(s): [1, n] (torch.float)
if self.return_top_tokens > 0:
ret["top_probs"] = pprobs # Top k probs: [1, n, k] (torch.float)
ret["top_tokens"] = ptokens # Top k tokens: [1, n, k] (torch.long)
if self.return_logits:
ret["logits"] = logits.unsqueeze(0) # Raw output logits: [1, n, vocab_size] (torch.float)
return ret
# Legacy function
def stream(self) -> Union[Tuple[str, bool, torch.Tensor],
Tuple[str, bool, torch.Tensor, torch.Tensor],
Tuple[str, bool, torch.Tensor, torch.Tensor, torch.Tensor],
Tuple[str, bool, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]]:
Tuple[str, bool, torch.Tensor, torch.Tensor, torch.Tensor]]:
chunk, eos, chunk_token_ids, probs, ptokens, logits = self._stream()
assert self.return_top_tokens == 0, "Use stream_ex() to return top K probs"
chunk, eos, chunk_token_ids, probs, _, _, logits = self._stream()
ret = [chunk, eos, chunk_token_ids]
if self.return_probabilities:
ret.append(probs)
if self.return_probabilities_k > 1:
ret.append(ptokens)
if self.return_logits:
ret.append(logits)
return tuple(ret)
# Get the next chunk of text in the stream. Returns eos if stop condition has been met but does not count tokens
def _stream(self) -> (str, bool, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor):
# Token healing
@@ -165,7 +221,7 @@ class ExLlamaV2StreamingGenerator(ExLlamaV2BaseGenerator):
# Regenerate the last token again, with prefix
healed_token, _, _, eos, logits = self._gen_single_token(self.settings, prefix_token = last_token)
healed_token, _, _, _, eos, logits = self._gen_single_token(self.settings, prefix_token = last_token)
new_tail = self.tokenizer.decode(self.sequence_ids[:, -self.tail_decode_tokens:])[0]
self.held_text += new_tail[len(old_tail):]
@@ -173,7 +229,7 @@ class ExLlamaV2StreamingGenerator(ExLlamaV2BaseGenerator):
# In case we only needed the healed token
if eos: return self.held_text, True, self.no_tokens, self.no_probs, self.no_ptokens, self.no_logits
if eos: return self.held_text, True, self.no_tokens, self.no_probs, self.no_ptokens, self.no_pprobs, self.no_logits
# Start filters when not healing
@@ -184,19 +240,18 @@ class ExLlamaV2StreamingGenerator(ExLlamaV2BaseGenerator):
self.settings.begin_filters()
self.first_token = False
# Decode the current tail end of the sequence
old_tail = self.tokenizer.decode(self.sequence_ids[:1, -self.tail_decode_tokens:])[0]
# Generate a single token and append to the sequence
next_token, next_ptokens, next_prob, eos, next_logits = self._gen_single_token(self.settings)
next_token, next_ptokens, next_pprobs, next_prob, eos, next_logits = self._gen_single_token(self.settings)
# End immediately if it was a stop token
if next_token.item() in self.stop_tokens:
return self.held_text, True, self.no_tokens, self.no_probs, self.no_ptokens, self.no_logits
return self.held_text, True, self.no_tokens, self.no_probs, self.no_ptokens, self.no_pprobs, self.no_logits
# Decode the tail end of the sequence with the added token to get (actual) characters added
@@ -209,14 +264,15 @@ class ExLlamaV2StreamingGenerator(ExLlamaV2BaseGenerator):
self.held_tokens = torch.cat([self.held_tokens, next_token], dim = -1)
if self.return_probabilities:
self.held_probs = torch.cat([self.held_probs, next_prob], dim = 1)
if self.return_probabilities_k > 1:
self.held_ptokens = torch.cat([self.held_ptokens, next_ptokens], dim = 1)
if self.return_top_tokens > 0:
self.held_ptokens = torch.cat([self.held_ptokens, next_ptokens], dim = 1)
self.held_pprobs = torch.cat([self.held_pprobs, next_pprobs], dim = 1)
if self.return_logits:
self.held_logits = torch.cat([self.held_logits, next_logits], dim = 0)
# Return now if newly added token ends a filter
if eos: return self.held_text, True, self.held_tokens, self.held_probs, self.held_ptokens, self.held_logits
if eos: return self.held_text, True, self.held_tokens, self.held_probs, self.held_ptokens, self.held_pprobs, self.held_logits
# Hold text as long as it contains part of a stop string
@@ -227,7 +283,7 @@ class ExLlamaV2StreamingGenerator(ExLlamaV2BaseGenerator):
position = self.held_text.find(ss)
if position != -1:
return self.held_text[:position], True, self.no_tokens, self.no_probs, self.no_ptokens, self.no_logits
return self.held_text[:position], True, self.no_tokens, self.no_probs, self.no_ptokens, self.no_pprobs, self.no_logits
# Check for overlap between end of held_text and start of stop string
@@ -239,7 +295,7 @@ class ExLlamaV2StreamingGenerator(ExLlamaV2BaseGenerator):
# If holding text because of a partial stop condition, return nothing but also EOS = False
if partial_ss:
return "", False, self.no_tokens, self.no_probs, self.no_ptokens, self.no_logits
return "", False, self.no_tokens, self.no_probs, self.no_ptokens, self.no_pprobs, self.no_logits
# No stop condition, so return whatever is being held
@@ -247,13 +303,15 @@ class ExLlamaV2StreamingGenerator(ExLlamaV2BaseGenerator):
stream_tokens = self.held_tokens
stream_probs = self.held_probs
stream_ptokens = self.held_ptokens
stream_pprobs = self.held_pprobs
stream_logits = self.held_logits
self.held_text = ""
self.held_tokens = self.no_tokens
self.held_probs = self.no_probs
self.held_ptokens = self.no_ptokens
self.held_pprobs = self.no_pprobs
self.held_logits = self.no_logits
return stream_text, False, stream_tokens, stream_probs, stream_ptokens, stream_logits
return stream_text, False, stream_tokens, stream_probs, stream_ptokens, stream_pprobs, stream_logits
def _decode_utf8(self):
@@ -372,11 +430,11 @@ class ExLlamaV2StreamingGenerator(ExLlamaV2BaseGenerator):
if self.draft_model is None:
logits = self.model.forward(self.sequence_ids[:, -1:], self.cache, loras = self.active_loras, input_mask = self.input_mask, position_offsets = self.position_offsets).float().cpu()
token, ptokens, prob, eos = ExLlamaV2Sampler.sample(logits, gen_settings, self.sequence_ids[:1, :], random.random(), self.tokenizer, prefix_token, self.return_probabilities_k)
token, ptokens, pprobs, prob, eos = ExLlamaV2Sampler.sample(logits, gen_settings, self.sequence_ids[:1, :], random.random(), self.tokenizer, prefix_token, self.return_top_tokens)
else:
token, ptokens, prob, eos, logits = self._gen_single_token_speculative(gen_settings, prefix_token)
token, ptokens, pprobs, prob, eos, logits = self._gen_single_token_speculative(gen_settings, prefix_token)
if self.sequence_ids.shape[0] > 1 and token.shape[0] == 1:
self.sequence_ids = torch.cat([self.sequence_ids, token.repeat(self.sequence_ids.shape[0], 1)], dim = 1)
@@ -384,7 +442,7 @@ class ExLlamaV2StreamingGenerator(ExLlamaV2BaseGenerator):
self.sequence_ids = torch.cat([self.sequence_ids, token], dim = 1)
gen_settings.feed_filters(token)
return token, ptokens, prob, eos, logits.flatten(1)
return token, ptokens, pprobs, prob, eos, logits.flatten(1)
def _gen_single_token_speculative(self, gen_settings, prefix_token = None):
@@ -400,7 +458,7 @@ class ExLlamaV2StreamingGenerator(ExLlamaV2BaseGenerator):
for k in range(self.num_speculative_tokens):
logits = self.draft_model.forward(draft_sequence_ids[:, -1:], self.draft_cache).float().cpu()
token, _, prob, _ = ExLlamaV2Sampler.sample(logits, draft_gen_settings, draft_sequence_ids, random.random(), self.tokenizer, prefix_token if k == 0 else None)
token, _, _, prob, _ = ExLlamaV2Sampler.sample(logits, draft_gen_settings, draft_sequence_ids, random.random(), self.tokenizer, prefix_token if k == 0 else None)
if prob < self.speculative_prob_threshold:
self.draft_cache.current_seq_len -= 1
@@ -430,7 +488,7 @@ class ExLlamaV2StreamingGenerator(ExLlamaV2BaseGenerator):
# Sample the first future logits
logits = self.future_logits[:, :1, :]
token, ptokens, prob, eos = ExLlamaV2Sampler.sample(logits, gen_settings, self.sequence_ids[:1, :], random.random(), self.tokenizer, prefix_token, self.return_probabilities_k)
token, ptokens, pprobs, prob, eos = ExLlamaV2Sampler.sample(logits, gen_settings, self.sequence_ids[:1, :], random.random(), self.tokenizer, prefix_token, self.return_top_tokens)
self.future_logits = self.future_logits[:, 1:, :]
self.future_tokens = self.future_tokens[:, 1:]
self.cache.current_seq_len += 1
@@ -445,7 +503,7 @@ class ExLlamaV2StreamingGenerator(ExLlamaV2BaseGenerator):
self.accepted_draft_tokens += 1
self.total_tokens += 1
return token, ptokens, prob, eos, logits
return token, ptokens, pprobs, prob, eos, logits
def reset_sd_stats(self):