From b1f3baad744b6f409b8af05347139c0fa6e713f3 Mon Sep 17 00:00:00 2001 From: kingbri Date: Tue, 9 Apr 2024 21:23:01 -0400 Subject: [PATCH] OAI: Add response_format parameter response_format allows a user to request a valid, but arbitrary JSON object from the API. This is a new part of the OAI spec. Signed-off-by: kingbri --- endpoints/OAI/router.py | 8 ++++++++ endpoints/OAI/types/common.py | 7 +++++++ 2 files changed, 15 insertions(+) diff --git a/endpoints/OAI/router.py b/endpoints/OAI/router.py index 6d50e34..833945c 100644 --- a/endpoints/OAI/router.py +++ b/endpoints/OAI/router.py @@ -449,6 +449,10 @@ async def completion_request(request: Request, data: CompletionRequest): config.developer_config().get("disable_request_streaming"), False ) + # Set an empty JSON schema if the request wants a JSON response + if data.response_format.type == "json": + data.json_schema = {"type": "object"} + if data.stream and not disable_request_streaming: generator_callback = partial(stream_generate_completion, data, model_path) @@ -492,6 +496,10 @@ async def chat_completion_request(request: Request, data: ChatCompletionRequest) else: prompt = format_prompt_with_template(data) + # Set an empty JSON schema if the request wants a JSON response + if data.response_format.type == "json": + data.json_schema = {"type": "object"} + disable_request_streaming = unwrap( config.developer_config().get("disable_request_streaming"), False ) diff --git a/endpoints/OAI/types/common.py b/endpoints/OAI/types/common.py index 2241ad0..b9aac68 100644 --- a/endpoints/OAI/types/common.py +++ b/endpoints/OAI/types/common.py @@ -14,6 +14,10 @@ class UsageStats(BaseModel): total_tokens: int +class CompletionResponseFormat(BaseModel): + type: str = "text" + + class CommonCompletionRequest(BaseSamplerRequest): """Represents a common completion request.""" @@ -24,6 +28,9 @@ class CommonCompletionRequest(BaseSamplerRequest): # Generation info (remainder is in BaseSamplerRequest superclass) stream: Optional[bool] = False logprobs: Optional[int] = 0 + response_format: Optional[CompletionResponseFormat] = Field( + default_factory=CompletionResponseFormat + ) # Extra OAI request stuff best_of: Optional[int] = Field(