Files
exllamav2/examples/inference_json.py

95 lines
2.4 KiB
Python

import sys, os
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from exllamav2 import ExLlamaV2, ExLlamaV2Config, ExLlamaV2Cache, ExLlamaV2Tokenizer
from exllamav2.generator import ExLlamaV2DynamicGenerator
from exllamav2.generator.filters import ExLlamaV2PrefixFilter
from lmformatenforcer.integrations.exllamav2 import ExLlamaV2TokenEnforcerFilter
from lmformatenforcer import JsonSchemaParser
from pydantic import BaseModel, conlist
from typing import Literal
import json
model_dir = "/mnt/str/models/mistral-7b-exl2/4.0bpw"
config = ExLlamaV2Config(model_dir)
model = ExLlamaV2(config)
cache = ExLlamaV2Cache(model, max_seq_len = 32768, lazy = True)
model.load_autosplit(cache, progress = True)
print("Loading tokenizer...")
tokenizer = ExLlamaV2Tokenizer(config)
# Initialize the generator with all default parameters
generator = ExLlamaV2DynamicGenerator(
model = model,
cache = cache,
tokenizer = tokenizer,
)
# JSON schema
class SuperheroAppearance(BaseModel):
title: str
issue_number: int
year: int
class Superhero(BaseModel):
name: str
secret_identity: str
superpowers: conlist(str, max_length = 5)
first_appearance: SuperheroAppearance
gender: Literal["male", "female"]
schema_parser = JsonSchemaParser(Superhero.schema())
# Create prompts with and without filters
i_prompts = [
"Here is some information about Superman:\n\n",
"Here is some information about Batman:\n\n",
"Here is some information about Aquaman:\n\n",
]
prompts = []
filters = []
for p in i_prompts:
prompts.append(p)
filters.append(None)
prompts.append(p)
filters.append([
ExLlamaV2TokenEnforcerFilter(schema_parser, tokenizer),
ExLlamaV2PrefixFilter(model, tokenizer, ["{", " {"])
])
# Generate
print("Generating...")
outputs = generator.generate(
prompt = prompts,
filters = filters,
filter_prefer_eos = True,
max_new_tokens = 300,
add_bos = True,
stop_conditions = [tokenizer.eos_token_id],
completion_only = True
)
# Print outputs:
for i in range(len(i_prompts)):
print("---------------------------------------------------------------------------------")
print(i_prompts[i].strip())
print()
print("Without filter:")
print("---------------")
print(outputs[i * 2])
print()
print("With filter:")
print("------------")
print(json.dumps(json.loads(outputs[i * 2 + 1]), indent = 4).strip())
print()