Tree: Switch to async generators

Async generation helps remove many roadblocks to managing tasks
using threads. It should allow for abortables and modern-day paradigms.

NOTE: Exllamav2 itself is not an asynchronous library. It's just
been added into tabby's async nature to allow for a fast and concurrent
API server. It's still being debated to run stream_ex in a separate
thread or manually manage it using asyncio.sleep(0)

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri
2024-03-14 10:27:39 -04:00
committed by Brian Dashore
parent 33e2df50b7
commit 7fded4f183
10 changed files with 84 additions and 88 deletions

View File

@@ -1,5 +1,6 @@
"""The model container class for ExLlamaV2 models."""
import asyncio
import gc
from itertools import zip_longest
import pathlib
@@ -325,7 +326,7 @@ class ExllamaV2Container:
return model_params
def load(self, progress_callback=None):
async def load(self, progress_callback=None):
"""
Load model
@@ -338,7 +339,7 @@ class ExllamaV2Container:
for _ in self.load_gen(progress_callback):
pass
def load_loras(self, lora_directory: pathlib.Path, **kwargs):
async def load_loras(self, lora_directory: pathlib.Path, **kwargs):
"""
Load loras
"""
@@ -361,7 +362,7 @@ class ExllamaV2Container:
logger.info(f"Loading lora: {lora_name} at scaling {lora_scaling}")
lora_path = lora_directory / lora_name
# FIXME(alpin): Does self.model need to be passed here?
self.active_loras.append(
ExLlamaV2Lora.from_directory(self.model, lora_path, lora_scaling)
)
@@ -371,7 +372,7 @@ class ExllamaV2Container:
# Return success and failure names
return {"success": success, "failure": failure}
def load_gen(self, progress_callback=None):
async def load_gen(self, progress_callback=None):
"""
Load model, generator function
@@ -400,12 +401,16 @@ class ExllamaV2Container:
logger.info("Loading draft model: " + self.draft_config.model_dir)
self.draft_cache = ExLlamaV2Cache(self.draft_model, lazy=True)
yield from self.draft_model.load_autosplit_gen(
for value in self.draft_model.load_autosplit_gen(
self.draft_cache,
reserve_vram=autosplit_reserve,
last_id_only=True,
callback_gen=progress_callback,
)
):
# Manually suspend the task to allow for other stuff to run
await asyncio.sleep(0)
if value:
yield value
# Test VRAM allocation with a full-length forward pass
input_ids = torch.zeros((1, self.config.max_input_len), dtype=torch.long)
@@ -424,6 +429,8 @@ class ExllamaV2Container:
self.gpu_split,
callback_gen=progress_callback,
):
# Manually suspend the task to allow for other stuff to run
await asyncio.sleep(0)
if value:
yield value
@@ -452,6 +459,8 @@ class ExllamaV2Container:
last_id_only=True,
callback_gen=progress_callback,
):
# Manually suspend the task to allow for other stuff to run
await asyncio.sleep(0)
if value:
yield value
@@ -565,9 +574,11 @@ class ExllamaV2Container:
return dict(zip_longest(top_tokens, cleaned_values))
def generate(self, prompt: str, **kwargs):
async def generate(self, prompt: str, **kwargs):
"""Generate a response to a prompt"""
generations = list(self.generate_gen(prompt, **kwargs))
generations = []
async for generation in self.generate_gen(prompt, **kwargs):
generations.append(generation)
joined_generation = {
"text": "",
@@ -615,8 +626,7 @@ class ExllamaV2Container:
return kwargs
# pylint: disable=too-many-locals,too-many-branches,too-many-statements
def generate_gen(self, prompt: str, **kwargs):
async def generate_gen(self, prompt: str, **kwargs):
"""
Create generator function for prompt completion
@@ -889,6 +899,9 @@ class ExllamaV2Container:
chunk_tokens = 0
while True:
# Manually suspend the task to allow for other stuff to run
await asyncio.sleep(0)
# Ingest prompt
if chunk_tokens == 0:
ids = torch.cat((ids, save_tokens), dim=-1)