mirror of
https://github.com/theroyallab/tabbyAPI.git
synced 2026-04-25 16:59:09 +00:00
Model: Add filter support to dynamic gen
Dynamic gen takes in filters differently. Adjust to set the filter list per class rather than in the generation function. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
@@ -1,10 +1,10 @@
|
|||||||
import traceback
|
import traceback
|
||||||
from exllamav2 import ExLlamaV2, ExLlamaV2Tokenizer
|
from exllamav2 import ExLlamaV2, ExLlamaV2Tokenizer
|
||||||
from exllamav2.generator import ExLlamaV2Sampler
|
|
||||||
from exllamav2.generator.filters import ExLlamaV2Filter, ExLlamaV2PrefixFilter
|
from exllamav2.generator.filters import ExLlamaV2Filter, ExLlamaV2PrefixFilter
|
||||||
from lmformatenforcer import JsonSchemaParser, RegexParser
|
from lmformatenforcer import JsonSchemaParser, RegexParser
|
||||||
from lmformatenforcer.integrations.exllamav2 import ExLlamaV2TokenEnforcerFilter
|
from lmformatenforcer.integrations.exllamav2 import ExLlamaV2TokenEnforcerFilter
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
|
||||||
class OutlinesTokenizerWrapper:
|
class OutlinesTokenizerWrapper:
|
||||||
@@ -54,10 +54,14 @@ class ExLlamaV2EbnfFilter(ExLlamaV2Filter):
|
|||||||
class ExLlamaV2Grammar:
|
class ExLlamaV2Grammar:
|
||||||
"""ExLlamaV2 class for various grammar filters/parsers."""
|
"""ExLlamaV2 class for various grammar filters/parsers."""
|
||||||
|
|
||||||
|
filters: List[ExLlamaV2Filter]
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.filters = []
|
||||||
|
|
||||||
def add_json_schema_filter(
|
def add_json_schema_filter(
|
||||||
self,
|
self,
|
||||||
json_schema: dict,
|
json_schema: dict,
|
||||||
gen_settings: ExLlamaV2Sampler.Settings,
|
|
||||||
model: ExLlamaV2,
|
model: ExLlamaV2,
|
||||||
tokenizer: ExLlamaV2Tokenizer,
|
tokenizer: ExLlamaV2Tokenizer,
|
||||||
):
|
):
|
||||||
@@ -79,13 +83,11 @@ class ExLlamaV2Grammar:
|
|||||||
prefix_filter = ExLlamaV2PrefixFilter(model, tokenizer, "{")
|
prefix_filter = ExLlamaV2PrefixFilter(model, tokenizer, "{")
|
||||||
|
|
||||||
# Append the filters
|
# Append the filters
|
||||||
gen_settings.filters.extend([lmfilter, prefix_filter])
|
self.filters.extend([lmfilter, prefix_filter])
|
||||||
gen_settings.filter_prefer_eos = True
|
|
||||||
|
|
||||||
def add_regex_filter(
|
def add_regex_filter(
|
||||||
self,
|
self,
|
||||||
pattern: str,
|
pattern: str,
|
||||||
gen_settings: ExLlamaV2Sampler.Settings,
|
|
||||||
tokenizer: ExLlamaV2Tokenizer,
|
tokenizer: ExLlamaV2Tokenizer,
|
||||||
):
|
):
|
||||||
"""Adds an ExllamaV2 filter based on regular expressions."""
|
"""Adds an ExllamaV2 filter based on regular expressions."""
|
||||||
@@ -105,13 +107,11 @@ class ExLlamaV2Grammar:
|
|||||||
lmfilter = ExLlamaV2TokenEnforcerFilter(pattern_parser, tokenizer)
|
lmfilter = ExLlamaV2TokenEnforcerFilter(pattern_parser, tokenizer)
|
||||||
|
|
||||||
# Append the filters
|
# Append the filters
|
||||||
gen_settings.filters.extend([lmfilter])
|
self.filters.extend([lmfilter])
|
||||||
gen_settings.filter_prefer_eos = True
|
|
||||||
|
|
||||||
def add_ebnf_filter(
|
def add_ebnf_filter(
|
||||||
self,
|
self,
|
||||||
ebnf_string: str,
|
ebnf_string: str,
|
||||||
gen_settings: ExLlamaV2Sampler.Settings,
|
|
||||||
model: ExLlamaV2,
|
model: ExLlamaV2,
|
||||||
tokenizer: ExLlamaV2Tokenizer,
|
tokenizer: ExLlamaV2Tokenizer,
|
||||||
):
|
):
|
||||||
@@ -132,5 +132,4 @@ class ExLlamaV2Grammar:
|
|||||||
|
|
||||||
return
|
return
|
||||||
|
|
||||||
gen_settings.filters.append(ebnf_filter)
|
self.filters.append(ebnf_filter)
|
||||||
gen_settings.filter_prefer_eos = True
|
|
||||||
|
|||||||
@@ -856,28 +856,23 @@ class ExllamaV2Container:
|
|||||||
|
|
||||||
# Initialize grammar handler
|
# Initialize grammar handler
|
||||||
grammar_handler = ExLlamaV2Grammar()
|
grammar_handler = ExLlamaV2Grammar()
|
||||||
gen_settings.filters = []
|
|
||||||
|
|
||||||
# Add JSON schema filter if it exists
|
# Add JSON schema filter if it exists
|
||||||
json_schema = unwrap(kwargs.get("json_schema"))
|
json_schema = unwrap(kwargs.get("json_schema"))
|
||||||
if json_schema:
|
if json_schema:
|
||||||
grammar_handler.add_json_schema_filter(
|
grammar_handler.add_json_schema_filter(
|
||||||
json_schema, gen_settings, self.model, self.tokenizer
|
json_schema, self.model, self.tokenizer
|
||||||
)
|
)
|
||||||
|
|
||||||
# Add regex filter if it exists
|
# Add regex filter if it exists
|
||||||
regex_pattern = unwrap(kwargs.get("regex_pattern"))
|
regex_pattern = unwrap(kwargs.get("regex_pattern"))
|
||||||
if regex_pattern:
|
if regex_pattern:
|
||||||
grammar_handler.add_regex_filter(
|
grammar_handler.add_regex_filter(regex_pattern, self.tokenizer)
|
||||||
regex_pattern, gen_settings, self.tokenizer
|
|
||||||
)
|
|
||||||
|
|
||||||
# Add EBNF filter if it exists
|
# Add EBNF filter if it exists
|
||||||
grammar_string = unwrap(kwargs.get("grammar_string"))
|
grammar_string = unwrap(kwargs.get("grammar_string"))
|
||||||
if grammar_string:
|
if grammar_string:
|
||||||
grammar_handler.add_ebnf_filter(
|
grammar_handler.add_ebnf_filter(grammar_string, self.model, self.tokenizer)
|
||||||
grammar_string, gen_settings, self.model, self.tokenizer
|
|
||||||
)
|
|
||||||
|
|
||||||
# Fetch EOS tokens from generation_config if they exist
|
# Fetch EOS tokens from generation_config if they exist
|
||||||
eos_tokens = (
|
eos_tokens = (
|
||||||
@@ -971,6 +966,7 @@ class ExllamaV2Container:
|
|||||||
banned_tokens=banned_tokens,
|
banned_tokens=banned_tokens,
|
||||||
banned_strings=banned_strings,
|
banned_strings=banned_strings,
|
||||||
logit_bias=logit_bias,
|
logit_bias=logit_bias,
|
||||||
|
filters=grammar_handler.filters,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Log prompt to console
|
# Log prompt to console
|
||||||
@@ -994,6 +990,8 @@ class ExllamaV2Container:
|
|||||||
gen_settings=gen_settings,
|
gen_settings=gen_settings,
|
||||||
stop_conditions=stop_conditions,
|
stop_conditions=stop_conditions,
|
||||||
decode_special_tokens=decode_special_tokens,
|
decode_special_tokens=decode_special_tokens,
|
||||||
|
filters=grammar_handler.filters,
|
||||||
|
filter_prefer_eos=bool(grammar_handler.filters),
|
||||||
return_probs=request_logprobs > 0,
|
return_probs=request_logprobs > 0,
|
||||||
return_top_tokens=request_logprobs,
|
return_top_tokens=request_logprobs,
|
||||||
return_logits=request_logprobs > 0,
|
return_logits=request_logprobs > 0,
|
||||||
|
|||||||
Reference in New Issue
Block a user