mirror of
https://github.com/turboderp-org/exllamav3.git
synced 2026-05-11 08:20:05 +00:00
chat.py: Random benchmark question feature
This commit is contained in:
@@ -121,6 +121,7 @@ def main(args):
|
||||
"/think Toggle reasoning mode",
|
||||
"/tps Toggle tokens/second output",
|
||||
"/x Exit",
|
||||
"/b <source> Random benchmark question",
|
||||
]))
|
||||
continue
|
||||
|
||||
@@ -129,6 +130,27 @@ def main(args):
|
||||
print_info("Exiting")
|
||||
break
|
||||
|
||||
# Random benchmark question
|
||||
case "/b":
|
||||
source = c[1] if len(c) > 1 else None
|
||||
sources = get_sample_sources()
|
||||
if source and source.isnumeric():
|
||||
i = int(source)
|
||||
source = list(sources)[i - 1] if 1 <= i <= len(sources) else None
|
||||
if source not in sources:
|
||||
print_info(
|
||||
"Available sample sources:\n\n" +
|
||||
"\n".join([f"{i + 1}. {t}" for i, t in enumerate(get_sample_sources())])
|
||||
)
|
||||
continue
|
||||
question = sample_question(source)
|
||||
print_info(f"Question from {source}; multi-line mode, press Alt-Enter to submit or Ctrl-C to abort")
|
||||
try:
|
||||
user_prompt = read_input_fn(args, "Prompt", True, question)
|
||||
except KeyboardInterrupt:
|
||||
print_info("Aborted")
|
||||
continue
|
||||
|
||||
# Copy codeblock to clipboard
|
||||
case "/cc":
|
||||
try:
|
||||
@@ -178,35 +200,47 @@ def main(args):
|
||||
|
||||
# Edit last response
|
||||
case "/e":
|
||||
print_info("Press Alt+Enter to submit")
|
||||
print_info("Press Alt-Enter to submit")
|
||||
user_prompt = context[-1][0]
|
||||
last_reply = context[-1][-1]
|
||||
prefix = read_input_fn(args, bot_name, True, last_reply)
|
||||
context = context[:-1]
|
||||
enable_healing = True
|
||||
try:
|
||||
prefix = read_input_fn(args, bot_name, True, last_reply)
|
||||
context = context[:-1]
|
||||
enable_healing = True
|
||||
except KeyboardInterrupt:
|
||||
print_info("Exiting")
|
||||
break
|
||||
|
||||
# Edit system prompt
|
||||
case "/sp":
|
||||
print_info("Press Alt+Enter to submit")
|
||||
system_prompt = read_input_fn(args, "System prompt", True, system_prompt)
|
||||
continue
|
||||
print_info("Press Alt-Enter to submit")
|
||||
try:
|
||||
system_prompt = read_input_fn(args, "System prompt", True, system_prompt)
|
||||
continue
|
||||
except KeyboardInterrupt:
|
||||
print_info("Exiting")
|
||||
break
|
||||
|
||||
# Edit banned strings
|
||||
case "/ban":
|
||||
print_info("Write each string on a new line and enclose in \"double quotes\", press Alt+Enter to submit")
|
||||
print_info("Write each string on a new line and enclose in \"double quotes\", press Alt-Enter to submit")
|
||||
bans = "\n".join(f"\"{b}\"" for b in banned_strings)
|
||||
bans = read_input_fn(args, "Banned strings", True, bans)
|
||||
bans = [b.strip() for b in bans.split("\n")]
|
||||
bans = [b[1:-1] for b in bans if b.startswith("\"") and b.endswith("\"")]
|
||||
d = len(bans) - len(banned_strings)
|
||||
banned_strings = bans
|
||||
if d < 0:
|
||||
print_info(f"{-d} string(s) removed")
|
||||
elif d > 0:
|
||||
print_info(f"{d} string(s) added")
|
||||
else:
|
||||
print_info("Strings updated")
|
||||
continue
|
||||
try:
|
||||
bans = read_input_fn(args, "Banned strings", True, bans)
|
||||
bans = [b.strip() for b in bans.split("\n")]
|
||||
bans = [b[1:-1] for b in bans if b.startswith("\"") and b.endswith("\"")]
|
||||
d = len(bans) - len(banned_strings)
|
||||
banned_strings = bans
|
||||
if d < 0:
|
||||
print_info(f"{-d} string(s) removed")
|
||||
elif d > 0:
|
||||
print_info(f"{d} string(s) added")
|
||||
else:
|
||||
print_info("Strings updated")
|
||||
continue
|
||||
except KeyboardInterrupt:
|
||||
print_info("Exiting")
|
||||
break
|
||||
|
||||
# Save conversation
|
||||
case "/save":
|
||||
@@ -369,7 +403,7 @@ if __name__ == "__main__":
|
||||
parser.add_argument("-modes", "--modes", action = "store_true", help = "List available prompt modes and exit")
|
||||
parser.add_argument("-un", "--user_name", type = str, default = "User", help = "User name (raw mode only)")
|
||||
parser.add_argument("-bn", "--bot_name", type = str, default = "Assistant", help = "Bot name (raw mode only)")
|
||||
parser.add_argument("-mli", "--multiline", action = "store_true", help = "Enable multi line input (use Alt+Enter to submit input)")
|
||||
parser.add_argument("-mli", "--multiline", action = "store_true", help = "Enable multi line input (use Alt-Enter to submit input)")
|
||||
parser.add_argument("-sp", "--system_prompt", type = str, help = "Use custom system prompt")
|
||||
parser.add_argument("-maxr", "--max_response_tokens", type = int, default = 1000, help = "Max tokens per response, default = 1000")
|
||||
parser.add_argument("-basic", "--basic_console", action = "store_true", help = "Use basic console output (no markdown and fancy prompt input")
|
||||
|
||||
@@ -1,5 +1,9 @@
|
||||
from __future__ import annotations
|
||||
import re
|
||||
import pyperclip
|
||||
import json
|
||||
import random
|
||||
from exllamav3.util.file import disk_lru_cache
|
||||
|
||||
def copy_last_codeblock(text: str, num) -> str | None:
|
||||
pattern = re.compile(r"```[^\n`]*\n(.*?)```", re.DOTALL)
|
||||
@@ -32,4 +36,96 @@ def extract_svg(s: str, begin: str = "<svg", end: str = "</svg>"):
|
||||
return None
|
||||
|
||||
_, start, stop = best
|
||||
return s[start:stop]
|
||||
return s[start:stop]
|
||||
|
||||
|
||||
def _fetch_github_json(url: str) -> list[dict]:
|
||||
import urllib.request
|
||||
req = urllib.request.Request(url, headers = {"User-Agent": "benchmark-sampler"})
|
||||
with urllib.request.urlopen(req) as resp:
|
||||
return json.loads(resp.read().decode())
|
||||
|
||||
|
||||
@disk_lru_cache("_load_truthfulqa")
|
||||
def _load_truthfulqa() -> list[str]:
|
||||
"""817 questions spanning health, law, finance, politics, etc."""
|
||||
from datasets import load_dataset
|
||||
ds = load_dataset("truthfulqa/truthful_qa", "generation", split = "validation")
|
||||
return [row["question"] for row in ds]
|
||||
|
||||
|
||||
@disk_lru_cache("_load_simpleqa")
|
||||
def _load_simpleqa() -> list[str]:
|
||||
from datasets import load_dataset
|
||||
ds = load_dataset("basicv8vc/SimpleQA", split = "test")
|
||||
return [row["problem"] for row in ds]
|
||||
|
||||
|
||||
@disk_lru_cache("_load_arc_challenge")
|
||||
def _load_arc_challenge() -> list[str]:
|
||||
from datasets import load_dataset
|
||||
ds = load_dataset("allenai/ai2_arc", "ARC-Challenge", split = "test")
|
||||
return [row["question"] for row in ds]
|
||||
|
||||
|
||||
@disk_lru_cache("_load_commonsenseqa")
|
||||
def _load_commonsenseqa() -> list[str]:
|
||||
from datasets import load_dataset
|
||||
ds = load_dataset("tau/commonsense_qa", split = "validation")
|
||||
return [row["question"] for row in ds]
|
||||
|
||||
|
||||
@disk_lru_cache("_load_bullshitbench_v2_")
|
||||
def _load_bullshitbench_v2() -> list[str]:
|
||||
url = "https://raw.githubusercontent.com/petergpt/bullshit-benchmark/main/questions.v2.json"
|
||||
data = _fetch_github_json(url)
|
||||
questions = []
|
||||
for t in data["techniques"]:
|
||||
questions += [q["question"] for q in t.get("questions", data) if "question" in q]
|
||||
return questions
|
||||
|
||||
|
||||
@disk_lru_cache("_load_mt_bench")
|
||||
def _load_mt_bench() -> list[str]:
|
||||
from datasets import load_dataset
|
||||
ds = load_dataset("HuggingFaceH4/mt_bench_prompts", split = "train")
|
||||
return [row["prompt"][0] for row in ds]
|
||||
|
||||
|
||||
@disk_lru_cache("_load_wildbench")
|
||||
def _load_wildbench() -> list[str]:
|
||||
from datasets import load_dataset
|
||||
ds = load_dataset("allenai/WildBench", "v2", split = "test")
|
||||
questions = []
|
||||
for row in ds:
|
||||
turns = row["conversation_input"]
|
||||
for turn in turns:
|
||||
if turn["role"] == "user":
|
||||
if turn.get("toxic") or turn.get("redacted"):
|
||||
break
|
||||
text = turn["content"].strip()
|
||||
if text:
|
||||
questions.append(text)
|
||||
break # only first user turn
|
||||
return questions
|
||||
|
||||
|
||||
bench_sources = {
|
||||
"truthfulqa": _load_truthfulqa,
|
||||
"simpleqa": _load_simpleqa,
|
||||
"commonsenseqa": _load_commonsenseqa,
|
||||
"bullshitbench": _load_bullshitbench_v2,
|
||||
"mtbench": _load_mt_bench,
|
||||
"wildbench": _load_wildbench,
|
||||
}
|
||||
|
||||
|
||||
def get_sample_sources():
|
||||
return bench_sources.keys()
|
||||
|
||||
|
||||
def sample_question(source: str):
|
||||
if source not in bench_sources:
|
||||
return None
|
||||
questions = bench_sources[source]()
|
||||
return random.choice(questions)
|
||||
Reference in New Issue
Block a user