Sampling: Cleanup and update

Cleanup how overrides are handled, class naming, and adopt exllamav2's
model class to enforce latest stable version methods rather than
adding multiple backwards compatability checks.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri
2024-02-01 12:58:55 -05:00
parent 2ea063cea9
commit b827bcbb44
4 changed files with 34 additions and 86 deletions

View File

@@ -468,56 +468,9 @@ class ExllamaV2Container:
}
def check_unsupported_settings(self, **kwargs):
# Warn of unsupported settings if the setting is enabled
if (unwrap(kwargs.get("mirostat"), False)) and not hasattr(
ExLlamaV2Sampler.Settings, "mirostat"
):
logger.warning(
"Mirostat sampling is not supported by the currently "
"installed ExLlamaV2 version."
)
"""Check and warn the user if a sampler is unsupported. Meant for dev wheels!"""
if (unwrap(kwargs.get("min_p"), 0.0)) not in [0.0, 1.0] and not hasattr(
ExLlamaV2Sampler.Settings, "min_p"
):
logger.warning(
"Min-P sampling is not supported by the currently "
"installed ExLlamaV2 version."
)
if (unwrap(kwargs.get("tfs"), 0.0)) not in [0.0, 1.0] and not hasattr(
ExLlamaV2Sampler.Settings, "tfs"
):
logger.warning(
"Tail-free sampling (TFS) is not supported by the currently "
"installed ExLlamaV2 version."
)
if (unwrap(kwargs.get("temperature_last"), False)) and not hasattr(
ExLlamaV2Sampler.Settings, "temperature_last"
):
logger.warning(
"Temperature last is not supported by the currently "
"installed ExLlamaV2 version."
)
if (unwrap(kwargs.get("top_a"), False)) and not hasattr(
ExLlamaV2Sampler.Settings, "top_a"
):
logger.warning(
"Top-A is not supported by the currently "
"installed ExLlamaV2 version."
)
if (unwrap(kwargs.get("presence_penalty"), 0.0)) != 0.0 and not hasattr(
ExLlamaV2Sampler.Settings, "token_presence_penalty"
):
logger.warning(
"Presence penalty is not supported by the currently "
"installed ExLlamaV2 version."
)
if (unwrap(kwargs.get("max_temp"), 0.0)) > 0.0 and not hasattr(
if kwargs.get("max_temp") > 0.0 and not hasattr(
ExLlamaV2Sampler.Settings, "max_temp"
):
logger.warning(
@@ -597,7 +550,7 @@ class ExllamaV2Container:
# Sampler settings
gen_settings = ExLlamaV2Sampler.Settings()
# TODO: Migrate settings validation to different function
# Check unsupported settings for dev wheels
self.check_unsupported_settings(**kwargs)
# Apply settings
@@ -646,44 +599,31 @@ class ExllamaV2Container:
else:
logger.warn(
"CFG is currently disabled. "
+ "Please reload your model with use_cfg = True.",
"Please reload your model with use_cfg = True.",
)
gen_settings.token_presence_penalty = unwrap(
kwargs.get("presence_penalty"), 0.0
)
gen_settings.token_repetition_penalty = unwrap(
kwargs.get("repetition_penalty"), 1.0
)
gen_settings.token_frequency_penalty = unwrap(
kwargs.get("frequency_penalty"), 0.0
)
gen_settings.token_presence_penalty = unwrap(
kwargs.get("presence_penalty"), 0.0
)
# Applies for all penalties despite being called token_repetition_range
gen_settings.token_repetition_range = unwrap(
kwargs.get("penalty_range"), self.config.max_seq_len
)
auto_scale_penalty_range = False
frequency_penalty = unwrap(kwargs.get("frequency_penalty"), 0.0)
if hasattr(gen_settings, "token_frequency_penalty"):
gen_settings.token_frequency_penalty = frequency_penalty
# Dynamically scale penalty range to output tokens
# Only do this if freq/pres pen is enabled
# and the repetition range is -1
auto_scale_penalty_range = (
gen_settings.token_frequency_penalty != 0
or gen_settings.token_presence_penalty != 0
) and gen_settings.token_repetition_range == -1
elif frequency_penalty != 0.0:
logger.warning(
"Frequency penalty is not supported by the currently "
"installed ExLlamaV2 version."
)
# Override the repetition penalty value if it isn't set already
# if the user is on an older exl2 version
if unwrap(gen_settings.token_repetition_penalty, 1.0) == 1.0:
gen_settings.token_repetition_penalty = frequency_penalty
logger.warning("Setting this value to repetition penalty instead.")
# Dynamically scale penalty range to output tokens
# Only do this if freq/pres pen is enabled
# and the repetition range is -1
auto_scale_penalty_range = (
gen_settings.token_frequency_penalty != 0
or gen_settings.token_presence_penalty != 0
) and gen_settings.token_repetition_range == -1
# Always make sure the fallback is 0 if range < 0
# It's technically fine to use -1, but this just validates the passed
@@ -820,7 +760,7 @@ class ExllamaV2Container:
gen_settings.token_repetition_range = generated_tokens
# Generate
chunk, eos, tokens = self.generator.stream()
chunk, eos, tokens, _, *extra_parts = self.generator.stream()
if token_healing:
# Extract healed token