mirror of
https://github.com/theroyallab/tabbyAPI.git
synced 2026-04-20 06:19:15 +00:00
Merge pull request #146 from theroyallab/tokenizer_data_fix
Tokenizer data fix
This commit is contained in:
@@ -2,9 +2,13 @@ import traceback
|
||||
from exllamav2 import ExLlamaV2, ExLlamaV2Tokenizer
|
||||
from exllamav2.generator.filters import ExLlamaV2Filter, ExLlamaV2PrefixFilter
|
||||
from lmformatenforcer import JsonSchemaParser, RegexParser
|
||||
from lmformatenforcer.integrations.exllamav2 import ExLlamaV2TokenEnforcerFilter
|
||||
from lmformatenforcer.integrations.exllamav2 import (
|
||||
ExLlamaV2TokenEnforcerFilter,
|
||||
build_token_enforcer_tokenizer_data,
|
||||
)
|
||||
from loguru import logger
|
||||
from typing import List
|
||||
from functools import lru_cache
|
||||
|
||||
|
||||
class OutlinesTokenizerWrapper:
|
||||
@@ -51,6 +55,18 @@ class ExLlamaV2EbnfFilter(ExLlamaV2Filter):
|
||||
return self.fsm.allowed_token_ids(self.state), set()
|
||||
|
||||
|
||||
@lru_cache(10)
|
||||
def _get_lmfe_tokenizer_data(tokenizer: ExLlamaV2Tokenizer):
|
||||
return build_token_enforcer_tokenizer_data(tokenizer)
|
||||
|
||||
|
||||
def clear_grammar_func_cache():
|
||||
"""Flush tokenizer_data cache to avoid holding references to
|
||||
tokenizers after unloading a model"""
|
||||
|
||||
_get_lmfe_tokenizer_data.clear_cache()
|
||||
|
||||
|
||||
class ExLlamaV2Grammar:
|
||||
"""ExLlamaV2 class for various grammar filters/parsers."""
|
||||
|
||||
@@ -82,7 +98,9 @@ class ExLlamaV2Grammar:
|
||||
# Allow JSON objects or JSON arrays at the top level
|
||||
json_prefixes = ["[", "{"]
|
||||
|
||||
lmfilter = ExLlamaV2TokenEnforcerFilter(schema_parser, tokenizer)
|
||||
lmfilter = ExLlamaV2TokenEnforcerFilter(
|
||||
schema_parser, _get_lmfe_tokenizer_data(tokenizer)
|
||||
)
|
||||
prefix_filter = ExLlamaV2PrefixFilter(model, tokenizer, json_prefixes)
|
||||
|
||||
# Append the filters
|
||||
@@ -107,7 +125,9 @@ class ExLlamaV2Grammar:
|
||||
|
||||
return
|
||||
|
||||
lmfilter = ExLlamaV2TokenEnforcerFilter(pattern_parser, tokenizer)
|
||||
lmfilter = ExLlamaV2TokenEnforcerFilter(
|
||||
pattern_parser, _get_lmfe_tokenizer_data(tokenizer)
|
||||
)
|
||||
|
||||
# Append the filters
|
||||
self.filters.append(lmfilter)
|
||||
|
||||
@@ -26,7 +26,10 @@ from itertools import zip_longest
|
||||
from loguru import logger
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from backends.exllamav2.grammar import ExLlamaV2Grammar
|
||||
from backends.exllamav2.grammar import (
|
||||
ExLlamaV2Grammar,
|
||||
clear_grammar_func_cache,
|
||||
)
|
||||
from backends.exllamav2.utils import (
|
||||
exllama_disabled_flash_attn,
|
||||
hardware_supports_flash_attn,
|
||||
@@ -704,6 +707,10 @@ class ExllamaV2Container:
|
||||
# Wait for other jobs to finish
|
||||
await self.wait_for_jobs(kwargs.get("skip_wait"))
|
||||
|
||||
# Delete references held in the grammar module
|
||||
clear_grammar_func_cache()
|
||||
|
||||
# Unload LoRAs
|
||||
if self.generator and self.generator.generator.current_loras:
|
||||
for lora in self.generator.generator.current_loras:
|
||||
lora.unload()
|
||||
|
||||
Reference in New Issue
Block a user