add dry sampler (#513)

* add dry sampler

* use vocab instead of model in dry_init function

* fix compile error for build test

---------

Co-authored-by: firecoperana <firecoperana>
This commit is contained in:
firecoperana
2025-06-19 02:24:53 -05:00
committed by GitHub
parent c5368148cf
commit 3f111ad7bb
21 changed files with 743 additions and 36 deletions

View File

@@ -1,7 +1,7 @@
#pragma once
#include "llama-impl.h"
#include <unordered_map>
struct llama_sampling {
llama_sampling(int32_t n_vocab) : n_vocab(n_vocab) {}
@@ -35,6 +35,34 @@ void llama_sample_temp_impl (struct llama_sampling * smpl, llama_token_data_
void llama_sample_xtc_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, float probability, float threshold, size_t min_keep);
void llama_sample_top_n_sigma_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float top_n_sigma);
struct llama_sampler_dry {
int32_t total_context_size;
const float dry_multiplier;
const float dry_base;
const int32_t dry_allowed_length;
const int32_t dry_penalty_last_n;
std::unordered_multimap<llama_token, std::vector<llama_token>> dry_processed_breakers;
std::vector<int> dry_repeat_count;
std::unordered_map<llama_token, int> dry_max_token_repeat;
ring_buffer<llama_token> last_tokens;
};
struct llama_sampler_dry * llama_sampler_init_dry_impl(
const struct llama_vocab & vocab,
int32_t context_size,
float dry_multiplier,
float dry_base,
int32_t dry_allowed_length,
int32_t dry_penalty_last_n,
const char ** seq_breakers,
size_t num_breakers);
void llama_sampler_dry_apply(struct llama_sampler_dry* smpl, llama_token_data_array* cur_p);
void llama_sample_repetition_penalties_impl(
struct llama_sampling * smpl,
llama_token_data_array * candidates,
@@ -56,3 +84,5 @@ llama_token llama_sample_token_greedy_impl (struct llama_sampling * smpl, ll
llama_token llama_sample_token_with_rng_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, std::mt19937 & rng);
llama_token llama_sample_token_impl (struct llama_sampling * smpl, llama_token_data_array * candidates);