mirror of
https://github.com/theroyallab/tabbyAPI.git
synced 2026-04-20 14:28:54 +00:00
API + Model: Add blocks and checks for various load requests
Add a sequential lock and wait until jobs are completed before executing any loading requests that directly alter the model. However, we also need to block any new requests that come in until the load is finished, so add a condition that triggers once the lock is free. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
"""The model container class for ExLlamaV2 models."""
|
||||
|
||||
import asyncio
|
||||
import gc
|
||||
import math
|
||||
import pathlib
|
||||
@@ -54,7 +55,6 @@ class ExllamaV2Container:
|
||||
tokenizer: Optional[ExLlamaV2Tokenizer] = None
|
||||
generator: Optional[ExLlamaV2DynamicGeneratorAsync] = None
|
||||
prompt_template: Optional[PromptTemplate] = None
|
||||
active_loras: List[ExLlamaV2Lora] = []
|
||||
paged: bool = True
|
||||
|
||||
# Internal config vars
|
||||
@@ -71,6 +71,12 @@ class ExllamaV2Container:
|
||||
model_is_loading: bool = False
|
||||
model_loaded: bool = False
|
||||
|
||||
# Load synchronization
|
||||
# The lock keeps load tasks sequential
|
||||
# The condition notifies any waiting tasks
|
||||
load_lock: asyncio.Lock = asyncio.Lock()
|
||||
load_condition: asyncio.Condition = asyncio.Condition()
|
||||
|
||||
def __init__(self, model_directory: pathlib.Path, quiet=False, **kwargs):
|
||||
"""
|
||||
Create model container
|
||||
@@ -348,6 +354,22 @@ class ExllamaV2Container:
|
||||
|
||||
return model_params
|
||||
|
||||
async def wait_for_jobs(self, skip_wait: bool = False):
|
||||
"""Polling mechanism to wait for pending generation jobs."""
|
||||
|
||||
if not self.generator:
|
||||
return
|
||||
|
||||
# Immediately abort all jobs if asked
|
||||
if skip_wait:
|
||||
# Requires a copy to avoid errors during iteration
|
||||
jobs_copy = self.generator.jobs.copy()
|
||||
for job in jobs_copy.values():
|
||||
await job.cancel()
|
||||
|
||||
while self.generator.jobs:
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
async def load(self, progress_callback=None):
|
||||
"""
|
||||
Load model
|
||||
@@ -361,89 +383,67 @@ class ExllamaV2Container:
|
||||
async for _ in self.load_gen(progress_callback):
|
||||
pass
|
||||
|
||||
async def load_loras(self, lora_directory: pathlib.Path, **kwargs):
|
||||
"""
|
||||
Load loras
|
||||
"""
|
||||
|
||||
loras = unwrap(kwargs.get("loras"), [])
|
||||
success: List[str] = []
|
||||
failure: List[str] = []
|
||||
|
||||
for lora in loras:
|
||||
lora_name = lora.get("name")
|
||||
lora_scaling = unwrap(lora.get("scaling"), 1.0)
|
||||
|
||||
if lora_name is None:
|
||||
logger.warning(
|
||||
"One of your loras does not have a name. Please check your "
|
||||
"config.yml! Skipping lora load."
|
||||
)
|
||||
failure.append(lora_name)
|
||||
continue
|
||||
|
||||
logger.info(f"Loading lora: {lora_name} at scaling {lora_scaling}")
|
||||
lora_path = lora_directory / lora_name
|
||||
|
||||
self.active_loras.append(
|
||||
ExLlamaV2Lora.from_directory(self.model, lora_path, lora_scaling)
|
||||
)
|
||||
logger.info(f"Lora successfully loaded: {lora_name}")
|
||||
success.append(lora_name)
|
||||
|
||||
# Return success and failure names
|
||||
return {"success": success, "failure": failure}
|
||||
|
||||
async def load_gen(self, progress_callback=None):
|
||||
async def load_gen(self, progress_callback=None, **kwargs):
|
||||
"""Loads a model and streams progress via a generator."""
|
||||
|
||||
# Indicate that model load has started
|
||||
self.model_is_loading = True
|
||||
# Do this operation under the load lock's context
|
||||
try:
|
||||
await self.load_lock.acquire()
|
||||
self.model_is_loading = True
|
||||
|
||||
# Streaming gen for model load progress
|
||||
model_load_generator = self.load_model_sync(progress_callback)
|
||||
async for value in iterate_in_threadpool(model_load_generator):
|
||||
yield value
|
||||
# Wait for existing generation jobs to finish
|
||||
await self.wait_for_jobs(kwargs.get("skip_wait"))
|
||||
|
||||
# Disable paged mode if the user's min GPU is supported (ampere and above)
|
||||
min_compute_capability = min(
|
||||
set(
|
||||
[
|
||||
torch.cuda.get_device_capability(device=module.device_idx)[0]
|
||||
for module in self.model.modules
|
||||
if module.device_idx >= 0
|
||||
]
|
||||
# Streaming gen for model load progress
|
||||
model_load_generator = self.load_model_sync(progress_callback)
|
||||
async for value in iterate_in_threadpool(model_load_generator):
|
||||
yield value
|
||||
|
||||
# Disable paged mode if the user's min GPU is supported (ampere and above)
|
||||
min_compute_capability = min(
|
||||
set(
|
||||
[
|
||||
torch.cuda.get_device_capability(device=module.device_idx)[0]
|
||||
for module in self.model.modules
|
||||
if module.device_idx >= 0
|
||||
]
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
if torch.version.hip or min_compute_capability < 8:
|
||||
logger.warning(
|
||||
"An unsupported GPU is found in this configuration. "
|
||||
"Switching to compatibility mode. This disables parallel batching."
|
||||
if torch.version.hip or min_compute_capability < 8:
|
||||
logger.warning(
|
||||
"An unsupported GPU is found in this configuration. "
|
||||
"Switching to compatibility mode. This disables parallel batching."
|
||||
)
|
||||
self.paged = False
|
||||
self.max_batch_size = 1
|
||||
|
||||
# Create async generator
|
||||
self.generator = ExLlamaV2DynamicGeneratorAsync(
|
||||
model=self.model,
|
||||
cache=self.cache,
|
||||
draft_model=self.draft_model,
|
||||
draft_cache=self.draft_cache,
|
||||
tokenizer=self.tokenizer,
|
||||
max_batch_size=self.max_batch_size,
|
||||
paged=self.paged,
|
||||
)
|
||||
self.paged = False
|
||||
self.max_batch_size = 1
|
||||
|
||||
# Create async generator
|
||||
self.generator = ExLlamaV2DynamicGeneratorAsync(
|
||||
model=self.model,
|
||||
cache=self.cache,
|
||||
draft_model=self.draft_model,
|
||||
draft_cache=self.draft_cache,
|
||||
tokenizer=self.tokenizer,
|
||||
max_batch_size=self.max_batch_size,
|
||||
paged=self.paged,
|
||||
)
|
||||
# Clean up any extra vram usage from torch and cuda
|
||||
# (Helps reduce VRAM bottlenecking on Windows)
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# Clean up any extra vram usage from torch and cuda
|
||||
# (Helps reduce VRAM bottlenecking on Windows)
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
# Cleanup and update model load state
|
||||
self.model_loaded = True
|
||||
logger.info("Model successfully loaded.")
|
||||
finally:
|
||||
self.load_lock.release()
|
||||
self.model_is_loading = False
|
||||
|
||||
# Cleanup and update model load state
|
||||
self.model_is_loading = False
|
||||
self.model_loaded = True
|
||||
logger.info("Model successfully loaded.")
|
||||
async with self.load_condition:
|
||||
self.load_condition.notify_all()
|
||||
|
||||
@torch.inference_mode()
|
||||
def load_model_sync(self, progress_callback=None):
|
||||
@@ -538,39 +538,108 @@ class ExllamaV2Container:
|
||||
input_ids = torch.zeros((1, self.config.max_input_len), dtype=torch.long)
|
||||
self.model.forward(input_ids, cache=self.cache, preprocess_only=True)
|
||||
|
||||
def unload(self, loras_only: bool = False):
|
||||
def get_loras(self):
|
||||
"""Convenience function to get all loras."""
|
||||
|
||||
return unwrap(self.generator.generator.current_loras, [])
|
||||
|
||||
async def load_loras(self, lora_directory: pathlib.Path, **kwargs):
|
||||
"""
|
||||
Load loras
|
||||
"""
|
||||
|
||||
loras = unwrap(kwargs.get("loras"), [])
|
||||
|
||||
try:
|
||||
await self.load_lock.acquire()
|
||||
|
||||
# Wait for existing generation jobs to finish
|
||||
await self.wait_for_jobs(kwargs.get("skip_wait"))
|
||||
|
||||
loras_to_load: List[ExLlamaV2Lora] = []
|
||||
success: List[str] = []
|
||||
failure: List[str] = []
|
||||
|
||||
for lora in loras:
|
||||
lora_name = lora.get("name")
|
||||
lora_scaling = unwrap(lora.get("scaling"), 1.0)
|
||||
|
||||
if lora_name is None:
|
||||
logger.warning(
|
||||
"One of your loras does not have a name. Please check your "
|
||||
"config.yml! Skipping lora load."
|
||||
)
|
||||
failure.append(lora_name)
|
||||
continue
|
||||
|
||||
logger.info(f"Adding lora: {lora_name} at scaling {lora_scaling}")
|
||||
lora_path = lora_directory / lora_name
|
||||
|
||||
loras_to_load.append(
|
||||
ExLlamaV2Lora.from_directory(self.model, lora_path, lora_scaling)
|
||||
)
|
||||
logger.info(f"Lora successfully added: {lora_name}")
|
||||
success.append(lora_name)
|
||||
|
||||
self.generator.generator.set_loras(loras_to_load)
|
||||
logger.info("All loras successfully loaded")
|
||||
|
||||
# Return success and failure names
|
||||
return {"success": success, "failure": failure}
|
||||
finally:
|
||||
self.load_lock.release()
|
||||
|
||||
async with self.load_condition:
|
||||
self.load_condition.notify_all()
|
||||
|
||||
async def unload(self, loras_only: bool = False, **kwargs):
|
||||
"""
|
||||
Free all VRAM resources used by this model
|
||||
"""
|
||||
|
||||
for lora in self.active_loras:
|
||||
lora.unload()
|
||||
try:
|
||||
await self.load_lock.acquire()
|
||||
|
||||
self.active_loras = []
|
||||
# Wait for other jobs to finish
|
||||
await self.wait_for_jobs(kwargs.get("skip_wait"))
|
||||
|
||||
# Unload the entire model if not just unloading loras
|
||||
if not loras_only:
|
||||
if self.model:
|
||||
self.model.unload()
|
||||
self.model = None
|
||||
if self.generator and self.generator.generator.current_loras:
|
||||
for lora in self.generator.generator.current_loras:
|
||||
lora.unload()
|
||||
|
||||
if self.draft_model:
|
||||
self.draft_model.unload()
|
||||
self.draft_model = None
|
||||
self.generator.generator.set_loras([])
|
||||
|
||||
self.config = None
|
||||
self.cache = None
|
||||
self.tokenizer = None
|
||||
self.generator = None
|
||||
# Unload the entire model if not just unloading loras
|
||||
if not loras_only:
|
||||
if self.model:
|
||||
self.model.unload()
|
||||
self.model = None
|
||||
|
||||
# Set all model state variables to False
|
||||
self.model_is_loading = False
|
||||
self.model_loaded = False
|
||||
if self.draft_model:
|
||||
self.draft_model.unload()
|
||||
self.draft_model = None
|
||||
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
self.config = None
|
||||
self.cache = None
|
||||
self.tokenizer = None
|
||||
|
||||
logger.info("Loras unloaded." if loras_only else "Model unloaded.")
|
||||
# Cleanup the generator from any pending jobs
|
||||
await self.generator.close()
|
||||
self.generator = None
|
||||
|
||||
# Set all model state variables to False
|
||||
self.model_is_loading = False
|
||||
self.model_loaded = False
|
||||
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
logger.info("Loras unloaded." if loras_only else "Model unloaded.")
|
||||
finally:
|
||||
self.load_lock.release()
|
||||
|
||||
async with self.load_condition:
|
||||
self.load_condition.notify_all()
|
||||
|
||||
def encode_tokens(self, text: str, **kwargs):
|
||||
"""Wrapper to encode tokens from a text string"""
|
||||
@@ -683,6 +752,10 @@ class ExllamaV2Container:
|
||||
for kwargs, check common/sampling.py
|
||||
"""
|
||||
|
||||
# Wait for load lock to be freed before processing
|
||||
async with self.load_condition:
|
||||
await self.load_condition.wait_for(lambda: not self.load_lock.locked())
|
||||
|
||||
prompts = [prompt]
|
||||
|
||||
token_healing = unwrap(kwargs.get("token_healing"), False)
|
||||
@@ -951,79 +1024,84 @@ class ExllamaV2Container:
|
||||
)
|
||||
|
||||
# Save generated tokens and full response
|
||||
# Copy over max seq len incase model is unloaded and stored jobs can complete
|
||||
# Full response is required for offset calculation
|
||||
max_seq_len = self.config.max_seq_len
|
||||
generated_tokens = 0
|
||||
full_response = ""
|
||||
|
||||
# Get the generation status once it's ready
|
||||
async for result in job:
|
||||
stage = result.get("stage")
|
||||
result_id = result.get("identifier")
|
||||
try:
|
||||
# Get the generation status once it's ready
|
||||
async for result in job:
|
||||
stage = result.get("stage")
|
||||
result_id = result.get("identifier")
|
||||
|
||||
if stage == "streaming" and result_id == job_id:
|
||||
chunk = unwrap(result.get("text"), "")
|
||||
full_response += chunk
|
||||
if stage == "streaming" and result_id == job_id:
|
||||
chunk = unwrap(result.get("text"), "")
|
||||
full_response += chunk
|
||||
|
||||
chunk_tokens = result.get("token_ids")
|
||||
if chunk_tokens is not None:
|
||||
generated_tokens += chunk_tokens.size(dim=0)
|
||||
chunk_tokens = result.get("token_ids")
|
||||
if chunk_tokens is not None:
|
||||
generated_tokens += chunk_tokens.size(dim=0)
|
||||
|
||||
generation = {
|
||||
"text": chunk,
|
||||
"prompt_tokens": context_len,
|
||||
"generated_tokens": generated_tokens,
|
||||
"offset": len(full_response),
|
||||
}
|
||||
|
||||
if request_logprobs > 0:
|
||||
# Get top tokens and probs
|
||||
top_tokens = unwrap(
|
||||
result.get("top_k_tokens"),
|
||||
torch.empty((1, 0, 1), dtype=torch.long),
|
||||
)
|
||||
|
||||
top_probs = unwrap(
|
||||
result.get("top_k_probs"),
|
||||
torch.empty((1, 0, 1), dtype=torch.float),
|
||||
)
|
||||
|
||||
if top_tokens.numel() > 0 and top_probs.numel() > 0:
|
||||
logprobs = self.get_logprobs(top_tokens, top_probs)
|
||||
generation["logprobs"] = logprobs
|
||||
|
||||
# The first logprob is the selected token prob
|
||||
generation["token_probs"] = {
|
||||
token: logprobs[token]
|
||||
for token in list(logprobs.keys())[:1]
|
||||
}
|
||||
|
||||
yield generation
|
||||
|
||||
# Second yield if eos is true
|
||||
if result.get("eos"):
|
||||
log_response(full_response)
|
||||
|
||||
eos_reason = result.get("eos_reason")
|
||||
finish_reason = (
|
||||
"length" if eos_reason == "max_new_tokens" else "stop"
|
||||
)
|
||||
|
||||
log_metrics(
|
||||
result.get("time_enqueued"),
|
||||
result.get("prompt_tokens"),
|
||||
result.get("time_prefill"),
|
||||
result.get("new_tokens"),
|
||||
result.get("time_generate"),
|
||||
context_len,
|
||||
self.config.max_seq_len,
|
||||
)
|
||||
|
||||
# Remove the token text
|
||||
generation = {
|
||||
"prompt_tokens": generation.get("prompt_tokens"),
|
||||
"generated_tokens": generation.get("generated_tokens"),
|
||||
"finish_reason": finish_reason,
|
||||
"text": chunk,
|
||||
"prompt_tokens": context_len,
|
||||
"generated_tokens": generated_tokens,
|
||||
"offset": len(full_response),
|
||||
}
|
||||
|
||||
if request_logprobs > 0:
|
||||
# Get top tokens and probs
|
||||
top_tokens = unwrap(
|
||||
result.get("top_k_tokens"),
|
||||
torch.empty((1, 0, 1), dtype=torch.long),
|
||||
)
|
||||
|
||||
top_probs = unwrap(
|
||||
result.get("top_k_probs"),
|
||||
torch.empty((1, 0, 1), dtype=torch.float),
|
||||
)
|
||||
|
||||
if top_tokens.numel() > 0 and top_probs.numel() > 0:
|
||||
logprobs = self.get_logprobs(top_tokens, top_probs)
|
||||
generation["logprobs"] = logprobs
|
||||
|
||||
# The first logprob is the selected token prob
|
||||
generation["token_probs"] = {
|
||||
token: logprobs[token]
|
||||
for token in list(logprobs.keys())[:1]
|
||||
}
|
||||
|
||||
yield generation
|
||||
break
|
||||
|
||||
# Second yield if eos is true
|
||||
if result.get("eos"):
|
||||
log_response(full_response)
|
||||
|
||||
eos_reason = result.get("eos_reason")
|
||||
finish_reason = (
|
||||
"length" if eos_reason == "max_new_tokens" else "stop"
|
||||
)
|
||||
|
||||
log_metrics(
|
||||
result.get("time_enqueued"),
|
||||
result.get("prompt_tokens"),
|
||||
result.get("time_prefill"),
|
||||
result.get("new_tokens"),
|
||||
result.get("time_generate"),
|
||||
context_len,
|
||||
max_seq_len,
|
||||
)
|
||||
|
||||
# Remove the token text
|
||||
generation = {
|
||||
"prompt_tokens": generation.get("prompt_tokens"),
|
||||
"generated_tokens": generation.get("generated_tokens"),
|
||||
"finish_reason": finish_reason,
|
||||
}
|
||||
|
||||
yield generation
|
||||
break
|
||||
except asyncio.CancelledError:
|
||||
await job.cancel()
|
||||
|
||||
Reference in New Issue
Block a user