Grammar: Preliminary Formatron KBNF support

This commit is contained in:
DocShotgun
2024-11-23 12:05:41 -08:00
parent 0836a9317f
commit a9f39bcff3

View File

@@ -1,4 +1,5 @@
import traceback import traceback
import typing
from exllamav2 import ExLlamaV2, ExLlamaV2Tokenizer from exllamav2 import ExLlamaV2, ExLlamaV2Tokenizer
from exllamav2.generator.filters import ExLlamaV2Filter from exllamav2.generator.filters import ExLlamaV2Filter
from loguru import logger from loguru import logger
@@ -7,6 +8,7 @@ from typing import List
from formatron.formatter import FormatterBuilder from formatron.formatter import FormatterBuilder
from formatron.schemas import json_schema from formatron.schemas import json_schema
from formatron.integrations.exllamav2 import create_formatter_filter from formatron.integrations.exllamav2 import create_formatter_filter
from formatron.extractor import NonterminalExtractor
def clear_grammar_func_cache(): def clear_grammar_func_cache():
@@ -98,7 +100,11 @@ class ExLlamaV2Grammar:
try: try:
# Validate KBNF and create formatter # Validate KBNF and create formatter
f = FormatterBuilder() f = FormatterBuilder()
# TODO: Implement this f.append_line(
f"{f.extractor(
lambda nonterminal: CustomExtractor(nonterminal, kbnf_string)
)}"
)
except Exception: except Exception:
logger.error( logger.error(
"Skipping because the KBNF string couldn't be parsed. " "Skipping because the KBNF string couldn't be parsed. "
@@ -111,3 +117,19 @@ class ExLlamaV2Grammar:
# Append the filters # Append the filters
self.filters.append(lmfilter) self.filters.append(lmfilter)
class CustomExtractor(NonterminalExtractor):
def __init__(self, nonterminal: str, kbnf_string: str):
super().__init__(nonterminal)
self.kbnf_string = kbnf_string
# Fails without an extract function defined
# No idea what it does or why it's needed, but this seems to work
# TODO: Figure out how to do this properly
def extract(self, input_str: str) -> typing.Optional[tuple[str, typing.Any]]:
return input_str[len(input_str) :], input_str[: len(input_str)]
@property
def kbnf_definition(self) -> str:
return self.kbnf_string.replace("start", self.nonterminal)