Files
exllamav2/examples/inference_formatron.py

251 lines
7.2 KiB
Python

import sys, os
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from exllamav2 import ExLlamaV2, ExLlamaV2Config, ExLlamaV2Cache, ExLlamaV2Tokenizer
from exllamav2.generator import ExLlamaV2DynamicGenerator
from formatron.schemas.pydantic import ClassSchema
from formatron.integrations.exllamav2 import create_formatter_filter
from formatron.formatter import FormatterBuilder
from pydantic import conlist
from typing import Literal, Optional
import json
import argparse
def create_superhero_json_formatter():
class SuperheroAppearance(ClassSchema):
title: str
issue_number: int
year: int
class Superhero(ClassSchema):
name: str
secret_identity: str
gender: Literal["male", "female"]
superpowers: conlist(str, max_length = 5)
first_appearance: SuperheroAppearance
f = FormatterBuilder()
f.append_line(f"{f.json(Superhero, capture_name='json')}")
return f
def generate_json(model, tokenizer, generator):
# Prepare prompts, alternating between raw and filtered
i_prompts = [
"Here is some information about Superman:\n\n",
"Here is some information about Batman:\n\n",
"Here is some information about Aquaman:\n\n",
]
formatter = create_superhero_json_formatter()
prompts = []
filters = []
for p in i_prompts:
prompts.append(p)
filters.append(None)
prompts.append(p)
filters.append([
create_formatter_filter(model, tokenizer, formatter),
])
# Generate
print("Generating JSON examples...")
outputs = generator.generate(
prompt = prompts,
filters = filters,
max_new_tokens = 300,
add_bos = True,
stop_conditions = [tokenizer.eos_token_id],
completion_only = True
)
# Print outputs:
for i in range(len(i_prompts)):
print("---------------------------------------------------------------------------------")
print(i_prompts[i].strip())
print()
print("Without filter:")
print("---------------")
print(outputs[i * 2])
print()
print("With filter:")
print("------------")
print(json.dumps(json.loads(outputs[i * 2 + 1]), indent=4).strip())
print()
def create_benchmark_json_formatter():
occupations = [
"Accountant",
"Actor",
"Acupuncturist",
"Acute Care Nurse",
"Adapted Physical Education Specialist",
"Adhesive Bonding Machine Operator",
"Administrative Law Judge",
"Administrative Services Manager",
"Advanced Practice Psychiatric Nurse",
"Advertising Sales Agent",
"Aerospace Engineer",
"Cat"
]
class Address(ClassSchema):
street: str
city: str
state: Optional[str] = None
postal_code: str
country: str
class Person(ClassSchema):
full_name: str
initials: str
occupation: Literal[*occupations]
previous_occupation: Literal[*occupations]
work_address: Address
home_address: Address
cat_owner: bool
dog_owner: bool
nationality: str
f = FormatterBuilder()
f.append_line("\n\nParsed information, JSON format:\n")
f.append_line(f"{f.json(Person, capture_name = 'json')}")
return f
def test_overhead(model, tokenizer, generator, bsz):
prompts = [
"""Emily Jane Parker is currently working as an Administrative Services Manager at Greenfield Solutions. Before """
"""this, she was an Advertising Sales Agent. Her office is located at 789 Corporate Drive, Los Angeles, CA 90210, """
"""USA. She resides at 321 Oak Street, Pasadena, CA 91103, USA. Emily owns several cats who often keep her company. """
"""She is a proud citizen of the United States."""
] * bsz
# Warm up at bsz to make sure kernels are tuned and graphs are built
print(f"bsz {bsz}, warmup...")
outputs = generator.generate(
prompt = prompts,
min_new_tokens = 499,
max_new_tokens = 500,
add_bos = True,
stop_conditions = [tokenizer.eos_token_id],
completion_only = True,
)
# Generate at bsz
print(f"bsz {bsz}, constrained generation...")
formatter = create_benchmark_json_formatter()
filters = [[create_formatter_filter(model, tokenizer, formatter)] for _ in range(bsz)]
outputs, res = generator.generate(
prompt = prompts,
filters = filters,
max_new_tokens = 500,
add_bos = True,
stop_conditions = [tokenizer.eos_token_id],
completion_only = True,
return_last_results = True
)
avg_tokens = sum(r["new_tokens"] for r in res) / bsz
avg_time = sum(r["time_generate"] for r in res) / bsz
tps_filtered = avg_tokens / avg_time
print(f"bsz {bsz}, constrained generation: {tps_filtered:.2f} t/s")
# Decode JSON and print to the console to verify that filter is working
j = outputs[0]
j = j[j.find('{'):]
j = json.dumps(json.loads(j), indent = 4).strip()
print(j)
# Generate at bsz for refernece, unfiltered
print(f"bsz {bsz}, unconstrained generation...")
outputs, res = generator.generate(
prompt = prompts,
max_new_tokens = int(avg_tokens),
add_bos = True,
stop_conditions = [tokenizer.eos_token_id],
completion_only = True,
return_last_results = True
)
avg_tokens = sum(r["new_tokens"] for r in res) / bsz
avg_time = sum(r["time_generate"] for r in res) / bsz
tps_unfiltered = avg_tokens / avg_time
print(f"bsz {bsz}, unconstrained generation: {tps_unfiltered:.2f} t/s")
return tps_filtered, tps_unfiltered
def main(benchmark: bool):
# Load model etc.
model_dir = "/mnt/str/models/llama3-8b-exl2/6.0bpw"
config = ExLlamaV2Config(model_dir)
config.arch_compat_overrides()
model = ExLlamaV2(config)
cache = ExLlamaV2Cache(model, max_seq_len = 32768, lazy = True)
model.load_autosplit(cache, progress = True)
print("Loading tokenizer...")
tokenizer = ExLlamaV2Tokenizer(config)
# Initialize the generator with all default parameters
generator = ExLlamaV2DynamicGenerator(
model = model,
cache = cache,
tokenizer = tokenizer,
)
# Benchmark
if benchmark:
batchsizes = [1, 2, 4]
results = []
for bsz in batchsizes:
tps_filtered, tps_unfiltered = test_overhead(model, tokenizer, generator, bsz)
overhead = 1 - tps_filtered / tps_unfiltered
latency = 1 / tps_filtered - 1 / tps_unfiltered
results.append((overhead, latency))
print("---------------------------------------------------------------------------------")
print("Results:")
print()
for bsz, (overhead, latency) in zip(batchsizes, results):
print(f"bsz {bsz}: overhead {overhead*100:8.2f}% latency: {latency*1000:8.4f} ms")
# Run example generation
else:
generate_json(model, tokenizer, generator)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("-b", "--benchmark", action = "store_true", help = "Run benchmark")
args = parser.parse_args()
main(args.benchmark)