mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-10 16:30:12 +00:00
add dry sampler (#513)
* add dry sampler * use vocab instead of model in dry_init function * fix compile error for build test --------- Co-authored-by: firecoperana <firecoperana>
This commit is contained in:
@@ -20849,6 +20849,10 @@ enum llama_vocab_type llama_vocab_type(const struct llama_model * model) {
|
||||
return model->vocab.type;
|
||||
}
|
||||
|
||||
const struct llama_vocab* llama_get_model_vocab(const struct llama_model* model) {
|
||||
return &model->vocab;
|
||||
}
|
||||
|
||||
enum llama_rope_type llama_rope_type(const struct llama_model * model) {
|
||||
switch (model->arch) {
|
||||
// these models do not use RoPE
|
||||
@@ -23280,6 +23284,11 @@ void llama_sample_top_n_sigma(struct llama_context * ctx, llama_token_data_array
|
||||
llama_sample_top_n_sigma_impl(ctx ? &ctx->sampling : nullptr, candidates_p, top_n_sigma);
|
||||
}
|
||||
|
||||
|
||||
void llama_sample_dry(struct llama_context* ctx, struct llama_sampler_dry* smpl, llama_token_data_array* candidates_p) {
|
||||
llama_sampler_dry_apply(smpl, candidates_p);
|
||||
}
|
||||
|
||||
void llama_sample_repetition_penalties(
|
||||
struct llama_context * ctx,
|
||||
llama_token_data_array * candidates,
|
||||
@@ -23327,6 +23336,42 @@ int llama_split_path(char * split_path, size_t maxlen, const char * path_prefix,
|
||||
return 0;
|
||||
}
|
||||
|
||||
struct llama_sampler_dry * llama_sampler_init_dry(const struct llama_vocab* vocab, float dry_multiplier, float dry_base, int32_t dry_allowed_length, int32_t dry_penalty_last_n, const char** seq_breakers, size_t num_breakers) {
|
||||
return llama_sampler_init_dry_impl(*vocab, vocab->n_tokens(), dry_multiplier, dry_base, dry_allowed_length, dry_penalty_last_n, seq_breakers, num_breakers);
|
||||
}
|
||||
|
||||
void llama_sampler_dry_reset(struct llama_sampler_dry* smpl) {
|
||||
smpl->last_tokens.clear();
|
||||
smpl->dry_repeat_count.clear();
|
||||
smpl->dry_max_token_repeat.clear();
|
||||
}
|
||||
|
||||
void llama_sampler_dry_free(struct llama_sampler_dry* smpl) {
|
||||
delete smpl;
|
||||
}
|
||||
|
||||
struct llama_sampler_dry* llama_sampler_dry_clone(struct llama_sampler_dry* smpl) {
|
||||
// nullptr is passed as vocab because it is only needed for raw sequence breaker processing, which we have already done and will be copying
|
||||
auto* result = llama_sampler_init_dry(nullptr, smpl->dry_multiplier, smpl->dry_base, smpl->dry_allowed_length, smpl->dry_penalty_last_n, NULL, 0);
|
||||
// Copy the state, including the processed breakers
|
||||
{
|
||||
auto* result_ctx = smpl;
|
||||
result_ctx->dry_processed_breakers = smpl->dry_processed_breakers;
|
||||
result_ctx->dry_repeat_count = smpl->dry_repeat_count;
|
||||
result_ctx->dry_max_token_repeat = smpl->dry_max_token_repeat;
|
||||
result_ctx->last_tokens = smpl->last_tokens;
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
void llama_sampler_dry_accept(struct llama_sampler_dry* smpl, llama_token token) {
|
||||
if (smpl->dry_multiplier == 0.0f || smpl->dry_base < 1.0f || smpl->dry_penalty_last_n == 0) {
|
||||
return;
|
||||
}
|
||||
smpl->last_tokens.push_back(token);
|
||||
}
|
||||
|
||||
int llama_split_prefix(char * dest, size_t maxlen, const char * split_path, int split_no, int split_count) {
|
||||
std::string str_split_path(split_path);
|
||||
char postfix[32];
|
||||
|
||||
Reference in New Issue
Block a user