Files
tabbyAPI/backends/exllamav2/grammar.py
2024-11-24 11:35:45 -08:00

127 lines
3.7 KiB
Python

import traceback
import typing
from typing import List
from exllamav2 import ExLlamaV2, ExLlamaV2Tokenizer
from exllamav2.generator.filters import ExLlamaV2Filter
from formatron.extractor import NonterminalExtractor
from formatron.formatter import FormatterBuilder
from formatron.integrations.exllamav2 import create_formatter_filter
from formatron.schemas import json_schema
from loguru import logger
class ExLlamaV2Grammar:
"""ExLlamaV2 class for various grammar filters/parsers."""
filters: List[ExLlamaV2Filter]
def __init__(self):
self.filters = []
def add_json_schema_filter(
self,
schema: dict,
model: ExLlamaV2,
tokenizer: ExLlamaV2Tokenizer,
):
"""Adds an ExllamaV2 filter based on a JSON schema."""
# Create the parser
try:
# Add fields required by formatron if not present
if "$id" not in schema:
schema["$id"] = "https://example.com/example.json"
if "$schema" not in schema:
schema["$schema"] = "http://json-schema.org/draft-07/schema#"
# Validate schema and create formatter
schema = json_schema.create_schema(schema)
f = FormatterBuilder()
f.append_line(f"{f.json(schema)}")
except Exception:
traceback.print_exc()
logger.error(
"Skipping because the JSON schema couldn't be parsed. "
"Please read the above error for more information."
)
return
lmfilter = create_formatter_filter(model, tokenizer, f)
# Append the filters
self.filters.append(lmfilter)
def add_regex_filter(
self,
pattern: str,
model: ExLlamaV2,
tokenizer: ExLlamaV2Tokenizer,
):
"""Adds an ExllamaV2 filter based on regular expressions."""
# Create the parser
try:
# Validate regex and create formatter
f = FormatterBuilder()
f.append_line(f"{f.regex(pattern)}")
except Exception:
traceback.print_exc()
logger.error(
"Skipping because the regex pattern couldn't be parsed. "
"Please read the above error for more information."
)
return
lmfilter = create_formatter_filter(model, tokenizer, f)
# Append the filters
self.filters.append(lmfilter)
def add_kbnf_filter(
self,
kbnf_string: str,
model: ExLlamaV2,
tokenizer: ExLlamaV2Tokenizer,
):
"""Adds an ExllamaV2 filter based on KBNF grammar."""
# Create the parser
try:
# Validate KBNF and create formatter
f = FormatterBuilder()
f.append_line(
f"""{f.extractor(lambda nonterminal:
CFGExtractor(nonterminal, kbnf_string))}"""
)
except Exception:
logger.error(
"Skipping because the KBNF string couldn't be parsed. "
"Please read the above error for more information."
)
return
lmfilter = create_formatter_filter(model, tokenizer, f)
# Append the filters
self.filters.append(lmfilter)
class CFGExtractor(NonterminalExtractor):
"""Extractor class for KBNF context-free grammar"""
def __init__(self, nonterminal: str, kbnf_string: str):
super().__init__(nonterminal)
self.kbnf_string = kbnf_string
# Return the entire input string as the extracted string
def extract(self, input_str: str) -> typing.Optional[tuple[str, typing.Any]]:
return "", input_str
@property
def kbnf_definition(self) -> str:
return self.kbnf_string.replace("start", self.nonterminal)