This commit is contained in:
turboderp
2024-07-08 03:49:26 +02:00
parent 4cf79c5ae1
commit 8bbce3455c

View File

@@ -2,7 +2,10 @@ import traceback
from exllamav2 import ExLlamaV2, ExLlamaV2Tokenizer from exllamav2 import ExLlamaV2, ExLlamaV2Tokenizer
from exllamav2.generator.filters import ExLlamaV2Filter, ExLlamaV2PrefixFilter from exllamav2.generator.filters import ExLlamaV2Filter, ExLlamaV2PrefixFilter
from lmformatenforcer import JsonSchemaParser, RegexParser from lmformatenforcer import JsonSchemaParser, RegexParser
from lmformatenforcer.integrations.exllamav2 import ExLlamaV2TokenEnforcerFilter, build_token_enforcer_tokenizer_data from lmformatenforcer.integrations.exllamav2 import (
ExLlamaV2TokenEnforcerFilter,
build_token_enforcer_tokenizer_data,
)
from loguru import logger from loguru import logger
from typing import List from typing import List
from functools import lru_cache from functools import lru_cache
@@ -56,6 +59,7 @@ class ExLlamaV2EbnfFilter(ExLlamaV2Filter):
def _get_lmfe_tokenizer_data(tokenizer: ExLlamaV2Tokenizer): def _get_lmfe_tokenizer_data(tokenizer: ExLlamaV2Tokenizer):
return build_token_enforcer_tokenizer_data(tokenizer) return build_token_enforcer_tokenizer_data(tokenizer)
def clear_grammar_func_cache(): def clear_grammar_func_cache():
"""Flush tokenizer_data cache to avoid holding references to tokenizers after unloading a model""" """Flush tokenizer_data cache to avoid holding references to tokenizers after unloading a model"""
@@ -93,7 +97,9 @@ class ExLlamaV2Grammar:
# Allow JSON objects or JSON arrays at the top level # Allow JSON objects or JSON arrays at the top level
json_prefixes = ["[", "{"] json_prefixes = ["[", "{"]
lmfilter = ExLlamaV2TokenEnforcerFilter(schema_parser, _get_lmfe_tokenizer_data(tokenizer)) lmfilter = ExLlamaV2TokenEnforcerFilter(
schema_parser, _get_lmfe_tokenizer_data(tokenizer)
)
prefix_filter = ExLlamaV2PrefixFilter(model, tokenizer, json_prefixes) prefix_filter = ExLlamaV2PrefixFilter(model, tokenizer, json_prefixes)
# Append the filters # Append the filters
@@ -118,7 +124,9 @@ class ExLlamaV2Grammar:
return return
lmfilter = ExLlamaV2TokenEnforcerFilter(pattern_parser, _get_lmfe_tokenizer_data(tokenizer)) lmfilter = ExLlamaV2TokenEnforcerFilter(
pattern_parser, _get_lmfe_tokenizer_data(tokenizer)
)
# Append the filters # Append the filters
self.filters.append(lmfilter) self.filters.append(lmfilter)
@@ -146,4 +154,4 @@ class ExLlamaV2Grammar:
return return
self.filters.append(ebnf_filter) self.filters.append(ebnf_filter)