Files
exllamav3/examples/chat.py
2025-04-06 14:42:49 +02:00

113 lines
4.2 KiB
Python

import sys, os
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import argparse
from exllamav3 import Generator, Job, model_init
from chat_templates import *
import torch
# ANSI color codes
col_default = "\u001b[0m"
col_user = "\u001b[33;1m" # Yellow
col_bot = "\u001b[34;1m" # Blue
col_error = "\u001b[31;1m" # Magenta
col_sysprompt = "\u001b[37;1m" # Grey
@torch.inference_mode()
def main(args):
# Prompt format
if args.modes:
print("Available modes:")
for k, v in prompt_formats.items():
print(f" - {k:16} {v.description}")
return
user_name = args.user_name
bot_name = args.bot_name
prompt_format = prompt_formats[args.mode](user_name, bot_name)
system_prompt = prompt_format.default_system_prompt() if not args.system_prompt else args.system_prompt
add_bos = prompt_format.add_bos()
max_response_tokens = args.max_response_tokens
# Load model
model, config, cache, tokenizer = model_init.init(args)
context_length = cache.max_num_tokens
# Generator
generator = Generator(
model = model,
cache = cache,
tokenizer = tokenizer,
)
stop_conditions = prompt_format.stop_conditions(tokenizer)
# Main loop
print("\n" + col_sysprompt + system_prompt.strip() + col_default)
context = []
while True:
# Get user prompt and add to context
print("\n" + col_user + user_name + ": " + col_default, end = '', flush = True)
if args.mli:
user_prompt = sys.stdin.read().rstrip()
else:
user_prompt = input().strip()
context.append((user_prompt, None))
# Tokenize context and trim from head if too long
def get_input_ids():
frm_context = prompt_format.format(system_prompt, context)
ids_ = tokenizer.encode(frm_context, add_bos = add_bos, encode_special_tokens = True)
exp_len_ = ids_.shape[-1] + max_response_tokens + 1
return ids_, exp_len_
ids, exp_len = get_input_ids()
if exp_len > context_length:
while exp_len > context_length - 2 * max_response_tokens:
context = context[1:]
ids, exp_len = get_input_ids()
# Inference
print("\n" + col_bot + bot_name + ": " + col_default, end = "")
job = Job(
input_ids = ids,
max_new_tokens = max_response_tokens,
stop_conditions = stop_conditions
)
generator.enqueue(job)
# Stream response
response = ""
while generator.num_remaining_jobs():
for r in generator.iterate():
chunk = r.get("text", "")
if not response and chunk.startswith(" "):
print(chunk[1:], end = "", flush = True)
else:
print(chunk, end = "", flush = True)
response += chunk
if r["eos"] and r["eos_reason"] == "max_new_tokens":
print("\n" + col_error + f" !! Response exceeded {max_response_tokens} tokens and was cut short." + col_default)
if not response.endswith("\n"):
print()
# Add response to context
context[-1] = (user_prompt, response.strip())
if __name__ == "__main__":
parser = argparse.ArgumentParser()
model_init.add_args(parser, cache = True)
parser.add_argument("-mode", "--mode", type = str, help = "Prompt mode", required = True)
parser.add_argument("-modes", "--modes", action = "store_true", help = "List available prompt modes and exit")
parser.add_argument("-un", "--user_name", type = str, default = "User", help = "User name (raw mode only)")
parser.add_argument("-bn", "--bot_name", type = str, default = "Assistant", help = "Bot name (raw mode only)")
parser.add_argument("-mli", "--mli", action = "store_true", help = "Enable multi line input")
parser.add_argument("-sp", "--system_prompt", type = str, help = "Use custom system prompt")
parser.add_argument("-maxr", "--max_response_tokens", type = int, default = 1000, help = "Max tokens per response, default = 1000")
# TODO: Sampling options
_args = parser.parse_args()
main(_args)