diff --git a/backends/exllamav2/model.py b/backends/exllamav2/model.py index 4ffded7..e0c7cce 100644 --- a/backends/exllamav2/model.py +++ b/backends/exllamav2/model.py @@ -1,5 +1,6 @@ """The model container class for ExLlamaV2 models.""" +from functools import partial import aiofiles import asyncio import gc @@ -31,7 +32,7 @@ from exllamav2.generator import ( ) from itertools import zip_longest from loguru import logger -from typing import List, Optional, Union +from typing import Dict, List, Optional, Union from ruamel.yaml import YAML @@ -106,6 +107,7 @@ class ExllamaV2Container: # Load synchronization # The lock keeps load tasks sequential # The condition notifies any waiting tasks + active_job_ids: Dict[str, ExLlamaV2DynamicJobAsync] = {} load_lock: asyncio.Lock = asyncio.Lock() load_condition: asyncio.Condition = asyncio.Condition() @@ -887,12 +889,7 @@ class ExllamaV2Container: self.model = None if self.vision_model: - # TODO: Remove this with newer exl2 versions - # Required otherwise unload function won't finish - try: - self.vision_model.unload() - except AttributeError: - pass + self.vision_model.unload() self.vision_model = None @@ -950,7 +947,6 @@ class ExllamaV2Container: decode_special_tokens=unwrap(kwargs.get("decode_special_tokens"), True), )[0] - # TODO: Maybe support generation_config for eos_token def get_special_tokens( self, add_bos_token: bool = True, ban_eos_token: bool = False ): @@ -1042,13 +1038,6 @@ class ExllamaV2Container: Meant for dev wheels! """ - if unwrap(kwargs.get("xtc_probability"), 0.0) > 0.0 and not hasattr( - ExLlamaV2Sampler.Settings, "xtc_probability" - ): - logger.warning( - "XTC is not supported by the currently " "installed ExLlamaV2 version." - ) - return kwargs async def generate_gen( @@ -1082,6 +1071,7 @@ class ExllamaV2Container: kwargs = self.check_unsupported_settings(**kwargs) # Apply settings + partial(gen_settings.temperature, 1.0) gen_settings.temperature = unwrap(kwargs.get("temperature"), 1.0) gen_settings.temperature_last = unwrap(kwargs.get("temperature_last"), False) gen_settings.smoothing_factor = unwrap(kwargs.get("smoothing_factor"), 0.0) @@ -1191,7 +1181,6 @@ class ExllamaV2Container: if dry_multiplier > 0: gen_settings.dry_multiplier = dry_multiplier - # TODO: Maybe set the "sane" defaults instead? gen_settings.dry_allowed_length = unwrap( kwargs.get("dry_allowed_length"), 0 ) @@ -1261,18 +1250,10 @@ class ExllamaV2Container: gen_settings.typical = 0 logger.warning( - "".join( - [ - "Temperature is set to 0. Overriding temp, ", - "top_k, top_p, and typical to 1.0, 1, 0, and 0.", - ] - ) + "Temperature is set to 0. Overriding temp, " + "top_k, top_p, and typical to 1.0, 1, 0, and 0." ) - # Store the gen settings for logging purposes - # Deepcopy to save a snapshot of vars - gen_settings_log_dict = deepcopy(vars(gen_settings)) - # Set banned tokens banned_tokens = unwrap(kwargs.get("banned_tokens"), []) if banned_tokens: @@ -1522,26 +1503,11 @@ class ExllamaV2Container: # Some options are too large, so log the args instead log_generation_params( request_id=request_id, - max_tokens=max_tokens, - min_tokens=min_tokens, - stream=kwargs.get("stream"), - **gen_settings_log_dict, - token_healing=token_healing, - auto_scale_penalty_range=auto_scale_penalty_range, - generate_window=generate_window, bos_token_id=self.tokenizer.bos_token_id, eos_token_id=eos_tokens, - add_bos_token=add_bos_token, - ban_eos_token=ban_eos_token, - skip_special_tokens=not decode_special_tokens, - speculative_ngram=self.generator.speculative_ngram, - logprobs=request_logprobs, - stop_conditions=stop_conditions, - banned_tokens=banned_tokens, - allowed_tokens=allowed_tokens, - banned_strings=banned_strings, - logit_bias=logit_bias, - filters=grammar_handler.filters, + **kwargs, + generate_window=generate_window, + auto_scale_penalty_range=auto_scale_penalty_range, ) # Log the metrics if present diff --git a/common/sampling.py b/common/sampling.py index 7e5ded4..d2c230c 100644 --- a/common/sampling.py +++ b/common/sampling.py @@ -165,7 +165,7 @@ class BaseSamplerRequest(BaseModel): "rep_pen_range", ), description=( - "Aliases: repetition_range, repetition_penalty_range, " "rep_pen_range" + "Aliases: repetition_range, repetition_penalty_range, rep_pen_range" ), ) diff --git a/common/utils.py b/common/utils.py index 97ecaf7..9fd46de 100644 --- a/common/utils.py +++ b/common/utils.py @@ -1,5 +1,6 @@ """Common utility functions""" +import inspect from types import NoneType from typing import Dict, Optional, Type, Union, get_args, get_origin, TypeVar @@ -85,3 +86,54 @@ def unwrap_optional_type(type_hint) -> Type: return arg return type_hint + + +def with_defer(func): + """ + Decorator for a go-style defer + """ + + def wrapper(*args, **kwargs): + deferred_calls = [] + + # This 'defer' function is what you'll call inside your decorated function + def defer(fn, *fn_args, **fn_kwargs): + deferred_calls.append((fn, fn_args, fn_kwargs)) + + try: + # Inject 'defer' into the kwargs of the original function + return func(*args, defer=defer, **kwargs) + finally: + # After the original function finishes (or raises), run deferred calls + for fn, fn_args, fn_kwargs in reversed(deferred_calls): + fn(*fn_args, **fn_kwargs) + + return wrapper + + +def with_defer_async(func): + """ + Decorator for running async functions in go-style defer blocks + """ + + async def wrapper(*args, **kwargs): + deferred_calls = [] + + # This 'defer' function is what you'll call inside your decorated function + def defer(fn, *fn_args, **fn_kwargs): + deferred_calls.append((fn, fn_args, fn_kwargs)) + + try: + # Inject 'defer' into the kwargs of the original function + return await func(*args, defer=defer, **kwargs) + finally: + # After the original function finishes (or raises), run deferred calls + for fn, fn_args, fn_kwargs in reversed(deferred_calls): + if inspect.iscoroutinefunction(fn): + await fn(*fn_args, **fn_kwargs) + elif inspect.iscoroutine(fn): + await fn + else: + fn(*fn_args, **fn_kwargs) + + return wrapper