Model: Fix frequency penalty fallback

The appropriate branches weren't firing when frequency penalty is
0.0. Also fix repetition penalty overriding.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri
2023-12-31 11:12:14 -05:00
parent 47744fe9f7
commit 72bc30343c

View File

@@ -574,26 +574,29 @@ class ModelContainer:
)
auto_scale_penalty_range = False
# Frequency penalty = repetition penalty if the user is on an older exl2 version
frequency_penalty = unwrap(kwargs.get("frequency_penalty"), 0.0)
if (frequency_penalty) != 0.0 and not hasattr(
gen_settings, "token_frequency_penalty"
):
logger.warning(
"Frequency penalty is not supported by the currently "
"installed ExLlamaV2 version. Setting this value to repetition penalty "
"instead."
)
gen_settings.token_repetition_penalty = frequency_penalty
else:
if hasattr(gen_settings, "token_frequency_penalty"):
gen_settings.token_frequency_penalty = frequency_penalty
# Dynamically scale penalty range to output tokens
# Only do this if freq/pres pen is enabled and the repetition range is -1
# Only do this if freq/pres pen is enabled
# and the repetition range is -1
auto_scale_penalty_range = (
gen_settings.token_frequency_penalty != 0
or gen_settings.token_presence_penalty != 0
) and gen_settings.token_repetition_range == -1
elif frequency_penalty != 0.0:
logger.warning(
"Frequency penalty is not supported by the currently "
"installed ExLlamaV2 version."
)
# Override the repetition penalty value if it isn't set already
# if the user is on an older exl2 version
if unwrap(gen_settings.token_repetition_penalty, 1.0) == 1.0:
gen_settings.token_repetition_penalty = frequency_penalty
logger.warning("Setting this value to repetition penalty instead.")
gen_settings.token_frequency_penalty = frequency_penalty
# Always make sure the fallback is 0 if range < 0
# It's technically fine to use -1, but this just validates the passed