chat.py: Add some more commands

This commit is contained in:
turboderp
2025-07-10 18:32:27 +02:00
parent 51e9922b5b
commit e12a1a5d0b
3 changed files with 117 additions and 8 deletions

View File

@@ -6,6 +6,7 @@ from exllamav3 import Generator, Job, model_init
from exllamav3.generator.sampler import ComboSampler
from chat_templates import *
from chat_util import *
from chat_io import *
import torch
from chat_console import *
@@ -64,6 +65,8 @@ def main(args):
# Main loop
print("\n" + col_sysprompt + system_prompt.strip() + col_default)
context = []
tt = prompt_format.thinktag()
banned_strings = [tt[0], tt[1]] if args.no_think else []
response = ""
while True:
@@ -74,6 +77,7 @@ def main(args):
# Get user prompt
user_prompt = read_input_fn(args, user_name, multiline)
prefix = ""
# Intercept commands
if user_prompt.startswith("/"):
@@ -123,6 +127,65 @@ def main(args):
print_info("Disabled tokens/second output")
continue
# Retry last response
case "/r":
user_prompt = context[-1][0]
context = context[:-1]
# Edit last response
case "/e":
print_info("Press Alt+Enter to submit")
user_prompt = context[-1][0]
last_reply = context[-1][-1]
prefix = read_input_fn(args, bot_name, True, last_reply)
context = context[:-1]
# Edit system prompt
case "/sp":
print_info("Press Alt+Enter to submit")
system_prompt = read_input_fn(args, "System prompt", True, system_prompt)
continue
# Edit banned strings
case "/ban":
print_info("Write each string on a new line and enclose in \"double quotes\", press Alt+Enter to submit")
bans = "\n".join(f"\"{b}\"" for b in banned_strings)
bans = read_input_fn(args, "Banned strings", True, bans)
bans = [b.strip() for b in bans.split("\n")]
bans = [b[1:-1] for b in bans if b.startswith("\"") and b.endswith("\"")]
d = len(bans) - len(banned_strings)
banned_strings = bans
if d < 0:
print_info(f"{-d} string(s) removed")
elif d > 0:
print_info(f"{d} string(s) added")
else:
print_info("Strings updated")
continue
# Save conversation
case "/save":
if len(c) == 1:
c.append("~/chat_py_session.json")
save_session(c[1], system_prompt, banned_strings, context)
print_info(f"Saved session to: {c[1]}")
continue
# Load conversation
case "/load":
if len(c) == 1:
c.append("~/chat_py_session.json")
try:
(
system_prompt,
banned_strings,
context
) = load_session(c[1])
print_info(f"Loaded session from: {c[1]}")
except:
print_error(f"Error loading {c[1]}")
continue
case _:
print_error(f"Unknown command: {c[0]}")
continue
@@ -131,34 +194,37 @@ def main(args):
context.append((user_prompt, None))
# Tokenize context and trim from head if too long
def get_input_ids():
def get_input_ids(_prefix):
frm_context = prompt_format.format(system_prompt, context)
if args.think:
if _prefix:
frm_context += prefix
elif args.think:
frm_context += prompt_format.thinktag()[0]
ids_ = tokenizer.encode(frm_context, add_bos = add_bos, encode_special_tokens = True)
exp_len_ = ids_.shape[-1] + max_response_tokens + 1
return ids_, exp_len_
ids, exp_len = get_input_ids()
ids, exp_len = get_input_ids(prefix)
if exp_len > context_length:
while exp_len > context_length - 2 * max_response_tokens:
context = context[1:]
ids, exp_len = get_input_ids()
ids, exp_len = get_input_ids(prefix)
# Inference
tt = prompt_format.thinktag()
job = Job(
input_ids = ids,
max_new_tokens = max_response_tokens,
stop_conditions = stop_conditions,
sampler = sampler,
banned_strings = [tt[0], tt[1]] if args.no_think else None
banned_strings = banned_strings
)
generator.enqueue(job)
# Stream response
ctx_exceeded = False
with streamer_cm(args, bot_name) as s:
if prefix:
s.stream(prefix, tt[0], tt[1])
while generator.num_remaining_jobs():
for r in generator.iterate():
chunk = r.get("text", "")

View File

@@ -36,9 +36,13 @@ def read_input_rich(args, user_name, multiline: bool):
user_prompt = Prompt.ask("\n" + col_user + user_name + col_default)
return user_prompt
def read_input_ptk(args, user_name, multiline: bool):
def read_input_ptk(args, user_name, multiline: bool, prefix: str = None):
print()
user_prompt = ptk_prompt(ANSI(col_user + user_name + col_default + ": "), multiline = multiline)
user_prompt = ptk_prompt(
ANSI(col_user + user_name + col_default + ": "),
multiline = multiline,
default = prefix or ""
)
return user_prompt
class Streamer_basic:

39
examples/chat_io.py Normal file
View File

@@ -0,0 +1,39 @@
import json
from pathlib import Path
def save_session(
filename: str,
system_prompt: str,
banned_strings: list[str],
context: list[tuple[str, str | None]]
):
"""
Save a single string and a list of strings to the given filename.
Ensures the directory exists before writing. Expands ~ to your home dir.
"""
path = Path(filename).expanduser()
if path.parent:
path.parent.mkdir(parents = True, exist_ok = True)
payload = {
"system_prompt": system_prompt,
"banned_strings": banned_strings,
"context": context
}
with path.open("w", encoding = "utf-8") as f:
json.dump(payload, f, ensure_ascii = False, indent=2)
def load_session(filename: str):
"""
Load and return (text, strings) from the given filename.
Expands ~ to your home dir.
"""
path = Path(filename).expanduser()
with path.open("r", encoding = "utf-8") as f:
data = json.load(f)
return (
data["system_prompt"],
data["banned_strings"],
data["context"]
)