mirror of
https://github.com/turboderp-org/exllamav3.git
synced 2026-04-20 14:29:51 +00:00
chat.py: Add some more commands
This commit is contained in:
@@ -6,6 +6,7 @@ from exllamav3 import Generator, Job, model_init
|
||||
from exllamav3.generator.sampler import ComboSampler
|
||||
from chat_templates import *
|
||||
from chat_util import *
|
||||
from chat_io import *
|
||||
import torch
|
||||
from chat_console import *
|
||||
|
||||
@@ -64,6 +65,8 @@ def main(args):
|
||||
# Main loop
|
||||
print("\n" + col_sysprompt + system_prompt.strip() + col_default)
|
||||
context = []
|
||||
tt = prompt_format.thinktag()
|
||||
banned_strings = [tt[0], tt[1]] if args.no_think else []
|
||||
response = ""
|
||||
|
||||
while True:
|
||||
@@ -74,6 +77,7 @@ def main(args):
|
||||
|
||||
# Get user prompt
|
||||
user_prompt = read_input_fn(args, user_name, multiline)
|
||||
prefix = ""
|
||||
|
||||
# Intercept commands
|
||||
if user_prompt.startswith("/"):
|
||||
@@ -123,6 +127,65 @@ def main(args):
|
||||
print_info("Disabled tokens/second output")
|
||||
continue
|
||||
|
||||
# Retry last response
|
||||
case "/r":
|
||||
user_prompt = context[-1][0]
|
||||
context = context[:-1]
|
||||
|
||||
# Edit last response
|
||||
case "/e":
|
||||
print_info("Press Alt+Enter to submit")
|
||||
user_prompt = context[-1][0]
|
||||
last_reply = context[-1][-1]
|
||||
prefix = read_input_fn(args, bot_name, True, last_reply)
|
||||
context = context[:-1]
|
||||
|
||||
# Edit system prompt
|
||||
case "/sp":
|
||||
print_info("Press Alt+Enter to submit")
|
||||
system_prompt = read_input_fn(args, "System prompt", True, system_prompt)
|
||||
continue
|
||||
|
||||
# Edit banned strings
|
||||
case "/ban":
|
||||
print_info("Write each string on a new line and enclose in \"double quotes\", press Alt+Enter to submit")
|
||||
bans = "\n".join(f"\"{b}\"" for b in banned_strings)
|
||||
bans = read_input_fn(args, "Banned strings", True, bans)
|
||||
bans = [b.strip() for b in bans.split("\n")]
|
||||
bans = [b[1:-1] for b in bans if b.startswith("\"") and b.endswith("\"")]
|
||||
d = len(bans) - len(banned_strings)
|
||||
banned_strings = bans
|
||||
if d < 0:
|
||||
print_info(f"{-d} string(s) removed")
|
||||
elif d > 0:
|
||||
print_info(f"{d} string(s) added")
|
||||
else:
|
||||
print_info("Strings updated")
|
||||
continue
|
||||
|
||||
# Save conversation
|
||||
case "/save":
|
||||
if len(c) == 1:
|
||||
c.append("~/chat_py_session.json")
|
||||
save_session(c[1], system_prompt, banned_strings, context)
|
||||
print_info(f"Saved session to: {c[1]}")
|
||||
continue
|
||||
|
||||
# Load conversation
|
||||
case "/load":
|
||||
if len(c) == 1:
|
||||
c.append("~/chat_py_session.json")
|
||||
try:
|
||||
(
|
||||
system_prompt,
|
||||
banned_strings,
|
||||
context
|
||||
) = load_session(c[1])
|
||||
print_info(f"Loaded session from: {c[1]}")
|
||||
except:
|
||||
print_error(f"Error loading {c[1]}")
|
||||
continue
|
||||
|
||||
case _:
|
||||
print_error(f"Unknown command: {c[0]}")
|
||||
continue
|
||||
@@ -131,34 +194,37 @@ def main(args):
|
||||
context.append((user_prompt, None))
|
||||
|
||||
# Tokenize context and trim from head if too long
|
||||
def get_input_ids():
|
||||
def get_input_ids(_prefix):
|
||||
frm_context = prompt_format.format(system_prompt, context)
|
||||
if args.think:
|
||||
if _prefix:
|
||||
frm_context += prefix
|
||||
elif args.think:
|
||||
frm_context += prompt_format.thinktag()[0]
|
||||
ids_ = tokenizer.encode(frm_context, add_bos = add_bos, encode_special_tokens = True)
|
||||
exp_len_ = ids_.shape[-1] + max_response_tokens + 1
|
||||
return ids_, exp_len_
|
||||
|
||||
ids, exp_len = get_input_ids()
|
||||
ids, exp_len = get_input_ids(prefix)
|
||||
if exp_len > context_length:
|
||||
while exp_len > context_length - 2 * max_response_tokens:
|
||||
context = context[1:]
|
||||
ids, exp_len = get_input_ids()
|
||||
ids, exp_len = get_input_ids(prefix)
|
||||
|
||||
# Inference
|
||||
tt = prompt_format.thinktag()
|
||||
job = Job(
|
||||
input_ids = ids,
|
||||
max_new_tokens = max_response_tokens,
|
||||
stop_conditions = stop_conditions,
|
||||
sampler = sampler,
|
||||
banned_strings = [tt[0], tt[1]] if args.no_think else None
|
||||
banned_strings = banned_strings
|
||||
)
|
||||
generator.enqueue(job)
|
||||
|
||||
# Stream response
|
||||
ctx_exceeded = False
|
||||
with streamer_cm(args, bot_name) as s:
|
||||
if prefix:
|
||||
s.stream(prefix, tt[0], tt[1])
|
||||
while generator.num_remaining_jobs():
|
||||
for r in generator.iterate():
|
||||
chunk = r.get("text", "")
|
||||
|
||||
@@ -36,9 +36,13 @@ def read_input_rich(args, user_name, multiline: bool):
|
||||
user_prompt = Prompt.ask("\n" + col_user + user_name + col_default)
|
||||
return user_prompt
|
||||
|
||||
def read_input_ptk(args, user_name, multiline: bool):
|
||||
def read_input_ptk(args, user_name, multiline: bool, prefix: str = None):
|
||||
print()
|
||||
user_prompt = ptk_prompt(ANSI(col_user + user_name + col_default + ": "), multiline = multiline)
|
||||
user_prompt = ptk_prompt(
|
||||
ANSI(col_user + user_name + col_default + ": "),
|
||||
multiline = multiline,
|
||||
default = prefix or ""
|
||||
)
|
||||
return user_prompt
|
||||
|
||||
class Streamer_basic:
|
||||
|
||||
39
examples/chat_io.py
Normal file
39
examples/chat_io.py
Normal file
@@ -0,0 +1,39 @@
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
def save_session(
|
||||
filename: str,
|
||||
system_prompt: str,
|
||||
banned_strings: list[str],
|
||||
context: list[tuple[str, str | None]]
|
||||
):
|
||||
"""
|
||||
Save a single string and a list of strings to the given filename.
|
||||
Ensures the directory exists before writing. Expands ~ to your home dir.
|
||||
"""
|
||||
path = Path(filename).expanduser()
|
||||
if path.parent:
|
||||
path.parent.mkdir(parents = True, exist_ok = True)
|
||||
|
||||
payload = {
|
||||
"system_prompt": system_prompt,
|
||||
"banned_strings": banned_strings,
|
||||
"context": context
|
||||
}
|
||||
with path.open("w", encoding = "utf-8") as f:
|
||||
json.dump(payload, f, ensure_ascii = False, indent=2)
|
||||
|
||||
|
||||
def load_session(filename: str):
|
||||
"""
|
||||
Load and return (text, strings) from the given filename.
|
||||
Expands ~ to your home dir.
|
||||
"""
|
||||
path = Path(filename).expanduser()
|
||||
with path.open("r", encoding = "utf-8") as f:
|
||||
data = json.load(f)
|
||||
return (
|
||||
data["system_prompt"],
|
||||
data["banned_strings"],
|
||||
data["context"]
|
||||
)
|
||||
Reference in New Issue
Block a user