mirror of
https://github.com/theroyallab/tabbyAPI.git
synced 2026-03-14 15:57:27 +00:00
OAI: Add chat completions endpoint
Chat completions is the endpoint that will be used by OAI in the future. Makes sense to support it even though the completions endpoint will be used more often. Also unify common parameters between the chat completion and completion requests since they're very similar. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
@@ -2,102 +2,26 @@ from uuid import uuid4
|
||||
from time import time
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import List, Optional, Dict, Union
|
||||
from OAI.types.common import LogProbs, UsageStats
|
||||
from OAI.types.common import LogProbs, UsageStats, CommonCompletionRequest
|
||||
|
||||
class CompletionRespChoice(BaseModel):
|
||||
# Index is 0 since we aren't using multiple choices
|
||||
index: int = 0
|
||||
finish_reason: str
|
||||
index: int
|
||||
logprobs: Optional[LogProbs] = None
|
||||
text: str
|
||||
|
||||
class CompletionRequest(BaseModel):
|
||||
# Model information
|
||||
model: str
|
||||
|
||||
# Inherited from common request
|
||||
class CompletionRequest(CommonCompletionRequest):
|
||||
# Prompt can also contain token ids, but that's out of scope for this project.
|
||||
prompt: Union[str, List[str]]
|
||||
|
||||
# Extra OAI request stuff
|
||||
best_of: Optional[int] = None
|
||||
echo: Optional[bool] = False
|
||||
logit_bias: Optional[Dict[str, float]] = None
|
||||
logprobs: Optional[int] = None
|
||||
n: Optional[int] = 1
|
||||
suffix: Optional[str] = None
|
||||
user: Optional[str] = None
|
||||
|
||||
# Generation info
|
||||
seed: Optional[int] = -1
|
||||
stream: Optional[bool] = False
|
||||
stop: Optional[Union[str, List[str]]] = None
|
||||
|
||||
# Default to 150 as 16 makes no sense as a default
|
||||
max_tokens: Optional[int] = 150
|
||||
|
||||
# Not supported sampling params
|
||||
presence_penalty: Optional[float] = 0.0
|
||||
|
||||
# Aliased to repetition_penalty
|
||||
frequency_penalty: Optional[float] = 0.0
|
||||
|
||||
# Sampling params
|
||||
token_healing: Optional[bool] = False
|
||||
temperature: Optional[float] = 1.0
|
||||
top_k: Optional[int] = 0
|
||||
top_p: Optional[float] = 1.0
|
||||
typical: Optional[float] = 0.0
|
||||
min_p: Optional[float] = 0.0
|
||||
tfs: Optional[float] = 1.0
|
||||
repetition_penalty: Optional[float] = 1.0
|
||||
repetition_penalty_range: Optional[int] = 0
|
||||
repetition_decay: Optional[int] = 0
|
||||
mirostat_mode: Optional[int] = 0
|
||||
mirostat_tau: Optional[float] = 1.5
|
||||
mirostat_eta: Optional[float] = 0.1
|
||||
add_bos_token: Optional[bool] = True
|
||||
ban_eos_token: Optional[bool] = False
|
||||
|
||||
# Converts to internal generation parameters
|
||||
def to_gen_params(self):
|
||||
# Convert prompt to a string
|
||||
if isinstance(self.prompt, list):
|
||||
self.prompt = "\n".join(self.prompt)
|
||||
|
||||
# Convert stop to an array of strings
|
||||
if isinstance(self.stop, str):
|
||||
self.stop = [self.stop]
|
||||
|
||||
# Set repetition_penalty to frequency_penalty if repetition_penalty isn't already defined
|
||||
if (self.repetition_penalty is None or self.repetition_penalty == 1.0) and self.frequency_penalty:
|
||||
self.repetition_penalty = self.frequency_penalty
|
||||
|
||||
return {
|
||||
"prompt": self.prompt,
|
||||
"stop": self.stop,
|
||||
"max_tokens": self.max_tokens,
|
||||
"add_bos_token": self.add_bos_token,
|
||||
"ban_eos_token": self.ban_eos_token,
|
||||
"token_healing": self.token_healing,
|
||||
"temperature": self.temperature,
|
||||
"top_k": self.top_k,
|
||||
"top_p": self.top_p,
|
||||
"typical": self.typical,
|
||||
"min_p": self.min_p,
|
||||
"tfs": self.tfs,
|
||||
"repetition_penalty": self.repetition_penalty,
|
||||
"repetition_penalty_range": self.repetition_penalty_range,
|
||||
"repetition_decay": self.repetition_decay,
|
||||
"mirostat": True if self.mirostat_mode == 2 else False,
|
||||
"mirostat_tau": self.mirostat_tau,
|
||||
"mirostat_eta": self.mirostat_eta
|
||||
}
|
||||
|
||||
class CompletionResponse(BaseModel):
|
||||
id: str = Field(default_factory=lambda: f"cmpl-{uuid4().hex}")
|
||||
choices: List[CompletionRespChoice]
|
||||
created: int = Field(default_factory=lambda: int(time()))
|
||||
model: str
|
||||
object: str = "text-completion"
|
||||
object: str = "text_completion"
|
||||
|
||||
# TODO: Add usage stats
|
||||
usage: Optional[UsageStats] = None
|
||||
|
||||
Reference in New Issue
Block a user