From 2a43d3069d577f1ee757a672bd87550cda3a7ee7 Mon Sep 17 00:00:00 2001 From: Sinan Akkoyun Date: Fri, 29 Sep 2023 15:57:28 +0200 Subject: [PATCH] Added codeblock highlighting to chatcode.py --- examples/chatcode.py | 398 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 398 insertions(+) create mode 100644 examples/chatcode.py diff --git a/examples/chatcode.py b/examples/chatcode.py new file mode 100644 index 0000000..49da5d0 --- /dev/null +++ b/examples/chatcode.py @@ -0,0 +1,398 @@ +import os +import re +import sys +from io import StringIO + +from pygments import highlight +from pygments.formatter import Formatter +from pygments.formatters.terminal import TerminalFormatter +from pygments.lexers import get_lexer_by_name, guess_lexer +from pygments.style import Style +from pygments.styles.default import DefaultStyle +from pygments.token import Token +from pygments.util import ClassNotFound + +import shutil + +# Append the parent directory to sys.path +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + + +from exllamav2 import( + ExLlamaV2, + ExLlamaV2Config, + ExLlamaV2Cache, + ExLlamaV2Tokenizer, + model_init, +) + +import argparse +import torch + +from exllamav2.generator import ( + ExLlamaV2StreamingGenerator, + ExLlamaV2Sampler +) + +# Options + +parser = argparse.ArgumentParser(description = "Simple Llama2 chat example for ExLlamaV2") +parser.add_argument("-mode", "--mode", choices = ["llama", "raw", "codellama"], help = "Chat mode. Use llama for Llama 1/2 chat finetunes.") +parser.add_argument("-un", "--username", type = str, default = "User", help = "Username when using raw chat mode") +parser.add_argument("-bn", "--botname", type = str, default = "Chatbort", help = "Bot name when using raw chat mode") +parser.add_argument("-sp", "--system_prompt", type = str, help = "Use custom system prompt") + +parser.add_argument("-temp", "--temperature", type = float, default = 0.95, help = "Sampler temperature, default = 0.95 (1 to disable)") +parser.add_argument("-topk", "--top_k", type = int, default = 50, help = "Sampler top-K, default = 50 (0 to disable)") +parser.add_argument("-topp", "--top_p", type = float, default = 0.8, help = "Sampler top-P, default = 0.8 (0 to disable)") +parser.add_argument("-typical", "--typical", type = float, default = 0.0, help = "Sampler typical threshold, default = 0.0 (0 to disable)") +parser.add_argument("-repp", "--repetition_penalty", type = float, default = 1.1, help = "Sampler repetition penalty, default = 1.1 (1 to disable)") +parser.add_argument("-maxr", "--max_response_tokens", type = int, default = 1000, help = "Max tokens per response, default = 1000") +parser.add_argument("-resc", "--response_chunk", type = int, default = 250, help = "Space to reserve in context for reply, default = 250") + +# Initialize model and tokenizer + +model_init.add_args(parser) +args = parser.parse_args() +model_init.check_args(args) +model_init.print_options(args) +model, tokenizer = model_init.init(args) + +# Create cache + +cache = ExLlamaV2Cache(model) + +# Prompt templates + +username = args.username +botname = args.botname +system_prompt = args.system_prompt +mode = args.mode + +if mode == "llama" or mode == "codellama": + + if not system_prompt: + + if mode == "llama": + + system_prompt = \ + """You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. """ + \ + """Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. """ + \ + """Please ensure that your responses are socially unbiased and positive in nature.""" + + elif mode == "codellama": + + system_prompt = \ + """You are a helpful coding assistant. Always answer as helpfully as possible.""" + + first_prompt = \ + """[INST] <>\n<|system_prompt|>\n<>\n\n<|user_prompt|> [/INST]""" + + subs_prompt = \ + """[INST] <|user_prompt|> [/INST]""" + +elif mode == "raw": + + if not system_prompt: + + system_prompt = \ + f"""This is a conversation between a helpful AI assistant named {botname} and a """ + ("""user named {username}.""" if username != "User" else """user.""") + + first_prompt = \ + f"""<|system_prompt|>\n{username}: <|user_prompt|>\n{botname}:""" + + subs_prompt = \ + f"""{username}: <|user_prompt|>\n{botname}:""" + +else: + + print(" ## Error: Incorrect/no mode specified.") + sys.exit() + +# Chat context + +def format_prompt(user_prompt, first): + global system_prompt, first_prompt, subs_prompt + + if first: + return first_prompt \ + .replace("<|system_prompt|>", system_prompt) \ + .replace("<|user_prompt|>", user_prompt) + else: + return subs_prompt \ + .replace("<|user_prompt|>", user_prompt) + +def encode_prompt(text): + global tokenizer, mode + + if mode == "llama" or mode == "codellama": + return tokenizer.encode(text, add_bos = True) + + if mode == "raw": + return tokenizer.encode(text) + +user_prompts = [] +responses_ids = [] + +def get_tokenized_context(max_len): + global user_prompts, responses_ids + + while True: + + context = torch.empty((1, 0), dtype=torch.long) + + for turn in range(len(user_prompts)): + + up_ids = encode_prompt(format_prompt(user_prompts[turn], context.shape[-1] == 0)) + context = torch.cat([context, up_ids], dim=-1) + + if turn < len(responses_ids): + context = torch.cat([context, responses_ids[turn]], dim=-1) + + if context.shape[-1] < max_len: return context + + # If the context is too long, remove the first Q/A pair and try again. The system prompt will be moved to + # the first entry in the truncated context + + user_prompts = user_prompts[1:] + responses_ids = responses_ids[1:] + + +# Generator + +generator = ExLlamaV2StreamingGenerator(model, cache, tokenizer) + +settings = ExLlamaV2Sampler.Settings() +settings.temperature = args.temperature +settings.top_k = args.top_k +settings.top_p = args.top_p +settings.typical = args.typical +settings.token_repetition_penalty = args.repetition_penalty + +max_response_tokens = args.max_response_tokens +min_space_in_context = args.response_chunk + +# Stop conditions + +if mode == "llama" or mode == "codellama": + + generator.set_stop_conditions([tokenizer.eos_token_id]) + +if mode == "raw": + + generator.set_stop_conditions([username + ":", username[0:1] + ":", username.upper() + ":", username.lower() + ":", tokenizer.eos_token_id]) + +# ANSI color codes + +col_default = "\u001b[0m" +col_user = "\u001b[34;1m" # Blue +col_bot = "\u001b[31;1m" # Bright Red +col_error = "\u001b[31;1m" # Magenta + +# Code block syntax helpers + +in_code_block = False +code_block_text = "" +lines_printed = 0 + +code_pad = 2 +block_pad_left = 1 + +# Code block formatter for black background + +class BlackBackgroundTerminalFormatter(TerminalFormatter): + def format(self, tokensource, outfile): + global code_pad, block_pad_left + # Create a buffer to capture the parent class's output + buffer = StringIO() + # Call the parent class's format method + super().format(tokensource, buffer) + # Get the content from the buffer + content = buffer.getvalue() + + # Padding of code + lines = content.split('\n') + padded_lines = [f"{lines[0]}{' '*code_pad*2}"] + [f"{' '*code_pad}{line}{' '*code_pad}" for line in lines[1:-1]] + [lines[-1]] + content = '\n'.join(padded_lines) + + # Modify the ANSI codes to include a black background + modified_content = self.add_black_background(content) + + # Offset codeblock + modified_content = '\n'.join([modified_content.split('\n')[0]] + [f"{' '*block_pad_left}{line}" for line in modified_content.split('\n')[1:]]) + + # Relay the modified content to the outfile + outfile.write(modified_content) + + def add_black_background(self, content): + # Split the content into lines + lines = content.split('\n') + + # Process each line to ensure it has a black background + processed_lines = [] + for line in lines: + # Split the line into tokens based on ANSI escape sequences + tokens = re.split(r'(\033\[[^m]*m)', line) + # Process each token to ensure it has a black background + processed_tokens = [] + for token in tokens: + # If the token is an ANSI escape sequence + if re.match(r'\033\[[^m]*m', token): + # Append the black background code to the existing ANSI code + processed_tokens.append(f'{token}\033[40m') + else: + # If the token is not an ANSI escape sequence, add the black background code to it + processed_tokens.append(f'\033[40m{token}\033[0m') # Reset code added here + + # Join the processed tokens back into a single line + processed_line = ''.join(processed_tokens) + # Add the ANSI reset code to the end of the line + processed_line += '\033[0m' + processed_lines.append(processed_line) + + # Join the processed lines back into a single string + modified_content = '\n'.join(processed_lines) + + return modified_content + +# Print a code block, updating the CLI in real-time + +def print_code_block(chunk): + global lines_printed + global code_block_text + global code_pad, block_pad_left + + # Clear previously printed lines + for _ in range(lines_printed): # -1 not needed? + # Move cursor up one line + print('\x1b[1A', end='') + # Clear line + print('\x1b[2K', end='') + + terminal_width = shutil.get_terminal_size().columns + + # Check if the chunk will exceed the terminal width on the current line + current_line_length = len(code_block_text.split('\n')[-1]) + len(chunk) + 2 * 3 + 3 # Including padding and offset + if current_line_length > terminal_width: + code_block_text += '\n' + + # Update the code block text + code_block_text += chunk + + # Split updated text into lines and find the longest line + lines = code_block_text.split('\n') + max_length = max(len(line) for line in lines) + + # Pad all lines to match the length of the longest line + padded_lines = [line.ljust(max_length) for line in lines] + + # Join padded lines into a single string + padded_text = '\n'.join(padded_lines) + + # Try guessing the lexer for syntax highlighting + try: + lexer = guess_lexer(padded_text) + except ClassNotFound: + lexer = get_lexer_by_name("text") # Fallback to plain text if language isn't supported by pygments + + formatter = BlackBackgroundTerminalFormatter() + highlighted_text = highlight(padded_text, lexer, formatter) + + highlighted_text = highlighted_text.replace('\n', '\033[0m\n') + + + # Print the updated padded and highlighted text + print(highlighted_text, end='') + + # Update the lines_printed counter + lines_printed = len(lines) + +# Main loop + +while True: + + # Get user prompt + + print() + up = input(col_user + username + ": " + col_default).strip() + print() + + # Add to context + + user_prompts.append(up) + + # Send tokenized context to generator + + active_context = get_tokenized_context(model.config.max_seq_len - min_space_in_context) + generator.begin_stream(active_context, settings) + + # print("------") + # print(tokenizer.decode(active_context)) + # print("------") + + # Stream response + + if mode == "raw": + + print(col_bot + botname + ": " + col_default, end = "") + + response_tokens = 0 + response_text = "" + responses_ids.append(torch.empty((1, 0), dtype = torch.long)) + + while True: + + # Get response stream + + chunk, eos, tokens = generator.stream() + if len(response_text) == 0: chunk = chunk.lstrip() + response_text += chunk + + # Check for code block delimiters + if chunk.startswith("```"): + in_code_block = not in_code_block # Toggle in_code_block flag + chunk = chunk[3:] # Remove the delimiter from the chunk + print('\n') + + if in_code_block: + print_code_block(chunk) # Handle code block streaming + else: + # If exiting a code block, highlight and print the code block text + if code_block_text: + code_block_text = "" # Reset code_block_text for the next code block + lines_printed = 0 + print('\033[0m', end='') # Reset block color to be certain + + # Continue as normal if not in a code block + responses_ids[-1] = torch.cat([responses_ids[-1], tokens], dim=-1) + print(chunk, end="") + sys.stdout.flush() + + # If model has run out of space, rebuild the context and restart stream + + if generator.full(): + + active_context = get_tokenized_context(model.config.max_seq_len - min_space_in_context) + generator.begin_stream(active_context, settings) + + # If response is too long, cut it short, and append EOS if that was a stop condition + + response_tokens += 1 + if response_tokens == max_response_tokens: + + if tokenizer.eos_token_id in generator.stop_tokens: + responses_ids[-1] = torch.cat([responses_ids[-1], tokenizer.single_token(tokenizer.eos_token_id)], dim = -1) + + print() + print(col_error + f" !! Response exceeded {max_response_tokens} tokens and was cut short." + col_default) + break + + # EOS signal returned + + if eos: + + if mode == "llama" or mode == "codellama": + print() + + break +