mirror of
https://github.com/theroyallab/tabbyAPI.git
synced 2026-03-15 00:07:28 +00:00
At any point for any request cancellation, the semaphore will be decremented. This is an issue since an arbitrary request can desync the semaphore, causing multiple tasks to be processed at once and break generation. Remove this from the networking handlers and therefore, remove the release_semaphore function itself. Signed-off-by: kingbri <bdashore3@proton.me>
58 lines
1.5 KiB
Python
58 lines
1.5 KiB
Python
"""Concurrency handling"""
|
|
|
|
import asyncio
|
|
import inspect
|
|
from fastapi.concurrency import run_in_threadpool # noqa
|
|
from functools import partialmethod
|
|
from typing import AsyncGenerator, Generator, Union
|
|
|
|
generate_semaphore = asyncio.Semaphore(1)
|
|
|
|
|
|
# Originally from https://github.com/encode/starlette/blob/master/starlette/concurrency.py
|
|
# Uses generators instead of generics
|
|
class _StopIteration(Exception):
|
|
"""Wrapper for StopIteration because it doesn't send across threads."""
|
|
|
|
pass
|
|
|
|
|
|
def gen_next(generator: Generator):
|
|
"""Threaded function to get the next value in an iterator."""
|
|
|
|
try:
|
|
return next(generator)
|
|
except StopIteration as e:
|
|
raise _StopIteration from e
|
|
|
|
|
|
async def iterate_in_threadpool(generator: Generator) -> AsyncGenerator:
|
|
"""Iterates a generator within a threadpool."""
|
|
|
|
while True:
|
|
try:
|
|
yield await asyncio.to_thread(gen_next, generator)
|
|
except _StopIteration:
|
|
break
|
|
|
|
|
|
async def generate_with_semaphore(generator: Union[AsyncGenerator, Generator]):
|
|
"""Generate with a semaphore."""
|
|
|
|
async with generate_semaphore:
|
|
if not inspect.isasyncgenfunction:
|
|
generator = iterate_in_threadpool(generator())
|
|
|
|
async for result in generator():
|
|
yield result
|
|
|
|
|
|
async def call_with_semaphore(callback: partialmethod):
|
|
"""Call with a semaphore."""
|
|
|
|
async with generate_semaphore:
|
|
if not inspect.iscoroutinefunction:
|
|
callback = run_in_threadpool(callback)
|
|
|
|
return await callback()
|