diff --git a/backends/exllamav2/grammar.py b/backends/exllamav2/grammar.py index 62f50a3..47c5ed5 100644 --- a/backends/exllamav2/grammar.py +++ b/backends/exllamav2/grammar.py @@ -1,12 +1,14 @@ import traceback import typing +from functools import lru_cache from typing import List +import torch from exllamav2 import ExLlamaV2, ExLlamaV2Tokenizer from exllamav2.generator.filters import ExLlamaV2Filter from formatron.extractor import NonterminalExtractor from formatron.formatter import FormatterBuilder -from formatron.integrations.exllamav2 import create_formatter_filter +from formatron.integrations.exllamav2 import FormatterFilter, create_engine_vocabulary from formatron.schemas import json_schema from loguru import logger @@ -48,7 +50,7 @@ class ExLlamaV2Grammar: return - lmfilter = create_formatter_filter(model, tokenizer, f) + lmfilter = _create_formatter_filter(model, tokenizer, f) # Append the filters self.filters.append(lmfilter) @@ -75,7 +77,7 @@ class ExLlamaV2Grammar: return - lmfilter = create_formatter_filter(model, tokenizer, f) + lmfilter = _create_formatter_filter(model, tokenizer, f) # Append the filters self.filters.append(lmfilter) @@ -104,7 +106,7 @@ class ExLlamaV2Grammar: return - lmfilter = create_formatter_filter(model, tokenizer, f) + lmfilter = _create_formatter_filter(model, tokenizer, f) # Append the filters self.filters.append(lmfilter) @@ -124,3 +126,33 @@ class CFGExtractor(NonterminalExtractor): @property def kbnf_definition(self) -> str: return self.kbnf_string.replace("start", self.nonterminal) + + +@lru_cache(1) +def _create_cached_engine_vocabulary(tokenizer: ExLlamaV2Tokenizer): + """Build and cache engine vocabulary on first grammar run""" + + return create_engine_vocabulary(tokenizer) + + +def _create_formatter_filter( + model: ExLlamaV2, tokenizer: ExLlamaV2Tokenizer, formatter_builder: FormatterBuilder +) -> ExLlamaV2Filter: + """ + Create a formatter filter for the ExLlamaV2 engine. + Minimalist clone of formatron.integrations.exllamav2.create_formatter_filter + with lru_cache enabled for engine vocabulary + """ + + vocab = _create_cached_engine_vocabulary(tokenizer) + f = formatter_builder.build( + vocab, lambda tokens: tokenizer.decode(torch.tensor(tokens)) + ) + return FormatterFilter(model, tokenizer, f) + + +def clear_grammar_func_cache(): + """Flush tokenizer_data cache to avoid holding references to + tokenizers after unloading a model""" + + _create_cached_engine_vocabulary.cache_clear() diff --git a/backends/exllamav2/model.py b/backends/exllamav2/model.py index 7a5324f..50cef42 100644 --- a/backends/exllamav2/model.py +++ b/backends/exllamav2/model.py @@ -39,6 +39,7 @@ from common.health import HealthManager from backends.exllamav2.grammar import ( ExLlamaV2Grammar, + clear_grammar_func_cache, ) from backends.exllamav2.utils import ( exllama_disabled_flash_attn, @@ -832,6 +833,9 @@ class ExllamaV2Container: # Wait for other jobs to finish await self.wait_for_jobs(kwargs.get("skip_wait")) + # Delete references held in the grammar module + clear_grammar_func_cache() + # Clear the image embedding cache clear_image_embedding_cache()