diff --git a/backends/exllamav2/grammar.py b/backends/exllamav2/grammar.py index 61c78b3..e6c7ab8 100644 --- a/backends/exllamav2/grammar.py +++ b/backends/exllamav2/grammar.py @@ -1,8 +1,10 @@ +import traceback from common.logger import init_logger from exllamav2 import ExLlamaV2, ExLlamaV2Tokenizer from exllamav2.generator import ExLlamaV2Sampler +from exllamav2.generator.filters import ExLlamaV2Filter -# Temporary, remove once the exllama version is bumped +# TODO: Remove after new exllama version is released try: from exllamav2.generator.filters import ExLlamaV2PrefixFilter @@ -10,18 +12,54 @@ try: except ImportError: _exllama_filter_available = False -try: - from lmformatenforcer import JsonSchemaParser - from lmformatenforcer.integrations.exllamav2 import ExLlamaV2TokenEnforcerFilter - - _lmformatenforcer_available = True -except ImportError: - _lmformatenforcer_available = False - logger = init_logger(__name__) +class OutlinesTokenizerWrapper: + """Wrapper for Outlines tokenizer""" + + def __init__(self, tokenizer): + self.tokenizer = tokenizer + id_to_piece = self.tokenizer.get_id_to_piece_list() + self.vocabulary = {piece: idx for idx, piece in enumerate(id_to_piece)} + self.eos_token_id = self.tokenizer.eos_token_id + self.eos_token = id_to_piece[self.tokenizer.eos_token_id] + self.special_tokens = list(self.tokenizer.extended_id_to_piece.keys()) + + def convert_token_to_string(self, token): + return token + + def decode(self, tokens): + s = "" + id_to_piece = self.tokenizer.get_id_to_piece_list() + for t in tokens: + s += id_to_piece[t] + return s + + +class ExLlamaV2EbnfFilter(ExLlamaV2Filter): + """Filter class for context-free grammar via outlines""" + + def __init__(self, model, tokenizer, grammar): + from outlines.fsm.fsm import CFGFSM + + super().__init__(model, tokenizer) + + self.wrapped_tokenizer = OutlinesTokenizerWrapper(tokenizer) + self.fsm = CFGFSM(grammar, self.wrapped_tokenizer) + self.state = self.fsm.first_state + + def begin(self, prefix_str=""): + self.state = self.fsm.first_state + + def feed(self, token): + self.state = self.fsm.next_state(self.state, token.item()) + + def next(self): + return self.fsm.allowed_token_ids(self.state), set() + + class ExLlamaV2Grammar: """ExLlamaV2 class for various grammar filters/parsers.""" @@ -34,28 +72,80 @@ class ExLlamaV2Grammar: ): """Adds an ExllamaV2 filter based on a JSON schema.""" - # Check if the required dependencies can be imported if not _exllama_filter_available: logger.warning( "ExllamaV2PrefixFilter is not available " - "in the currently installed ExllamaV2 version." + "in the currently installed ExllamaV2 version. " + "Skipping JSON schema parsing." ) return - if not _lmformatenforcer_available: + # Import optional dependencies + try: + from lmformatenforcer import JsonSchemaParser + from lmformatenforcer.integrations.exllamav2 import ( + ExLlamaV2TokenEnforcerFilter, + ) + except ImportError: logger.error( - "lmformatenforcer must be installed to parse a json schema.\n" - "Please run the following command: pip install lm-format-enforcer" + "Skipping JSON schema parsing because " + "lm-format-enforcer is not installed.\n" + "Please run the following command: " + "pip install lm-format-enforcer" ) return # Create the parser - schema_parser = JsonSchemaParser(json_schema) + try: + schema_parser = JsonSchemaParser(json_schema) + except Exception: + traceback.print_exc() + logger.error( + "Skipping because the JSON schema couldn't be parsed. " + "Please read the above error for more information." + ) + + return + lmfilter = ExLlamaV2TokenEnforcerFilter(schema_parser, tokenizer) prefix_filter = ExLlamaV2PrefixFilter(model, tokenizer, "{") # Append the filters - gen_settings.filters += [lmfilter, prefix_filter] + gen_settings.filters.extend([lmfilter, prefix_filter]) + gen_settings.filter_prefer_eos = True + + def add_ebnf_filter( + self, + ebnf_string: str, + gen_settings: ExLlamaV2Sampler.Settings, + model: ExLlamaV2, + tokenizer: ExLlamaV2Tokenizer, + ): + """ + Add an EBNF grammar filter. + Possibly replace outlines with an in-house solution in the future. + """ + + if not _exllama_filter_available: + logger.warning( + "filter_prefer_eos is not available " + "in the currently installed ExllamaV2 version. " + "Skipping EBNF parsing." + ) + + return + + try: + ebnf_filter = ExLlamaV2EbnfFilter(model, tokenizer, ebnf_string) + except ImportError: + logger.error( + "Skipping EBNF parsing because Outlines is not installed.\n" + "Please run the following command: pip install outlines" + ) + + return + + gen_settings.filters.append(ebnf_filter) gen_settings.filter_prefer_eos = True diff --git a/backends/exllamav2/model.py b/backends/exllamav2/model.py index e602149..b995a5f 100644 --- a/backends/exllamav2/model.py +++ b/backends/exllamav2/model.py @@ -761,6 +761,7 @@ class ExllamaV2Container: # Initialize grammar handler grammar_handler = ExLlamaV2Grammar() + gen_settings.filters = [] # Add JSON schema filter if it exists json_schema = unwrap(kwargs.get("json_schema")) @@ -769,6 +770,13 @@ class ExllamaV2Container: json_schema, gen_settings, self.model, self.tokenizer ) + # Add EBNF filter if it exists + grammar_string = unwrap(kwargs.get("grammar_string")) + if grammar_string: + grammar_handler.add_ebnf_filter( + grammar_string, gen_settings, self.model, self.tokenizer + ) + # Ban the EOS token if specified. If not, append to stop conditions # as well. # Set this below logging to avoid polluting the stop strings array diff --git a/common/sampling.py b/common/sampling.py index e0d20c7..d62f87a 100644 --- a/common/sampling.py +++ b/common/sampling.py @@ -122,6 +122,10 @@ class BaseSamplerRequest(BaseModel): default_factory=lambda: get_default_sampler_value("json_schema"), ) + grammar_string: Optional[str] = Field( + default_factory=lambda: get_default_sampler_value("grammar_string"), + ) + # Aliased variables typical: Optional[float] = Field( default_factory=lambda: get_default_sampler_value("typical", 1.0), @@ -266,6 +270,7 @@ class BaseSamplerRequest(BaseModel): "cfg_scale": self.cfg_scale, "negative_prompt": self.negative_prompt, "json_schema": self.json_schema, + "grammar_string": self.grammar_string, } return {**gen_params, **kwargs}