Merge pull request #146 from theroyallab/tokenizer_data_fix

Tokenizer data fix
This commit is contained in:
Brian Dashore
2024-07-08 15:08:29 -04:00
committed by GitHub
2 changed files with 31 additions and 4 deletions

View File

@@ -2,9 +2,13 @@ import traceback
from exllamav2 import ExLlamaV2, ExLlamaV2Tokenizer
from exllamav2.generator.filters import ExLlamaV2Filter, ExLlamaV2PrefixFilter
from lmformatenforcer import JsonSchemaParser, RegexParser
from lmformatenforcer.integrations.exllamav2 import ExLlamaV2TokenEnforcerFilter
from lmformatenforcer.integrations.exllamav2 import (
ExLlamaV2TokenEnforcerFilter,
build_token_enforcer_tokenizer_data,
)
from loguru import logger
from typing import List
from functools import lru_cache
class OutlinesTokenizerWrapper:
@@ -51,6 +55,18 @@ class ExLlamaV2EbnfFilter(ExLlamaV2Filter):
return self.fsm.allowed_token_ids(self.state), set()
@lru_cache(10)
def _get_lmfe_tokenizer_data(tokenizer: ExLlamaV2Tokenizer):
return build_token_enforcer_tokenizer_data(tokenizer)
def clear_grammar_func_cache():
"""Flush tokenizer_data cache to avoid holding references to
tokenizers after unloading a model"""
_get_lmfe_tokenizer_data.clear_cache()
class ExLlamaV2Grammar:
"""ExLlamaV2 class for various grammar filters/parsers."""
@@ -82,7 +98,9 @@ class ExLlamaV2Grammar:
# Allow JSON objects or JSON arrays at the top level
json_prefixes = ["[", "{"]
lmfilter = ExLlamaV2TokenEnforcerFilter(schema_parser, tokenizer)
lmfilter = ExLlamaV2TokenEnforcerFilter(
schema_parser, _get_lmfe_tokenizer_data(tokenizer)
)
prefix_filter = ExLlamaV2PrefixFilter(model, tokenizer, json_prefixes)
# Append the filters
@@ -107,7 +125,9 @@ class ExLlamaV2Grammar:
return
lmfilter = ExLlamaV2TokenEnforcerFilter(pattern_parser, tokenizer)
lmfilter = ExLlamaV2TokenEnforcerFilter(
pattern_parser, _get_lmfe_tokenizer_data(tokenizer)
)
# Append the filters
self.filters.append(lmfilter)

View File

@@ -26,7 +26,10 @@ from itertools import zip_longest
from loguru import logger
from typing import List, Optional, Union
from backends.exllamav2.grammar import ExLlamaV2Grammar
from backends.exllamav2.grammar import (
ExLlamaV2Grammar,
clear_grammar_func_cache,
)
from backends.exllamav2.utils import (
exllama_disabled_flash_attn,
hardware_supports_flash_attn,
@@ -704,6 +707,10 @@ class ExllamaV2Container:
# Wait for other jobs to finish
await self.wait_for_jobs(kwargs.get("skip_wait"))
# Delete references held in the grammar module
clear_grammar_func_cache()
# Unload LoRAs
if self.generator and self.generator.generator.current_loras:
for lora in self.generator.generator.current_loras:
lora.unload()