diff --git a/examples/chat.py b/examples/chat.py index 8493f13..be88e61 100644 --- a/examples/chat.py +++ b/examples/chat.py @@ -121,6 +121,7 @@ def main(args): "/think Toggle reasoning mode", "/tps Toggle tokens/second output", "/x Exit", + "/b Random benchmark question", ])) continue @@ -129,6 +130,27 @@ def main(args): print_info("Exiting") break + # Random benchmark question + case "/b": + source = c[1] if len(c) > 1 else None + sources = get_sample_sources() + if source and source.isnumeric(): + i = int(source) + source = list(sources)[i - 1] if 1 <= i <= len(sources) else None + if source not in sources: + print_info( + "Available sample sources:\n\n" + + "\n".join([f"{i + 1}. {t}" for i, t in enumerate(get_sample_sources())]) + ) + continue + question = sample_question(source) + print_info(f"Question from {source}; multi-line mode, press Alt-Enter to submit or Ctrl-C to abort") + try: + user_prompt = read_input_fn(args, "Prompt", True, question) + except KeyboardInterrupt: + print_info("Aborted") + continue + # Copy codeblock to clipboard case "/cc": try: @@ -178,35 +200,47 @@ def main(args): # Edit last response case "/e": - print_info("Press Alt+Enter to submit") + print_info("Press Alt-Enter to submit") user_prompt = context[-1][0] last_reply = context[-1][-1] - prefix = read_input_fn(args, bot_name, True, last_reply) - context = context[:-1] - enable_healing = True + try: + prefix = read_input_fn(args, bot_name, True, last_reply) + context = context[:-1] + enable_healing = True + except KeyboardInterrupt: + print_info("Exiting") + break # Edit system prompt case "/sp": - print_info("Press Alt+Enter to submit") - system_prompt = read_input_fn(args, "System prompt", True, system_prompt) - continue + print_info("Press Alt-Enter to submit") + try: + system_prompt = read_input_fn(args, "System prompt", True, system_prompt) + continue + except KeyboardInterrupt: + print_info("Exiting") + break # Edit banned strings case "/ban": - print_info("Write each string on a new line and enclose in \"double quotes\", press Alt+Enter to submit") + print_info("Write each string on a new line and enclose in \"double quotes\", press Alt-Enter to submit") bans = "\n".join(f"\"{b}\"" for b in banned_strings) - bans = read_input_fn(args, "Banned strings", True, bans) - bans = [b.strip() for b in bans.split("\n")] - bans = [b[1:-1] for b in bans if b.startswith("\"") and b.endswith("\"")] - d = len(bans) - len(banned_strings) - banned_strings = bans - if d < 0: - print_info(f"{-d} string(s) removed") - elif d > 0: - print_info(f"{d} string(s) added") - else: - print_info("Strings updated") - continue + try: + bans = read_input_fn(args, "Banned strings", True, bans) + bans = [b.strip() for b in bans.split("\n")] + bans = [b[1:-1] for b in bans if b.startswith("\"") and b.endswith("\"")] + d = len(bans) - len(banned_strings) + banned_strings = bans + if d < 0: + print_info(f"{-d} string(s) removed") + elif d > 0: + print_info(f"{d} string(s) added") + else: + print_info("Strings updated") + continue + except KeyboardInterrupt: + print_info("Exiting") + break # Save conversation case "/save": @@ -369,7 +403,7 @@ if __name__ == "__main__": parser.add_argument("-modes", "--modes", action = "store_true", help = "List available prompt modes and exit") parser.add_argument("-un", "--user_name", type = str, default = "User", help = "User name (raw mode only)") parser.add_argument("-bn", "--bot_name", type = str, default = "Assistant", help = "Bot name (raw mode only)") - parser.add_argument("-mli", "--multiline", action = "store_true", help = "Enable multi line input (use Alt+Enter to submit input)") + parser.add_argument("-mli", "--multiline", action = "store_true", help = "Enable multi line input (use Alt-Enter to submit input)") parser.add_argument("-sp", "--system_prompt", type = str, help = "Use custom system prompt") parser.add_argument("-maxr", "--max_response_tokens", type = int, default = 1000, help = "Max tokens per response, default = 1000") parser.add_argument("-basic", "--basic_console", action = "store_true", help = "Use basic console output (no markdown and fancy prompt input") diff --git a/examples/chat_util.py b/examples/chat_util.py index 1a7a8a2..d113ac4 100644 --- a/examples/chat_util.py +++ b/examples/chat_util.py @@ -1,5 +1,9 @@ +from __future__ import annotations import re import pyperclip +import json +import random +from exllamav3.util.file import disk_lru_cache def copy_last_codeblock(text: str, num) -> str | None: pattern = re.compile(r"```[^\n`]*\n(.*?)```", re.DOTALL) @@ -32,4 +36,96 @@ def extract_svg(s: str, begin: str = " list[dict]: + import urllib.request + req = urllib.request.Request(url, headers = {"User-Agent": "benchmark-sampler"}) + with urllib.request.urlopen(req) as resp: + return json.loads(resp.read().decode()) + + +@disk_lru_cache("_load_truthfulqa") +def _load_truthfulqa() -> list[str]: + """817 questions spanning health, law, finance, politics, etc.""" + from datasets import load_dataset + ds = load_dataset("truthfulqa/truthful_qa", "generation", split = "validation") + return [row["question"] for row in ds] + + +@disk_lru_cache("_load_simpleqa") +def _load_simpleqa() -> list[str]: + from datasets import load_dataset + ds = load_dataset("basicv8vc/SimpleQA", split = "test") + return [row["problem"] for row in ds] + + +@disk_lru_cache("_load_arc_challenge") +def _load_arc_challenge() -> list[str]: + from datasets import load_dataset + ds = load_dataset("allenai/ai2_arc", "ARC-Challenge", split = "test") + return [row["question"] for row in ds] + + +@disk_lru_cache("_load_commonsenseqa") +def _load_commonsenseqa() -> list[str]: + from datasets import load_dataset + ds = load_dataset("tau/commonsense_qa", split = "validation") + return [row["question"] for row in ds] + + +@disk_lru_cache("_load_bullshitbench_v2_") +def _load_bullshitbench_v2() -> list[str]: + url = "https://raw.githubusercontent.com/petergpt/bullshit-benchmark/main/questions.v2.json" + data = _fetch_github_json(url) + questions = [] + for t in data["techniques"]: + questions += [q["question"] for q in t.get("questions", data) if "question" in q] + return questions + + +@disk_lru_cache("_load_mt_bench") +def _load_mt_bench() -> list[str]: + from datasets import load_dataset + ds = load_dataset("HuggingFaceH4/mt_bench_prompts", split = "train") + return [row["prompt"][0] for row in ds] + + +@disk_lru_cache("_load_wildbench") +def _load_wildbench() -> list[str]: + from datasets import load_dataset + ds = load_dataset("allenai/WildBench", "v2", split = "test") + questions = [] + for row in ds: + turns = row["conversation_input"] + for turn in turns: + if turn["role"] == "user": + if turn.get("toxic") or turn.get("redacted"): + break + text = turn["content"].strip() + if text: + questions.append(text) + break # only first user turn + return questions + + +bench_sources = { + "truthfulqa": _load_truthfulqa, + "simpleqa": _load_simpleqa, + "commonsenseqa": _load_commonsenseqa, + "bullshitbench": _load_bullshitbench_v2, + "mtbench": _load_mt_bench, + "wildbench": _load_wildbench, +} + + +def get_sample_sources(): + return bench_sources.keys() + + +def sample_question(source: str): + if source not in bench_sources: + return None + questions = bench_sources[source]() + return random.choice(questions) \ No newline at end of file