import sys, shutil from rich.prompt import Prompt from rich.markdown import Markdown from rich.console import Console from prompt_toolkit import prompt as ptk_prompt from prompt_toolkit.formatted_text import ANSI import time, os import unicodedata # ANSI codes ESC = "\u001b" col_default = "\u001b[0m" col_user = "\u001b[33;1m" # Yellow col_bot = "\u001b[34;1m" # Blue col_think1 = "\u001b[35;1m" # Bright magenta col_think2 = "\u001b[35m" # Magenta col_error = "\u001b[31;1m" # Bright red col_info = "\u001b[32;1m" # Green col_sysprompt = "\u001b[37;1m" # Grey col_dark = "\u001b[0;90m" # White col_b = "\u001b[100m" LRO = '\u202D' # Left-to-Right Override PDF = '\u202C' # Pop Directional Formatting def print_error(text): ftext = text.replace("\n", "\n ") print(col_error + "\nError: " + col_default + ftext) def print_info(text): ftext = text.replace("\n", "\n ") print(col_info + "\nInfo: " + col_default + ftext) def read_input_console(args, user_name, multiline: bool): print("\n" + col_user + user_name + ": " + col_default, end = '', flush = True) if multiline: user_prompt = sys.stdin.read().rstrip() else: user_prompt = input().strip() return user_prompt def read_input_rich(args, user_name, multiline: bool): user_prompt = Prompt.ask("\n" + col_user + user_name + col_default) return user_prompt def read_input_ptk(args, user_name, multiline: bool, prefix: str = None): print() user_prompt = ptk_prompt( ANSI(col_user + user_name + col_default + ": "), multiline = multiline, default = prefix or "" ) return user_prompt class Streamer_basic: def __init__(self, args, bot_name, think_tag, end_think_tag, updates_per_second, think): self.all_text = "" self.args = args self.bot_name = bot_name def __enter__(self): print() print(col_bot + self.bot_name + ": " + col_default, end = "") return self def __exit__(self, exc_type, exc_value, traceback): if not self.all_text.endswith("\n"): print() def stream(self, text: str): if self.all_text or not text.startswith(" "): print_text = text else: print_text = text[1:] self.all_text += text print(print_text, end = "", flush = True) class MarkdownConsoleStream: def __init__(self, console: Console = None): # Make the Rich console a little narrower to prevent overflows from extra-wide emojis c, r = shutil.get_terminal_size(fallback = (80, 24)) c -= 2 self.console = console or Console(emoji_variant = "text", width = c) self.height = r - 2 self._last_lines = [] def __enter__(self): return self def __exit__(self, exc_type, exc_value, traceback): return False def update(self, markdown_text) -> None: new_lines = self._render_to_lines(markdown_text) old_lines = self._last_lines prefix_length = self._common_prefix_length(old_lines, new_lines) prefix_length = max(prefix_length, len(old_lines) - self.height, len(new_lines) - self.height) old_suffix_len = len(old_lines) - prefix_length new_suffix_len = len(new_lines) - prefix_length if old_suffix_len > 0: print(f"{ESC}[{old_suffix_len}A", end = "") changed_count = max(old_suffix_len, new_suffix_len) for i in range(changed_count): if i < new_suffix_len: print(f"{ESC}[2K", end = "") # Clear entire line print(new_lines[prefix_length + i].rstrip()) else: print(f"{ESC}[2K", end = "") # print() self._last_lines = new_lines def _render_to_lines(self, markdown_text: str): # Capture Rich’s output to a string, then split by lines. with self.console.capture() as cap: self.console.print(Markdown(markdown_text)) rendered = cap.get() split = [] for s in [r.rstrip() for r in rendered.rstrip("\n").split("\n")]: if s or len(split) == 0 or split[-1]: split.append(s) return split @staticmethod def _common_prefix_length(a, b) -> int: i = 0 for x, y in zip(a, b): if x != y: break i += 1 return i class Streamer_rich: def __init__(self, args, bot_name, think_tag, end_think_tag, updates_per_second, think): self.all_text = "" self.think_text = "" self.bot_name = bot_name self.all_print_text = col_bot + self.bot_name + col_default + ": " self.args = args self.live = None self.is_live = False self.think_tag = think_tag self.end_think_tag = end_think_tag self.updates_per_second = updates_per_second self.last_update = time.time() self.think = think def begin(self): self.live = MarkdownConsoleStream() self.live.__enter__() self.live.update(self.all_print_text) self.is_live = True def __enter__(self): if self.think and self.think_tag is not None: print() print(col_think1 + "Thinking" + col_default + ": " + col_think2, end = "") else: print() self.begin() return self def __exit__(self, exc_type, exc_value, traceback): self.stream("", True) if self.is_live: self.live.__exit__(exc_type, exc_value, traceback) def stream(self, text: str, force: bool = False): if self.think and self.think_tag is not None and not self.is_live: print_text = text if not self.think_text: print_text = print_text.lstrip() self.think_text += print_text if self.end_think_tag is not None and self.end_think_tag in self.think_text: print(print_text.rstrip(), flush = True) print() self.begin() else: print(print_text, end = "", flush = True) else: print_text = text if not self.all_text.strip(): print_text = print_text.lstrip() if print_text.startswith("```"): print_text = "\n" + print_text self.all_text += text self.all_print_text += print_text formatted_text = self.all_print_text if self.think_tag is not None: formatted_text = formatted_text.replace(self.think_tag, f"`{self.think_tag}`") formatted_text = formatted_text.replace(self.end_think_tag, f"`{self.end_think_tag}`") now = time.time() if now - self.last_update > 1.0 / self.updates_per_second or force: self.last_update = now self.live.update(formatted_text) class KeyReader: """ Cross-platform, non-blocking key reader. Usage: with KeyReader() as keys: k = keys.getkey() # non-blocking (returns None if no key) k = keys.getkey(timeout) # wait up to `timeout` seconds """ def __init__(self): self._platform = 'nt' if os.name == 'nt' else 'posix' self._entered = False # POSIX fields self._old_termios = None self.disabled = False def __enter__(self): try: if self._platform == 'posix': import termios, tty self._termios = termios self._tty = tty self._fd = sys.stdin.fileno() self._old_termios = self._termios.tcgetattr(self._fd) # cbreak mode lets us read single characters without Enter self._tty.setcbreak(self._fd) self._entered = True return self except termios.error: self.disabled = True return self def __exit__(self, exc_type, exc, tb): if not self.disabled and self._platform == 'posix' and self._old_termios: self._termios.tcsetattr(self._fd, self._termios.TCSADRAIN, self._old_termios) self._entered = False def getkey(self, timeout=0.0): """ Returns a single-character string if a key is pressed, or None. `timeout` is seconds to wait (float). 0.0 => non-blocking poll. """ if self.disabled: return None if not self._entered: raise RuntimeError("Use KeyReader as a context manager") if self._platform == 'nt': import msvcrt # busy-wait with tiny sleeps if a timeout is requested end = None if timeout is None else (time.time() + timeout) while True: if msvcrt.kbhit(): b = msvcrt.getch() # Handle special keys (arrows, function keys) which come as a prefix + code if b in (b'\x00', b'\xe0'): # special prefix _ = msvcrt.getch() # consume the second byte return None # ignore special keys for simplicity try: return b.decode('utf-8', errors='ignore').lower() except Exception: return None if timeout == 0.0: return None if end is not None and time.time() >= end: return None time.sleep(0.005) else: import select r, _, _ = select.select([sys.stdin], [], [], timeout) if r: try: ch = sys.stdin.read(1) except (IOError, OSError): return None return (ch or "").lower() or None return None def display_width(s): return sum(2 if unicodedata.east_asian_width(c) in ('F', 'W') else 1 for c in s) def char_width(c): return 2 if unicodedata.east_asian_width(c) in ('F', 'W') else 1 def ljust_truncate(s, width): dw = display_width(s) if dw <= width: return s + ' ' * (width - dw) max_content = width - 1 # reserve 1 column for '…' cols = 0 truncated = [] for c in s: cw = char_width(c) if cols + cw > max_content: break truncated.append(c) cols += cw padding = width - cols - 1 # 1 for the ellipsis t = ''.join(truncated) + '…' + ' ' * padding return t def print_tokens( ids: list, vocab: list, ids_per_line = 10, ): print() line = "" for pos in range(len(ids)): t = ids[pos] p = repr(vocab[t])[1:-1].replace(" ", "␣") line += f"{col_user}{t:6}{col_default} " line += f"{ljust_truncate(p, 10)} " if (pos + 1) % ids_per_line == 0 or pos == len(ids) - 1: line = line.replace("␣", f"{col_bot}␣{col_default}").replace("…", f"{col_error}…{col_default}") ppos = pos // ids_per_line * ids_per_line print(f"{col_info}{ppos:6} {col_default}: {line}") line = "" def print_probs( saved_topk: list, saved_probs: list, saved_samples: list, vocab: list, ids_per_line = 5, minimum_p = 1e-6, ): if len(saved_topk) == 0: return num = len(saved_topk[0]) lines = [" "] * (num + 2) ids_this_line = 0 for ids, probs, sample in zip(saved_topk, saved_probs, saved_samples): tss = repr(vocab[sample[0]])[1:-1].replace(" ", "␣") lines[0] += f"{col_sysprompt}" + ljust_truncate(tss, 27) + f"{col_default} " lines[1] += f"{col_dark}" + "─" * 27 + f"{col_default} " for i, (t, p) in enumerate(zip(ids, probs)): if p < minimum_p: lines[i + 2] += " " * 29 else: hl = col_b if sample[0] == t else "" ts = repr(vocab[t])[1:-1].replace(" ", "␣") lines[i + 2] += f"{col_user}{hl} {t:6}{col_default}{hl} " lines[i + 2] += f"{ljust_truncate(ts, 10)} " lines[i + 2] += f"{col_think1}{hl}{p:7.5f}{col_default} " ids_this_line += 1 if ids_this_line == ids_per_line: print(f"\n{LRO}" + "\n".join(lines) + f"{PDF}") lines = [" "] * (num + 2) ids_this_line = 0 if ids_this_line > 0: print(f"\n{LRO}" + "\n".join(lines) + f"{PDF}")