mirror of
https://github.com/turboderp-org/exllamav2.git
synced 2026-04-20 14:29:28 +00:00
Add lm-format-enforcer JSON example
This commit is contained in:
137
examples/lm_format_enforcer.py
Normal file
137
examples/lm_format_enforcer.py
Normal file
@@ -0,0 +1,137 @@
|
||||
|
||||
import sys, os
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from pydantic import BaseModel, conlist
|
||||
from typing import Literal
|
||||
from lmformatenforcer.integrations.exllamav2 import ExLlamaV2TokenEnforcerFilter
|
||||
from lmformatenforcer import JsonSchemaParser
|
||||
|
||||
from exllamav2 import (
|
||||
ExLlamaV2,
|
||||
ExLlamaV2Config,
|
||||
ExLlamaV2Cache,
|
||||
ExLlamaV2Tokenizer,
|
||||
)
|
||||
|
||||
from exllamav2.generator import (
|
||||
ExLlamaV2StreamingGenerator,
|
||||
ExLlamaV2Sampler,
|
||||
)
|
||||
|
||||
from exllamav2.generator.filters import (
|
||||
ExLlamaV2PrefixFilter
|
||||
)
|
||||
|
||||
import time, json
|
||||
|
||||
# Initialize model and cache
|
||||
|
||||
model_directory = "/mnt/str/models/llama2-13b-exl2/4.0bpw/"
|
||||
|
||||
config = ExLlamaV2Config()
|
||||
config.model_dir = model_directory
|
||||
config.prepare()
|
||||
|
||||
model = ExLlamaV2(config)
|
||||
print("Loading model: " + model_directory)
|
||||
|
||||
cache = ExLlamaV2Cache(model, lazy = True)
|
||||
model.load_autosplit(cache)
|
||||
|
||||
tokenizer = ExLlamaV2Tokenizer(config)
|
||||
|
||||
# Initialize generator
|
||||
|
||||
generator = ExLlamaV2StreamingGenerator(model, cache, tokenizer)
|
||||
generator.warmup() # for more accurate timing
|
||||
|
||||
# Generate with or without filter
|
||||
|
||||
def completion(prompt, filters = None, max_new_tokens = 200, eos_bias = False):
|
||||
|
||||
settings = ExLlamaV2Sampler.Settings()
|
||||
settings.temperature = 0.75
|
||||
settings.top_k = 0
|
||||
settings.top_p = 0.5
|
||||
settings.token_repetition_penalty = 1.0
|
||||
|
||||
settings.filters = filters
|
||||
|
||||
# If using a filter, sample the EOS token as soon as filter allows it
|
||||
|
||||
if len(filters) > 0: settings.filter_prefer_eos = True
|
||||
|
||||
# Send prompt to generator to begin stream
|
||||
|
||||
input_ids = tokenizer.encode(prompt)
|
||||
prompt_tokens = input_ids.shape[-1]
|
||||
|
||||
time_begin_prompt = time.time()
|
||||
|
||||
generator.set_stop_conditions([tokenizer.eos_token_id])
|
||||
generator.begin_stream(input_ids, settings)
|
||||
|
||||
# Streaming loop
|
||||
|
||||
time_begin_stream = time.time()
|
||||
generated_tokens = 0
|
||||
|
||||
print("--------------------------------------------------")
|
||||
print(prompt)
|
||||
print(" ------>" + (" (filtered)" if len(filters) > 0 else ""))
|
||||
|
||||
result = ""
|
||||
while True:
|
||||
chunk, eos, _ = generator.stream()
|
||||
result += chunk
|
||||
generated_tokens += 1
|
||||
print(chunk, end = "")
|
||||
sys.stdout.flush()
|
||||
if eos or generated_tokens == max_new_tokens: break
|
||||
|
||||
time_end = time.time()
|
||||
|
||||
time_prompt = time_begin_stream - time_begin_prompt
|
||||
time_tokens = time_end - time_begin_stream
|
||||
|
||||
print("\n")
|
||||
print(f"Prompt processed in {time_prompt:.2f} seconds, {prompt_tokens} tokens, {prompt_tokens / time_prompt:.2f} tokens/second")
|
||||
print(f"Response generated in {time_tokens:.2f} seconds, {generated_tokens} tokens, {generated_tokens / time_tokens:.2f} tokens/second")
|
||||
print("\n\n")
|
||||
|
||||
return result
|
||||
|
||||
# Configure filter
|
||||
|
||||
class SuperheroAppearance(BaseModel):
|
||||
title: str
|
||||
issue_number: int
|
||||
year: int
|
||||
|
||||
class Superhero(BaseModel):
|
||||
name: str
|
||||
secret_identity: str
|
||||
superpowers: conlist(str, max_length = 5)
|
||||
first_appearance: SuperheroAppearance
|
||||
gender: Literal["male", "female"]
|
||||
|
||||
schema_parser = JsonSchemaParser(Superhero.schema())
|
||||
lmfe_filter = ExLlamaV2TokenEnforcerFilter(schema_parser, tokenizer)
|
||||
prefix_filter = ExLlamaV2PrefixFilter(model, tokenizer, "{") # Make sure we start JSONing right away
|
||||
|
||||
# Run some tests
|
||||
|
||||
prompt = "Here is some information about Superman:\n"
|
||||
completion(prompt, [])
|
||||
result = completion(prompt, [lmfe_filter, prefix_filter], eos_bias = True)
|
||||
|
||||
j = json.loads(result)
|
||||
print("Parsed JSON:" , j)
|
||||
|
||||
prompt = "Here is some information about Batman:\n"
|
||||
completion(prompt, [])
|
||||
result = completion(prompt, [lmfe_filter, prefix_filter], eos_bias = True)
|
||||
|
||||
j = json.loads(result)
|
||||
print("Parsed JSON:" , j)
|
||||
Reference in New Issue
Block a user