mirror of
https://github.com/theroyallab/tabbyAPI.git
synced 2026-04-27 17:51:36 +00:00
Grammar: Add custom ExLlamaV2TokenEnforcerFilter class
This commit is contained in:
@@ -1,9 +1,13 @@
|
|||||||
import traceback
|
import traceback
|
||||||
from exllamav2 import ExLlamaV2, ExLlamaV2Tokenizer
|
from exllamav2 import ExLlamaV2, ExLlamaV2Tokenizer
|
||||||
from exllamav2.generator.filters import ExLlamaV2Filter, ExLlamaV2PrefixFilter
|
from exllamav2.generator.filters import ExLlamaV2Filter, ExLlamaV2PrefixFilter
|
||||||
from lmformatenforcer import JsonSchemaParser, RegexParser
|
from lmformatenforcer import (
|
||||||
|
JsonSchemaParser,
|
||||||
|
RegexParser,
|
||||||
|
TokenEnforcer,
|
||||||
|
CharacterLevelParser,
|
||||||
|
)
|
||||||
from lmformatenforcer.integrations.exllamav2 import (
|
from lmformatenforcer.integrations.exllamav2 import (
|
||||||
ExLlamaV2TokenEnforcerFilter,
|
|
||||||
build_token_enforcer_tokenizer_data,
|
build_token_enforcer_tokenizer_data,
|
||||||
)
|
)
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
@@ -54,12 +58,48 @@ class ExLlamaV2EbnfFilter(ExLlamaV2Filter):
|
|||||||
def next(self):
|
def next(self):
|
||||||
return self.fsm.allowed_token_ids(self.state), set()
|
return self.fsm.allowed_token_ids(self.state), set()
|
||||||
|
|
||||||
|
def use_background_worker(self):
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
@lru_cache(10)
|
@lru_cache(10)
|
||||||
def _get_lmfe_tokenizer_data(tokenizer: ExLlamaV2Tokenizer):
|
def _get_lmfe_tokenizer_data(tokenizer: ExLlamaV2Tokenizer):
|
||||||
return build_token_enforcer_tokenizer_data(tokenizer)
|
return build_token_enforcer_tokenizer_data(tokenizer)
|
||||||
|
|
||||||
|
|
||||||
|
class ExLlamaV2TokenEnforcerFilter(ExLlamaV2Filter):
|
||||||
|
"""Filter class for LMFE"""
|
||||||
|
|
||||||
|
token_sequence: List[int]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model: ExLlamaV2,
|
||||||
|
tokenizer: ExLlamaV2Tokenizer,
|
||||||
|
character_level_parser: CharacterLevelParser,
|
||||||
|
):
|
||||||
|
super().__init__(model, tokenizer)
|
||||||
|
tokenizer_data = _get_lmfe_tokenizer_data(tokenizer)
|
||||||
|
self.token_enforcer = TokenEnforcer(tokenizer_data, character_level_parser)
|
||||||
|
self.token_sequence = []
|
||||||
|
|
||||||
|
def begin(self, prefix_str: str):
|
||||||
|
self.token_sequence = []
|
||||||
|
|
||||||
|
def feed(self, token):
|
||||||
|
self.token_sequence.append(int(token[0][0]))
|
||||||
|
|
||||||
|
def next(self):
|
||||||
|
allowed_tokens = self.token_enforcer.get_allowed_tokens(self.token_sequence)
|
||||||
|
if not hasattr(self, "allow_return_type_list"):
|
||||||
|
return set(allowed_tokens), set()
|
||||||
|
else:
|
||||||
|
return sorted(allowed_tokens), []
|
||||||
|
|
||||||
|
def use_background_worker(self):
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
def clear_grammar_func_cache():
|
def clear_grammar_func_cache():
|
||||||
"""Flush tokenizer_data cache to avoid holding references to
|
"""Flush tokenizer_data cache to avoid holding references to
|
||||||
tokenizers after unloading a model"""
|
tokenizers after unloading a model"""
|
||||||
@@ -98,9 +138,7 @@ class ExLlamaV2Grammar:
|
|||||||
# Allow JSON objects or JSON arrays at the top level
|
# Allow JSON objects or JSON arrays at the top level
|
||||||
json_prefixes = ["[", "{"]
|
json_prefixes = ["[", "{"]
|
||||||
|
|
||||||
lmfilter = ExLlamaV2TokenEnforcerFilter(
|
lmfilter = ExLlamaV2TokenEnforcerFilter(model, tokenizer, schema_parser)
|
||||||
schema_parser, _get_lmfe_tokenizer_data(tokenizer)
|
|
||||||
)
|
|
||||||
prefix_filter = ExLlamaV2PrefixFilter(model, tokenizer, json_prefixes)
|
prefix_filter = ExLlamaV2PrefixFilter(model, tokenizer, json_prefixes)
|
||||||
|
|
||||||
# Append the filters
|
# Append the filters
|
||||||
@@ -109,6 +147,7 @@ class ExLlamaV2Grammar:
|
|||||||
def add_regex_filter(
|
def add_regex_filter(
|
||||||
self,
|
self,
|
||||||
pattern: str,
|
pattern: str,
|
||||||
|
model: ExLlamaV2,
|
||||||
tokenizer: ExLlamaV2Tokenizer,
|
tokenizer: ExLlamaV2Tokenizer,
|
||||||
):
|
):
|
||||||
"""Adds an ExllamaV2 filter based on regular expressions."""
|
"""Adds an ExllamaV2 filter based on regular expressions."""
|
||||||
@@ -125,9 +164,7 @@ class ExLlamaV2Grammar:
|
|||||||
|
|
||||||
return
|
return
|
||||||
|
|
||||||
lmfilter = ExLlamaV2TokenEnforcerFilter(
|
lmfilter = ExLlamaV2TokenEnforcerFilter(model, tokenizer, pattern_parser)
|
||||||
pattern_parser, _get_lmfe_tokenizer_data(tokenizer)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Append the filters
|
# Append the filters
|
||||||
self.filters.append(lmfilter)
|
self.filters.append(lmfilter)
|
||||||
|
|||||||
@@ -1141,7 +1141,7 @@ class ExllamaV2Container:
|
|||||||
# Add regex filter if it exists
|
# Add regex filter if it exists
|
||||||
regex_pattern = unwrap(kwargs.get("regex_pattern"))
|
regex_pattern = unwrap(kwargs.get("regex_pattern"))
|
||||||
if regex_pattern:
|
if regex_pattern:
|
||||||
grammar_handler.add_regex_filter(regex_pattern, self.tokenizer)
|
grammar_handler.add_regex_filter(regex_pattern, self.model, self.tokenizer)
|
||||||
|
|
||||||
# Add EBNF filter if it exists
|
# Add EBNF filter if it exists
|
||||||
grammar_string = unwrap(kwargs.get("grammar_string"))
|
grammar_string = unwrap(kwargs.get("grammar_string"))
|
||||||
|
|||||||
Reference in New Issue
Block a user