Grammar: Add custom ExLlamaV2TokenEnforcerFilter class

This commit is contained in:
turboderp
2024-09-14 21:42:53 +02:00
parent a2b4e3f21f
commit c66fe8e947
2 changed files with 46 additions and 9 deletions

View File

@@ -1,9 +1,13 @@
import traceback
from exllamav2 import ExLlamaV2, ExLlamaV2Tokenizer
from exllamav2.generator.filters import ExLlamaV2Filter, ExLlamaV2PrefixFilter
from lmformatenforcer import JsonSchemaParser, RegexParser
from lmformatenforcer import (
JsonSchemaParser,
RegexParser,
TokenEnforcer,
CharacterLevelParser,
)
from lmformatenforcer.integrations.exllamav2 import (
ExLlamaV2TokenEnforcerFilter,
build_token_enforcer_tokenizer_data,
)
from loguru import logger
@@ -54,12 +58,48 @@ class ExLlamaV2EbnfFilter(ExLlamaV2Filter):
def next(self):
return self.fsm.allowed_token_ids(self.state), set()
def use_background_worker(self):
return True
@lru_cache(10)
def _get_lmfe_tokenizer_data(tokenizer: ExLlamaV2Tokenizer):
return build_token_enforcer_tokenizer_data(tokenizer)
class ExLlamaV2TokenEnforcerFilter(ExLlamaV2Filter):
"""Filter class for LMFE"""
token_sequence: List[int]
def __init__(
self,
model: ExLlamaV2,
tokenizer: ExLlamaV2Tokenizer,
character_level_parser: CharacterLevelParser,
):
super().__init__(model, tokenizer)
tokenizer_data = _get_lmfe_tokenizer_data(tokenizer)
self.token_enforcer = TokenEnforcer(tokenizer_data, character_level_parser)
self.token_sequence = []
def begin(self, prefix_str: str):
self.token_sequence = []
def feed(self, token):
self.token_sequence.append(int(token[0][0]))
def next(self):
allowed_tokens = self.token_enforcer.get_allowed_tokens(self.token_sequence)
if not hasattr(self, "allow_return_type_list"):
return set(allowed_tokens), set()
else:
return sorted(allowed_tokens), []
def use_background_worker(self):
return True
def clear_grammar_func_cache():
"""Flush tokenizer_data cache to avoid holding references to
tokenizers after unloading a model"""
@@ -98,9 +138,7 @@ class ExLlamaV2Grammar:
# Allow JSON objects or JSON arrays at the top level
json_prefixes = ["[", "{"]
lmfilter = ExLlamaV2TokenEnforcerFilter(
schema_parser, _get_lmfe_tokenizer_data(tokenizer)
)
lmfilter = ExLlamaV2TokenEnforcerFilter(model, tokenizer, schema_parser)
prefix_filter = ExLlamaV2PrefixFilter(model, tokenizer, json_prefixes)
# Append the filters
@@ -109,6 +147,7 @@ class ExLlamaV2Grammar:
def add_regex_filter(
self,
pattern: str,
model: ExLlamaV2,
tokenizer: ExLlamaV2Tokenizer,
):
"""Adds an ExllamaV2 filter based on regular expressions."""
@@ -125,9 +164,7 @@ class ExLlamaV2Grammar:
return
lmfilter = ExLlamaV2TokenEnforcerFilter(
pattern_parser, _get_lmfe_tokenizer_data(tokenizer)
)
lmfilter = ExLlamaV2TokenEnforcerFilter(model, tokenizer, pattern_parser)
# Append the filters
self.filters.append(lmfilter)

View File

@@ -1141,7 +1141,7 @@ class ExllamaV2Container:
# Add regex filter if it exists
regex_pattern = unwrap(kwargs.get("regex_pattern"))
if regex_pattern:
grammar_handler.add_regex_filter(regex_pattern, self.tokenizer)
grammar_handler.add_regex_filter(regex_pattern, self.model, self.tokenizer)
# Add EBNF filter if it exists
grammar_string = unwrap(kwargs.get("grammar_string"))