mirror of
https://github.com/turboderp-org/exllamav3.git
synced 2026-03-15 00:07:24 +00:00
366 lines
15 KiB
Python
366 lines
15 KiB
Python
import sys, os
|
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|
|
|
import argparse
|
|
from exllamav3 import Generator, Job, model_init
|
|
from chat_templates import *
|
|
from chat_util import *
|
|
from chat_io import *
|
|
import torch
|
|
from chat_console import *
|
|
|
|
@torch.inference_mode()
|
|
def main(args):
|
|
|
|
# Prompt format
|
|
if args.modes or args.mode is None:
|
|
print("Available modes:")
|
|
for k, v in prompt_formats.items():
|
|
print(f" - {k:16} {v.description}")
|
|
return
|
|
|
|
user_name = args.user_name
|
|
bot_name = args.bot_name
|
|
max_response_tokens = args.max_response_tokens
|
|
multiline = args.multiline
|
|
show_tps = args.show_tps
|
|
think = args.think
|
|
assert not (args.think and args.no_think), "Cannot enable think and no_think modes at the same time"
|
|
|
|
if args.basic_console:
|
|
read_input_fn = read_input_ptk
|
|
streamer_cm = Streamer_basic
|
|
else:
|
|
read_input_fn = read_input_ptk
|
|
streamer_cm = Streamer_rich
|
|
|
|
# Prompt format
|
|
prompt_format = prompt_formats[args.mode](user_name, bot_name)
|
|
spc = {}
|
|
if args.think_budget is not None:
|
|
spc["thinking_budget"] = args.think_budget
|
|
prompt_format.set_special(spc)
|
|
system_prompt = prompt_format.default_system_prompt(think) if not args.system_prompt else args.system_prompt
|
|
add_bos = prompt_format.add_bos()
|
|
|
|
# Load model
|
|
model, config, cache, tokenizer = model_init.init(args)
|
|
context_length = cache.max_num_tokens
|
|
|
|
# Generator
|
|
generator = Generator(
|
|
model = model,
|
|
cache = cache,
|
|
tokenizer = tokenizer,
|
|
)
|
|
stop_conditions = [sc for sc in prompt_format.stop_conditions(tokenizer) if sc]
|
|
if config.eos_token_id_list and all(config.eos_token_id_list):
|
|
stop_conditions += config.eos_token_id_list
|
|
|
|
# Sampler
|
|
sampler = model_init.get_arg_sampler(args)
|
|
|
|
# Single prompt mode
|
|
single_prompt = args.prompt
|
|
|
|
# Main loop
|
|
print("\n" + col_sysprompt + system_prompt.strip() + col_default)
|
|
context = []
|
|
tt = prompt_format.thinktag()
|
|
banned_strings = [tt[0], tt[1]] if args.no_think else []
|
|
response = ""
|
|
last_tokens = None
|
|
|
|
while True:
|
|
|
|
# Amnesia mode
|
|
if args.amnesia:
|
|
context = []
|
|
|
|
# Get user prompt
|
|
if single_prompt is not None:
|
|
# This round, use provided prompt from cmdline
|
|
user_prompt = single_prompt
|
|
prefix = ""
|
|
# Next round, exit
|
|
single_prompt = "/x"
|
|
else:
|
|
try:
|
|
user_prompt = read_input_fn(args, user_name, multiline)
|
|
prefix = ""
|
|
except KeyboardInterrupt:
|
|
user_prompt = "/x"
|
|
|
|
# Intercept commands
|
|
en_dis = { True: "Enabled", False: "Disabled" }
|
|
if user_prompt.startswith("/"):
|
|
c = user_prompt.strip().split(" ")
|
|
match c[0]:
|
|
|
|
# List commands
|
|
case "/h":
|
|
print_info("\n".join([
|
|
"/ban Edit banned strings list",
|
|
"/cat Run Catbench 1.0",
|
|
"/cc Copy last code block to clipboard",
|
|
"/cc <n> Copy nth-last code block to clipboard",
|
|
"/clear Clear context",
|
|
"/e Edit and resume last model response",
|
|
"/load Load stored session from ~/chat_py_session.json",
|
|
"/load <filename> Load stored session from file",
|
|
"/mli Toggle multiline input",
|
|
"/r Rewind and repeat last prompt",
|
|
"/save Save current session to ~/chat_py_session.json",
|
|
"/save <filename> Save current session to file",
|
|
"/sp Edit system prompt",
|
|
"/t Tokenize context",
|
|
"/think Toggle reasoning mode",
|
|
"/tps Toggle tokens/second output",
|
|
"/x Exit",
|
|
]))
|
|
continue
|
|
|
|
# Exit app
|
|
case "/x":
|
|
print_info("Exiting")
|
|
break
|
|
|
|
# Copy codeblock to clipboard
|
|
case "/cc":
|
|
try:
|
|
b = int(c[1])
|
|
except:
|
|
b = 1
|
|
snippet = copy_last_codeblock(response, b)
|
|
if not snippet:
|
|
print_error("No code block found in last response")
|
|
else:
|
|
num_lines = len(snippet.split("\n"))
|
|
print_info(f"Copied {num_lines} line{'s' if num_lines > 1 else ''} to the clipboard")
|
|
continue
|
|
|
|
# Toggle multiline mode
|
|
case "/mli":
|
|
multiline = not multiline
|
|
print_info(f"{en_dis[multiline]} multiline mode")
|
|
continue
|
|
|
|
# Clear context
|
|
case "/clear":
|
|
context = []
|
|
print_info("Cleared context")
|
|
continue
|
|
|
|
# Toggle TPS
|
|
case "/tps":
|
|
show_tps = not show_tps
|
|
print_info(f"{en_dis[show_tps]} tokens/second output")
|
|
continue
|
|
|
|
# Toggle reasoning
|
|
case "/think":
|
|
think = not think
|
|
print_info(f"{en_dis[think]} reasoning mode")
|
|
continue
|
|
|
|
# Retry last response
|
|
case "/r":
|
|
user_prompt = context[-1][0]
|
|
context = context[:-1]
|
|
|
|
# Edit last response
|
|
case "/e":
|
|
print_info("Press Alt+Enter to submit")
|
|
user_prompt = context[-1][0]
|
|
last_reply = context[-1][-1]
|
|
prefix = read_input_fn(args, bot_name, True, last_reply)
|
|
context = context[:-1]
|
|
|
|
# Edit system prompt
|
|
case "/sp":
|
|
print_info("Press Alt+Enter to submit")
|
|
system_prompt = read_input_fn(args, "System prompt", True, system_prompt)
|
|
continue
|
|
|
|
# Edit banned strings
|
|
case "/ban":
|
|
print_info("Write each string on a new line and enclose in \"double quotes\", press Alt+Enter to submit")
|
|
bans = "\n".join(f"\"{b}\"" for b in banned_strings)
|
|
bans = read_input_fn(args, "Banned strings", True, bans)
|
|
bans = [b.strip() for b in bans.split("\n")]
|
|
bans = [b[1:-1] for b in bans if b.startswith("\"") and b.endswith("\"")]
|
|
d = len(bans) - len(banned_strings)
|
|
banned_strings = bans
|
|
if d < 0:
|
|
print_info(f"{-d} string(s) removed")
|
|
elif d > 0:
|
|
print_info(f"{d} string(s) added")
|
|
else:
|
|
print_info("Strings updated")
|
|
continue
|
|
|
|
# Save conversation
|
|
case "/save":
|
|
if len(c) == 1:
|
|
c.append("~/chat_py_session.json")
|
|
save_session(c[1], system_prompt, banned_strings, context)
|
|
print_info(f"Saved session to: {c[1]}")
|
|
continue
|
|
|
|
# Load conversation
|
|
case "/load":
|
|
if len(c) == 1:
|
|
c.append("~/chat_py_session.json")
|
|
try:
|
|
(
|
|
system_prompt,
|
|
banned_strings,
|
|
context
|
|
) = load_session(c[1])
|
|
print_info(f"Loaded session from: {c[1]}")
|
|
except:
|
|
print_error(f"Error loading {c[1]}")
|
|
continue
|
|
|
|
# Print token IDs for last response
|
|
case "/t":
|
|
if last_tokens is None:
|
|
print_error(f"No previous response to tokenize")
|
|
continue
|
|
print_tokens(last_tokens, tokenizer.get_id_to_piece_list())
|
|
continue
|
|
|
|
# Catbench 1.0
|
|
case "/cat":
|
|
user_prompt = "Write a python script that draws a cute kitten using matplotlib."
|
|
|
|
case _:
|
|
print_error(f"Unknown command: {c[0]}")
|
|
continue
|
|
|
|
# Add to context
|
|
context.append((user_prompt, None))
|
|
|
|
# Tokenize context and trim from head if too long
|
|
def get_input_ids(_prefix):
|
|
frm_context = prompt_format.format(system_prompt, context, think)
|
|
if _prefix:
|
|
frm_context += prefix
|
|
elif think and prompt_format.thinktag()[0] is not None:
|
|
frm_context += prompt_format.thinktag()[0]
|
|
ids_ = tokenizer.encode(frm_context, add_bos = add_bos, encode_special_tokens = True)
|
|
exp_len_ = ids_.shape[-1] + max_response_tokens + 1
|
|
return ids_, exp_len_
|
|
|
|
ids, exp_len = get_input_ids(prefix)
|
|
if exp_len > context_length:
|
|
while exp_len > context_length - 2 * max_response_tokens:
|
|
context = context[1:]
|
|
ids, exp_len = get_input_ids(prefix)
|
|
|
|
# Inference
|
|
job = Job(
|
|
input_ids = ids,
|
|
max_new_tokens = max_response_tokens,
|
|
stop_conditions = stop_conditions,
|
|
sampler = sampler,
|
|
banned_strings = banned_strings
|
|
)
|
|
generator.enqueue(job)
|
|
|
|
# Stream response
|
|
ctx_exceeded = False
|
|
with (
|
|
KeyReader() as keyreader,
|
|
streamer_cm(args, bot_name, tt[0], tt[1], args.updates_per_second, think) as s
|
|
):
|
|
if prefix:
|
|
s.stream(prefix)
|
|
while generator.num_remaining_jobs():
|
|
for r in generator.iterate():
|
|
chunk = r.get("text", "")
|
|
s.stream(chunk)
|
|
token_ids = r.get("token_ids")
|
|
if token_ids is not None:
|
|
ids = torch.cat((ids, token_ids), dim = -1)
|
|
if r["eos"] and r["eos_reason"] == "max_new_tokens":
|
|
ctx_exceeded = True
|
|
|
|
# Check for keypress while streaming
|
|
keypress = keyreader.getkey()
|
|
match keypress:
|
|
case "\x1b":
|
|
print(f"\n\n{col_error} !! Aborted.{col_default}")
|
|
generator.cancel(job)
|
|
r = None
|
|
break
|
|
|
|
last_tokens = ids[0].tolist()
|
|
|
|
if ctx_exceeded:
|
|
print(f"\n{col_error} !! Response exceeded {max_response_tokens} tokens and was cut short.{col_default}")
|
|
|
|
if show_tps and r:
|
|
prompt_tokens = r["prompt_tokens"]
|
|
cached_tokens = r["cached_tokens"]
|
|
new_ctx_tokens = prompt_tokens - cached_tokens
|
|
prompt_tps = new_ctx_tokens / r["time_prefill"]
|
|
new_tokens = r["new_tokens"]
|
|
tps = new_tokens / r["time_generate"]
|
|
print(
|
|
"\n"
|
|
f"Context: {col_info}{new_ctx_tokens:,}{col_default} new tokens at {col_info}{prompt_tps:.3f}{col_default} t/s - "
|
|
f"{col_info}{cached_tokens:,}{col_default} tokens cached - "
|
|
f"Generate: {col_info}{new_tokens:,}{col_default} tokens at {col_info}{tps:.3f}{col_default} t/s"
|
|
)
|
|
|
|
if args.debug:
|
|
from pprint import pprint
|
|
print()
|
|
pprint(r, compact = True, indent = 4)
|
|
print()
|
|
|
|
# Add response to context
|
|
response = s.all_text.strip()
|
|
|
|
# Optionally save output
|
|
if args.save:
|
|
sr = response
|
|
if sr and args.save_svg:
|
|
sr = extract_svg(sr)
|
|
if sr: print_info(f"Found SVG: {len(sr)} characters")
|
|
else: print_error(f"No SVG block found")
|
|
if sr:
|
|
print_info(f"Writing response to: {args.save}")
|
|
with open(args.save, "w") as f:
|
|
f.write(sr)
|
|
else:
|
|
print_info(f"Nothing to write")
|
|
|
|
context[-1] = (user_prompt, response)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser()
|
|
model_init.add_args(parser, cache = True, add_sampling_args = True)
|
|
parser.add_argument("-mode", "--mode", type = str, help = "Prompt mode", default = None)
|
|
parser.add_argument("-modes", "--modes", action = "store_true", help = "List available prompt modes and exit")
|
|
parser.add_argument("-un", "--user_name", type = str, default = "User", help = "User name (raw mode only)")
|
|
parser.add_argument("-bn", "--bot_name", type = str, default = "Assistant", help = "Bot name (raw mode only)")
|
|
parser.add_argument("-mli", "--multiline", action = "store_true", help = "Enable multi line input (use Alt+Enter to submit input)")
|
|
parser.add_argument("-sp", "--system_prompt", type = str, help = "Use custom system prompt")
|
|
parser.add_argument("-maxr", "--max_response_tokens", type = int, default = 1000, help = "Max tokens per response, default = 1000")
|
|
parser.add_argument("-basic", "--basic_console", action = "store_true", help = "Use basic console output (no markdown and fancy prompt input")
|
|
parser.add_argument("-think", "--think", action = "store_true", help = "Use (very simplistic) reasoning template and formatting")
|
|
parser.add_argument("-no_think", "--no_think", action = "store_true", help = "Suppress think tags (won't necessarily stop reasoning model from reasoning anyway)")
|
|
parser.add_argument("-think_budget", "--think_budget", type = int, help = "Thinking budget for supported models", default = None)
|
|
parser.add_argument("-amnesia", "--amnesia", action = "store_true", help = "Forget context with every new prompt")
|
|
parser.add_argument("-tps", "--show_tps", action = "store_true", help = "Show tokens/second after every reply")
|
|
parser.add_argument("-prompt", "--prompt", type = str, help = "Run single prompt, then exit")
|
|
parser.add_argument("-save", "--save", type = str, help = "Save output to file (use with --prompt)")
|
|
parser.add_argument("-save_svg", "--save_svg", action = "store_true", help = "Extract SVG from response (use with --save)")
|
|
parser.add_argument("-dbg", "--debug", action = "store_true", help = "Print extra debug stuff")
|
|
parser.add_argument("-ups", "--updates-per-second", type = int, help = "Max number of console updates per second (markdown console), default: 30", default = 30)
|
|
_args = parser.parse_args()
|
|
main(_args)
|