mirror of
https://github.com/theroyallab/tabbyAPI.git
synced 2026-03-15 00:07:28 +00:00
Model: Move unsupported sampler check
Overbloated the generation function. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
113
model.py
113
model.py
@@ -430,6 +430,64 @@ class ModelContainer:
|
||||
"unk_token": self.tokenizer.unk_token,
|
||||
}
|
||||
|
||||
def check_unsupported_settings(self, **kwargs):
|
||||
# Warn of unsupported settings if the setting is enabled
|
||||
if (unwrap(kwargs.get("mirostat"), False)) and not hasattr(
|
||||
ExLlamaV2Sampler.Settings, "mirostat"
|
||||
):
|
||||
logger.warning(
|
||||
"Mirostat sampling is not supported by the currently "
|
||||
"installed ExLlamaV2 version."
|
||||
)
|
||||
|
||||
if (unwrap(kwargs.get("min_p"), 0.0)) not in [0.0, 1.0] and not hasattr(
|
||||
ExLlamaV2Sampler.Settings, "min_p"
|
||||
):
|
||||
logger.warning(
|
||||
"Min-P sampling is not supported by the currently "
|
||||
"installed ExLlamaV2 version."
|
||||
)
|
||||
|
||||
if (unwrap(kwargs.get("tfs"), 0.0)) not in [0.0, 1.0] and not hasattr(
|
||||
ExLlamaV2Sampler.Settings, "tfs"
|
||||
):
|
||||
logger.warning(
|
||||
"Tail-free sampling (TFS) is not supported by the currently "
|
||||
"installed ExLlamaV2 version."
|
||||
)
|
||||
|
||||
if (unwrap(kwargs.get("temperature_last"), False)) and not hasattr(
|
||||
ExLlamaV2Sampler.Settings, "temperature_last"
|
||||
):
|
||||
logger.warning(
|
||||
"Temperature last is not supported by the currently "
|
||||
"installed ExLlamaV2 version."
|
||||
)
|
||||
|
||||
if (unwrap(kwargs.get("top_a"), False)) and not hasattr(
|
||||
ExLlamaV2Sampler.Settings, "top_a"
|
||||
):
|
||||
logger.warning(
|
||||
"Top-A is not supported by the currently "
|
||||
"installed ExLlamaV2 version."
|
||||
)
|
||||
|
||||
if (unwrap(kwargs.get("frequency_penalty"), 0.0)) != 0.0 and not hasattr(
|
||||
ExLlamaV2Sampler.Settings, "token_frequency_penalty"
|
||||
):
|
||||
logger.warning(
|
||||
"Frequency penalty is not supported by the currently "
|
||||
"installed ExLlamaV2 version."
|
||||
)
|
||||
|
||||
if (unwrap(kwargs.get("presence_penalty"), 0.0)) != 0.0 and not hasattr(
|
||||
ExLlamaV2Sampler.Settings, "token_presence_penalty"
|
||||
):
|
||||
logger.warning(
|
||||
"Presence penalty is not supported by the currently "
|
||||
"installed ExLlamaV2 version."
|
||||
)
|
||||
|
||||
def generate(self, prompt: str, **kwargs):
|
||||
"""Generate a response to a prompt"""
|
||||
generation = list(self.generate_gen(prompt, **kwargs))
|
||||
@@ -492,60 +550,7 @@ class ModelContainer:
|
||||
# Sampler settings
|
||||
gen_settings = ExLlamaV2Sampler.Settings()
|
||||
|
||||
# Warn of unsupported settings if the setting is enabled
|
||||
if (unwrap(kwargs.get("mirostat"), False)) and not hasattr(
|
||||
gen_settings, "mirostat"
|
||||
):
|
||||
logger.warning(
|
||||
"Mirostat sampling is not supported by the currently "
|
||||
"installed ExLlamaV2 version."
|
||||
)
|
||||
|
||||
if (unwrap(kwargs.get("min_p"), 0.0)) not in [0.0, 1.0] and not hasattr(
|
||||
gen_settings, "min_p"
|
||||
):
|
||||
logger.warning(
|
||||
"Min-P sampling is not supported by the currently "
|
||||
"installed ExLlamaV2 version."
|
||||
)
|
||||
|
||||
if (unwrap(kwargs.get("tfs"), 0.0)) not in [0.0, 1.0] and not hasattr(
|
||||
gen_settings, "tfs"
|
||||
):
|
||||
logger.warning(
|
||||
"Tail-free sampling (TFS) is not supported by the currently "
|
||||
"installed ExLlamaV2 version."
|
||||
)
|
||||
|
||||
if (unwrap(kwargs.get("temperature_last"), False)) and not hasattr(
|
||||
gen_settings, "temperature_last"
|
||||
):
|
||||
logger.warning(
|
||||
"Temperature last is not supported by the currently "
|
||||
"installed ExLlamaV2 version."
|
||||
)
|
||||
|
||||
if (unwrap(kwargs.get("top_a"), False)) and not hasattr(gen_settings, "top_a"):
|
||||
logger.warning(
|
||||
"Top-A is not supported by the currently "
|
||||
"installed ExLlamaV2 version."
|
||||
)
|
||||
|
||||
if (unwrap(kwargs.get("frequency_penalty"), 0.0)) != 0.0 and not hasattr(
|
||||
gen_settings, "token_frequency_penalty"
|
||||
):
|
||||
logger.warning(
|
||||
"Frequency penalty is not supported by the currently "
|
||||
"installed ExLlamaV2 version."
|
||||
)
|
||||
|
||||
if (unwrap(kwargs.get("presence_penalty"), 0.0)) != 0.0 and not hasattr(
|
||||
gen_settings, "token_presence_penalty"
|
||||
):
|
||||
logger.warning(
|
||||
"Presence penalty is not supported by the currently "
|
||||
"installed ExLlamaV2 version."
|
||||
)
|
||||
self.check_unsupported_settings(**kwargs)
|
||||
|
||||
# Apply settings
|
||||
gen_settings.temperature = unwrap(kwargs.get("temperature"), 1.0)
|
||||
|
||||
Reference in New Issue
Block a user