Files
exllamav3/examples/chat_console.py
2025-06-06 03:43:41 +02:00

176 lines
5.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import sys, shutil
from rich.prompt import Prompt
from rich.live import Live
from rich.markdown import Markdown
from rich.console import Console
from prompt_toolkit import prompt as ptk_prompt
from prompt_toolkit.formatted_text import ANSI
# ANSI codes
ESC = "\u001b"
col_default = "\u001b[0m"
col_user = "\u001b[33;1m" # Yellow
col_bot = "\u001b[34;1m" # Blue
col_think1 = "\u001b[35;1m" # Bright magenta
col_think2 = "\u001b[35m" # Magenta
col_error = "\u001b[31;1m" # Bright red
col_info = "\u001b[32;1m" # Green
col_sysprompt = "\u001b[37;1m" # Grey
def print_error(text):
print(col_error + "\nError: " + col_default + text)
def print_info(text):
print(col_info + "\nInfo: " + col_default + text)
def read_input_console(args, user_name, multiline: bool):
print("\n" + col_user + user_name + ": " + col_default, end = '', flush = True)
if multiline:
user_prompt = sys.stdin.read().rstrip()
else:
user_prompt = input().strip()
return user_prompt
def read_input_rich(args, user_name, multiline: bool):
user_prompt = Prompt.ask("\n" + col_user + user_name + col_default)
return user_prompt
def read_input_ptk(args, user_name, multiline: bool):
print()
user_prompt = ptk_prompt(ANSI(col_user + user_name + col_default + ": "), multiline = multiline)
return user_prompt
class Streamer_basic:
def __init__(self, args, bot_name):
self.all_text = ""
self.args = args
self.bot_name = bot_name
def __enter__(self):
print()
print(col_bot + self.bot_name + ": " + col_default, end = "")
return self
def __exit__(self, exc_type, exc_value, traceback):
if not self.all_text.endswith("\n"):
print()
def stream(self, text: str, think_tag, end_think_tag):
if self.all_text or not text.startswith(" "):
print_text = text
else:
print_text = text[1:]
self.all_text += text
print(print_text, end = "", flush = True)
class MarkdownConsoleStream:
def __init__(self, console: Console = None):
# Make the Rich console a little narrower to prevent overflows from extra-wide emojis
c, r = shutil.get_terminal_size(fallback = (80, 24))
c -= 2
self.console = console or Console(emoji_variant = "text", width = c)
self.height = r - 2
self._last_lines = []
def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, traceback):
return False
def update(self, markdown_text) -> None:
new_lines = self._render_to_lines(markdown_text)
old_lines = self._last_lines
prefix_length = self._common_prefix_length(old_lines, new_lines)
prefix_length = max(prefix_length, len(old_lines) - self.height, len(new_lines) - self.height)
old_suffix_len = len(old_lines) - prefix_length
new_suffix_len = len(new_lines) - prefix_length
if old_suffix_len > 0:
print(f"{ESC}[{old_suffix_len}A", end = "")
changed_count = max(old_suffix_len, new_suffix_len)
for i in range(changed_count):
if i < new_suffix_len:
print(f"{ESC}[2K", end = "") # Clear entire line
print(new_lines[prefix_length + i].rstrip())
else:
print(f"{ESC}[2K", end = "")
# print()
self._last_lines = new_lines
def _render_to_lines(self, markdown_text: str):
# Capture Richs output to a string, then split by lines.
with self.console.capture() as cap:
self.console.print(Markdown(markdown_text))
rendered = cap.get()
split = []
for s in [r.rstrip() for r in rendered.rstrip("\n").split("\n")]:
if s or len(split) == 0 or split[-1]:
split.append(s)
return split
@staticmethod
def _common_prefix_length(a, b) -> int:
i = 0
for x, y in zip(a, b):
if x != y:
break
i += 1
return i
class Streamer_rich:
def __init__(self, args, bot_name):
self.all_text = ""
self.think_text = ""
self.bot_name = bot_name
self.all_print_text = col_bot + self.bot_name + col_default + ": "
self.args = args
self.live = None
self.is_live = False
def begin(self):
self.live = MarkdownConsoleStream()
self.live.__enter__()
self.live.update(self.all_print_text)
self.is_live = True
def __enter__(self):
if self.args.think:
print()
print(col_think1 + "Thinking" + col_default + ": " + col_think2, end = "")
else:
print()
self.begin()
return self
def __exit__(self, exc_type, exc_value, traceback):
if self.is_live:
self.live.__exit__(exc_type, exc_value, traceback)
def stream(self, text: str, think_tag: str, end_think_tag: str):
if self.args.think and not self.is_live:
print_text = text
if not self.think_text:
print_text = print_text.lstrip()
self.think_text += print_text
if end_think_tag in self.think_text:
print(print_text.rstrip(), flush = True)
print()
self.begin()
else:
print(print_text, end = "", flush = True)
else:
print_text = text
if not self.all_text.strip():
print_text = print_text.lstrip()
if print_text.startswith("```"):
print_text = "\n" + print_text
self.all_text += text
self.all_print_text += print_text
formatted_text = self.all_print_text
formatted_text = formatted_text.replace(think_tag, f"`{think_tag}`")
formatted_text = formatted_text.replace(end_think_tag, f"`{end_think_tag}`")
self.live.update(formatted_text)