Merge pull request #147 from ai-and-i/stream_options

Support stream_options argument to get usage info in streaming mode
This commit is contained in:
Brian Dashore
2024-07-12 14:38:20 -04:00
committed by GitHub
3 changed files with 37 additions and 3 deletions

View File

@@ -64,3 +64,4 @@ class ChatCompletionStreamChunk(BaseModel):
created: int = Field(default_factory=lambda: int(time()))
model: str
object: str = "chat.completion.chunk"
usage: Optional[UsageStats] = None

View File

@@ -18,6 +18,10 @@ class CompletionResponseFormat(BaseModel):
type: str = "text"
class ChatCompletionStreamOptions(BaseModel):
include_usage: Optional[bool] = False
class CommonCompletionRequest(BaseSamplerRequest):
"""Represents a common completion request."""
@@ -27,6 +31,7 @@ class CommonCompletionRequest(BaseSamplerRequest):
# Generation info (remainder is in BaseSamplerRequest superclass)
stream: Optional[bool] = False
stream_options: Optional[ChatCompletionStreamOptions] = None
logprobs: Optional[int] = Field(
default_factory=lambda: get_default_sampler_value("logprobs", 0)
)

View File

@@ -93,22 +93,37 @@ def _create_stream_chunk(
const_id: str,
generation: Optional[dict] = None,
model_name: Optional[str] = None,
is_usage_chunk: bool = False,
):
"""Create a chat completion stream chunk from the provided text."""
index = generation.get("index")
logprob_response = None
choices = []
usage_stats = None
if "finish_reason" in generation:
if is_usage_chunk:
prompt_tokens = unwrap(generation.get("prompt_tokens"), 0)
completion_tokens = unwrap(generation.get("generated_tokens"), 0)
usage_stats = UsageStats(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
elif "finish_reason" in generation:
choice = ChatCompletionStreamChoice(
index=index,
finish_reason=generation.get("finish_reason"),
)
choices.append(choice)
else:
message = ChatCompletionMessage(
role="assistant", content=unwrap(generation.get("text"), "")
)
logprob_response = None
token_probs = unwrap(generation.get("token_probs"), {})
if token_probs:
logprobs = unwrap(generation.get("logprobs"), {})
@@ -132,8 +147,13 @@ def _create_stream_chunk(
logprobs=logprob_response,
)
choices.append(choice)
chunk = ChatCompletionStreamChunk(
id=const_id, choices=[choice], model=unwrap(model_name, "")
id=const_id,
choices=choices,
model=unwrap(model_name, ""),
usage=usage_stats,
)
return chunk
@@ -279,6 +299,14 @@ async def stream_generate_chat_completion(
# Check if all tasks are completed
if all(task.done() for task in gen_tasks) and gen_queue.empty():
# Send a usage chunk
if data.stream_options and data.stream_options.include_usage:
usage_chunk = _create_stream_chunk(
const_id, generation, model_path.name, is_usage_chunk=True
)
yield usage_chunk.model_dump_json()
yield "[DONE]"
break
except CancelledError:
# Get out if the request gets disconnected