diff --git a/examples/chat.py b/examples/chat.py index 7d8506c..8a878f4 100644 --- a/examples/chat.py +++ b/examples/chat.py @@ -24,6 +24,8 @@ def main(args): max_response_tokens = args.max_response_tokens multiline = args.multiline show_tps = args.show_tps + think = args.think + assert not (args.think and args.no_think), "Cannot enable think and no_think modes at the same time" if args.basic_console: read_input_fn = read_input_ptk @@ -38,7 +40,7 @@ def main(args): if args.think_budget is not None: spc["thinking_budget"] = args.think_budget prompt_format.set_special(spc) - system_prompt = prompt_format.default_system_prompt(args.think) if not args.system_prompt else args.system_prompt + system_prompt = prompt_format.default_system_prompt(think) if not args.system_prompt else args.system_prompt add_bos = prompt_format.add_bos() # Load model @@ -90,10 +92,34 @@ def main(args): user_prompt = "/x" # Intercept commands + en_dis = { True: "Enabled", False: "Disabled" } if user_prompt.startswith("/"): c = user_prompt.strip().split(" ") match c[0]: + # List commands + case "/h": + print_info("\n".join([ + "/ban Edit banned strings list", + "/cat Run Catbench 1.0", + "/cc Copy last code block to clipboard", + "/cc Copy nth-last code block to clipboard", + "/clear Clear context", + "/e Edit and resume last model response", + "/load Load stored session from ~/chat_py_session.json", + "/load Load stored session from file", + "/mli Toggle multiline input", + "/r Rewind and repeat last prompt", + "/save Save current session to ~/chat_py_session.json", + "/save Save current session to file", + "/sp Edit system prompt", + "/t Tokenize context", + "/think Toggle reasoning mode", + "/tps Toggle tokens/second output", + "/x Exit", + ])) + continue + # Exit app case "/x": print_info("Exiting") @@ -116,10 +142,7 @@ def main(args): # Toggle multiline mode case "/mli": multiline = not multiline - if multiline: - print_info("Enabled multiline mode") - else: - print_info("Disabled multiline mode") + print_info(f"{en_dis[multiline]} multiline mode") continue # Clear context @@ -130,11 +153,14 @@ def main(args): # Toggle TPS case "/tps": - multiline = not multiline - if multiline: - print_info("Enabled tokens/second output") - else: - print_info("Disabled tokens/second output") + show_tps = not show_tps + print_info(f"{en_dis[show_tps]} tokens/second output") + continue + + # Toggle reasoning + case "/think": + think = not think + print_info(f"{en_dis[think]} reasoning mode") continue # Retry last response @@ -217,10 +243,10 @@ def main(args): # Tokenize context and trim from head if too long def get_input_ids(_prefix): - frm_context = prompt_format.format(system_prompt, context, args.think) + frm_context = prompt_format.format(system_prompt, context, think) if _prefix: frm_context += prefix - elif args.think and prompt_format.thinktag()[0] is not None: + elif think and prompt_format.thinktag()[0] is not None: frm_context += prompt_format.thinktag()[0] ids_ = tokenizer.encode(frm_context, add_bos = add_bos, encode_special_tokens = True) exp_len_ = ids_.shape[-1] + max_response_tokens + 1 @@ -246,7 +272,7 @@ def main(args): ctx_exceeded = False with ( KeyReader() as keyreader, - streamer_cm(args, bot_name, tt[0], tt[1], args.updates_per_second) as s + streamer_cm(args, bot_name, tt[0], tt[1], args.updates_per_second, think) as s ): if prefix: s.stream(prefix) diff --git a/examples/chat_console.py b/examples/chat_console.py index c88cae8..752dc67 100644 --- a/examples/chat_console.py +++ b/examples/chat_console.py @@ -19,10 +19,12 @@ col_info = "\u001b[32;1m" # Green col_sysprompt = "\u001b[37;1m" # Grey def print_error(text): - print(col_error + "\nError: " + col_default + text) + ftext = text.replace("\n", "\n ") + print(col_error + "\nError: " + col_default + ftext) def print_info(text): - print(col_info + "\nInfo: " + col_default + text) + ftext = text.replace("\n", "\n ") + print(col_info + "\nInfo: " + col_default + ftext) def read_input_console(args, user_name, multiline: bool): print("\n" + col_user + user_name + ": " + col_default, end = '', flush = True) @@ -47,12 +49,10 @@ def read_input_ptk(args, user_name, multiline: bool, prefix: str = None): class Streamer_basic: - def __init__(self, args, bot_name, think_tag, end_think_tag, updates_per_second): + def __init__(self, args, bot_name, think_tag, end_think_tag, updates_per_second, think): self.all_text = "" self.args = args self.bot_name = bot_name - self.updates_per_second = updates_per_second - def __enter__(self): print() @@ -127,7 +127,7 @@ class MarkdownConsoleStream: return i class Streamer_rich: - def __init__(self, args, bot_name, think_tag, end_think_tag, updates_per_second): + def __init__(self, args, bot_name, think_tag, end_think_tag, updates_per_second, think): self.all_text = "" self.think_text = "" self.bot_name = bot_name @@ -139,6 +139,7 @@ class Streamer_rich: self.end_think_tag = end_think_tag self.updates_per_second = updates_per_second self.last_update = time.time() + self.think = think def begin(self): self.live = MarkdownConsoleStream() @@ -147,7 +148,7 @@ class Streamer_rich: self.is_live = True def __enter__(self): - if self.args.think and self.think_tag is not None: + if self.think and self.think_tag is not None: print() print(col_think1 + "Thinking" + col_default + ": " + col_think2, end = "") else: @@ -161,7 +162,7 @@ class Streamer_rich: self.live.__exit__(exc_type, exc_value, traceback) def stream(self, text: str, force: bool = False): - if self.args.think and self.think_tag is not None and not self.is_live: + if self.think and self.think_tag is not None and not self.is_live: print_text = text if not self.think_text: print_text = print_text.lstrip()