diff --git a/backends/exllamav2/grammar.py b/backends/exllamav2/grammar.py index fa3306a..d137fd8 100644 --- a/backends/exllamav2/grammar.py +++ b/backends/exllamav2/grammar.py @@ -2,9 +2,10 @@ import traceback from exllamav2 import ExLlamaV2, ExLlamaV2Tokenizer from exllamav2.generator.filters import ExLlamaV2Filter, ExLlamaV2PrefixFilter from lmformatenforcer import JsonSchemaParser, RegexParser -from lmformatenforcer.integrations.exllamav2 import ExLlamaV2TokenEnforcerFilter +from lmformatenforcer.integrations.exllamav2 import ExLlamaV2TokenEnforcerFilter, build_token_enforcer_tokenizer_data from loguru import logger from typing import List +from functools import lru_cache class OutlinesTokenizerWrapper: @@ -59,6 +60,10 @@ class ExLlamaV2Grammar: def __init__(self): self.filters = [] + @lru_cache(10) + def _get_lmfe_tokenizer_data(self, tokenizer: ExLlamaV2Tokenizer): + return build_token_enforcer_tokenizer_data(tokenizer) + def add_json_schema_filter( self, json_schema: dict, @@ -82,7 +87,7 @@ class ExLlamaV2Grammar: # Allow JSON objects or JSON arrays at the top level json_prefixes = ["[", "{"] - lmfilter = ExLlamaV2TokenEnforcerFilter(schema_parser, tokenizer) + lmfilter = ExLlamaV2TokenEnforcerFilter(schema_parser, self._get_lmfe_tokenizer_data(tokenizer)) prefix_filter = ExLlamaV2PrefixFilter(model, tokenizer, json_prefixes) # Append the filters @@ -107,7 +112,7 @@ class ExLlamaV2Grammar: return - lmfilter = ExLlamaV2TokenEnforcerFilter(pattern_parser, tokenizer) + lmfilter = ExLlamaV2TokenEnforcerFilter(pattern_parser, self._get_lmfe_tokenizer_data(tokenizer)) # Append the filters self.filters.append(lmfilter) @@ -135,4 +140,4 @@ class ExLlamaV2Grammar: return - self.filters.append(ebnf_filter) + self.filters.append(ebnf_filter) \ No newline at end of file