From 63650d2c3c46e44f6690e0689ebf3aba331313be Mon Sep 17 00:00:00 2001 From: kingbri Date: Mon, 5 Aug 2024 11:08:58 -0400 Subject: [PATCH] Model: Disable banned strings if grammar is used ExllamaV2 filters don't allow for rewinding which is what banned strings uses. Therefore, constrained generation via LMFE or outlines is not compatible for now. Signed-off-by: kingbri --- backends/exllamav2/model.py | 51 ++++++++++++++++++++++--------------- 1 file changed, 30 insertions(+), 21 deletions(-) diff --git a/backends/exllamav2/model.py b/backends/exllamav2/model.py index 3ded71b..98b5636 100644 --- a/backends/exllamav2/model.py +++ b/backends/exllamav2/model.py @@ -1018,8 +1018,37 @@ class ExllamaV2Container: kwargs.get("repetition_decay"), fallback_decay, 0 ) - stop_conditions: List[Union[str, int]] = unwrap(kwargs.get("stop"), []) + # Initialize grammar handler + grammar_handler = ExLlamaV2Grammar() + + # Add JSON schema filter if it exists + json_schema = unwrap(kwargs.get("json_schema")) + if json_schema: + grammar_handler.add_json_schema_filter( + json_schema, self.model, self.tokenizer + ) + + # Add regex filter if it exists + regex_pattern = unwrap(kwargs.get("regex_pattern")) + if regex_pattern: + grammar_handler.add_regex_filter(regex_pattern, self.tokenizer) + + # Add EBNF filter if it exists + grammar_string = unwrap(kwargs.get("grammar_string")) + if grammar_string: + grammar_handler.add_ebnf_filter(grammar_string, self.model, self.tokenizer) + + # Set banned strings banned_strings: List[str] = unwrap(kwargs.get("banned_strings"), []) + if banned_strings and len(grammar_handler.filters) > 0: + logger.warning( + "Disabling banned_strings because " + "they cannot be used with grammar filters." + ) + + banned_strings = [] + + stop_conditions: List[Union[str, int]] = unwrap(kwargs.get("stop"), []) add_bos_token = unwrap(kwargs.get("add_bos_token"), True) ban_eos_token = unwrap(kwargs.get("ban_eos_token"), False) logit_bias = kwargs.get("logit_bias") @@ -1067,26 +1096,6 @@ class ExllamaV2Container: "in the model's vocab. Skipping." ) - # Initialize grammar handler - grammar_handler = ExLlamaV2Grammar() - - # Add JSON schema filter if it exists - json_schema = unwrap(kwargs.get("json_schema")) - if json_schema: - grammar_handler.add_json_schema_filter( - json_schema, self.model, self.tokenizer - ) - - # Add regex filter if it exists - regex_pattern = unwrap(kwargs.get("regex_pattern")) - if regex_pattern: - grammar_handler.add_regex_filter(regex_pattern, self.tokenizer) - - # Add EBNF filter if it exists - grammar_string = unwrap(kwargs.get("grammar_string")) - if grammar_string: - grammar_handler.add_ebnf_filter(grammar_string, self.model, self.tokenizer) - # Fetch EOS tokens from generation_config if they exist eos_tokens = ( self.generation_config.eos_tokens()