Model: Add logprobs support

Returns token offsets, selected tokens, probabilities of tokens
post-sampling, and normalized probability of selecting a token
pre-sampling (for efficiency purposes).

Only for text completions. Chat completions in a later commit.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri
2024-02-07 21:41:15 -05:00
committed by Brian Dashore
parent 2642ef7156
commit 0af6a38af3
6 changed files with 145 additions and 52 deletions

View File

@@ -1,19 +1,10 @@
""" Common types for OAI. """ """ Common types for OAI. """
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from typing import List, Dict, Optional from typing import Optional
from common.sampling import BaseSamplerRequest from common.sampling import BaseSamplerRequest
class LogProbs(BaseModel):
"""Represents log probabilities."""
text_offset: List[int] = Field(default_factory=list)
token_logprobs: List[Optional[float]] = Field(default_factory=list)
tokens: List[str] = Field(default_factory=list)
top_logprobs: List[Optional[Dict[str, float]]] = Field(default_factory=list)
class UsageStats(BaseModel): class UsageStats(BaseModel):
"""Represents usage stats.""" """Represents usage stats."""
@@ -29,6 +20,10 @@ class CommonCompletionRequest(BaseSamplerRequest):
# This parameter is not used, the loaded model is used instead # This parameter is not used, the loaded model is used instead
model: Optional[str] = None model: Optional[str] = None
# Generation info (remainder is in BaseSamplerRequest superclass)
stream: Optional[bool] = False
logprobs: Optional[int] = 0
# Extra OAI request stuff # Extra OAI request stuff
best_of: Optional[int] = Field( best_of: Optional[int] = Field(
description="Not parsed. Only used for OAI compliance.", default=None description="Not parsed. Only used for OAI compliance.", default=None
@@ -36,9 +31,6 @@ class CommonCompletionRequest(BaseSamplerRequest):
echo: Optional[bool] = Field( echo: Optional[bool] = Field(
description="Not parsed. Only used for OAI compliance.", default=False description="Not parsed. Only used for OAI compliance.", default=False
) )
logprobs: Optional[int] = Field(
description="Not parsed. Only used for OAI compliance.", default=None
)
n: Optional[int] = Field( n: Optional[int] = Field(
description="Not parsed. Only used for OAI compliance.", default=1 description="Not parsed. Only used for OAI compliance.", default=1
) )
@@ -49,5 +41,7 @@ class CommonCompletionRequest(BaseSamplerRequest):
description="Not parsed. Only used for OAI compliance.", default=None description="Not parsed. Only used for OAI compliance.", default=None
) )
# Generation info (remainder is in BaseSamplerRequest superclass) def to_gen_params(self):
stream: Optional[bool] = False extra_gen_params = {"logprobs": self.logprobs}
return super().to_gen_params(**extra_gen_params)

View File

@@ -1,10 +1,19 @@
""" Completion API protocols """ """ Completion API protocols """
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from time import time from time import time
from typing import List, Optional, Union from typing import Dict, List, Optional, Union
from uuid import uuid4 from uuid import uuid4
from OAI.types.common import CommonCompletionRequest, LogProbs, UsageStats from OAI.types.common import CommonCompletionRequest, UsageStats
class CompletionLogProbs(BaseModel):
"""Represents log probabilities for a completion request."""
text_offset: List[int] = Field(default_factory=list)
token_logprobs: List[Optional[float]] = Field(default_factory=list)
tokens: List[str] = Field(default_factory=list)
top_logprobs: List[Optional[Dict[str, float]]] = Field(default_factory=list)
class CompletionRespChoice(BaseModel): class CompletionRespChoice(BaseModel):
@@ -13,7 +22,7 @@ class CompletionRespChoice(BaseModel):
# Index is 0 since we aren't using multiple choices # Index is 0 since we aren't using multiple choices
index: int = 0 index: int = 0
finish_reason: str finish_reason: str
logprobs: Optional[LogProbs] = None logprobs: Optional[CompletionLogProbs] = None
text: str text: str

View File

@@ -9,22 +9,40 @@ from OAI.types.chat_completion import (
ChatCompletionResponse, ChatCompletionResponse,
ChatCompletionStreamChoice, ChatCompletionStreamChoice,
) )
from OAI.types.completion import CompletionResponse, CompletionRespChoice from OAI.types.completion import (
CompletionResponse,
CompletionRespChoice,
CompletionLogProbs,
)
from OAI.types.common import UsageStats from OAI.types.common import UsageStats
def create_completion_response( def create_completion_response(**kwargs):
text: str,
prompt_tokens: int,
completion_tokens: int,
model_name: Optional[str],
):
"""Create a completion response from the provided text.""" """Create a completion response from the provided text."""
choice = CompletionRespChoice(finish_reason="Generated", text=text)
token_probs = unwrap(kwargs.get("token_probs"), {})
logprobs = unwrap(kwargs.get("logprobs"), [])
offset = unwrap(kwargs.get("offset"), [])
logprob_response = CompletionLogProbs(
text_offset=offset if isinstance(offset, list) else [offset],
token_logprobs=token_probs.values(),
tokens=token_probs.keys(),
top_logprobs=logprobs if isinstance(logprobs, list) else [logprobs],
)
choice = CompletionRespChoice(
finish_reason="Generated",
text=unwrap(kwargs.get("text"), ""),
logprobs=logprob_response,
)
prompt_tokens = unwrap(kwargs.get("prompt_tokens"), 0)
completion_tokens = unwrap(kwargs.get("completion_tokens"), 0)
response = CompletionResponse( response = CompletionResponse(
choices=[choice], choices=[choice],
model=unwrap(model_name, ""), model=unwrap(kwargs.get("model_name"), ""),
usage=UsageStats( usage=UsageStats(
prompt_tokens=prompt_tokens, prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens, completion_tokens=completion_tokens,
@@ -37,12 +55,12 @@ def create_completion_response(
def create_chat_completion_response( def create_chat_completion_response(
text: str, text: str,
prompt_tokens: int, prompt_tokens: Optional[int],
completion_tokens: int, completion_tokens: Optional[int],
model_name: Optional[str], model_name: Optional[str],
): ):
"""Create a chat completion response from the provided text.""" """Create a chat completion response from the provided text."""
message = ChatCompletionMessage(role="assistant", content=text) message = ChatCompletionMessage(role="assistant", content=unwrap(text, ""))
choice = ChatCompletionRespChoice(finish_reason="Generated", message=message) choice = ChatCompletionRespChoice(finish_reason="Generated", message=message)

View File

@@ -472,20 +472,72 @@ class ExllamaV2Container:
"unk_token": self.tokenizer.unk_token, "unk_token": self.tokenizer.unk_token,
} }
def get_logprobs(self, logits: torch.Tensor, max_logprobs: int):
normalized_logits = torch.log_softmax(logits, dim=-1)
top_values, top_ids = torch.topk(normalized_logits, max_logprobs, dim=-1)
top_tokens = list(
map(
lambda index: self.tokenizer.extended_id_to_piece.get(
index, self.tokenizer.id_to_piece[index]
),
top_ids[0].tolist(),
)
)
top_values = top_values[0].tolist()
return dict(zip(top_tokens, top_values, strict=True))
def get_token_probs(self, token_ids: torch.tensor, token_probs: torch.Tensor):
tokens = list(
map(
lambda index: self.tokenizer.extended_id_to_piece.get(
index, self.tokenizer.id_to_piece[index]
),
token_ids[0].tolist(),
)
)
return dict(zip(tokens, token_probs[0].tolist(), strict=True))
def generate(self, prompt: str, **kwargs):
"""Generate a response to a prompt"""
generations = list(self.generate_gen(prompt, **kwargs))
joined_generation = {
"chunk": "",
"prompt_tokens": 0,
"generation_tokens": 0,
"offset": [],
"token_probs": {},
"logprobs": [],
}
if generations:
for generation in generations:
joined_generation["chunk"] += unwrap(generation.get("chunk"), "")
joined_generation["offset"].append(unwrap(generation.get("offset"), []))
joined_generation["token_probs"].update(
unwrap(generation.get("token_probs"), {})
)
joined_generation["logprobs"].append(
unwrap(generation.get("logprobs"), {})
)
joined_generation["prompt_tokens"] = unwrap(
generations[-1].get("prompt_tokens"), 0
)
joined_generation["generation_tokens"] = unwrap(
generations[-1].get("generated_tokens"), 0
)
return joined_generation
def check_unsupported_settings(self, **kwargs): def check_unsupported_settings(self, **kwargs):
"""Check and warn the user if a sampler is unsupported. Meant for dev wheels!""" """Check and warn the user if a sampler is unsupported. Meant for dev wheels!"""
pass pass
def generate(self, prompt: str, **kwargs):
"""Generate a response to a prompt"""
generation = list(self.generate_gen(prompt, **kwargs))
if generation:
response = "".join(map(lambda chunk: chunk[0], generation))
return response, generation[-1][1], generation[-1][2]
return "", 0, 0
# pylint: disable=too-many-locals,too-many-branches,too-many-statements # pylint: disable=too-many-locals,too-many-branches,too-many-statements
def generate_gen(self, prompt: str, **kwargs): def generate_gen(self, prompt: str, **kwargs):
""" """
@@ -639,6 +691,7 @@ class ExllamaV2Container:
add_bos_token = unwrap(kwargs.get("add_bos_token"), True) add_bos_token = unwrap(kwargs.get("add_bos_token"), True)
ban_eos_token = unwrap(kwargs.get("ban_eos_token"), False) ban_eos_token = unwrap(kwargs.get("ban_eos_token"), False)
logit_bias = kwargs.get("logit_bias") logit_bias = kwargs.get("logit_bias")
request_logprobs = unwrap(kwargs.get("logprobs"), 0)
# Override sampler settings for temp = 0 # Override sampler settings for temp = 0
if gen_settings.temperature == 0: if gen_settings.temperature == 0:
@@ -657,6 +710,7 @@ class ExllamaV2Container:
generate_window=generate_window, generate_window=generate_window,
add_bos_token=add_bos_token, add_bos_token=add_bos_token,
ban_eos_token=ban_eos_token, ban_eos_token=ban_eos_token,
logprobs=request_logprobs,
stop_conditions=stop_conditions, stop_conditions=stop_conditions,
logit_bias=logit_bias, logit_bias=logit_bias,
) )
@@ -758,7 +812,7 @@ class ExllamaV2Container:
gen_settings.token_repetition_range = generated_tokens gen_settings.token_repetition_range = generated_tokens
# Generate # Generate
chunk, eos, tokens, _, _ = self.generator.stream() chunk, eos, tokens, token_probs, logits = self.generator.stream()
if token_healing: if token_healing:
# Extract healed token # Extract healed token
@@ -780,7 +834,27 @@ class ExllamaV2Container:
if chunk_buffer != "" and ( if chunk_buffer != "" and (
elapsed > stream_interval or eos or generated_tokens == max_tokens elapsed > stream_interval or eos or generated_tokens == max_tokens
): ):
yield chunk_buffer, prompt_tokens, generated_tokens generation = {
"chunk": chunk_buffer,
"prompt_tokens": prompt_tokens,
"generated_tokens": generated_tokens,
"offset": len(full_response),
}
if request_logprobs > 0:
# Get sampled token probs
if token_probs.numel() > 0 and tokens.numel() > 0:
generation["token_probs"] = self.get_token_probs(
tokens, token_probs
)
# Get logprob choices
if logits.numel() > 0:
generation["logprobs"] = self.get_logprobs(
logits, request_logprobs
)
yield generation
full_response += chunk_buffer full_response += chunk_buffer
chunk_buffer = "" chunk_buffer = ""
last_chunk_time = now last_chunk_time = now

View File

@@ -159,7 +159,7 @@ class BaseSamplerRequest(BaseModel):
examples=[1.0], examples=[1.0],
) )
def to_gen_params(self): def to_gen_params(self, **kwargs):
"""Converts samplers to internal generation params""" """Converts samplers to internal generation params"""
# Add forced overrides if present # Add forced overrides if present
@@ -201,7 +201,7 @@ class BaseSamplerRequest(BaseModel):
"negative_prompt": self.negative_prompt, "negative_prompt": self.negative_prompt,
} }
return gen_params return {**gen_params, **kwargs}
# Global for default overrides # Global for default overrides

16
main.py
View File

@@ -458,12 +458,13 @@ async def generate_completion(request: Request, data: CompletionRequest):
new_generation = MODEL_CONTAINER.generate_gen( new_generation = MODEL_CONTAINER.generate_gen(
data.prompt, **data.to_gen_params() data.prompt, **data.to_gen_params()
) )
for part, prompt_tokens, completion_tokens in new_generation: for generation in new_generation:
if await request.is_disconnected(): if await request.is_disconnected():
break break
response = create_completion_response( response = create_completion_response(
part, prompt_tokens, completion_tokens, model_path.name **generation,
model_name=model_path.name,
) )
yield get_sse_packet(response.model_dump_json()) yield get_sse_packet(response.model_dump_json())
@@ -479,13 +480,10 @@ async def generate_completion(request: Request, data: CompletionRequest):
generate_with_semaphore(generator), media_type="text/event-stream" generate_with_semaphore(generator), media_type="text/event-stream"
) )
response_text, prompt_tokens, completion_tokens = await call_with_semaphore( generation = await call_with_semaphore(
partial(MODEL_CONTAINER.generate, data.prompt, **data.to_gen_params()) partial(MODEL_CONTAINER.generate, data.prompt, **data.to_gen_params())
) )
response = create_completion_response(**generation)
response = create_completion_response(
response_text, prompt_tokens, completion_tokens, model_path.name
)
return response return response
@@ -545,12 +543,12 @@ async def generate_chat_completion(request: Request, data: ChatCompletionRequest
new_generation = MODEL_CONTAINER.generate_gen( new_generation = MODEL_CONTAINER.generate_gen(
prompt, **data.to_gen_params() prompt, **data.to_gen_params()
) )
for part, _, _ in new_generation: for generation in new_generation:
if await request.is_disconnected(): if await request.is_disconnected():
break break
response = create_chat_completion_stream_chunk( response = create_chat_completion_stream_chunk(
const_id, part, model_path.name const_id, generation.get("chunk"), model_path.name
) )
yield get_sse_packet(response.model_dump_json()) yield get_sse_packet(response.model_dump_json())