Model: Add fallback for freq and presence pen

Previous behavior aliased freq pen for rep pen. Keep this behavior
when using the freq pen parameter with a legacy exllamav2 version
rather than ignoring both entirely.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri
2023-12-30 00:24:15 -05:00
parent 79a57588d5
commit 0dc12d82d5

View File

@@ -474,14 +474,6 @@ class ModelContainer:
"installed ExLlamaV2 version."
)
if (unwrap(kwargs.get("frequency_penalty"), 0.0)) != 0.0 and not hasattr(
ExLlamaV2Sampler.Settings, "token_frequency_penalty"
):
logger.warning(
"Frequency penalty is not supported by the currently "
"installed ExLlamaV2 version."
)
if (unwrap(kwargs.get("presence_penalty"), 0.0)) != 0.0 and not hasattr(
ExLlamaV2Sampler.Settings, "token_presence_penalty"
):
@@ -568,9 +560,7 @@ class ModelContainer:
# Default tau and eta fallbacks don't matter if mirostat is off
gen_settings.mirostat_tau = unwrap(kwargs.get("mirostat_tau"), 1.5)
gen_settings.mirostat_eta = unwrap(kwargs.get("mirostat_eta"), 0.1)
gen_settings.token_frequency_penalty = unwrap(
kwargs.get("frequency_penalty"), 0.0
)
gen_settings.token_presence_penalty = unwrap(
kwargs.get("presence_penalty"), 0.0
)
@@ -582,13 +572,28 @@ class ModelContainer:
gen_settings.token_repetition_range = unwrap(
kwargs.get("penalty_range"), self.config.max_seq_len
)
auto_scale_penalty_range = False
# Dynamically scale penalty range to output tokens
# Only do this if freq/pres pen is enabled and the repetition range is -1
auto_scale_penalty_range = (
gen_settings.token_frequency_penalty != 0
or gen_settings.token_presence_penalty != 0
) and gen_settings.token_repetition_range == -1
# Frequency penalty = repetition penalty if the user is on an older exl2 version
frequency_penalty = unwrap(kwargs.get("frequency_penalty"), 0.0)
if (frequency_penalty) != 0.0 and not hasattr(
gen_settings, "token_frequency_penalty"
):
logger.warning(
"Frequency penalty is not supported by the currently "
"installed ExLlamaV2 version. Setting this value to repetition penalty "
"instead."
)
gen_settings.token_repetition_penalty = frequency_penalty
else:
# Dynamically scale penalty range to output tokens
# Only do this if freq/pres pen is enabled and the repetition range is -1
auto_scale_penalty_range = (
gen_settings.token_frequency_penalty != 0
or gen_settings.token_presence_penalty != 0
) and gen_settings.token_repetition_range == -1
gen_settings.token_frequency_penalty = frequency_penalty
# Always make sure the fallback is 0 if range < 0
# It's technically fine to use -1, but this just validates the passed