Cache creation tokenizer_data in LMFE

This commit is contained in:
turboderp
2024-07-08 00:51:59 +02:00
parent bb8b02a60a
commit 4d0bb1ffc3

View File

@@ -2,9 +2,10 @@ import traceback
from exllamav2 import ExLlamaV2, ExLlamaV2Tokenizer from exllamav2 import ExLlamaV2, ExLlamaV2Tokenizer
from exllamav2.generator.filters import ExLlamaV2Filter, ExLlamaV2PrefixFilter from exllamav2.generator.filters import ExLlamaV2Filter, ExLlamaV2PrefixFilter
from lmformatenforcer import JsonSchemaParser, RegexParser from lmformatenforcer import JsonSchemaParser, RegexParser
from lmformatenforcer.integrations.exllamav2 import ExLlamaV2TokenEnforcerFilter from lmformatenforcer.integrations.exllamav2 import ExLlamaV2TokenEnforcerFilter, build_token_enforcer_tokenizer_data
from loguru import logger from loguru import logger
from typing import List from typing import List
from functools import lru_cache
class OutlinesTokenizerWrapper: class OutlinesTokenizerWrapper:
@@ -59,6 +60,10 @@ class ExLlamaV2Grammar:
def __init__(self): def __init__(self):
self.filters = [] self.filters = []
@lru_cache(10)
def _get_lmfe_tokenizer_data(self, tokenizer: ExLlamaV2Tokenizer):
return build_token_enforcer_tokenizer_data(tokenizer)
def add_json_schema_filter( def add_json_schema_filter(
self, self,
json_schema: dict, json_schema: dict,
@@ -82,7 +87,7 @@ class ExLlamaV2Grammar:
# Allow JSON objects or JSON arrays at the top level # Allow JSON objects or JSON arrays at the top level
json_prefixes = ["[", "{"] json_prefixes = ["[", "{"]
lmfilter = ExLlamaV2TokenEnforcerFilter(schema_parser, tokenizer) lmfilter = ExLlamaV2TokenEnforcerFilter(schema_parser, self._get_lmfe_tokenizer_data(tokenizer))
prefix_filter = ExLlamaV2PrefixFilter(model, tokenizer, json_prefixes) prefix_filter = ExLlamaV2PrefixFilter(model, tokenizer, json_prefixes)
# Append the filters # Append the filters
@@ -107,7 +112,7 @@ class ExLlamaV2Grammar:
return return
lmfilter = ExLlamaV2TokenEnforcerFilter(pattern_parser, tokenizer) lmfilter = ExLlamaV2TokenEnforcerFilter(pattern_parser, self._get_lmfe_tokenizer_data(tokenizer))
# Append the filters # Append the filters
self.filters.append(lmfilter) self.filters.append(lmfilter)
@@ -135,4 +140,4 @@ class ExLlamaV2Grammar:
return return
self.filters.append(ebnf_filter) self.filters.append(ebnf_filter)