mirror of
https://github.com/turboderp-org/exllamav3.git
synced 2026-04-20 14:29:51 +00:00
Refactor sampler args for examples
This commit is contained in:
@@ -3,7 +3,6 @@ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
import argparse
|
||||
from exllamav3 import Generator, Job, model_init
|
||||
from exllamav3.generator.sampler import ComboSampler
|
||||
from chat_templates import *
|
||||
from chat_util import *
|
||||
from chat_io import *
|
||||
@@ -57,18 +56,7 @@ def main(args):
|
||||
stop_conditions += config.eos_token_id_list
|
||||
|
||||
# Sampler
|
||||
sampler = ComboSampler(
|
||||
rep_p = args.repetition_penalty,
|
||||
pres_p = args.presence_penalty,
|
||||
freq_p = args.frequency_penalty,
|
||||
rep_sustain_range = args.penalty_range,
|
||||
rep_decay_range = args.penalty_range,
|
||||
temperature = args.temperature,
|
||||
min_p = args.min_p,
|
||||
top_k = args.top_k,
|
||||
top_p = args.top_p,
|
||||
temp_last = not args.temperature_first,
|
||||
)
|
||||
sampler = model_init.get_arg_sampler(args)
|
||||
|
||||
# Single prompt mode
|
||||
single_prompt = args.prompt
|
||||
@@ -324,7 +312,7 @@ def main(args):
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
model_init.add_args(parser, cache = True)
|
||||
model_init.add_args(parser, cache = True, add_sampling_args = True)
|
||||
parser.add_argument("-mode", "--mode", type = str, help = "Prompt mode", default = None)
|
||||
parser.add_argument("-modes", "--modes", action = "store_true", help = "List available prompt modes and exit")
|
||||
parser.add_argument("-un", "--user_name", type = str, default = "User", help = "User name (raw mode only)")
|
||||
@@ -337,15 +325,6 @@ if __name__ == "__main__":
|
||||
parser.add_argument("-no_think", "--no_think", action = "store_true", help = "Suppress think tags (won't necessarily stop reasoning model from reasoning anyway)")
|
||||
parser.add_argument("-think_budget", "--think_budget", type = int, help = "Thinking budget for supported models", default = None)
|
||||
parser.add_argument("-amnesia", "--amnesia", action = "store_true", help = "Forget context with every new prompt")
|
||||
parser.add_argument("-temp", "--temperature", type = float, help = "Sampling temperature", default = 0.8)
|
||||
parser.add_argument("-temp_first", "--temperature_first", action = "store_true", help = "Apply temperature before truncation")
|
||||
parser.add_argument("-repp", "--repetition_penalty", type = float, help = "Repetition penalty, HF style, 1 to disable (default: disabled)", default = 1.0)
|
||||
parser.add_argument("-presp", "--presence_penalty", type = float, help = "Presence penalty, 0 to disable (default: disabled)", default = 0.0)
|
||||
parser.add_argument("-freqp", "--frequency_penalty", type = float, help = "Frequency penalty, 0 to disable (default: disabled)", default = 0.0)
|
||||
parser.add_argument("-penr", "--penalty_range", type = int, help = "Range for penalties, in tokens (default: 1024) ", default = 1024)
|
||||
parser.add_argument("-minp", "--min_p", type = float, help = "Min-P truncation, 0 to disable (default: 0.08)", default = 0.08)
|
||||
parser.add_argument("-topk", "--top_k", type = int, help = "Top-K truncation, 0 to disable (default: disabled)", default = 0)
|
||||
parser.add_argument("-topp", "--top_p", type = float, help = "Top-P truncation, 1 to disable (default: disabled)", default = 1.0)
|
||||
parser.add_argument("-tps", "--show_tps", action = "store_true", help = "Show tokens/second after every reply")
|
||||
parser.add_argument("-prompt", "--prompt", type = str, help = "Run single prompt, then exit")
|
||||
parser.add_argument("-save", "--save", type = str, help = "Save output to file (use with --prompt)")
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from . import Model, Config, Cache, Tokenizer
|
||||
from .loader import SafetensorsCollection, VariantSafetensorsCollection
|
||||
from .cache import CacheLayer_fp16, CacheLayer_quant
|
||||
from .generator.sampler import ComboSampler
|
||||
from argparse import ArgumentParser
|
||||
import yaml
|
||||
|
||||
@@ -8,6 +9,7 @@ def add_args(
|
||||
parser: ArgumentParser,
|
||||
cache: bool = True,
|
||||
default_cache_size = 8192,
|
||||
add_sampling_args: bool = False
|
||||
):
|
||||
"""
|
||||
Add standard model loading arguments to command line parser
|
||||
@@ -20,6 +22,9 @@ def add_args(
|
||||
|
||||
:param default_cache_size:
|
||||
Default value for -cs / --cache_size argument
|
||||
|
||||
:param add_sampling_args:
|
||||
bool, add sampling arguments
|
||||
"""
|
||||
parser.add_argument("-m", "--model_dir", type = str, help = "Path to model directory", required = True)
|
||||
parser.add_argument("-gs", "--gpu_split", type = str, help = "Maximum amount of VRAM to use per device, in GB.")
|
||||
@@ -36,11 +41,46 @@ def add_args(
|
||||
|
||||
parser.add_argument("-lv", "--load_verbose", action = "store_true", help = "Verbose output while loading")
|
||||
|
||||
if add_sampling_args:
|
||||
parser.add_argument("-temp", "--temperature", type = float, help = "Sampling temperature", default = 0.8)
|
||||
parser.add_argument("-temp_first", "--temperature_first", action = "store_true", help = "Apply temperature before truncation")
|
||||
parser.add_argument("-repp", "--repetition_penalty", type = float, help = "Repetition penalty, HF style, 1 to disable (default: disabled)", default = 1.0)
|
||||
parser.add_argument("-presp", "--presence_penalty", type = float, help = "Presence penalty, 0 to disable (default: disabled)", default = 0.0)
|
||||
parser.add_argument("-freqp", "--frequency_penalty", type = float, help = "Frequency penalty, 0 to disable (default: disabled)", default = 0.0)
|
||||
parser.add_argument("-penr", "--penalty_range", type = int, help = "Range for penalties, in tokens (default: 1024) ", default = 1024)
|
||||
parser.add_argument("-minp", "--min_p", type = float, help = "Min-P truncation, 0 to disable (default: 0.08)", default = 0.08)
|
||||
parser.add_argument("-topk", "--top_k", type = int, help = "Top-K truncation, 0 to disable (default: disabled)", default = 0)
|
||||
parser.add_argument("-topp", "--top_p", type = float, help = "Top-P truncation, 1 to disable (default: disabled)", default = 1.0)
|
||||
|
||||
if cache:
|
||||
parser.add_argument("-cs", "--cache_size", type = int, help = f"Total cache size in tokens, default: {default_cache_size}", default = default_cache_size)
|
||||
parser.add_argument("-cq", "--cache_quant", type = str, help = "Use quantized cache. Specify either kv_bits or k_bits,v_bits pair")
|
||||
|
||||
|
||||
def get_arg_sampler(args):
|
||||
"""
|
||||
Create CompoSampler from default args above
|
||||
|
||||
:param args:
|
||||
args from ArgumentParser
|
||||
|
||||
:return:
|
||||
ComboSampler
|
||||
"""
|
||||
return ComboSampler(
|
||||
rep_p = args.repetition_penalty,
|
||||
pres_p = args.presence_penalty,
|
||||
freq_p = args.frequency_penalty,
|
||||
rep_sustain_range = args.penalty_range,
|
||||
rep_decay_range = args.penalty_range,
|
||||
temperature = args.temperature,
|
||||
min_p = args.min_p,
|
||||
top_k = args.top_k,
|
||||
top_p = args.top_p,
|
||||
temp_last = not args.temperature_first,
|
||||
)
|
||||
|
||||
|
||||
def init(
|
||||
args,
|
||||
load_tokenizer: bool = True,
|
||||
|
||||
Reference in New Issue
Block a user