add dry sampler (#513)

* add dry sampler

* use vocab instead of model in dry_init function

* fix compile error for build test

---------

Co-authored-by: firecoperana <firecoperana>
This commit is contained in:
firecoperana
2025-06-19 02:24:53 -05:00
committed by GitHub
parent c5368148cf
commit 3f111ad7bb
21 changed files with 743 additions and 36 deletions

View File

@@ -20849,6 +20849,10 @@ enum llama_vocab_type llama_vocab_type(const struct llama_model * model) {
return model->vocab.type;
}
const struct llama_vocab* llama_get_model_vocab(const struct llama_model* model) {
return &model->vocab;
}
enum llama_rope_type llama_rope_type(const struct llama_model * model) {
switch (model->arch) {
// these models do not use RoPE
@@ -23280,6 +23284,11 @@ void llama_sample_top_n_sigma(struct llama_context * ctx, llama_token_data_array
llama_sample_top_n_sigma_impl(ctx ? &ctx->sampling : nullptr, candidates_p, top_n_sigma);
}
void llama_sample_dry(struct llama_context* ctx, struct llama_sampler_dry* smpl, llama_token_data_array* candidates_p) {
llama_sampler_dry_apply(smpl, candidates_p);
}
void llama_sample_repetition_penalties(
struct llama_context * ctx,
llama_token_data_array * candidates,
@@ -23327,6 +23336,42 @@ int llama_split_path(char * split_path, size_t maxlen, const char * path_prefix,
return 0;
}
struct llama_sampler_dry * llama_sampler_init_dry(const struct llama_vocab* vocab, float dry_multiplier, float dry_base, int32_t dry_allowed_length, int32_t dry_penalty_last_n, const char** seq_breakers, size_t num_breakers) {
return llama_sampler_init_dry_impl(*vocab, vocab->n_tokens(), dry_multiplier, dry_base, dry_allowed_length, dry_penalty_last_n, seq_breakers, num_breakers);
}
void llama_sampler_dry_reset(struct llama_sampler_dry* smpl) {
smpl->last_tokens.clear();
smpl->dry_repeat_count.clear();
smpl->dry_max_token_repeat.clear();
}
void llama_sampler_dry_free(struct llama_sampler_dry* smpl) {
delete smpl;
}
struct llama_sampler_dry* llama_sampler_dry_clone(struct llama_sampler_dry* smpl) {
// nullptr is passed as vocab because it is only needed for raw sequence breaker processing, which we have already done and will be copying
auto* result = llama_sampler_init_dry(nullptr, smpl->dry_multiplier, smpl->dry_base, smpl->dry_allowed_length, smpl->dry_penalty_last_n, NULL, 0);
// Copy the state, including the processed breakers
{
auto* result_ctx = smpl;
result_ctx->dry_processed_breakers = smpl->dry_processed_breakers;
result_ctx->dry_repeat_count = smpl->dry_repeat_count;
result_ctx->dry_max_token_repeat = smpl->dry_max_token_repeat;
result_ctx->last_tokens = smpl->last_tokens;
}
return result;
}
void llama_sampler_dry_accept(struct llama_sampler_dry* smpl, llama_token token) {
if (smpl->dry_multiplier == 0.0f || smpl->dry_base < 1.0f || smpl->dry_penalty_last_n == 0) {
return;
}
smpl->last_tokens.push_back(token);
}
int llama_split_prefix(char * dest, size_t maxlen, const char * split_path, int split_no, int split_count) {
std::string str_split_path(split_path);
char postfix[32];