mirror of
https://github.com/turboderp-org/exllamav3.git
synced 2026-04-20 14:29:51 +00:00
Support multiline input and simple <think></think> reasoning mode in chat.py
This commit is contained in:
@@ -7,14 +7,20 @@ from chat_templates import *
|
||||
from rich.prompt import Prompt
|
||||
from rich.live import Live
|
||||
from rich.markdown import Markdown
|
||||
from prompt_toolkit import prompt as ptk_prompt
|
||||
from prompt_toolkit.formatted_text import ANSI
|
||||
import torch
|
||||
import re
|
||||
|
||||
# ANSI color codes
|
||||
col_default = "\u001b[0m"
|
||||
col_user = "\u001b[33;1m" # Yellow
|
||||
col_bot = "\u001b[34;1m" # Blue
|
||||
col_error = "\u001b[31;1m" # Magenta
|
||||
col_think1 = "\u001b[35;1m" # Bright magenta
|
||||
col_think2 = "\u001b[35m" # Magenta
|
||||
col_error = "\u001b[31;1m" # Bright red
|
||||
col_sysprompt = "\u001b[37;1m" # Grey
|
||||
thinktag = ("<think>", "</think>")
|
||||
|
||||
def read_input_console(args, user_name):
|
||||
print("\n" + col_user + user_name + ": " + col_default, end = '', flush = True)
|
||||
@@ -28,13 +34,20 @@ def read_input_rich(args, user_name):
|
||||
user_prompt = Prompt.ask("\n" + col_user + user_name + col_default)
|
||||
return user_prompt
|
||||
|
||||
def read_input_ptk(args, user_name):
|
||||
print()
|
||||
user_prompt = ptk_prompt(ANSI(col_user + user_name + col_default + ": "), multiline = args.multiline)
|
||||
return user_prompt
|
||||
|
||||
class Streamer_basic:
|
||||
def __init__(self, args, bot_name):
|
||||
self.all_text = ""
|
||||
self.args = args
|
||||
self.bot_name = bot_name
|
||||
|
||||
def __enter__(self):
|
||||
print("\n" + col_bot + self.bot_name + ": " + col_default, end = "")
|
||||
print()
|
||||
print(col_bot + self.bot_name + ": " + col_default, end = "")
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
@@ -52,27 +65,51 @@ class Streamer_basic:
|
||||
class Streamer_rich:
|
||||
def __init__(self, args, bot_name):
|
||||
self.all_text = ""
|
||||
self.think_text = ""
|
||||
self.bot_name = bot_name
|
||||
self.all_print_text = col_bot + self.bot_name + col_default + ": "
|
||||
self.live = Live(refresh_per_second = args.refresh_per_second, vertical_overflow = "visible")
|
||||
self.args = args
|
||||
self.live = None
|
||||
self.is_live = False
|
||||
|
||||
def __enter__(self):
|
||||
print()
|
||||
def begin(self):
|
||||
self.live = Live(refresh_per_second = self.args.refresh_per_second, vertical_overflow = "visible")
|
||||
self.live.__enter__()
|
||||
self.live.update(Markdown(self.all_print_text))
|
||||
self.is_live = True
|
||||
|
||||
def __enter__(self):
|
||||
if self.args.think:
|
||||
print()
|
||||
print(col_think1 + "Thinking" + col_default + ": " + col_think2, end = "")
|
||||
else:
|
||||
print()
|
||||
self.begin()
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
self.live.__exit__(exc_type, exc_value, traceback)
|
||||
if self.is_live:
|
||||
self.live.__exit__(exc_type, exc_value, traceback)
|
||||
|
||||
def stream(self, text: str):
|
||||
if self.all_text or not text.startswith(" "):
|
||||
print_text = text
|
||||
if self.args.think and not self.is_live:
|
||||
if self.think_text or not text.startswith(" "):
|
||||
print_text = text
|
||||
else:
|
||||
print_text = text[1:]
|
||||
self.think_text += print_text
|
||||
print(print_text, end = "", flush = True)
|
||||
if thinktag[1] in self.think_text:
|
||||
self.begin()
|
||||
else:
|
||||
print_text = text[1:]
|
||||
self.all_text += text
|
||||
self.all_print_text += print_text
|
||||
self.live.update(Markdown(self.all_print_text))
|
||||
if self.all_text or not text.startswith(" "):
|
||||
print_text = text
|
||||
else:
|
||||
print_text = text[1:]
|
||||
self.all_text += text
|
||||
self.all_print_text += print_text
|
||||
formatted_text = self.all_print_text
|
||||
self.live.update(Markdown(formatted_text))
|
||||
|
||||
@torch.inference_mode()
|
||||
def main(args):
|
||||
@@ -92,10 +129,10 @@ def main(args):
|
||||
max_response_tokens = args.max_response_tokens
|
||||
|
||||
if args.basic_console:
|
||||
read_input_fn = read_input_console
|
||||
read_input_fn = read_input_ptk
|
||||
streamer_cm = Streamer_basic
|
||||
else:
|
||||
read_input_fn = read_input_rich
|
||||
read_input_fn = read_input_ptk
|
||||
streamer_cm = Streamer_rich
|
||||
|
||||
# Load model
|
||||
@@ -123,6 +160,8 @@ def main(args):
|
||||
# Tokenize context and trim from head if too long
|
||||
def get_input_ids():
|
||||
frm_context = prompt_format.format(system_prompt, context)
|
||||
if args.think:
|
||||
frm_context += thinktag[0]
|
||||
ids_ = tokenizer.encode(frm_context, add_bos = add_bos, encode_special_tokens = True)
|
||||
exp_len_ = ids_.shape[-1] + max_response_tokens + 1
|
||||
return ids_, exp_len_
|
||||
@@ -158,7 +197,9 @@ def main(args):
|
||||
)
|
||||
|
||||
# Add response to context
|
||||
context[-1] = (user_prompt, s.all_text.strip())
|
||||
response = s.all_text.strip()
|
||||
|
||||
context[-1] = (user_prompt, response)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
@@ -168,11 +209,12 @@ if __name__ == "__main__":
|
||||
parser.add_argument("-modes", "--modes", action = "store_true", help = "List available prompt modes and exit")
|
||||
parser.add_argument("-un", "--user_name", type = str, default = "User", help = "User name (raw mode only)")
|
||||
parser.add_argument("-bn", "--bot_name", type = str, default = "Assistant", help = "Bot name (raw mode only)")
|
||||
parser.add_argument("-mli", "--mli", action = "store_true", help = "Enable multi line input")
|
||||
parser.add_argument("-mli", "--multiline", action = "store_true", help = "Enable multi line input (use Alt+Enter to submit input)")
|
||||
parser.add_argument("-sp", "--system_prompt", type = str, help = "Use custom system prompt")
|
||||
parser.add_argument("-maxr", "--max_response_tokens", type = int, default = 1000, help = "Max tokens per response, default = 1000")
|
||||
parser.add_argument("-basic", "--basic_console", action = "store_true", help = "Use basic console output (no markdown and fancy prompt input")
|
||||
parser.add_argument("-rps", "--refresh_per_second", type = int, help = "Max updates per second in Markdown mode, default = 25", default = 25)
|
||||
parser.add_argument("-think", "--think", action = "store_true", help = "Use (very simplistic) reasoning template and formatting")
|
||||
# TODO: Sampling options
|
||||
_args = parser.parse_args()
|
||||
main(_args)
|
||||
|
||||
@@ -1 +1,2 @@
|
||||
blessed
|
||||
prompt_toolkit
|
||||
Reference in New Issue
Block a user