From b149d3398d4f62c282452121fc6057c4439eb123 Mon Sep 17 00:00:00 2001 From: Volodymyr Kuznetsov Date: Mon, 8 Jul 2024 13:42:54 -0700 Subject: [PATCH] OAI: support stream_options argument --- endpoints/OAI/types/chat_completion.py | 1 + endpoints/OAI/types/common.py | 5 +++++ endpoints/OAI/utils/chat_completion.py | 19 ++++++++++++++++++- 3 files changed, 24 insertions(+), 1 deletion(-) diff --git a/endpoints/OAI/types/chat_completion.py b/endpoints/OAI/types/chat_completion.py index be5cfea..b50e646 100644 --- a/endpoints/OAI/types/chat_completion.py +++ b/endpoints/OAI/types/chat_completion.py @@ -64,3 +64,4 @@ class ChatCompletionStreamChunk(BaseModel): created: int = Field(default_factory=lambda: int(time())) model: str object: str = "chat.completion.chunk" + usage: Optional[UsageStats] = None diff --git a/endpoints/OAI/types/common.py b/endpoints/OAI/types/common.py index d44e41a..6970adf 100644 --- a/endpoints/OAI/types/common.py +++ b/endpoints/OAI/types/common.py @@ -18,6 +18,10 @@ class CompletionResponseFormat(BaseModel): type: str = "text" +class ChatCompletionStreamOptions(BaseModel): + include_usage: Optional[bool] = False + + class CommonCompletionRequest(BaseSamplerRequest): """Represents a common completion request.""" @@ -27,6 +31,7 @@ class CommonCompletionRequest(BaseSamplerRequest): # Generation info (remainder is in BaseSamplerRequest superclass) stream: Optional[bool] = False + stream_options: Optional[ChatCompletionStreamOptions] = None logprobs: Optional[int] = Field( default_factory=lambda: get_default_sampler_value("logprobs", 0) ) diff --git a/endpoints/OAI/utils/chat_completion.py b/endpoints/OAI/utils/chat_completion.py index 9e82b1b..9b91d1d 100644 --- a/endpoints/OAI/utils/chat_completion.py +++ b/endpoints/OAI/utils/chat_completion.py @@ -246,6 +246,7 @@ async def stream_generate_chat_completion( gen_queue = asyncio.Queue() gen_tasks: List[asyncio.Task] = [] disconnect_task = asyncio.create_task(request_disconnect_loop(request)) + need_usage = data.stream_options and data.stream_options.include_usage try: gen_params = data.to_gen_params() @@ -275,10 +276,26 @@ async def stream_generate_chat_completion( raise generation response = _create_stream_chunk(const_id, generation, model_path.name) - yield response.model_dump_json() + yield response.model_dump_json(exclude=None if need_usage else "usage") # Check if all tasks are completed if all(task.done() for task in gen_tasks) and gen_queue.empty(): + if need_usage: + prompt_tokens = unwrap(generation.get("prompt_tokens"), 0) + completion_tokens = unwrap(generation.get("generated_tokens"), 0) + + response = ChatCompletionStreamChunk( + id=const_id, + choices=[], + model=unwrap(model_path.name, ""), + usage=UsageStats( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=prompt_tokens + completion_tokens, + ), + ) + + yield response.model_dump_json() break except CancelledError: # Get out if the request gets disconnected