mirror of
https://github.com/theroyallab/tabbyAPI.git
synced 2026-04-25 08:48:57 +00:00
Model: Switch logprobs to use post-sampling
Previously, pre-sampling logprobs were used from the raw logits, but newer versions of exl2 allow for returning token probs post-sampling. Convert these to logprobs and send to the user. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
@@ -466,11 +466,15 @@ class ExllamaV2Container:
|
||||
def encode_tokens(self, text: str, **kwargs):
|
||||
"""Wrapper to encode tokens from a text string"""
|
||||
|
||||
return self.tokenizer.encode(
|
||||
text,
|
||||
add_bos=unwrap(kwargs.get("add_bos_token"), True),
|
||||
encode_special_tokens=unwrap(kwargs.get("encode_special_tokens"), True),
|
||||
)[0].tolist()
|
||||
return (
|
||||
self.tokenizer.encode(
|
||||
text,
|
||||
add_bos=unwrap(kwargs.get("add_bos_token"), True),
|
||||
encode_special_tokens=unwrap(kwargs.get("encode_special_tokens"), True),
|
||||
)
|
||||
.flatten()
|
||||
.tolist()
|
||||
)
|
||||
|
||||
def decode_tokens(self, ids: List[int], **kwargs):
|
||||
"""Wrapper to decode tokens from a list of IDs"""
|
||||
@@ -489,35 +493,24 @@ class ExllamaV2Container:
|
||||
"unk_token": self.tokenizer.unk_token,
|
||||
}
|
||||
|
||||
def get_logprobs(self, logits: torch.Tensor, max_logprobs: int):
|
||||
normalized_logits = torch.log_softmax(logits, dim=-1)
|
||||
top_values, top_ids = torch.topk(normalized_logits, max_logprobs, dim=-1)
|
||||
|
||||
def get_logprobs(self, token_ids: torch.Tensor, token_probs: torch.Tensor):
|
||||
top_tokens = list(
|
||||
map(
|
||||
lambda index: self.tokenizer.extended_id_to_piece.get(
|
||||
index, self.tokenizer.id_to_piece[index]
|
||||
),
|
||||
top_ids[0].tolist(),
|
||||
)
|
||||
)
|
||||
top_values = top_values[0].tolist()
|
||||
|
||||
return dict(zip_longest(top_tokens, top_values))
|
||||
|
||||
def get_token_probs(self, token_ids: torch.tensor, token_probs: torch.Tensor):
|
||||
normalized_probs = torch.log(token_probs)
|
||||
|
||||
tokens = list(
|
||||
map(
|
||||
lambda index: self.tokenizer.extended_id_to_piece.get(
|
||||
index, self.tokenizer.id_to_piece[index]
|
||||
),
|
||||
token_ids[0].tolist(),
|
||||
token_ids.flatten().tolist(),
|
||||
)
|
||||
)
|
||||
|
||||
return dict(zip_longest(tokens, normalized_probs[0].tolist()))
|
||||
top_values = torch.log(token_probs).flatten().tolist()
|
||||
|
||||
# Cannot return -inf in JSON
|
||||
cleaned_values = list(
|
||||
map(lambda value: -1000 if value == float("-inf") else value, top_values)
|
||||
)
|
||||
|
||||
return dict(zip_longest(top_tokens, cleaned_values))
|
||||
|
||||
def generate(self, prompt: str, **kwargs):
|
||||
"""Generate a response to a prompt"""
|
||||
@@ -712,7 +705,10 @@ class ExllamaV2Container:
|
||||
add_bos_token = unwrap(kwargs.get("add_bos_token"), True)
|
||||
ban_eos_token = unwrap(kwargs.get("ban_eos_token"), False)
|
||||
logit_bias = kwargs.get("logit_bias")
|
||||
|
||||
# Logprobs
|
||||
request_logprobs = unwrap(kwargs.get("logprobs"), 0)
|
||||
self.generator.return_top_tokens = request_logprobs
|
||||
|
||||
# Override sampler settings for temp = 0
|
||||
if gen_settings.temperature == 0:
|
||||
@@ -832,14 +828,20 @@ class ExllamaV2Container:
|
||||
if auto_scale_penalty_range:
|
||||
gen_settings.token_repetition_range = generated_tokens
|
||||
|
||||
# Generate
|
||||
chunk, eos, tokens, token_probs, logits = self.generator.stream()
|
||||
# Run dict generation
|
||||
# Guarantees return of chunk, eos, and chunk_token_ids
|
||||
raw_generation = self.generator.stream_ex()
|
||||
|
||||
if token_healing:
|
||||
# Extract healed token
|
||||
ids[:, -1] = self.generator.sequence_ids[:, -2]
|
||||
token_healing = False
|
||||
|
||||
# Get parameters that will always exist
|
||||
chunk = raw_generation["chunk"]
|
||||
eos = raw_generation["eos"]
|
||||
tokens = raw_generation["chunk_token_ids"]
|
||||
|
||||
save_tokens = torch.cat(
|
||||
(save_tokens, tokens.expand(save_tokens.shape[0], -1)), dim=-1
|
||||
)
|
||||
@@ -863,17 +865,26 @@ class ExllamaV2Container:
|
||||
}
|
||||
|
||||
if request_logprobs > 0:
|
||||
# Get sampled token probs
|
||||
if token_probs.numel() > 0 and tokens.numel() > 0:
|
||||
generation["token_probs"] = self.get_token_probs(
|
||||
tokens, token_probs
|
||||
)
|
||||
# Get top tokens and probs
|
||||
top_tokens = unwrap(
|
||||
raw_generation.get("top_tokens"),
|
||||
torch.empty((1, 0, 1), dtype=torch.long),
|
||||
)
|
||||
|
||||
# Get logprob choices
|
||||
if logits.numel() > 0:
|
||||
generation["logprobs"] = self.get_logprobs(
|
||||
logits, request_logprobs
|
||||
)
|
||||
top_probs = unwrap(
|
||||
raw_generation.get("top_probs"),
|
||||
torch.empty((1, 0, 1), dtype=torch.float),
|
||||
)
|
||||
|
||||
if top_tokens.numel() > 0 and top_probs.numel() > 0:
|
||||
logprobs = self.get_logprobs(top_tokens, top_probs)
|
||||
generation["logprobs"] = logprobs
|
||||
|
||||
# The first logprob is the selected token prob
|
||||
generation["token_probs"] = {
|
||||
token: logprobs[token]
|
||||
for token in list(logprobs.keys())[:1]
|
||||
}
|
||||
|
||||
yield generation
|
||||
full_response += chunk_buffer
|
||||
|
||||
Reference in New Issue
Block a user