mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-04-30 03:11:51 +00:00
add dry sampler (#513)
* add dry sampler * use vocab instead of model in dry_init function * fix compile error for build test --------- Co-authored-by: firecoperana <firecoperana>
This commit is contained in:
@@ -35,11 +35,16 @@ typedef struct llama_sampling_params {
|
||||
float temp = 0.80f; // <= 0.0 to sample greedily, 0.0 to not output probabilities
|
||||
float dynatemp_range = 0.00f; // 0.0 = disabled
|
||||
float dynatemp_exponent = 1.00f; // controls how entropy maps to temperature in dynamic temperature sampler
|
||||
int32_t penalty_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size)
|
||||
int32_t penalty_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size)
|
||||
float penalty_repeat = 1.00f; // 1.0 = disabled
|
||||
float penalty_freq = 0.00f; // 0.0 = disabled
|
||||
float penalty_present = 0.00f; // 0.0 = disabled
|
||||
int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
|
||||
float dry_multiplier = 0.0f; // 0.0 = disabled; DRY repetition penalty for tokens extending repetition:
|
||||
float dry_base = 1.75f; // 0.0 = disabled; multiplier * base ^ (length of sequence before token - allowed length)
|
||||
int32_t dry_allowed_length = 2; // tokens extending repetitions beyond this receive penalty
|
||||
int32_t dry_penalty_last_n = -1; // how many tokens to scan for repetitions (0 = disable penalty, -1 = context size)
|
||||
int32_t total_context_size = 16840;
|
||||
int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
|
||||
float mirostat_tau = 5.00f; // target entropy
|
||||
float mirostat_eta = 0.10f; // learning rate
|
||||
float xtc_probability = 0.0f; // xtc probability
|
||||
@@ -48,12 +53,16 @@ typedef struct llama_sampling_params {
|
||||
bool penalize_nl = false; // consider newlines as a repeatable token
|
||||
uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampling_context
|
||||
|
||||
std::vector<std::string> dry_sequence_breakers = { "\n", ":", "\"", "*" }; // default sequence breakers for DRY
|
||||
|
||||
std::vector<llama_sampler_type> samplers_sequence = {
|
||||
llama_sampler_type::DRY,
|
||||
llama_sampler_type::TOP_K,
|
||||
llama_sampler_type::TFS_Z,
|
||||
llama_sampler_type::TYPICAL_P,
|
||||
llama_sampler_type::TOP_P,
|
||||
llama_sampler_type::MIN_P,
|
||||
llama_sampler_type::XTC,
|
||||
llama_sampler_type::TOP_N_SIGMA,
|
||||
llama_sampler_type::TEMPERATURE
|
||||
};
|
||||
@@ -88,6 +97,8 @@ struct llama_sampling_context {
|
||||
// TODO: replace with ring-buffer
|
||||
std::vector<llama_token> prev;
|
||||
std::vector<llama_token_data> cur;
|
||||
llama_sampler_dry* smpl;
|
||||
|
||||
size_t n_valid; // Number of correct top tokens with correct probabilities.
|
||||
|
||||
std::mt19937 rng;
|
||||
@@ -96,7 +107,7 @@ struct llama_sampling_context {
|
||||
#include "common.h"
|
||||
|
||||
// Create a new sampling context instance.
|
||||
struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_params & params);
|
||||
struct llama_sampling_context * llama_sampling_init(const struct llama_vocab* vocab, const struct llama_sampling_params & params);
|
||||
|
||||
void llama_sampling_free(struct llama_sampling_context * ctx);
|
||||
|
||||
|
||||
Reference in New Issue
Block a user