mirror of
https://github.com/theroyallab/tabbyAPI.git
synced 2026-03-14 15:57:27 +00:00
Merge pull request #13 from ziadloo/main
Adding the usage stat support (prompt_tokens, completion_tokens, and total_tokens)
This commit is contained in:
@@ -32,8 +32,6 @@ class ChatCompletionResponse(BaseModel):
|
||||
created: int = Field(default_factory=lambda: int(time()))
|
||||
model: str
|
||||
object: str = "chat.completion"
|
||||
|
||||
# TODO: Add usage stats
|
||||
usage: Optional[UsageStats] = None
|
||||
|
||||
class ChatCompletionStreamChunk(BaseModel):
|
||||
|
||||
@@ -8,8 +8,8 @@ class LogProbs(BaseModel):
|
||||
top_logprobs: List[Dict[str, float]] = Field(default_factory=list)
|
||||
|
||||
class UsageStats(BaseModel):
|
||||
completion_tokens: int
|
||||
prompt_tokens: int
|
||||
completion_tokens: int
|
||||
total_tokens: int
|
||||
|
||||
class CommonCompletionRequest(BaseModel):
|
||||
|
||||
@@ -22,6 +22,4 @@ class CompletionResponse(BaseModel):
|
||||
created: int = Field(default_factory=lambda: int(time()))
|
||||
model: str
|
||||
object: str = "text_completion"
|
||||
|
||||
# TODO: Add usage stats
|
||||
usage: Optional[UsageStats] = None
|
||||
|
||||
20
OAI/utils.py
20
OAI/utils.py
@@ -1,5 +1,5 @@
|
||||
import os, pathlib
|
||||
from OAI.types.completion import CompletionResponse, CompletionRespChoice
|
||||
from OAI.types.completion import CompletionResponse, CompletionRespChoice, UsageStats
|
||||
from OAI.types.chat_completion import (
|
||||
ChatCompletionMessage,
|
||||
ChatCompletionRespChoice,
|
||||
@@ -20,9 +20,7 @@ try:
|
||||
except ImportError:
|
||||
_fastchat_available = False
|
||||
|
||||
def create_completion_response(text: str, model_name: Optional[str]):
|
||||
# TODO: Add method to get token amounts in model for UsageStats
|
||||
|
||||
def create_completion_response(text: str, prompt_tokens: int, completion_tokens: int, model_name: Optional[str]):
|
||||
choice = CompletionRespChoice(
|
||||
finish_reason = "Generated",
|
||||
text = text
|
||||
@@ -30,14 +28,15 @@ def create_completion_response(text: str, model_name: Optional[str]):
|
||||
|
||||
response = CompletionResponse(
|
||||
choices = [choice],
|
||||
model = model_name or ""
|
||||
model = model_name or "",
|
||||
usage = UsageStats(prompt_tokens = prompt_tokens,
|
||||
completion_tokens = completion_tokens,
|
||||
total_tokens = prompt_tokens + completion_tokens)
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
def create_chat_completion_response(text: str, model_name: Optional[str]):
|
||||
# TODO: Add method to get token amounts in model for UsageStats
|
||||
|
||||
def create_chat_completion_response(text: str, prompt_tokens: int, completion_tokens: int, model_name: Optional[str]):
|
||||
message = ChatCompletionMessage(
|
||||
role = "assistant",
|
||||
content = text
|
||||
@@ -50,7 +49,10 @@ def create_chat_completion_response(text: str, model_name: Optional[str]):
|
||||
|
||||
response = ChatCompletionResponse(
|
||||
choices = [choice],
|
||||
model = model_name or ""
|
||||
model = model_name or "",
|
||||
usage = UsageStats(prompt_tokens = prompt_tokens,
|
||||
completion_tokens = completion_tokens,
|
||||
total_tokens = prompt_tokens + completion_tokens)
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
23
main.py
23
main.py
@@ -188,11 +188,14 @@ async def generate_completion(request: Request, data: CompletionRequest):
|
||||
async def generator():
|
||||
try:
|
||||
new_generation = model_container.generate_gen(data.prompt, **data.to_gen_params())
|
||||
for part in new_generation:
|
||||
for (part, prompt_tokens, completion_tokens) in new_generation:
|
||||
if await request.is_disconnected():
|
||||
break
|
||||
|
||||
response = create_completion_response(part, model_path.name)
|
||||
response = create_completion_response(part,
|
||||
prompt_tokens,
|
||||
completion_tokens,
|
||||
model_path.name)
|
||||
|
||||
yield response.json(ensure_ascii=False)
|
||||
except Exception as e:
|
||||
@@ -200,8 +203,11 @@ async def generate_completion(request: Request, data: CompletionRequest):
|
||||
|
||||
return EventSourceResponse(generator())
|
||||
else:
|
||||
response_text = model_container.generate(data.prompt, **data.to_gen_params())
|
||||
response = create_completion_response(response_text, model_path.name)
|
||||
response_text, prompt_tokens, completion_tokens = model_container.generate(data.prompt, **data.to_gen_params())
|
||||
response = create_completion_response(response_text,
|
||||
prompt_tokens,
|
||||
completion_tokens,
|
||||
model_path.name)
|
||||
|
||||
return response
|
||||
|
||||
@@ -219,7 +225,7 @@ async def generate_chat_completion(request: Request, data: ChatCompletionRequest
|
||||
const_id = f"chatcmpl-{uuid4().hex}"
|
||||
async def generator():
|
||||
try:
|
||||
new_generation = model_container.generate_gen(prompt, **data.to_gen_params())
|
||||
new_generation, prompt_tokens, completion_tokens = model_container.generate_gen(prompt, **data.to_gen_params())
|
||||
for part in new_generation:
|
||||
if await request.is_disconnected():
|
||||
break
|
||||
@@ -236,8 +242,11 @@ async def generate_chat_completion(request: Request, data: ChatCompletionRequest
|
||||
|
||||
return EventSourceResponse(generator())
|
||||
else:
|
||||
response_text = model_container.generate(prompt, **data.to_gen_params())
|
||||
response = create_chat_completion_response(response_text, model_path.name)
|
||||
response_text, prompt_tokens, completion_tokens = model_container.generate(prompt, **data.to_gen_params())
|
||||
response = create_chat_completion_response(response_text,
|
||||
prompt_tokens,
|
||||
completion_tokens,
|
||||
model_path.name)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
12
model.py
12
model.py
@@ -226,9 +226,9 @@ class ModelContainer:
|
||||
|
||||
|
||||
def generate(self, prompt: str, **kwargs):
|
||||
gen = self.generate_gen(prompt, **kwargs)
|
||||
reponse = "".join(gen)
|
||||
return reponse
|
||||
gen = list(self.generate_gen(prompt, **kwargs))
|
||||
reponse = "".join(map(lambda o: o[0], gen))
|
||||
return reponse, gen[-1][1], gen[-1][2]
|
||||
|
||||
def generate_gen(self, prompt: str, **kwargs):
|
||||
"""
|
||||
@@ -345,6 +345,8 @@ class ModelContainer:
|
||||
"Generation is truncated and metrics may not be accurate."
|
||||
)
|
||||
|
||||
prompt_tokens = ids.shape[-1]
|
||||
|
||||
# Begin
|
||||
|
||||
generated_tokens = 0
|
||||
@@ -390,7 +392,7 @@ class ModelContainer:
|
||||
elapsed = now - last_chunk_time
|
||||
|
||||
if chunk_buffer != "" and (elapsed > stream_interval or eos or generated_tokens == max_tokens):
|
||||
yield chunk_buffer
|
||||
yield chunk_buffer, prompt_tokens, generated_tokens
|
||||
full_response += chunk_buffer
|
||||
chunk_buffer = ""
|
||||
last_chunk_time = now
|
||||
@@ -399,7 +401,7 @@ class ModelContainer:
|
||||
|
||||
elapsed_time = last_chunk_time - start_time
|
||||
|
||||
initial_response = f"Response: {round(generated_tokens, 2)} tokens generated in {round(elapsed_time, 2)} seconds"
|
||||
initial_response = f"Response: {generated_tokens} tokens generated in {round(elapsed_time, 2)} seconds"
|
||||
itemization = []
|
||||
extra_parts = []
|
||||
|
||||
|
||||
Reference in New Issue
Block a user