Files
exllamav3/examples/chat_console.py
2026-03-03 23:15:07 +01:00

289 lines
10 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import sys, shutil
from rich.prompt import Prompt
from rich.markdown import Markdown
from rich.console import Console
from prompt_toolkit import prompt as ptk_prompt
from prompt_toolkit.formatted_text import ANSI
import time, os
# ANSI codes
ESC = "\u001b"
col_default = "\u001b[0m"
col_user = "\u001b[33;1m" # Yellow
col_bot = "\u001b[34;1m" # Blue
col_think1 = "\u001b[35;1m" # Bright magenta
col_think2 = "\u001b[35m" # Magenta
col_error = "\u001b[31;1m" # Bright red
col_info = "\u001b[32;1m" # Green
col_sysprompt = "\u001b[37;1m" # Grey
def print_error(text):
ftext = text.replace("\n", "\n ")
print(col_error + "\nError: " + col_default + ftext)
def print_info(text):
ftext = text.replace("\n", "\n ")
print(col_info + "\nInfo: " + col_default + ftext)
def read_input_console(args, user_name, multiline: bool):
print("\n" + col_user + user_name + ": " + col_default, end = '', flush = True)
if multiline:
user_prompt = sys.stdin.read().rstrip()
else:
user_prompt = input().strip()
return user_prompt
def read_input_rich(args, user_name, multiline: bool):
user_prompt = Prompt.ask("\n" + col_user + user_name + col_default)
return user_prompt
def read_input_ptk(args, user_name, multiline: bool, prefix: str = None):
print()
user_prompt = ptk_prompt(
ANSI(col_user + user_name + col_default + ": "),
multiline = multiline,
default = prefix or ""
)
return user_prompt
class Streamer_basic:
def __init__(self, args, bot_name, think_tag, end_think_tag, updates_per_second, think):
self.all_text = ""
self.args = args
self.bot_name = bot_name
def __enter__(self):
print()
print(col_bot + self.bot_name + ": " + col_default, end = "")
return self
def __exit__(self, exc_type, exc_value, traceback):
if not self.all_text.endswith("\n"):
print()
def stream(self, text: str):
if self.all_text or not text.startswith(" "):
print_text = text
else:
print_text = text[1:]
self.all_text += text
print(print_text, end = "", flush = True)
class MarkdownConsoleStream:
def __init__(self, console: Console = None):
# Make the Rich console a little narrower to prevent overflows from extra-wide emojis
c, r = shutil.get_terminal_size(fallback = (80, 24))
c -= 2
self.console = console or Console(emoji_variant = "text", width = c)
self.height = r - 2
self._last_lines = []
def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, traceback):
return False
def update(self, markdown_text) -> None:
new_lines = self._render_to_lines(markdown_text)
old_lines = self._last_lines
prefix_length = self._common_prefix_length(old_lines, new_lines)
prefix_length = max(prefix_length, len(old_lines) - self.height, len(new_lines) - self.height)
old_suffix_len = len(old_lines) - prefix_length
new_suffix_len = len(new_lines) - prefix_length
if old_suffix_len > 0:
print(f"{ESC}[{old_suffix_len}A", end = "")
changed_count = max(old_suffix_len, new_suffix_len)
for i in range(changed_count):
if i < new_suffix_len:
print(f"{ESC}[2K", end = "") # Clear entire line
print(new_lines[prefix_length + i].rstrip())
else:
print(f"{ESC}[2K", end = "")
# print()
self._last_lines = new_lines
def _render_to_lines(self, markdown_text: str):
# Capture Richs output to a string, then split by lines.
with self.console.capture() as cap:
self.console.print(Markdown(markdown_text))
rendered = cap.get()
split = []
for s in [r.rstrip() for r in rendered.rstrip("\n").split("\n")]:
if s or len(split) == 0 or split[-1]:
split.append(s)
return split
@staticmethod
def _common_prefix_length(a, b) -> int:
i = 0
for x, y in zip(a, b):
if x != y:
break
i += 1
return i
class Streamer_rich:
def __init__(self, args, bot_name, think_tag, end_think_tag, updates_per_second, think):
self.all_text = ""
self.think_text = ""
self.bot_name = bot_name
self.all_print_text = col_bot + self.bot_name + col_default + ": "
self.args = args
self.live = None
self.is_live = False
self.think_tag = think_tag
self.end_think_tag = end_think_tag
self.updates_per_second = updates_per_second
self.last_update = time.time()
self.think = think
def begin(self):
self.live = MarkdownConsoleStream()
self.live.__enter__()
self.live.update(self.all_print_text)
self.is_live = True
def __enter__(self):
if self.think and self.think_tag is not None:
print()
print(col_think1 + "Thinking" + col_default + ": " + col_think2, end = "")
else:
print()
self.begin()
return self
def __exit__(self, exc_type, exc_value, traceback):
self.stream("", True)
if self.is_live:
self.live.__exit__(exc_type, exc_value, traceback)
def stream(self, text: str, force: bool = False):
if self.think and self.think_tag is not None and not self.is_live:
print_text = text
if not self.think_text:
print_text = print_text.lstrip()
self.think_text += print_text
if self.end_think_tag is not None and self.end_think_tag in self.think_text:
print(print_text.rstrip(), flush = True)
print()
self.begin()
else:
print(print_text, end = "", flush = True)
else:
print_text = text
if not self.all_text.strip():
print_text = print_text.lstrip()
if print_text.startswith("```"):
print_text = "\n" + print_text
self.all_text += text
self.all_print_text += print_text
formatted_text = self.all_print_text
if self.think_tag is not None:
formatted_text = formatted_text.replace(self.think_tag, f"`{self.think_tag}`")
formatted_text = formatted_text.replace(self.end_think_tag, f"`{self.end_think_tag}`")
now = time.time()
if now - self.last_update > 1.0 / self.updates_per_second or force:
self.last_update = now
self.live.update(formatted_text)
class KeyReader:
"""
Cross-platform, non-blocking key reader.
Usage:
with KeyReader() as keys:
k = keys.getkey() # non-blocking (returns None if no key)
k = keys.getkey(timeout) # wait up to `timeout` seconds
"""
def __init__(self):
self._platform = 'nt' if os.name == 'nt' else 'posix'
self._entered = False
# POSIX fields
self._old_termios = None
self.disabled = False
def __enter__(self):
try:
if self._platform == 'posix':
import termios, tty
self._termios = termios
self._tty = tty
self._fd = sys.stdin.fileno()
self._old_termios = self._termios.tcgetattr(self._fd)
# cbreak mode lets us read single characters without Enter
self._tty.setcbreak(self._fd)
self._entered = True
return self
except termios.error:
self.disabled = True
return self
def __exit__(self, exc_type, exc, tb):
if not self.disabled and self._platform == 'posix' and self._old_termios:
self._termios.tcsetattr(self._fd, self._termios.TCSADRAIN, self._old_termios)
self._entered = False
def getkey(self, timeout=0.0):
"""
Returns a single-character string if a key is pressed, or None.
`timeout` is seconds to wait (float). 0.0 => non-blocking poll.
"""
if self.disabled:
return None
if not self._entered:
raise RuntimeError("Use KeyReader as a context manager")
if self._platform == 'nt':
import msvcrt
# busy-wait with tiny sleeps if a timeout is requested
end = None if timeout is None else (time.time() + timeout)
while True:
if msvcrt.kbhit():
b = msvcrt.getch()
# Handle special keys (arrows, function keys) which come as a prefix + code
if b in (b'\x00', b'\xe0'): # special prefix
_ = msvcrt.getch() # consume the second byte
return None # ignore special keys for simplicity
try:
return b.decode('utf-8', errors='ignore').lower()
except Exception:
return None
if timeout == 0.0:
return None
if end is not None and time.time() >= end:
return None
time.sleep(0.005)
else:
import select
r, _, _ = select.select([sys.stdin], [], [], timeout)
if r:
try:
ch = sys.stdin.read(1)
except (IOError, OSError):
return None
return (ch or "").lower() or None
return None
def print_tokens(
ids: list,
vocab: list,
ids_per_line = 10,
):
print()
line = ""
for pos in range(len(ids)):
t = ids[pos]
p = repr(vocab[t])[1:-1].replace(" ", "")
line += f"{col_user}{t:6}{col_default} "
line += f"{p:10} " if len(p) <= 10 else f"{p[:9]}"
if (pos + 1) % ids_per_line == 0 or pos == len(ids) - 1:
line = line.replace("", f"{col_bot}{col_default}").replace("", f"{col_error}{col_default}")
ppos = pos // ids_per_line * ids_per_line
print(f"{col_info}{ppos:6} {col_default}: {line}")
line = ""