mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-04-30 11:21:56 +00:00
add dry sampler (#513)
* add dry sampler * use vocab instead of model in dry_init function * fix compile error for build test --------- Co-authored-by: firecoperana <firecoperana>
This commit is contained in:
@@ -666,6 +666,47 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
|
||||
sparams.top_n_sigma = std::stof(argv[i]);
|
||||
return true;
|
||||
}
|
||||
|
||||
if (arg == "--dry-multiplier") {
|
||||
CHECK_ARG
|
||||
sparams.dry_multiplier = std::stof(argv[i]);
|
||||
return true;
|
||||
}
|
||||
if (arg == "--dry-base") {
|
||||
CHECK_ARG
|
||||
sparams.dry_base = std::stof(argv[i]);
|
||||
return true;
|
||||
}
|
||||
if (arg == "--dry-allowed-length") {
|
||||
CHECK_ARG
|
||||
sparams.dry_allowed_length = std::stof(argv[i]);
|
||||
return true;
|
||||
}
|
||||
if (arg == "--dry-penalty-last-n") {
|
||||
CHECK_ARG
|
||||
sparams.dry_penalty_last_n = std::stof(argv[i]);
|
||||
return true;
|
||||
}
|
||||
if (arg == "--dry-sequence-breaker") {
|
||||
CHECK_ARG
|
||||
static bool defaults_cleared = false;
|
||||
|
||||
if (!defaults_cleared) {
|
||||
params.sparams.dry_sequence_breakers.clear();
|
||||
defaults_cleared = true;
|
||||
}
|
||||
std::string value= std::string(argv[i]);
|
||||
if (value == "none") {
|
||||
params.sparams.dry_sequence_breakers.clear();
|
||||
}
|
||||
else {
|
||||
for (size_t i; i < value.size(); i++)
|
||||
{
|
||||
params.sparams.dry_sequence_breakers.emplace_back(""+value[i]);
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
if (arg == "--cfg-negative-prompt") {
|
||||
CHECK_ARG
|
||||
sparams.cfg_negative_prompt = argv[i];
|
||||
@@ -2326,6 +2367,11 @@ struct llama_init_result llama_init_from_gpt_params(gpt_params & params) {
|
||||
params.sparams.logit_bias[llama_token_eos(model)] = -INFINITY;
|
||||
}
|
||||
|
||||
if (params.sparams.dry_penalty_last_n == -1) {
|
||||
LOG("%s: setting dry_penalty_last_n to ctx_size = %d\n", __func__, llama_n_ctx(lctx));
|
||||
params.sparams.dry_penalty_last_n = llama_n_ctx(lctx);
|
||||
}
|
||||
|
||||
if (params.warmup) {
|
||||
LOG("warming up the model with an empty run\n");
|
||||
|
||||
@@ -3389,6 +3435,10 @@ void yaml_dump_non_result_info(FILE * stream, const gpt_params & params, const l
|
||||
fprintf(stream, "chunks: %d # default: -1 (unlimited)\n", params.n_chunks);
|
||||
fprintf(stream, "color: %s # default: false\n", params.use_color ? "true" : "false");
|
||||
fprintf(stream, "ctx_size: %d # default: 512\n", params.n_ctx);
|
||||
fprintf(stream, "dry_allowed_length: %d # default: 2\n", sparams.dry_allowed_length);
|
||||
fprintf(stream, "dry_base: %.2f # default: 1.75\n", sparams.dry_base);
|
||||
fprintf(stream, "dry_multiplier: %.1f # default: 0.0\n", sparams.dry_multiplier);
|
||||
fprintf(stream, "dry_penalty_last_n: %d # default: -1 (0 = disable, -1 = context size)\n", sparams.dry_penalty_last_n);
|
||||
fprintf(stream, "escape: %s # default: false\n", params.escape ? "true" : "false");
|
||||
fprintf(stream, "file: # never logged, see prompt instead. Can still be specified for input.\n");
|
||||
fprintf(stream, "frequency_penalty: %f # default: 0.0 \n", sparams.penalty_freq);
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
#define LLAMA_API_INTERNAL
|
||||
#include "sampling.h"
|
||||
#include "llama-vocab.h"
|
||||
#include <random>
|
||||
|
||||
struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_params & params) {
|
||||
struct llama_sampling_context * llama_sampling_init(const struct llama_vocab* vocab, const struct llama_sampling_params & params) {
|
||||
struct llama_sampling_context * result = new llama_sampling_context();
|
||||
|
||||
result->params = params;
|
||||
@@ -36,13 +37,32 @@ struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_
|
||||
}
|
||||
result->grammar = grammar;
|
||||
}
|
||||
|
||||
result->prev.resize(params.n_prev);
|
||||
|
||||
result->n_valid = 0;
|
||||
|
||||
// init DRY
|
||||
for (const auto& cnstr : params.samplers_sequence)
|
||||
{
|
||||
switch (cnstr)
|
||||
{
|
||||
case llama_sampler_type::DRY:
|
||||
{
|
||||
std::vector<const char*> c_breakers;
|
||||
c_breakers.reserve(params.dry_sequence_breakers.size());
|
||||
for (const auto& str : params.dry_sequence_breakers)
|
||||
{
|
||||
c_breakers.push_back(str.c_str());
|
||||
}
|
||||
result->smpl=llama_sampler_init_dry(vocab, params.dry_multiplier, params.dry_base, params.dry_allowed_length, params.dry_penalty_last_n, c_breakers.data(), c_breakers.size());
|
||||
|
||||
break;
|
||||
}
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
llama_sampling_set_rng_seed(result, params.seed);
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
@@ -50,7 +70,8 @@ void llama_sampling_free(struct llama_sampling_context * ctx) {
|
||||
if (ctx->grammar != NULL) {
|
||||
llama_grammar_free(ctx->grammar);
|
||||
}
|
||||
|
||||
if (ctx->smpl !=NULL)
|
||||
llama_sampler_dry_free(ctx->smpl);
|
||||
delete ctx;
|
||||
}
|
||||
|
||||
@@ -75,6 +96,7 @@ void llama_sampling_reset(llama_sampling_context * ctx) {
|
||||
std::fill(ctx->prev.begin(), ctx->prev.end(), 0);
|
||||
ctx->cur.clear();
|
||||
ctx->n_valid = 0;
|
||||
llama_sampler_dry_reset(ctx->smpl);
|
||||
}
|
||||
|
||||
void llama_sampling_set_rng_seed(struct llama_sampling_context * ctx, uint32_t seed) {
|
||||
@@ -95,6 +117,7 @@ void llama_sampling_cp(llama_sampling_context * src, llama_sampling_context * ds
|
||||
}
|
||||
|
||||
dst->prev = src->prev;
|
||||
dst->smpl = llama_sampler_dry_clone(src->smpl);
|
||||
}
|
||||
|
||||
llama_token llama_sampling_last(llama_sampling_context * ctx) {
|
||||
@@ -149,6 +172,7 @@ std::string llama_sampling_order_print(const llama_sampling_params & params) {
|
||||
|
||||
std::string llama_sampling_type_to_str(llama_sampler_type sampler_type) {
|
||||
switch (sampler_type) {
|
||||
case llama_sampler_type::DRY: return "dry";
|
||||
case llama_sampler_type::TOP_K: return "top_k";
|
||||
case llama_sampler_type::TFS_Z: return "tfs_z";
|
||||
case llama_sampler_type::TYPICAL_P: return "typical_p";
|
||||
@@ -163,6 +187,7 @@ std::string llama_sampling_type_to_str(llama_sampler_type sampler_type) {
|
||||
|
||||
std::vector<llama_sampler_type> llama_sampling_types_from_names(const std::vector<std::string> & names, bool allow_alt_names) {
|
||||
std::unordered_map<std::string, llama_sampler_type> sampler_canonical_name_map {
|
||||
{"dry", llama_sampler_type::DRY},
|
||||
{"top_k", llama_sampler_type::TOP_K},
|
||||
{"top_p", llama_sampler_type::TOP_P},
|
||||
{"typical_p", llama_sampler_type::TYPICAL_P},
|
||||
@@ -176,6 +201,7 @@ std::vector<llama_sampler_type> llama_sampling_types_from_names(const std::vecto
|
||||
// since samplers names are written multiple ways
|
||||
// make it ready for both system names and input names
|
||||
std::unordered_map<std::string, llama_sampler_type> sampler_alt_name_map {
|
||||
{"dry", llama_sampler_type::DRY},
|
||||
{"top-k", llama_sampler_type::TOP_K},
|
||||
{"top-p", llama_sampler_type::TOP_P},
|
||||
{"nucleus", llama_sampler_type::TOP_P},
|
||||
@@ -215,6 +241,7 @@ std::vector<llama_sampler_type> llama_sampling_types_from_names(const std::vecto
|
||||
|
||||
std::vector<llama_sampler_type> llama_sampling_types_from_chars(const std::string & names_string) {
|
||||
std::unordered_map<char, llama_sampler_type> sampler_name_map {
|
||||
{'d', llama_sampler_type::DRY},
|
||||
{'k', llama_sampler_type::TOP_K},
|
||||
{'p', llama_sampler_type::TOP_P},
|
||||
{'y', llama_sampler_type::TYPICAL_P},
|
||||
@@ -238,25 +265,28 @@ std::vector<llama_sampler_type> llama_sampling_types_from_chars(const std::strin
|
||||
|
||||
// no reasons to expose this function in header
|
||||
static void sampler_queue(
|
||||
struct llama_context * ctx_main,
|
||||
const llama_sampling_params & params,
|
||||
llama_token_data_array & cur_p,
|
||||
size_t min_keep) {
|
||||
const float temp = params.temp;
|
||||
const float dynatemp_range = params.dynatemp_range;
|
||||
struct llama_context* ctx_main,
|
||||
const llama_sampling_params& params,
|
||||
llama_sampling_context * ctx_sampling,
|
||||
llama_token_data_array& cur_p,
|
||||
size_t min_keep) {
|
||||
const float temp = params.temp;
|
||||
const float dynatemp_range = params.dynatemp_range;
|
||||
const float dynatemp_exponent = params.dynatemp_exponent;
|
||||
const int32_t top_k = params.top_k;
|
||||
const float top_p = params.top_p;
|
||||
const float min_p = params.min_p;
|
||||
const float tfs_z = params.tfs_z;
|
||||
const float typical_p = params.typical_p;
|
||||
const float xtc_probability = params.xtc_probability;
|
||||
const float xtc_threshold = params.xtc_threshold;
|
||||
const float top_n_sigma = params.top_n_sigma;
|
||||
const int32_t top_k = params.top_k;
|
||||
const float top_p = params.top_p;
|
||||
const float min_p = params.min_p;
|
||||
const float tfs_z = params.tfs_z;
|
||||
const float typical_p = params.typical_p;
|
||||
const float xtc_probability = params.xtc_probability;
|
||||
const float xtc_threshold = params.xtc_threshold;
|
||||
const float top_n_sigma = params.top_n_sigma;
|
||||
|
||||
const std::vector<llama_sampler_type> & samplers_sequence = params.samplers_sequence;
|
||||
|
||||
for (auto sampler_type : samplers_sequence) {
|
||||
switch (sampler_type) {
|
||||
case llama_sampler_type::DRY : llama_sample_dry (ctx_main, ctx_sampling->smpl, &cur_p); break;
|
||||
case llama_sampler_type::TOP_K : llama_sample_top_k (ctx_main, &cur_p, top_k, min_keep); break;
|
||||
case llama_sampler_type::TFS_Z : llama_sample_tail_free(ctx_main, &cur_p, tfs_z, min_keep); break;
|
||||
case llama_sampler_type::TYPICAL_P : llama_sample_typical (ctx_main, &cur_p, typical_p, min_keep); break;
|
||||
@@ -317,7 +347,7 @@ static llama_token llama_sampling_sample_impl(
|
||||
// temperature sampling
|
||||
size_t min_keep = std::max(1, params.min_keep);
|
||||
|
||||
sampler_queue(ctx_main, params, cur_p, min_keep);
|
||||
sampler_queue(ctx_main, params,ctx_sampling, cur_p, min_keep);
|
||||
|
||||
id = llama_sample_token_with_rng(ctx_main, &cur_p, ctx_sampling->rng);
|
||||
|
||||
@@ -472,4 +502,5 @@ void llama_sampling_accept(
|
||||
if (ctx_sampling->grammar != NULL && apply_grammar) {
|
||||
llama_grammar_accept_token(ctx_sampling->grammar, ctx_main, id);
|
||||
}
|
||||
llama_sampler_dry_accept(ctx_sampling->smpl, id);
|
||||
}
|
||||
|
||||
@@ -35,11 +35,16 @@ typedef struct llama_sampling_params {
|
||||
float temp = 0.80f; // <= 0.0 to sample greedily, 0.0 to not output probabilities
|
||||
float dynatemp_range = 0.00f; // 0.0 = disabled
|
||||
float dynatemp_exponent = 1.00f; // controls how entropy maps to temperature in dynamic temperature sampler
|
||||
int32_t penalty_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size)
|
||||
int32_t penalty_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size)
|
||||
float penalty_repeat = 1.00f; // 1.0 = disabled
|
||||
float penalty_freq = 0.00f; // 0.0 = disabled
|
||||
float penalty_present = 0.00f; // 0.0 = disabled
|
||||
int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
|
||||
float dry_multiplier = 0.0f; // 0.0 = disabled; DRY repetition penalty for tokens extending repetition:
|
||||
float dry_base = 1.75f; // 0.0 = disabled; multiplier * base ^ (length of sequence before token - allowed length)
|
||||
int32_t dry_allowed_length = 2; // tokens extending repetitions beyond this receive penalty
|
||||
int32_t dry_penalty_last_n = -1; // how many tokens to scan for repetitions (0 = disable penalty, -1 = context size)
|
||||
int32_t total_context_size = 16840;
|
||||
int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
|
||||
float mirostat_tau = 5.00f; // target entropy
|
||||
float mirostat_eta = 0.10f; // learning rate
|
||||
float xtc_probability = 0.0f; // xtc probability
|
||||
@@ -48,12 +53,16 @@ typedef struct llama_sampling_params {
|
||||
bool penalize_nl = false; // consider newlines as a repeatable token
|
||||
uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampling_context
|
||||
|
||||
std::vector<std::string> dry_sequence_breakers = { "\n", ":", "\"", "*" }; // default sequence breakers for DRY
|
||||
|
||||
std::vector<llama_sampler_type> samplers_sequence = {
|
||||
llama_sampler_type::DRY,
|
||||
llama_sampler_type::TOP_K,
|
||||
llama_sampler_type::TFS_Z,
|
||||
llama_sampler_type::TYPICAL_P,
|
||||
llama_sampler_type::TOP_P,
|
||||
llama_sampler_type::MIN_P,
|
||||
llama_sampler_type::XTC,
|
||||
llama_sampler_type::TOP_N_SIGMA,
|
||||
llama_sampler_type::TEMPERATURE
|
||||
};
|
||||
@@ -88,6 +97,8 @@ struct llama_sampling_context {
|
||||
// TODO: replace with ring-buffer
|
||||
std::vector<llama_token> prev;
|
||||
std::vector<llama_token_data> cur;
|
||||
llama_sampler_dry* smpl;
|
||||
|
||||
size_t n_valid; // Number of correct top tokens with correct probabilities.
|
||||
|
||||
std::mt19937 rng;
|
||||
@@ -96,7 +107,7 @@ struct llama_sampling_context {
|
||||
#include "common.h"
|
||||
|
||||
// Create a new sampling context instance.
|
||||
struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_params & params);
|
||||
struct llama_sampling_context * llama_sampling_init(const struct llama_vocab* vocab, const struct llama_sampling_params & params);
|
||||
|
||||
void llama_sampling_free(struct llama_sampling_context * ctx);
|
||||
|
||||
|
||||
Reference in New Issue
Block a user