API: Add temperature_last support

Documented in previous commits. Also make sure that for version checking,
check the value of kwargs instead of if the key is present since requests
pass default values.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri
2023-11-21 21:20:59 -05:00
parent 3337fe6acc
commit 71b9a53336
2 changed files with 6 additions and 4 deletions

View File

@@ -40,6 +40,7 @@ class CommonCompletionRequest(BaseModel):
# Sampling params
token_healing: Optional[bool] = False
temperature: Optional[float] = 1.0
temperature_last: Optional[bool] = False
top_k: Optional[int] = 0
top_p: Optional[float] = 1.0
typical: Optional[float] = 0.0
@@ -71,6 +72,7 @@ class CommonCompletionRequest(BaseModel):
"ban_eos_token": self.ban_eos_token,
"token_healing": self.token_healing,
"temperature": self.temperature,
"temperature_last": self.temperature_last,
"top_k": self.top_k,
"top_p": self.top_p,
"typical": self.typical,
@@ -81,5 +83,5 @@ class CommonCompletionRequest(BaseModel):
"repetition_decay": self.repetition_decay,
"mirostat": self.mirostat_mode == 2,
"mirostat_tau": self.mirostat_tau,
"mirostat_eta": self.mirostat_eta
"mirostat_eta": self.mirostat_eta,
}

View File

@@ -271,9 +271,9 @@ class ModelContainer:
gen_settings = ExLlamaV2Sampler.Settings()
# Warn if unsupported settings supplied
# Warn of unsupported settings if the setting is enabled
if "mirostat" in kwargs and not hasattr(gen_settings, "mirostat"):
if kwargs.get("mirostat", False) and not hasattr(gen_settings, "mirostat"):
print(" !! Warning: Currently installed ExLlamaV2 does not support Mirostat sampling")
if kwargs.get("min_p", 0.0) not in [0.0, 1.0] and not hasattr(gen_settings, "min_p"):
@@ -282,7 +282,7 @@ class ModelContainer:
if kwargs.get("tfs", 0.0) not in [0.0, 1.0] and not hasattr(gen_settings, "tfs"):
print(" !! Warning: Currently installed ExLlamaV2 does not support tail-free sampling (TFS)")
if "temperature_last" in kwargs and not hasattr(gen_settings, "temperature_last"):
if kwargs.get("temperature_last", False) and not hasattr(gen_settings, "temperature_last"):
print(" !! Warning: Currently installed ExLlamaV2 does not support temperature_last")
#Apply settings