mirror of
https://github.com/theroyallab/tabbyAPI.git
synced 2026-03-15 00:07:28 +00:00
153 lines
4.6 KiB
Python
153 lines
4.6 KiB
Python
import traceback
|
|
from exllamav2 import ExLlamaV2, ExLlamaV2Tokenizer
|
|
from exllamav2.generator import ExLlamaV2Sampler
|
|
from exllamav2.generator.filters import ExLlamaV2Filter
|
|
|
|
from common.logger import init_logger
|
|
|
|
# TODO: Remove after new exllama version is released
|
|
try:
|
|
from exllamav2.generator.filters import ExLlamaV2PrefixFilter
|
|
|
|
_exllama_filter_available = True
|
|
except ImportError:
|
|
_exllama_filter_available = False
|
|
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
|
|
class OutlinesTokenizerWrapper:
|
|
"""Wrapper for Outlines tokenizer"""
|
|
|
|
def __init__(self, tokenizer):
|
|
self.tokenizer = tokenizer
|
|
id_to_piece = self.tokenizer.get_id_to_piece_list()
|
|
self.vocabulary = {piece: idx for idx, piece in enumerate(id_to_piece)}
|
|
self.eos_token_id = self.tokenizer.eos_token_id
|
|
self.eos_token = id_to_piece[self.tokenizer.eos_token_id]
|
|
self.special_tokens = list(self.tokenizer.extended_id_to_piece.keys())
|
|
|
|
def convert_token_to_string(self, token):
|
|
return token
|
|
|
|
def decode(self, tokens):
|
|
s = ""
|
|
id_to_piece = self.tokenizer.get_id_to_piece_list()
|
|
for t in tokens:
|
|
s += id_to_piece[t]
|
|
return s
|
|
|
|
|
|
class ExLlamaV2EbnfFilter(ExLlamaV2Filter):
|
|
"""Filter class for context-free grammar via outlines"""
|
|
|
|
def __init__(self, model, tokenizer, grammar):
|
|
from outlines.fsm.fsm import CFGFSM
|
|
|
|
super().__init__(model, tokenizer)
|
|
|
|
self.wrapped_tokenizer = OutlinesTokenizerWrapper(tokenizer)
|
|
self.fsm = CFGFSM(grammar, self.wrapped_tokenizer)
|
|
self.state = self.fsm.first_state
|
|
|
|
def begin(self, prefix_str=""):
|
|
self.state = self.fsm.first_state
|
|
|
|
def feed(self, token):
|
|
self.state = self.fsm.next_state(self.state, token.item())
|
|
|
|
def next(self):
|
|
return self.fsm.allowed_token_ids(self.state), set()
|
|
|
|
|
|
class ExLlamaV2Grammar:
|
|
"""ExLlamaV2 class for various grammar filters/parsers."""
|
|
|
|
def add_json_schema_filter(
|
|
self,
|
|
json_schema: dict,
|
|
gen_settings: ExLlamaV2Sampler.Settings,
|
|
model: ExLlamaV2,
|
|
tokenizer: ExLlamaV2Tokenizer,
|
|
):
|
|
"""Adds an ExllamaV2 filter based on a JSON schema."""
|
|
|
|
if not _exllama_filter_available:
|
|
logger.warning(
|
|
"ExllamaV2PrefixFilter is not available "
|
|
"in the currently installed ExllamaV2 version. "
|
|
"Skipping JSON schema parsing."
|
|
)
|
|
|
|
return
|
|
|
|
# Import optional dependencies
|
|
try:
|
|
from lmformatenforcer import JsonSchemaParser
|
|
from lmformatenforcer.integrations.exllamav2 import (
|
|
ExLlamaV2TokenEnforcerFilter,
|
|
)
|
|
except ImportError:
|
|
logger.error(
|
|
"Skipping JSON schema parsing because "
|
|
"lm-format-enforcer is not installed.\n"
|
|
"Please run the following command: "
|
|
"pip install lm-format-enforcer"
|
|
)
|
|
|
|
return
|
|
|
|
# Create the parser
|
|
try:
|
|
schema_parser = JsonSchemaParser(json_schema)
|
|
except Exception:
|
|
traceback.print_exc()
|
|
logger.error(
|
|
"Skipping because the JSON schema couldn't be parsed. "
|
|
"Please read the above error for more information."
|
|
)
|
|
|
|
return
|
|
|
|
lmfilter = ExLlamaV2TokenEnforcerFilter(schema_parser, tokenizer)
|
|
prefix_filter = ExLlamaV2PrefixFilter(model, tokenizer, "{")
|
|
|
|
# Append the filters
|
|
gen_settings.filters.extend([lmfilter, prefix_filter])
|
|
gen_settings.filter_prefer_eos = True
|
|
|
|
def add_ebnf_filter(
|
|
self,
|
|
ebnf_string: str,
|
|
gen_settings: ExLlamaV2Sampler.Settings,
|
|
model: ExLlamaV2,
|
|
tokenizer: ExLlamaV2Tokenizer,
|
|
):
|
|
"""
|
|
Add an EBNF grammar filter.
|
|
Possibly replace outlines with an in-house solution in the future.
|
|
"""
|
|
|
|
if not _exllama_filter_available:
|
|
logger.warning(
|
|
"filter_prefer_eos is not available "
|
|
"in the currently installed ExllamaV2 version. "
|
|
"Skipping EBNF parsing."
|
|
)
|
|
|
|
return
|
|
|
|
try:
|
|
ebnf_filter = ExLlamaV2EbnfFilter(model, tokenizer, ebnf_string)
|
|
except ImportError:
|
|
logger.error(
|
|
"Skipping EBNF parsing because Outlines is not installed.\n"
|
|
"Please run the following command: pip install outlines"
|
|
)
|
|
|
|
return
|
|
|
|
gen_settings.filters.append(ebnf_filter)
|
|
gen_settings.filter_prefer_eos = True
|