mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-01-26 17:20:01 +00:00
* model : Granite docling + Idefics3 preprocessing (SmolVLM) (#16206) * feat: Add granite-docling conversion using trillion pretokenizer Branch: gabe-l-hart/GraniteDocling Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> * feat: Add granite-docling vocab pre enum Branch: gabe-l-hart/GraniteDocling Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> * fix: Use granite-docling pre Branch: gabe-l-hart/GraniteDocling Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> * feat: Add clip_is_idefics3 Branch: gabe-l-hart/GraniteDocling Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> * feat: Allow multi-token boundary sequences for image templating Branch: gabe-l-hart/GraniteDocling Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> * feat: Add tiling support for idefices3 in clip.cpp This should likely be moved into llava_uhd::get_slice_instructions, but for now this avoids disrupting the logic there. Branch: gabe-l-hart/GraniteDocling Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> * feat: Partial support for full templating for idefics3 in mtmd There are still errors encoding some of the image chunks, but the token sequence now matches transformers _almost_ perfectly, except for the double newline before the global image which shows up as two consecutive newline tokens instead of a single double-newline token. I think this is happening because the blocks are tokenized separately then concatenated. Branch: gabe-l-hart/GraniteDocling Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> * feat: Fully working image preprocessing for idefics3 w/ resize and slicing Branch: gabe-l-hart/GraniteDocling Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> * feat: Parse the preprocessor config's longest side and add it to the mmproj hparams Branch: GraniteDocling Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> * fix: Use the longest side instead of size * scale_factor For Granite Docling, these come out to the same value, but that was just a conicidence. Branch: GraniteDocling Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> * fix: Allow batch encoding and remove clip_is_idefics3 Branch: GraniteDocling Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> * refactor: Remove unnecessary conditionals for empty token vectors Branch: GraniteDocling Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> * refactor: Use image_manipulation util Branch: GraniteDocling Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> * add test model --------- Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> Co-authored-by: Xuan Son Nguyen <son@huggingface.co> # Conflicts: # convert_hf_to_gguf.py # convert_hf_to_gguf_update.py # gguf-py/gguf/constants.py # gguf-py/gguf/gguf_writer.py # src/llama-vocab.cpp # src/llama-vocab.h * mtmd : support home-cooked Mistral Small Omni (#14928) * model : add LightOnOCR-1B model (#16764) * model : add LightOnOCR-1B model * add test # Conflicts: # convert_hf_to_gguf.py # gguf-py/gguf/constants.py * mtmd : fix idefics3 preprocessing (#16806) * mtmd : fix idefics3 preprocessing * disable granite test * fix test for granite * model: Add support for CogVLM model (#15002) * Added GGUF mappings for CogVLM model * Add tensor mapping for CogVLM visual encoder * Add CogVLM to conversion script, no vision part yet * Added CogVLM vision model to conversion script * Add graph for CogVLM CLIP model * Add graph for CogVLM * Fixes for CogVLM. Now compiles. * Model now runs * Fixes for cogvlm graph * Account for graph context change after rebase * Changes for whitespace * Changes in convert script according to comments * Switch CogVLM LLM graph to merged QKV tensor * Use rope_type variable instead of direct definition * Change CogVLM CLIP encoder to use SWIGLU * Switch CogVLM CLIP to use merged QKV * Apply rebase edits and remove ggml_cont call that is now unnecessary * clean up --------- Co-authored-by: Xuan Son Nguyen <son@huggingface.co> # Conflicts: # convert_hf_to_gguf.py # examples/mtmd/clip.cpp # gguf-py/gguf/constants.py # gguf-py/gguf/tensor_mapping.py # src/llama-arch.cpp # src/llama-arch.h # src/llama-model.cpp # src/llama-model.h * mtmd: refactor preprocessing + support max/min pixels (#16878) * mtmd: refactor preprocessing + support max/min pixels * fix mlp type * implement mix/max pixels * improve hparams * better image preproc for qwen * fix * fix out of bound composite * fix (2) * fix token calculation * get_merge_kernel_size() * fix llama4 and lfm2 * gonna fix them all * use simple resize for qwen * qwen: increase min tokens * no resize if dst size == src size * restore to initial min/max tokens value for qwen # Conflicts: # examples/mtmd/clip.cpp * clip : use FA (#16837) * clip : use FA * cont : add warning about unsupported ops * implement "auto" mode for clip flash attn * clip : print more detailed op support info during warmup * cont : remove obsolete comment [no ci] * improve debugging message * trailing space * metal : remove stray return --------- Co-authored-by: Xuan Son Nguyen <son@huggingface.co> * model: add Janus Pro for image understanding (#16906) * Add support for Janus Pro * Update gguf-py/gguf/tensor_mapping.py Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com> * Update gguf-py/gguf/tensor_mapping.py Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com> * Address reviewer suggestions Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com> * Add JANUS_PRO constant * Update clip model handling Co-authored-by: Xuan-Son Nguyen <son@huggingface.co> * Update tools/mtmd/clip.cpp Co-authored-by: Xuan-Son Nguyen <thichthat@gmail.com> * Refactor JANUS_PRO handling in clip.cpp Co-authored-by: Xuan-Son Nguyen <son@huggingface.co> * Update tools/mtmd/clip.cpp Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com> * em whitespace --------- Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com> Co-authored-by: Xuan-Son Nguyen <son@huggingface.co> Co-authored-by: Xuan-Son Nguyen <thichthat@gmail.com> # Conflicts: # convert_hf_to_gguf.py # gguf-py/gguf/constants.py # gguf-py/gguf/tensor_mapping.py * mtmd: pad mask for qwen2.5vl (#16954) * mtmd: pad mask for qwen2.5vl * improve * mtmd: add --image-min/max-tokens (#16921) * mtmd: improve struct initialization (#16981) * mtmd: allow QwenVL to process larger image by default (#17020) * Disable flash attention * mtmd : fix embedding size for image input (#17123) * mtmd: fix patch_size initialized to random value in audio models (#17128) * mtmd: fix patch_size initialized to random value in audio models * add default hparams * add llama_model_n_embd_inp * Fix load qwen3 vl Change batch size * Add description * Fix cli build error --------- Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> Co-authored-by: Gabe Goodhart <ghart@us.ibm.com> Co-authored-by: Xuan Son Nguyen <son@huggingface.co> Co-authored-by: Tianyue-Zhao <zhaotianyue@outlook.com> Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> Co-authored-by: Zhiyong Wang <85110830+ravenouse@users.noreply.github.com> Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com> Co-authored-by: Xuan-Son Nguyen <thichthat@gmail.com> Co-authored-by: firecoperana <firecoperana>
304 lines
11 KiB
C++
304 lines
11 KiB
C++
#ifndef MTMD_H
|
|
#define MTMD_H
|
|
|
|
#include "ggml.h"
|
|
#include "llama.h"
|
|
|
|
#include <stddef.h>
|
|
#include <stdint.h>
|
|
#include <stdbool.h>
|
|
|
|
#ifdef __cplusplus
|
|
#include <string>
|
|
#include <vector>
|
|
#include <cinttypes>
|
|
#include <memory>
|
|
#endif
|
|
|
|
/**
|
|
* libmtmd: A library for multimodal support in llama.cpp.
|
|
*
|
|
* WARNING: This API is experimental and subject to many BREAKING CHANGES.
|
|
* Issues related to API usage may receive lower priority support.
|
|
*
|
|
* For the usage, see an example in mtmd-cli.cpp
|
|
*/
|
|
|
|
#ifdef LLAMA_SHARED
|
|
# if defined(_WIN32) && !defined(__MINGW32__)
|
|
# ifdef LLAMA_BUILD
|
|
# define MTMD_API __declspec(dllexport)
|
|
# else
|
|
# define MTMD_API __declspec(dllimport)
|
|
# endif
|
|
# else
|
|
# define MTMD_API __attribute__ ((visibility ("default")))
|
|
# endif
|
|
#else
|
|
# define MTMD_API
|
|
#endif
|
|
|
|
// deprecated marker, use mtmd_default_marker() instead
|
|
#define MTMD_DEFAULT_IMAGE_MARKER "<__image__>"
|
|
|
|
#ifdef __cplusplus
|
|
extern "C" {
|
|
#endif
|
|
|
|
enum mtmd_input_chunk_type {
|
|
MTMD_INPUT_CHUNK_TYPE_TEXT,
|
|
MTMD_INPUT_CHUNK_TYPE_IMAGE,
|
|
MTMD_INPUT_CHUNK_TYPE_AUDIO,
|
|
};
|
|
|
|
// opaque types
|
|
struct mtmd_context;
|
|
struct mtmd_bitmap;
|
|
struct mtmd_image_tokens;
|
|
struct mtmd_input_chunk;
|
|
struct mtmd_input_chunks;
|
|
|
|
struct mtmd_input_text {
|
|
const char * text;
|
|
bool add_special;
|
|
bool parse_special;
|
|
};
|
|
|
|
//
|
|
// C API
|
|
//
|
|
|
|
typedef struct mtmd_context mtmd_context;
|
|
typedef struct mtmd_bitmap mtmd_bitmap;
|
|
typedef struct mtmd_image_tokens mtmd_image_tokens;
|
|
typedef struct mtmd_input_chunk mtmd_input_chunk;
|
|
typedef struct mtmd_input_chunks mtmd_input_chunks;
|
|
typedef struct mtmd_input_text mtmd_input_text;
|
|
|
|
struct mtmd_context_params {
|
|
bool use_gpu;
|
|
bool print_timings;
|
|
int n_threads;
|
|
enum ggml_log_level verbosity;
|
|
const char * image_marker; // deprecated, use media_marker instead
|
|
const char * media_marker;
|
|
enum llama_flash_attn_type flash_attn_type;
|
|
|
|
// limit number of image tokens, only for vision models with dynamic resolution
|
|
int image_min_tokens; // minimum number of tokens for image input (default: read from metadata)
|
|
int image_max_tokens; // maximum number of tokens for image input (default: read from metadata)
|
|
};
|
|
|
|
MTMD_API const char * mtmd_default_marker(void);
|
|
|
|
MTMD_API struct mtmd_context_params mtmd_context_params_default(void);
|
|
|
|
// initialize the mtmd context
|
|
// return nullptr on failure
|
|
MTMD_API mtmd_context * mtmd_init_from_file(const char * mmproj_fname,
|
|
const struct llama_model * text_model,
|
|
const struct mtmd_context_params ctx_params);
|
|
|
|
MTMD_API void mtmd_free(mtmd_context * ctx);
|
|
|
|
// whether we need to set non-causal mask before llama_decode
|
|
MTMD_API bool mtmd_decode_use_non_causal(mtmd_context * ctx);
|
|
|
|
// whether the current model use M-RoPE for llama_decode
|
|
MTMD_API bool mtmd_decode_use_mrope(mtmd_context * ctx);
|
|
|
|
// whether the current model supports vision input
|
|
MTMD_API bool mtmd_support_vision(mtmd_context * ctx);
|
|
|
|
// whether the current model supports audio input
|
|
MTMD_API bool mtmd_support_audio(mtmd_context * ctx);
|
|
|
|
// get audio bitrate in Hz, for example 16000 for Whisper
|
|
// return -1 if audio is not supported
|
|
MTMD_API int mtmd_get_audio_bitrate(mtmd_context * ctx);
|
|
|
|
// mtmd_bitmap
|
|
//
|
|
// if bitmap is image:
|
|
// length of data must be nx * ny * 3
|
|
// the data is in RGBRGBRGB... format
|
|
// if bitmap is audio:
|
|
// length of data must be n_samples * sizeof(float)
|
|
// the data is in float format (PCM F32)
|
|
MTMD_API mtmd_bitmap * mtmd_bitmap_init (uint32_t nx, uint32_t ny, const unsigned char * data);
|
|
MTMD_API mtmd_bitmap * mtmd_bitmap_init_from_audio(size_t n_samples, const float * data);
|
|
MTMD_API uint32_t mtmd_bitmap_get_nx (const mtmd_bitmap * bitmap);
|
|
MTMD_API uint32_t mtmd_bitmap_get_ny (const mtmd_bitmap * bitmap);
|
|
MTMD_API const unsigned char * mtmd_bitmap_get_data (const mtmd_bitmap * bitmap);
|
|
MTMD_API size_t mtmd_bitmap_get_n_bytes(const mtmd_bitmap * bitmap);
|
|
MTMD_API bool mtmd_bitmap_is_audio (const mtmd_bitmap * bitmap);
|
|
MTMD_API void mtmd_bitmap_free (mtmd_bitmap * bitmap);
|
|
// bitmap ID is optional, but useful for KV cache tracking
|
|
// these getters/setters are dedicated functions, so you can for example calculate the hash of the image based on mtmd_bitmap_get_data()
|
|
MTMD_API const char * mtmd_bitmap_get_id(const mtmd_bitmap * bitmap);
|
|
MTMD_API void mtmd_bitmap_set_id(mtmd_bitmap * bitmap, const char * id);
|
|
|
|
|
|
// mtmd_input_chunks
|
|
//
|
|
// this is simply a list of mtmd_input_chunk
|
|
// the elements can only be populated via mtmd_tokenize()
|
|
MTMD_API mtmd_input_chunks * mtmd_input_chunks_init(void);
|
|
MTMD_API size_t mtmd_input_chunks_size(const mtmd_input_chunks * chunks);
|
|
MTMD_API const mtmd_input_chunk * mtmd_input_chunks_get (const mtmd_input_chunks * chunks, size_t idx);
|
|
MTMD_API void mtmd_input_chunks_free(mtmd_input_chunks * chunks);
|
|
|
|
// mtmd_input_chunk
|
|
//
|
|
// the instance will be constructed via mtmd_tokenize()
|
|
// it will be freed along with mtmd_input_chunks
|
|
MTMD_API enum mtmd_input_chunk_type mtmd_input_chunk_get_type (const mtmd_input_chunk * chunk);
|
|
MTMD_API const llama_token * mtmd_input_chunk_get_tokens_text (const mtmd_input_chunk * chunk, size_t * n_tokens_output);
|
|
MTMD_API const mtmd_image_tokens * mtmd_input_chunk_get_tokens_image(const mtmd_input_chunk * chunk);
|
|
MTMD_API size_t mtmd_input_chunk_get_n_tokens (const mtmd_input_chunk * chunk);
|
|
// returns nullptr for ID on text chunk
|
|
MTMD_API const char * mtmd_input_chunk_get_id (const mtmd_input_chunk * chunk);
|
|
// number of temporal positions (always 1 for M-RoPE, n_tokens otherwise)
|
|
MTMD_API llama_pos mtmd_input_chunk_get_n_pos (const mtmd_input_chunk * chunk);
|
|
|
|
// in case you want to use custom logic to handle the chunk (i.e. KV cache management)
|
|
// you can move the chunk ownership to your own code by copying it
|
|
// remember to free the chunk when you are done with it
|
|
MTMD_API mtmd_input_chunk * mtmd_input_chunk_copy(const mtmd_input_chunk * chunk);
|
|
MTMD_API void mtmd_input_chunk_free(mtmd_input_chunk * chunk);
|
|
|
|
|
|
// mtmd_image_tokens
|
|
//
|
|
// the instance will be constructed via mtmd_tokenize()
|
|
// it will be freed along with mtmd_input_chunk
|
|
MTMD_API size_t mtmd_image_tokens_get_n_tokens(const mtmd_image_tokens * image_tokens); // TODO: deprecate
|
|
MTMD_API size_t mtmd_image_tokens_get_nx (const mtmd_image_tokens * image_tokens);
|
|
MTMD_API size_t mtmd_image_tokens_get_ny (const mtmd_image_tokens * image_tokens);
|
|
MTMD_API const char * mtmd_image_tokens_get_id (const mtmd_image_tokens * image_tokens); // TODO: deprecate
|
|
// number of temporal positions (always 1 for M-RoPE, n_tokens otherwise)
|
|
MTMD_API llama_pos mtmd_image_tokens_get_n_pos (const mtmd_image_tokens * image_tokens); // TODO: deprecate
|
|
|
|
// tokenize an input text prompt and a list of bitmaps (images/audio)
|
|
// the prompt must have the input image marker (default: "<__media__>") in it
|
|
// the default marker is defined by mtmd_default_marker()
|
|
// the marker will be replaced with the image/audio chunk
|
|
// for example:
|
|
// "here is an image: <__media__>\ndescribe it in detail."
|
|
// this will gives 3 chunks:
|
|
// 1. "here is an image: <start_of_image>"
|
|
// 2. (image/audio tokens)
|
|
// 3. "<end_of_image>\ndescribe it in detail."
|
|
// number of bitmaps must be equal to the number of markers in the prompt
|
|
// this function is thread-safe (shared ctx)
|
|
// return values:
|
|
// 0 on success
|
|
// 1 on number of bitmaps not matching the number of markers
|
|
// 2 on image preprocessing error
|
|
MTMD_API int32_t mtmd_tokenize(mtmd_context * ctx,
|
|
mtmd_input_chunks * output,
|
|
const mtmd_input_text * text,
|
|
const mtmd_bitmap ** bitmaps,
|
|
size_t n_bitmaps);
|
|
|
|
// returns 0 on success
|
|
// TODO: deprecate
|
|
MTMD_API int32_t mtmd_encode(mtmd_context * ctx,
|
|
const mtmd_image_tokens * image_tokens);
|
|
|
|
// returns 0 on success
|
|
MTMD_API int32_t mtmd_encode_chunk(mtmd_context * ctx,
|
|
const mtmd_input_chunk * chunk);
|
|
|
|
// get output embeddings from the last encode pass
|
|
// the reading size (in bytes) is equal to:
|
|
// llama_model_n_embd(model) * mtmd_input_chunk_get_n_tokens(chunk) * sizeof(float)
|
|
MTMD_API float * mtmd_get_output_embd(mtmd_context * ctx);
|
|
|
|
/////////////////////////////////////////
|
|
|
|
// test function, to be used in test-mtmd-c-api.c
|
|
MTMD_API mtmd_input_chunks * mtmd_test_create_input_chunks(void);
|
|
|
|
#ifdef __cplusplus
|
|
} // extern "C"
|
|
#endif
|
|
|
|
//
|
|
// C++ wrappers
|
|
//
|
|
|
|
#ifdef __cplusplus
|
|
|
|
namespace mtmd {
|
|
|
|
struct mtmd_context_deleter {
|
|
void operator()(mtmd_context * val) { mtmd_free(val); }
|
|
};
|
|
using context_ptr = std::unique_ptr<mtmd_context, mtmd_context_deleter>;
|
|
|
|
struct mtmd_bitmap_deleter {
|
|
void operator()(mtmd_bitmap * val) { mtmd_bitmap_free(val); }
|
|
};
|
|
using bitmap_ptr = std::unique_ptr<mtmd_bitmap, mtmd_bitmap_deleter>;
|
|
|
|
struct mtmd_input_chunks_deleter {
|
|
void operator()(mtmd_input_chunks * val) { mtmd_input_chunks_free(val); }
|
|
};
|
|
using input_chunks_ptr = std::unique_ptr<mtmd_input_chunks, mtmd_input_chunks_deleter>;
|
|
|
|
struct mtmd_input_chunk_deleter {
|
|
void operator()(mtmd_input_chunk * val) { mtmd_input_chunk_free(val); }
|
|
};
|
|
using input_chunk_ptr = std::unique_ptr<mtmd_input_chunk, mtmd_input_chunk_deleter>;
|
|
|
|
struct bitmap {
|
|
bitmap_ptr ptr;
|
|
bitmap() : ptr(nullptr) {}
|
|
bitmap(mtmd_bitmap * bitmap) : ptr(bitmap) {}
|
|
bitmap(bitmap && other) noexcept : ptr(std::move(other.ptr)) {}
|
|
bitmap(uint32_t nx, uint32_t ny, const unsigned char * data) {
|
|
ptr.reset(mtmd_bitmap_init(nx, ny, data));
|
|
}
|
|
~bitmap() = default;
|
|
uint32_t nx() { return mtmd_bitmap_get_nx(ptr.get()); }
|
|
uint32_t ny() { return mtmd_bitmap_get_ny(ptr.get()); }
|
|
const unsigned char * data() { return mtmd_bitmap_get_data(ptr.get()); }
|
|
size_t n_bytes() { return mtmd_bitmap_get_n_bytes(ptr.get()); }
|
|
std::string id() { return mtmd_bitmap_get_id(ptr.get()); }
|
|
void set_id(const char * id) { mtmd_bitmap_set_id(ptr.get(), id); }
|
|
};
|
|
|
|
struct bitmaps {
|
|
std::vector<bitmap> entries;
|
|
~bitmaps() = default;
|
|
// return list of pointers to mtmd_bitmap
|
|
// example:
|
|
// auto bitmaps_c_ptr = bitmaps.c_ptr();
|
|
// int32_t res = mtmd_tokenize(... bitmaps_c_ptr.data(), bitmaps_c_ptr.size());
|
|
std::vector<const mtmd_bitmap *> c_ptr() {
|
|
std::vector<const mtmd_bitmap *> res(entries.size());
|
|
for (size_t i = 0; i < entries.size(); i++) {
|
|
res[i] = entries[i].ptr.get();
|
|
}
|
|
return res;
|
|
}
|
|
};
|
|
|
|
struct input_chunks {
|
|
input_chunks_ptr ptr;
|
|
input_chunks() = default;
|
|
input_chunks(mtmd_input_chunks * chunks) : ptr(chunks) {}
|
|
~input_chunks() = default;
|
|
size_t size() { return mtmd_input_chunks_size(ptr.get()); }
|
|
const mtmd_input_chunk * operator[](size_t idx) {
|
|
return mtmd_input_chunks_get(ptr.get(), idx);
|
|
}
|
|
};
|
|
|
|
} // namespace mtmd
|
|
|
|
#endif
|
|
|
|
#endif
|