diff --git a/model.py b/model.py index d10cfe8..35fb02c 100644 --- a/model.py +++ b/model.py @@ -574,26 +574,29 @@ class ModelContainer: ) auto_scale_penalty_range = False - # Frequency penalty = repetition penalty if the user is on an older exl2 version frequency_penalty = unwrap(kwargs.get("frequency_penalty"), 0.0) - if (frequency_penalty) != 0.0 and not hasattr( - gen_settings, "token_frequency_penalty" - ): - logger.warning( - "Frequency penalty is not supported by the currently " - "installed ExLlamaV2 version. Setting this value to repetition penalty " - "instead." - ) - gen_settings.token_repetition_penalty = frequency_penalty - else: + if hasattr(gen_settings, "token_frequency_penalty"): + gen_settings.token_frequency_penalty = frequency_penalty + # Dynamically scale penalty range to output tokens - # Only do this if freq/pres pen is enabled and the repetition range is -1 + # Only do this if freq/pres pen is enabled + # and the repetition range is -1 auto_scale_penalty_range = ( gen_settings.token_frequency_penalty != 0 or gen_settings.token_presence_penalty != 0 ) and gen_settings.token_repetition_range == -1 + elif frequency_penalty != 0.0: + logger.warning( + "Frequency penalty is not supported by the currently " + "installed ExLlamaV2 version." + ) + + # Override the repetition penalty value if it isn't set already + # if the user is on an older exl2 version + if unwrap(gen_settings.token_repetition_penalty, 1.0) == 1.0: + gen_settings.token_repetition_penalty = frequency_penalty + logger.warning("Setting this value to repetition penalty instead.") - gen_settings.token_frequency_penalty = frequency_penalty # Always make sure the fallback is 0 if range < 0 # It's technically fine to use -1, but this just validates the passed