mirror of
https://github.com/theroyallab/tabbyAPI.git
synced 2026-04-29 02:31:48 +00:00
Model: Cleanup logging and remove extraneous declarations
Log the parameters passed into the generate gen function rather than the generation settings to reduce complexity. Signed-off-by: kingbri <8082010+kingbri1@users.noreply.github.com>
This commit is contained in:
@@ -1,5 +1,6 @@
|
|||||||
"""The model container class for ExLlamaV2 models."""
|
"""The model container class for ExLlamaV2 models."""
|
||||||
|
|
||||||
|
from functools import partial
|
||||||
import aiofiles
|
import aiofiles
|
||||||
import asyncio
|
import asyncio
|
||||||
import gc
|
import gc
|
||||||
@@ -31,7 +32,7 @@ from exllamav2.generator import (
|
|||||||
)
|
)
|
||||||
from itertools import zip_longest
|
from itertools import zip_longest
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from typing import List, Optional, Union
|
from typing import Dict, List, Optional, Union
|
||||||
|
|
||||||
from ruamel.yaml import YAML
|
from ruamel.yaml import YAML
|
||||||
|
|
||||||
@@ -106,6 +107,7 @@ class ExllamaV2Container:
|
|||||||
# Load synchronization
|
# Load synchronization
|
||||||
# The lock keeps load tasks sequential
|
# The lock keeps load tasks sequential
|
||||||
# The condition notifies any waiting tasks
|
# The condition notifies any waiting tasks
|
||||||
|
active_job_ids: Dict[str, ExLlamaV2DynamicJobAsync] = {}
|
||||||
load_lock: asyncio.Lock = asyncio.Lock()
|
load_lock: asyncio.Lock = asyncio.Lock()
|
||||||
load_condition: asyncio.Condition = asyncio.Condition()
|
load_condition: asyncio.Condition = asyncio.Condition()
|
||||||
|
|
||||||
@@ -887,12 +889,7 @@ class ExllamaV2Container:
|
|||||||
self.model = None
|
self.model = None
|
||||||
|
|
||||||
if self.vision_model:
|
if self.vision_model:
|
||||||
# TODO: Remove this with newer exl2 versions
|
self.vision_model.unload()
|
||||||
# Required otherwise unload function won't finish
|
|
||||||
try:
|
|
||||||
self.vision_model.unload()
|
|
||||||
except AttributeError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
self.vision_model = None
|
self.vision_model = None
|
||||||
|
|
||||||
@@ -950,7 +947,6 @@ class ExllamaV2Container:
|
|||||||
decode_special_tokens=unwrap(kwargs.get("decode_special_tokens"), True),
|
decode_special_tokens=unwrap(kwargs.get("decode_special_tokens"), True),
|
||||||
)[0]
|
)[0]
|
||||||
|
|
||||||
# TODO: Maybe support generation_config for eos_token
|
|
||||||
def get_special_tokens(
|
def get_special_tokens(
|
||||||
self, add_bos_token: bool = True, ban_eos_token: bool = False
|
self, add_bos_token: bool = True, ban_eos_token: bool = False
|
||||||
):
|
):
|
||||||
@@ -1042,13 +1038,6 @@ class ExllamaV2Container:
|
|||||||
Meant for dev wheels!
|
Meant for dev wheels!
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if unwrap(kwargs.get("xtc_probability"), 0.0) > 0.0 and not hasattr(
|
|
||||||
ExLlamaV2Sampler.Settings, "xtc_probability"
|
|
||||||
):
|
|
||||||
logger.warning(
|
|
||||||
"XTC is not supported by the currently " "installed ExLlamaV2 version."
|
|
||||||
)
|
|
||||||
|
|
||||||
return kwargs
|
return kwargs
|
||||||
|
|
||||||
async def generate_gen(
|
async def generate_gen(
|
||||||
@@ -1082,6 +1071,7 @@ class ExllamaV2Container:
|
|||||||
kwargs = self.check_unsupported_settings(**kwargs)
|
kwargs = self.check_unsupported_settings(**kwargs)
|
||||||
|
|
||||||
# Apply settings
|
# Apply settings
|
||||||
|
partial(gen_settings.temperature, 1.0)
|
||||||
gen_settings.temperature = unwrap(kwargs.get("temperature"), 1.0)
|
gen_settings.temperature = unwrap(kwargs.get("temperature"), 1.0)
|
||||||
gen_settings.temperature_last = unwrap(kwargs.get("temperature_last"), False)
|
gen_settings.temperature_last = unwrap(kwargs.get("temperature_last"), False)
|
||||||
gen_settings.smoothing_factor = unwrap(kwargs.get("smoothing_factor"), 0.0)
|
gen_settings.smoothing_factor = unwrap(kwargs.get("smoothing_factor"), 0.0)
|
||||||
@@ -1191,7 +1181,6 @@ class ExllamaV2Container:
|
|||||||
if dry_multiplier > 0:
|
if dry_multiplier > 0:
|
||||||
gen_settings.dry_multiplier = dry_multiplier
|
gen_settings.dry_multiplier = dry_multiplier
|
||||||
|
|
||||||
# TODO: Maybe set the "sane" defaults instead?
|
|
||||||
gen_settings.dry_allowed_length = unwrap(
|
gen_settings.dry_allowed_length = unwrap(
|
||||||
kwargs.get("dry_allowed_length"), 0
|
kwargs.get("dry_allowed_length"), 0
|
||||||
)
|
)
|
||||||
@@ -1261,18 +1250,10 @@ class ExllamaV2Container:
|
|||||||
gen_settings.typical = 0
|
gen_settings.typical = 0
|
||||||
|
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"".join(
|
"Temperature is set to 0. Overriding temp, "
|
||||||
[
|
"top_k, top_p, and typical to 1.0, 1, 0, and 0."
|
||||||
"Temperature is set to 0. Overriding temp, ",
|
|
||||||
"top_k, top_p, and typical to 1.0, 1, 0, and 0.",
|
|
||||||
]
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Store the gen settings for logging purposes
|
|
||||||
# Deepcopy to save a snapshot of vars
|
|
||||||
gen_settings_log_dict = deepcopy(vars(gen_settings))
|
|
||||||
|
|
||||||
# Set banned tokens
|
# Set banned tokens
|
||||||
banned_tokens = unwrap(kwargs.get("banned_tokens"), [])
|
banned_tokens = unwrap(kwargs.get("banned_tokens"), [])
|
||||||
if banned_tokens:
|
if banned_tokens:
|
||||||
@@ -1522,26 +1503,11 @@ class ExllamaV2Container:
|
|||||||
# Some options are too large, so log the args instead
|
# Some options are too large, so log the args instead
|
||||||
log_generation_params(
|
log_generation_params(
|
||||||
request_id=request_id,
|
request_id=request_id,
|
||||||
max_tokens=max_tokens,
|
|
||||||
min_tokens=min_tokens,
|
|
||||||
stream=kwargs.get("stream"),
|
|
||||||
**gen_settings_log_dict,
|
|
||||||
token_healing=token_healing,
|
|
||||||
auto_scale_penalty_range=auto_scale_penalty_range,
|
|
||||||
generate_window=generate_window,
|
|
||||||
bos_token_id=self.tokenizer.bos_token_id,
|
bos_token_id=self.tokenizer.bos_token_id,
|
||||||
eos_token_id=eos_tokens,
|
eos_token_id=eos_tokens,
|
||||||
add_bos_token=add_bos_token,
|
**kwargs,
|
||||||
ban_eos_token=ban_eos_token,
|
generate_window=generate_window,
|
||||||
skip_special_tokens=not decode_special_tokens,
|
auto_scale_penalty_range=auto_scale_penalty_range,
|
||||||
speculative_ngram=self.generator.speculative_ngram,
|
|
||||||
logprobs=request_logprobs,
|
|
||||||
stop_conditions=stop_conditions,
|
|
||||||
banned_tokens=banned_tokens,
|
|
||||||
allowed_tokens=allowed_tokens,
|
|
||||||
banned_strings=banned_strings,
|
|
||||||
logit_bias=logit_bias,
|
|
||||||
filters=grammar_handler.filters,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Log the metrics if present
|
# Log the metrics if present
|
||||||
|
|||||||
@@ -165,7 +165,7 @@ class BaseSamplerRequest(BaseModel):
|
|||||||
"rep_pen_range",
|
"rep_pen_range",
|
||||||
),
|
),
|
||||||
description=(
|
description=(
|
||||||
"Aliases: repetition_range, repetition_penalty_range, " "rep_pen_range"
|
"Aliases: repetition_range, repetition_penalty_range, rep_pen_range"
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
"""Common utility functions"""
|
"""Common utility functions"""
|
||||||
|
|
||||||
|
import inspect
|
||||||
from types import NoneType
|
from types import NoneType
|
||||||
from typing import Dict, Optional, Type, Union, get_args, get_origin, TypeVar
|
from typing import Dict, Optional, Type, Union, get_args, get_origin, TypeVar
|
||||||
|
|
||||||
@@ -85,3 +86,54 @@ def unwrap_optional_type(type_hint) -> Type:
|
|||||||
return arg
|
return arg
|
||||||
|
|
||||||
return type_hint
|
return type_hint
|
||||||
|
|
||||||
|
|
||||||
|
def with_defer(func):
|
||||||
|
"""
|
||||||
|
Decorator for a go-style defer
|
||||||
|
"""
|
||||||
|
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
deferred_calls = []
|
||||||
|
|
||||||
|
# This 'defer' function is what you'll call inside your decorated function
|
||||||
|
def defer(fn, *fn_args, **fn_kwargs):
|
||||||
|
deferred_calls.append((fn, fn_args, fn_kwargs))
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Inject 'defer' into the kwargs of the original function
|
||||||
|
return func(*args, defer=defer, **kwargs)
|
||||||
|
finally:
|
||||||
|
# After the original function finishes (or raises), run deferred calls
|
||||||
|
for fn, fn_args, fn_kwargs in reversed(deferred_calls):
|
||||||
|
fn(*fn_args, **fn_kwargs)
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
|
def with_defer_async(func):
|
||||||
|
"""
|
||||||
|
Decorator for running async functions in go-style defer blocks
|
||||||
|
"""
|
||||||
|
|
||||||
|
async def wrapper(*args, **kwargs):
|
||||||
|
deferred_calls = []
|
||||||
|
|
||||||
|
# This 'defer' function is what you'll call inside your decorated function
|
||||||
|
def defer(fn, *fn_args, **fn_kwargs):
|
||||||
|
deferred_calls.append((fn, fn_args, fn_kwargs))
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Inject 'defer' into the kwargs of the original function
|
||||||
|
return await func(*args, defer=defer, **kwargs)
|
||||||
|
finally:
|
||||||
|
# After the original function finishes (or raises), run deferred calls
|
||||||
|
for fn, fn_args, fn_kwargs in reversed(deferred_calls):
|
||||||
|
if inspect.iscoroutinefunction(fn):
|
||||||
|
await fn(*fn_args, **fn_kwargs)
|
||||||
|
elif inspect.iscoroutine(fn):
|
||||||
|
await fn
|
||||||
|
else:
|
||||||
|
fn(*fn_args, **fn_kwargs)
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|||||||
Reference in New Issue
Block a user