mirror of
https://github.com/theroyallab/tabbyAPI.git
synced 2026-04-28 02:01:24 +00:00
Merge pull request #252 from DocShotgun/main
Switch grammar backend to Formatron
This commit is contained in:
@@ -1,110 +1,16 @@
|
|||||||
import traceback
|
import traceback
|
||||||
from exllamav2 import ExLlamaV2, ExLlamaV2Tokenizer
|
import typing
|
||||||
from exllamav2.generator.filters import ExLlamaV2Filter, ExLlamaV2PrefixFilter
|
|
||||||
from lmformatenforcer import (
|
|
||||||
JsonSchemaParser,
|
|
||||||
RegexParser,
|
|
||||||
TokenEnforcer,
|
|
||||||
CharacterLevelParser,
|
|
||||||
)
|
|
||||||
from lmformatenforcer.integrations.exllamav2 import (
|
|
||||||
build_token_enforcer_tokenizer_data,
|
|
||||||
)
|
|
||||||
from loguru import logger
|
|
||||||
from typing import List
|
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import torch
|
||||||
class OutlinesTokenizerWrapper:
|
from exllamav2 import ExLlamaV2, ExLlamaV2Tokenizer
|
||||||
"""Wrapper for Outlines tokenizer"""
|
from exllamav2.generator.filters import ExLlamaV2Filter
|
||||||
|
from formatron.extractor import NonterminalExtractor
|
||||||
def __init__(self, tokenizer):
|
from formatron.formatter import FormatterBuilder
|
||||||
self.tokenizer = tokenizer
|
from formatron.integrations.exllamav2 import FormatterFilter, create_engine_vocabulary
|
||||||
id_to_piece = self.tokenizer.get_id_to_piece_list()
|
from formatron.schemas import json_schema
|
||||||
self.vocabulary = {piece: idx for idx, piece in enumerate(id_to_piece)}
|
from loguru import logger
|
||||||
self.eos_token_id = self.tokenizer.eos_token_id
|
|
||||||
self.eos_token = id_to_piece[self.tokenizer.eos_token_id]
|
|
||||||
self.special_tokens = list(self.tokenizer.extended_id_to_piece.keys())
|
|
||||||
|
|
||||||
def convert_token_to_string(self, token):
|
|
||||||
return token
|
|
||||||
|
|
||||||
def decode(self, tokens):
|
|
||||||
s = ""
|
|
||||||
id_to_piece = self.tokenizer.get_id_to_piece_list()
|
|
||||||
for t in tokens:
|
|
||||||
s += id_to_piece[t]
|
|
||||||
return s
|
|
||||||
|
|
||||||
|
|
||||||
class ExLlamaV2EbnfFilter(ExLlamaV2Filter):
|
|
||||||
"""Filter class for context-free grammar via outlines"""
|
|
||||||
|
|
||||||
def __init__(self, model, tokenizer, grammar):
|
|
||||||
from outlines.fsm.fsm import CFGFSM
|
|
||||||
|
|
||||||
super().__init__(model, tokenizer)
|
|
||||||
|
|
||||||
self.wrapped_tokenizer = OutlinesTokenizerWrapper(tokenizer)
|
|
||||||
self.fsm = CFGFSM(grammar, self.wrapped_tokenizer)
|
|
||||||
self.state = self.fsm.first_state
|
|
||||||
|
|
||||||
def begin(self, prefix_str=""):
|
|
||||||
self.state = self.fsm.first_state
|
|
||||||
|
|
||||||
def feed(self, token):
|
|
||||||
self.state = self.fsm.next_state(self.state, token.item())
|
|
||||||
|
|
||||||
def next(self):
|
|
||||||
return self.fsm.allowed_token_ids(self.state), set()
|
|
||||||
|
|
||||||
def use_background_worker(self):
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
@lru_cache(10)
|
|
||||||
def _get_lmfe_tokenizer_data(tokenizer: ExLlamaV2Tokenizer):
|
|
||||||
return build_token_enforcer_tokenizer_data(tokenizer)
|
|
||||||
|
|
||||||
|
|
||||||
class ExLlamaV2TokenEnforcerFilter(ExLlamaV2Filter):
|
|
||||||
"""Filter class for LMFE"""
|
|
||||||
|
|
||||||
token_sequence: List[int]
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model: ExLlamaV2,
|
|
||||||
tokenizer: ExLlamaV2Tokenizer,
|
|
||||||
character_level_parser: CharacterLevelParser,
|
|
||||||
):
|
|
||||||
super().__init__(model, tokenizer)
|
|
||||||
tokenizer_data = _get_lmfe_tokenizer_data(tokenizer)
|
|
||||||
self.token_enforcer = TokenEnforcer(tokenizer_data, character_level_parser)
|
|
||||||
self.token_sequence = []
|
|
||||||
|
|
||||||
def begin(self, prefix_str: str):
|
|
||||||
self.token_sequence = []
|
|
||||||
|
|
||||||
def feed(self, token):
|
|
||||||
self.token_sequence.append(int(token[0][0]))
|
|
||||||
|
|
||||||
def next(self):
|
|
||||||
allowed_tokens = self.token_enforcer.get_allowed_tokens(self.token_sequence)
|
|
||||||
if not hasattr(self, "allow_return_type_list"):
|
|
||||||
return set(allowed_tokens), set()
|
|
||||||
else:
|
|
||||||
return sorted(allowed_tokens), []
|
|
||||||
|
|
||||||
def use_background_worker(self):
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
def clear_grammar_func_cache():
|
|
||||||
"""Flush tokenizer_data cache to avoid holding references to
|
|
||||||
tokenizers after unloading a model"""
|
|
||||||
|
|
||||||
_get_lmfe_tokenizer_data.cache_clear()
|
|
||||||
|
|
||||||
|
|
||||||
class ExLlamaV2Grammar:
|
class ExLlamaV2Grammar:
|
||||||
@@ -117,7 +23,7 @@ class ExLlamaV2Grammar:
|
|||||||
|
|
||||||
def add_json_schema_filter(
|
def add_json_schema_filter(
|
||||||
self,
|
self,
|
||||||
json_schema: dict,
|
schema: dict,
|
||||||
model: ExLlamaV2,
|
model: ExLlamaV2,
|
||||||
tokenizer: ExLlamaV2Tokenizer,
|
tokenizer: ExLlamaV2Tokenizer,
|
||||||
):
|
):
|
||||||
@@ -125,7 +31,16 @@ class ExLlamaV2Grammar:
|
|||||||
|
|
||||||
# Create the parser
|
# Create the parser
|
||||||
try:
|
try:
|
||||||
schema_parser = JsonSchemaParser(json_schema)
|
# Add fields required by formatron if not present
|
||||||
|
if "$id" not in schema:
|
||||||
|
schema["$id"] = "https://example.com/example.json"
|
||||||
|
if "$schema" not in schema:
|
||||||
|
schema["$schema"] = "http://json-schema.org/draft-07/schema#"
|
||||||
|
|
||||||
|
# Validate schema and create formatter
|
||||||
|
schema = json_schema.create_schema(schema)
|
||||||
|
f = FormatterBuilder()
|
||||||
|
f.append_line(f"{f.json(schema)}")
|
||||||
except Exception:
|
except Exception:
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
logger.error(
|
logger.error(
|
||||||
@@ -135,14 +50,10 @@ class ExLlamaV2Grammar:
|
|||||||
|
|
||||||
return
|
return
|
||||||
|
|
||||||
# Allow JSON objects or JSON arrays at the top level
|
lmfilter = _create_formatter_filter(model, tokenizer, f)
|
||||||
json_prefixes = ["[", "{"]
|
|
||||||
|
|
||||||
lmfilter = ExLlamaV2TokenEnforcerFilter(model, tokenizer, schema_parser)
|
|
||||||
prefix_filter = ExLlamaV2PrefixFilter(model, tokenizer, json_prefixes)
|
|
||||||
|
|
||||||
# Append the filters
|
# Append the filters
|
||||||
self.filters.extend([lmfilter, prefix_filter])
|
self.filters.append(lmfilter)
|
||||||
|
|
||||||
def add_regex_filter(
|
def add_regex_filter(
|
||||||
self,
|
self,
|
||||||
@@ -154,7 +65,9 @@ class ExLlamaV2Grammar:
|
|||||||
|
|
||||||
# Create the parser
|
# Create the parser
|
||||||
try:
|
try:
|
||||||
pattern_parser = RegexParser(pattern)
|
# Validate regex and create formatter
|
||||||
|
f = FormatterBuilder()
|
||||||
|
f.append_line(f"{f.regex(pattern)}")
|
||||||
except Exception:
|
except Exception:
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
logger.error(
|
logger.error(
|
||||||
@@ -164,32 +77,82 @@ class ExLlamaV2Grammar:
|
|||||||
|
|
||||||
return
|
return
|
||||||
|
|
||||||
lmfilter = ExLlamaV2TokenEnforcerFilter(model, tokenizer, pattern_parser)
|
lmfilter = _create_formatter_filter(model, tokenizer, f)
|
||||||
|
|
||||||
# Append the filters
|
# Append the filters
|
||||||
self.filters.append(lmfilter)
|
self.filters.append(lmfilter)
|
||||||
|
|
||||||
def add_ebnf_filter(
|
def add_kbnf_filter(
|
||||||
self,
|
self,
|
||||||
ebnf_string: str,
|
kbnf_string: str,
|
||||||
model: ExLlamaV2,
|
model: ExLlamaV2,
|
||||||
tokenizer: ExLlamaV2Tokenizer,
|
tokenizer: ExLlamaV2Tokenizer,
|
||||||
):
|
):
|
||||||
"""
|
"""Adds an ExllamaV2 filter based on KBNF grammar."""
|
||||||
Add an EBNF grammar filter.
|
|
||||||
Possibly replace outlines with an in-house solution in the future.
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
# Create the parser
|
||||||
try:
|
try:
|
||||||
ebnf_filter = ExLlamaV2EbnfFilter(model, tokenizer, ebnf_string)
|
# Validate KBNF and create formatter
|
||||||
except ImportError:
|
f = FormatterBuilder()
|
||||||
|
f.append_line(
|
||||||
|
f"""{f.extractor(lambda nonterminal:
|
||||||
|
CFGExtractor(nonterminal, kbnf_string))}"""
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
logger.error(
|
logger.error(
|
||||||
"Skipping EBNF parsing because Outlines is not installed.\n"
|
"Skipping because the KBNF string couldn't be parsed. "
|
||||||
"Please run the following command in your environment "
|
"Please read the above error for more information."
|
||||||
"to install extra packages:\n"
|
|
||||||
"pip install -U .[extras]"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return
|
return
|
||||||
|
|
||||||
self.filters.append(ebnf_filter)
|
lmfilter = _create_formatter_filter(model, tokenizer, f)
|
||||||
|
|
||||||
|
# Append the filters
|
||||||
|
self.filters.append(lmfilter)
|
||||||
|
|
||||||
|
|
||||||
|
class CFGExtractor(NonterminalExtractor):
|
||||||
|
"""Extractor class for KBNF context-free grammar"""
|
||||||
|
|
||||||
|
def __init__(self, nonterminal: str, kbnf_string: str):
|
||||||
|
super().__init__(nonterminal)
|
||||||
|
self.kbnf_string = kbnf_string
|
||||||
|
|
||||||
|
# Return the entire input string as the extracted string
|
||||||
|
def extract(self, input_str: str) -> typing.Optional[tuple[str, typing.Any]]:
|
||||||
|
return "", input_str
|
||||||
|
|
||||||
|
@property
|
||||||
|
def kbnf_definition(self) -> str:
|
||||||
|
return self.kbnf_string.replace("start", self.nonterminal)
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache(1)
|
||||||
|
def _create_cached_engine_vocabulary(tokenizer: ExLlamaV2Tokenizer):
|
||||||
|
"""Build and cache engine vocabulary on first grammar run"""
|
||||||
|
|
||||||
|
return create_engine_vocabulary(tokenizer)
|
||||||
|
|
||||||
|
|
||||||
|
def _create_formatter_filter(
|
||||||
|
model: ExLlamaV2, tokenizer: ExLlamaV2Tokenizer, formatter_builder: FormatterBuilder
|
||||||
|
) -> ExLlamaV2Filter:
|
||||||
|
"""
|
||||||
|
Create a formatter filter for the ExLlamaV2 engine.
|
||||||
|
Minimalist clone of formatron.integrations.exllamav2.create_formatter_filter
|
||||||
|
with lru_cache enabled for engine vocabulary
|
||||||
|
"""
|
||||||
|
|
||||||
|
vocab = _create_cached_engine_vocabulary(tokenizer)
|
||||||
|
f = formatter_builder.build(
|
||||||
|
vocab, lambda tokens: tokenizer.decode(torch.tensor(tokens))
|
||||||
|
)
|
||||||
|
return FormatterFilter(model, tokenizer, f)
|
||||||
|
|
||||||
|
|
||||||
|
def clear_grammar_func_cache():
|
||||||
|
"""Flush tokenizer_data cache to avoid holding references to
|
||||||
|
tokenizers after unloading a model"""
|
||||||
|
|
||||||
|
_create_cached_engine_vocabulary.cache_clear()
|
||||||
|
|||||||
@@ -1194,7 +1194,7 @@ class ExllamaV2Container:
|
|||||||
# Add EBNF filter if it exists
|
# Add EBNF filter if it exists
|
||||||
grammar_string = unwrap(kwargs.get("grammar_string"))
|
grammar_string = unwrap(kwargs.get("grammar_string"))
|
||||||
if grammar_string:
|
if grammar_string:
|
||||||
grammar_handler.add_ebnf_filter(grammar_string, self.model, self.tokenizer)
|
grammar_handler.add_kbnf_filter(grammar_string, self.model, self.tokenizer)
|
||||||
|
|
||||||
# Set banned strings
|
# Set banned strings
|
||||||
banned_strings: List[str] = unwrap(kwargs.get("banned_strings"), [])
|
banned_strings: List[str] = unwrap(kwargs.get("banned_strings"), [])
|
||||||
|
|||||||
@@ -26,7 +26,8 @@ dependencies = [
|
|||||||
"sse-starlette",
|
"sse-starlette",
|
||||||
"packaging",
|
"packaging",
|
||||||
"tokenizers",
|
"tokenizers",
|
||||||
"lm-format-enforcer >= 0.9.6",
|
"formatron",
|
||||||
|
"kbnf>=0.4.1",
|
||||||
"aiofiles",
|
"aiofiles",
|
||||||
"aiohttp",
|
"aiohttp",
|
||||||
"async_lru",
|
"async_lru",
|
||||||
@@ -53,7 +54,6 @@ dependencies = [
|
|||||||
[project.optional-dependencies]
|
[project.optional-dependencies]
|
||||||
extras = [
|
extras = [
|
||||||
# Heavy dependencies that aren't for everyday use
|
# Heavy dependencies that aren't for everyday use
|
||||||
"outlines",
|
|
||||||
"infinity-emb",
|
"infinity-emb",
|
||||||
"sentence-transformers",
|
"sentence-transformers",
|
||||||
]
|
]
|
||||||
|
|||||||
Reference in New Issue
Block a user