Samplers: Add dynamic temperature

Does not work if max_temp is less than or equal to min_temp. Sampler
validation will have to be refactored in the future, so the dynamic
temperature check will also be changed.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri
2024-01-31 01:20:59 -05:00
parent 3605067898
commit 4a7b8b1b7a
3 changed files with 50 additions and 0 deletions

View File

@@ -515,6 +515,14 @@ class ExllamaV2Container:
"installed ExLlamaV2 version."
)
if (unwrap(kwargs.get("max_temp"), 0.0)) > 0.0 and not hasattr(
ExLlamaV2Sampler.Settings, "max_temp"
):
logger.warning(
"DynaTemp parameters are not supported by the currently "
"installed ExLlamaV2 version."
)
def generate(self, prompt: str, **kwargs):
"""Generate a response to a prompt"""
generation = list(self.generate_gen(prompt, **kwargs))
@@ -579,6 +587,7 @@ class ExllamaV2Container:
# Sampler settings
gen_settings = ExLlamaV2Sampler.Settings()
# TODO: Migrate settings validation to different function
self.check_unsupported_settings(**kwargs)
# Apply settings
@@ -592,6 +601,22 @@ class ExllamaV2Container:
gen_settings.typical = unwrap(kwargs.get("typical"), 1.0)
gen_settings.mirostat = unwrap(kwargs.get("mirostat"), False)
# DynaTemp settings
if hasattr(gen_settings, "max_temp"):
max_temp = unwrap(kwargs.get("max_temp"), 0.0)
min_temp = unwrap(kwargs.get("min_temp"), 0.0)
if max_temp < min_temp or (
0 not in {min_temp, max_temp} and max_temp == min_temp
):
logger.warning(
"Max temp is less than or equal to min temp, skipping DynaTemp."
)
gen_settings.max_temp = max_temp
gen_settings.min_temp = min_temp
gen_settings.temp_exponent = kwargs.get("temp_exponent")
# Default tau and eta fallbacks don't matter if mirostat is off
gen_settings.mirostat_tau = unwrap(kwargs.get("mirostat_tau"), 1.5)
gen_settings.mirostat_eta = unwrap(kwargs.get("mirostat_eta"), 0.1)