import sys, os sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) import argparse from exllamav3 import Generator, Job, model_init from chat_templates import * from chat_util import * from chat_io import * import torch from chat_console import * from safetensors.torch import save_file @torch.inference_mode() def main(args): # Prompt format if args.modes or args.mode is None: print("Available modes:") for k, v in prompt_formats.items(): print(f" - {k:16} {v.description}") return user_name = args.user_name bot_name = args.bot_name max_response_tokens = args.max_response_tokens multiline = args.multiline show_tps = args.show_tps think = args.think last_input_ids = None assert not (args.think and args.no_think), "Cannot enable think and no_think modes at the same time" save_probs = args.probs if args.basic_console: read_input_fn = read_input_ptk streamer_cm = Streamer_basic else: read_input_fn = read_input_ptk streamer_cm = Streamer_rich # Prompt format prompt_format = prompt_formats[args.mode](user_name, bot_name) spc = {} if args.think_budget is not None: spc["thinking_budget"] = args.think_budget prompt_format.set_special(spc) system_prompt = prompt_format.default_system_prompt(think) if not args.system_prompt else args.system_prompt add_bos = prompt_format.add_bos() # Load model model, config, cache, tokenizer = model_init.init(args) context_length = cache.max_num_tokens # Generator generator = Generator( model = model, cache = cache, tokenizer = tokenizer, ) stop_conditions = [sc for sc in prompt_format.stop_conditions(tokenizer) if sc] if config.eos_token_id_list and all(config.eos_token_id_list): stop_conditions += config.eos_token_id_list # Sampler sampler = model_init.get_arg_sampler(args) # Single prompt mode single_prompt = args.prompt # Main loop print("\n" + col_sysprompt + system_prompt.strip() + col_default) context = [] tt = prompt_format.thinktag() banned_strings = [tt[0], tt[1]] if args.no_think else [] response = "" last_tokens = None while True: # Amnesia mode if args.amnesia: context = [] # Get user prompt enable_healing = False if single_prompt is not None: # This round, use provided prompt from cmdline user_prompt = single_prompt prefix = "" # Next round, exit single_prompt = "/x" else: try: user_prompt = read_input_fn(args, user_name, multiline) prefix = "" except KeyboardInterrupt: user_prompt = "/x" # Intercept commands en_dis = { True: "Enabled", False: "Disabled" } if user_prompt.startswith("/"): c = user_prompt.strip().split(" ") match c[0]: # List commands case "/h": print_info("\n".join([ "/b Random benchmark question", "/ban Edit banned strings list", "/cat Run Catbench 1.0", "/cc Copy last code block to clipboard", "/cc Copy nth-last code block to clipboard", "/clear Clear context", "/e Edit and resume last model response", "/load Load stored session from ~/chat_py_session.json", "/load Load stored session from file", "/mli Toggle multiline input", "/probs Set number of probs recorded (0 to disable), adds overhead", "/r Rewind and repeat last prompt", "/save Save current session to ~/chat_py_session.json", "/save Save current session to file", "/save_ids Save last input IDs to last_ids.safetensors", "/sp Edit system prompt", "/t Tokenize context", "/think Toggle reasoning mode", "/tps Toggle tokens/second output", "/x Exit", ])) continue # Exit app case "/x": print_info("Exiting") break # Random benchmark question case "/b": source = c[1] if len(c) > 1 else None sources = get_sample_sources() if source and source.isnumeric(): i = int(source) source = list(sources)[i - 1] if 1 <= i <= len(sources) else None if source not in sources: print_info( "Available sample sources:\n\n" + "\n".join([f"{i + 1}. {t}" for i, t in enumerate(get_sample_sources())]) ) continue question = sample_question(source) print_info(f"Question from {source}; multi-line mode, press Alt-Enter to submit or Ctrl-C to abort") try: user_prompt = read_input_fn(args, "Prompt", True, question) except KeyboardInterrupt: print_info("Aborted") continue # Copy codeblock to clipboard case "/cc": try: b = int(c[1]) except: b = 1 snippet = copy_last_codeblock(response, b) if not snippet: print_error("No code block found in last response") else: num_lines = len(snippet.split("\n")) print_info(f"Copied {num_lines} line{'s' if num_lines > 1 else ''} to the clipboard") continue # Toggle multiline mode case "/mli": multiline = not multiline print_info(f"{en_dis[multiline]} multiline mode") continue # Clear context case "/clear": context = [] print_info("Cleared context") continue # Toggle TPS case "/tps": show_tps = not show_tps print_info(f"{en_dis[show_tps]} tokens/second output") continue # Toggle reasoning case "/think": think = not think print_info(f"{en_dis[think]} reasoning mode") continue # Retry last response case "/r": if len(context) == 0: print_error(f"No last prompt to replay") continue else: user_prompt = context[-1][0] context = context[:-1] # Edit last response case "/e": print_info("Press Alt-Enter to submit") user_prompt = context[-1][0] last_reply = context[-1][-1] try: prefix = read_input_fn(args, bot_name, True, last_reply) context = context[:-1] enable_healing = True except KeyboardInterrupt: print_info("Exiting") break # Edit system prompt case "/sp": print_info("Press Alt-Enter to submit") try: system_prompt = read_input_fn(args, "System prompt", True, system_prompt) continue except KeyboardInterrupt: print_info("Exiting") break # Edit banned strings case "/ban": print_info("Write each string on a new line and enclose in \"double quotes\", press Alt-Enter to submit") bans = "\n".join(f"\"{b}\"" for b in banned_strings) try: bans = read_input_fn(args, "Banned strings", True, bans) bans = [b.strip() for b in bans.split("\n")] bans = [b[1:-1] for b in bans if b.startswith("\"") and b.endswith("\"")] d = len(bans) - len(banned_strings) banned_strings = bans if d < 0: print_info(f"{-d} string(s) removed") elif d > 0: print_info(f"{d} string(s) added") else: print_info("Strings updated") continue except KeyboardInterrupt: print_info("Exiting") break # Save conversation case "/save": if len(c) == 1: c.append("~/chat_py_session.json") save_session(c[1], system_prompt, banned_strings, context) print_info(f"Saved session to: {c[1]}") continue # Save IDs case "/save_ids": if last_input_ids is None: print_error(f"No IDs to save") else: d = {"ids": last_input_ids} save_file(d, "last_ids.safetensors") print_info(f"Saved IDs to last_ids.safetensors") continue # Load conversation case "/load": if len(c) == 1: c.append("~/chat_py_session.json") try: ( system_prompt, banned_strings, context ) = load_session(c[1]) print_info(f"Loaded session from: {c[1]}") except: print_error(f"Error loading {c[1]}") continue # Print token IDs for last response case "/t": if last_tokens is None: print_error(f"No previous response to tokenize") continue print_tokens(last_tokens, tokenizer.get_id_to_piece_list()) continue # Catbench 1.0 case "/cat": user_prompt = "Write a python script that draws a cute kitten using matplotlib." # Enable/disable case "/probs": n = c[1] if len(c) > 1 else "0" if n.isnumeric(): save_probs = int(n) if save_probs: print_info(f"Saving top-{save_probs} probs per token.") else: print_info(f"Disabled probs") else: print_error("Invalid argument") continue case _: print_error(f"Unknown command: {c[0]}") continue # Add to context context.append((user_prompt, None)) # Tokenize context and trim from head if too long def get_input_ids(_prefix): frm_context = prompt_format.format(system_prompt, context, think) if _prefix: frm_context += prefix elif think and prompt_format.thinktag()[0] is not None: frm_context += prompt_format.thinktag()[0] ids_ = tokenizer.encode(frm_context, add_bos = add_bos, encode_special_tokens = True) exp_len_ = ids_.shape[-1] + max_response_tokens + 1 return ids_, exp_len_ ids, exp_len = get_input_ids(prefix) if exp_len > context_length: while exp_len > context_length - 2 * max_response_tokens: context = context[1:] ids, exp_len = get_input_ids(prefix) last_input_ids = ids.clone() # Inference job = Job( input_ids = ids, max_new_tokens = max_response_tokens, stop_conditions = stop_conditions, sampler = sampler, banned_strings = banned_strings, token_healing = enable_healing, return_logits = save_probs > 0, ) generator.enqueue(job) saved_topk = [] saved_probs = [] saved_samples = [] # Stream response ctx_exceeded = False with ( KeyReader() as keyreader, streamer_cm(args, bot_name, tt[0], tt[1], args.updates_per_second, think) as s ): if prefix: s.stream(prefix) while generator.num_remaining_jobs(): for r in generator.iterate(): chunk = r.get("text", "") s.stream(chunk) token_ids = r.get("token_ids") if save_probs and "logits" in r: logits = r["logits"] probs = logits.softmax(dim = -1) topk = probs.topk(k = save_probs) saved_probs += list(x.flatten().tolist() for x in topk[0].split(1, 1)) saved_topk += list(x.flatten().tolist() for x in topk[1].split(1, 1)) if save_probs and token_ids is not None: saved_samples += list(x.flatten().tolist() for x in token_ids.split(1, 1)) if token_ids is not None: ids = torch.cat((ids, token_ids), dim = -1) if r["eos"] and r["eos_reason"] == "max_new_tokens": ctx_exceeded = True # Check for keypress while streaming keypress = keyreader.getkey() match keypress: case "\x1b": print(f"\n\n{col_error} !! Aborted.{col_default}") generator.cancel(job) r = None break last_tokens = ids[0].tolist() if ctx_exceeded: print(f"\n{col_error} !! Response exceeded {max_response_tokens} tokens and was cut short.{col_default}") if show_tps and r: prompt_tokens = r["prompt_tokens"] cached_tokens = r["cached_tokens"] new_ctx_tokens = prompt_tokens - cached_tokens prompt_tps = new_ctx_tokens / r["time_prefill"] new_tokens = r["new_tokens"] tps = new_tokens / r["time_generate"] print( "\n" f"Context: {col_info}{new_ctx_tokens:,}{col_default} new tokens at {col_info}{prompt_tps:.3f}{col_default} t/s - " f"{col_info}{cached_tokens:,}{col_default} tokens cached - " f"Generate: {col_info}{new_tokens:,}{col_default} tokens at {col_info}{tps:.3f}{col_default} t/s" ) if save_probs: print_probs(saved_topk, saved_probs, saved_samples, tokenizer.get_id_to_piece_list()) if args.debug: from pprint import pprint print() pprint(r, compact = True, indent = 4) print() # Add response to context response = s.all_text.strip() # Optionally save output if args.save: sr = response if sr and args.save_svg: sr = extract_svg(sr) if sr: print_info(f"Found SVG: {len(sr)} characters") else: print_error(f"No SVG block found") if sr: print_info(f"Writing response to: {args.save}") with open(args.save, "w") as f: f.write(sr) else: print_info(f"Nothing to write") context[-1] = (user_prompt, response) if __name__ == "__main__": parser = argparse.ArgumentParser() model_init.add_args(parser, cache = True, add_sampling_args = True) parser.add_argument("-mode", "--mode", type = str, help = "Prompt mode", default = None) parser.add_argument("-modes", "--modes", action = "store_true", help = "List available prompt modes and exit") parser.add_argument("-un", "--user_name", type = str, default = "User", help = "User name (raw mode only)") parser.add_argument("-bn", "--bot_name", type = str, default = "Assistant", help = "Bot name (raw mode only)") parser.add_argument("-mli", "--multiline", action = "store_true", help = "Enable multi line input (use Alt-Enter to submit input)") parser.add_argument("-sp", "--system_prompt", type = str, help = "Use custom system prompt") parser.add_argument("-maxr", "--max_response_tokens", type = int, default = 1000, help = "Max tokens per response, default = 1000") parser.add_argument("-basic", "--basic_console", action = "store_true", help = "Use basic console output (no markdown and fancy prompt input") parser.add_argument("-think", "--think", action = "store_true", help = "Use (very simplistic) reasoning template and formatting") parser.add_argument("-no_think", "--no_think", action = "store_true", help = "Suppress think tags (won't necessarily stop reasoning model from reasoning anyway)") parser.add_argument("-think_budget", "--think_budget", type = int, help = "Thinking budget for supported models", default = None) parser.add_argument("-amnesia", "--amnesia", action = "store_true", help = "Forget context with every new prompt") parser.add_argument("-tps", "--show_tps", action = "store_true", help = "Show tokens/second after every reply") parser.add_argument("-probs", "--probs", type = int, help = "Sample top-K raw probabilities per token, adds overhead", default = 0) parser.add_argument("-prompt", "--prompt", type = str, help = "Run single prompt, then exit") parser.add_argument("-save", "--save", type = str, help = "Save output to file (use with --prompt)") parser.add_argument("-save_svg", "--save_svg", action = "store_true", help = "Extract SVG from response (use with --save)") parser.add_argument("-dbg", "--debug", action = "store_true", help = "Print extra debug stuff") parser.add_argument("-ups", "--updates-per-second", type = int, help = "Max number of console updates per second (markdown console), default: 30", default = 30) _args = parser.parse_args() main(_args)