Grammar: Add custom ExLlamaV2TokenEnforcerFilter class

This commit is contained in:
turboderp
2024-09-14 21:42:53 +02:00
parent a2b4e3f21f
commit c66fe8e947
2 changed files with 46 additions and 9 deletions

View File

@@ -1,9 +1,13 @@
import traceback import traceback
from exllamav2 import ExLlamaV2, ExLlamaV2Tokenizer from exllamav2 import ExLlamaV2, ExLlamaV2Tokenizer
from exllamav2.generator.filters import ExLlamaV2Filter, ExLlamaV2PrefixFilter from exllamav2.generator.filters import ExLlamaV2Filter, ExLlamaV2PrefixFilter
from lmformatenforcer import JsonSchemaParser, RegexParser from lmformatenforcer import (
JsonSchemaParser,
RegexParser,
TokenEnforcer,
CharacterLevelParser,
)
from lmformatenforcer.integrations.exllamav2 import ( from lmformatenforcer.integrations.exllamav2 import (
ExLlamaV2TokenEnforcerFilter,
build_token_enforcer_tokenizer_data, build_token_enforcer_tokenizer_data,
) )
from loguru import logger from loguru import logger
@@ -54,12 +58,48 @@ class ExLlamaV2EbnfFilter(ExLlamaV2Filter):
def next(self): def next(self):
return self.fsm.allowed_token_ids(self.state), set() return self.fsm.allowed_token_ids(self.state), set()
def use_background_worker(self):
return True
@lru_cache(10) @lru_cache(10)
def _get_lmfe_tokenizer_data(tokenizer: ExLlamaV2Tokenizer): def _get_lmfe_tokenizer_data(tokenizer: ExLlamaV2Tokenizer):
return build_token_enforcer_tokenizer_data(tokenizer) return build_token_enforcer_tokenizer_data(tokenizer)
class ExLlamaV2TokenEnforcerFilter(ExLlamaV2Filter):
"""Filter class for LMFE"""
token_sequence: List[int]
def __init__(
self,
model: ExLlamaV2,
tokenizer: ExLlamaV2Tokenizer,
character_level_parser: CharacterLevelParser,
):
super().__init__(model, tokenizer)
tokenizer_data = _get_lmfe_tokenizer_data(tokenizer)
self.token_enforcer = TokenEnforcer(tokenizer_data, character_level_parser)
self.token_sequence = []
def begin(self, prefix_str: str):
self.token_sequence = []
def feed(self, token):
self.token_sequence.append(int(token[0][0]))
def next(self):
allowed_tokens = self.token_enforcer.get_allowed_tokens(self.token_sequence)
if not hasattr(self, "allow_return_type_list"):
return set(allowed_tokens), set()
else:
return sorted(allowed_tokens), []
def use_background_worker(self):
return True
def clear_grammar_func_cache(): def clear_grammar_func_cache():
"""Flush tokenizer_data cache to avoid holding references to """Flush tokenizer_data cache to avoid holding references to
tokenizers after unloading a model""" tokenizers after unloading a model"""
@@ -98,9 +138,7 @@ class ExLlamaV2Grammar:
# Allow JSON objects or JSON arrays at the top level # Allow JSON objects or JSON arrays at the top level
json_prefixes = ["[", "{"] json_prefixes = ["[", "{"]
lmfilter = ExLlamaV2TokenEnforcerFilter( lmfilter = ExLlamaV2TokenEnforcerFilter(model, tokenizer, schema_parser)
schema_parser, _get_lmfe_tokenizer_data(tokenizer)
)
prefix_filter = ExLlamaV2PrefixFilter(model, tokenizer, json_prefixes) prefix_filter = ExLlamaV2PrefixFilter(model, tokenizer, json_prefixes)
# Append the filters # Append the filters
@@ -109,6 +147,7 @@ class ExLlamaV2Grammar:
def add_regex_filter( def add_regex_filter(
self, self,
pattern: str, pattern: str,
model: ExLlamaV2,
tokenizer: ExLlamaV2Tokenizer, tokenizer: ExLlamaV2Tokenizer,
): ):
"""Adds an ExllamaV2 filter based on regular expressions.""" """Adds an ExllamaV2 filter based on regular expressions."""
@@ -125,9 +164,7 @@ class ExLlamaV2Grammar:
return return
lmfilter = ExLlamaV2TokenEnforcerFilter( lmfilter = ExLlamaV2TokenEnforcerFilter(model, tokenizer, pattern_parser)
pattern_parser, _get_lmfe_tokenizer_data(tokenizer)
)
# Append the filters # Append the filters
self.filters.append(lmfilter) self.filters.append(lmfilter)

View File

@@ -1141,7 +1141,7 @@ class ExllamaV2Container:
# Add regex filter if it exists # Add regex filter if it exists
regex_pattern = unwrap(kwargs.get("regex_pattern")) regex_pattern = unwrap(kwargs.get("regex_pattern"))
if regex_pattern: if regex_pattern:
grammar_handler.add_regex_filter(regex_pattern, self.tokenizer) grammar_handler.add_regex_filter(regex_pattern, self.model, self.tokenizer)
# Add EBNF filter if it exists # Add EBNF filter if it exists
grammar_string = unwrap(kwargs.get("grammar_string")) grammar_string = unwrap(kwargs.get("grammar_string"))