diff --git a/common/chat.cpp b/common/chat.cpp index 3358ac05..f384bfa7 100644 --- a/common/chat.cpp +++ b/common/chat.cpp @@ -489,11 +489,12 @@ std::string common_chat_format_single( return ss.str(); } -std::string common_chat_format_example(const struct common_chat_templates * tmpls, bool use_jinja) { +std::string common_chat_format_example(const struct common_chat_templates * tmpls, bool use_jinja, const std::map & chat_template_kwargs) { common_chat_templates_inputs inputs; inputs.use_jinja = use_jinja; inputs.add_bos = tmpls->add_bos; inputs.add_eos = tmpls->add_eos; + inputs.chat_template_kwargs = chat_template_kwargs; auto add_simple_msg = [&](auto role, auto content) { common_chat_msg msg; msg.role = role; diff --git a/common/chat.h b/common/chat.h index 55180e31..ef6d53c4 100644 --- a/common/chat.h +++ b/common/chat.h @@ -188,7 +188,8 @@ std::string common_chat_format_single( // Returns an example of formatted chat std::string common_chat_format_example( const struct common_chat_templates * tmpls, - bool use_jinja); + bool use_jinja, + const std::map & chat_template_kwargs); const char* common_chat_format_name(common_chat_format format); const char* common_reasoning_format_name(common_reasoning_format format); diff --git a/common/common.cpp b/common/common.cpp index 024228a9..6027b192 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -899,7 +899,12 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa } if (arg == "--mmproj") { CHECK_ARG - params.mmproj = argv[i]; + params.mmproj.path = argv[i]; + return true; + } + if (arg == "--mmproj-url") { + CHECK_ARG + params.mmproj.url = argv[i]; return true; } if (arg == "--image") { diff --git a/common/common.h b/common/common.h index 639771c7..2b4d1540 100644 --- a/common/common.h +++ b/common/common.h @@ -68,6 +68,29 @@ struct llama_control_vector_load_info; int32_t cpu_get_num_physical_cores(); int32_t cpu_get_num_math(); +enum llama_example { + LLAMA_EXAMPLE_COMMON, + LLAMA_EXAMPLE_SPECULATIVE, + LLAMA_EXAMPLE_MAIN, + LLAMA_EXAMPLE_EMBEDDING, + LLAMA_EXAMPLE_PERPLEXITY, + LLAMA_EXAMPLE_RETRIEVAL, + LLAMA_EXAMPLE_PASSKEY, + LLAMA_EXAMPLE_IMATRIX, + LLAMA_EXAMPLE_BENCH, + LLAMA_EXAMPLE_SERVER, + LLAMA_EXAMPLE_CVECTOR_GENERATOR, + LLAMA_EXAMPLE_EXPORT_LORA, + LLAMA_EXAMPLE_MTMD, + LLAMA_EXAMPLE_LOOKUP, + LLAMA_EXAMPLE_PARALLEL, + LLAMA_EXAMPLE_TTS, + LLAMA_EXAMPLE_DIFFUSION, + LLAMA_EXAMPLE_FINETUNE, + + LLAMA_EXAMPLE_COUNT, +}; + // // CLI argument parsing // @@ -86,6 +109,14 @@ enum common_reasoning_format { COMMON_REASONING_FORMAT_DEEPSEEK, // Extract thinking tag contents and return as `message.reasoning_content`, including in streaming deltas. }; +struct model_paths { + std::string path = ""; // model local path // NOLINT + std::string url = ""; // model url to download // NOLINT + std::string hf_repo = ""; // HF repo // NOLINT + std::string hf_file = ""; // HF file // NOLINT + std::string docker_repo = ""; // Docker repo // NOLINT +}; + struct gpt_params { uint32_t seed = LLAMA_DEFAULT_SEED; // RNG seed @@ -230,8 +261,10 @@ struct gpt_params { std::string cache_type_k_draft = ""; // KV cache data type for K for the draft model std::string cache_type_v_draft = ""; // KV cache data type for V for the draft model - // multimodal models (see examples/llava) - std::string mmproj = ""; // path to multimodal projector + // multimodal models (see examples/mtmd) + model_paths mmproj; + bool mmproj_use_gpu = true; // use GPU for multimodal model + bool no_mmproj = false; // explicitly disable multimodal model std::vector image; // path to image file(s) // embedding diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 850f33d5..d8f295f2 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -233,7 +233,7 @@ int main(int argc, char ** argv) { if (params.conversation) { if (params.enable_chat_template) { //LOG_TEE("%s: chat template example: %s\n", __func__, common_chat_format_example(model, *chat_templates.template_default, params.use_jinja).c_str()); - LOG_TEE("%s: chat template example:\n%s\n", __func__, common_chat_format_example(chat_templates.get(), params.use_jinja).c_str()); + LOG_TEE("%s: chat template example:\n%s\n", __func__, common_chat_format_example(chat_templates.get(), params.use_jinja, {}).c_str()); } else { LOG_TEE("%s: in-suffix/prefix is specified, chat template will be disabled\n", __func__); } diff --git a/examples/mtmd/mtmd-cli.cpp b/examples/mtmd/mtmd-cli.cpp index 5fde6ca0..8ff8d7ab 100644 --- a/examples/mtmd/mtmd-cli.cpp +++ b/examples/mtmd/mtmd-cli.cpp @@ -1,4 +1,4 @@ -#include "arg.h" +//#include "arg.h" #include "log.h" #include "common.h" #include "sampling.h" @@ -63,6 +63,60 @@ static void sigint_handler(int signo) { } #endif +// ======================= compat ================================ +using common_init_result = llama_init_result; +using common_sampler = llama_sampling_context; +using llama_tokens = std::vector; +using common_params = gpt_params; + +inline common_init_result common_init_from_params(gpt_params & params) { + return llama_init_from_gpt_params(params); +} +inline llama_sampling_context * common_sampler_init(const llama_model * model, const llama_sampling_params & sparams) { + return llama_sampling_init(llama_get_model_vocab(model), sparams); +} +inline std::vector common_tokenize(const llama_context * ctx, const std::string & text, bool add_special, bool parse_special = false) { + return llama_tokenize(ctx, text, add_special, parse_special); +} +inline void common_sampler_free(common_sampler * smpl) { + llama_sampling_free(smpl); +} +inline llama_token common_sampler_sample(common_sampler * gsmpl, llama_context * ctx, int idx, [[maybe_unused]] bool grammar_first = false) { + return llama_sampling_sample(gsmpl, ctx, nullptr, idx); +} +inline void common_sampler_accept(common_sampler * gsmpl, llama_context * ctx, llama_token token, bool accept_grammar) { + llama_sampling_accept(gsmpl, ctx, token, accept_grammar); +} +inline std::string common_token_to_piece(const llama_context * ctx, llama_token token, bool special = true) { + return llama_token_to_piece(ctx, token, special); +} +inline void common_batch_clear(llama_batch & batch) { + llama_batch_clear(batch); +} +inline void common_batch_add(llama_batch & batch, llama_token id, llama_pos pos, const std::vector & seq_ids, bool logits) { + llama_batch_add(batch, id, pos, seq_ids, logits); +} +void common_init() { +#ifdef NDEBUG + const char * build_type = ""; +#else + const char * build_type = " (debug)"; +#endif + LOG("build: %d (%s) with %s for %s%s\n", LLAMA_BUILD_NUMBER, LLAMA_COMMIT, LLAMA_COMPILER, LLAMA_BUILD_TARGET, build_type); +} + + +#ifndef LOG_ERR +#define LOG_ERR LOG +#endif +#ifndef LOG_INF +#define LOG_INF LOG +#endif +#ifndef LOG_DBG +#define LOG_DBG LOG +#endif +// ======================= end compat ================================ + struct mtmd_cli_context { mtmd::context_ptr ctx_vision; common_init_result llama_init; @@ -87,11 +141,11 @@ struct mtmd_cli_context { llama_pos n_past = 0; mtmd_cli_context(common_params & params) : llama_init(common_init_from_params(params)) { - model = llama_init.model.get(); - lctx = llama_init.context.get(); + model = llama_init.model; //.get(); + lctx = llama_init.context; //.get(); vocab = llama_model_get_vocab(model); - smpl = common_sampler_init(model, params.sampling); - n_threads = params.cpuparams.n_threads; + smpl = common_sampler_init(model, params.sparams); //sampling); + n_threads = params.n_threads; batch = llama_batch_init(1, 0, 1); // batch for next token generation n_batch = params.n_batch; @@ -130,7 +184,7 @@ struct mtmd_cli_context { mtmd_context_params mparams = mtmd_context_params_default(); mparams.use_gpu = params.mmproj_use_gpu; mparams.print_timings = true; - mparams.n_threads = params.cpuparams.n_threads; + mparams.n_threads = params.n_threads; mparams.verbosity = params.verbosity > 0 ? GGML_LOG_LEVEL_DEBUG : GGML_LOG_LEVEL_INFO; ctx_vision.reset(mtmd_init_from_file(clip_path, model, mparams)); if (!ctx_vision.get()) { @@ -170,7 +224,7 @@ static int generate_response(mtmd_cli_context & ctx, int n_predict) { llama_token token_id = common_sampler_sample(ctx.smpl, ctx.lctx, -1); generated_tokens.push_back(token_id); - common_sampler_accept(ctx.smpl, token_id, true); + common_sampler_accept(ctx.smpl, ctx.lctx, token_id, true); if (llama_vocab_is_eog(ctx.vocab, token_id) || ctx.check_antiprompt(generated_tokens)) { LOG("\n"); @@ -249,11 +303,14 @@ int main(int argc, char ** argv) { ggml_time_init(); common_params params; - params.sampling.temp = 0.2; // lower temp by default for better quality + params.sparams.temp = 0.2; // lower temp by default for better quality - if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_MTMD, show_additional_info)) { + if (!gpt_params_parse(argc, argv, params)) { return 1; } + //if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_MTMD, show_additional_info)) { + // return 1; + //} common_init(); @@ -264,7 +321,7 @@ int main(int argc, char ** argv) { } mtmd_cli_context ctx(params); - LOG("%s: loading model: %s\n", __func__, params.model.path.c_str()); + LOG("%s: loading model: %s\n", __func__, params.model.c_str()); bool is_single_turn = !params.prompt.empty() && !params.image.empty(); @@ -342,7 +399,8 @@ int main(int argc, char ** argv) { } if (line == "/clear") { ctx.n_past = 0; - llama_memory_seq_rm(llama_get_memory(ctx.lctx), 0, 1, -1); // keep BOS + llama_kv_cache_seq_rm(ctx.lctx, 0, 1, -1); + //llama_memory_seq_rm(llama_get_memory(ctx.lctx), 0, 1, -1); // keep BOS LOG("Chat history cleared\n\n"); continue; } @@ -381,6 +439,7 @@ int main(int argc, char ** argv) { } if (g_is_interrupted) LOG("\nInterrupted by user\n"); LOG("\n\n"); - llama_perf_context_print(ctx.lctx); + llama_print_timings(ctx.lctx); + //llama_perf_context_print(ctx.lctx); return g_is_interrupted ? 130 : 0; } diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 546b4c63..e7ba2bfe 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -1235,7 +1235,7 @@ struct server_context { chat_templates = common_chat_templates_init(model, params.chat_template); try { - common_chat_format_example(chat_templates.get(), params.use_jinja); + common_chat_format_example(chat_templates.get(), params.use_jinja, {}); } catch (const std::exception& e) { LOG_WARNING("%s: The chat template that comes with this model is not yet supported, falling back to chatml. This may cause the model to output suboptimal responses\n", __func__); @@ -3778,7 +3778,7 @@ int main(int argc, char ** argv) { }); LOG_INFO("chat template", { - {"chat_example", common_chat_format_example(ctx_server.chat_templates.get(), ctx_server.params.use_jinja).c_str() + {"chat_example", common_chat_format_example(ctx_server.chat_templates.get(), ctx_server.params.use_jinja, {}).c_str() }, {"built_in", params.chat_template.empty()}, }); diff --git a/include/llama.h b/include/llama.h index 0784ef7b..8e6bfad2 100644 --- a/include/llama.h +++ b/include/llama.h @@ -1026,6 +1026,7 @@ extern "C" { // Check if the token is supposed to end generation (end-of-generation, eg. EOS, EOT, etc.) LLAMA_API bool llama_token_is_eog(const struct llama_model * model, llama_token token); + LLAMA_API bool llama_vocab_is_eog(const struct llama_vocab * vocab, llama_token token); // Identify if Token Id is a control token or a render-able token LLAMA_API bool llama_token_is_control(const struct llama_model * model, llama_token token);