diff --git a/OAI/types/common.py b/OAI/types/common.py index 919bf79..fcdeede 100644 --- a/OAI/types/common.py +++ b/OAI/types/common.py @@ -40,6 +40,7 @@ class CommonCompletionRequest(BaseModel): # Sampling params token_healing: Optional[bool] = False temperature: Optional[float] = 1.0 + temperature_last: Optional[bool] = False top_k: Optional[int] = 0 top_p: Optional[float] = 1.0 typical: Optional[float] = 0.0 @@ -71,6 +72,7 @@ class CommonCompletionRequest(BaseModel): "ban_eos_token": self.ban_eos_token, "token_healing": self.token_healing, "temperature": self.temperature, + "temperature_last": self.temperature_last, "top_k": self.top_k, "top_p": self.top_p, "typical": self.typical, @@ -81,5 +83,5 @@ class CommonCompletionRequest(BaseModel): "repetition_decay": self.repetition_decay, "mirostat": self.mirostat_mode == 2, "mirostat_tau": self.mirostat_tau, - "mirostat_eta": self.mirostat_eta + "mirostat_eta": self.mirostat_eta, } diff --git a/model.py b/model.py index 8d3480b..604a314 100644 --- a/model.py +++ b/model.py @@ -271,9 +271,9 @@ class ModelContainer: gen_settings = ExLlamaV2Sampler.Settings() - # Warn if unsupported settings supplied + # Warn of unsupported settings if the setting is enabled - if "mirostat" in kwargs and not hasattr(gen_settings, "mirostat"): + if kwargs.get("mirostat", False) and not hasattr(gen_settings, "mirostat"): print(" !! Warning: Currently installed ExLlamaV2 does not support Mirostat sampling") if kwargs.get("min_p", 0.0) not in [0.0, 1.0] and not hasattr(gen_settings, "min_p"): @@ -282,7 +282,7 @@ class ModelContainer: if kwargs.get("tfs", 0.0) not in [0.0, 1.0] and not hasattr(gen_settings, "tfs"): print(" !! Warning: Currently installed ExLlamaV2 does not support tail-free sampling (TFS)") - if "temperature_last" in kwargs and not hasattr(gen_settings, "temperature_last"): + if kwargs.get("temperature_last", False) and not hasattr(gen_settings, "temperature_last"): print(" !! Warning: Currently installed ExLlamaV2 does not support temperature_last") #Apply settings