mirror of
https://github.com/theroyallab/tabbyAPI.git
synced 2026-03-14 15:57:27 +00:00
API: Split functions into their own files
Previously, generation function were bundled with the request function causing the overall code structure and API to look ugly and unreadable. Split these up and cleanup a lot of the methods that were previously overlooked in the API itself. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
@@ -74,3 +74,16 @@ async def load_model_gen(model_path: pathlib.Path, **kwargs):
|
||||
async def load_model(model_path: pathlib.Path, **kwargs):
|
||||
async for _, _, _ in load_model_gen(model_path, **kwargs):
|
||||
pass
|
||||
|
||||
|
||||
def load_loras(lora_dir, **kwargs):
|
||||
"""Wrapper to load loras."""
|
||||
if len(container.active_loras) > 0:
|
||||
unload_loras()
|
||||
|
||||
return container.load_loras(lora_dir, **kwargs)
|
||||
|
||||
|
||||
def unload_loras():
|
||||
"""Wrapper to unload loras"""
|
||||
container.unload(loras_only=True)
|
||||
|
||||
@@ -1,30 +1,24 @@
|
||||
import pathlib
|
||||
from sse_starlette import EventSourceResponse
|
||||
import uvicorn
|
||||
from asyncio import CancelledError
|
||||
from uuid import uuid4
|
||||
from jinja2 import TemplateError
|
||||
from fastapi import FastAPI, Depends, HTTPException, Request
|
||||
from fastapi.concurrency import run_in_threadpool
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from functools import partial
|
||||
from loguru import logger
|
||||
from sse_starlette import EventSourceResponse
|
||||
|
||||
from common import config, model, gen_logging, sampling
|
||||
from common.auth import check_admin_key, check_api_key
|
||||
from common.generators import (
|
||||
call_with_semaphore,
|
||||
generate_with_semaphore,
|
||||
release_semaphore,
|
||||
)
|
||||
from common.logger import UVICORN_LOG_CONFIG
|
||||
from common.templating import (
|
||||
get_all_templates,
|
||||
get_prompt_from_template,
|
||||
get_template_from_file,
|
||||
)
|
||||
from common.utils import (
|
||||
get_generator_error,
|
||||
handle_request_error,
|
||||
unwrap,
|
||||
)
|
||||
@@ -39,7 +33,6 @@ from endpoints.OAI.types.lora import (
|
||||
from endpoints.OAI.types.model import (
|
||||
ModelCard,
|
||||
ModelLoadRequest,
|
||||
ModelLoadResponse,
|
||||
ModelCardParameters,
|
||||
)
|
||||
from endpoints.OAI.types.sampler_overrides import SamplerOverrideSwitchRequest
|
||||
@@ -50,12 +43,16 @@ from endpoints.OAI.types.token import (
|
||||
TokenDecodeRequest,
|
||||
TokenDecodeResponse,
|
||||
)
|
||||
from endpoints.OAI.utils.completion import (
|
||||
create_completion_response,
|
||||
create_chat_completion_response,
|
||||
create_chat_completion_stream_chunk,
|
||||
from endpoints.OAI.utils.chat_completion import (
|
||||
format_prompt_with_template,
|
||||
generate_chat_completion,
|
||||
stream_generate_chat_completion,
|
||||
)
|
||||
from endpoints.OAI.utils.model import get_model_list
|
||||
from endpoints.OAI.utils.completion import (
|
||||
generate_completion,
|
||||
stream_generate_completion,
|
||||
)
|
||||
from endpoints.OAI.utils.model import get_model_list, stream_model_load
|
||||
from endpoints.OAI.utils.lora import get_lora_list
|
||||
|
||||
app = FastAPI(
|
||||
@@ -169,73 +166,34 @@ async def load_model(request: Request, data: ModelLoadRequest):
|
||||
model_path = pathlib.Path(unwrap(config.model_config().get("model_dir"), "models"))
|
||||
model_path = model_path / data.name
|
||||
|
||||
load_data = data.model_dump()
|
||||
|
||||
draft_model_path = None
|
||||
if data.draft:
|
||||
if not data.draft.draft_model_name:
|
||||
raise HTTPException(
|
||||
400, "draft_model_name was not found inside the draft object."
|
||||
)
|
||||
|
||||
load_data["draft"]["draft_model_dir"] = unwrap(
|
||||
draft_model_path = unwrap(
|
||||
config.draft_model_config().get("draft_model_dir"), "models"
|
||||
)
|
||||
|
||||
if not model_path.exists():
|
||||
raise HTTPException(400, "model_path does not exist. Check model_name?")
|
||||
|
||||
async def generator():
|
||||
"""Request generation wrapper for the loading process."""
|
||||
load_callback = partial(
|
||||
stream_model_load, request, data, model_path, draft_model_path
|
||||
)
|
||||
|
||||
load_status = model.load_model_gen(model_path, **load_data)
|
||||
try:
|
||||
async for module, modules, model_type in load_status:
|
||||
if await request.is_disconnected():
|
||||
release_semaphore()
|
||||
logger.error(
|
||||
"Model load cancelled by user. "
|
||||
"Please make sure to run unload to free up resources."
|
||||
)
|
||||
return
|
||||
|
||||
if module != 0:
|
||||
response = ModelLoadResponse(
|
||||
model_type=model_type,
|
||||
module=module,
|
||||
modules=modules,
|
||||
status="processing",
|
||||
)
|
||||
|
||||
yield response.model_dump_json()
|
||||
|
||||
if module == modules:
|
||||
response = ModelLoadResponse(
|
||||
model_type=model_type,
|
||||
module=module,
|
||||
modules=modules,
|
||||
status="finished",
|
||||
)
|
||||
|
||||
yield response.model_dump_json()
|
||||
except CancelledError:
|
||||
logger.error(
|
||||
"Model load cancelled by user. "
|
||||
"Please make sure to run unload to free up resources."
|
||||
)
|
||||
except Exception as exc:
|
||||
yield get_generator_error(str(exc))
|
||||
|
||||
# Determine whether to use or skip the queue
|
||||
# Wrap in a semaphore if the queue isn't being skipped
|
||||
if data.skip_queue:
|
||||
logger.warning(
|
||||
"Model load request is skipping the completions queue. "
|
||||
"Unexpected results may occur."
|
||||
)
|
||||
generator_callback = generator
|
||||
else:
|
||||
generator_callback = partial(generate_with_semaphore, generator)
|
||||
load_callback = partial(generate_with_semaphore, load_callback)
|
||||
|
||||
return EventSourceResponse(generator_callback())
|
||||
return EventSourceResponse(load_callback())
|
||||
|
||||
|
||||
# Unload model endpoint
|
||||
@@ -363,6 +321,7 @@ async def get_active_loras():
|
||||
)
|
||||
async def load_lora(data: LoraLoadRequest):
|
||||
"""Loads a LoRA into the model container."""
|
||||
|
||||
if not data.loras:
|
||||
raise HTTPException(400, "List of loras to load is not found.")
|
||||
|
||||
@@ -373,28 +332,25 @@ async def load_lora(data: LoraLoadRequest):
|
||||
"A parent lora directory does not exist. Check your config.yml?",
|
||||
)
|
||||
|
||||
# Clean-up existing loras if present
|
||||
def load_loras_internal():
|
||||
if len(model.container.active_loras) > 0:
|
||||
unload_loras()
|
||||
load_callback = partial(
|
||||
run_in_threadpool, model.load_loras, lora_dir, **data.model_dump()
|
||||
)
|
||||
|
||||
result = model.container.load_loras(lora_dir, **data.model_dump())
|
||||
return LoraLoadResponse(
|
||||
success=unwrap(result.get("success"), []),
|
||||
failure=unwrap(result.get("failure"), []),
|
||||
)
|
||||
|
||||
internal_callback = partial(run_in_threadpool, load_loras_internal)
|
||||
|
||||
# Determine whether to skip the queue
|
||||
# Wrap in a semaphore if the queue isn't being skipped
|
||||
if data.skip_queue:
|
||||
logger.warning(
|
||||
"Lora load request is skipping the completions queue. "
|
||||
"Unexpected results may occur."
|
||||
)
|
||||
return await internal_callback()
|
||||
else:
|
||||
return await call_with_semaphore(internal_callback)
|
||||
load_callback = partial(call_with_semaphore, load_callback)
|
||||
|
||||
load_result = await load_callback()
|
||||
|
||||
return LoraLoadResponse(
|
||||
success=unwrap(load_result.get("success"), []),
|
||||
failure=unwrap(load_result.get("failure"), []),
|
||||
)
|
||||
|
||||
|
||||
# Unload lora endpoint
|
||||
@@ -404,7 +360,8 @@ async def load_lora(data: LoraLoadRequest):
|
||||
)
|
||||
async def unload_loras():
|
||||
"""Unloads the currently loaded loras."""
|
||||
model.container.unload(loras_only=True)
|
||||
|
||||
model.unload_loras()
|
||||
|
||||
|
||||
# Encode tokens endpoint
|
||||
@@ -439,7 +396,7 @@ async def decode_tokens(data: TokenDecodeRequest):
|
||||
"/v1/completions",
|
||||
dependencies=[Depends(check_api_key), Depends(check_model_container)],
|
||||
)
|
||||
async def generate_completion(request: Request, data: CompletionRequest):
|
||||
async def completion_request(request: Request, data: CompletionRequest):
|
||||
"""Generates a completion from a prompt."""
|
||||
model_path = model.container.get_model_path()
|
||||
|
||||
@@ -451,52 +408,17 @@ async def generate_completion(request: Request, data: CompletionRequest):
|
||||
)
|
||||
|
||||
if data.stream and not disable_request_streaming:
|
||||
|
||||
async def generator():
|
||||
try:
|
||||
new_generation = model.container.generate_gen(
|
||||
data.prompt, **data.to_gen_params()
|
||||
)
|
||||
for generation in new_generation:
|
||||
# Get out if the request gets disconnected
|
||||
if await request.is_disconnected():
|
||||
release_semaphore()
|
||||
logger.error("Completion generation cancelled by user.")
|
||||
return
|
||||
|
||||
response = create_completion_response(generation, model_path.name)
|
||||
|
||||
yield response.model_dump_json()
|
||||
|
||||
# Yield a finish response on successful generation
|
||||
yield "[DONE]"
|
||||
except Exception:
|
||||
yield get_generator_error(
|
||||
"Completion aborted. Please check the server console."
|
||||
)
|
||||
|
||||
return EventSourceResponse(generate_with_semaphore(generator))
|
||||
|
||||
try:
|
||||
generation = await call_with_semaphore(
|
||||
partial(
|
||||
run_in_threadpool,
|
||||
model.container.generate,
|
||||
data.prompt,
|
||||
**data.to_gen_params(),
|
||||
)
|
||||
generator_callback = partial(
|
||||
stream_generate_completion, request, data, model_path
|
||||
)
|
||||
|
||||
response = create_completion_response(generation, model_path.name)
|
||||
return response
|
||||
except Exception as exc:
|
||||
error_message = handle_request_error(
|
||||
"Completion aborted. Maybe the model was unloaded? "
|
||||
"Please check the server console."
|
||||
).error.message
|
||||
return EventSourceResponse(generate_with_semaphore(generator_callback))
|
||||
else:
|
||||
response = await call_with_semaphore(
|
||||
partial(generate_completion, data, model_path)
|
||||
)
|
||||
|
||||
# Server error if there's a generation exception
|
||||
raise HTTPException(503, error_message) from exc
|
||||
return response
|
||||
|
||||
|
||||
# Chat completions endpoint
|
||||
@@ -504,7 +426,7 @@ async def generate_completion(request: Request, data: CompletionRequest):
|
||||
"/v1/chat/completions",
|
||||
dependencies=[Depends(check_api_key), Depends(check_model_container)],
|
||||
)
|
||||
async def generate_chat_completion(request: Request, data: ChatCompletionRequest):
|
||||
async def chat_completion_request(request: Request, data: ChatCompletionRequest):
|
||||
"""Generates a chat completion from a prompt."""
|
||||
|
||||
if model.container.prompt_template is None:
|
||||
@@ -518,90 +440,24 @@ async def generate_chat_completion(request: Request, data: ChatCompletionRequest
|
||||
if isinstance(data.messages, str):
|
||||
prompt = data.messages
|
||||
else:
|
||||
try:
|
||||
special_tokens_dict = model.container.get_special_tokens(
|
||||
unwrap(data.add_bos_token, True),
|
||||
unwrap(data.ban_eos_token, False),
|
||||
)
|
||||
|
||||
prompt = get_prompt_from_template(
|
||||
data.messages,
|
||||
model.container.prompt_template,
|
||||
data.add_generation_prompt,
|
||||
special_tokens_dict,
|
||||
)
|
||||
except KeyError as exc:
|
||||
raise HTTPException(
|
||||
400,
|
||||
"Could not find a Conversation from prompt template "
|
||||
f"'{model.container.prompt_template.name}'. "
|
||||
"Check your spelling?",
|
||||
) from exc
|
||||
except TemplateError as exc:
|
||||
raise HTTPException(
|
||||
400,
|
||||
f"TemplateError: {str(exc)}",
|
||||
) from exc
|
||||
prompt = format_prompt_with_template(data)
|
||||
|
||||
disable_request_streaming = unwrap(
|
||||
config.developer_config().get("disable_request_streaming"), False
|
||||
)
|
||||
|
||||
if data.stream and not disable_request_streaming:
|
||||
const_id = f"chatcmpl-{uuid4().hex}"
|
||||
|
||||
async def generator():
|
||||
"""Generator for the generation process."""
|
||||
try:
|
||||
new_generation = model.container.generate_gen(
|
||||
prompt, **data.to_gen_params()
|
||||
)
|
||||
for generation in new_generation:
|
||||
# Get out if the request gets disconnected
|
||||
if await request.is_disconnected():
|
||||
release_semaphore()
|
||||
logger.error("Chat completion generation cancelled by user.")
|
||||
return
|
||||
|
||||
response = create_chat_completion_stream_chunk(
|
||||
const_id, generation, model_path.name
|
||||
)
|
||||
|
||||
yield response.model_dump_json()
|
||||
|
||||
# Yield a finish response on successful generation
|
||||
finish_response = create_chat_completion_stream_chunk(
|
||||
const_id, finish_reason="stop"
|
||||
)
|
||||
|
||||
yield finish_response.model_dump_json()
|
||||
except Exception:
|
||||
yield get_generator_error(
|
||||
"Chat completion aborted. Please check the server console."
|
||||
)
|
||||
|
||||
return EventSourceResponse(generate_with_semaphore(generator))
|
||||
|
||||
try:
|
||||
generation = await call_with_semaphore(
|
||||
partial(
|
||||
run_in_threadpool,
|
||||
model.container.generate,
|
||||
prompt,
|
||||
**data.to_gen_params(),
|
||||
)
|
||||
generator_callback = partial(
|
||||
stream_generate_chat_completion, prompt, request, data, model_path
|
||||
)
|
||||
|
||||
return EventSourceResponse(generate_with_semaphore(generator_callback))
|
||||
else:
|
||||
response = await call_with_semaphore(
|
||||
partial(generate_chat_completion, prompt, request, data, model_path)
|
||||
)
|
||||
response = create_chat_completion_response(generation, model_path.name)
|
||||
|
||||
return response
|
||||
except Exception as exc:
|
||||
error_message = handle_request_error(
|
||||
"Chat completion aborted. Maybe the model was unloaded? "
|
||||
"Please check the server console."
|
||||
).error.message
|
||||
|
||||
# Server error if there's a generation exception
|
||||
raise HTTPException(503, error_message) from exc
|
||||
|
||||
|
||||
def start_api(host: str, port: int):
|
||||
|
||||
199
endpoints/OAI/utils/chat_completion.py
Normal file
199
endpoints/OAI/utils/chat_completion.py
Normal file
@@ -0,0 +1,199 @@
|
||||
"""Chat completion utilities for OAI server."""
|
||||
import pathlib
|
||||
from typing import Optional
|
||||
from uuid import uuid4
|
||||
|
||||
from fastapi import HTTPException, Request
|
||||
from fastapi.concurrency import run_in_threadpool
|
||||
from jinja2 import TemplateError
|
||||
from loguru import logger
|
||||
|
||||
from common import model
|
||||
from common.generators import release_semaphore
|
||||
from common.templating import get_prompt_from_template
|
||||
from common.utils import get_generator_error, handle_request_error, unwrap
|
||||
from endpoints.OAI.types.chat_completion import (
|
||||
ChatCompletionLogprobs,
|
||||
ChatCompletionLogprob,
|
||||
ChatCompletionMessage,
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionRespChoice,
|
||||
ChatCompletionStreamChunk,
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionStreamChoice,
|
||||
)
|
||||
from endpoints.OAI.types.common import UsageStats
|
||||
|
||||
|
||||
def _create_response(generation: dict, model_name: Optional[str]):
|
||||
"""Create a chat completion response from the provided text."""
|
||||
|
||||
message = ChatCompletionMessage(
|
||||
role="assistant", content=unwrap(generation.get("text"), "")
|
||||
)
|
||||
|
||||
logprob_response = None
|
||||
|
||||
token_probs = unwrap(generation.get("token_probs"), {})
|
||||
if token_probs:
|
||||
logprobs = unwrap(generation.get("logprobs"), [])
|
||||
|
||||
collected_token_probs = []
|
||||
for index, token in enumerate(token_probs.keys()):
|
||||
top_logprobs = [
|
||||
ChatCompletionLogprob(token=token, logprob=logprob)
|
||||
for token, logprob in logprobs[index].items()
|
||||
]
|
||||
|
||||
collected_token_probs.append(
|
||||
ChatCompletionLogprob(
|
||||
token=token,
|
||||
logprob=token_probs[token],
|
||||
top_logprobs=top_logprobs,
|
||||
)
|
||||
)
|
||||
|
||||
logprob_response = ChatCompletionLogprobs(content=collected_token_probs)
|
||||
|
||||
choice = ChatCompletionRespChoice(
|
||||
finish_reason="Generated", message=message, logprobs=logprob_response
|
||||
)
|
||||
|
||||
prompt_tokens = unwrap(generation.get("prompt_tokens"), 0)
|
||||
completion_tokens = unwrap(generation.get("completion_tokens"), 0)
|
||||
|
||||
response = ChatCompletionResponse(
|
||||
choices=[choice],
|
||||
model=unwrap(model_name, ""),
|
||||
usage=UsageStats(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
),
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
def _create_stream_chunk(
|
||||
const_id: str,
|
||||
generation: Optional[dict] = None,
|
||||
model_name: Optional[str] = None,
|
||||
finish_reason: Optional[str] = None,
|
||||
):
|
||||
"""Create a chat completion stream chunk from the provided text."""
|
||||
|
||||
logprob_response = None
|
||||
|
||||
if finish_reason:
|
||||
message = {}
|
||||
else:
|
||||
message = ChatCompletionMessage(
|
||||
role="assistant", content=unwrap(generation.get("text"), "")
|
||||
)
|
||||
|
||||
token_probs = unwrap(generation.get("token_probs"), {})
|
||||
if token_probs:
|
||||
logprobs = unwrap(generation.get("logprobs"), {})
|
||||
top_logprobs = [
|
||||
ChatCompletionLogprob(token=token, logprob=logprob)
|
||||
for token, logprob in logprobs.items()
|
||||
]
|
||||
|
||||
generated_token = next(iter(token_probs))
|
||||
token_prob_response = ChatCompletionLogprob(
|
||||
token=generated_token,
|
||||
logprob=token_probs[generated_token],
|
||||
top_logprobs=top_logprobs,
|
||||
)
|
||||
|
||||
logprob_response = ChatCompletionLogprobs(content=[token_prob_response])
|
||||
|
||||
# The finish reason can be None
|
||||
choice = ChatCompletionStreamChoice(
|
||||
finish_reason=finish_reason, delta=message, logprobs=logprob_response
|
||||
)
|
||||
|
||||
chunk = ChatCompletionStreamChunk(
|
||||
id=const_id, choices=[choice], model=unwrap(model_name, "")
|
||||
)
|
||||
|
||||
return chunk
|
||||
|
||||
|
||||
def format_prompt_with_template(data: ChatCompletionRequest):
|
||||
try:
|
||||
special_tokens_dict = model.container.get_special_tokens(
|
||||
unwrap(data.add_bos_token, True),
|
||||
unwrap(data.ban_eos_token, False),
|
||||
)
|
||||
|
||||
return get_prompt_from_template(
|
||||
data.messages,
|
||||
model.container.prompt_template,
|
||||
data.add_generation_prompt,
|
||||
special_tokens_dict,
|
||||
)
|
||||
except KeyError as exc:
|
||||
raise HTTPException(
|
||||
400,
|
||||
"Could not find a Conversation from prompt template "
|
||||
f"'{model.container.prompt_template.name}'. "
|
||||
"Check your spelling?",
|
||||
) from exc
|
||||
except TemplateError as exc:
|
||||
raise HTTPException(
|
||||
400,
|
||||
f"TemplateError: {str(exc)}",
|
||||
) from exc
|
||||
|
||||
|
||||
async def stream_generate_chat_completion(
|
||||
prompt: str, request: Request, data: ChatCompletionRequest, model_path: pathlib.Path
|
||||
):
|
||||
"""Generator for the generation process."""
|
||||
try:
|
||||
const_id = f"chatcmpl-{uuid4().hex}"
|
||||
|
||||
new_generation = model.container.generate_gen(prompt, **data.to_gen_params())
|
||||
for generation in new_generation:
|
||||
# Get out if the request gets disconnected
|
||||
if await request.is_disconnected():
|
||||
release_semaphore()
|
||||
logger.error("Chat completion generation cancelled by user.")
|
||||
return
|
||||
|
||||
response = _create_stream_chunk(const_id, generation, model_path.name)
|
||||
|
||||
yield response.model_dump_json()
|
||||
|
||||
# Yield a finish response on successful generation
|
||||
finish_response = _create_stream_chunk(const_id, finish_reason="stop")
|
||||
|
||||
yield finish_response.model_dump_json()
|
||||
except Exception:
|
||||
yield get_generator_error(
|
||||
"Chat completion aborted. Please check the server console."
|
||||
)
|
||||
|
||||
|
||||
async def generate_chat_completion(
|
||||
prompt: str, request: Request, data: ChatCompletionRequest, model_path: pathlib.Path
|
||||
):
|
||||
try:
|
||||
generation = await run_in_threadpool(
|
||||
model.container.generate,
|
||||
prompt,
|
||||
**data.to_gen_params(),
|
||||
)
|
||||
response = _create_response(generation, model_path.name)
|
||||
|
||||
return response
|
||||
except Exception as exc:
|
||||
error_message = handle_request_error(
|
||||
"Chat completion aborted. Maybe the model was unloaded? "
|
||||
"Please check the server console."
|
||||
).error.message
|
||||
|
||||
# Server error if there's a generation exception
|
||||
raise HTTPException(503, error_message) from exc
|
||||
@@ -1,17 +1,15 @@
|
||||
""" Utility functions for the OpenAI server. """
|
||||
"""Completion utilities for OAI server."""
|
||||
import pathlib
|
||||
from fastapi import HTTPException, Request
|
||||
from fastapi.concurrency import run_in_threadpool
|
||||
from loguru import logger
|
||||
from typing import Optional
|
||||
|
||||
from common.utils import unwrap
|
||||
from endpoints.OAI.types.chat_completion import (
|
||||
ChatCompletionLogprobs,
|
||||
ChatCompletionLogprob,
|
||||
ChatCompletionMessage,
|
||||
ChatCompletionRespChoice,
|
||||
ChatCompletionStreamChunk,
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionStreamChoice,
|
||||
)
|
||||
from common import model
|
||||
from common.generators import release_semaphore
|
||||
from common.utils import get_generator_error, handle_request_error, unwrap
|
||||
from endpoints.OAI.types.completion import (
|
||||
CompletionRequest,
|
||||
CompletionResponse,
|
||||
CompletionRespChoice,
|
||||
CompletionLogProbs,
|
||||
@@ -19,7 +17,7 @@ from endpoints.OAI.types.completion import (
|
||||
from endpoints.OAI.types.common import UsageStats
|
||||
|
||||
|
||||
def create_completion_response(generation: dict, model_name: Optional[str]):
|
||||
def _create_response(generation: dict, model_name: Optional[str]):
|
||||
"""Create a completion response from the provided text."""
|
||||
|
||||
logprob_response = None
|
||||
@@ -58,97 +56,49 @@ def create_completion_response(generation: dict, model_name: Optional[str]):
|
||||
return response
|
||||
|
||||
|
||||
def create_chat_completion_response(generation: dict, model_name: Optional[str]):
|
||||
"""Create a chat completion response from the provided text."""
|
||||
|
||||
message = ChatCompletionMessage(
|
||||
role="assistant", content=unwrap(generation.get("text"), "")
|
||||
)
|
||||
|
||||
logprob_response = None
|
||||
|
||||
token_probs = unwrap(generation.get("token_probs"), {})
|
||||
if token_probs:
|
||||
logprobs = unwrap(generation.get("logprobs"), [])
|
||||
|
||||
collected_token_probs = []
|
||||
for index, token in enumerate(token_probs.keys()):
|
||||
top_logprobs = [
|
||||
ChatCompletionLogprob(token=token, logprob=logprob)
|
||||
for token, logprob in logprobs[index].items()
|
||||
]
|
||||
|
||||
collected_token_probs.append(
|
||||
ChatCompletionLogprob(
|
||||
token=token,
|
||||
logprob=token_probs[token],
|
||||
top_logprobs=top_logprobs,
|
||||
)
|
||||
)
|
||||
|
||||
logprob_response = ChatCompletionLogprobs(content=collected_token_probs)
|
||||
|
||||
choice = ChatCompletionRespChoice(
|
||||
finish_reason="Generated", message=message, logprobs=logprob_response
|
||||
)
|
||||
|
||||
prompt_tokens = unwrap(generation.get("prompt_tokens"), 0)
|
||||
completion_tokens = unwrap(generation.get("completion_tokens"), 0)
|
||||
|
||||
response = ChatCompletionResponse(
|
||||
choices=[choice],
|
||||
model=unwrap(model_name, ""),
|
||||
usage=UsageStats(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
),
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
def create_chat_completion_stream_chunk(
|
||||
const_id: str,
|
||||
generation: Optional[dict] = None,
|
||||
model_name: Optional[str] = None,
|
||||
finish_reason: Optional[str] = None,
|
||||
async def stream_generate_completion(
|
||||
request: Request, data: CompletionRequest, model_path: pathlib.Path
|
||||
):
|
||||
"""Create a chat completion stream chunk from the provided text."""
|
||||
"""Streaming generation for completions."""
|
||||
|
||||
logprob_response = None
|
||||
try:
|
||||
new_generation = model.container.generate_gen(
|
||||
data.prompt, **data.to_gen_params()
|
||||
)
|
||||
for generation in new_generation:
|
||||
# Get out if the request gets disconnected
|
||||
if await request.is_disconnected():
|
||||
release_semaphore()
|
||||
logger.error("Completion generation cancelled by user.")
|
||||
return
|
||||
|
||||
if finish_reason:
|
||||
message = {}
|
||||
else:
|
||||
message = ChatCompletionMessage(
|
||||
role="assistant", content=unwrap(generation.get("text"), "")
|
||||
response = _create_response(generation, model_path.name)
|
||||
|
||||
yield response.model_dump_json()
|
||||
|
||||
# Yield a finish response on successful generation
|
||||
yield "[DONE]"
|
||||
except Exception:
|
||||
yield get_generator_error(
|
||||
"Completion aborted. Please check the server console."
|
||||
)
|
||||
|
||||
token_probs = unwrap(generation.get("token_probs"), {})
|
||||
if token_probs:
|
||||
logprobs = unwrap(generation.get("logprobs"), {})
|
||||
top_logprobs = [
|
||||
ChatCompletionLogprob(token=token, logprob=logprob)
|
||||
for token, logprob in logprobs.items()
|
||||
]
|
||||
|
||||
generated_token = next(iter(token_probs))
|
||||
token_prob_response = ChatCompletionLogprob(
|
||||
token=generated_token,
|
||||
logprob=token_probs[generated_token],
|
||||
top_logprobs=top_logprobs,
|
||||
)
|
||||
async def generate_completion(data: CompletionRequest, model_path: pathlib.Path):
|
||||
"""Non-streaming generate for completions"""
|
||||
|
||||
logprob_response = ChatCompletionLogprobs(content=[token_prob_response])
|
||||
try:
|
||||
generation = await run_in_threadpool(
|
||||
model.container.generate, data.prompt, **data.to_gen_params()
|
||||
)
|
||||
|
||||
# The finish reason can be None
|
||||
choice = ChatCompletionStreamChoice(
|
||||
finish_reason=finish_reason, delta=message, logprobs=logprob_response
|
||||
)
|
||||
response = _create_response(generation, model_path.name)
|
||||
return response
|
||||
except Exception as exc:
|
||||
error_message = handle_request_error(
|
||||
"Completion aborted. Maybe the model was unloaded? "
|
||||
"Please check the server console."
|
||||
).error.message
|
||||
|
||||
chunk = ChatCompletionStreamChunk(
|
||||
id=const_id, choices=[choice], model=unwrap(model_name, "")
|
||||
)
|
||||
|
||||
return chunk
|
||||
# Server error if there's a generation exception
|
||||
raise HTTPException(503, error_message) from exc
|
||||
|
||||
@@ -1,7 +1,19 @@
|
||||
import pathlib
|
||||
from asyncio import CancelledError
|
||||
from fastapi import Request
|
||||
from loguru import logger
|
||||
from typing import Optional
|
||||
|
||||
from endpoints.OAI.types.model import ModelCard, ModelList
|
||||
from common import model
|
||||
from common.generators import release_semaphore
|
||||
from common.utils import get_generator_error
|
||||
|
||||
from endpoints.OAI.types.model import (
|
||||
ModelCard,
|
||||
ModelList,
|
||||
ModelLoadRequest,
|
||||
ModelLoadResponse,
|
||||
)
|
||||
|
||||
|
||||
def get_model_list(model_path: pathlib.Path, draft_model_path: Optional[str] = None):
|
||||
@@ -20,3 +32,55 @@ def get_model_list(model_path: pathlib.Path, draft_model_path: Optional[str] = N
|
||||
model_card_list.data.append(model_card) # pylint: disable=no-member
|
||||
|
||||
return model_card_list
|
||||
|
||||
|
||||
async def stream_model_load(
|
||||
request: Request,
|
||||
data: ModelLoadRequest,
|
||||
model_path: pathlib.Path,
|
||||
draft_model_path: str,
|
||||
):
|
||||
"""Request generation wrapper for the loading process."""
|
||||
|
||||
# Set the draft model path if it exists
|
||||
load_data = data.model_dump()
|
||||
if draft_model_path:
|
||||
load_data["draft"]["draft_model_dir"] = draft_model_path
|
||||
|
||||
load_status = model.load_model_gen(model_path, **load_data)
|
||||
try:
|
||||
async for module, modules, model_type in load_status:
|
||||
if await request.is_disconnected():
|
||||
release_semaphore()
|
||||
logger.error(
|
||||
"Model load cancelled by user. "
|
||||
"Please make sure to run unload to free up resources."
|
||||
)
|
||||
return
|
||||
|
||||
if module != 0:
|
||||
response = ModelLoadResponse(
|
||||
model_type=model_type,
|
||||
module=module,
|
||||
modules=modules,
|
||||
status="processing",
|
||||
)
|
||||
|
||||
yield response.model_dump_json()
|
||||
|
||||
if module == modules:
|
||||
response = ModelLoadResponse(
|
||||
model_type=model_type,
|
||||
module=module,
|
||||
modules=modules,
|
||||
status="finished",
|
||||
)
|
||||
|
||||
yield response.model_dump_json()
|
||||
except CancelledError:
|
||||
logger.error(
|
||||
"Model load cancelled by user. "
|
||||
"Please make sure to run unload to free up resources."
|
||||
)
|
||||
except Exception as exc:
|
||||
yield get_generator_error(str(exc))
|
||||
|
||||
Reference in New Issue
Block a user