Server: rename functions and refactor code

rename functions

refactor update slots

rename params_base

rename timings
This commit is contained in:
firecoperana
2026-01-13 12:02:58 -06:00
parent cb1063f6cd
commit b43b22b68a
39 changed files with 609 additions and 595 deletions

View File

@@ -82,7 +82,7 @@ int main(int argc, char ** argv) {
fprintf(stderr, "\n\n");
for (auto id : inp) {
fprintf(stderr, "%s", llama_token_to_piece(ctx, id).c_str());
fprintf(stderr, "%s", common_token_to_piece(ctx, id).c_str());
}
fflush(stderr);
@@ -96,7 +96,7 @@ int main(int argc, char ** argv) {
llama_decode(ctx, llama_batch_get_one(&inp.back(), 1, n_input - 1, 0));
for (int s = 1; s < W + G + 1; ++s) {
llama_kv_cache_seq_cp(ctx, 0, s, -1, -1);
llama_memory_seq_cp(ctx, 0, s, -1, -1);
}
const auto t_enc_end = ggml_time_us();
@@ -118,7 +118,7 @@ int main(int argc, char ** argv) {
llama_batch batch = llama_batch_init(params.n_ctx, 0, W + G + 1);
// target model sampling context
struct llama_sampling_context * ctx_sampling = llama_sampling_init(llama_get_model_vocab(model), params.sparams);
struct llama_sampling_context * ctx_sampling = common_sampler_init(llama_get_model_vocab(model), params.sparams);
// verification n-grams
std::vector<ngram_data> ngrams_cur(G);
@@ -159,12 +159,12 @@ int main(int argc, char ** argv) {
// sample first token
{
id = llama_sampling_sample(ctx_sampling, ctx, NULL, 0);
id = common_sampler_sample(ctx_sampling, ctx, NULL, 0);
llama_sampling_accept(ctx_sampling, ctx, id, true);
common_sampler_accept(ctx_sampling, ctx, id, true);
{
const std::string token_str = llama_token_to_piece(ctx, id);
const std::string token_str = common_token_to_piece(ctx, id);
printf("%s", token_str.c_str());
fflush(stdout);
@@ -204,10 +204,10 @@ int main(int argc, char ** argv) {
// V V V V V V
// id
{
llama_batch_clear(batch);
common_batch_clear(batch);
// current token - first token of the first level
llama_batch_add(batch, id, n_past, seq_id_all, true);
common_batch_add(batch, id, n_past, seq_id_all, true);
// verification n-grams - queue this before the lookahead tokens for less KV cache fragmentation
{
@@ -232,7 +232,7 @@ int main(int argc, char ** argv) {
ngrams_cur[g].tokens [j + 1] = t;
ngrams_cur[g].i_batch[j + 1] = batch.n_tokens;
llama_batch_add(batch, t, n_past + j + 1, { W + 1 + g }, true);
common_batch_add(batch, t, n_past + j + 1, { W + 1 + g }, true);
}
}
}
@@ -244,13 +244,13 @@ int main(int argc, char ** argv) {
seq_id_look[j] = i + j + 1;
}
llama_batch_add(batch, tokens_j[0][i], n_past + i, seq_id_look, false);
common_batch_add(batch, tokens_j[0][i], n_past + i, seq_id_look, false);
}
// fill the rest of the levels
for (int j = 1; j < N - 1; j++) {
for (int i = 0; i < W; i++) {
llama_batch_add(batch, tokens_j[j][i], n_past + j + i, { i + 1 }, j == N - 2);
common_batch_add(batch, tokens_j[j][i], n_past + j + i, { i + 1 }, j == N - 2);
}
}
}
@@ -284,13 +284,13 @@ int main(int argc, char ** argv) {
}
// sample the next token
id = llama_sampling_sample(ctx_sampling, ctx, NULL, i_batch);
id = common_sampler_sample(ctx_sampling, ctx, NULL, i_batch);
llama_sampling_accept(ctx_sampling, ctx, id, true);
common_sampler_accept(ctx_sampling, ctx, id, true);
// print
{
const std::string token_str = llama_token_to_piece(ctx, id);
const std::string token_str = common_token_to_piece(ctx, id);
if (v == 0) {
printf("%s", token_str.c_str());
@@ -330,7 +330,7 @@ int main(int argc, char ** argv) {
// print known n-grams starting with token id (debug)
if (0 && v == 0) {
if (ngrams_observed.cnt[id] > 0) {
printf("\n - %d n-grams starting with '%s'\n", ngrams_observed.cnt[id], llama_token_to_piece(ctx, id).c_str());
printf("\n - %d n-grams starting with '%s'\n", ngrams_observed.cnt[id], common_token_to_piece(ctx, id).c_str());
}
for (int i = 0; i < ngrams_observed.cnt[id]; i++) {
@@ -339,7 +339,7 @@ int main(int argc, char ** argv) {
const int idx = id*(N - 1)*G + i*(N - 1);
for (int j = 0; j < N - 1; j++) {
const std::string token_str = llama_token_to_piece(ctx, ngrams_observed.tokens[idx + j]);
const std::string token_str = common_token_to_piece(ctx, ngrams_observed.tokens[idx + j]);
printf("%s", token_str.c_str());
}
@@ -361,7 +361,7 @@ int main(int argc, char ** argv) {
if (v == 0) {
// sample from the last level
for (int i = 0; i < W; i++) {
tokens_j[N - 2][i] = llama_sampling_sample(ctx_sampling, ctx, NULL, ngrams_cur.size()*(N-1) + W*(N - 2) + i);
tokens_j[N - 2][i] = common_sampler_sample(ctx_sampling, ctx, NULL, ngrams_cur.size()*(N-1) + W*(N - 2) + i);
}
} else {
for (int i = 0; i < W; i++) {
@@ -438,17 +438,17 @@ int main(int argc, char ** argv) {
// KV cache management
// if no verification token matched, we simply remove all cells from this batch -> no fragmentation
llama_kv_cache_seq_rm(ctx, -1, n_past, -1);
llama_memory_seq_rm(ctx, -1, n_past, -1);
if (seq_id_best != 0) {
// if a verification token matched, we keep the best sequence and remove the rest
// this leads to some KV cache fragmentation
llama_kv_cache_seq_keep(ctx, seq_id_best);
llama_kv_cache_seq_cp (ctx, seq_id_best, 0, -1, -1);
llama_kv_cache_seq_rm (ctx, seq_id_best, -1, -1);
llama_memory_seq_cp (ctx, seq_id_best, 0, -1, -1);
llama_memory_seq_rm (ctx, seq_id_best, -1, -1);
for (int s = 1; s < W + G + 1; ++s) {
llama_kv_cache_seq_cp(ctx, 0, s, -1, -1);
llama_memory_seq_cp(ctx, 0, s, -1, -1);
}
}
}
@@ -471,7 +471,7 @@ int main(int argc, char ** argv) {
llama_print_timings(ctx);
llama_kv_cache_view_free(&kvc_view);
llama_sampling_free(ctx_sampling);
common_sampler_free(ctx_sampling);
llama_batch_free(batch);