From 71b9a53336b0451045240291a4e3bba4910f96d8 Mon Sep 17 00:00:00 2001 From: kingbri Date: Tue, 21 Nov 2023 21:20:59 -0500 Subject: [PATCH] API: Add temperature_last support Documented in previous commits. Also make sure that for version checking, check the value of kwargs instead of if the key is present since requests pass default values. Signed-off-by: kingbri --- OAI/types/common.py | 4 +++- model.py | 6 +++--- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/OAI/types/common.py b/OAI/types/common.py index 919bf79..fcdeede 100644 --- a/OAI/types/common.py +++ b/OAI/types/common.py @@ -40,6 +40,7 @@ class CommonCompletionRequest(BaseModel): # Sampling params token_healing: Optional[bool] = False temperature: Optional[float] = 1.0 + temperature_last: Optional[bool] = False top_k: Optional[int] = 0 top_p: Optional[float] = 1.0 typical: Optional[float] = 0.0 @@ -71,6 +72,7 @@ class CommonCompletionRequest(BaseModel): "ban_eos_token": self.ban_eos_token, "token_healing": self.token_healing, "temperature": self.temperature, + "temperature_last": self.temperature_last, "top_k": self.top_k, "top_p": self.top_p, "typical": self.typical, @@ -81,5 +83,5 @@ class CommonCompletionRequest(BaseModel): "repetition_decay": self.repetition_decay, "mirostat": self.mirostat_mode == 2, "mirostat_tau": self.mirostat_tau, - "mirostat_eta": self.mirostat_eta + "mirostat_eta": self.mirostat_eta, } diff --git a/model.py b/model.py index 8d3480b..604a314 100644 --- a/model.py +++ b/model.py @@ -271,9 +271,9 @@ class ModelContainer: gen_settings = ExLlamaV2Sampler.Settings() - # Warn if unsupported settings supplied + # Warn of unsupported settings if the setting is enabled - if "mirostat" in kwargs and not hasattr(gen_settings, "mirostat"): + if kwargs.get("mirostat", False) and not hasattr(gen_settings, "mirostat"): print(" !! Warning: Currently installed ExLlamaV2 does not support Mirostat sampling") if kwargs.get("min_p", 0.0) not in [0.0, 1.0] and not hasattr(gen_settings, "min_p"): @@ -282,7 +282,7 @@ class ModelContainer: if kwargs.get("tfs", 0.0) not in [0.0, 1.0] and not hasattr(gen_settings, "tfs"): print(" !! Warning: Currently installed ExLlamaV2 does not support tail-free sampling (TFS)") - if "temperature_last" in kwargs and not hasattr(gen_settings, "temperature_last"): + if kwargs.get("temperature_last", False) and not hasattr(gen_settings, "temperature_last"): print(" !! Warning: Currently installed ExLlamaV2 does not support temperature_last") #Apply settings