Merge pull request #397 from beep39/json-schema-for-exllamav3

Constrained generation with json schema for ExllamaV3
This commit is contained in:
Brian
2025-11-24 22:34:31 -05:00
committed by GitHub
2 changed files with 63 additions and 0 deletions

View File

@@ -0,0 +1,57 @@
from typing import List
import traceback
from exllamav3 import (
Tokenizer,
Filter,
FormatronFilter,
)
from formatron.formatter import FormatterBuilder
from formatron.schemas import json_schema
from loguru import logger
class ExLlamaV3Grammar:
"""ExLlamaV3 class for various grammar filters/parsers."""
filters: List[Filter]
def __init__(self):
self.filters = []
def add_json_schema_filter(
self,
schema: dict,
tokenizer: Tokenizer,
):
"""Adds an ExllamaV3 filter based on a JSON schema."""
leading_character = "[" if schema.get("type") == "array" else "{"
try:
# Add fields required by formatron if not present
if "$id" not in schema:
schema["$id"] = "https://example.com/example.json"
if "$schema" not in schema:
schema["$schema"] = "http://json-schema.org/draft-07/schema#"
# Validate schema and create formatter
schema = json_schema.create_schema(schema)
except Exception:
traceback.print_exc()
logger.error(
"Skipping because the JSON schema couldn't be parsed. "
"Please read the above error for more information."
)
return
f = FormatterBuilder()
f.append_line(f"{f.json(schema)}")
self.filters.append(
FormatronFilter(tokenizer, eos_after_completed=True, formatter_builder=f)
)
# Additional constraint to force leading character
f = FormatterBuilder()
f.append_line(leading_character)
self.filters.append(FormatronFilter(tokenizer, formatter_builder=f))

View File

@@ -21,6 +21,7 @@ from exllamav3 import (
Tokenizer,
)
from exllamav3.cache import CacheLayer_quant
from backends.exllamav3.grammar import ExLlamaV3Grammar
from loguru import logger
from backends.base_model_container import BaseModelContainer
@@ -933,6 +934,7 @@ class ExllamaV3Container(BaseModelContainer):
prompts = [prompt]
stop_conditions = params.stop
add_bos_token = unwrap(params.add_bos_token, self.hf_model.add_bos_token())
grammar_handler = ExLlamaV3Grammar()
# Get multimodal embeddings if present
mm_embeddings_content = mm_embeddings.content if mm_embeddings else []
@@ -974,6 +976,9 @@ class ExllamaV3Container(BaseModelContainer):
request_id,
)
if params.json_schema:
grammar_handler.add_json_schema_filter(params.json_schema, self.tokenizer)
generation = {}
job = AsyncJob(
self.generator,
@@ -985,6 +990,7 @@ class ExllamaV3Container(BaseModelContainer):
embeddings=mm_embeddings_content,
return_top_tokens=params.logprobs,
max_rq_tokens=self.max_rq_tokens,
filters=grammar_handler.filters,
)
generated_tokens = 0