mirror of
https://github.com/theroyallab/tabbyAPI.git
synced 2026-03-15 00:07:28 +00:00
Model: Add filter support to dynamic gen
Dynamic gen takes in filters differently. Adjust to set the filter list per class rather than in the generation function. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
@@ -1,10 +1,10 @@
|
||||
import traceback
|
||||
from exllamav2 import ExLlamaV2, ExLlamaV2Tokenizer
|
||||
from exllamav2.generator import ExLlamaV2Sampler
|
||||
from exllamav2.generator.filters import ExLlamaV2Filter, ExLlamaV2PrefixFilter
|
||||
from lmformatenforcer import JsonSchemaParser, RegexParser
|
||||
from lmformatenforcer.integrations.exllamav2 import ExLlamaV2TokenEnforcerFilter
|
||||
from loguru import logger
|
||||
from typing import List
|
||||
|
||||
|
||||
class OutlinesTokenizerWrapper:
|
||||
@@ -54,10 +54,14 @@ class ExLlamaV2EbnfFilter(ExLlamaV2Filter):
|
||||
class ExLlamaV2Grammar:
|
||||
"""ExLlamaV2 class for various grammar filters/parsers."""
|
||||
|
||||
filters: List[ExLlamaV2Filter]
|
||||
|
||||
def __init__(self):
|
||||
self.filters = []
|
||||
|
||||
def add_json_schema_filter(
|
||||
self,
|
||||
json_schema: dict,
|
||||
gen_settings: ExLlamaV2Sampler.Settings,
|
||||
model: ExLlamaV2,
|
||||
tokenizer: ExLlamaV2Tokenizer,
|
||||
):
|
||||
@@ -79,13 +83,11 @@ class ExLlamaV2Grammar:
|
||||
prefix_filter = ExLlamaV2PrefixFilter(model, tokenizer, "{")
|
||||
|
||||
# Append the filters
|
||||
gen_settings.filters.extend([lmfilter, prefix_filter])
|
||||
gen_settings.filter_prefer_eos = True
|
||||
self.filters.extend([lmfilter, prefix_filter])
|
||||
|
||||
def add_regex_filter(
|
||||
self,
|
||||
pattern: str,
|
||||
gen_settings: ExLlamaV2Sampler.Settings,
|
||||
tokenizer: ExLlamaV2Tokenizer,
|
||||
):
|
||||
"""Adds an ExllamaV2 filter based on regular expressions."""
|
||||
@@ -105,13 +107,11 @@ class ExLlamaV2Grammar:
|
||||
lmfilter = ExLlamaV2TokenEnforcerFilter(pattern_parser, tokenizer)
|
||||
|
||||
# Append the filters
|
||||
gen_settings.filters.extend([lmfilter])
|
||||
gen_settings.filter_prefer_eos = True
|
||||
self.filters.extend([lmfilter])
|
||||
|
||||
def add_ebnf_filter(
|
||||
self,
|
||||
ebnf_string: str,
|
||||
gen_settings: ExLlamaV2Sampler.Settings,
|
||||
model: ExLlamaV2,
|
||||
tokenizer: ExLlamaV2Tokenizer,
|
||||
):
|
||||
@@ -132,5 +132,4 @@ class ExLlamaV2Grammar:
|
||||
|
||||
return
|
||||
|
||||
gen_settings.filters.append(ebnf_filter)
|
||||
gen_settings.filter_prefer_eos = True
|
||||
self.filters.append(ebnf_filter)
|
||||
|
||||
@@ -856,28 +856,23 @@ class ExllamaV2Container:
|
||||
|
||||
# Initialize grammar handler
|
||||
grammar_handler = ExLlamaV2Grammar()
|
||||
gen_settings.filters = []
|
||||
|
||||
# Add JSON schema filter if it exists
|
||||
json_schema = unwrap(kwargs.get("json_schema"))
|
||||
if json_schema:
|
||||
grammar_handler.add_json_schema_filter(
|
||||
json_schema, gen_settings, self.model, self.tokenizer
|
||||
json_schema, self.model, self.tokenizer
|
||||
)
|
||||
|
||||
# Add regex filter if it exists
|
||||
regex_pattern = unwrap(kwargs.get("regex_pattern"))
|
||||
if regex_pattern:
|
||||
grammar_handler.add_regex_filter(
|
||||
regex_pattern, gen_settings, self.tokenizer
|
||||
)
|
||||
grammar_handler.add_regex_filter(regex_pattern, self.tokenizer)
|
||||
|
||||
# Add EBNF filter if it exists
|
||||
grammar_string = unwrap(kwargs.get("grammar_string"))
|
||||
if grammar_string:
|
||||
grammar_handler.add_ebnf_filter(
|
||||
grammar_string, gen_settings, self.model, self.tokenizer
|
||||
)
|
||||
grammar_handler.add_ebnf_filter(grammar_string, self.model, self.tokenizer)
|
||||
|
||||
# Fetch EOS tokens from generation_config if they exist
|
||||
eos_tokens = (
|
||||
@@ -971,6 +966,7 @@ class ExllamaV2Container:
|
||||
banned_tokens=banned_tokens,
|
||||
banned_strings=banned_strings,
|
||||
logit_bias=logit_bias,
|
||||
filters=grammar_handler.filters,
|
||||
)
|
||||
|
||||
# Log prompt to console
|
||||
@@ -994,6 +990,8 @@ class ExllamaV2Container:
|
||||
gen_settings=gen_settings,
|
||||
stop_conditions=stop_conditions,
|
||||
decode_special_tokens=decode_special_tokens,
|
||||
filters=grammar_handler.filters,
|
||||
filter_prefer_eos=bool(grammar_handler.filters),
|
||||
return_probs=request_logprobs > 0,
|
||||
return_top_tokens=request_logprobs,
|
||||
return_logits=request_logprobs > 0,
|
||||
|
||||
Reference in New Issue
Block a user