mirror of
https://github.com/turboderp-org/exllamav3.git
synced 2026-03-15 00:07:24 +00:00
289 lines
10 KiB
Python
289 lines
10 KiB
Python
import sys, shutil
|
||
|
||
from rich.prompt import Prompt
|
||
from rich.markdown import Markdown
|
||
from rich.console import Console
|
||
from prompt_toolkit import prompt as ptk_prompt
|
||
from prompt_toolkit.formatted_text import ANSI
|
||
import time, os
|
||
|
||
# ANSI codes
|
||
ESC = "\u001b"
|
||
col_default = "\u001b[0m"
|
||
col_user = "\u001b[33;1m" # Yellow
|
||
col_bot = "\u001b[34;1m" # Blue
|
||
col_think1 = "\u001b[35;1m" # Bright magenta
|
||
col_think2 = "\u001b[35m" # Magenta
|
||
col_error = "\u001b[31;1m" # Bright red
|
||
col_info = "\u001b[32;1m" # Green
|
||
col_sysprompt = "\u001b[37;1m" # Grey
|
||
|
||
def print_error(text):
|
||
ftext = text.replace("\n", "\n ")
|
||
print(col_error + "\nError: " + col_default + ftext)
|
||
|
||
def print_info(text):
|
||
ftext = text.replace("\n", "\n ")
|
||
print(col_info + "\nInfo: " + col_default + ftext)
|
||
|
||
def read_input_console(args, user_name, multiline: bool):
|
||
print("\n" + col_user + user_name + ": " + col_default, end = '', flush = True)
|
||
if multiline:
|
||
user_prompt = sys.stdin.read().rstrip()
|
||
else:
|
||
user_prompt = input().strip()
|
||
return user_prompt
|
||
|
||
def read_input_rich(args, user_name, multiline: bool):
|
||
user_prompt = Prompt.ask("\n" + col_user + user_name + col_default)
|
||
return user_prompt
|
||
|
||
def read_input_ptk(args, user_name, multiline: bool, prefix: str = None):
|
||
print()
|
||
user_prompt = ptk_prompt(
|
||
ANSI(col_user + user_name + col_default + ": "),
|
||
multiline = multiline,
|
||
default = prefix or ""
|
||
)
|
||
return user_prompt
|
||
|
||
class Streamer_basic:
|
||
|
||
def __init__(self, args, bot_name, think_tag, end_think_tag, updates_per_second, think):
|
||
self.all_text = ""
|
||
self.args = args
|
||
self.bot_name = bot_name
|
||
|
||
def __enter__(self):
|
||
print()
|
||
print(col_bot + self.bot_name + ": " + col_default, end = "")
|
||
return self
|
||
|
||
def __exit__(self, exc_type, exc_value, traceback):
|
||
if not self.all_text.endswith("\n"):
|
||
print()
|
||
|
||
def stream(self, text: str):
|
||
if self.all_text or not text.startswith(" "):
|
||
print_text = text
|
||
else:
|
||
print_text = text[1:]
|
||
self.all_text += text
|
||
print(print_text, end = "", flush = True)
|
||
|
||
class MarkdownConsoleStream:
|
||
|
||
def __init__(self, console: Console = None):
|
||
# Make the Rich console a little narrower to prevent overflows from extra-wide emojis
|
||
c, r = shutil.get_terminal_size(fallback = (80, 24))
|
||
c -= 2
|
||
self.console = console or Console(emoji_variant = "text", width = c)
|
||
self.height = r - 2
|
||
self._last_lines = []
|
||
|
||
def __enter__(self):
|
||
return self
|
||
|
||
def __exit__(self, exc_type, exc_value, traceback):
|
||
return False
|
||
|
||
def update(self, markdown_text) -> None:
|
||
new_lines = self._render_to_lines(markdown_text)
|
||
old_lines = self._last_lines
|
||
prefix_length = self._common_prefix_length(old_lines, new_lines)
|
||
prefix_length = max(prefix_length, len(old_lines) - self.height, len(new_lines) - self.height)
|
||
old_suffix_len = len(old_lines) - prefix_length
|
||
new_suffix_len = len(new_lines) - prefix_length
|
||
if old_suffix_len > 0:
|
||
print(f"{ESC}[{old_suffix_len}A", end = "")
|
||
changed_count = max(old_suffix_len, new_suffix_len)
|
||
for i in range(changed_count):
|
||
if i < new_suffix_len:
|
||
print(f"{ESC}[2K", end = "") # Clear entire line
|
||
print(new_lines[prefix_length + i].rstrip())
|
||
else:
|
||
print(f"{ESC}[2K", end = "")
|
||
# print()
|
||
self._last_lines = new_lines
|
||
|
||
def _render_to_lines(self, markdown_text: str):
|
||
# Capture Rich’s output to a string, then split by lines.
|
||
with self.console.capture() as cap:
|
||
self.console.print(Markdown(markdown_text))
|
||
rendered = cap.get()
|
||
split = []
|
||
for s in [r.rstrip() for r in rendered.rstrip("\n").split("\n")]:
|
||
if s or len(split) == 0 or split[-1]:
|
||
split.append(s)
|
||
return split
|
||
|
||
@staticmethod
|
||
def _common_prefix_length(a, b) -> int:
|
||
i = 0
|
||
for x, y in zip(a, b):
|
||
if x != y:
|
||
break
|
||
i += 1
|
||
return i
|
||
|
||
class Streamer_rich:
|
||
def __init__(self, args, bot_name, think_tag, end_think_tag, updates_per_second, think):
|
||
self.all_text = ""
|
||
self.think_text = ""
|
||
self.bot_name = bot_name
|
||
self.all_print_text = col_bot + self.bot_name + col_default + ": "
|
||
self.args = args
|
||
self.live = None
|
||
self.is_live = False
|
||
self.think_tag = think_tag
|
||
self.end_think_tag = end_think_tag
|
||
self.updates_per_second = updates_per_second
|
||
self.last_update = time.time()
|
||
self.think = think
|
||
|
||
def begin(self):
|
||
self.live = MarkdownConsoleStream()
|
||
self.live.__enter__()
|
||
self.live.update(self.all_print_text)
|
||
self.is_live = True
|
||
|
||
def __enter__(self):
|
||
if self.think and self.think_tag is not None:
|
||
print()
|
||
print(col_think1 + "Thinking" + col_default + ": " + col_think2, end = "")
|
||
else:
|
||
print()
|
||
self.begin()
|
||
return self
|
||
|
||
def __exit__(self, exc_type, exc_value, traceback):
|
||
self.stream("", True)
|
||
if self.is_live:
|
||
self.live.__exit__(exc_type, exc_value, traceback)
|
||
|
||
def stream(self, text: str, force: bool = False):
|
||
if self.think and self.think_tag is not None and not self.is_live:
|
||
print_text = text
|
||
if not self.think_text:
|
||
print_text = print_text.lstrip()
|
||
self.think_text += print_text
|
||
if self.end_think_tag is not None and self.end_think_tag in self.think_text:
|
||
print(print_text.rstrip(), flush = True)
|
||
print()
|
||
self.begin()
|
||
else:
|
||
print(print_text, end = "", flush = True)
|
||
|
||
else:
|
||
print_text = text
|
||
if not self.all_text.strip():
|
||
print_text = print_text.lstrip()
|
||
if print_text.startswith("```"):
|
||
print_text = "\n" + print_text
|
||
self.all_text += text
|
||
self.all_print_text += print_text
|
||
formatted_text = self.all_print_text
|
||
if self.think_tag is not None:
|
||
formatted_text = formatted_text.replace(self.think_tag, f"`{self.think_tag}`")
|
||
formatted_text = formatted_text.replace(self.end_think_tag, f"`{self.end_think_tag}`")
|
||
|
||
now = time.time()
|
||
if now - self.last_update > 1.0 / self.updates_per_second or force:
|
||
self.last_update = now
|
||
self.live.update(formatted_text)
|
||
|
||
class KeyReader:
|
||
"""
|
||
Cross-platform, non-blocking key reader.
|
||
Usage:
|
||
with KeyReader() as keys:
|
||
k = keys.getkey() # non-blocking (returns None if no key)
|
||
k = keys.getkey(timeout) # wait up to `timeout` seconds
|
||
"""
|
||
def __init__(self):
|
||
self._platform = 'nt' if os.name == 'nt' else 'posix'
|
||
self._entered = False
|
||
# POSIX fields
|
||
self._old_termios = None
|
||
self.disabled = False
|
||
|
||
def __enter__(self):
|
||
try:
|
||
if self._platform == 'posix':
|
||
import termios, tty
|
||
self._termios = termios
|
||
self._tty = tty
|
||
self._fd = sys.stdin.fileno()
|
||
self._old_termios = self._termios.tcgetattr(self._fd)
|
||
# cbreak mode lets us read single characters without Enter
|
||
self._tty.setcbreak(self._fd)
|
||
self._entered = True
|
||
return self
|
||
except termios.error:
|
||
self.disabled = True
|
||
return self
|
||
|
||
def __exit__(self, exc_type, exc, tb):
|
||
if not self.disabled and self._platform == 'posix' and self._old_termios:
|
||
self._termios.tcsetattr(self._fd, self._termios.TCSADRAIN, self._old_termios)
|
||
self._entered = False
|
||
|
||
def getkey(self, timeout=0.0):
|
||
"""
|
||
Returns a single-character string if a key is pressed, or None.
|
||
`timeout` is seconds to wait (float). 0.0 => non-blocking poll.
|
||
"""
|
||
if self.disabled:
|
||
return None
|
||
|
||
if not self._entered:
|
||
raise RuntimeError("Use KeyReader as a context manager")
|
||
|
||
if self._platform == 'nt':
|
||
import msvcrt
|
||
# busy-wait with tiny sleeps if a timeout is requested
|
||
end = None if timeout is None else (time.time() + timeout)
|
||
while True:
|
||
if msvcrt.kbhit():
|
||
b = msvcrt.getch()
|
||
# Handle special keys (arrows, function keys) which come as a prefix + code
|
||
if b in (b'\x00', b'\xe0'): # special prefix
|
||
_ = msvcrt.getch() # consume the second byte
|
||
return None # ignore special keys for simplicity
|
||
try:
|
||
return b.decode('utf-8', errors='ignore').lower()
|
||
except Exception:
|
||
return None
|
||
if timeout == 0.0:
|
||
return None
|
||
if end is not None and time.time() >= end:
|
||
return None
|
||
time.sleep(0.005)
|
||
else:
|
||
import select
|
||
r, _, _ = select.select([sys.stdin], [], [], timeout)
|
||
if r:
|
||
try:
|
||
ch = sys.stdin.read(1)
|
||
except (IOError, OSError):
|
||
return None
|
||
return (ch or "").lower() or None
|
||
return None
|
||
|
||
def print_tokens(
|
||
ids: list,
|
||
vocab: list,
|
||
ids_per_line = 10,
|
||
):
|
||
print()
|
||
line = ""
|
||
for pos in range(len(ids)):
|
||
t = ids[pos]
|
||
p = repr(vocab[t])[1:-1].replace(" ", "␣")
|
||
line += f"{col_user}{t:6}{col_default} "
|
||
line += f"{p:10} " if len(p) <= 10 else f"{p[:9]}… "
|
||
if (pos + 1) % ids_per_line == 0 or pos == len(ids) - 1:
|
||
line = line.replace("␣", f"{col_bot}␣{col_default}").replace("…", f"{col_error}…{col_default}")
|
||
ppos = pos // ids_per_line * ids_per_line
|
||
print(f"{col_info}{ppos:6} {col_default}: {line}")
|
||
line = "" |