mirror of
https://github.com/theroyallab/tabbyAPI.git
synced 2026-03-14 15:57:27 +00:00
Merge pull request #397 from beep39/json-schema-for-exllamav3
Constrained generation with json schema for ExllamaV3
This commit is contained in:
57
backends/exllamav3/grammar.py
Normal file
57
backends/exllamav3/grammar.py
Normal file
@@ -0,0 +1,57 @@
|
||||
from typing import List
|
||||
import traceback
|
||||
|
||||
from exllamav3 import (
|
||||
Tokenizer,
|
||||
Filter,
|
||||
FormatronFilter,
|
||||
)
|
||||
from formatron.formatter import FormatterBuilder
|
||||
from formatron.schemas import json_schema
|
||||
from loguru import logger
|
||||
|
||||
|
||||
class ExLlamaV3Grammar:
|
||||
"""ExLlamaV3 class for various grammar filters/parsers."""
|
||||
|
||||
filters: List[Filter]
|
||||
|
||||
def __init__(self):
|
||||
self.filters = []
|
||||
|
||||
def add_json_schema_filter(
|
||||
self,
|
||||
schema: dict,
|
||||
tokenizer: Tokenizer,
|
||||
):
|
||||
"""Adds an ExllamaV3 filter based on a JSON schema."""
|
||||
|
||||
leading_character = "[" if schema.get("type") == "array" else "{"
|
||||
|
||||
try:
|
||||
# Add fields required by formatron if not present
|
||||
if "$id" not in schema:
|
||||
schema["$id"] = "https://example.com/example.json"
|
||||
if "$schema" not in schema:
|
||||
schema["$schema"] = "http://json-schema.org/draft-07/schema#"
|
||||
|
||||
# Validate schema and create formatter
|
||||
schema = json_schema.create_schema(schema)
|
||||
except Exception:
|
||||
traceback.print_exc()
|
||||
logger.error(
|
||||
"Skipping because the JSON schema couldn't be parsed. "
|
||||
"Please read the above error for more information."
|
||||
)
|
||||
return
|
||||
|
||||
f = FormatterBuilder()
|
||||
f.append_line(f"{f.json(schema)}")
|
||||
self.filters.append(
|
||||
FormatronFilter(tokenizer, eos_after_completed=True, formatter_builder=f)
|
||||
)
|
||||
|
||||
# Additional constraint to force leading character
|
||||
f = FormatterBuilder()
|
||||
f.append_line(leading_character)
|
||||
self.filters.append(FormatronFilter(tokenizer, formatter_builder=f))
|
||||
@@ -21,6 +21,7 @@ from exllamav3 import (
|
||||
Tokenizer,
|
||||
)
|
||||
from exllamav3.cache import CacheLayer_quant
|
||||
from backends.exllamav3.grammar import ExLlamaV3Grammar
|
||||
from loguru import logger
|
||||
|
||||
from backends.base_model_container import BaseModelContainer
|
||||
@@ -933,6 +934,7 @@ class ExllamaV3Container(BaseModelContainer):
|
||||
prompts = [prompt]
|
||||
stop_conditions = params.stop
|
||||
add_bos_token = unwrap(params.add_bos_token, self.hf_model.add_bos_token())
|
||||
grammar_handler = ExLlamaV3Grammar()
|
||||
|
||||
# Get multimodal embeddings if present
|
||||
mm_embeddings_content = mm_embeddings.content if mm_embeddings else []
|
||||
@@ -974,6 +976,9 @@ class ExllamaV3Container(BaseModelContainer):
|
||||
request_id,
|
||||
)
|
||||
|
||||
if params.json_schema:
|
||||
grammar_handler.add_json_schema_filter(params.json_schema, self.tokenizer)
|
||||
|
||||
generation = {}
|
||||
job = AsyncJob(
|
||||
self.generator,
|
||||
@@ -985,6 +990,7 @@ class ExllamaV3Container(BaseModelContainer):
|
||||
embeddings=mm_embeddings_content,
|
||||
return_top_tokens=params.logprobs,
|
||||
max_rq_tokens=self.max_rq_tokens,
|
||||
filters=grammar_handler.filters,
|
||||
)
|
||||
|
||||
generated_tokens = 0
|
||||
|
||||
Reference in New Issue
Block a user