Support multiline input and simple <think></think> reasoning mode in chat.py

This commit is contained in:
turboderp
2025-04-13 20:59:28 +02:00
parent 1f7c1f709f
commit 3129b748ff
2 changed files with 59 additions and 16 deletions

View File

@@ -7,14 +7,20 @@ from chat_templates import *
from rich.prompt import Prompt
from rich.live import Live
from rich.markdown import Markdown
from prompt_toolkit import prompt as ptk_prompt
from prompt_toolkit.formatted_text import ANSI
import torch
import re
# ANSI color codes
col_default = "\u001b[0m"
col_user = "\u001b[33;1m" # Yellow
col_bot = "\u001b[34;1m" # Blue
col_error = "\u001b[31;1m" # Magenta
col_think1 = "\u001b[35;1m" # Bright magenta
col_think2 = "\u001b[35m" # Magenta
col_error = "\u001b[31;1m" # Bright red
col_sysprompt = "\u001b[37;1m" # Grey
thinktag = ("<think>", "</think>")
def read_input_console(args, user_name):
print("\n" + col_user + user_name + ": " + col_default, end = '', flush = True)
@@ -28,13 +34,20 @@ def read_input_rich(args, user_name):
user_prompt = Prompt.ask("\n" + col_user + user_name + col_default)
return user_prompt
def read_input_ptk(args, user_name):
print()
user_prompt = ptk_prompt(ANSI(col_user + user_name + col_default + ": "), multiline = args.multiline)
return user_prompt
class Streamer_basic:
def __init__(self, args, bot_name):
self.all_text = ""
self.args = args
self.bot_name = bot_name
def __enter__(self):
print("\n" + col_bot + self.bot_name + ": " + col_default, end = "")
print()
print(col_bot + self.bot_name + ": " + col_default, end = "")
return self
def __exit__(self, exc_type, exc_value, traceback):
@@ -52,27 +65,51 @@ class Streamer_basic:
class Streamer_rich:
def __init__(self, args, bot_name):
self.all_text = ""
self.think_text = ""
self.bot_name = bot_name
self.all_print_text = col_bot + self.bot_name + col_default + ": "
self.live = Live(refresh_per_second = args.refresh_per_second, vertical_overflow = "visible")
self.args = args
self.live = None
self.is_live = False
def __enter__(self):
print()
def begin(self):
self.live = Live(refresh_per_second = self.args.refresh_per_second, vertical_overflow = "visible")
self.live.__enter__()
self.live.update(Markdown(self.all_print_text))
self.is_live = True
def __enter__(self):
if self.args.think:
print()
print(col_think1 + "Thinking" + col_default + ": " + col_think2, end = "")
else:
print()
self.begin()
return self
def __exit__(self, exc_type, exc_value, traceback):
self.live.__exit__(exc_type, exc_value, traceback)
if self.is_live:
self.live.__exit__(exc_type, exc_value, traceback)
def stream(self, text: str):
if self.all_text or not text.startswith(" "):
print_text = text
if self.args.think and not self.is_live:
if self.think_text or not text.startswith(" "):
print_text = text
else:
print_text = text[1:]
self.think_text += print_text
print(print_text, end = "", flush = True)
if thinktag[1] in self.think_text:
self.begin()
else:
print_text = text[1:]
self.all_text += text
self.all_print_text += print_text
self.live.update(Markdown(self.all_print_text))
if self.all_text or not text.startswith(" "):
print_text = text
else:
print_text = text[1:]
self.all_text += text
self.all_print_text += print_text
formatted_text = self.all_print_text
self.live.update(Markdown(formatted_text))
@torch.inference_mode()
def main(args):
@@ -92,10 +129,10 @@ def main(args):
max_response_tokens = args.max_response_tokens
if args.basic_console:
read_input_fn = read_input_console
read_input_fn = read_input_ptk
streamer_cm = Streamer_basic
else:
read_input_fn = read_input_rich
read_input_fn = read_input_ptk
streamer_cm = Streamer_rich
# Load model
@@ -123,6 +160,8 @@ def main(args):
# Tokenize context and trim from head if too long
def get_input_ids():
frm_context = prompt_format.format(system_prompt, context)
if args.think:
frm_context += thinktag[0]
ids_ = tokenizer.encode(frm_context, add_bos = add_bos, encode_special_tokens = True)
exp_len_ = ids_.shape[-1] + max_response_tokens + 1
return ids_, exp_len_
@@ -158,7 +197,9 @@ def main(args):
)
# Add response to context
context[-1] = (user_prompt, s.all_text.strip())
response = s.all_text.strip()
context[-1] = (user_prompt, response)
if __name__ == "__main__":
@@ -168,11 +209,12 @@ if __name__ == "__main__":
parser.add_argument("-modes", "--modes", action = "store_true", help = "List available prompt modes and exit")
parser.add_argument("-un", "--user_name", type = str, default = "User", help = "User name (raw mode only)")
parser.add_argument("-bn", "--bot_name", type = str, default = "Assistant", help = "Bot name (raw mode only)")
parser.add_argument("-mli", "--mli", action = "store_true", help = "Enable multi line input")
parser.add_argument("-mli", "--multiline", action = "store_true", help = "Enable multi line input (use Alt+Enter to submit input)")
parser.add_argument("-sp", "--system_prompt", type = str, help = "Use custom system prompt")
parser.add_argument("-maxr", "--max_response_tokens", type = int, default = 1000, help = "Max tokens per response, default = 1000")
parser.add_argument("-basic", "--basic_console", action = "store_true", help = "Use basic console output (no markdown and fancy prompt input")
parser.add_argument("-rps", "--refresh_per_second", type = int, help = "Max updates per second in Markdown mode, default = 25", default = 25)
parser.add_argument("-think", "--think", action = "store_true", help = "Use (very simplistic) reasoning template and formatting")
# TODO: Sampling options
_args = parser.parse_args()
main(_args)

View File

@@ -1 +1,2 @@
blessed
prompt_toolkit