Added draft token count as parameter to chat.py (#635)

This commit is contained in:
Sinan
2024-09-24 11:16:30 +02:00
committed by GitHub
parent 15e54046ba
commit 7c7b1993b4

View File

@@ -33,6 +33,7 @@ prompt_formats_list = list(prompt_formats.keys())
parser = argparse.ArgumentParser(description = "Simple Llama2 chat example for ExLlamaV2")
parser.add_argument("-dm", "--draft_model_dir", type = str, default = None, help = "Path to draft model directory")
parser.add_argument("-nds", "--no_draft_scale", action = "store_true", help = "If draft model has smaller context size than model, don't apply alpha (NTK) scaling to extend it")
parser.add_argument("-dn", "--draft_n_tokens", type = int, default = 5, help = "How many tokens to speculate ahead (defaults to 5)")
parser.add_argument("-modes", "--modes", action = "store_true", help = "List available modes and exit.")
parser.add_argument("-mode", "--mode", choices = prompt_formats_list, help = "Chat mode. Use llama for Llama 1/2 chat finetunes.")
@@ -219,7 +220,7 @@ def get_tokenized_context(max_len):
# Generator
generator = ExLlamaV2StreamingGenerator(model, cache, tokenizer, draft_model, draft_cache)
generator = ExLlamaV2StreamingGenerator(model, cache, tokenizer, draft_model, draft_cache, num_speculative_tokens=args.draft_n_tokens)
generator.speculative_ngram = args.ngram_decoding
settings = ExLlamaV2Sampler.Settings(