mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-20 21:24:08 +00:00
spec : add self speculative decoding, ngram and refactor (#1261)
* spec : add self speculative decoding and ngram-mod and refactor common : use common_ prefix for common library function llama : use LLAMA_TOKEN_NULL spec : add self speculative decoding (no draft model required) + refactor spec : add ngram-mod spec : various improvements ton ngram-map + docs spec : fix the check-rate logic of ngram-simple common : add common_speculative_is_compat() spec : simplify time measurement using common_time_meas refactor common_sampler_init refactor common_token_to_piece refactor and fix cur_p bug clean up * spec : remove check rate * spec: show warnings instead of abort --------- Co-authored-by: firecoperana <firecoperana> Co-authored-by: Sascha Rogmann <59577610+srogmann@users.noreply.github.com>
This commit is contained in:
@@ -6,8 +6,10 @@
|
||||
#include <nlohmann/json.hpp>
|
||||
using json = nlohmann::ordered_json;
|
||||
|
||||
struct llama_sampling_context * common_sampler_init(const struct llama_vocab* vocab, const struct llama_sampling_params & params) {
|
||||
struct llama_sampling_context * result = new llama_sampling_context();
|
||||
struct common_sampler * common_sampler_init(const struct llama_model * model, const struct common_params_sampling & params) {
|
||||
const llama_vocab * vocab = llama_model_get_vocab(model);
|
||||
|
||||
struct common_sampler * result = new common_sampler();
|
||||
|
||||
result->params = params;
|
||||
result->grammar = nullptr;
|
||||
@@ -107,7 +109,7 @@ struct llama_sampling_context * common_sampler_init(const struct llama_vocab* vo
|
||||
return result;
|
||||
}
|
||||
|
||||
void common_sampler_free(struct llama_sampling_context * ctx) {
|
||||
void common_sampler_free(struct common_sampler * ctx) {
|
||||
if (ctx->grammar != NULL) {
|
||||
llama_grammar_free(ctx->grammar);
|
||||
}
|
||||
@@ -116,7 +118,7 @@ void common_sampler_free(struct llama_sampling_context * ctx) {
|
||||
delete ctx;
|
||||
}
|
||||
|
||||
static void llama_grammar_reset(llama_sampling_context * ctx) {
|
||||
static void llama_grammar_reset(common_sampler * ctx) {
|
||||
ctx->prev.clear();
|
||||
if (!ctx->grammar) {
|
||||
return;
|
||||
@@ -135,19 +137,19 @@ static void llama_grammar_reset(llama_sampling_context * ctx) {
|
||||
ctx->grammar = grammar_new;
|
||||
}
|
||||
|
||||
void common_sampler_reset(const struct llama_vocab * vocab, llama_sampling_context * ctx) {
|
||||
void common_sampler_reset(common_sampler * ctx) {
|
||||
llama_grammar_reset(ctx);
|
||||
llama_sampler_dry_reset(ctx->smpl);
|
||||
}
|
||||
|
||||
void llama_sampling_set_rng_seed(struct llama_sampling_context * ctx, uint32_t seed) {
|
||||
void llama_sampling_set_rng_seed(struct common_sampler * ctx, uint32_t seed) {
|
||||
if (seed == LLAMA_DEFAULT_SEED) {
|
||||
seed = std::random_device{}();
|
||||
}
|
||||
ctx->rng.seed(seed);
|
||||
}
|
||||
|
||||
void common_sampler_clone(llama_sampling_context * src, llama_sampling_context * dst) {
|
||||
void common_sampler_clone(common_sampler * src, common_sampler * dst) {
|
||||
if (dst->grammar) {
|
||||
llama_grammar_free(dst->grammar);
|
||||
dst->grammar = nullptr;
|
||||
@@ -163,11 +165,11 @@ void common_sampler_clone(llama_sampling_context * src, llama_sampling_context *
|
||||
dst->smpl = llama_sampler_dry_clone(src->smpl);
|
||||
}
|
||||
|
||||
llama_token llama_sampling_last(llama_sampling_context * ctx) {
|
||||
llama_token llama_sampling_last(common_sampler * ctx) {
|
||||
return ctx->prev.back();
|
||||
}
|
||||
|
||||
std::string llama_sampling_prev_str(llama_sampling_context * ctx_sampling, llama_context * ctx_main, int n) {
|
||||
std::string llama_sampling_prev_str(common_sampler * ctx_sampling, llama_context * ctx_main, int n) {
|
||||
const int size = ctx_sampling->prev.size();
|
||||
|
||||
n = std::min(n, size);
|
||||
@@ -181,7 +183,7 @@ std::string llama_sampling_prev_str(llama_sampling_context * ctx_sampling, llama
|
||||
return result;
|
||||
}
|
||||
|
||||
std::string llama_sampling_print(const llama_sampling_params & params) {
|
||||
std::string llama_sampling_print(const common_params_sampling & params) {
|
||||
char result[1024];
|
||||
|
||||
snprintf(result, sizeof(result),
|
||||
@@ -199,7 +201,7 @@ std::string llama_sampling_print(const llama_sampling_params & params) {
|
||||
return std::string(result);
|
||||
}
|
||||
|
||||
std::string llama_sampling_order_print(const llama_sampling_params & params) {
|
||||
std::string llama_sampling_order_print(const common_params_sampling & params) {
|
||||
std::string result = "CFG -> Penalties ";
|
||||
if (params.mirostat == 0) {
|
||||
for (auto sampler_type : params.samplers_sequence) {
|
||||
@@ -315,8 +317,8 @@ std::vector<llama_sampler_type> llama_sampling_types_from_chars(const std::strin
|
||||
// no reasons to expose this function in header
|
||||
static void sampler_queue(
|
||||
struct llama_context* ctx_main,
|
||||
const llama_sampling_params& params,
|
||||
llama_sampling_context * ctx_sampling,
|
||||
const common_params_sampling& params,
|
||||
common_sampler * ctx_sampling,
|
||||
llama_token_data_array& cur_p,
|
||||
size_t min_keep) {
|
||||
const float temp = params.temp;
|
||||
@@ -343,6 +345,7 @@ static void sampler_queue(
|
||||
case llama_sampler_type::MIN_P : llama_sample_min_p (ctx_main, &cur_p, min_p, min_keep); break;
|
||||
case llama_sampler_type::XTC : llama_sample_xtc (ctx_main, &cur_p, xtc_probability, xtc_threshold, min_keep); break;
|
||||
case llama_sampler_type::TOP_N_SIGMA: llama_sample_top_n_sigma(ctx_main, &cur_p, top_n_sigma); break;
|
||||
case llama_sampler_type::DIST : llama_sample_dist (ctx_main, &cur_p); break;
|
||||
case llama_sampler_type::TEMPERATURE:
|
||||
if (dynatemp_range > 0) {
|
||||
float dynatemp_min = std::max(0.0f, temp - dynatemp_range);
|
||||
@@ -364,12 +367,12 @@ static void sampler_queue(
|
||||
}
|
||||
|
||||
static llama_token llama_sampling_sample_impl(
|
||||
struct llama_sampling_context * ctx_sampling,
|
||||
struct common_sampler * ctx_sampling,
|
||||
struct llama_context * ctx_main,
|
||||
struct llama_context * ctx_cfg,
|
||||
const int idx,
|
||||
bool is_resampling) {
|
||||
const llama_sampling_params & params = ctx_sampling->params;
|
||||
const common_params_sampling & params = ctx_sampling->params;
|
||||
|
||||
const float temp = params.temp;
|
||||
const int mirostat = params.mirostat;
|
||||
@@ -378,7 +381,8 @@ static llama_token llama_sampling_sample_impl(
|
||||
const float adaptive_target = params.adaptive_target;
|
||||
|
||||
std::vector<float> original_logits;
|
||||
auto cur_p = llama_sampling_prepare(ctx_sampling, ctx_main, ctx_cfg, idx, /* apply_grammar= */ is_resampling, &original_logits);
|
||||
llama_sampling_prepare(ctx_sampling, ctx_main, ctx_cfg, idx, /* apply_grammar= */ is_resampling, &original_logits);
|
||||
llama_token_data_array & cur_p = ctx_sampling->cur_p;
|
||||
if (ctx_sampling->grammar != NULL && !is_resampling) {
|
||||
GGML_ASSERT(!original_logits.empty());
|
||||
}
|
||||
@@ -414,22 +418,9 @@ static llama_token llama_sampling_sample_impl(
|
||||
// temperature sampling
|
||||
size_t min_keep = std::max(1, params.min_keep);
|
||||
|
||||
sampler_queue(ctx_main, params,ctx_sampling, cur_p, min_keep);
|
||||
|
||||
sampler_queue(ctx_main, params,ctx_sampling, cur_p, min_keep);
|
||||
id = llama_sample_token_with_rng(ctx_main, &cur_p, ctx_sampling->rng);
|
||||
|
||||
//{
|
||||
// const int n_top = 10;
|
||||
// LOG("top %d candidates:\n", n_top);
|
||||
|
||||
// for (int i = 0; i < n_top; i++) {
|
||||
// const llama_token id = cur_p.data[i].id;
|
||||
// (void)id; // To avoid a warning that id is unused when logging is disabled.
|
||||
// LOG(" - %5d: '%12s' (%.3f)\n", id, common_token_to_piece(ctx_main, id).c_str(), cur_p.data[i].p);
|
||||
// }
|
||||
//}
|
||||
|
||||
//LOG("sampled token: %5d: '%s'\n", id, common_token_to_piece(ctx_main, id).c_str());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -457,20 +448,19 @@ static llama_token llama_sampling_sample_impl(
|
||||
return llama_sampling_sample_impl(ctx_sampling, ctx_main, ctx_cfg, idx, /* is_resampling= */ true);
|
||||
}
|
||||
}
|
||||
|
||||
ctx_sampling->n_valid = temp == 0.0f ? 0 : cur_p.size;
|
||||
|
||||
return id;
|
||||
}
|
||||
|
||||
static llama_token_data_array llama_sampling_prepare_impl(
|
||||
struct llama_sampling_context * ctx_sampling,
|
||||
struct common_sampler * ctx_sampling,
|
||||
struct llama_context * ctx_main,
|
||||
struct llama_context * ctx_cfg,
|
||||
const int idx,
|
||||
bool apply_grammar,
|
||||
std::vector<float> * original_logits) {
|
||||
const llama_sampling_params & params = ctx_sampling->params;
|
||||
const common_params_sampling & params = ctx_sampling->params;
|
||||
|
||||
const int n_vocab = llama_n_vocab(llama_get_model(ctx_main));
|
||||
|
||||
@@ -541,8 +531,8 @@ static llama_token_data_array llama_sampling_prepare_impl(
|
||||
return cur_p;
|
||||
}
|
||||
|
||||
llama_token common_sampler_sample(
|
||||
struct llama_sampling_context * ctx_sampling,
|
||||
llama_token common_sampler_sample_legacy(
|
||||
struct common_sampler * ctx_sampling,
|
||||
struct llama_context * ctx_main,
|
||||
struct llama_context * ctx_cfg,
|
||||
const int idx) {
|
||||
@@ -550,8 +540,17 @@ llama_token common_sampler_sample(
|
||||
return llama_sampling_sample_impl(ctx_sampling, ctx_main, ctx_cfg, idx, /* is_resampling= */ false);
|
||||
}
|
||||
|
||||
llama_token common_sampler_sample(
|
||||
struct common_sampler * ctx_sampling,
|
||||
struct llama_context * ctx_main,
|
||||
const int idx,
|
||||
bool grammar_first) {
|
||||
// Call the implementation function with is_resampling set to false by default
|
||||
return llama_sampling_sample_impl(ctx_sampling, ctx_main, nullptr, idx, /* is_resampling= */ grammar_first);
|
||||
}
|
||||
|
||||
llama_token_data_array llama_sampling_prepare(
|
||||
struct llama_sampling_context * ctx_sampling,
|
||||
struct common_sampler * ctx_sampling,
|
||||
struct llama_context * ctx_main,
|
||||
struct llama_context * ctx_cfg,
|
||||
const int idx,
|
||||
@@ -561,7 +560,7 @@ llama_token_data_array llama_sampling_prepare(
|
||||
}
|
||||
|
||||
void common_sampler_accept(
|
||||
struct llama_sampling_context * ctx_sampling,
|
||||
struct common_sampler * ctx_sampling,
|
||||
struct llama_context * ctx_main,
|
||||
llama_token id,
|
||||
bool apply_grammar) {
|
||||
@@ -579,11 +578,32 @@ void common_sampler_accept(
|
||||
}
|
||||
}
|
||||
|
||||
llama_token_data_array * common_sampler_get_candidates(struct llama_sampling_context * ctx_sampling) {
|
||||
return &ctx_sampling->cur_p;
|
||||
llama_token_data_array * common_sampler_get_candidates(struct common_sampler * gsmpl, bool do_sort) {
|
||||
auto * res = &gsmpl->cur_p;
|
||||
|
||||
if (do_sort && !res->sorted) {
|
||||
// remember the selected token before sorting
|
||||
const llama_token id = res->data[res->selected].id;
|
||||
|
||||
std::sort(res->data, res->data + res->size, [](const llama_token_data & a, const llama_token_data & b) {
|
||||
return a.p > b.p;
|
||||
});
|
||||
|
||||
// restore the selected token after sorting
|
||||
for (size_t i = 0; i < res->size; ++i) {
|
||||
if (res->data[i].id == id) {
|
||||
res->selected = i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
res->sorted = true;
|
||||
}
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
std::vector<llama_token> llama_sampling_sample_and_accept_n(struct llama_sampling_context * gsmpl, struct llama_context * ctx, const std::vector<llama_token> & draft) {
|
||||
std::vector<llama_token> llama_sampling_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector<llama_token> & draft) {
|
||||
std::vector<int> idxs(draft.size() + 1);
|
||||
for (size_t i = 0; i < idxs.size(); ++i) {
|
||||
idxs[i] = i;
|
||||
@@ -592,7 +612,7 @@ std::vector<llama_token> llama_sampling_sample_and_accept_n(struct llama_samplin
|
||||
return common_sampler_sample_and_accept_n(gsmpl, ctx, idxs, draft);
|
||||
}
|
||||
|
||||
std::vector<llama_token> common_sampler_sample_and_accept_n(struct llama_sampling_context * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const std::vector<llama_token> & draft) {
|
||||
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const std::vector<llama_token> & draft, bool grammar_first) {
|
||||
GGML_ASSERT(idxs.size() == draft.size() + 1 && "idxs.size() must be draft.size() + 1");
|
||||
|
||||
std::vector<llama_token> result;
|
||||
@@ -600,7 +620,7 @@ std::vector<llama_token> common_sampler_sample_and_accept_n(struct llama_samplin
|
||||
|
||||
size_t i = 0;
|
||||
for (; i < draft.size(); i++) {
|
||||
const llama_token id = common_sampler_sample(gsmpl, ctx, nullptr, idxs[i]);
|
||||
const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i], grammar_first);
|
||||
|
||||
common_sampler_accept(gsmpl, ctx, id, true);
|
||||
|
||||
@@ -612,7 +632,7 @@ std::vector<llama_token> common_sampler_sample_and_accept_n(struct llama_samplin
|
||||
}
|
||||
|
||||
if (i == draft.size()) {
|
||||
const llama_token id = common_sampler_sample(gsmpl, ctx, nullptr, idxs[i]);
|
||||
const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i], grammar_first);
|
||||
|
||||
common_sampler_accept(gsmpl, ctx, id, true);
|
||||
|
||||
|
||||
Reference in New Issue
Block a user