diff --git a/backends/exllamav2/model.py b/backends/exllamav2/model.py index fb39a9f..5f8f68d 100644 --- a/backends/exllamav2/model.py +++ b/backends/exllamav2/model.py @@ -605,6 +605,9 @@ class ExllamaV2Container: joined_generation["generation_tokens"] = unwrap( generations[-1].get("generated_tokens"), 0 ) + joined_generation["finish_reason"] = unwrap( + generations[-1].get("finish_reason"), "stop" + ) return joined_generation @@ -1004,6 +1007,10 @@ class ExllamaV2Container: last_chunk_time = now if eos or generated_tokens == max_tokens: + finish_reason = "length" if generated_tokens == max_tokens else "stop" + generation = {"finish_reason": finish_reason} + yield generation + break # Print response diff --git a/endpoints/OAI/types/chat_completion.py b/endpoints/OAI/types/chat_completion.py index 75987d9..7d3138e 100644 --- a/endpoints/OAI/types/chat_completion.py +++ b/endpoints/OAI/types/chat_completion.py @@ -24,7 +24,7 @@ class ChatCompletionMessage(BaseModel): class ChatCompletionRespChoice(BaseModel): # Index is 0 since we aren't using multiple choices index: int = 0 - finish_reason: str + finish_reason: Optional[str] = None message: ChatCompletionMessage logprobs: Optional[ChatCompletionLogprobs] = None @@ -32,7 +32,7 @@ class ChatCompletionRespChoice(BaseModel): class ChatCompletionStreamChoice(BaseModel): # Index is 0 since we aren't using multiple choices index: int = 0 - finish_reason: Optional[str] + finish_reason: Optional[str] = None delta: Union[ChatCompletionMessage, dict] = {} logprobs: Optional[ChatCompletionLogprobs] = None diff --git a/endpoints/OAI/types/completion.py b/endpoints/OAI/types/completion.py index 84b7519..d0a7187 100644 --- a/endpoints/OAI/types/completion.py +++ b/endpoints/OAI/types/completion.py @@ -22,7 +22,7 @@ class CompletionRespChoice(BaseModel): # Index is 0 since we aren't using multiple choices index: int = 0 - finish_reason: str + finish_reason: Optional[str] = None logprobs: Optional[CompletionLogProbs] = None text: str diff --git a/endpoints/OAI/utils/chat_completion.py b/endpoints/OAI/utils/chat_completion.py index c88c258..cb25ab2 100644 --- a/endpoints/OAI/utils/chat_completion.py +++ b/endpoints/OAI/utils/chat_completion.py @@ -60,7 +60,9 @@ def _create_response(generation: dict, model_name: Optional[str]): logprob_response = ChatCompletionLogprobs(content=collected_token_probs) choice = ChatCompletionRespChoice( - finish_reason="Generated", message=message, logprobs=logprob_response + finish_reason=generation.get("finish_reason"), + message=message, + logprobs=logprob_response, ) prompt_tokens = unwrap(generation.get("prompt_tokens"), 0) @@ -83,14 +85,15 @@ def _create_stream_chunk( const_id: str, generation: Optional[dict] = None, model_name: Optional[str] = None, - finish_reason: Optional[str] = None, ): """Create a chat completion stream chunk from the provided text.""" logprob_response = None - if finish_reason: - message = {} + if "finish_reason" in generation: + choice = ChatCompletionStreamChoice( + finish_reason=generation.get("finish_reason") + ) else: message = ChatCompletionMessage( role="assistant", content=unwrap(generation.get("text"), "") @@ -113,10 +116,10 @@ def _create_stream_chunk( logprob_response = ChatCompletionLogprobs(content=[token_prob_response]) - # The finish reason can be None - choice = ChatCompletionStreamChoice( - finish_reason=finish_reason, delta=message, logprobs=logprob_response - ) + choice = ChatCompletionStreamChoice( + delta=message, + logprobs=logprob_response, + ) chunk = ChatCompletionStreamChunk( id=const_id, choices=[choice], model=unwrap(model_name, "") @@ -165,10 +168,14 @@ async def stream_generate_chat_completion( yield response.model_dump_json() - # Yield a finish response on successful generation - finish_response = _create_stream_chunk(const_id, finish_reason="stop") + # Break if the generation is finished + if "finish_reason" in generation: + break - yield finish_response.model_dump_json() + # Yield a finish response on successful generation + # finish_response = _create_stream_chunk(const_id, finish_reason="stop") + + # yield finish_response.model_dump_json() except CancelledError: # Get out if the request gets disconnected diff --git a/endpoints/OAI/utils/completion.py b/endpoints/OAI/utils/completion.py index 8d56b09..c690493 100644 --- a/endpoints/OAI/utils/completion.py +++ b/endpoints/OAI/utils/completion.py @@ -39,7 +39,7 @@ def _create_response(generation: dict, model_name: Optional[str]): ) choice = CompletionRespChoice( - finish_reason="Generated", + finish_reason=generation.get("finish_reason"), text=unwrap(generation.get("text"), ""), logprobs=logprob_response, ) @@ -69,11 +69,15 @@ async def stream_generate_completion(data: CompletionRequest, model_path: pathli ) async for generation in new_generation: response = _create_response(generation, model_path.name) - yield response.model_dump_json() + # Break if the generation is finished + if "finish_reason" in generation: + yield "[DONE]" + break + # Yield a finish response on successful generation - yield "[DONE]" + # yield "[DONE]" except CancelledError: # Get out if the request gets disconnected