Adding token usage support

This commit is contained in:
Mehran Ziadloo
2023-11-27 20:05:05 -08:00
parent 44e7f7b0ee
commit ead503c75b
6 changed files with 34 additions and 23 deletions

View File

@@ -32,8 +32,6 @@ class ChatCompletionResponse(BaseModel):
created: int = Field(default_factory=lambda: int(time())) created: int = Field(default_factory=lambda: int(time()))
model: str model: str
object: str = "chat.completion" object: str = "chat.completion"
# TODO: Add usage stats
usage: Optional[UsageStats] = None usage: Optional[UsageStats] = None
class ChatCompletionStreamChunk(BaseModel): class ChatCompletionStreamChunk(BaseModel):

View File

@@ -8,8 +8,8 @@ class LogProbs(BaseModel):
top_logprobs: List[Dict[str, float]] = Field(default_factory=list) top_logprobs: List[Dict[str, float]] = Field(default_factory=list)
class UsageStats(BaseModel): class UsageStats(BaseModel):
completion_tokens: int
prompt_tokens: int prompt_tokens: int
completion_tokens: int
total_tokens: int total_tokens: int
class CommonCompletionRequest(BaseModel): class CommonCompletionRequest(BaseModel):

View File

@@ -22,6 +22,4 @@ class CompletionResponse(BaseModel):
created: int = Field(default_factory=lambda: int(time())) created: int = Field(default_factory=lambda: int(time()))
model: str model: str
object: str = "text_completion" object: str = "text_completion"
# TODO: Add usage stats
usage: Optional[UsageStats] = None usage: Optional[UsageStats] = None

View File

@@ -1,5 +1,5 @@
import os, pathlib import os, pathlib
from OAI.types.completion import CompletionResponse, CompletionRespChoice from OAI.types.completion import CompletionResponse, CompletionRespChoice, UsageStats
from OAI.types.chat_completion import ( from OAI.types.chat_completion import (
ChatCompletionMessage, ChatCompletionMessage,
ChatCompletionRespChoice, ChatCompletionRespChoice,
@@ -20,9 +20,7 @@ try:
except ImportError: except ImportError:
_fastchat_available = False _fastchat_available = False
def create_completion_response(text: str, model_name: Optional[str]): def create_completion_response(text: str, prompt_tokens: int, completion_tokens: int, model_name: Optional[str]):
# TODO: Add method to get token amounts in model for UsageStats
choice = CompletionRespChoice( choice = CompletionRespChoice(
finish_reason = "Generated", finish_reason = "Generated",
text = text text = text
@@ -30,14 +28,15 @@ def create_completion_response(text: str, model_name: Optional[str]):
response = CompletionResponse( response = CompletionResponse(
choices = [choice], choices = [choice],
model = model_name or "" model = model_name or "",
usage = UsageStats(prompt_tokens = prompt_tokens,
completion_tokens = completion_tokens,
total_tokens = prompt_tokens + completion_tokens)
) )
return response return response
def create_chat_completion_response(text: str, model_name: Optional[str]): def create_chat_completion_response(text: str, prompt_tokens: int, completion_tokens: int, model_name: Optional[str]):
# TODO: Add method to get token amounts in model for UsageStats
message = ChatCompletionMessage( message = ChatCompletionMessage(
role = "assistant", role = "assistant",
content = text content = text
@@ -50,7 +49,10 @@ def create_chat_completion_response(text: str, model_name: Optional[str]):
response = ChatCompletionResponse( response = ChatCompletionResponse(
choices = [choice], choices = [choice],
model = model_name or "" model = model_name or "",
usage = UsageStats(prompt_tokens = prompt_tokens,
completion_tokens = completion_tokens,
total_tokens = prompt_tokens + completion_tokens)
) )
return response return response

15
main.py
View File

@@ -179,14 +179,20 @@ async def generate_completion(request: Request, data: CompletionRequest):
if await request.is_disconnected(): if await request.is_disconnected():
break break
response = create_completion_response(part, model_path.name) response = create_completion_response(part,
model_container.prompt_token_size,
model_container.completion_token_size,
model_path.name)
yield response.json(ensure_ascii=False) yield response.json(ensure_ascii=False)
return EventSourceResponse(generator()) return EventSourceResponse(generator())
else: else:
response_text = model_container.generate(data.prompt, **data.to_gen_params()) response_text = model_container.generate(data.prompt, **data.to_gen_params())
response = create_completion_response(response_text, model_path.name) response = create_completion_response(response_text,
model_container.prompt_token_size,
model_container.completion_token_size,
model_path.name)
return response return response
@@ -219,7 +225,10 @@ async def generate_chat_completion(request: Request, data: ChatCompletionRequest
return EventSourceResponse(generator()) return EventSourceResponse(generator())
else: else:
response_text = model_container.generate(prompt, **data.to_gen_params()) response_text = model_container.generate(prompt, **data.to_gen_params())
response = create_chat_completion_response(response_text, model_path.name) response = create_chat_completion_response(response_text,
model_container.prompt_token_size,
model_container.completion_token_size,
model_path.name)
return response return response

View File

@@ -32,6 +32,8 @@ class ModelContainer:
draft_enabled: bool = False draft_enabled: bool = False
gpu_split_auto: bool = True gpu_split_auto: bool = True
gpu_split: list or None = None gpu_split: list or None = None
prompt_token_size: int = 0
completion_token_size: int = 0
def __init__(self, model_directory: pathlib.Path, quiet = False, **kwargs): def __init__(self, model_directory: pathlib.Path, quiet = False, **kwargs):
""" """
@@ -333,9 +335,11 @@ class ModelContainer:
encode_special_tokens = True encode_special_tokens = True
) )
self.prompt_token_size = ids.shape[-1]
# Begin # Begin
generated_tokens = 0 self.completion_token_size = 0
full_response = "" full_response = ""
start_time = time.time() start_time = time.time()
last_chunk_time = start_time last_chunk_time = start_time
@@ -369,7 +373,7 @@ class ModelContainer:
save_tokens = torch.cat((save_tokens, tokens), dim=-1) save_tokens = torch.cat((save_tokens, tokens), dim=-1)
chunk_buffer += chunk chunk_buffer += chunk
generated_tokens += 1 self.completion_token_size += 1
chunk_tokens -= 1 chunk_tokens -= 1
# Yield output # Yield output
@@ -377,21 +381,21 @@ class ModelContainer:
now = time.time() now = time.time()
elapsed = now - last_chunk_time elapsed = now - last_chunk_time
if chunk_buffer != "" and (elapsed > stream_interval or eos or generated_tokens == max_tokens): if chunk_buffer != "" and (elapsed > stream_interval or eos or self.completion_token_size == max_tokens):
yield chunk_buffer yield chunk_buffer
full_response += chunk_buffer full_response += chunk_buffer
chunk_buffer = "" chunk_buffer = ""
last_chunk_time = now last_chunk_time = now
if eos or generated_tokens == max_tokens: break if eos or self.completion_token_size == max_tokens: break
elapsed_time = last_chunk_time - start_time elapsed_time = last_chunk_time - start_time
initial_response = f"Response: {round(generated_tokens, 2)} tokens generated in {round(elapsed_time, 2)} seconds" initial_response = f"Response: {round(self.completion_token_size)} tokens generated in {round(elapsed_time, 2)} seconds"
extra_responses = [] extra_responses = []
# Add tokens per second # Add tokens per second
extra_responses.append(f"{'Indeterminate' if elapsed_time == 0 else round(generated_tokens / elapsed_time, 2)} T/s") extra_responses.append(f"{'Indeterminate' if elapsed_time == 0 else round(self.completion_token_size / elapsed_time, 2)} T/s")
# Add context (original token count) # Add context (original token count)
if ids is not None: if ids is not None: