Add typical setting to chat example.

This commit is contained in:
turboderp
2023-09-26 19:50:44 +02:00
parent 92a9828450
commit ba5f6191c8
2 changed files with 11 additions and 15 deletions

View File

@@ -29,6 +29,7 @@ parser.add_argument("-sp", "--system_prompt", type = str, help = "Use custom sys
parser.add_argument("-temp", "--temperature", type = float, default = 0.95, help = "Sampler temperature, default = 0.95 (1 to disable)")
parser.add_argument("-topk", "--top_k", type = int, default = 50, help = "Sampler top-K, default = 50 (0 to disable)")
parser.add_argument("-topp", "--top_p", type = float, default = 0.8, help = "Sampler top-P, default = 0.8 (0 to disable)")
parser.add_argument("-typical", "--typical", type = float, default = 0.0, help = "Sampler typical threshold, default = 0.0 (0 to disable)")
parser.add_argument("-repp", "--repetition_penalty", type = float, default = 1.1, help = "Sampler repetition penalty, default = 1.1 (1 to disable)")
parser.add_argument("-maxr", "--max_response_tokens", type = int, default = 1000, help = "Max tokens per response, default = 1000")
parser.add_argument("-resc", "--response_chunk", type = int, default = 250, help = "Space to reserve in context for reply, default = 250")
@@ -149,6 +150,7 @@ settings = ExLlamaV2Sampler.Settings()
settings.temperature = args.temperature
settings.top_k = args.top_k
settings.top_p = args.top_p
settings.typical = args.typical
settings.token_repetition_penalty = args.repetition_penalty
max_response_tokens = args.max_response_tokens
@@ -162,7 +164,7 @@ if mode == "llama" or mode == "codellama":
if mode == "raw":
generator.set_stop_conditions([username + ":", username[0:1] + ":"])
generator.set_stop_conditions([username + ":", username[0:1] + ":", username.upper() + ":", username.lower() + ":", tokenizer.eos_token_id])
# ANSI color codes

View File

@@ -14,22 +14,16 @@ import torch
# Models to test
model_base = "/mnt/str/models/_exl2/llama2-7b-chat-exl2/"
model_base = "/mnt/str/models/_exl2/llama2-70b-chat-exl2/"
variants = [v for v in os.listdir(model_base) if os.path.isdir(os.path.join(model_base, v))]
# variants = [v for v in os.listdir(model_base) if os.path.isdir(os.path.join(model_base, v))]
# variants = \
# [
# "2.5bpw",
# "3.0bpw",
# "3.5bpw",
# "4.0bpw",
# "4.65bpw",
# "5.0bpw",
# "6.0bpw",
# "8.0bpw",
# "fp16",
# ]
variants = \
[
"3.0bpw",
"4.0bpw",
"4.65bpw",
]
gpu_split = (19.5, 24)