Merge remote-tracking branch 'origin/dev' into dev

This commit is contained in:
turboderp
2024-07-05 23:57:40 +02:00

View File

@@ -22,7 +22,9 @@ class ExLlamaV2DynamicGeneratorAsync:
try:
while True:
async with self.condition:
await self.condition.wait_for(lambda: len(self.jobs) > 0)
# Unlock if there's no jobs or if the parent task is cancelled
await self.condition.wait_for(lambda: len(self.jobs) > 0 or self.iteration_task.cancelled())
results = self.generator.iterate()
for result in results:
job = result["job"]
@@ -31,6 +33,9 @@ class ExLlamaV2DynamicGeneratorAsync:
if result["eos"]:
del self.jobs[job]
await asyncio.sleep(0)
except asyncio.CancelledError:
# Silently return on cancel
return
except Exception as e:
# If the generator throws an exception it won't pertain to any one ongoing job, so push it to all of them
for async_job in self.jobs.values():
@@ -48,6 +53,9 @@ class ExLlamaV2DynamicGeneratorAsync:
async def close(self):
self.iteration_task.cancel()
# Force a re-check of the condition to unlock the loop
await self._notify_condition()
try:
await self.iteration_task
except asyncio.CancelledError: