diff --git a/model.py b/model.py index 1c26239..0ddddb5 100644 --- a/model.py +++ b/model.py @@ -430,6 +430,64 @@ class ModelContainer: "unk_token": self.tokenizer.unk_token, } + def check_unsupported_settings(self, **kwargs): + # Warn of unsupported settings if the setting is enabled + if (unwrap(kwargs.get("mirostat"), False)) and not hasattr( + ExLlamaV2Sampler.Settings, "mirostat" + ): + logger.warning( + "Mirostat sampling is not supported by the currently " + "installed ExLlamaV2 version." + ) + + if (unwrap(kwargs.get("min_p"), 0.0)) not in [0.0, 1.0] and not hasattr( + ExLlamaV2Sampler.Settings, "min_p" + ): + logger.warning( + "Min-P sampling is not supported by the currently " + "installed ExLlamaV2 version." + ) + + if (unwrap(kwargs.get("tfs"), 0.0)) not in [0.0, 1.0] and not hasattr( + ExLlamaV2Sampler.Settings, "tfs" + ): + logger.warning( + "Tail-free sampling (TFS) is not supported by the currently " + "installed ExLlamaV2 version." + ) + + if (unwrap(kwargs.get("temperature_last"), False)) and not hasattr( + ExLlamaV2Sampler.Settings, "temperature_last" + ): + logger.warning( + "Temperature last is not supported by the currently " + "installed ExLlamaV2 version." + ) + + if (unwrap(kwargs.get("top_a"), False)) and not hasattr( + ExLlamaV2Sampler.Settings, "top_a" + ): + logger.warning( + "Top-A is not supported by the currently " + "installed ExLlamaV2 version." + ) + + if (unwrap(kwargs.get("frequency_penalty"), 0.0)) != 0.0 and not hasattr( + ExLlamaV2Sampler.Settings, "token_frequency_penalty" + ): + logger.warning( + "Frequency penalty is not supported by the currently " + "installed ExLlamaV2 version." + ) + + if (unwrap(kwargs.get("presence_penalty"), 0.0)) != 0.0 and not hasattr( + ExLlamaV2Sampler.Settings, "token_presence_penalty" + ): + logger.warning( + "Presence penalty is not supported by the currently " + "installed ExLlamaV2 version." + ) + def generate(self, prompt: str, **kwargs): """Generate a response to a prompt""" generation = list(self.generate_gen(prompt, **kwargs)) @@ -492,60 +550,7 @@ class ModelContainer: # Sampler settings gen_settings = ExLlamaV2Sampler.Settings() - # Warn of unsupported settings if the setting is enabled - if (unwrap(kwargs.get("mirostat"), False)) and not hasattr( - gen_settings, "mirostat" - ): - logger.warning( - "Mirostat sampling is not supported by the currently " - "installed ExLlamaV2 version." - ) - - if (unwrap(kwargs.get("min_p"), 0.0)) not in [0.0, 1.0] and not hasattr( - gen_settings, "min_p" - ): - logger.warning( - "Min-P sampling is not supported by the currently " - "installed ExLlamaV2 version." - ) - - if (unwrap(kwargs.get("tfs"), 0.0)) not in [0.0, 1.0] and not hasattr( - gen_settings, "tfs" - ): - logger.warning( - "Tail-free sampling (TFS) is not supported by the currently " - "installed ExLlamaV2 version." - ) - - if (unwrap(kwargs.get("temperature_last"), False)) and not hasattr( - gen_settings, "temperature_last" - ): - logger.warning( - "Temperature last is not supported by the currently " - "installed ExLlamaV2 version." - ) - - if (unwrap(kwargs.get("top_a"), False)) and not hasattr(gen_settings, "top_a"): - logger.warning( - "Top-A is not supported by the currently " - "installed ExLlamaV2 version." - ) - - if (unwrap(kwargs.get("frequency_penalty"), 0.0)) != 0.0 and not hasattr( - gen_settings, "token_frequency_penalty" - ): - logger.warning( - "Frequency penalty is not supported by the currently " - "installed ExLlamaV2 version." - ) - - if (unwrap(kwargs.get("presence_penalty"), 0.0)) != 0.0 and not hasattr( - gen_settings, "token_presence_penalty" - ): - logger.warning( - "Presence penalty is not supported by the currently " - "installed ExLlamaV2 version." - ) + self.check_unsupported_settings(**kwargs) # Apply settings gen_settings.temperature = unwrap(kwargs.get("temperature"), 1.0)