mirror of
https://github.com/turboderp-org/exllamav3.git
synced 2026-04-19 22:08:58 +00:00
153 lines
5.3 KiB
Python
153 lines
5.3 KiB
Python
import sys, os
|
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|
from exllamav3 import Config, Model, Cache, Tokenizer, Generator, Job, Filter, FormatronFilter
|
|
from formatron.schemas.pydantic import ClassSchema
|
|
from formatron.formatter import FormatterBuilder
|
|
from pydantic import conlist
|
|
from typing import Literal, Optional
|
|
from formatron.extractor import NonterminalExtractor
|
|
|
|
def get_superhero_filter(tokenizer) -> list[Filter]:
|
|
|
|
class SuperheroAppearance(ClassSchema):
|
|
title: str
|
|
issue_number: int
|
|
year: int
|
|
class Superhero(ClassSchema):
|
|
name: str
|
|
secret_identity: str
|
|
gender: Literal["male", "female"]
|
|
superpowers: conlist(str, max_length = 5)
|
|
first_appearance: SuperheroAppearance
|
|
|
|
# Create JSON formatter and ExLlama filter
|
|
f = FormatterBuilder()
|
|
f.append_line(f"{f.json(Superhero, capture_name = 'json')}")
|
|
filters = [FormatronFilter(tokenizer, eos_after_completed = True, formatter_builder = f)]
|
|
|
|
# Additional constraint to force leading {
|
|
f = FormatterBuilder()
|
|
f.append_line("{")
|
|
filters += [FormatronFilter(tokenizer, formatter_builder = f)]
|
|
|
|
# Test triggered filter, triggers on "Bruce" (ID 79579 in Llama3.1 vocab)
|
|
f = FormatterBuilder()
|
|
f.append_str(" Thomas")
|
|
filters += [FormatronFilter(tokenizer, trigger_token = 79579, formatter_builder = f)]
|
|
|
|
return filters
|
|
|
|
|
|
def get_arithmetic_filter(tokenizer) -> list[Filter]:
|
|
|
|
class ArithmeticExpressionExtractor(NonterminalExtractor):
|
|
def __init__(self, nonterminal: str, capture_name: Optional[str] = None):
|
|
super().__init__(nonterminal, capture_name)
|
|
|
|
def extract(self, input_str: str) -> Optional[tuple[str, any]]:
|
|
i = 0
|
|
left_bracket = 0
|
|
while i < len(input_str):
|
|
if input_str[i].isdigit() or input_str[i] in "+-*/.":
|
|
i += 1
|
|
continue
|
|
if input_str[i] == "(":
|
|
i += 1
|
|
left_bracket += 1
|
|
continue
|
|
if input_str[i] == ")":
|
|
i += 1
|
|
left_bracket -= 1
|
|
continue
|
|
else:
|
|
break
|
|
if left_bracket != 0:
|
|
return None
|
|
return input_str[i:], input_str[:i]
|
|
|
|
@property
|
|
def kbnf_definition(self) -> str:
|
|
return (
|
|
"""expression ::= term { ("+" | "-") term };"""
|
|
"""term ::= factor { ("*" | "/") factor };"""
|
|
"""factor ::= number | "(" expression ")";"""
|
|
"""number ::= #"[0-9]+(\\\\.[0-9]+)?([eE][+-]?[0-9]+)?";"""
|
|
).replace("expression", self.nonterminal)
|
|
|
|
# Create arithmetic formatter and ExLlama filter
|
|
f = FormatterBuilder()
|
|
extractor1 = f.extractor(lambda nonterminal: ArithmeticExpressionExtractor(nonterminal, 'ex1'))
|
|
extractor2 = f.extractor(lambda nonterminal: ArithmeticExpressionExtractor(nonterminal, 'ex2'))
|
|
f.append_line(f"{extractor1} = {extractor2}")
|
|
filters = [FormatronFilter(tokenizer, eos_after_completed = True, formatter_builder = f)]
|
|
|
|
return filters
|
|
|
|
|
|
def stream_gen(generator, tokenizer, prompt, filters):
|
|
|
|
# Create job
|
|
job = Job(
|
|
input_ids = tokenizer.encode(prompt, add_bos = True),
|
|
filters = filters,
|
|
max_new_tokens = 400,
|
|
)
|
|
generator.enqueue(job)
|
|
|
|
print("----------------------")
|
|
print(prompt, end = "")
|
|
|
|
while generator.num_remaining_jobs():
|
|
results = generator.iterate()
|
|
for result in results:
|
|
text = result.get("text", "")
|
|
print(text, end = "", flush = True)
|
|
print()
|
|
|
|
|
|
def main():
|
|
|
|
# Load model etc.
|
|
model_dir = "/mnt/str/models/llama3.1-8b-instruct/exl3/4.0bpw/"
|
|
config = Config.from_directory(model_dir)
|
|
model = Model.from_config(config)
|
|
cache = Cache(model, max_num_tokens = 8192)
|
|
model.load()
|
|
tokenizer = Tokenizer.from_config(config)
|
|
generator = Generator(model, cache, tokenizer)
|
|
|
|
# Single gens can reuse filters
|
|
sh_filter = get_superhero_filter(tokenizer)
|
|
stream_gen(generator, tokenizer, "Here is vital information about Superman, in JSON format:\n\n", sh_filter)
|
|
stream_gen(generator, tokenizer, "Here is vital information about Batman, in JSON format:\n\n", sh_filter)
|
|
ar_filter = get_arithmetic_filter(tokenizer)
|
|
stream_gen(generator, tokenizer, "Number of seconds in a century: 100*", ar_filter) # (Llama3.1 can't math)
|
|
stream_gen(generator, tokenizer, "Two plus two: 2+", ar_filter)
|
|
|
|
# But filters are stateful, so we need multiple instances for batched gen
|
|
sh_filter2 = get_superhero_filter(tokenizer)
|
|
ar_filter2 = get_arithmetic_filter(tokenizer)
|
|
batched_gens = generator.generate(
|
|
prompt = [
|
|
"Here is vital information about Superman, in JSON format:\n\n",
|
|
"Here is vital information about Batman, in JSON format:\n\n",
|
|
"Number of seconds in a century: 100*",
|
|
"Two plus two: 2+",
|
|
],
|
|
filters = [
|
|
sh_filter,
|
|
sh_filter2,
|
|
ar_filter,
|
|
ar_filter2,
|
|
],
|
|
max_new_tokens = 400,
|
|
add_bos = True,
|
|
)
|
|
for g in batched_gens:
|
|
print("----------------------")
|
|
print(g)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|