Tree: Switch to async generators

Async generation helps remove many roadblocks to managing tasks
using threads. It should allow for abortables and modern-day paradigms.

NOTE: Exllamav2 itself is not an asynchronous library. It's just
been added into tabby's async nature to allow for a fast and concurrent
API server. It's still being debated to run stream_ex in a separate
thread or manually manage it using asyncio.sleep(0)

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri
2024-03-14 10:27:39 -04:00
committed by Brian Dashore
parent 33e2df50b7
commit 7fded4f183
10 changed files with 84 additions and 88 deletions

View File

@@ -1,5 +1,6 @@
"""The model container class for ExLlamaV2 models.""" """The model container class for ExLlamaV2 models."""
import asyncio
import gc import gc
from itertools import zip_longest from itertools import zip_longest
import pathlib import pathlib
@@ -325,7 +326,7 @@ class ExllamaV2Container:
return model_params return model_params
def load(self, progress_callback=None): async def load(self, progress_callback=None):
""" """
Load model Load model
@@ -338,7 +339,7 @@ class ExllamaV2Container:
for _ in self.load_gen(progress_callback): for _ in self.load_gen(progress_callback):
pass pass
def load_loras(self, lora_directory: pathlib.Path, **kwargs): async def load_loras(self, lora_directory: pathlib.Path, **kwargs):
""" """
Load loras Load loras
""" """
@@ -361,7 +362,7 @@ class ExllamaV2Container:
logger.info(f"Loading lora: {lora_name} at scaling {lora_scaling}") logger.info(f"Loading lora: {lora_name} at scaling {lora_scaling}")
lora_path = lora_directory / lora_name lora_path = lora_directory / lora_name
# FIXME(alpin): Does self.model need to be passed here?
self.active_loras.append( self.active_loras.append(
ExLlamaV2Lora.from_directory(self.model, lora_path, lora_scaling) ExLlamaV2Lora.from_directory(self.model, lora_path, lora_scaling)
) )
@@ -371,7 +372,7 @@ class ExllamaV2Container:
# Return success and failure names # Return success and failure names
return {"success": success, "failure": failure} return {"success": success, "failure": failure}
def load_gen(self, progress_callback=None): async def load_gen(self, progress_callback=None):
""" """
Load model, generator function Load model, generator function
@@ -400,12 +401,16 @@ class ExllamaV2Container:
logger.info("Loading draft model: " + self.draft_config.model_dir) logger.info("Loading draft model: " + self.draft_config.model_dir)
self.draft_cache = ExLlamaV2Cache(self.draft_model, lazy=True) self.draft_cache = ExLlamaV2Cache(self.draft_model, lazy=True)
yield from self.draft_model.load_autosplit_gen( for value in self.draft_model.load_autosplit_gen(
self.draft_cache, self.draft_cache,
reserve_vram=autosplit_reserve, reserve_vram=autosplit_reserve,
last_id_only=True, last_id_only=True,
callback_gen=progress_callback, callback_gen=progress_callback,
) ):
# Manually suspend the task to allow for other stuff to run
await asyncio.sleep(0)
if value:
yield value
# Test VRAM allocation with a full-length forward pass # Test VRAM allocation with a full-length forward pass
input_ids = torch.zeros((1, self.config.max_input_len), dtype=torch.long) input_ids = torch.zeros((1, self.config.max_input_len), dtype=torch.long)
@@ -424,6 +429,8 @@ class ExllamaV2Container:
self.gpu_split, self.gpu_split,
callback_gen=progress_callback, callback_gen=progress_callback,
): ):
# Manually suspend the task to allow for other stuff to run
await asyncio.sleep(0)
if value: if value:
yield value yield value
@@ -452,6 +459,8 @@ class ExllamaV2Container:
last_id_only=True, last_id_only=True,
callback_gen=progress_callback, callback_gen=progress_callback,
): ):
# Manually suspend the task to allow for other stuff to run
await asyncio.sleep(0)
if value: if value:
yield value yield value
@@ -565,9 +574,11 @@ class ExllamaV2Container:
return dict(zip_longest(top_tokens, cleaned_values)) return dict(zip_longest(top_tokens, cleaned_values))
def generate(self, prompt: str, **kwargs): async def generate(self, prompt: str, **kwargs):
"""Generate a response to a prompt""" """Generate a response to a prompt"""
generations = list(self.generate_gen(prompt, **kwargs)) generations = []
async for generation in self.generate_gen(prompt, **kwargs):
generations.append(generation)
joined_generation = { joined_generation = {
"text": "", "text": "",
@@ -615,8 +626,7 @@ class ExllamaV2Container:
return kwargs return kwargs
# pylint: disable=too-many-locals,too-many-branches,too-many-statements async def generate_gen(self, prompt: str, **kwargs):
def generate_gen(self, prompt: str, **kwargs):
""" """
Create generator function for prompt completion Create generator function for prompt completion
@@ -889,6 +899,9 @@ class ExllamaV2Container:
chunk_tokens = 0 chunk_tokens = 0
while True: while True:
# Manually suspend the task to allow for other stuff to run
await asyncio.sleep(0)
# Ingest prompt # Ingest prompt
if chunk_tokens == 0: if chunk_tokens == 0:
ids = torch.cat((ids, save_tokens), dim=-1) ids = torch.cat((ids, save_tokens), dim=-1)

View File

@@ -1,4 +1,4 @@
"""Generator handling""" """Concurrency handling"""
import asyncio import asyncio
import inspect import inspect

View File

@@ -52,7 +52,7 @@ async def load_model_gen(model_path: pathlib.Path, **kwargs):
progress.start() progress.start()
try: try:
for module, modules in load_status: async for module, modules in load_status:
if module == 0: if module == 0:
loading_task = progress.add_task( loading_task = progress.add_task(
f"[cyan]Loading {model_type} modules", total=modules f"[cyan]Loading {model_type} modules", total=modules
@@ -76,12 +76,12 @@ async def load_model(model_path: pathlib.Path, **kwargs):
pass pass
def load_loras(lora_dir, **kwargs): async def load_loras(lora_dir, **kwargs):
"""Wrapper to load loras.""" """Wrapper to load loras."""
if len(container.active_loras) > 0: if len(container.active_loras) > 0:
unload_loras() unload_loras()
return container.load_loras(lora_dir, **kwargs) return await container.load_loras(lora_dir, **kwargs)
def unload_loras(): def unload_loras():

View File

@@ -6,6 +6,8 @@ from loguru import logger
from pydantic import BaseModel from pydantic import BaseModel
from typing import Optional from typing import Optional
from common.concurrency import release_semaphore
def load_progress(module, modules): def load_progress(module, modules):
"""Wrapper callback for load progress.""" """Wrapper callback for load progress."""
@@ -51,6 +53,13 @@ def handle_request_error(message: str, exc_info: bool = True):
return request_error return request_error
def handle_request_disconnect(message: str):
"""Wrapper for handling for request disconnection."""
release_semaphore()
logger.error(message)
def unwrap(wrapped, default=None): def unwrap(wrapped, default=None):
"""Unwrap function for Optionals.""" """Unwrap function for Optionals."""
if wrapped is None: if wrapped is None:

View File

@@ -1,7 +1,6 @@
import pathlib import pathlib
import uvicorn import uvicorn
from fastapi import FastAPI, Depends, HTTPException, Request from fastapi import FastAPI, Depends, HTTPException, Request
from fastapi.concurrency import run_in_threadpool
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from functools import partial from functools import partial
from loguru import logger from loguru import logger
@@ -10,7 +9,7 @@ from sys import maxsize
from common import config, model, gen_logging, sampling from common import config, model, gen_logging, sampling
from common.auth import check_admin_key, check_api_key from common.auth import check_admin_key, check_api_key
from common.generators import ( from common.concurrency import (
call_with_semaphore, call_with_semaphore,
generate_with_semaphore, generate_with_semaphore,
) )
@@ -181,9 +180,7 @@ async def load_model(request: Request, data: ModelLoadRequest):
if not model_path.exists(): if not model_path.exists():
raise HTTPException(400, "model_path does not exist. Check model_name?") raise HTTPException(400, "model_path does not exist. Check model_name?")
load_callback = partial( load_callback = partial(stream_model_load, data, model_path, draft_model_path)
stream_model_load, request, data, model_path, draft_model_path
)
# Wrap in a semaphore if the queue isn't being skipped # Wrap in a semaphore if the queue isn't being skipped
if data.skip_queue: if data.skip_queue:
@@ -333,9 +330,7 @@ async def load_lora(data: LoraLoadRequest):
"A parent lora directory does not exist. Check your config.yml?", "A parent lora directory does not exist. Check your config.yml?",
) )
load_callback = partial( load_callback = partial(model.load_loras, lora_dir, **data.model_dump())
run_in_threadpool, model.load_loras, lora_dir, **data.model_dump()
)
# Wrap in a semaphore if the queue isn't being skipped # Wrap in a semaphore if the queue isn't being skipped
if data.skip_queue: if data.skip_queue:
@@ -409,9 +404,7 @@ async def completion_request(request: Request, data: CompletionRequest):
) )
if data.stream and not disable_request_streaming: if data.stream and not disable_request_streaming:
generator_callback = partial( generator_callback = partial(stream_generate_completion, data, model_path)
stream_generate_completion, request, data, model_path
)
return EventSourceResponse( return EventSourceResponse(
generate_with_semaphore(generator_callback), generate_with_semaphore(generator_callback),
@@ -452,7 +445,7 @@ async def chat_completion_request(request: Request, data: ChatCompletionRequest)
if data.stream and not disable_request_streaming: if data.stream and not disable_request_streaming:
generator_callback = partial( generator_callback = partial(
stream_generate_chat_completion, prompt, request, data, model_path stream_generate_chat_completion, prompt, data, model_path
) )
return EventSourceResponse( return EventSourceResponse(
@@ -461,13 +454,13 @@ async def chat_completion_request(request: Request, data: ChatCompletionRequest)
) )
else: else:
response = await call_with_semaphore( response = await call_with_semaphore(
partial(generate_chat_completion, prompt, request, data, model_path) partial(generate_chat_completion, prompt, data, model_path)
) )
return response return response
def start_api(host: str, port: int): async def start_api(host: str, port: int):
"""Isolated function to start the API server""" """Isolated function to start the API server"""
# TODO: Move OAI API to a separate folder # TODO: Move OAI API to a separate folder
@@ -475,9 +468,12 @@ def start_api(host: str, port: int):
logger.info(f"Completions: http://{host}:{port}/v1/completions") logger.info(f"Completions: http://{host}:{port}/v1/completions")
logger.info(f"Chat completions: http://{host}:{port}/v1/chat/completions") logger.info(f"Chat completions: http://{host}:{port}/v1/chat/completions")
uvicorn.run( config = uvicorn.Config(
app, app,
host=host, host=host,
port=port, port=port,
log_config=UVICORN_LOG_CONFIG, log_config=UVICORN_LOG_CONFIG,
) )
server = uvicorn.Server(config)
await server.serve()

View File

@@ -1,18 +1,21 @@
"""Chat completion utilities for OAI server.""" """Chat completion utilities for OAI server."""
from asyncio import CancelledError
import pathlib import pathlib
from typing import Optional from typing import Optional
from uuid import uuid4 from uuid import uuid4
from fastapi import HTTPException, Request from fastapi import HTTPException
from fastapi.concurrency import run_in_threadpool
from jinja2 import TemplateError from jinja2 import TemplateError
from loguru import logger
from common import model from common import model
from common.generators import release_semaphore
from common.templating import get_prompt_from_template from common.templating import get_prompt_from_template
from common.utils import get_generator_error, handle_request_error, unwrap from common.utils import (
get_generator_error,
handle_request_disconnect,
handle_request_error,
unwrap,
)
from endpoints.OAI.types.chat_completion import ( from endpoints.OAI.types.chat_completion import (
ChatCompletionLogprobs, ChatCompletionLogprobs,
ChatCompletionLogprob, ChatCompletionLogprob,
@@ -150,20 +153,14 @@ def format_prompt_with_template(data: ChatCompletionRequest):
async def stream_generate_chat_completion( async def stream_generate_chat_completion(
prompt: str, request: Request, data: ChatCompletionRequest, model_path: pathlib.Path prompt: str, data: ChatCompletionRequest, model_path: pathlib.Path
): ):
"""Generator for the generation process.""" """Generator for the generation process."""
try: try:
const_id = f"chatcmpl-{uuid4().hex}" const_id = f"chatcmpl-{uuid4().hex}"
new_generation = model.container.generate_gen(prompt, **data.to_gen_params()) new_generation = model.container.generate_gen(prompt, **data.to_gen_params())
for generation in new_generation: async for generation in new_generation:
# Get out if the request gets disconnected
if await request.is_disconnected():
release_semaphore()
logger.error("Chat completion generation cancelled by user.")
return
response = _create_stream_chunk(const_id, generation, model_path.name) response = _create_stream_chunk(const_id, generation, model_path.name)
yield response.model_dump_json() yield response.model_dump_json()
@@ -172,6 +169,10 @@ async def stream_generate_chat_completion(
finish_response = _create_stream_chunk(const_id, finish_reason="stop") finish_response = _create_stream_chunk(const_id, finish_reason="stop")
yield finish_response.model_dump_json() yield finish_response.model_dump_json()
except CancelledError:
# Get out if the request gets disconnected
handle_request_disconnect("Chat completion generation cancelled by user.")
except Exception: except Exception:
yield get_generator_error( yield get_generator_error(
"Chat completion aborted. Please check the server console." "Chat completion aborted. Please check the server console."
@@ -179,11 +180,10 @@ async def stream_generate_chat_completion(
async def generate_chat_completion( async def generate_chat_completion(
prompt: str, request: Request, data: ChatCompletionRequest, model_path: pathlib.Path prompt: str, data: ChatCompletionRequest, model_path: pathlib.Path
): ):
try: try:
generation = await run_in_threadpool( generation = await model.container.generate(
model.container.generate,
prompt, prompt,
**data.to_gen_params(), **data.to_gen_params(),
) )

View File

@@ -1,14 +1,17 @@
"""Completion utilities for OAI server.""" """Completion utilities for OAI server."""
from asyncio import CancelledError
import pathlib import pathlib
from fastapi import HTTPException, Request from fastapi import HTTPException
from fastapi.concurrency import run_in_threadpool
from loguru import logger
from typing import Optional from typing import Optional
from common import model from common import model
from common.generators import release_semaphore from common.utils import (
from common.utils import get_generator_error, handle_request_error, unwrap get_generator_error,
handle_request_disconnect,
handle_request_error,
unwrap,
)
from endpoints.OAI.types.completion import ( from endpoints.OAI.types.completion import (
CompletionRequest, CompletionRequest,
CompletionResponse, CompletionResponse,
@@ -57,28 +60,24 @@ def _create_response(generation: dict, model_name: Optional[str]):
return response return response
async def stream_generate_completion( async def stream_generate_completion(data: CompletionRequest, model_path: pathlib.Path):
request: Request, data: CompletionRequest, model_path: pathlib.Path
):
"""Streaming generation for completions.""" """Streaming generation for completions."""
try: try:
new_generation = model.container.generate_gen( new_generation = model.container.generate_gen(
data.prompt, **data.to_gen_params() data.prompt, **data.to_gen_params()
) )
for generation in new_generation: async for generation in new_generation:
# Get out if the request gets disconnected
if await request.is_disconnected():
release_semaphore()
logger.error("Completion generation cancelled by user.")
return
response = _create_response(generation, model_path.name) response = _create_response(generation, model_path.name)
yield response.model_dump_json() yield response.model_dump_json()
# Yield a finish response on successful generation # Yield a finish response on successful generation
yield "[DONE]" yield "[DONE]"
except CancelledError:
# Get out if the request gets disconnected
handle_request_disconnect("Completion generation cancelled by user.")
except Exception: except Exception:
yield get_generator_error( yield get_generator_error(
"Completion aborted. Please check the server console." "Completion aborted. Please check the server console."
@@ -89,9 +88,7 @@ async def generate_completion(data: CompletionRequest, model_path: pathlib.Path)
"""Non-streaming generate for completions""" """Non-streaming generate for completions"""
try: try:
generation = await run_in_threadpool( generation = await model.container.generate(data.prompt, **data.to_gen_params())
model.container.generate, data.prompt, **data.to_gen_params()
)
response = _create_response(generation, model_path.name) response = _create_response(generation, model_path.name)
return response return response

View File

@@ -9,6 +9,6 @@ def get_lora_list(lora_path: pathlib.Path):
for path in lora_path.iterdir(): for path in lora_path.iterdir():
if path.is_dir(): if path.is_dir():
lora_card = LoraCard(id=path.name) lora_card = LoraCard(id=path.name)
lora_list.data.append(lora_card) # pylint: disable=no-member lora_list.data.append(lora_card)
return lora_list return lora_list

View File

@@ -1,12 +1,9 @@
import pathlib import pathlib
from asyncio import CancelledError from asyncio import CancelledError
from fastapi import Request
from loguru import logger
from typing import Optional from typing import Optional
from common import model from common import model
from common.generators import release_semaphore from common.utils import get_generator_error, handle_request_disconnect
from common.utils import get_generator_error
from endpoints.OAI.types.model import ( from endpoints.OAI.types.model import (
ModelCard, ModelCard,
@@ -35,7 +32,6 @@ def get_model_list(model_path: pathlib.Path, draft_model_path: Optional[str] = N
async def stream_model_load( async def stream_model_load(
request: Request,
data: ModelLoadRequest, data: ModelLoadRequest,
model_path: pathlib.Path, model_path: pathlib.Path,
draft_model_path: str, draft_model_path: str,
@@ -50,14 +46,6 @@ async def stream_model_load(
load_status = model.load_model_gen(model_path, **load_data) load_status = model.load_model_gen(model_path, **load_data)
try: try:
async for module, modules, model_type in load_status: async for module, modules, model_type in load_status:
if await request.is_disconnected():
release_semaphore()
logger.error(
"Model load cancelled by user. "
"Please make sure to run unload to free up resources."
)
return
if module != 0: if module != 0:
response = ModelLoadResponse( response = ModelLoadResponse(
model_type=model_type, model_type=model_type,
@@ -78,7 +66,9 @@ async def stream_model_load(
yield response.model_dump_json() yield response.model_dump_json()
except CancelledError: except CancelledError:
logger.error( # Get out if the request gets disconnected
handle_request_disconnect(
"Model load cancelled by user. " "Model load cancelled by user. "
"Please make sure to run unload to free up resources." "Please make sure to run unload to free up resources."
) )

11
main.py
View File

@@ -5,9 +5,6 @@ import os
import pathlib import pathlib
import signal import signal
import sys import sys
import threading
import time
from functools import partial
from loguru import logger from loguru import logger
from typing import Optional from typing import Optional
@@ -121,13 +118,7 @@ async def entrypoint(args: Optional[dict] = None):
lora_dir = pathlib.Path(unwrap(lora_config.get("lora_dir"), "loras")) lora_dir = pathlib.Path(unwrap(lora_config.get("lora_dir"), "loras"))
model.container.load_loras(lora_dir.resolve(), **lora_config) model.container.load_loras(lora_dir.resolve(), **lora_config)
# TODO: Replace this with abortables, async via producer consumer, or something else await start_api(host, port)
api_thread = threading.Thread(target=partial(start_api, host, port), daemon=True)
api_thread.start()
# Keep the program alive
while api_thread.is_alive():
time.sleep(0.5)
if __name__ == "__main__": if __name__ == "__main__":