mirror of
https://github.com/theroyallab/tabbyAPI.git
synced 2026-04-27 09:41:54 +00:00
Tree: Switch to async generators
Async generation helps remove many roadblocks to managing tasks using threads. It should allow for abortables and modern-day paradigms. NOTE: Exllamav2 itself is not an asynchronous library. It's just been added into tabby's async nature to allow for a fast and concurrent API server. It's still being debated to run stream_ex in a separate thread or manually manage it using asyncio.sleep(0) Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
"""The model container class for ExLlamaV2 models."""
|
||||
|
||||
import asyncio
|
||||
import gc
|
||||
from itertools import zip_longest
|
||||
import pathlib
|
||||
@@ -325,7 +326,7 @@ class ExllamaV2Container:
|
||||
|
||||
return model_params
|
||||
|
||||
def load(self, progress_callback=None):
|
||||
async def load(self, progress_callback=None):
|
||||
"""
|
||||
Load model
|
||||
|
||||
@@ -338,7 +339,7 @@ class ExllamaV2Container:
|
||||
for _ in self.load_gen(progress_callback):
|
||||
pass
|
||||
|
||||
def load_loras(self, lora_directory: pathlib.Path, **kwargs):
|
||||
async def load_loras(self, lora_directory: pathlib.Path, **kwargs):
|
||||
"""
|
||||
Load loras
|
||||
"""
|
||||
@@ -361,7 +362,7 @@ class ExllamaV2Container:
|
||||
|
||||
logger.info(f"Loading lora: {lora_name} at scaling {lora_scaling}")
|
||||
lora_path = lora_directory / lora_name
|
||||
# FIXME(alpin): Does self.model need to be passed here?
|
||||
|
||||
self.active_loras.append(
|
||||
ExLlamaV2Lora.from_directory(self.model, lora_path, lora_scaling)
|
||||
)
|
||||
@@ -371,7 +372,7 @@ class ExllamaV2Container:
|
||||
# Return success and failure names
|
||||
return {"success": success, "failure": failure}
|
||||
|
||||
def load_gen(self, progress_callback=None):
|
||||
async def load_gen(self, progress_callback=None):
|
||||
"""
|
||||
Load model, generator function
|
||||
|
||||
@@ -400,12 +401,16 @@ class ExllamaV2Container:
|
||||
logger.info("Loading draft model: " + self.draft_config.model_dir)
|
||||
|
||||
self.draft_cache = ExLlamaV2Cache(self.draft_model, lazy=True)
|
||||
yield from self.draft_model.load_autosplit_gen(
|
||||
for value in self.draft_model.load_autosplit_gen(
|
||||
self.draft_cache,
|
||||
reserve_vram=autosplit_reserve,
|
||||
last_id_only=True,
|
||||
callback_gen=progress_callback,
|
||||
)
|
||||
):
|
||||
# Manually suspend the task to allow for other stuff to run
|
||||
await asyncio.sleep(0)
|
||||
if value:
|
||||
yield value
|
||||
|
||||
# Test VRAM allocation with a full-length forward pass
|
||||
input_ids = torch.zeros((1, self.config.max_input_len), dtype=torch.long)
|
||||
@@ -424,6 +429,8 @@ class ExllamaV2Container:
|
||||
self.gpu_split,
|
||||
callback_gen=progress_callback,
|
||||
):
|
||||
# Manually suspend the task to allow for other stuff to run
|
||||
await asyncio.sleep(0)
|
||||
if value:
|
||||
yield value
|
||||
|
||||
@@ -452,6 +459,8 @@ class ExllamaV2Container:
|
||||
last_id_only=True,
|
||||
callback_gen=progress_callback,
|
||||
):
|
||||
# Manually suspend the task to allow for other stuff to run
|
||||
await asyncio.sleep(0)
|
||||
if value:
|
||||
yield value
|
||||
|
||||
@@ -565,9 +574,11 @@ class ExllamaV2Container:
|
||||
|
||||
return dict(zip_longest(top_tokens, cleaned_values))
|
||||
|
||||
def generate(self, prompt: str, **kwargs):
|
||||
async def generate(self, prompt: str, **kwargs):
|
||||
"""Generate a response to a prompt"""
|
||||
generations = list(self.generate_gen(prompt, **kwargs))
|
||||
generations = []
|
||||
async for generation in self.generate_gen(prompt, **kwargs):
|
||||
generations.append(generation)
|
||||
|
||||
joined_generation = {
|
||||
"text": "",
|
||||
@@ -615,8 +626,7 @@ class ExllamaV2Container:
|
||||
|
||||
return kwargs
|
||||
|
||||
# pylint: disable=too-many-locals,too-many-branches,too-many-statements
|
||||
def generate_gen(self, prompt: str, **kwargs):
|
||||
async def generate_gen(self, prompt: str, **kwargs):
|
||||
"""
|
||||
Create generator function for prompt completion
|
||||
|
||||
@@ -889,6 +899,9 @@ class ExllamaV2Container:
|
||||
chunk_tokens = 0
|
||||
|
||||
while True:
|
||||
# Manually suspend the task to allow for other stuff to run
|
||||
await asyncio.sleep(0)
|
||||
|
||||
# Ingest prompt
|
||||
if chunk_tokens == 0:
|
||||
ids = torch.cat((ids, save_tokens), dim=-1)
|
||||
|
||||
Reference in New Issue
Block a user