sweep_bench: set number of repetions (#1176)

This commit is contained in:
Kawrakow
2026-01-22 12:28:30 +02:00
committed by GitHub
parent 101fe54797
commit 573e23679d
3 changed files with 77 additions and 54 deletions

View File

@@ -31,6 +31,7 @@ int main(int argc, char ** argv) {
print_usage(argc, argv);
return 1;
}
if (params.nrep < 1) params.nrep = 1;
// init LLM
@@ -135,49 +136,63 @@ int main(int argc, char ** argv) {
common_batch_clear(batch);
llama_kv_cache_clear(ctx);
int i_loop = 0;
for (unsigned int n_kv = 0; n_kv < n_kv_max; n_kv += params.n_ubatch) {
// clean up KV cache before generation
llama_kv_cache_seq_rm(ctx, 0, n_kv, -1);
//llama_kv_cache_seq_rm(ctx, 0, n_kv, -1);
int nrep = i_loop < 1 ? params.nrep : 1;
// first measure token generation performance at this context size
const auto t_tg_start = ggml_time_us();
for (unsigned int i = 0; i < tg; ++i) {
for (int irep = 0; irep < nrep; ++irep) {
llama_kv_cache_seq_rm(ctx, 0, n_kv, -1);
for (unsigned int i = 0; i < tg; ++i) {
common_batch_clear(batch);
common_batch_add(batch, std::rand() % n_vocab, n_kv + i, { 0 }, true);
if (!decode_helper(ctx, batch, ctx_params.n_batch)) {
LOG_TEE("%s: llama_decode() failed\n", __func__);
return 1;
}
}
}
const auto t_tg_end = ggml_time_us();
// measure prompt processing performance
const auto t_pp_start = ggml_time_us();
for (int irep = 0; irep < nrep; ++irep) {
// clean up KV cache after generation
llama_kv_cache_seq_rm(ctx, 0, n_kv, -1);
// prepare batch of pp size for prompt processing performance measurement
common_batch_clear(batch);
common_batch_add(batch, std::rand() % n_vocab, n_kv + i, { 0 }, true);
for (unsigned int i = 0; i < pp; ++i) {
common_batch_add(batch, std::rand() % n_vocab, n_kv + i, { 0 }, false);
}
batch.logits[batch.n_tokens - 1] = true;
if (!decode_helper(ctx, batch, ctx_params.n_batch)) {
LOG_TEE("%s: llama_decode() failed\n", __func__);
return 1;
}
}
const auto t_tg_end = ggml_time_us();
// clean up KV cache after generation
llama_kv_cache_seq_rm(ctx, 0, n_kv, -1);
// prepare batch of pp size for prompt processing performance measurement
common_batch_clear(batch);
for (unsigned int i = 0; i < pp; ++i) {
common_batch_add(batch, std::rand() % n_vocab, n_kv + i, { 0 }, false);
}
batch.logits[batch.n_tokens - 1] = true;
// measure prompt processing performance
const auto t_pp_start = ggml_time_us();
if (!decode_helper(ctx, batch, ctx_params.n_batch)) {
LOG_TEE("%s: llama_decode() failed\n", __func__);
return 1;
}
const auto t_pp_end = ggml_time_us();
// calculate and print metrics
const float t_pp = (t_pp_end - t_pp_start) / 1000000.0f;
const float t_tg = (t_tg_end - t_tg_start) / 1000000.0f;
const float t_pp = (t_pp_end - t_pp_start) / 1000000.0f / nrep;
const float t_tg = (t_tg_end - t_tg_start) / 1000000.0f / nrep;
const float speed_pp = pp / t_pp;
const float speed_tg = tg / t_tg;
@@ -192,6 +207,8 @@ int main(int argc, char ** argv) {
} else {
LOG_TEE("|%6d | %6d | %6d | %8.3f | %8.2f | %8.3f | %8.2f |\n", pp, tg, n_kv, t_pp, speed_pp, t_tg, speed_tg);
}
++i_loop;
}
llama_batch_free(batch);