Files
tabbyAPI/common/concurrency.py
kingbri c474076b22 Concurrency: Remove release_semaphore method
At any point for any request cancellation, the semaphore will be
decremented. This is an issue since an arbitrary request can desync
the semaphore, causing multiple tasks to be processed at once and
break generation.

Remove this from the networking handlers and therefore, remove the
release_semaphore function itself.

Signed-off-by: kingbri <bdashore3@proton.me>
2024-05-19 10:42:26 -04:00

58 lines
1.5 KiB
Python

"""Concurrency handling"""
import asyncio
import inspect
from fastapi.concurrency import run_in_threadpool # noqa
from functools import partialmethod
from typing import AsyncGenerator, Generator, Union
generate_semaphore = asyncio.Semaphore(1)
# Originally from https://github.com/encode/starlette/blob/master/starlette/concurrency.py
# Uses generators instead of generics
class _StopIteration(Exception):
"""Wrapper for StopIteration because it doesn't send across threads."""
pass
def gen_next(generator: Generator):
"""Threaded function to get the next value in an iterator."""
try:
return next(generator)
except StopIteration as e:
raise _StopIteration from e
async def iterate_in_threadpool(generator: Generator) -> AsyncGenerator:
"""Iterates a generator within a threadpool."""
while True:
try:
yield await asyncio.to_thread(gen_next, generator)
except _StopIteration:
break
async def generate_with_semaphore(generator: Union[AsyncGenerator, Generator]):
"""Generate with a semaphore."""
async with generate_semaphore:
if not inspect.isasyncgenfunction:
generator = iterate_in_threadpool(generator())
async for result in generator():
yield result
async def call_with_semaphore(callback: partialmethod):
"""Call with a semaphore."""
async with generate_semaphore:
if not inspect.iscoroutinefunction:
callback = run_in_threadpool(callback)
return await callback()