diff --git a/common/CMakeLists.txt b/common/CMakeLists.txt index 789154e8..2c7e65f6 100644 --- a/common/CMakeLists.txt +++ b/common/CMakeLists.txt @@ -74,6 +74,7 @@ add_library(${TARGET} STATIC train.cpp ngram-cache.h ngram-cache.cpp + speculative.cpp ) if (BUILD_SHARED_LIBS) diff --git a/common/sampling.cpp b/common/sampling.cpp index 7d460b57..9c5580e8 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -442,7 +442,9 @@ static llama_token_data_array llama_sampling_prepare_impl( cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f}; } - llama_token_data_array cur_p = { cur.data(), cur.size(), false }; + ctx_sampling->cur_p = { cur.data(), cur.size(), false }; + + llama_token_data_array & cur_p = ctx_sampling->cur_p; // apply penalties const auto& penalty_tokens = params.use_penalty_prompt_tokens ? params.penalty_prompt_tokens : prev; @@ -507,6 +509,10 @@ void llama_sampling_accept( } } +llama_token_data_array * llama_sampling_get_candidates(struct llama_sampling_context * ctx_sampling) { + return &ctx_sampling->cur_p; +} + std::vector llama_sampling_sample_and_accept_n(struct llama_sampling_context * gsmpl, struct llama_context * ctx, const std::vector & draft) { std::vector result; diff --git a/common/sampling.h b/common/sampling.h index 405f5a63..d209a59f 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -101,6 +101,8 @@ struct llama_sampling_context { size_t n_valid; // Number of correct top tokens with correct probabilities. + llama_token_data_array cur_p; // current candidates + std::mt19937 rng; }; @@ -178,5 +180,8 @@ void llama_sampling_accept( bool apply_grammar); // returns at least 1 token, up to draft.size() +// access the internal list of current candidate tokens +llama_token_data_array * llama_sampling_get_candidates(struct llama_sampling_context * ctx_sampling); + std::vector llama_sampling_sample_and_accept_n(struct llama_sampling_context * gsmpl, struct llama_context * ctx, const std::vector & draft); diff --git a/common/speculative.cpp b/common/speculative.cpp index aa7592b5..ae326be4 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -1,8 +1,8 @@ #include "speculative.h" -#include "log.h" #include "common.h" #include "sampling.h" +#include "llama-impl.h" #include #include @@ -10,17 +10,17 @@ #define SPEC_VOCAB_MAX_SIZE_DIFFERENCE 128 #define SPEC_VOCAB_CHECK_START_TOKEN_ID 5 -struct common_speculative { +struct llama_speculative { struct llama_context * ctx; - struct common_sampler * smpl; + struct llama_sampling_context * smpl; llama_batch batch; std::vector prompt; }; -struct common_speculative * common_speculative_init( +struct llama_speculative * llama_speculative_init( struct llama_context * ctx_dft) { - auto * result = new common_speculative { + auto * result = new llama_speculative { /* .ctx = */ ctx_dft, /* .smpl = */ nullptr, /* .batch = */ llama_batch_init(llama_n_batch(ctx_dft), 0, 1), @@ -30,7 +30,7 @@ struct common_speculative * common_speculative_init( // TODO: optimize or pass from outside? #if 0 { - common_params_sampling params; + llama_sampling_params params; params.no_perf = false; params.top_k = 40; @@ -42,90 +42,87 @@ struct common_speculative * common_speculative_init( COMMON_SAMPLER_TYPE_INFILL, }; - result->smpl = common_sampler_init(llama_get_model(ctx_dft), params); + result->smpl = llama_sampler_init(llama_get_model(ctx_dft), params); } #else { - common_params_sampling params; - params.no_perf = false; - + llama_sampling_params params; params.top_k = 10; - - params.samplers = { - COMMON_SAMPLER_TYPE_TOP_K, + params.samplers_sequence = { + llama_sampler_type::TOP_K, }; - - result->smpl = common_sampler_init(llama_get_model(ctx_dft), params); + const auto *model_dft = llama_get_model(ctx_dft); + result->smpl = llama_sampling_init(llama_get_model_vocab(model_dft), params); } #endif return result; } -void common_speculative_free(struct common_speculative * spec) { +void llama_speculative_free(struct llama_speculative * spec) { if (spec == nullptr) { return; } - common_sampler_free(spec->smpl); + llama_sampling_free(spec->smpl); llama_batch_free(spec->batch); delete spec; } -bool common_speculative_are_compatible( +bool llama_speculative_are_compatible( const struct llama_context * ctx_tgt, const struct llama_context * ctx_dft) { const struct llama_model * model_tgt = llama_get_model(ctx_tgt); const struct llama_model * model_dft = llama_get_model(ctx_dft); - const struct llama_vocab * vocab_tgt = llama_model_get_vocab(model_tgt); - const struct llama_vocab * vocab_dft = llama_model_get_vocab(model_dft); + const struct llama_vocab * vocab_tgt = llama_get_model_vocab(model_tgt); + const struct llama_vocab * vocab_dft = llama_get_model_vocab(model_dft); - const bool vocab_type_tgt = llama_vocab_type(vocab_tgt); - LLAMA_LOG_DBG("%s: vocab_type tgt: %d\n", __func__, vocab_type_tgt); + const bool vocab_type_tgt = llama_vocab_type(model_tgt); + LLAMA_LOG_INFO("%s: vocab_type tgt: %d\n", __func__, vocab_type_tgt); - const bool vocab_type_dft = llama_vocab_type(vocab_dft); - LLAMA_LOG_DBG("%s: vocab_type dft: %d\n", __func__, vocab_type_dft); + const bool vocab_type_dft = llama_vocab_type(model_dft); + LLAMA_LOG_INFO("%s: vocab_type dft: %d\n", __func__, vocab_type_dft); if (vocab_type_tgt != vocab_type_dft) { - LLAMA_LOG_ERR("%s: draft model vocab type must match target model to use speculation but " + LLAMA_LOG_ERROR("%s: draft model vocab type must match target model to use speculation but " "vocab_type_dft = %d while vocab_type_tgt = %d\n", __func__, vocab_type_dft, vocab_type_tgt); return false; } - if (llama_vocab_get_add_bos(vocab_tgt) != llama_vocab_get_add_bos(vocab_dft) || - llama_vocab_get_add_eos(vocab_tgt) != llama_vocab_get_add_eos(vocab_dft) || - llama_vocab_bos(vocab_tgt) != llama_vocab_bos(vocab_dft) || - llama_vocab_eos(vocab_tgt) != llama_vocab_eos(vocab_dft)) { - LLAMA_LOG_ERR("%s: draft vocab special tokens must match target vocab to use speculation\n", __func__); - LLAMA_LOG_ERR("%s: tgt: bos = %d (%d), eos = %d (%d)\n", __func__, llama_vocab_bos(vocab_tgt), llama_vocab_get_add_bos(vocab_tgt), llama_vocab_eos(vocab_tgt), llama_vocab_get_add_eos(vocab_tgt)); - LLAMA_LOG_ERR("%s: dft: bos = %d (%d), eos = %d (%d)\n", __func__, llama_vocab_bos(vocab_dft), llama_vocab_get_add_bos(vocab_dft), llama_vocab_eos(vocab_dft), llama_vocab_get_add_eos(vocab_dft)); + if (llama_add_bos_token(model_tgt) != llama_add_bos_token(model_dft) || + llama_add_eos_token(model_tgt) != llama_add_eos_token(model_dft) || + llama_token_bos(model_tgt) != llama_token_bos(model_dft) || + llama_token_eos(model_tgt) != llama_token_eos(model_dft)) { + LLAMA_LOG_ERROR("%s: draft vocab special tokens must match target vocab to use speculation\n", __func__); + LLAMA_LOG_ERROR("%s: tgt: bos = %d (%d), eos = %d (%d)\n", __func__, llama_token_bos(model_tgt), llama_add_bos_token(model_tgt), llama_token_eos(model_tgt), llama_add_eos_token(model_tgt)); + LLAMA_LOG_ERROR("%s: dft: bos = %d (%d), eos = %d (%d)\n", __func__, llama_token_bos(model_dft), llama_add_bos_token(model_dft), llama_token_eos(model_dft), llama_add_eos_token(model_dft)); return false; } { - const int n_vocab_tgt = llama_vocab_n_tokens(vocab_tgt); - const int n_vocab_dft = llama_vocab_n_tokens(vocab_dft); + const int n_vocab_tgt = llama_n_vocab(model_tgt); + const int n_vocab_dft = llama_n_vocab(model_dft); - const int vocab_diff = std::abs(n_vocab_tgt - n_vocab_dft); + const int model_diff = std::abs(n_vocab_tgt - n_vocab_dft); - if (vocab_diff > SPEC_VOCAB_MAX_SIZE_DIFFERENCE) { - LLAMA_LOG_ERR("%s: draft model vocab must closely match target model to use speculation but " + if (model_diff > SPEC_VOCAB_MAX_SIZE_DIFFERENCE) { + LLAMA_LOG_ERROR("%s: draft model vocab must closely match target model to use speculation but " "target vocab size %d does not match draft vocab size %d - difference %d, max allowed %d\n", - __func__, n_vocab_tgt, llama_vocab_n_tokens(vocab_dft), vocab_diff, SPEC_VOCAB_MAX_SIZE_DIFFERENCE); + __func__, n_vocab_tgt, n_vocab_dft, model_diff, SPEC_VOCAB_MAX_SIZE_DIFFERENCE); return false; } for (int i = SPEC_VOCAB_CHECK_START_TOKEN_ID; i < std::min(n_vocab_tgt, n_vocab_dft); ++i) { - const char * token_text_tgt = llama_vocab_get_text(vocab_tgt, i); - const char * token_text_dft = llama_vocab_get_text(vocab_dft, i); + const char * token_text_tgt = llama_token_get_text(model_tgt, i); + const char * token_text_dft = llama_token_get_text(model_dft, i); if (std::strcmp(token_text_tgt, token_text_dft) != 0) { - LLAMA_LOG_ERR("%s: draft vocab vocab must match target vocab to use speculation but " + LLAMA_LOG_ERROR("%s: draft vocab vocab must match target vocab to use speculation but " "token %d content differs - target '%s', draft '%s'\n", __func__, i, - common_token_to_piece(ctx_tgt, i).c_str(), - common_token_to_piece(ctx_dft, i).c_str()); + llama_token_to_piece(ctx_tgt, i).c_str(), + llama_token_to_piece(ctx_dft, i).c_str()); return false; } } @@ -134,18 +131,16 @@ bool common_speculative_are_compatible( return true; } -llama_tokens common_speculative_gen_draft( - struct common_speculative * spec, - struct common_speculative_params params, - const llama_tokens & prompt_tgt, +std::vector llama_speculative_gen_draft( + struct llama_speculative * spec, + struct llama_speculative_params params, + const std::vector & prompt_tgt, llama_token id_last) { auto & batch = spec->batch; auto & ctx = spec->ctx; auto & smpl = spec->smpl; auto & prompt = spec->prompt; - auto * mem = llama_get_memory(ctx); - int reuse_i = 0; int reuse_n = 0; @@ -169,13 +164,13 @@ llama_tokens common_speculative_gen_draft( } } - LLAMA_LOG_DBG("%s: reuse_i = %d, reuse_n = %d, prompt = %d\n", __func__, reuse_i, reuse_n, (int) prompt.size()); + LLAMA_LOG_INFO("%s: reuse_i = %d, reuse_n = %d, prompt = %d\n", __func__, reuse_i, reuse_n, (int) prompt.size()); - llama_tokens result; + std::vector result; result.reserve(params.n_draft); if (reuse_n == 0) { - llama_memory_clear(mem, false); + llama_kv_cache_clear(ctx, false); prompt.clear(); } else { @@ -194,68 +189,68 @@ llama_tokens common_speculative_gen_draft( } if (reuse_i > 0) { - llama_memory_seq_rm (mem, 0, 0, reuse_i); - llama_memory_seq_add(mem, 0, reuse_i, -1, -reuse_i); + llama_kv_cache_seq_rm (ctx, 0, 0, reuse_i); + llama_kv_cache_seq_add(ctx, 0, reuse_i, -1, -reuse_i); prompt.erase(prompt.begin(), prompt.begin() + reuse_i); } if (reuse_n < (int) prompt.size()) { - llama_memory_seq_rm (mem, 0, reuse_n, -1); + llama_kv_cache_seq_rm (ctx, 0, reuse_n, -1); prompt.erase(prompt.begin() + reuse_n, prompt.end()); } } // prepare a batch to evaluate any new tokens in the prompt - common_batch_clear(batch); + llama_batch_clear(batch); for (size_t i = i_start + reuse_n; i < prompt_tgt.size(); ++i) { - //LLAMA_LOG_DBG("i = %d, i_start = %d, reuse_n = %d, i - i_start = %d, id = %6d\n", i, i_start, reuse_n, i - i_start, prompt_tgt[i]); - common_batch_add(batch, prompt_tgt[i], i - i_start, { 0 }, false); + //LLAMA_LOG_INFO("i = %d, i_start = %d, reuse_n = %d, i - i_start = %d, id = %6d\n", i, i_start, reuse_n, i - i_start, prompt_tgt[i]); + llama_batch_add(batch, prompt_tgt[i], i - i_start, { 0 }, false); prompt.push_back(prompt_tgt[i]); } // we should rarely end-up here during normal decoding if (batch.n_tokens > 0) { - //LLAMA_LOG_DBG("%s: draft prompt batch: %s\n", __func__, string_from(ctx, batch).c_str()); + //LLAMA_LOG_INFO("%s: draft prompt batch: %s\n", __func__, string_from(ctx, batch).c_str()); llama_decode(ctx, batch); } const llama_pos n_past = prompt.size(); - LLAMA_LOG_DBG("%s: n_past = %d\n", __func__, n_past); + LLAMA_LOG_INFO("%s: n_past = %d\n", __func__, n_past); - common_batch_clear(batch); - common_batch_add (batch, id_last, n_past, { 0 }, true); + llama_batch_clear(batch); + llama_batch_add (batch, id_last, n_past, { 0 }, true); prompt.push_back(id_last); - //LLAMA_LOG_DBG("%s: draft prompt: %s\n", __func__, string_from(ctx, prompt).c_str()); + //LLAMA_LOG_INFO("%s: draft prompt: %s\n", __func__, string_from(ctx, prompt).c_str()); llama_decode(ctx, batch); - common_sampler_reset(smpl); + llama_sampling_reset(smpl); // sample n_draft tokens from the draft model for (int i = 0; i < params.n_draft; ++i) { - common_batch_clear(batch); + llama_batch_clear(batch); - common_sampler_sample(smpl, ctx, 0, true); + llama_sampling_sample(smpl, ctx, 0, true); - const auto * cur_p = common_sampler_get_candidates(smpl); + const auto * cur_p = llama_sampling_get_candidates(smpl); for (int k = 0; k < std::min(3, (int) cur_p->size); ++k) { - LLAMA_LOG_DBG(" - draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n", - k, i, cur_p->data[k].id, cur_p->data[k].p, common_token_to_piece(ctx, cur_p->data[k].id).c_str()); + LLAMA_LOG_INFO(" - draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n", + k, i, cur_p->data[k].id, cur_p->data[k].p, llama_token_to_piece(ctx, cur_p->data[k].id).c_str()); } // add drafted token for each sequence const llama_token id = cur_p->data[0].id; - common_sampler_accept(smpl, id, true); + llama_sampling_accept(smpl, ctx, id, true); result.push_back(id); @@ -268,7 +263,7 @@ llama_tokens common_speculative_gen_draft( break; } - common_batch_add(batch, id, n_past + i + 1, { 0 }, true); + llama_batch_add(batch, id, n_past + i + 1, { 0 }, true); // evaluate the drafted tokens on the draft model llama_decode(ctx, batch); diff --git a/common/speculative.h b/common/speculative.h index 75f2e311..faa6ee54 100644 --- a/common/speculative.h +++ b/common/speculative.h @@ -1,28 +1,29 @@ #pragma once #include "llama.h" -#include "common.h" -struct common_speculative; +#include -struct common_speculative_params { +struct llama_speculative; + +struct llama_speculative_params { int n_draft = 16; // max drafted tokens int n_reuse = 256; float p_min = 0.75f; // min probability required to accept a token in the draft }; -struct common_speculative * common_speculative_init(struct llama_context * ctx_dft); +struct llama_speculative * llama_speculative_init(struct llama_context * ctx_dft); -void common_speculative_free(struct common_speculative * spec); +void llama_speculative_free(struct llama_speculative * spec); -bool common_speculative_are_compatible( +bool llama_speculative_are_compatible( const struct llama_context * ctx_tgt, const struct llama_context * ctx_dft); // sample up to n_draft tokens and add them to the batch using the draft model -std::vector common_speculative_gen_draft( - struct common_speculative * spec, - struct common_speculative_params params, +std::vector llama_speculative_gen_draft( + struct llama_speculative * spec, + struct llama_speculative_params params, const std::vector & prompt, llama_token id_last);