Add mtmd: builds successfully

This commit is contained in:
Iwan Kawrakow
2025-09-25 15:40:46 +03:00
parent 6b0c8e02a8
commit 24618e301b
8 changed files with 120 additions and 20 deletions

View File

@@ -1,4 +1,4 @@
#include "arg.h"
//#include "arg.h"
#include "log.h"
#include "common.h"
#include "sampling.h"
@@ -63,6 +63,60 @@ static void sigint_handler(int signo) {
}
#endif
// ======================= compat ================================
using common_init_result = llama_init_result;
using common_sampler = llama_sampling_context;
using llama_tokens = std::vector<llama_token>;
using common_params = gpt_params;
inline common_init_result common_init_from_params(gpt_params & params) {
return llama_init_from_gpt_params(params);
}
inline llama_sampling_context * common_sampler_init(const llama_model * model, const llama_sampling_params & sparams) {
return llama_sampling_init(llama_get_model_vocab(model), sparams);
}
inline std::vector<llama_token> common_tokenize(const llama_context * ctx, const std::string & text, bool add_special, bool parse_special = false) {
return llama_tokenize(ctx, text, add_special, parse_special);
}
inline void common_sampler_free(common_sampler * smpl) {
llama_sampling_free(smpl);
}
inline llama_token common_sampler_sample(common_sampler * gsmpl, llama_context * ctx, int idx, [[maybe_unused]] bool grammar_first = false) {
return llama_sampling_sample(gsmpl, ctx, nullptr, idx);
}
inline void common_sampler_accept(common_sampler * gsmpl, llama_context * ctx, llama_token token, bool accept_grammar) {
llama_sampling_accept(gsmpl, ctx, token, accept_grammar);
}
inline std::string common_token_to_piece(const llama_context * ctx, llama_token token, bool special = true) {
return llama_token_to_piece(ctx, token, special);
}
inline void common_batch_clear(llama_batch & batch) {
llama_batch_clear(batch);
}
inline void common_batch_add(llama_batch & batch, llama_token id, llama_pos pos, const std::vector<llama_seq_id> & seq_ids, bool logits) {
llama_batch_add(batch, id, pos, seq_ids, logits);
}
void common_init() {
#ifdef NDEBUG
const char * build_type = "";
#else
const char * build_type = " (debug)";
#endif
LOG("build: %d (%s) with %s for %s%s\n", LLAMA_BUILD_NUMBER, LLAMA_COMMIT, LLAMA_COMPILER, LLAMA_BUILD_TARGET, build_type);
}
#ifndef LOG_ERR
#define LOG_ERR LOG
#endif
#ifndef LOG_INF
#define LOG_INF LOG
#endif
#ifndef LOG_DBG
#define LOG_DBG LOG
#endif
// ======================= end compat ================================
struct mtmd_cli_context {
mtmd::context_ptr ctx_vision;
common_init_result llama_init;
@@ -87,11 +141,11 @@ struct mtmd_cli_context {
llama_pos n_past = 0;
mtmd_cli_context(common_params & params) : llama_init(common_init_from_params(params)) {
model = llama_init.model.get();
lctx = llama_init.context.get();
model = llama_init.model; //.get();
lctx = llama_init.context; //.get();
vocab = llama_model_get_vocab(model);
smpl = common_sampler_init(model, params.sampling);
n_threads = params.cpuparams.n_threads;
smpl = common_sampler_init(model, params.sparams); //sampling);
n_threads = params.n_threads;
batch = llama_batch_init(1, 0, 1); // batch for next token generation
n_batch = params.n_batch;
@@ -130,7 +184,7 @@ struct mtmd_cli_context {
mtmd_context_params mparams = mtmd_context_params_default();
mparams.use_gpu = params.mmproj_use_gpu;
mparams.print_timings = true;
mparams.n_threads = params.cpuparams.n_threads;
mparams.n_threads = params.n_threads;
mparams.verbosity = params.verbosity > 0 ? GGML_LOG_LEVEL_DEBUG : GGML_LOG_LEVEL_INFO;
ctx_vision.reset(mtmd_init_from_file(clip_path, model, mparams));
if (!ctx_vision.get()) {
@@ -170,7 +224,7 @@ static int generate_response(mtmd_cli_context & ctx, int n_predict) {
llama_token token_id = common_sampler_sample(ctx.smpl, ctx.lctx, -1);
generated_tokens.push_back(token_id);
common_sampler_accept(ctx.smpl, token_id, true);
common_sampler_accept(ctx.smpl, ctx.lctx, token_id, true);
if (llama_vocab_is_eog(ctx.vocab, token_id) || ctx.check_antiprompt(generated_tokens)) {
LOG("\n");
@@ -249,11 +303,14 @@ int main(int argc, char ** argv) {
ggml_time_init();
common_params params;
params.sampling.temp = 0.2; // lower temp by default for better quality
params.sparams.temp = 0.2; // lower temp by default for better quality
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_MTMD, show_additional_info)) {
if (!gpt_params_parse(argc, argv, params)) {
return 1;
}
//if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_MTMD, show_additional_info)) {
// return 1;
//}
common_init();
@@ -264,7 +321,7 @@ int main(int argc, char ** argv) {
}
mtmd_cli_context ctx(params);
LOG("%s: loading model: %s\n", __func__, params.model.path.c_str());
LOG("%s: loading model: %s\n", __func__, params.model.c_str());
bool is_single_turn = !params.prompt.empty() && !params.image.empty();
@@ -342,7 +399,8 @@ int main(int argc, char ** argv) {
}
if (line == "/clear") {
ctx.n_past = 0;
llama_memory_seq_rm(llama_get_memory(ctx.lctx), 0, 1, -1); // keep BOS
llama_kv_cache_seq_rm(ctx.lctx, 0, 1, -1);
//llama_memory_seq_rm(llama_get_memory(ctx.lctx), 0, 1, -1); // keep BOS
LOG("Chat history cleared\n\n");
continue;
}
@@ -381,6 +439,7 @@ int main(int argc, char ** argv) {
}
if (g_is_interrupted) LOG("\nInterrupted by user\n");
LOG("\n\n");
llama_perf_context_print(ctx.lctx);
llama_print_timings(ctx.lctx);
//llama_perf_context_print(ctx.lctx);
return g_is_interrupted ? 130 : 0;
}