Files
exllamav3/examples/constrained_generation.py
2025-11-16 14:25:46 +01:00

153 lines
5.3 KiB
Python

import sys, os
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from exllamav3 import Config, Model, Cache, Tokenizer, Generator, Job, Filter, FormatronFilter
from formatron.schemas.pydantic import ClassSchema
from formatron.formatter import FormatterBuilder
from pydantic import conlist
from typing import Literal, Optional
from formatron.extractor import NonterminalExtractor
def get_superhero_filter(tokenizer) -> list[Filter]:
class SuperheroAppearance(ClassSchema):
title: str
issue_number: int
year: int
class Superhero(ClassSchema):
name: str
secret_identity: str
gender: Literal["male", "female"]
superpowers: conlist(str, max_length = 5)
first_appearance: SuperheroAppearance
# Create JSON formatter and ExLlama filter
f = FormatterBuilder()
f.append_line(f"{f.json(Superhero, capture_name = 'json')}")
filters = [FormatronFilter(tokenizer, eos_after_completed = True, formatter_builder = f)]
# Additional constraint to force leading {
f = FormatterBuilder()
f.append_line("{")
filters += [FormatronFilter(tokenizer, formatter_builder = f)]
# Test triggered filter, triggers on "Bruce" (ID 79579 in Llama3.1 vocab)
f = FormatterBuilder()
f.append_str(" Thomas")
filters += [FormatronFilter(tokenizer, trigger_token = 79579, formatter_builder = f)]
return filters
def get_arithmetic_filter(tokenizer) -> list[Filter]:
class ArithmeticExpressionExtractor(NonterminalExtractor):
def __init__(self, nonterminal: str, capture_name: Optional[str] = None):
super().__init__(nonterminal, capture_name)
def extract(self, input_str: str) -> Optional[tuple[str, any]]:
i = 0
left_bracket = 0
while i < len(input_str):
if input_str[i].isdigit() or input_str[i] in "+-*/.":
i += 1
continue
if input_str[i] == "(":
i += 1
left_bracket += 1
continue
if input_str[i] == ")":
i += 1
left_bracket -= 1
continue
else:
break
if left_bracket != 0:
return None
return input_str[i:], input_str[:i]
@property
def kbnf_definition(self) -> str:
return (
"""expression ::= term { ("+" | "-") term };"""
"""term ::= factor { ("*" | "/") factor };"""
"""factor ::= number | "(" expression ")";"""
"""number ::= #"[0-9]+(\\\\.[0-9]+)?([eE][+-]?[0-9]+)?";"""
).replace("expression", self.nonterminal)
# Create arithmetic formatter and ExLlama filter
f = FormatterBuilder()
extractor1 = f.extractor(lambda nonterminal: ArithmeticExpressionExtractor(nonterminal, 'ex1'))
extractor2 = f.extractor(lambda nonterminal: ArithmeticExpressionExtractor(nonterminal, 'ex2'))
f.append_line(f"{extractor1} = {extractor2}")
filters = [FormatronFilter(tokenizer, eos_after_completed = True, formatter_builder = f)]
return filters
def stream_gen(generator, tokenizer, prompt, filters):
# Create job
job = Job(
input_ids = tokenizer.encode(prompt, add_bos = True),
filters = filters,
max_new_tokens = 400,
)
generator.enqueue(job)
print("----------------------")
print(prompt, end = "")
while generator.num_remaining_jobs():
results = generator.iterate()
for result in results:
text = result.get("text", "")
print(text, end = "", flush = True)
print()
def main():
# Load model etc.
model_dir = "/mnt/str/models/llama3.1-8b-instruct/exl3/4.0bpw/"
config = Config.from_directory(model_dir)
model = Model.from_config(config)
cache = Cache(model, max_num_tokens = 8192)
model.load()
tokenizer = Tokenizer.from_config(config)
generator = Generator(model, cache, tokenizer)
# Single gens can reuse filters
sh_filter = get_superhero_filter(tokenizer)
stream_gen(generator, tokenizer, "Here is vital information about Superman, in JSON format:\n\n", sh_filter)
stream_gen(generator, tokenizer, "Here is vital information about Batman, in JSON format:\n\n", sh_filter)
ar_filter = get_arithmetic_filter(tokenizer)
stream_gen(generator, tokenizer, "Number of seconds in a century: 100*", ar_filter) # (Llama3.1 can't math)
stream_gen(generator, tokenizer, "Two plus two: 2+", ar_filter)
# But filters are stateful, so we need multiple instances for batched gen
sh_filter2 = get_superhero_filter(tokenizer)
ar_filter2 = get_arithmetic_filter(tokenizer)
batched_gens = generator.generate(
prompt = [
"Here is vital information about Superman, in JSON format:\n\n",
"Here is vital information about Batman, in JSON format:\n\n",
"Number of seconds in a century: 100*",
"Two plus two: 2+",
],
filters = [
sh_filter,
sh_filter2,
ar_filter,
ar_filter2,
],
max_new_tokens = 400,
add_bos = True,
)
for g in batched_gens:
print("----------------------")
print(g)
if __name__ == "__main__":
main()