Refactor sampler args for examples

This commit is contained in:
turboderp
2026-01-11 12:33:27 +01:00
parent 27c68d4e65
commit 288a98f5e3
2 changed files with 42 additions and 23 deletions

View File

@@ -3,7 +3,6 @@ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import argparse
from exllamav3 import Generator, Job, model_init
from exllamav3.generator.sampler import ComboSampler
from chat_templates import *
from chat_util import *
from chat_io import *
@@ -57,18 +56,7 @@ def main(args):
stop_conditions += config.eos_token_id_list
# Sampler
sampler = ComboSampler(
rep_p = args.repetition_penalty,
pres_p = args.presence_penalty,
freq_p = args.frequency_penalty,
rep_sustain_range = args.penalty_range,
rep_decay_range = args.penalty_range,
temperature = args.temperature,
min_p = args.min_p,
top_k = args.top_k,
top_p = args.top_p,
temp_last = not args.temperature_first,
)
sampler = model_init.get_arg_sampler(args)
# Single prompt mode
single_prompt = args.prompt
@@ -324,7 +312,7 @@ def main(args):
if __name__ == "__main__":
parser = argparse.ArgumentParser()
model_init.add_args(parser, cache = True)
model_init.add_args(parser, cache = True, add_sampling_args = True)
parser.add_argument("-mode", "--mode", type = str, help = "Prompt mode", default = None)
parser.add_argument("-modes", "--modes", action = "store_true", help = "List available prompt modes and exit")
parser.add_argument("-un", "--user_name", type = str, default = "User", help = "User name (raw mode only)")
@@ -337,15 +325,6 @@ if __name__ == "__main__":
parser.add_argument("-no_think", "--no_think", action = "store_true", help = "Suppress think tags (won't necessarily stop reasoning model from reasoning anyway)")
parser.add_argument("-think_budget", "--think_budget", type = int, help = "Thinking budget for supported models", default = None)
parser.add_argument("-amnesia", "--amnesia", action = "store_true", help = "Forget context with every new prompt")
parser.add_argument("-temp", "--temperature", type = float, help = "Sampling temperature", default = 0.8)
parser.add_argument("-temp_first", "--temperature_first", action = "store_true", help = "Apply temperature before truncation")
parser.add_argument("-repp", "--repetition_penalty", type = float, help = "Repetition penalty, HF style, 1 to disable (default: disabled)", default = 1.0)
parser.add_argument("-presp", "--presence_penalty", type = float, help = "Presence penalty, 0 to disable (default: disabled)", default = 0.0)
parser.add_argument("-freqp", "--frequency_penalty", type = float, help = "Frequency penalty, 0 to disable (default: disabled)", default = 0.0)
parser.add_argument("-penr", "--penalty_range", type = int, help = "Range for penalties, in tokens (default: 1024) ", default = 1024)
parser.add_argument("-minp", "--min_p", type = float, help = "Min-P truncation, 0 to disable (default: 0.08)", default = 0.08)
parser.add_argument("-topk", "--top_k", type = int, help = "Top-K truncation, 0 to disable (default: disabled)", default = 0)
parser.add_argument("-topp", "--top_p", type = float, help = "Top-P truncation, 1 to disable (default: disabled)", default = 1.0)
parser.add_argument("-tps", "--show_tps", action = "store_true", help = "Show tokens/second after every reply")
parser.add_argument("-prompt", "--prompt", type = str, help = "Run single prompt, then exit")
parser.add_argument("-save", "--save", type = str, help = "Save output to file (use with --prompt)")

View File

@@ -1,6 +1,7 @@
from . import Model, Config, Cache, Tokenizer
from .loader import SafetensorsCollection, VariantSafetensorsCollection
from .cache import CacheLayer_fp16, CacheLayer_quant
from .generator.sampler import ComboSampler
from argparse import ArgumentParser
import yaml
@@ -8,6 +9,7 @@ def add_args(
parser: ArgumentParser,
cache: bool = True,
default_cache_size = 8192,
add_sampling_args: bool = False
):
"""
Add standard model loading arguments to command line parser
@@ -20,6 +22,9 @@ def add_args(
:param default_cache_size:
Default value for -cs / --cache_size argument
:param add_sampling_args:
bool, add sampling arguments
"""
parser.add_argument("-m", "--model_dir", type = str, help = "Path to model directory", required = True)
parser.add_argument("-gs", "--gpu_split", type = str, help = "Maximum amount of VRAM to use per device, in GB.")
@@ -36,11 +41,46 @@ def add_args(
parser.add_argument("-lv", "--load_verbose", action = "store_true", help = "Verbose output while loading")
if add_sampling_args:
parser.add_argument("-temp", "--temperature", type = float, help = "Sampling temperature", default = 0.8)
parser.add_argument("-temp_first", "--temperature_first", action = "store_true", help = "Apply temperature before truncation")
parser.add_argument("-repp", "--repetition_penalty", type = float, help = "Repetition penalty, HF style, 1 to disable (default: disabled)", default = 1.0)
parser.add_argument("-presp", "--presence_penalty", type = float, help = "Presence penalty, 0 to disable (default: disabled)", default = 0.0)
parser.add_argument("-freqp", "--frequency_penalty", type = float, help = "Frequency penalty, 0 to disable (default: disabled)", default = 0.0)
parser.add_argument("-penr", "--penalty_range", type = int, help = "Range for penalties, in tokens (default: 1024) ", default = 1024)
parser.add_argument("-minp", "--min_p", type = float, help = "Min-P truncation, 0 to disable (default: 0.08)", default = 0.08)
parser.add_argument("-topk", "--top_k", type = int, help = "Top-K truncation, 0 to disable (default: disabled)", default = 0)
parser.add_argument("-topp", "--top_p", type = float, help = "Top-P truncation, 1 to disable (default: disabled)", default = 1.0)
if cache:
parser.add_argument("-cs", "--cache_size", type = int, help = f"Total cache size in tokens, default: {default_cache_size}", default = default_cache_size)
parser.add_argument("-cq", "--cache_quant", type = str, help = "Use quantized cache. Specify either kv_bits or k_bits,v_bits pair")
def get_arg_sampler(args):
"""
Create CompoSampler from default args above
:param args:
args from ArgumentParser
:return:
ComboSampler
"""
return ComboSampler(
rep_p = args.repetition_penalty,
pres_p = args.presence_penalty,
freq_p = args.frequency_penalty,
rep_sustain_range = args.penalty_range,
rep_decay_range = args.penalty_range,
temperature = args.temperature,
min_p = args.min_p,
top_k = args.top_k,
top_p = args.top_p,
temp_last = not args.temperature_first,
)
def init(
args,
load_tokenizer: bool = True,