add dry sampler (#513)

* add dry sampler

* use vocab instead of model in dry_init function

* fix compile error for build test

---------

Co-authored-by: firecoperana <firecoperana>
This commit is contained in:
firecoperana
2025-06-19 02:24:53 -05:00
committed by GitHub
parent c5368148cf
commit 3f111ad7bb
21 changed files with 743 additions and 36 deletions

View File

@@ -666,6 +666,47 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
sparams.top_n_sigma = std::stof(argv[i]);
return true;
}
if (arg == "--dry-multiplier") {
CHECK_ARG
sparams.dry_multiplier = std::stof(argv[i]);
return true;
}
if (arg == "--dry-base") {
CHECK_ARG
sparams.dry_base = std::stof(argv[i]);
return true;
}
if (arg == "--dry-allowed-length") {
CHECK_ARG
sparams.dry_allowed_length = std::stof(argv[i]);
return true;
}
if (arg == "--dry-penalty-last-n") {
CHECK_ARG
sparams.dry_penalty_last_n = std::stof(argv[i]);
return true;
}
if (arg == "--dry-sequence-breaker") {
CHECK_ARG
static bool defaults_cleared = false;
if (!defaults_cleared) {
params.sparams.dry_sequence_breakers.clear();
defaults_cleared = true;
}
std::string value= std::string(argv[i]);
if (value == "none") {
params.sparams.dry_sequence_breakers.clear();
}
else {
for (size_t i; i < value.size(); i++)
{
params.sparams.dry_sequence_breakers.emplace_back(""+value[i]);
}
}
return true;
}
if (arg == "--cfg-negative-prompt") {
CHECK_ARG
sparams.cfg_negative_prompt = argv[i];
@@ -2326,6 +2367,11 @@ struct llama_init_result llama_init_from_gpt_params(gpt_params & params) {
params.sparams.logit_bias[llama_token_eos(model)] = -INFINITY;
}
if (params.sparams.dry_penalty_last_n == -1) {
LOG("%s: setting dry_penalty_last_n to ctx_size = %d\n", __func__, llama_n_ctx(lctx));
params.sparams.dry_penalty_last_n = llama_n_ctx(lctx);
}
if (params.warmup) {
LOG("warming up the model with an empty run\n");
@@ -3389,6 +3435,10 @@ void yaml_dump_non_result_info(FILE * stream, const gpt_params & params, const l
fprintf(stream, "chunks: %d # default: -1 (unlimited)\n", params.n_chunks);
fprintf(stream, "color: %s # default: false\n", params.use_color ? "true" : "false");
fprintf(stream, "ctx_size: %d # default: 512\n", params.n_ctx);
fprintf(stream, "dry_allowed_length: %d # default: 2\n", sparams.dry_allowed_length);
fprintf(stream, "dry_base: %.2f # default: 1.75\n", sparams.dry_base);
fprintf(stream, "dry_multiplier: %.1f # default: 0.0\n", sparams.dry_multiplier);
fprintf(stream, "dry_penalty_last_n: %d # default: -1 (0 = disable, -1 = context size)\n", sparams.dry_penalty_last_n);
fprintf(stream, "escape: %s # default: false\n", params.escape ? "true" : "false");
fprintf(stream, "file: # never logged, see prompt instead. Can still be specified for input.\n");
fprintf(stream, "frequency_penalty: %f # default: 0.0 \n", sparams.penalty_freq);