Files
exllamav3/examples/chat.py
2026-03-03 23:15:07 +01:00

366 lines
15 KiB
Python

import sys, os
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import argparse
from exllamav3 import Generator, Job, model_init
from chat_templates import *
from chat_util import *
from chat_io import *
import torch
from chat_console import *
@torch.inference_mode()
def main(args):
# Prompt format
if args.modes or args.mode is None:
print("Available modes:")
for k, v in prompt_formats.items():
print(f" - {k:16} {v.description}")
return
user_name = args.user_name
bot_name = args.bot_name
max_response_tokens = args.max_response_tokens
multiline = args.multiline
show_tps = args.show_tps
think = args.think
assert not (args.think and args.no_think), "Cannot enable think and no_think modes at the same time"
if args.basic_console:
read_input_fn = read_input_ptk
streamer_cm = Streamer_basic
else:
read_input_fn = read_input_ptk
streamer_cm = Streamer_rich
# Prompt format
prompt_format = prompt_formats[args.mode](user_name, bot_name)
spc = {}
if args.think_budget is not None:
spc["thinking_budget"] = args.think_budget
prompt_format.set_special(spc)
system_prompt = prompt_format.default_system_prompt(think) if not args.system_prompt else args.system_prompt
add_bos = prompt_format.add_bos()
# Load model
model, config, cache, tokenizer = model_init.init(args)
context_length = cache.max_num_tokens
# Generator
generator = Generator(
model = model,
cache = cache,
tokenizer = tokenizer,
)
stop_conditions = [sc for sc in prompt_format.stop_conditions(tokenizer) if sc]
if config.eos_token_id_list and all(config.eos_token_id_list):
stop_conditions += config.eos_token_id_list
# Sampler
sampler = model_init.get_arg_sampler(args)
# Single prompt mode
single_prompt = args.prompt
# Main loop
print("\n" + col_sysprompt + system_prompt.strip() + col_default)
context = []
tt = prompt_format.thinktag()
banned_strings = [tt[0], tt[1]] if args.no_think else []
response = ""
last_tokens = None
while True:
# Amnesia mode
if args.amnesia:
context = []
# Get user prompt
if single_prompt is not None:
# This round, use provided prompt from cmdline
user_prompt = single_prompt
prefix = ""
# Next round, exit
single_prompt = "/x"
else:
try:
user_prompt = read_input_fn(args, user_name, multiline)
prefix = ""
except KeyboardInterrupt:
user_prompt = "/x"
# Intercept commands
en_dis = { True: "Enabled", False: "Disabled" }
if user_prompt.startswith("/"):
c = user_prompt.strip().split(" ")
match c[0]:
# List commands
case "/h":
print_info("\n".join([
"/ban Edit banned strings list",
"/cat Run Catbench 1.0",
"/cc Copy last code block to clipboard",
"/cc <n> Copy nth-last code block to clipboard",
"/clear Clear context",
"/e Edit and resume last model response",
"/load Load stored session from ~/chat_py_session.json",
"/load <filename> Load stored session from file",
"/mli Toggle multiline input",
"/r Rewind and repeat last prompt",
"/save Save current session to ~/chat_py_session.json",
"/save <filename> Save current session to file",
"/sp Edit system prompt",
"/t Tokenize context",
"/think Toggle reasoning mode",
"/tps Toggle tokens/second output",
"/x Exit",
]))
continue
# Exit app
case "/x":
print_info("Exiting")
break
# Copy codeblock to clipboard
case "/cc":
try:
b = int(c[1])
except:
b = 1
snippet = copy_last_codeblock(response, b)
if not snippet:
print_error("No code block found in last response")
else:
num_lines = len(snippet.split("\n"))
print_info(f"Copied {num_lines} line{'s' if num_lines > 1 else ''} to the clipboard")
continue
# Toggle multiline mode
case "/mli":
multiline = not multiline
print_info(f"{en_dis[multiline]} multiline mode")
continue
# Clear context
case "/clear":
context = []
print_info("Cleared context")
continue
# Toggle TPS
case "/tps":
show_tps = not show_tps
print_info(f"{en_dis[show_tps]} tokens/second output")
continue
# Toggle reasoning
case "/think":
think = not think
print_info(f"{en_dis[think]} reasoning mode")
continue
# Retry last response
case "/r":
user_prompt = context[-1][0]
context = context[:-1]
# Edit last response
case "/e":
print_info("Press Alt+Enter to submit")
user_prompt = context[-1][0]
last_reply = context[-1][-1]
prefix = read_input_fn(args, bot_name, True, last_reply)
context = context[:-1]
# Edit system prompt
case "/sp":
print_info("Press Alt+Enter to submit")
system_prompt = read_input_fn(args, "System prompt", True, system_prompt)
continue
# Edit banned strings
case "/ban":
print_info("Write each string on a new line and enclose in \"double quotes\", press Alt+Enter to submit")
bans = "\n".join(f"\"{b}\"" for b in banned_strings)
bans = read_input_fn(args, "Banned strings", True, bans)
bans = [b.strip() for b in bans.split("\n")]
bans = [b[1:-1] for b in bans if b.startswith("\"") and b.endswith("\"")]
d = len(bans) - len(banned_strings)
banned_strings = bans
if d < 0:
print_info(f"{-d} string(s) removed")
elif d > 0:
print_info(f"{d} string(s) added")
else:
print_info("Strings updated")
continue
# Save conversation
case "/save":
if len(c) == 1:
c.append("~/chat_py_session.json")
save_session(c[1], system_prompt, banned_strings, context)
print_info(f"Saved session to: {c[1]}")
continue
# Load conversation
case "/load":
if len(c) == 1:
c.append("~/chat_py_session.json")
try:
(
system_prompt,
banned_strings,
context
) = load_session(c[1])
print_info(f"Loaded session from: {c[1]}")
except:
print_error(f"Error loading {c[1]}")
continue
# Print token IDs for last response
case "/t":
if last_tokens is None:
print_error(f"No previous response to tokenize")
continue
print_tokens(last_tokens, tokenizer.get_id_to_piece_list())
continue
# Catbench 1.0
case "/cat":
user_prompt = "Write a python script that draws a cute kitten using matplotlib."
case _:
print_error(f"Unknown command: {c[0]}")
continue
# Add to context
context.append((user_prompt, None))
# Tokenize context and trim from head if too long
def get_input_ids(_prefix):
frm_context = prompt_format.format(system_prompt, context, think)
if _prefix:
frm_context += prefix
elif think and prompt_format.thinktag()[0] is not None:
frm_context += prompt_format.thinktag()[0]
ids_ = tokenizer.encode(frm_context, add_bos = add_bos, encode_special_tokens = True)
exp_len_ = ids_.shape[-1] + max_response_tokens + 1
return ids_, exp_len_
ids, exp_len = get_input_ids(prefix)
if exp_len > context_length:
while exp_len > context_length - 2 * max_response_tokens:
context = context[1:]
ids, exp_len = get_input_ids(prefix)
# Inference
job = Job(
input_ids = ids,
max_new_tokens = max_response_tokens,
stop_conditions = stop_conditions,
sampler = sampler,
banned_strings = banned_strings
)
generator.enqueue(job)
# Stream response
ctx_exceeded = False
with (
KeyReader() as keyreader,
streamer_cm(args, bot_name, tt[0], tt[1], args.updates_per_second, think) as s
):
if prefix:
s.stream(prefix)
while generator.num_remaining_jobs():
for r in generator.iterate():
chunk = r.get("text", "")
s.stream(chunk)
token_ids = r.get("token_ids")
if token_ids is not None:
ids = torch.cat((ids, token_ids), dim = -1)
if r["eos"] and r["eos_reason"] == "max_new_tokens":
ctx_exceeded = True
# Check for keypress while streaming
keypress = keyreader.getkey()
match keypress:
case "\x1b":
print(f"\n\n{col_error} !! Aborted.{col_default}")
generator.cancel(job)
r = None
break
last_tokens = ids[0].tolist()
if ctx_exceeded:
print(f"\n{col_error} !! Response exceeded {max_response_tokens} tokens and was cut short.{col_default}")
if show_tps and r:
prompt_tokens = r["prompt_tokens"]
cached_tokens = r["cached_tokens"]
new_ctx_tokens = prompt_tokens - cached_tokens
prompt_tps = new_ctx_tokens / r["time_prefill"]
new_tokens = r["new_tokens"]
tps = new_tokens / r["time_generate"]
print(
"\n"
f"Context: {col_info}{new_ctx_tokens:,}{col_default} new tokens at {col_info}{prompt_tps:.3f}{col_default} t/s - "
f"{col_info}{cached_tokens:,}{col_default} tokens cached - "
f"Generate: {col_info}{new_tokens:,}{col_default} tokens at {col_info}{tps:.3f}{col_default} t/s"
)
if args.debug:
from pprint import pprint
print()
pprint(r, compact = True, indent = 4)
print()
# Add response to context
response = s.all_text.strip()
# Optionally save output
if args.save:
sr = response
if sr and args.save_svg:
sr = extract_svg(sr)
if sr: print_info(f"Found SVG: {len(sr)} characters")
else: print_error(f"No SVG block found")
if sr:
print_info(f"Writing response to: {args.save}")
with open(args.save, "w") as f:
f.write(sr)
else:
print_info(f"Nothing to write")
context[-1] = (user_prompt, response)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
model_init.add_args(parser, cache = True, add_sampling_args = True)
parser.add_argument("-mode", "--mode", type = str, help = "Prompt mode", default = None)
parser.add_argument("-modes", "--modes", action = "store_true", help = "List available prompt modes and exit")
parser.add_argument("-un", "--user_name", type = str, default = "User", help = "User name (raw mode only)")
parser.add_argument("-bn", "--bot_name", type = str, default = "Assistant", help = "Bot name (raw mode only)")
parser.add_argument("-mli", "--multiline", action = "store_true", help = "Enable multi line input (use Alt+Enter to submit input)")
parser.add_argument("-sp", "--system_prompt", type = str, help = "Use custom system prompt")
parser.add_argument("-maxr", "--max_response_tokens", type = int, default = 1000, help = "Max tokens per response, default = 1000")
parser.add_argument("-basic", "--basic_console", action = "store_true", help = "Use basic console output (no markdown and fancy prompt input")
parser.add_argument("-think", "--think", action = "store_true", help = "Use (very simplistic) reasoning template and formatting")
parser.add_argument("-no_think", "--no_think", action = "store_true", help = "Suppress think tags (won't necessarily stop reasoning model from reasoning anyway)")
parser.add_argument("-think_budget", "--think_budget", type = int, help = "Thinking budget for supported models", default = None)
parser.add_argument("-amnesia", "--amnesia", action = "store_true", help = "Forget context with every new prompt")
parser.add_argument("-tps", "--show_tps", action = "store_true", help = "Show tokens/second after every reply")
parser.add_argument("-prompt", "--prompt", type = str, help = "Run single prompt, then exit")
parser.add_argument("-save", "--save", type = str, help = "Save output to file (use with --prompt)")
parser.add_argument("-save_svg", "--save_svg", action = "store_true", help = "Extract SVG from response (use with --save)")
parser.add_argument("-dbg", "--debug", action = "store_true", help = "Print extra debug stuff")
parser.add_argument("-ups", "--updates-per-second", type = int, help = "Max number of console updates per second (markdown console), default: 30", default = 30)
_args = parser.parse_args()
main(_args)