From 6a71890d4580645305ecbf081725f1b350f2a655 Mon Sep 17 00:00:00 2001 From: kingbri Date: Wed, 6 Dec 2023 17:25:13 -0500 Subject: [PATCH] Model: Fix sampler bugs Lots of bugs were unearthed when switching to the new fallback changes. Fix them and make sure samplers are being set properly. Signed-off-by: kingbri --- OAI/types/common.py | 3 +-- model.py | 18 +++++++++--------- 2 files changed, 10 insertions(+), 11 deletions(-) diff --git a/OAI/types/common.py b/OAI/types/common.py index 2c2660e..cbdc44e 100644 --- a/OAI/types/common.py +++ b/OAI/types/common.py @@ -43,11 +43,10 @@ class CommonCompletionRequest(BaseModel): temperature_last: Optional[bool] = False top_k: Optional[int] = 0 top_p: Optional[float] = 1.0 - typical: Optional[float] = 0.0 + typical: Optional[float] = 1.0 min_p: Optional[float] = 0.0 tfs: Optional[float] = 1.0 repetition_penalty: Optional[float] = 1.0 - repetition_penalty_range: Optional[int] = 0 repetition_decay: Optional[int] = 0 mirostat_mode: Optional[int] = 0 mirostat_tau: Optional[float] = 1.5 diff --git a/model.py b/model.py index 4f72b4d..647f34b 100644 --- a/model.py +++ b/model.py @@ -284,27 +284,27 @@ class ModelContainer: # Warn of unsupported settings if the setting is enabled - if kwargs.get("mirostat") or False and not hasattr(gen_settings, "mirostat"): + if (kwargs.get("mirostat") or False) and not hasattr(gen_settings, "mirostat"): print(" !! Warning: Currently installed ExLlamaV2 does not support Mirostat sampling") - if kwargs.get("min_p") or 0.0 not in [0.0, 1.0] and not hasattr(gen_settings, "min_p"): + if (kwargs.get("min_p") or 0.0) not in [0.0, 1.0] and not hasattr(gen_settings, "min_p"): print(" !! Warning: Currently installed ExLlamaV2 does not support min-P sampling") - if kwargs.get("tfs") or 0.0 not in [0.0, 1.0] and not hasattr(gen_settings, "tfs"): + if (kwargs.get("tfs") or 0.0) not in [0.0, 1.0] and not hasattr(gen_settings, "tfs"): print(" !! Warning: Currently installed ExLlamaV2 does not support tail-free sampling (TFS)") - if kwargs.get("temperature_last") or False and not hasattr(gen_settings, "temperature_last"): + if (kwargs.get("temperature_last") or False) and not hasattr(gen_settings, "temperature_last"): print(" !! Warning: Currently installed ExLlamaV2 does not support temperature_last") #Apply settings gen_settings.temperature = kwargs.get("temperature") or 1.0 gen_settings.temperature_last = kwargs.get("temperature_last") or False - gen_settings.top_k = kwargs.get("top_k") or 1 + gen_settings.top_k = kwargs.get("top_k") or 0 gen_settings.top_p = kwargs.get("top_p") or 1.0 gen_settings.min_p = kwargs.get("min_p") or 0.0 - gen_settings.tfs = kwargs.get("tfs") or 0.0 - gen_settings.typical = kwargs.get("typical") or 0.0 + gen_settings.tfs = kwargs.get("tfs") or 1.0 + gen_settings.typical = kwargs.get("typical") or 1.0 gen_settings.mirostat = kwargs.get("mirostat") or False # Default tau and eta fallbacks don't matter if mirostat is off @@ -312,11 +312,11 @@ class ModelContainer: gen_settings.mirostat_eta = kwargs.get("mirostat_eta") or 0.1 gen_settings.token_repetition_penalty = kwargs.get("repetition_penalty") or 1.0 gen_settings.token_repetition_range = kwargs.get("repetition_range") or self.config.max_seq_len - + print(gen_settings.token_repetition_range) # Always make sure the fallback is 0 if range < 0 # It's technically fine to use -1, but this just validates the passed fallback - fallback_decay = 0 if gen_settings.token_repetition_penalty <= 0 else gen_settings.token_repetition_range + fallback_decay = 0 if gen_settings.token_repetition_range <= 0 else gen_settings.token_repetition_range gen_settings.token_repetition_decay = kwargs.get("repetition_decay") or fallback_decay or 0 stop_conditions: List[Union[str, int]] = kwargs.get("stop") or []