diff --git a/backends/exllamav3/grammar.py b/backends/exllamav3/grammar.py new file mode 100644 index 0000000..7dc831d --- /dev/null +++ b/backends/exllamav3/grammar.py @@ -0,0 +1,57 @@ +from typing import List +import traceback + +from exllamav3 import ( + Tokenizer, + Filter, + FormatronFilter, +) +from formatron.formatter import FormatterBuilder +from formatron.schemas import json_schema +from loguru import logger + + +class ExLlamaV3Grammar: + """ExLlamaV3 class for various grammar filters/parsers.""" + + filters: List[Filter] + + def __init__(self): + self.filters = [] + + def add_json_schema_filter( + self, + schema: dict, + tokenizer: Tokenizer, + ): + """Adds an ExllamaV3 filter based on a JSON schema.""" + + leading_character = "[" if schema.get("type") == "array" else "{" + + try: + # Add fields required by formatron if not present + if "$id" not in schema: + schema["$id"] = "https://example.com/example.json" + if "$schema" not in schema: + schema["$schema"] = "http://json-schema.org/draft-07/schema#" + + # Validate schema and create formatter + schema = json_schema.create_schema(schema) + except Exception: + traceback.print_exc() + logger.error( + "Skipping because the JSON schema couldn't be parsed. " + "Please read the above error for more information." + ) + return + + f = FormatterBuilder() + f.append_line(f"{f.json(schema)}") + self.filters.append( + FormatronFilter(tokenizer, eos_after_completed=True, formatter_builder=f) + ) + + # Additional constraint to force leading character + f = FormatterBuilder() + f.append_line(leading_character) + self.filters.append(FormatronFilter(tokenizer, formatter_builder=f)) diff --git a/backends/exllamav3/model.py b/backends/exllamav3/model.py index 44298c9..0d1780b 100644 --- a/backends/exllamav3/model.py +++ b/backends/exllamav3/model.py @@ -21,6 +21,7 @@ from exllamav3 import ( Tokenizer, ) from exllamav3.cache import CacheLayer_quant +from backends.exllamav3.grammar import ExLlamaV3Grammar from loguru import logger from backends.base_model_container import BaseModelContainer @@ -933,6 +934,7 @@ class ExllamaV3Container(BaseModelContainer): prompts = [prompt] stop_conditions = params.stop add_bos_token = unwrap(params.add_bos_token, self.hf_model.add_bos_token()) + grammar_handler = ExLlamaV3Grammar() # Get multimodal embeddings if present mm_embeddings_content = mm_embeddings.content if mm_embeddings else [] @@ -974,6 +976,9 @@ class ExllamaV3Container(BaseModelContainer): request_id, ) + if params.json_schema: + grammar_handler.add_json_schema_filter(params.json_schema, self.tokenizer) + generation = {} job = AsyncJob( self.generator, @@ -985,6 +990,7 @@ class ExllamaV3Container(BaseModelContainer): embeddings=mm_embeddings_content, return_top_tokens=params.logprobs, max_rq_tokens=self.max_rq_tokens, + filters=grammar_handler.filters, ) generated_tokens = 0