Made it compile with ik_llama

This commit is contained in:
Saood Karim
2025-02-22 21:50:50 -06:00
parent e7c8b0df6c
commit dfaf65109a
2 changed files with 41 additions and 31 deletions

View File

@@ -269,6 +269,8 @@ struct gpt_params {
bool spm_infill = false; // suffix/prefix/middle pattern for infill
std::string lora_outfile = "ggml-lora-merged-f16.gguf";
bool sweep_bench_output_jsonl = false;
};
void gpt_params_handle_hf_token(gpt_params & params);

View File

@@ -1,7 +1,15 @@
#include "arg.h"
#include "common.h"
#include "log.h"
#include "ggml.h"
#include "llama.h"
#include "common.h"
#include "llama-vocab.h"
#ifdef _WIN32
#define WIN32_LEAN_AND_MEAN
#ifndef NOMINMAX
# define NOMINMAX
#endif
#include <windows.h>
#endif
#include <algorithm>
#include <cstdlib>
@@ -16,14 +24,14 @@ static void print_usage(int, char ** argv) {
}
int main(int argc, char ** argv) {
common_params params;
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_BENCH, print_usage)) {
gpt_params params;
if (!gpt_params_parse(argc, argv, params)) {
print_usage(argc, argv);
return 1;
}
common_init();
// init LLM
llama_backend_init();
@@ -31,18 +39,18 @@ int main(int argc, char ** argv) {
// initialize the model
llama_model_params model_params = common_model_params_to_llama(params);
llama_model_params model_params = llama_model_params_from_gpt_params(params);
llama_model * model = llama_model_load_from_file(params.model.c_str(), model_params);
llama_model * model = llama_load_model_from_file(params.model.c_str(), model_params);
if (model == NULL) {
fprintf(stderr , "%s: error: unable to load model\n" , __func__);
return 1;
}
llama_context_params ctx_params = common_context_params_to_llama(params);
llama_context_params ctx_params = llama_context_params_from_gpt_params(params);
llama_context * ctx = llama_init_from_model(model, ctx_params);
llama_context * ctx = llama_new_context_with_model(model, ctx_params);
if (ctx == NULL) {
fprintf(stderr , "%s: error: failed to create the llama_context\n" , __func__);
@@ -50,10 +58,13 @@ int main(int argc, char ** argv) {
}
const unsigned int n_kv_max = llama_n_ctx(ctx);
const llama_vocab * vocab = llama_model_get_vocab(model);
const unsigned int n_vocab = llama_vocab_n_tokens(vocab);
const llama_token bos = llama_vocab_bos(vocab);
const llama_token eos = llama_vocab_eos(vocab);
const llama_vocab * vocab = llama_get_vocab(ctx);
llama_token bos = llama_token_bos_impl(*vocab);
//llama_token eos = llama_token_eos_impl(*vocab);
const unsigned int n_vocab = llama_n_vocab(model);
// decode in batches of ctx_params.n_batch tokens
auto decode_helper = [](llama_context * ctx, llama_batch & batch, int32_t n_batch) {
@@ -72,7 +83,7 @@ int main(int argc, char ** argv) {
const int ret = llama_decode(ctx, batch_view);
if (ret != 0) {
LOG_ERR("failed to decode the batch, n_batch = %d, ret = %d\n", n_batch, ret);
LOG("failed to decode the batch, n_batch = %d, ret = %d\n", n_batch, ret);
return false;
}
@@ -85,7 +96,7 @@ int main(int argc, char ** argv) {
const unsigned int pp = params.n_ubatch;
const unsigned int tg = params.n_ubatch / 4;
if (!params.batched_bench_output_jsonl) {
if (!params.sweep_bench_output_jsonl) {
LOG("\n");
LOG("%s: n_kv_max = %d, n_batch = %d, n_ubatch = %d, flash_attn = %d, n_gpu_layers = %d, n_threads = %u, n_threads_batch = %u\n", __func__, n_kv_max, params.n_batch, params.n_ubatch, params.flash_attn, params.n_gpu_layers, ctx_params.n_threads, ctx_params.n_threads_batch);
LOG("\n");
@@ -97,16 +108,15 @@ int main(int argc, char ** argv) {
// warm up
{
common_batch_add(batch, bos, 0, { 0 }, false);
common_batch_add(batch, eos, 1, { 0 }, false);
llama_batch_add(batch, bos, 0, { 0 }, false);
if (!decode_helper(ctx, batch, ctx_params.n_batch)) {
LOG_ERR("%s: llama_decode() failed\n", __func__);
LOG("%s: llama_decode() failed\n", __func__);
return 1;
}
}
common_batch_clear(batch);
llama_batch_clear(batch);
llama_kv_cache_clear(ctx);
for (unsigned int n_kv = 0; n_kv < n_kv_max; n_kv += params.n_ubatch) {
@@ -117,11 +127,11 @@ int main(int argc, char ** argv) {
const auto t_tg_start = ggml_time_us();
for (unsigned int i = 0; i < tg; ++i) {
common_batch_clear(batch);
common_batch_add(batch, std::rand() % n_vocab, n_kv + i, { 0 }, true);
llama_batch_clear(batch);
llama_batch_add(batch, std::rand() % n_vocab, n_kv + i, { 0 }, true);
if (!decode_helper(ctx, batch, ctx_params.n_batch)) {
LOG_ERR("%s: llama_decode() failed\n", __func__);
LOG("%s: llama_decode() failed\n", __func__);
return 1;
}
}
@@ -132,10 +142,10 @@ int main(int argc, char ** argv) {
llama_kv_cache_seq_rm(ctx, 0, n_kv, -1);
// prepare batch of pp size for prompt processing performance measurement
common_batch_clear(batch);
llama_batch_clear(batch);
for (unsigned int i = 0; i < pp; ++i) {
common_batch_add(batch, std::rand() % n_vocab, n_kv + i, { 0 }, false);
llama_batch_add(batch, std::rand() % n_vocab, n_kv + i, { 0 }, false);
}
batch.logits[batch.n_tokens - 1] = true;
@@ -143,7 +153,7 @@ int main(int argc, char ** argv) {
const auto t_pp_start = ggml_time_us();
if (!decode_helper(ctx, batch, ctx_params.n_batch)) {
LOG_ERR("%s: llama_decode() failed\n", __func__);
LOG("%s: llama_decode() failed\n", __func__);
return 1;
}
@@ -156,7 +166,7 @@ int main(int argc, char ** argv) {
const float speed_pp = pp / t_pp;
const float speed_tg = tg / t_tg;
if(params.batched_bench_output_jsonl) {
if(params.sweep_bench_output_jsonl) {
LOG(
"{\"n_kv_max\": %d, \"n_batch\": %d, \"n_ubatch\": %d, \"flash_attn\": %d, \"n_gpu_layers\": %d, \"n_threads\": %u, \"n_threads_batch\": %u, "
"\"pp\": %d, \"tg\": %d, \"n_kv\": %d, \"t_pp\": %f, \"speed_pp\": %f, \"t_tg\": %f, \"speed_tg\": %f }\n",
@@ -168,12 +178,10 @@ int main(int argc, char ** argv) {
}
}
llama_perf_context_print(ctx);
llama_batch_free(batch);
llama_free(ctx);
llama_model_free(model);
llama_free_model(model);
llama_backend_free();