mirror of
https://github.com/theroyallab/tabbyAPI.git
synced 2026-03-14 15:57:27 +00:00
Model: Fix sampler bugs
Lots of bugs were unearthed when switching to the new fallback changes. Fix them and make sure samplers are being set properly. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
@@ -43,11 +43,10 @@ class CommonCompletionRequest(BaseModel):
|
||||
temperature_last: Optional[bool] = False
|
||||
top_k: Optional[int] = 0
|
||||
top_p: Optional[float] = 1.0
|
||||
typical: Optional[float] = 0.0
|
||||
typical: Optional[float] = 1.0
|
||||
min_p: Optional[float] = 0.0
|
||||
tfs: Optional[float] = 1.0
|
||||
repetition_penalty: Optional[float] = 1.0
|
||||
repetition_penalty_range: Optional[int] = 0
|
||||
repetition_decay: Optional[int] = 0
|
||||
mirostat_mode: Optional[int] = 0
|
||||
mirostat_tau: Optional[float] = 1.5
|
||||
|
||||
18
model.py
18
model.py
@@ -284,27 +284,27 @@ class ModelContainer:
|
||||
|
||||
# Warn of unsupported settings if the setting is enabled
|
||||
|
||||
if kwargs.get("mirostat") or False and not hasattr(gen_settings, "mirostat"):
|
||||
if (kwargs.get("mirostat") or False) and not hasattr(gen_settings, "mirostat"):
|
||||
print(" !! Warning: Currently installed ExLlamaV2 does not support Mirostat sampling")
|
||||
|
||||
if kwargs.get("min_p") or 0.0 not in [0.0, 1.0] and not hasattr(gen_settings, "min_p"):
|
||||
if (kwargs.get("min_p") or 0.0) not in [0.0, 1.0] and not hasattr(gen_settings, "min_p"):
|
||||
print(" !! Warning: Currently installed ExLlamaV2 does not support min-P sampling")
|
||||
|
||||
if kwargs.get("tfs") or 0.0 not in [0.0, 1.0] and not hasattr(gen_settings, "tfs"):
|
||||
if (kwargs.get("tfs") or 0.0) not in [0.0, 1.0] and not hasattr(gen_settings, "tfs"):
|
||||
print(" !! Warning: Currently installed ExLlamaV2 does not support tail-free sampling (TFS)")
|
||||
|
||||
if kwargs.get("temperature_last") or False and not hasattr(gen_settings, "temperature_last"):
|
||||
if (kwargs.get("temperature_last") or False) and not hasattr(gen_settings, "temperature_last"):
|
||||
print(" !! Warning: Currently installed ExLlamaV2 does not support temperature_last")
|
||||
|
||||
#Apply settings
|
||||
|
||||
gen_settings.temperature = kwargs.get("temperature") or 1.0
|
||||
gen_settings.temperature_last = kwargs.get("temperature_last") or False
|
||||
gen_settings.top_k = kwargs.get("top_k") or 1
|
||||
gen_settings.top_k = kwargs.get("top_k") or 0
|
||||
gen_settings.top_p = kwargs.get("top_p") or 1.0
|
||||
gen_settings.min_p = kwargs.get("min_p") or 0.0
|
||||
gen_settings.tfs = kwargs.get("tfs") or 0.0
|
||||
gen_settings.typical = kwargs.get("typical") or 0.0
|
||||
gen_settings.tfs = kwargs.get("tfs") or 1.0
|
||||
gen_settings.typical = kwargs.get("typical") or 1.0
|
||||
gen_settings.mirostat = kwargs.get("mirostat") or False
|
||||
|
||||
# Default tau and eta fallbacks don't matter if mirostat is off
|
||||
@@ -312,11 +312,11 @@ class ModelContainer:
|
||||
gen_settings.mirostat_eta = kwargs.get("mirostat_eta") or 0.1
|
||||
gen_settings.token_repetition_penalty = kwargs.get("repetition_penalty") or 1.0
|
||||
gen_settings.token_repetition_range = kwargs.get("repetition_range") or self.config.max_seq_len
|
||||
|
||||
print(gen_settings.token_repetition_range)
|
||||
|
||||
# Always make sure the fallback is 0 if range < 0
|
||||
# It's technically fine to use -1, but this just validates the passed fallback
|
||||
fallback_decay = 0 if gen_settings.token_repetition_penalty <= 0 else gen_settings.token_repetition_range
|
||||
fallback_decay = 0 if gen_settings.token_repetition_range <= 0 else gen_settings.token_repetition_range
|
||||
gen_settings.token_repetition_decay = kwargs.get("repetition_decay") or fallback_decay or 0
|
||||
|
||||
stop_conditions: List[Union[str, int]] = kwargs.get("stop") or []
|
||||
|
||||
Reference in New Issue
Block a user