chat.py: Random benchmark question feature

This commit is contained in:
turboderp
2026-03-13 04:31:19 +01:00
parent aaf6337f12
commit ebd2efb6bd
2 changed files with 152 additions and 22 deletions

View File

@@ -121,6 +121,7 @@ def main(args):
"/think Toggle reasoning mode",
"/tps Toggle tokens/second output",
"/x Exit",
"/b <source> Random benchmark question",
]))
continue
@@ -129,6 +130,27 @@ def main(args):
print_info("Exiting")
break
# Random benchmark question
case "/b":
source = c[1] if len(c) > 1 else None
sources = get_sample_sources()
if source and source.isnumeric():
i = int(source)
source = list(sources)[i - 1] if 1 <= i <= len(sources) else None
if source not in sources:
print_info(
"Available sample sources:\n\n" +
"\n".join([f"{i + 1}. {t}" for i, t in enumerate(get_sample_sources())])
)
continue
question = sample_question(source)
print_info(f"Question from {source}; multi-line mode, press Alt-Enter to submit or Ctrl-C to abort")
try:
user_prompt = read_input_fn(args, "Prompt", True, question)
except KeyboardInterrupt:
print_info("Aborted")
continue
# Copy codeblock to clipboard
case "/cc":
try:
@@ -178,35 +200,47 @@ def main(args):
# Edit last response
case "/e":
print_info("Press Alt+Enter to submit")
print_info("Press Alt-Enter to submit")
user_prompt = context[-1][0]
last_reply = context[-1][-1]
prefix = read_input_fn(args, bot_name, True, last_reply)
context = context[:-1]
enable_healing = True
try:
prefix = read_input_fn(args, bot_name, True, last_reply)
context = context[:-1]
enable_healing = True
except KeyboardInterrupt:
print_info("Exiting")
break
# Edit system prompt
case "/sp":
print_info("Press Alt+Enter to submit")
system_prompt = read_input_fn(args, "System prompt", True, system_prompt)
continue
print_info("Press Alt-Enter to submit")
try:
system_prompt = read_input_fn(args, "System prompt", True, system_prompt)
continue
except KeyboardInterrupt:
print_info("Exiting")
break
# Edit banned strings
case "/ban":
print_info("Write each string on a new line and enclose in \"double quotes\", press Alt+Enter to submit")
print_info("Write each string on a new line and enclose in \"double quotes\", press Alt-Enter to submit")
bans = "\n".join(f"\"{b}\"" for b in banned_strings)
bans = read_input_fn(args, "Banned strings", True, bans)
bans = [b.strip() for b in bans.split("\n")]
bans = [b[1:-1] for b in bans if b.startswith("\"") and b.endswith("\"")]
d = len(bans) - len(banned_strings)
banned_strings = bans
if d < 0:
print_info(f"{-d} string(s) removed")
elif d > 0:
print_info(f"{d} string(s) added")
else:
print_info("Strings updated")
continue
try:
bans = read_input_fn(args, "Banned strings", True, bans)
bans = [b.strip() for b in bans.split("\n")]
bans = [b[1:-1] for b in bans if b.startswith("\"") and b.endswith("\"")]
d = len(bans) - len(banned_strings)
banned_strings = bans
if d < 0:
print_info(f"{-d} string(s) removed")
elif d > 0:
print_info(f"{d} string(s) added")
else:
print_info("Strings updated")
continue
except KeyboardInterrupt:
print_info("Exiting")
break
# Save conversation
case "/save":
@@ -369,7 +403,7 @@ if __name__ == "__main__":
parser.add_argument("-modes", "--modes", action = "store_true", help = "List available prompt modes and exit")
parser.add_argument("-un", "--user_name", type = str, default = "User", help = "User name (raw mode only)")
parser.add_argument("-bn", "--bot_name", type = str, default = "Assistant", help = "Bot name (raw mode only)")
parser.add_argument("-mli", "--multiline", action = "store_true", help = "Enable multi line input (use Alt+Enter to submit input)")
parser.add_argument("-mli", "--multiline", action = "store_true", help = "Enable multi line input (use Alt-Enter to submit input)")
parser.add_argument("-sp", "--system_prompt", type = str, help = "Use custom system prompt")
parser.add_argument("-maxr", "--max_response_tokens", type = int, default = 1000, help = "Max tokens per response, default = 1000")
parser.add_argument("-basic", "--basic_console", action = "store_true", help = "Use basic console output (no markdown and fancy prompt input")

View File

@@ -1,5 +1,9 @@
from __future__ import annotations
import re
import pyperclip
import json
import random
from exllamav3.util.file import disk_lru_cache
def copy_last_codeblock(text: str, num) -> str | None:
pattern = re.compile(r"```[^\n`]*\n(.*?)```", re.DOTALL)
@@ -32,4 +36,96 @@ def extract_svg(s: str, begin: str = "<svg", end: str = "</svg>"):
return None
_, start, stop = best
return s[start:stop]
return s[start:stop]
def _fetch_github_json(url: str) -> list[dict]:
import urllib.request
req = urllib.request.Request(url, headers = {"User-Agent": "benchmark-sampler"})
with urllib.request.urlopen(req) as resp:
return json.loads(resp.read().decode())
@disk_lru_cache("_load_truthfulqa")
def _load_truthfulqa() -> list[str]:
"""817 questions spanning health, law, finance, politics, etc."""
from datasets import load_dataset
ds = load_dataset("truthfulqa/truthful_qa", "generation", split = "validation")
return [row["question"] for row in ds]
@disk_lru_cache("_load_simpleqa")
def _load_simpleqa() -> list[str]:
from datasets import load_dataset
ds = load_dataset("basicv8vc/SimpleQA", split = "test")
return [row["problem"] for row in ds]
@disk_lru_cache("_load_arc_challenge")
def _load_arc_challenge() -> list[str]:
from datasets import load_dataset
ds = load_dataset("allenai/ai2_arc", "ARC-Challenge", split = "test")
return [row["question"] for row in ds]
@disk_lru_cache("_load_commonsenseqa")
def _load_commonsenseqa() -> list[str]:
from datasets import load_dataset
ds = load_dataset("tau/commonsense_qa", split = "validation")
return [row["question"] for row in ds]
@disk_lru_cache("_load_bullshitbench_v2_")
def _load_bullshitbench_v2() -> list[str]:
url = "https://raw.githubusercontent.com/petergpt/bullshit-benchmark/main/questions.v2.json"
data = _fetch_github_json(url)
questions = []
for t in data["techniques"]:
questions += [q["question"] for q in t.get("questions", data) if "question" in q]
return questions
@disk_lru_cache("_load_mt_bench")
def _load_mt_bench() -> list[str]:
from datasets import load_dataset
ds = load_dataset("HuggingFaceH4/mt_bench_prompts", split = "train")
return [row["prompt"][0] for row in ds]
@disk_lru_cache("_load_wildbench")
def _load_wildbench() -> list[str]:
from datasets import load_dataset
ds = load_dataset("allenai/WildBench", "v2", split = "test")
questions = []
for row in ds:
turns = row["conversation_input"]
for turn in turns:
if turn["role"] == "user":
if turn.get("toxic") or turn.get("redacted"):
break
text = turn["content"].strip()
if text:
questions.append(text)
break # only first user turn
return questions
bench_sources = {
"truthfulqa": _load_truthfulqa,
"simpleqa": _load_simpleqa,
"commonsenseqa": _load_commonsenseqa,
"bullshitbench": _load_bullshitbench_v2,
"mtbench": _load_mt_bench,
"wildbench": _load_wildbench,
}
def get_sample_sources():
return bench_sources.keys()
def sample_question(source: str):
if source not in bench_sources:
return None
questions = bench_sources[source]()
return random.choice(questions)