OAI: Implement completion API endpoint

Add support for /v1/completions with the option to use streaming
if needed. Also rewrite API endpoints to use async when possible
since that improves request performance.

Model container parameter names also needed rewrites as well and
set fallback cases to their disabled values.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri
2023-11-13 18:24:12 -05:00
parent 4fa4386275
commit eee8b642bd
6 changed files with 190 additions and 57 deletions

13
OAI/models/common.py Normal file
View File

@@ -0,0 +1,13 @@
from pydantic import BaseModel, Field
from typing import List, Dict
class LogProbs(BaseModel):
text_offset: List[int] = Field(default_factory=list)
token_logprobs: List[float] = Field(default_factory=list)
tokens: List[str] = Field(default_factory=list)
top_logprobs: List[Dict[str, float]] = Field(default_factory=list)
class UsageStats(BaseModel):
completion_tokens: int
prompt_tokens: int
total_tokens: int

99
OAI/models/completions.py Normal file
View File

@@ -0,0 +1,99 @@
from uuid import uuid4
from time import time
from pydantic import BaseModel, Field
from typing import List, Optional, Dict, Union
from OAI.models.common import LogProbs, UsageStats
class CompletionRespChoice(BaseModel):
finish_reason: str
index: int
logprobs: Optional[LogProbs] = None
text: str
class CompletionRequest(BaseModel):
# Model information
model: str
# Prompt can also contain token ids, but that's out of scope for this project.
prompt: Union[str, List[str]]
# Extra OAI request stuff
best_of: Optional[int] = None
echo: Optional[bool] = False
logit_bias: Optional[Dict[str, float]] = None
logprobs: Optional[int] = None
n: Optional[int] = 1
suffix: Optional[str] = None
user: Optional[str] = None
# Generation info
seed: Optional[int] = -1
stream: Optional[bool] = False
stop: Optional[Union[str, List[str]]] = None
# Default to 150 as 16 makes no sense as a default
max_tokens: Optional[int] = 150
# Not supported sampling params
presence_penalty: Optional[int] = 0
# Aliased to repetition_penalty
frequency_penalty: int = 0
# Sampling params
token_healing: Optional[bool] = False
temperature: Optional[float] = 1.0
top_k: Optional[int] = 0
top_p: Optional[float] = 1.0
typical: Optional[float] = 0.0
min_p: Optional[float] = 0.0
tfs: Optional[float] = 1.0
repetition_penalty: Optional[float] = 1.0
repetition_penalty_range: Optional[int] = 0
repetition_decay: Optional[int] = 0
mirostat_mode: Optional[int] = 0
mirostat_tau: Optional[float] = 1.5
mirostat_eta: Optional[float] = 0.1
# Converts to internal generation parameters
def to_gen_params(self):
# Convert prompt to a string
if isinstance(self.prompt, list):
self.prompt = "\n".join(self.prompt)
# Convert stop to an array of strings
if isinstance(self.stop, str):
self.stop = [self.stop]
# Set repetition_penalty to frequency_penalty if repetition_penalty isn't already defined
if (self.repetition_penalty is None or self.repetition_penalty == 1.0) and self.frequency_penalty:
self.repetition_penalty = self.frequency_penalty
return {
"prompt": self.prompt,
"stop": self.stop,
"max_tokens": self.max_tokens,
"token_healing": self.token_healing,
"temperature": self.temperature,
"top_k": self.top_k,
"top_p": self.top_p,
"typical": self.typical,
"min_p": self.min_p,
"tfs": self.tfs,
"repetition_penalty": self.repetition_penalty,
"repetition_penalty_range": self.repetition_penalty_range,
"repetition_decay": self.repetition_decay,
"mirostat": True if self.mirostat_mode == 2 else False,
"mirostat_tau": self.mirostat_tau,
"mirostat_eta": self.mirostat_eta
}
class CompletionResponse(BaseModel):
id: str = Field(default_factory=lambda: f"cmpl-{uuid4().hex}")
choices: List[CompletionRespChoice]
created: int = Field(default_factory=lambda: int(time()))
model: str
object: str = "text-completion"
# TODO: Add usage stats
usage: Optional[UsageStats] = None

19
OAI/utils.py Normal file
View File

@@ -0,0 +1,19 @@
from OAI.models.completions import CompletionResponse, CompletionRespChoice
from OAI.models.common import UsageStats
from typing import Optional
def create_completion_response(text: str, index: int, model_name: Optional[str]):
# TODO: Add method to get token amounts in model for UsageStats
choice = CompletionRespChoice(
finish_reason="Generated",
index = index,
text = text
)
response = CompletionResponse(
choices = [choice],
model = model_name or ""
)
return response