mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-20 05:04:11 +00:00
spec : add self speculative decoding, ngram and refactor (#1261)
* spec : add self speculative decoding and ngram-mod and refactor common : use common_ prefix for common library function llama : use LLAMA_TOKEN_NULL spec : add self speculative decoding (no draft model required) + refactor spec : add ngram-mod spec : various improvements ton ngram-map + docs spec : fix the check-rate logic of ngram-simple common : add common_speculative_is_compat() spec : simplify time measurement using common_time_meas refactor common_sampler_init refactor common_token_to_piece refactor and fix cur_p bug clean up * spec : remove check rate * spec: show warnings instead of abort --------- Co-authored-by: firecoperana <firecoperana> Co-authored-by: Sascha Rogmann <59577610+srogmann@users.noreply.github.com>
This commit is contained in:
@@ -72,7 +72,7 @@ struct client {
|
||||
std::string prompt;
|
||||
std::string response;
|
||||
|
||||
struct llama_sampling_context * ctx_sampling = nullptr;
|
||||
struct common_sampler * ctx_sampling = nullptr;
|
||||
};
|
||||
|
||||
static void print_date_time() {
|
||||
@@ -161,11 +161,11 @@ int main(int argc, char ** argv) {
|
||||
for (size_t i = 0; i < clients.size(); ++i) {
|
||||
auto & client = clients[i];
|
||||
client.id = i;
|
||||
client.ctx_sampling = common_sampler_init(llama_get_model_vocab(model), params.sparams);
|
||||
client.ctx_sampling = common_sampler_init(model, params.sparams);
|
||||
}
|
||||
|
||||
std::vector<llama_token> tokens_system;
|
||||
tokens_system = ::llama_tokenize(ctx, k_system, true);
|
||||
tokens_system = ::common_tokenize(ctx, k_system, true);
|
||||
const int32_t n_tokens_system = tokens_system.size();
|
||||
|
||||
llama_seq_id g_seq_id = 0;
|
||||
@@ -253,11 +253,11 @@ int main(int argc, char ** argv) {
|
||||
client.prompt = client.input + "\nAssistant:";
|
||||
client.response = "";
|
||||
|
||||
common_sampler_reset(llama_get_model_vocab(model), client.ctx_sampling);
|
||||
common_sampler_reset(client.ctx_sampling);
|
||||
|
||||
// do not prepend BOS because we have a system prompt!
|
||||
std::vector<llama_token> tokens_prompt;
|
||||
tokens_prompt = ::llama_tokenize(ctx, client.prompt, false);
|
||||
tokens_prompt = ::common_tokenize(ctx, client.prompt, false);
|
||||
|
||||
for (size_t i = 0; i < tokens_prompt.size(); ++i) {
|
||||
common_batch_add(batch, tokens_prompt[i], i + n_tokens_system, { client.id + 1 }, false);
|
||||
@@ -341,7 +341,7 @@ int main(int argc, char ** argv) {
|
||||
//printf("client %d, seq %d, token %d, pos %d, batch %d\n",
|
||||
// client.id, client.seq_id, client.sampled, client.n_decoded, client.i_batch);
|
||||
|
||||
const llama_token id = common_sampler_sample(client.ctx_sampling, ctx, NULL, client.i_batch - i);
|
||||
const llama_token id = common_sampler_sample(client.ctx_sampling, ctx, client.i_batch - i);
|
||||
|
||||
common_sampler_accept(client.ctx_sampling, ctx, id, true);
|
||||
|
||||
|
||||
Reference in New Issue
Block a user