OAI: Strictly type chat completions

Previously, the messages were a list of dicts. These are untyped
and don't provide strict hinting. Add types for chat completion
messages and reformat existing code.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri
2024-11-19 23:15:47 -05:00
parent 0fadb1e5e8
commit 8ffc636dce
3 changed files with 37 additions and 25 deletions

View File

@@ -1,7 +1,7 @@
from pydantic import BaseModel, Field
from pydantic.json_schema import SkipJsonSchema
from time import time
from typing import Union, List, Optional, Dict
from typing import Literal, Union, List, Optional, Dict
from uuid import uuid4
from endpoints.OAI.types.common import UsageStats, CommonCompletionRequest
@@ -18,10 +18,21 @@ class ChatCompletionLogprobs(BaseModel):
content: List[ChatCompletionLogprob] = Field(default_factory=list)
class ChatCompletionImageUrl(BaseModel):
url: str
class ChatCompletionMessagePart(BaseModel):
type: Literal["text", "image_url"] = "text"
text: Optional[str] = None
image_url: Optional[ChatCompletionImageUrl] = None
class ChatCompletionMessage(BaseModel):
role: Optional[str] = None
content: Optional[str] = None
role: str = "user"
content: Optional[Union[str, List[ChatCompletionMessagePart]]] = None
tool_calls: Optional[List[ToolCall]] = None
tool_calls_json: SkipJsonSchema[Optional[str]] = None
class ChatCompletionRespChoice(BaseModel):
@@ -51,7 +62,7 @@ class ChatCompletionRequest(CommonCompletionRequest):
# WIP this can probably be tightened, or maybe match the OAI lib type
# in openai\types\chat\chat_completion_message_param.py
messages: Union[str, List[Dict]]
messages: List[ChatCompletionMessage] = Field(default_factory=list)
prompt_template: Optional[str] = None
add_generation_prompt: Optional[bool] = True
template_vars: Optional[dict] = {}