mirror of
https://github.com/turboderp-org/exllamav2.git
synced 2026-04-20 14:29:28 +00:00
New API for streaming generator
This commit is contained in:
@@ -63,8 +63,9 @@ std::vector<float> sample_basic
|
||||
float typical,
|
||||
float random,
|
||||
torch::Tensor output_tokens, // shape [bsz, 1]
|
||||
torch::Tensor output_probs, // shape [bsz, 1] or [bsz, 1, num_probs]
|
||||
torch::Tensor output_ptokens, // None or [bsz, 1, num_probs]
|
||||
torch::Tensor output_probs, // shape [bsz, 1]
|
||||
torch::Tensor output_kprobs, // None or [bsz, 1, num_probs]
|
||||
torch::Tensor output_ktokens, // None or [bsz, 1, num_probs]
|
||||
torch::Tensor logit_filter, // shape [bsz, vocab_size]
|
||||
bool mirostat,
|
||||
std::vector<float>& mirostat_mu,
|
||||
@@ -93,9 +94,9 @@ std::vector<float> sample_basic
|
||||
int* temp_indices = (int*) malloc(vocab_size * sizeof(int));
|
||||
float* logits_ptr = (float*) logits.data_ptr();
|
||||
|
||||
int num_probs = 1;
|
||||
if (output_probs.dim() == 3)
|
||||
num_probs = output_probs.size(2);
|
||||
int num_probs = 0;
|
||||
if (!output_kprobs.device().is_meta())
|
||||
num_probs = output_kprobs.size(2);
|
||||
|
||||
bool* logits_filter_ptr = (bool*) logit_filter.data_ptr();
|
||||
|
||||
@@ -177,12 +178,9 @@ std::vector<float> sample_basic
|
||||
multinomial_cpu(num_candidates, temp_probs, temp_indices, random);
|
||||
|
||||
output_tokens[i][0] = temp_indices[0];
|
||||
output_probs[i][0] = temp_probs[0];
|
||||
|
||||
if (num_probs == 1)
|
||||
{
|
||||
output_probs[i][0] = temp_probs[0];
|
||||
}
|
||||
else
|
||||
if (num_probs)
|
||||
{
|
||||
num_candidates = pre_sort_descending(num_candidates, temp_probs, temp_indices);
|
||||
sort_descending(num_candidates, temp_probs, temp_indices, num_probs);
|
||||
@@ -193,8 +191,8 @@ std::vector<float> sample_basic
|
||||
float tp = temp_probs[j];
|
||||
if (tp == 0.0f) break;
|
||||
|
||||
output_ptokens[i][0][j] = temp_indices[j];
|
||||
output_probs[i][0][j] = tp;
|
||||
output_ktokens[i][0][j] = temp_indices[j];
|
||||
output_kprobs[i][0][j] = tp;
|
||||
}
|
||||
|
||||
// Candidate tokens are only valid up to num_candidates, so fake the ones with prob == 0
|
||||
@@ -205,8 +203,8 @@ std::vector<float> sample_basic
|
||||
for (int k = 0; k < num_candidates; ++k)
|
||||
if (temp_indices[k] == fake_idx) { fake_idx++; k = 0; }
|
||||
|
||||
output_ptokens[i][0][j] = fake_idx;
|
||||
output_probs[i][0][j] = 0.0f;
|
||||
output_ktokens[i][0][j] = fake_idx;
|
||||
output_kprobs[i][0][j] = 0.0f;
|
||||
fake_idx++;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -22,8 +22,9 @@ std::vector<float> sample_basic
|
||||
float typical,
|
||||
float random,
|
||||
torch::Tensor output_tokens, // shape [bsz, 1]
|
||||
torch::Tensor output_probs, // shape [bsz, 1] or [bsz, 1, num_probs]
|
||||
torch::Tensor output_ptokens, // None or [bsz, 1, num_probs]
|
||||
torch::Tensor output_probs, // shape [bsz, 1]
|
||||
torch::Tensor output_kprobs, // None or [bsz, 1, num_probs]
|
||||
torch::Tensor output_ktokens, // None or [bsz, 1, num_probs]
|
||||
torch::Tensor logit_filter, // shape [bsz, vocab_size]
|
||||
bool mirostat,
|
||||
std::vector<float>& mirostat_mu,
|
||||
|
||||
@@ -106,7 +106,7 @@ class ExLlamaV2BaseGenerator:
|
||||
for i in range(num_tokens):
|
||||
|
||||
logits = self.model.forward(self.sequence_ids[:, -1:], self.cache, input_mask = mask, loras = loras, position_offsets = position_offsets).float().cpu()
|
||||
token, _, _, _ = ExLlamaV2Sampler.sample(logits, gen_settings, self.sequence_ids, random.random(), self.tokenizer, prefix_token = unhealed_token)
|
||||
token, _, _, _, _ = ExLlamaV2Sampler.sample(logits, gen_settings, self.sequence_ids, random.random(), self.tokenizer, prefix_token = unhealed_token)
|
||||
|
||||
eos = False
|
||||
if stop_token is not None:
|
||||
|
||||
@@ -108,7 +108,13 @@ class ExLlamaV2Sampler:
|
||||
|
||||
|
||||
@staticmethod
|
||||
def sample(logits: torch.tensor, settings: Settings, sequence_ids: torch.tensor, random: float, tokenizer: ExLlamaV2Tokenizer, prefix_token = None, num_probs = 1):
|
||||
def sample(logits: torch.tensor,
|
||||
settings: Settings,
|
||||
sequence_ids: torch.tensor,
|
||||
random: float,
|
||||
tokenizer: ExLlamaV2Tokenizer,
|
||||
prefix_token = None,
|
||||
return_top_tokens = 0):
|
||||
|
||||
batch_size, _, vocab_size = logits.shape
|
||||
|
||||
@@ -201,12 +207,13 @@ class ExLlamaV2Sampler:
|
||||
batch_size = logits.shape[0]
|
||||
|
||||
output_tokens = torch.empty((batch_size, 1), device="cpu", dtype=torch.long)
|
||||
if num_probs == 1:
|
||||
output_probs = torch.empty((batch_size, 1), device = "cpu", dtype = torch.float)
|
||||
output_ptokens = none_tensor
|
||||
output_probs = torch.empty((batch_size, 1), device="cpu", dtype=torch.float)
|
||||
if return_top_tokens == 0:
|
||||
output_ktokens = none_tensor
|
||||
output_kprobs = none_tensor
|
||||
else:
|
||||
output_probs = torch.empty((batch_size, 1, num_probs), device = "cpu", dtype = torch.float)
|
||||
output_ptokens = torch.empty((batch_size, 1, num_probs), device = "cpu", dtype = torch.long)
|
||||
output_ktokens = torch.empty((batch_size, 1, return_top_tokens), device = "cpu", dtype = torch.long)
|
||||
output_kprobs = torch.empty((batch_size, 1, return_top_tokens), device = "cpu", dtype = torch.float)
|
||||
|
||||
m = ext_c.sample_basic(logits,
|
||||
1.0 if settings.temperature_last else settings.temperature,
|
||||
@@ -219,7 +226,8 @@ class ExLlamaV2Sampler:
|
||||
random,
|
||||
output_tokens,
|
||||
output_probs,
|
||||
output_ptokens,
|
||||
output_kprobs,
|
||||
output_ktokens,
|
||||
logit_filter,
|
||||
settings.mirostat,
|
||||
settings.mirostat_mu if settings.mirostat else [],
|
||||
@@ -238,4 +246,4 @@ class ExLlamaV2Sampler:
|
||||
end_filter = False
|
||||
if len(settings.filters) > 0 and output_tokens[0].item() in end_tokens: end_filter = True
|
||||
|
||||
return output_tokens, output_ptokens, output_probs, end_filter
|
||||
return output_tokens, output_ktokens, output_kprobs, output_probs, end_filter
|
||||
|
||||
@@ -18,29 +18,32 @@ class ExLlamaV2StreamingGenerator(ExLlamaV2BaseGenerator):
|
||||
|
||||
tail_decode_tokens: int = 2
|
||||
|
||||
remaining_tokens: int = 0
|
||||
held_text: str = ""
|
||||
held_utf8_tokens: torch.tensor = None
|
||||
expect_utf8: int = 0
|
||||
held_tokens: torch.Tensor or None = None
|
||||
held_ptokens: torch.Tensor or None = None
|
||||
held_probs: torch.Tensor or None = None
|
||||
held_logits: torch.Tensor or None = None
|
||||
settings: ExLlamaV2Sampler.Settings = None
|
||||
stop_strings: set = set()
|
||||
stop_tokens: set = set()
|
||||
draft_model: ExLlamaV2 or None = None
|
||||
draft_cache: ExLlamaV2Cache or None = None
|
||||
|
||||
remaining_tokens: int = 0
|
||||
held_utf8_tokens: torch.tensor = None
|
||||
expect_utf8: int = 0
|
||||
|
||||
no_tokens: torch.Tensor = None
|
||||
no_ptokens: torch.Tensor = None
|
||||
no_probs: torch.Tensor = None
|
||||
no_pprobs: torch.Tensor = None
|
||||
no_logits: torch.Tensor = None
|
||||
|
||||
held_text: str = ""
|
||||
held_tokens: torch.Tensor or None = None
|
||||
held_ptokens: torch.Tensor or None = None
|
||||
held_probs: torch.Tensor or None = None
|
||||
held_pprobs: torch.Tensor or None = None
|
||||
held_logits: torch.Tensor or None = None
|
||||
|
||||
first_token = False
|
||||
heal_next_token = False
|
||||
|
||||
draft_model: ExLlamaV2 or None = None
|
||||
draft_cache: ExLlamaV2Cache or None = None
|
||||
|
||||
future_logits: torch.tensor or None = None
|
||||
future_tokens: torch.tensor or None = None
|
||||
num_speculative_tokens: int
|
||||
@@ -49,9 +52,9 @@ class ExLlamaV2StreamingGenerator(ExLlamaV2BaseGenerator):
|
||||
total_tokens: int = 0
|
||||
accepted_draft_tokens: int = 0
|
||||
|
||||
return_probabilities: bool = False # Return final sampling probabilities, one per token
|
||||
return_probabilities_k: int = 1 # Number of probabilities to return per token
|
||||
return_logits: bool = False # Return raw logits prior to softmax, per token
|
||||
return_probabilities: bool = False
|
||||
return_top_tokens: int = 0
|
||||
return_logits: bool = False
|
||||
|
||||
active_loras = []
|
||||
position_offsets = None
|
||||
@@ -87,16 +90,40 @@ class ExLlamaV2StreamingGenerator(ExLlamaV2BaseGenerator):
|
||||
else: raise ValueError("Unsupported type in stop_conditions")
|
||||
|
||||
|
||||
def begin_stream(self, input_ids: torch.Tensor, gen_settings: ExLlamaV2Sampler.Settings, token_healing = False, loras = None, input_mask = None, position_offsets = None):
|
||||
# Begin stream
|
||||
|
||||
self.no_logits = torch.empty((0, ((self.model.config.vocab_size + 31) // 32) * 32), dtype = torch.float)
|
||||
self.no_tokens = torch.empty((1, 0), dtype=torch.long)
|
||||
if self.return_probabilities_k == 1:
|
||||
self.no_ptokens = torch.empty((1, 0), dtype = torch.long)
|
||||
self.no_probs = torch.empty((1, 0), dtype = torch.float)
|
||||
else:
|
||||
self.no_ptokens = torch.empty((1, 0, self.return_probabilities_k), dtype = torch.long)
|
||||
self.no_probs = torch.empty((1, 0, self.return_probabilities_k), dtype = torch.float)
|
||||
def begin_stream_ex(self,
|
||||
input_ids: torch.Tensor,
|
||||
gen_settings: ExLlamaV2Sampler.Settings,
|
||||
token_healing = False,
|
||||
loras = None,
|
||||
input_mask = None,
|
||||
position_offsets = None,
|
||||
return_probabilities = False,
|
||||
return_top_tokens = 0,
|
||||
return_logits = False):
|
||||
|
||||
self.return_probabilities = return_probabilities
|
||||
self.return_top_tokens = return_top_tokens
|
||||
self.return_logits = return_logits
|
||||
|
||||
self.begin_stream(input_ids,
|
||||
gen_settings,
|
||||
token_healing,
|
||||
loras,
|
||||
input_mask,
|
||||
position_offsets)
|
||||
|
||||
|
||||
# Legacy function
|
||||
|
||||
def begin_stream(self,
|
||||
input_ids: torch.Tensor,
|
||||
gen_settings: ExLlamaV2Sampler.Settings,
|
||||
token_healing = False,
|
||||
loras = None,
|
||||
input_mask = None,
|
||||
position_offsets = None):
|
||||
|
||||
assert input_ids.shape[0] <= 2, "Streaming generator does not support batch size > 1"
|
||||
if input_ids.shape[0] == 2:
|
||||
@@ -109,12 +136,19 @@ class ExLlamaV2StreamingGenerator(ExLlamaV2BaseGenerator):
|
||||
if loras is not None and isinstance(loras, ExLlamaV2Lora): loras = [loras]
|
||||
self.active_loras = loras
|
||||
|
||||
self.no_logits = torch.empty((0, ((self.model.config.vocab_size + 31) // 32) * 32), dtype=torch.float)
|
||||
self.no_tokens = torch.empty((1, 0), dtype=torch.long)
|
||||
self.no_probs = torch.empty((1, 0), dtype=torch.float)
|
||||
self.no_ptokens = torch.empty((1, 0, self.return_top_tokens), dtype=torch.long)
|
||||
self.no_pprobs = torch.empty((1, 0, self.return_top_tokens), dtype=torch.float)
|
||||
|
||||
self.held_text = ""
|
||||
self.held_utf8_tokens = self.no_tokens
|
||||
self.expect_utf8 = 0
|
||||
self.held_tokens = self.no_tokens
|
||||
self.held_ptokens = self.no_ptokens
|
||||
self.held_probs = self.no_probs
|
||||
self.held_pprobs = self.no_pprobs
|
||||
self.held_logits = self.no_logits
|
||||
self.settings = gen_settings
|
||||
self._gen_begin_reuse(input_ids, gen_settings)
|
||||
@@ -122,27 +156,49 @@ class ExLlamaV2StreamingGenerator(ExLlamaV2BaseGenerator):
|
||||
self.heal_next_token = (token_healing and self.sequence_ids.shape[-1] >= 2)
|
||||
|
||||
|
||||
# Get the next chunk of text in the stream
|
||||
|
||||
def stream_ex(self):
|
||||
|
||||
chunk, eos, chunk_token_ids, probs, ptokens, pprobs, logits = self._stream()
|
||||
|
||||
ret = { "chunk": chunk, # output text
|
||||
"eos": eos, # True if stop condition met
|
||||
"chunk_token_ids": chunk_token_ids } # token(s) corresponding to output text: [1, n] (torch.long)
|
||||
|
||||
if self.return_probabilities:
|
||||
ret["probs"] = probs # Probability of selected token(s): [1, n] (torch.float)
|
||||
|
||||
if self.return_top_tokens > 0:
|
||||
ret["top_probs"] = pprobs # Top k probs: [1, n, k] (torch.float)
|
||||
ret["top_tokens"] = ptokens # Top k tokens: [1, n, k] (torch.long)
|
||||
|
||||
if self.return_logits:
|
||||
ret["logits"] = logits.unsqueeze(0) # Raw output logits: [1, n, vocab_size] (torch.float)
|
||||
|
||||
return ret
|
||||
|
||||
|
||||
# Legacy function
|
||||
|
||||
def stream(self) -> Union[Tuple[str, bool, torch.Tensor],
|
||||
Tuple[str, bool, torch.Tensor, torch.Tensor],
|
||||
Tuple[str, bool, torch.Tensor, torch.Tensor, torch.Tensor],
|
||||
Tuple[str, bool, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]]:
|
||||
Tuple[str, bool, torch.Tensor, torch.Tensor, torch.Tensor]]:
|
||||
|
||||
chunk, eos, chunk_token_ids, probs, ptokens, logits = self._stream()
|
||||
assert self.return_top_tokens == 0, "Use stream_ex() to return top K probs"
|
||||
|
||||
chunk, eos, chunk_token_ids, probs, _, _, logits = self._stream()
|
||||
ret = [chunk, eos, chunk_token_ids]
|
||||
|
||||
if self.return_probabilities:
|
||||
ret.append(probs)
|
||||
if self.return_probabilities_k > 1:
|
||||
ret.append(ptokens)
|
||||
|
||||
|
||||
if self.return_logits:
|
||||
ret.append(logits)
|
||||
|
||||
return tuple(ret)
|
||||
|
||||
|
||||
# Get the next chunk of text in the stream. Returns eos if stop condition has been met but does not count tokens
|
||||
|
||||
def _stream(self) -> (str, bool, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor):
|
||||
|
||||
# Token healing
|
||||
@@ -165,7 +221,7 @@ class ExLlamaV2StreamingGenerator(ExLlamaV2BaseGenerator):
|
||||
|
||||
# Regenerate the last token again, with prefix
|
||||
|
||||
healed_token, _, _, eos, logits = self._gen_single_token(self.settings, prefix_token = last_token)
|
||||
healed_token, _, _, _, eos, logits = self._gen_single_token(self.settings, prefix_token = last_token)
|
||||
new_tail = self.tokenizer.decode(self.sequence_ids[:, -self.tail_decode_tokens:])[0]
|
||||
self.held_text += new_tail[len(old_tail):]
|
||||
|
||||
@@ -173,7 +229,7 @@ class ExLlamaV2StreamingGenerator(ExLlamaV2BaseGenerator):
|
||||
|
||||
# In case we only needed the healed token
|
||||
|
||||
if eos: return self.held_text, True, self.no_tokens, self.no_probs, self.no_ptokens, self.no_logits
|
||||
if eos: return self.held_text, True, self.no_tokens, self.no_probs, self.no_ptokens, self.no_pprobs, self.no_logits
|
||||
|
||||
# Start filters when not healing
|
||||
|
||||
@@ -184,19 +240,18 @@ class ExLlamaV2StreamingGenerator(ExLlamaV2BaseGenerator):
|
||||
self.settings.begin_filters()
|
||||
self.first_token = False
|
||||
|
||||
|
||||
# Decode the current tail end of the sequence
|
||||
|
||||
old_tail = self.tokenizer.decode(self.sequence_ids[:1, -self.tail_decode_tokens:])[0]
|
||||
|
||||
# Generate a single token and append to the sequence
|
||||
|
||||
next_token, next_ptokens, next_prob, eos, next_logits = self._gen_single_token(self.settings)
|
||||
next_token, next_ptokens, next_pprobs, next_prob, eos, next_logits = self._gen_single_token(self.settings)
|
||||
|
||||
# End immediately if it was a stop token
|
||||
|
||||
if next_token.item() in self.stop_tokens:
|
||||
return self.held_text, True, self.no_tokens, self.no_probs, self.no_ptokens, self.no_logits
|
||||
return self.held_text, True, self.no_tokens, self.no_probs, self.no_ptokens, self.no_pprobs, self.no_logits
|
||||
|
||||
# Decode the tail end of the sequence with the added token to get (actual) characters added
|
||||
|
||||
@@ -209,14 +264,15 @@ class ExLlamaV2StreamingGenerator(ExLlamaV2BaseGenerator):
|
||||
self.held_tokens = torch.cat([self.held_tokens, next_token], dim = -1)
|
||||
if self.return_probabilities:
|
||||
self.held_probs = torch.cat([self.held_probs, next_prob], dim = 1)
|
||||
if self.return_probabilities_k > 1:
|
||||
self.held_ptokens = torch.cat([self.held_ptokens, next_ptokens], dim = 1)
|
||||
if self.return_top_tokens > 0:
|
||||
self.held_ptokens = torch.cat([self.held_ptokens, next_ptokens], dim = 1)
|
||||
self.held_pprobs = torch.cat([self.held_pprobs, next_pprobs], dim = 1)
|
||||
if self.return_logits:
|
||||
self.held_logits = torch.cat([self.held_logits, next_logits], dim = 0)
|
||||
|
||||
# Return now if newly added token ends a filter
|
||||
|
||||
if eos: return self.held_text, True, self.held_tokens, self.held_probs, self.held_ptokens, self.held_logits
|
||||
if eos: return self.held_text, True, self.held_tokens, self.held_probs, self.held_ptokens, self.held_pprobs, self.held_logits
|
||||
|
||||
# Hold text as long as it contains part of a stop string
|
||||
|
||||
@@ -227,7 +283,7 @@ class ExLlamaV2StreamingGenerator(ExLlamaV2BaseGenerator):
|
||||
|
||||
position = self.held_text.find(ss)
|
||||
if position != -1:
|
||||
return self.held_text[:position], True, self.no_tokens, self.no_probs, self.no_ptokens, self.no_logits
|
||||
return self.held_text[:position], True, self.no_tokens, self.no_probs, self.no_ptokens, self.no_pprobs, self.no_logits
|
||||
|
||||
# Check for overlap between end of held_text and start of stop string
|
||||
|
||||
@@ -239,7 +295,7 @@ class ExLlamaV2StreamingGenerator(ExLlamaV2BaseGenerator):
|
||||
# If holding text because of a partial stop condition, return nothing but also EOS = False
|
||||
|
||||
if partial_ss:
|
||||
return "", False, self.no_tokens, self.no_probs, self.no_ptokens, self.no_logits
|
||||
return "", False, self.no_tokens, self.no_probs, self.no_ptokens, self.no_pprobs, self.no_logits
|
||||
|
||||
# No stop condition, so return whatever is being held
|
||||
|
||||
@@ -247,13 +303,15 @@ class ExLlamaV2StreamingGenerator(ExLlamaV2BaseGenerator):
|
||||
stream_tokens = self.held_tokens
|
||||
stream_probs = self.held_probs
|
||||
stream_ptokens = self.held_ptokens
|
||||
stream_pprobs = self.held_pprobs
|
||||
stream_logits = self.held_logits
|
||||
self.held_text = ""
|
||||
self.held_tokens = self.no_tokens
|
||||
self.held_probs = self.no_probs
|
||||
self.held_ptokens = self.no_ptokens
|
||||
self.held_pprobs = self.no_pprobs
|
||||
self.held_logits = self.no_logits
|
||||
return stream_text, False, stream_tokens, stream_probs, stream_ptokens, stream_logits
|
||||
return stream_text, False, stream_tokens, stream_probs, stream_ptokens, stream_pprobs, stream_logits
|
||||
|
||||
|
||||
def _decode_utf8(self):
|
||||
@@ -372,11 +430,11 @@ class ExLlamaV2StreamingGenerator(ExLlamaV2BaseGenerator):
|
||||
if self.draft_model is None:
|
||||
|
||||
logits = self.model.forward(self.sequence_ids[:, -1:], self.cache, loras = self.active_loras, input_mask = self.input_mask, position_offsets = self.position_offsets).float().cpu()
|
||||
token, ptokens, prob, eos = ExLlamaV2Sampler.sample(logits, gen_settings, self.sequence_ids[:1, :], random.random(), self.tokenizer, prefix_token, self.return_probabilities_k)
|
||||
token, ptokens, pprobs, prob, eos = ExLlamaV2Sampler.sample(logits, gen_settings, self.sequence_ids[:1, :], random.random(), self.tokenizer, prefix_token, self.return_top_tokens)
|
||||
|
||||
else:
|
||||
|
||||
token, ptokens, prob, eos, logits = self._gen_single_token_speculative(gen_settings, prefix_token)
|
||||
token, ptokens, pprobs, prob, eos, logits = self._gen_single_token_speculative(gen_settings, prefix_token)
|
||||
|
||||
if self.sequence_ids.shape[0] > 1 and token.shape[0] == 1:
|
||||
self.sequence_ids = torch.cat([self.sequence_ids, token.repeat(self.sequence_ids.shape[0], 1)], dim = 1)
|
||||
@@ -384,7 +442,7 @@ class ExLlamaV2StreamingGenerator(ExLlamaV2BaseGenerator):
|
||||
self.sequence_ids = torch.cat([self.sequence_ids, token], dim = 1)
|
||||
|
||||
gen_settings.feed_filters(token)
|
||||
return token, ptokens, prob, eos, logits.flatten(1)
|
||||
return token, ptokens, pprobs, prob, eos, logits.flatten(1)
|
||||
|
||||
|
||||
def _gen_single_token_speculative(self, gen_settings, prefix_token = None):
|
||||
@@ -400,7 +458,7 @@ class ExLlamaV2StreamingGenerator(ExLlamaV2BaseGenerator):
|
||||
for k in range(self.num_speculative_tokens):
|
||||
|
||||
logits = self.draft_model.forward(draft_sequence_ids[:, -1:], self.draft_cache).float().cpu()
|
||||
token, _, prob, _ = ExLlamaV2Sampler.sample(logits, draft_gen_settings, draft_sequence_ids, random.random(), self.tokenizer, prefix_token if k == 0 else None)
|
||||
token, _, _, prob, _ = ExLlamaV2Sampler.sample(logits, draft_gen_settings, draft_sequence_ids, random.random(), self.tokenizer, prefix_token if k == 0 else None)
|
||||
|
||||
if prob < self.speculative_prob_threshold:
|
||||
self.draft_cache.current_seq_len -= 1
|
||||
@@ -430,7 +488,7 @@ class ExLlamaV2StreamingGenerator(ExLlamaV2BaseGenerator):
|
||||
# Sample the first future logits
|
||||
|
||||
logits = self.future_logits[:, :1, :]
|
||||
token, ptokens, prob, eos = ExLlamaV2Sampler.sample(logits, gen_settings, self.sequence_ids[:1, :], random.random(), self.tokenizer, prefix_token, self.return_probabilities_k)
|
||||
token, ptokens, pprobs, prob, eos = ExLlamaV2Sampler.sample(logits, gen_settings, self.sequence_ids[:1, :], random.random(), self.tokenizer, prefix_token, self.return_top_tokens)
|
||||
self.future_logits = self.future_logits[:, 1:, :]
|
||||
self.future_tokens = self.future_tokens[:, 1:]
|
||||
self.cache.current_seq_len += 1
|
||||
@@ -445,7 +503,7 @@ class ExLlamaV2StreamingGenerator(ExLlamaV2BaseGenerator):
|
||||
self.accepted_draft_tokens += 1
|
||||
self.total_tokens += 1
|
||||
|
||||
return token, ptokens, prob, eos, logits
|
||||
return token, ptokens, pprobs, prob, eos, logits
|
||||
|
||||
|
||||
def reset_sd_stats(self):
|
||||
|
||||
Reference in New Issue
Block a user