Model: Fix sampler bugs

Lots of bugs were unearthed when switching to the new fallback changes.
Fix them and make sure samplers are being set properly.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri
2023-12-06 17:25:13 -05:00
parent 9f34af4906
commit 6a71890d45
2 changed files with 10 additions and 11 deletions

View File

@@ -43,11 +43,10 @@ class CommonCompletionRequest(BaseModel):
temperature_last: Optional[bool] = False
top_k: Optional[int] = 0
top_p: Optional[float] = 1.0
typical: Optional[float] = 0.0
typical: Optional[float] = 1.0
min_p: Optional[float] = 0.0
tfs: Optional[float] = 1.0
repetition_penalty: Optional[float] = 1.0
repetition_penalty_range: Optional[int] = 0
repetition_decay: Optional[int] = 0
mirostat_mode: Optional[int] = 0
mirostat_tau: Optional[float] = 1.5

View File

@@ -284,27 +284,27 @@ class ModelContainer:
# Warn of unsupported settings if the setting is enabled
if kwargs.get("mirostat") or False and not hasattr(gen_settings, "mirostat"):
if (kwargs.get("mirostat") or False) and not hasattr(gen_settings, "mirostat"):
print(" !! Warning: Currently installed ExLlamaV2 does not support Mirostat sampling")
if kwargs.get("min_p") or 0.0 not in [0.0, 1.0] and not hasattr(gen_settings, "min_p"):
if (kwargs.get("min_p") or 0.0) not in [0.0, 1.0] and not hasattr(gen_settings, "min_p"):
print(" !! Warning: Currently installed ExLlamaV2 does not support min-P sampling")
if kwargs.get("tfs") or 0.0 not in [0.0, 1.0] and not hasattr(gen_settings, "tfs"):
if (kwargs.get("tfs") or 0.0) not in [0.0, 1.0] and not hasattr(gen_settings, "tfs"):
print(" !! Warning: Currently installed ExLlamaV2 does not support tail-free sampling (TFS)")
if kwargs.get("temperature_last") or False and not hasattr(gen_settings, "temperature_last"):
if (kwargs.get("temperature_last") or False) and not hasattr(gen_settings, "temperature_last"):
print(" !! Warning: Currently installed ExLlamaV2 does not support temperature_last")
#Apply settings
gen_settings.temperature = kwargs.get("temperature") or 1.0
gen_settings.temperature_last = kwargs.get("temperature_last") or False
gen_settings.top_k = kwargs.get("top_k") or 1
gen_settings.top_k = kwargs.get("top_k") or 0
gen_settings.top_p = kwargs.get("top_p") or 1.0
gen_settings.min_p = kwargs.get("min_p") or 0.0
gen_settings.tfs = kwargs.get("tfs") or 0.0
gen_settings.typical = kwargs.get("typical") or 0.0
gen_settings.tfs = kwargs.get("tfs") or 1.0
gen_settings.typical = kwargs.get("typical") or 1.0
gen_settings.mirostat = kwargs.get("mirostat") or False
# Default tau and eta fallbacks don't matter if mirostat is off
@@ -312,11 +312,11 @@ class ModelContainer:
gen_settings.mirostat_eta = kwargs.get("mirostat_eta") or 0.1
gen_settings.token_repetition_penalty = kwargs.get("repetition_penalty") or 1.0
gen_settings.token_repetition_range = kwargs.get("repetition_range") or self.config.max_seq_len
print(gen_settings.token_repetition_range)
# Always make sure the fallback is 0 if range < 0
# It's technically fine to use -1, but this just validates the passed fallback
fallback_decay = 0 if gen_settings.token_repetition_penalty <= 0 else gen_settings.token_repetition_range
fallback_decay = 0 if gen_settings.token_repetition_range <= 0 else gen_settings.token_repetition_range
gen_settings.token_repetition_decay = kwargs.get("repetition_decay") or fallback_decay or 0
stop_conditions: List[Union[str, int]] = kwargs.get("stop") or []