mirror of
https://github.com/turboderp-org/exllamav3.git
synced 2026-03-15 00:07:24 +00:00
113 lines
4.2 KiB
Python
113 lines
4.2 KiB
Python
import sys, os
|
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|
|
|
import argparse
|
|
from exllamav3 import Generator, Job, model_init
|
|
from chat_templates import *
|
|
import torch
|
|
|
|
# ANSI color codes
|
|
col_default = "\u001b[0m"
|
|
col_user = "\u001b[33;1m" # Yellow
|
|
col_bot = "\u001b[34;1m" # Blue
|
|
col_error = "\u001b[31;1m" # Magenta
|
|
col_sysprompt = "\u001b[37;1m" # Grey
|
|
|
|
@torch.inference_mode()
|
|
def main(args):
|
|
|
|
# Prompt format
|
|
if args.modes:
|
|
print("Available modes:")
|
|
for k, v in prompt_formats.items():
|
|
print(f" - {k:16} {v.description}")
|
|
return
|
|
|
|
user_name = args.user_name
|
|
bot_name = args.bot_name
|
|
prompt_format = prompt_formats[args.mode](user_name, bot_name)
|
|
system_prompt = prompt_format.default_system_prompt() if not args.system_prompt else args.system_prompt
|
|
add_bos = prompt_format.add_bos()
|
|
max_response_tokens = args.max_response_tokens
|
|
|
|
# Load model
|
|
model, config, cache, tokenizer = model_init.init(args)
|
|
context_length = cache.max_num_tokens
|
|
|
|
# Generator
|
|
generator = Generator(
|
|
model = model,
|
|
cache = cache,
|
|
tokenizer = tokenizer,
|
|
)
|
|
stop_conditions = prompt_format.stop_conditions(tokenizer)
|
|
|
|
# Main loop
|
|
print("\n" + col_sysprompt + system_prompt.strip() + col_default)
|
|
context = []
|
|
|
|
while True:
|
|
|
|
# Get user prompt and add to context
|
|
print("\n" + col_user + user_name + ": " + col_default, end = '', flush = True)
|
|
if args.mli:
|
|
user_prompt = sys.stdin.read().rstrip()
|
|
else:
|
|
user_prompt = input().strip()
|
|
context.append((user_prompt, None))
|
|
|
|
# Tokenize context and trim from head if too long
|
|
def get_input_ids():
|
|
frm_context = prompt_format.format(system_prompt, context)
|
|
ids_ = tokenizer.encode(frm_context, add_bos = add_bos, encode_special_tokens = True)
|
|
exp_len_ = ids_.shape[-1] + max_response_tokens + 1
|
|
return ids_, exp_len_
|
|
|
|
ids, exp_len = get_input_ids()
|
|
if exp_len > context_length:
|
|
while exp_len > context_length - 2 * max_response_tokens:
|
|
context = context[1:]
|
|
ids, exp_len = get_input_ids()
|
|
|
|
# Inference
|
|
print("\n" + col_bot + bot_name + ": " + col_default, end = "")
|
|
job = Job(
|
|
input_ids = ids,
|
|
max_new_tokens = max_response_tokens,
|
|
stop_conditions = stop_conditions
|
|
)
|
|
generator.enqueue(job)
|
|
|
|
# Stream response
|
|
response = ""
|
|
while generator.num_remaining_jobs():
|
|
for r in generator.iterate():
|
|
chunk = r.get("text", "")
|
|
if not response and chunk.startswith(" "):
|
|
print(chunk[1:], end = "", flush = True)
|
|
else:
|
|
print(chunk, end = "", flush = True)
|
|
response += chunk
|
|
if r["eos"] and r["eos_reason"] == "max_new_tokens":
|
|
print("\n" + col_error + f" !! Response exceeded {max_response_tokens} tokens and was cut short." + col_default)
|
|
if not response.endswith("\n"):
|
|
print()
|
|
|
|
# Add response to context
|
|
context[-1] = (user_prompt, response.strip())
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser()
|
|
model_init.add_args(parser, cache = True)
|
|
parser.add_argument("-mode", "--mode", type = str, help = "Prompt mode", required = True)
|
|
parser.add_argument("-modes", "--modes", action = "store_true", help = "List available prompt modes and exit")
|
|
parser.add_argument("-un", "--user_name", type = str, default = "User", help = "User name (raw mode only)")
|
|
parser.add_argument("-bn", "--bot_name", type = str, default = "Assistant", help = "Bot name (raw mode only)")
|
|
parser.add_argument("-mli", "--mli", action = "store_true", help = "Enable multi line input")
|
|
parser.add_argument("-sp", "--system_prompt", type = str, help = "Use custom system prompt")
|
|
parser.add_argument("-maxr", "--max_response_tokens", type = int, default = 1000, help = "Max tokens per response, default = 1000")
|
|
# TODO: Sampling options
|
|
_args = parser.parse_args()
|
|
main(_args)
|