mirror of
https://github.com/theroyallab/tabbyAPI.git
synced 2026-04-26 01:08:52 +00:00
Model: Add logprobs support
Returns token offsets, selected tokens, probabilities of tokens post-sampling, and normalized probability of selecting a token pre-sampling (for efficiency purposes). Only for text completions. Chat completions in a later commit. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
@@ -1,19 +1,10 @@
|
||||
""" Common types for OAI. """
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import List, Dict, Optional
|
||||
from typing import Optional
|
||||
|
||||
from common.sampling import BaseSamplerRequest
|
||||
|
||||
|
||||
class LogProbs(BaseModel):
|
||||
"""Represents log probabilities."""
|
||||
|
||||
text_offset: List[int] = Field(default_factory=list)
|
||||
token_logprobs: List[Optional[float]] = Field(default_factory=list)
|
||||
tokens: List[str] = Field(default_factory=list)
|
||||
top_logprobs: List[Optional[Dict[str, float]]] = Field(default_factory=list)
|
||||
|
||||
|
||||
class UsageStats(BaseModel):
|
||||
"""Represents usage stats."""
|
||||
|
||||
@@ -29,6 +20,10 @@ class CommonCompletionRequest(BaseSamplerRequest):
|
||||
# This parameter is not used, the loaded model is used instead
|
||||
model: Optional[str] = None
|
||||
|
||||
# Generation info (remainder is in BaseSamplerRequest superclass)
|
||||
stream: Optional[bool] = False
|
||||
logprobs: Optional[int] = 0
|
||||
|
||||
# Extra OAI request stuff
|
||||
best_of: Optional[int] = Field(
|
||||
description="Not parsed. Only used for OAI compliance.", default=None
|
||||
@@ -36,9 +31,6 @@ class CommonCompletionRequest(BaseSamplerRequest):
|
||||
echo: Optional[bool] = Field(
|
||||
description="Not parsed. Only used for OAI compliance.", default=False
|
||||
)
|
||||
logprobs: Optional[int] = Field(
|
||||
description="Not parsed. Only used for OAI compliance.", default=None
|
||||
)
|
||||
n: Optional[int] = Field(
|
||||
description="Not parsed. Only used for OAI compliance.", default=1
|
||||
)
|
||||
@@ -49,5 +41,7 @@ class CommonCompletionRequest(BaseSamplerRequest):
|
||||
description="Not parsed. Only used for OAI compliance.", default=None
|
||||
)
|
||||
|
||||
# Generation info (remainder is in BaseSamplerRequest superclass)
|
||||
stream: Optional[bool] = False
|
||||
def to_gen_params(self):
|
||||
extra_gen_params = {"logprobs": self.logprobs}
|
||||
|
||||
return super().to_gen_params(**extra_gen_params)
|
||||
|
||||
@@ -1,10 +1,19 @@
|
||||
""" Completion API protocols """
|
||||
from pydantic import BaseModel, Field
|
||||
from time import time
|
||||
from typing import List, Optional, Union
|
||||
from typing import Dict, List, Optional, Union
|
||||
from uuid import uuid4
|
||||
|
||||
from OAI.types.common import CommonCompletionRequest, LogProbs, UsageStats
|
||||
from OAI.types.common import CommonCompletionRequest, UsageStats
|
||||
|
||||
|
||||
class CompletionLogProbs(BaseModel):
|
||||
"""Represents log probabilities for a completion request."""
|
||||
|
||||
text_offset: List[int] = Field(default_factory=list)
|
||||
token_logprobs: List[Optional[float]] = Field(default_factory=list)
|
||||
tokens: List[str] = Field(default_factory=list)
|
||||
top_logprobs: List[Optional[Dict[str, float]]] = Field(default_factory=list)
|
||||
|
||||
|
||||
class CompletionRespChoice(BaseModel):
|
||||
@@ -13,7 +22,7 @@ class CompletionRespChoice(BaseModel):
|
||||
# Index is 0 since we aren't using multiple choices
|
||||
index: int = 0
|
||||
finish_reason: str
|
||||
logprobs: Optional[LogProbs] = None
|
||||
logprobs: Optional[CompletionLogProbs] = None
|
||||
text: str
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user