mirror of
https://github.com/theroyallab/tabbyAPI.git
synced 2026-04-28 18:21:42 +00:00
Tree: Switch to async generators
Async generation helps remove many roadblocks to managing tasks using threads. It should allow for abortables and modern-day paradigms. NOTE: Exllamav2 itself is not an asynchronous library. It's just been added into tabby's async nature to allow for a fast and concurrent API server. It's still being debated to run stream_ex in a separate thread or manually manage it using asyncio.sleep(0) Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
@@ -1,5 +1,6 @@
|
|||||||
"""The model container class for ExLlamaV2 models."""
|
"""The model container class for ExLlamaV2 models."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import gc
|
import gc
|
||||||
from itertools import zip_longest
|
from itertools import zip_longest
|
||||||
import pathlib
|
import pathlib
|
||||||
@@ -325,7 +326,7 @@ class ExllamaV2Container:
|
|||||||
|
|
||||||
return model_params
|
return model_params
|
||||||
|
|
||||||
def load(self, progress_callback=None):
|
async def load(self, progress_callback=None):
|
||||||
"""
|
"""
|
||||||
Load model
|
Load model
|
||||||
|
|
||||||
@@ -338,7 +339,7 @@ class ExllamaV2Container:
|
|||||||
for _ in self.load_gen(progress_callback):
|
for _ in self.load_gen(progress_callback):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def load_loras(self, lora_directory: pathlib.Path, **kwargs):
|
async def load_loras(self, lora_directory: pathlib.Path, **kwargs):
|
||||||
"""
|
"""
|
||||||
Load loras
|
Load loras
|
||||||
"""
|
"""
|
||||||
@@ -361,7 +362,7 @@ class ExllamaV2Container:
|
|||||||
|
|
||||||
logger.info(f"Loading lora: {lora_name} at scaling {lora_scaling}")
|
logger.info(f"Loading lora: {lora_name} at scaling {lora_scaling}")
|
||||||
lora_path = lora_directory / lora_name
|
lora_path = lora_directory / lora_name
|
||||||
# FIXME(alpin): Does self.model need to be passed here?
|
|
||||||
self.active_loras.append(
|
self.active_loras.append(
|
||||||
ExLlamaV2Lora.from_directory(self.model, lora_path, lora_scaling)
|
ExLlamaV2Lora.from_directory(self.model, lora_path, lora_scaling)
|
||||||
)
|
)
|
||||||
@@ -371,7 +372,7 @@ class ExllamaV2Container:
|
|||||||
# Return success and failure names
|
# Return success and failure names
|
||||||
return {"success": success, "failure": failure}
|
return {"success": success, "failure": failure}
|
||||||
|
|
||||||
def load_gen(self, progress_callback=None):
|
async def load_gen(self, progress_callback=None):
|
||||||
"""
|
"""
|
||||||
Load model, generator function
|
Load model, generator function
|
||||||
|
|
||||||
@@ -400,12 +401,16 @@ class ExllamaV2Container:
|
|||||||
logger.info("Loading draft model: " + self.draft_config.model_dir)
|
logger.info("Loading draft model: " + self.draft_config.model_dir)
|
||||||
|
|
||||||
self.draft_cache = ExLlamaV2Cache(self.draft_model, lazy=True)
|
self.draft_cache = ExLlamaV2Cache(self.draft_model, lazy=True)
|
||||||
yield from self.draft_model.load_autosplit_gen(
|
for value in self.draft_model.load_autosplit_gen(
|
||||||
self.draft_cache,
|
self.draft_cache,
|
||||||
reserve_vram=autosplit_reserve,
|
reserve_vram=autosplit_reserve,
|
||||||
last_id_only=True,
|
last_id_only=True,
|
||||||
callback_gen=progress_callback,
|
callback_gen=progress_callback,
|
||||||
)
|
):
|
||||||
|
# Manually suspend the task to allow for other stuff to run
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
if value:
|
||||||
|
yield value
|
||||||
|
|
||||||
# Test VRAM allocation with a full-length forward pass
|
# Test VRAM allocation with a full-length forward pass
|
||||||
input_ids = torch.zeros((1, self.config.max_input_len), dtype=torch.long)
|
input_ids = torch.zeros((1, self.config.max_input_len), dtype=torch.long)
|
||||||
@@ -424,6 +429,8 @@ class ExllamaV2Container:
|
|||||||
self.gpu_split,
|
self.gpu_split,
|
||||||
callback_gen=progress_callback,
|
callback_gen=progress_callback,
|
||||||
):
|
):
|
||||||
|
# Manually suspend the task to allow for other stuff to run
|
||||||
|
await asyncio.sleep(0)
|
||||||
if value:
|
if value:
|
||||||
yield value
|
yield value
|
||||||
|
|
||||||
@@ -452,6 +459,8 @@ class ExllamaV2Container:
|
|||||||
last_id_only=True,
|
last_id_only=True,
|
||||||
callback_gen=progress_callback,
|
callback_gen=progress_callback,
|
||||||
):
|
):
|
||||||
|
# Manually suspend the task to allow for other stuff to run
|
||||||
|
await asyncio.sleep(0)
|
||||||
if value:
|
if value:
|
||||||
yield value
|
yield value
|
||||||
|
|
||||||
@@ -565,9 +574,11 @@ class ExllamaV2Container:
|
|||||||
|
|
||||||
return dict(zip_longest(top_tokens, cleaned_values))
|
return dict(zip_longest(top_tokens, cleaned_values))
|
||||||
|
|
||||||
def generate(self, prompt: str, **kwargs):
|
async def generate(self, prompt: str, **kwargs):
|
||||||
"""Generate a response to a prompt"""
|
"""Generate a response to a prompt"""
|
||||||
generations = list(self.generate_gen(prompt, **kwargs))
|
generations = []
|
||||||
|
async for generation in self.generate_gen(prompt, **kwargs):
|
||||||
|
generations.append(generation)
|
||||||
|
|
||||||
joined_generation = {
|
joined_generation = {
|
||||||
"text": "",
|
"text": "",
|
||||||
@@ -615,8 +626,7 @@ class ExllamaV2Container:
|
|||||||
|
|
||||||
return kwargs
|
return kwargs
|
||||||
|
|
||||||
# pylint: disable=too-many-locals,too-many-branches,too-many-statements
|
async def generate_gen(self, prompt: str, **kwargs):
|
||||||
def generate_gen(self, prompt: str, **kwargs):
|
|
||||||
"""
|
"""
|
||||||
Create generator function for prompt completion
|
Create generator function for prompt completion
|
||||||
|
|
||||||
@@ -889,6 +899,9 @@ class ExllamaV2Container:
|
|||||||
chunk_tokens = 0
|
chunk_tokens = 0
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
|
# Manually suspend the task to allow for other stuff to run
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
|
||||||
# Ingest prompt
|
# Ingest prompt
|
||||||
if chunk_tokens == 0:
|
if chunk_tokens == 0:
|
||||||
ids = torch.cat((ids, save_tokens), dim=-1)
|
ids = torch.cat((ids, save_tokens), dim=-1)
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
"""Generator handling"""
|
"""Concurrency handling"""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import inspect
|
import inspect
|
||||||
@@ -52,7 +52,7 @@ async def load_model_gen(model_path: pathlib.Path, **kwargs):
|
|||||||
progress.start()
|
progress.start()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
for module, modules in load_status:
|
async for module, modules in load_status:
|
||||||
if module == 0:
|
if module == 0:
|
||||||
loading_task = progress.add_task(
|
loading_task = progress.add_task(
|
||||||
f"[cyan]Loading {model_type} modules", total=modules
|
f"[cyan]Loading {model_type} modules", total=modules
|
||||||
@@ -76,12 +76,12 @@ async def load_model(model_path: pathlib.Path, **kwargs):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
def load_loras(lora_dir, **kwargs):
|
async def load_loras(lora_dir, **kwargs):
|
||||||
"""Wrapper to load loras."""
|
"""Wrapper to load loras."""
|
||||||
if len(container.active_loras) > 0:
|
if len(container.active_loras) > 0:
|
||||||
unload_loras()
|
unload_loras()
|
||||||
|
|
||||||
return container.load_loras(lora_dir, **kwargs)
|
return await container.load_loras(lora_dir, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def unload_loras():
|
def unload_loras():
|
||||||
|
|||||||
@@ -6,6 +6,8 @@ from loguru import logger
|
|||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
from common.concurrency import release_semaphore
|
||||||
|
|
||||||
|
|
||||||
def load_progress(module, modules):
|
def load_progress(module, modules):
|
||||||
"""Wrapper callback for load progress."""
|
"""Wrapper callback for load progress."""
|
||||||
@@ -51,6 +53,13 @@ def handle_request_error(message: str, exc_info: bool = True):
|
|||||||
return request_error
|
return request_error
|
||||||
|
|
||||||
|
|
||||||
|
def handle_request_disconnect(message: str):
|
||||||
|
"""Wrapper for handling for request disconnection."""
|
||||||
|
|
||||||
|
release_semaphore()
|
||||||
|
logger.error(message)
|
||||||
|
|
||||||
|
|
||||||
def unwrap(wrapped, default=None):
|
def unwrap(wrapped, default=None):
|
||||||
"""Unwrap function for Optionals."""
|
"""Unwrap function for Optionals."""
|
||||||
if wrapped is None:
|
if wrapped is None:
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
import pathlib
|
import pathlib
|
||||||
import uvicorn
|
import uvicorn
|
||||||
from fastapi import FastAPI, Depends, HTTPException, Request
|
from fastapi import FastAPI, Depends, HTTPException, Request
|
||||||
from fastapi.concurrency import run_in_threadpool
|
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
@@ -10,7 +9,7 @@ from sys import maxsize
|
|||||||
|
|
||||||
from common import config, model, gen_logging, sampling
|
from common import config, model, gen_logging, sampling
|
||||||
from common.auth import check_admin_key, check_api_key
|
from common.auth import check_admin_key, check_api_key
|
||||||
from common.generators import (
|
from common.concurrency import (
|
||||||
call_with_semaphore,
|
call_with_semaphore,
|
||||||
generate_with_semaphore,
|
generate_with_semaphore,
|
||||||
)
|
)
|
||||||
@@ -181,9 +180,7 @@ async def load_model(request: Request, data: ModelLoadRequest):
|
|||||||
if not model_path.exists():
|
if not model_path.exists():
|
||||||
raise HTTPException(400, "model_path does not exist. Check model_name?")
|
raise HTTPException(400, "model_path does not exist. Check model_name?")
|
||||||
|
|
||||||
load_callback = partial(
|
load_callback = partial(stream_model_load, data, model_path, draft_model_path)
|
||||||
stream_model_load, request, data, model_path, draft_model_path
|
|
||||||
)
|
|
||||||
|
|
||||||
# Wrap in a semaphore if the queue isn't being skipped
|
# Wrap in a semaphore if the queue isn't being skipped
|
||||||
if data.skip_queue:
|
if data.skip_queue:
|
||||||
@@ -333,9 +330,7 @@ async def load_lora(data: LoraLoadRequest):
|
|||||||
"A parent lora directory does not exist. Check your config.yml?",
|
"A parent lora directory does not exist. Check your config.yml?",
|
||||||
)
|
)
|
||||||
|
|
||||||
load_callback = partial(
|
load_callback = partial(model.load_loras, lora_dir, **data.model_dump())
|
||||||
run_in_threadpool, model.load_loras, lora_dir, **data.model_dump()
|
|
||||||
)
|
|
||||||
|
|
||||||
# Wrap in a semaphore if the queue isn't being skipped
|
# Wrap in a semaphore if the queue isn't being skipped
|
||||||
if data.skip_queue:
|
if data.skip_queue:
|
||||||
@@ -409,9 +404,7 @@ async def completion_request(request: Request, data: CompletionRequest):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if data.stream and not disable_request_streaming:
|
if data.stream and not disable_request_streaming:
|
||||||
generator_callback = partial(
|
generator_callback = partial(stream_generate_completion, data, model_path)
|
||||||
stream_generate_completion, request, data, model_path
|
|
||||||
)
|
|
||||||
|
|
||||||
return EventSourceResponse(
|
return EventSourceResponse(
|
||||||
generate_with_semaphore(generator_callback),
|
generate_with_semaphore(generator_callback),
|
||||||
@@ -452,7 +445,7 @@ async def chat_completion_request(request: Request, data: ChatCompletionRequest)
|
|||||||
|
|
||||||
if data.stream and not disable_request_streaming:
|
if data.stream and not disable_request_streaming:
|
||||||
generator_callback = partial(
|
generator_callback = partial(
|
||||||
stream_generate_chat_completion, prompt, request, data, model_path
|
stream_generate_chat_completion, prompt, data, model_path
|
||||||
)
|
)
|
||||||
|
|
||||||
return EventSourceResponse(
|
return EventSourceResponse(
|
||||||
@@ -461,13 +454,13 @@ async def chat_completion_request(request: Request, data: ChatCompletionRequest)
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
response = await call_with_semaphore(
|
response = await call_with_semaphore(
|
||||||
partial(generate_chat_completion, prompt, request, data, model_path)
|
partial(generate_chat_completion, prompt, data, model_path)
|
||||||
)
|
)
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
||||||
def start_api(host: str, port: int):
|
async def start_api(host: str, port: int):
|
||||||
"""Isolated function to start the API server"""
|
"""Isolated function to start the API server"""
|
||||||
|
|
||||||
# TODO: Move OAI API to a separate folder
|
# TODO: Move OAI API to a separate folder
|
||||||
@@ -475,9 +468,12 @@ def start_api(host: str, port: int):
|
|||||||
logger.info(f"Completions: http://{host}:{port}/v1/completions")
|
logger.info(f"Completions: http://{host}:{port}/v1/completions")
|
||||||
logger.info(f"Chat completions: http://{host}:{port}/v1/chat/completions")
|
logger.info(f"Chat completions: http://{host}:{port}/v1/chat/completions")
|
||||||
|
|
||||||
uvicorn.run(
|
config = uvicorn.Config(
|
||||||
app,
|
app,
|
||||||
host=host,
|
host=host,
|
||||||
port=port,
|
port=port,
|
||||||
log_config=UVICORN_LOG_CONFIG,
|
log_config=UVICORN_LOG_CONFIG,
|
||||||
)
|
)
|
||||||
|
server = uvicorn.Server(config)
|
||||||
|
|
||||||
|
await server.serve()
|
||||||
|
|||||||
@@ -1,18 +1,21 @@
|
|||||||
"""Chat completion utilities for OAI server."""
|
"""Chat completion utilities for OAI server."""
|
||||||
|
|
||||||
|
from asyncio import CancelledError
|
||||||
import pathlib
|
import pathlib
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
from fastapi import HTTPException, Request
|
from fastapi import HTTPException
|
||||||
from fastapi.concurrency import run_in_threadpool
|
|
||||||
from jinja2 import TemplateError
|
from jinja2 import TemplateError
|
||||||
from loguru import logger
|
|
||||||
|
|
||||||
from common import model
|
from common import model
|
||||||
from common.generators import release_semaphore
|
|
||||||
from common.templating import get_prompt_from_template
|
from common.templating import get_prompt_from_template
|
||||||
from common.utils import get_generator_error, handle_request_error, unwrap
|
from common.utils import (
|
||||||
|
get_generator_error,
|
||||||
|
handle_request_disconnect,
|
||||||
|
handle_request_error,
|
||||||
|
unwrap,
|
||||||
|
)
|
||||||
from endpoints.OAI.types.chat_completion import (
|
from endpoints.OAI.types.chat_completion import (
|
||||||
ChatCompletionLogprobs,
|
ChatCompletionLogprobs,
|
||||||
ChatCompletionLogprob,
|
ChatCompletionLogprob,
|
||||||
@@ -150,20 +153,14 @@ def format_prompt_with_template(data: ChatCompletionRequest):
|
|||||||
|
|
||||||
|
|
||||||
async def stream_generate_chat_completion(
|
async def stream_generate_chat_completion(
|
||||||
prompt: str, request: Request, data: ChatCompletionRequest, model_path: pathlib.Path
|
prompt: str, data: ChatCompletionRequest, model_path: pathlib.Path
|
||||||
):
|
):
|
||||||
"""Generator for the generation process."""
|
"""Generator for the generation process."""
|
||||||
try:
|
try:
|
||||||
const_id = f"chatcmpl-{uuid4().hex}"
|
const_id = f"chatcmpl-{uuid4().hex}"
|
||||||
|
|
||||||
new_generation = model.container.generate_gen(prompt, **data.to_gen_params())
|
new_generation = model.container.generate_gen(prompt, **data.to_gen_params())
|
||||||
for generation in new_generation:
|
async for generation in new_generation:
|
||||||
# Get out if the request gets disconnected
|
|
||||||
if await request.is_disconnected():
|
|
||||||
release_semaphore()
|
|
||||||
logger.error("Chat completion generation cancelled by user.")
|
|
||||||
return
|
|
||||||
|
|
||||||
response = _create_stream_chunk(const_id, generation, model_path.name)
|
response = _create_stream_chunk(const_id, generation, model_path.name)
|
||||||
|
|
||||||
yield response.model_dump_json()
|
yield response.model_dump_json()
|
||||||
@@ -172,6 +169,10 @@ async def stream_generate_chat_completion(
|
|||||||
finish_response = _create_stream_chunk(const_id, finish_reason="stop")
|
finish_response = _create_stream_chunk(const_id, finish_reason="stop")
|
||||||
|
|
||||||
yield finish_response.model_dump_json()
|
yield finish_response.model_dump_json()
|
||||||
|
except CancelledError:
|
||||||
|
# Get out if the request gets disconnected
|
||||||
|
|
||||||
|
handle_request_disconnect("Chat completion generation cancelled by user.")
|
||||||
except Exception:
|
except Exception:
|
||||||
yield get_generator_error(
|
yield get_generator_error(
|
||||||
"Chat completion aborted. Please check the server console."
|
"Chat completion aborted. Please check the server console."
|
||||||
@@ -179,11 +180,10 @@ async def stream_generate_chat_completion(
|
|||||||
|
|
||||||
|
|
||||||
async def generate_chat_completion(
|
async def generate_chat_completion(
|
||||||
prompt: str, request: Request, data: ChatCompletionRequest, model_path: pathlib.Path
|
prompt: str, data: ChatCompletionRequest, model_path: pathlib.Path
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
generation = await run_in_threadpool(
|
generation = await model.container.generate(
|
||||||
model.container.generate,
|
|
||||||
prompt,
|
prompt,
|
||||||
**data.to_gen_params(),
|
**data.to_gen_params(),
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,14 +1,17 @@
|
|||||||
"""Completion utilities for OAI server."""
|
"""Completion utilities for OAI server."""
|
||||||
|
|
||||||
|
from asyncio import CancelledError
|
||||||
import pathlib
|
import pathlib
|
||||||
from fastapi import HTTPException, Request
|
from fastapi import HTTPException
|
||||||
from fastapi.concurrency import run_in_threadpool
|
|
||||||
from loguru import logger
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from common import model
|
from common import model
|
||||||
from common.generators import release_semaphore
|
from common.utils import (
|
||||||
from common.utils import get_generator_error, handle_request_error, unwrap
|
get_generator_error,
|
||||||
|
handle_request_disconnect,
|
||||||
|
handle_request_error,
|
||||||
|
unwrap,
|
||||||
|
)
|
||||||
from endpoints.OAI.types.completion import (
|
from endpoints.OAI.types.completion import (
|
||||||
CompletionRequest,
|
CompletionRequest,
|
||||||
CompletionResponse,
|
CompletionResponse,
|
||||||
@@ -57,28 +60,24 @@ def _create_response(generation: dict, model_name: Optional[str]):
|
|||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
||||||
async def stream_generate_completion(
|
async def stream_generate_completion(data: CompletionRequest, model_path: pathlib.Path):
|
||||||
request: Request, data: CompletionRequest, model_path: pathlib.Path
|
|
||||||
):
|
|
||||||
"""Streaming generation for completions."""
|
"""Streaming generation for completions."""
|
||||||
|
|
||||||
try:
|
try:
|
||||||
new_generation = model.container.generate_gen(
|
new_generation = model.container.generate_gen(
|
||||||
data.prompt, **data.to_gen_params()
|
data.prompt, **data.to_gen_params()
|
||||||
)
|
)
|
||||||
for generation in new_generation:
|
async for generation in new_generation:
|
||||||
# Get out if the request gets disconnected
|
|
||||||
if await request.is_disconnected():
|
|
||||||
release_semaphore()
|
|
||||||
logger.error("Completion generation cancelled by user.")
|
|
||||||
return
|
|
||||||
|
|
||||||
response = _create_response(generation, model_path.name)
|
response = _create_response(generation, model_path.name)
|
||||||
|
|
||||||
yield response.model_dump_json()
|
yield response.model_dump_json()
|
||||||
|
|
||||||
# Yield a finish response on successful generation
|
# Yield a finish response on successful generation
|
||||||
yield "[DONE]"
|
yield "[DONE]"
|
||||||
|
except CancelledError:
|
||||||
|
# Get out if the request gets disconnected
|
||||||
|
|
||||||
|
handle_request_disconnect("Completion generation cancelled by user.")
|
||||||
except Exception:
|
except Exception:
|
||||||
yield get_generator_error(
|
yield get_generator_error(
|
||||||
"Completion aborted. Please check the server console."
|
"Completion aborted. Please check the server console."
|
||||||
@@ -89,9 +88,7 @@ async def generate_completion(data: CompletionRequest, model_path: pathlib.Path)
|
|||||||
"""Non-streaming generate for completions"""
|
"""Non-streaming generate for completions"""
|
||||||
|
|
||||||
try:
|
try:
|
||||||
generation = await run_in_threadpool(
|
generation = await model.container.generate(data.prompt, **data.to_gen_params())
|
||||||
model.container.generate, data.prompt, **data.to_gen_params()
|
|
||||||
)
|
|
||||||
|
|
||||||
response = _create_response(generation, model_path.name)
|
response = _create_response(generation, model_path.name)
|
||||||
return response
|
return response
|
||||||
|
|||||||
@@ -9,6 +9,6 @@ def get_lora_list(lora_path: pathlib.Path):
|
|||||||
for path in lora_path.iterdir():
|
for path in lora_path.iterdir():
|
||||||
if path.is_dir():
|
if path.is_dir():
|
||||||
lora_card = LoraCard(id=path.name)
|
lora_card = LoraCard(id=path.name)
|
||||||
lora_list.data.append(lora_card) # pylint: disable=no-member
|
lora_list.data.append(lora_card)
|
||||||
|
|
||||||
return lora_list
|
return lora_list
|
||||||
|
|||||||
@@ -1,12 +1,9 @@
|
|||||||
import pathlib
|
import pathlib
|
||||||
from asyncio import CancelledError
|
from asyncio import CancelledError
|
||||||
from fastapi import Request
|
|
||||||
from loguru import logger
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from common import model
|
from common import model
|
||||||
from common.generators import release_semaphore
|
from common.utils import get_generator_error, handle_request_disconnect
|
||||||
from common.utils import get_generator_error
|
|
||||||
|
|
||||||
from endpoints.OAI.types.model import (
|
from endpoints.OAI.types.model import (
|
||||||
ModelCard,
|
ModelCard,
|
||||||
@@ -35,7 +32,6 @@ def get_model_list(model_path: pathlib.Path, draft_model_path: Optional[str] = N
|
|||||||
|
|
||||||
|
|
||||||
async def stream_model_load(
|
async def stream_model_load(
|
||||||
request: Request,
|
|
||||||
data: ModelLoadRequest,
|
data: ModelLoadRequest,
|
||||||
model_path: pathlib.Path,
|
model_path: pathlib.Path,
|
||||||
draft_model_path: str,
|
draft_model_path: str,
|
||||||
@@ -50,14 +46,6 @@ async def stream_model_load(
|
|||||||
load_status = model.load_model_gen(model_path, **load_data)
|
load_status = model.load_model_gen(model_path, **load_data)
|
||||||
try:
|
try:
|
||||||
async for module, modules, model_type in load_status:
|
async for module, modules, model_type in load_status:
|
||||||
if await request.is_disconnected():
|
|
||||||
release_semaphore()
|
|
||||||
logger.error(
|
|
||||||
"Model load cancelled by user. "
|
|
||||||
"Please make sure to run unload to free up resources."
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
if module != 0:
|
if module != 0:
|
||||||
response = ModelLoadResponse(
|
response = ModelLoadResponse(
|
||||||
model_type=model_type,
|
model_type=model_type,
|
||||||
@@ -78,7 +66,9 @@ async def stream_model_load(
|
|||||||
|
|
||||||
yield response.model_dump_json()
|
yield response.model_dump_json()
|
||||||
except CancelledError:
|
except CancelledError:
|
||||||
logger.error(
|
# Get out if the request gets disconnected
|
||||||
|
|
||||||
|
handle_request_disconnect(
|
||||||
"Model load cancelled by user. "
|
"Model load cancelled by user. "
|
||||||
"Please make sure to run unload to free up resources."
|
"Please make sure to run unload to free up resources."
|
||||||
)
|
)
|
||||||
|
|||||||
11
main.py
11
main.py
@@ -5,9 +5,6 @@ import os
|
|||||||
import pathlib
|
import pathlib
|
||||||
import signal
|
import signal
|
||||||
import sys
|
import sys
|
||||||
import threading
|
|
||||||
import time
|
|
||||||
from functools import partial
|
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
@@ -121,13 +118,7 @@ async def entrypoint(args: Optional[dict] = None):
|
|||||||
lora_dir = pathlib.Path(unwrap(lora_config.get("lora_dir"), "loras"))
|
lora_dir = pathlib.Path(unwrap(lora_config.get("lora_dir"), "loras"))
|
||||||
model.container.load_loras(lora_dir.resolve(), **lora_config)
|
model.container.load_loras(lora_dir.resolve(), **lora_config)
|
||||||
|
|
||||||
# TODO: Replace this with abortables, async via producer consumer, or something else
|
await start_api(host, port)
|
||||||
api_thread = threading.Thread(target=partial(start_api, host, port), daemon=True)
|
|
||||||
|
|
||||||
api_thread.start()
|
|
||||||
# Keep the program alive
|
|
||||||
while api_thread.is_alive():
|
|
||||||
time.sleep(0.5)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
Reference in New Issue
Block a user