diff --git a/backends/exllamav2/model.py b/backends/exllamav2/model.py index 5cf7b60..64d2c27 100644 --- a/backends/exllamav2/model.py +++ b/backends/exllamav2/model.py @@ -1,8 +1,6 @@ """The model container class for ExLlamaV2 models.""" -import asyncio import gc -from itertools import zip_longest import pathlib import time @@ -17,10 +15,12 @@ from exllamav2 import ( ExLlamaV2Lora, ) from exllamav2.generator import ExLlamaV2StreamingGenerator, ExLlamaV2Sampler +from itertools import zip_longest from loguru import logger from typing import List, Optional, Union from backends.exllamav2.grammar import ExLlamaV2Grammar +from common.concurrency import iterate_in_threadpool from common.gen_logging import ( log_generation_params, log_metrics, @@ -336,7 +336,7 @@ class ExllamaV2Container: def progress(loaded_modules: int, total_modules: int) """ - for _ in self.load_gen(progress_callback): + async for _ in self.load_gen(progress_callback): pass async def load_loras(self, lora_directory: pathlib.Path, **kwargs): @@ -373,6 +373,13 @@ class ExllamaV2Container: return {"success": success, "failure": failure} async def load_gen(self, progress_callback=None): + """Basic async wrapper around the loading generator""" + + load_generator = self.load_gen_sync(progress_callback) + async for value in iterate_in_threadpool(load_generator): + yield value + + def load_gen_sync(self, progress_callback=None): """ Load model, generator function @@ -407,8 +414,6 @@ class ExllamaV2Container: last_id_only=True, callback_gen=progress_callback, ): - # Manually suspend the task to allow for other stuff to run - await asyncio.sleep(0) if value: yield value @@ -429,8 +434,6 @@ class ExllamaV2Container: self.gpu_split, callback_gen=progress_callback, ): - # Manually suspend the task to allow for other stuff to run - await asyncio.sleep(0) if value: yield value @@ -459,8 +462,6 @@ class ExllamaV2Container: last_id_only=True, callback_gen=progress_callback, ): - # Manually suspend the task to allow for other stuff to run - await asyncio.sleep(0) if value: yield value @@ -627,6 +628,13 @@ class ExllamaV2Container: return kwargs async def generate_gen(self, prompt: str, **kwargs): + """Basic async wrapper for completion generator""" + + sync_generator = self.generate_gen_sync(prompt, **kwargs) + async for value in iterate_in_threadpool(sync_generator): + yield value + + def generate_gen_sync(self, prompt: str, **kwargs): """ Create generator function for prompt completion @@ -899,9 +907,6 @@ class ExllamaV2Container: chunk_tokens = 0 while True: - # Manually suspend the task to allow for other stuff to run - await asyncio.sleep(0) - # Ingest prompt if chunk_tokens == 0: ids = torch.cat((ids, save_tokens), dim=-1) diff --git a/common/concurrency.py b/common/concurrency.py index 6cf9fed..675fbf1 100644 --- a/common/concurrency.py +++ b/common/concurrency.py @@ -2,12 +2,40 @@ import asyncio import inspect +from fastapi.concurrency import run_in_threadpool # noqa from functools import partialmethod from typing import AsyncGenerator, Generator, Union generate_semaphore = asyncio.Semaphore(1) +# Originally from https://github.com/encode/starlette/blob/master/starlette/concurrency.py +# Uses generators instead of generics +class _StopIteration(Exception): + """Wrapper for StopIteration because it doesn't send across threads.""" + + pass + + +def gen_next(generator: Generator): + """Threaded function to get the next value in an iterator.""" + + try: + return next(generator) + except StopIteration as e: + raise _StopIteration from e + + +async def iterate_in_threadpool(generator: Generator) -> AsyncGenerator: + """Iterates a generator within a threadpool.""" + + while True: + try: + yield await asyncio.to_thread(gen_next, generator) + except _StopIteration: + break + + def release_semaphore(): generate_semaphore.release() @@ -16,19 +44,18 @@ async def generate_with_semaphore(generator: Union[AsyncGenerator, Generator]): """Generate with a semaphore.""" async with generate_semaphore: - if inspect.isasyncgenfunction: - async for result in generator(): - yield result - else: - for result in generator(): - yield result + if not inspect.isasyncgenfunction: + generator = iterate_in_threadpool(generator()) + + async for result in generator(): + yield result async def call_with_semaphore(callback: partialmethod): """Call with a semaphore.""" async with generate_semaphore: - if inspect.iscoroutinefunction(callback): - return await callback() - else: - return callback() + if not inspect.iscoroutinefunction: + callback = run_in_threadpool(callback) + + return await callback() diff --git a/common/model.py b/common/model.py index 1e1d432..5608c9f 100644 --- a/common/model.py +++ b/common/model.py @@ -72,7 +72,7 @@ async def load_model_gen(model_path: pathlib.Path, **kwargs): async def load_model(model_path: pathlib.Path, **kwargs): - async for _, _, _ in load_model_gen(model_path, **kwargs): + async for _ in load_model_gen(model_path, **kwargs): pass diff --git a/endpoints/OAI/utils/completion.py b/endpoints/OAI/utils/completion.py index 7fc993f..8d56b09 100644 --- a/endpoints/OAI/utils/completion.py +++ b/endpoints/OAI/utils/completion.py @@ -1,7 +1,7 @@ """Completion utilities for OAI server.""" -from asyncio import CancelledError import pathlib +from asyncio import CancelledError from fastapi import HTTPException from typing import Optional