Model: Add logprobs support

Returns token offsets, selected tokens, probabilities of tokens
post-sampling, and normalized probability of selecting a token
pre-sampling (for efficiency purposes).

Only for text completions. Chat completions in a later commit.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri
2024-02-07 21:41:15 -05:00
committed by Brian Dashore
parent 2642ef7156
commit 0af6a38af3
6 changed files with 145 additions and 52 deletions

View File

@@ -1,19 +1,10 @@
""" Common types for OAI. """
from pydantic import BaseModel, Field
from typing import List, Dict, Optional
from typing import Optional
from common.sampling import BaseSamplerRequest
class LogProbs(BaseModel):
"""Represents log probabilities."""
text_offset: List[int] = Field(default_factory=list)
token_logprobs: List[Optional[float]] = Field(default_factory=list)
tokens: List[str] = Field(default_factory=list)
top_logprobs: List[Optional[Dict[str, float]]] = Field(default_factory=list)
class UsageStats(BaseModel):
"""Represents usage stats."""
@@ -29,6 +20,10 @@ class CommonCompletionRequest(BaseSamplerRequest):
# This parameter is not used, the loaded model is used instead
model: Optional[str] = None
# Generation info (remainder is in BaseSamplerRequest superclass)
stream: Optional[bool] = False
logprobs: Optional[int] = 0
# Extra OAI request stuff
best_of: Optional[int] = Field(
description="Not parsed. Only used for OAI compliance.", default=None
@@ -36,9 +31,6 @@ class CommonCompletionRequest(BaseSamplerRequest):
echo: Optional[bool] = Field(
description="Not parsed. Only used for OAI compliance.", default=False
)
logprobs: Optional[int] = Field(
description="Not parsed. Only used for OAI compliance.", default=None
)
n: Optional[int] = Field(
description="Not parsed. Only used for OAI compliance.", default=1
)
@@ -49,5 +41,7 @@ class CommonCompletionRequest(BaseSamplerRequest):
description="Not parsed. Only used for OAI compliance.", default=None
)
# Generation info (remainder is in BaseSamplerRequest superclass)
stream: Optional[bool] = False
def to_gen_params(self):
extra_gen_params = {"logprobs": self.logprobs}
return super().to_gen_params(**extra_gen_params)

View File

@@ -1,10 +1,19 @@
""" Completion API protocols """
from pydantic import BaseModel, Field
from time import time
from typing import List, Optional, Union
from typing import Dict, List, Optional, Union
from uuid import uuid4
from OAI.types.common import CommonCompletionRequest, LogProbs, UsageStats
from OAI.types.common import CommonCompletionRequest, UsageStats
class CompletionLogProbs(BaseModel):
"""Represents log probabilities for a completion request."""
text_offset: List[int] = Field(default_factory=list)
token_logprobs: List[Optional[float]] = Field(default_factory=list)
tokens: List[str] = Field(default_factory=list)
top_logprobs: List[Optional[Dict[str, float]]] = Field(default_factory=list)
class CompletionRespChoice(BaseModel):
@@ -13,7 +22,7 @@ class CompletionRespChoice(BaseModel):
# Index is 0 since we aren't using multiple choices
index: int = 0
finish_reason: str
logprobs: Optional[LogProbs] = None
logprobs: Optional[CompletionLogProbs] = None
text: str