Feat - add kimi 2.5 Vision (#1280)

* port kimi 25-vision  from upstream

* feat(clip): add support for Kimi K2.5 vision model
This commit is contained in:
Samuel Oliveira Alves
2026-02-19 04:15:03 -03:00
committed by GitHub
parent 04cf685e82
commit 51df09be8a
5 changed files with 453 additions and 140 deletions

View File

@@ -31,6 +31,8 @@
// vision-specific
#define KEY_VISION_PROJ_TYPE "clip.vision.projector_type" // for models with mixed modalities
#define KEY_IMAGE_SIZE "clip.vision.image_size"
#define KEY_IMAGE_MIN_PIXELS "clip.vision.image_min_pixels"
#define KEY_IMAGE_MAX_PIXELS "clip.vision.image_max_pixels"
#define KEY_PREPROC_IMAGE_SIZE "clip.vision.preproc_image_size"
#define KEY_PATCH_SIZE "clip.vision.patch_size"
#define KEY_IMAGE_MEAN "clip.vision.image_mean"
@@ -152,6 +154,7 @@ enum projector_type {
PROJECTOR_TYPE_VOXTRAL,
PROJECTOR_TYPE_LFM2,
PROJECTOR_TYPE_KIMIVL,
PROJECTOR_TYPE_KIMIK25,
PROJECTOR_TYPE_LIGHTONOCR,
PROJECTOR_TYPE_COGVLM,
PROJECTOR_TYPE_JANUS_PRO,
@@ -178,6 +181,7 @@ static std::map<projector_type, std::string> PROJECTOR_TYPE_NAMES = {
{ PROJECTOR_TYPE_VOXTRAL, "voxtral"},
{ PROJECTOR_TYPE_LFM2, "lfm2"},
{ PROJECTOR_TYPE_KIMIVL, "kimivl"},
{ PROJECTOR_TYPE_KIMIK25, "kimik25"},
{ PROJECTOR_TYPE_LIGHTONOCR,"lightonocr"},
{ PROJECTOR_TYPE_COGVLM, "cogvlm"},
{ PROJECTOR_TYPE_JANUS_PRO, "janus_pro"},

View File

@@ -35,6 +35,8 @@
#include <array>
#include <functional>
#define DEFAULT_INTERPOLATION_MODE (GGML_SCALE_MODE_BILINEAR | GGML_SCALE_FLAG_ALIGN_CORNERS)
// TODO: allow to pass callback from user code
struct clip_logger_state g_logger_state = {GGML_LOG_LEVEL_CONT, clip_log_callback_default, NULL};
@@ -511,6 +513,10 @@ struct clip_ctx {
}
};
//
// clip_graph
//
struct clip_graph {
clip_ctx * ctx;
const clip_model & model;
@@ -1406,6 +1412,75 @@ struct clip_graph {
return gf;
}
ggml_cgraph * build_kimik25() {
ggml_tensor * pos_h = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_patches);
ggml_set_name(pos_h, "pos_h");
ggml_set_input(pos_h);
ggml_tensor * pos_w = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_patches);
ggml_set_name(pos_w, "pos_w");
ggml_set_input(pos_w);
ggml_tensor * learned_pos_embd = resize_position_embeddings_3d(GGML_SCALE_MODE_BICUBIC);
// Kimi-K2.5 uses interleaved 2D RoPE pattern natively, but
// Q / K are permuted during conversion to use split format.
auto add_pos = [&](ggml_tensor * cur, const clip_layer &) {
cur = build_rope_2d(ctx0, cur, pos_w, pos_h, hparams.rope_theta, false);
return cur;
};
ggml_tensor * inp = build_inp();
// I don't know why, but doing this in the build_vit lead to the ggml_add not occurring?
// Doing it manually here does work.
inp = ggml_add(ctx0, inp, learned_pos_embd);
ggml_tensor * cur = build_vit(
inp, n_patches,
NORM_TYPE_NORMAL,
hparams.ffn_op,
nullptr,
add_pos);
cb(cur, "vit_out", -1);
{
// patch_merger
const int scale_factor = model.hparams.n_merge;
cur = build_patch_merge_permute(cur, scale_factor);
// projection norm
int proj_inp_dim = cur->ne[0];
int n_merged_patches = cur->ne[1];
cur = ggml_view_2d(ctx0, cur,
n_embd, n_merged_patches * scale_factor * scale_factor,
ggml_row_size(cur->type, n_embd), 0);
cur = ggml_norm(ctx0, cur, hparams.eps);
cur = ggml_mul(ctx0, cur, model.mm_input_norm_w);
cur = ggml_add(ctx0, cur, model.mm_input_norm_b);
cur = ggml_view_2d(ctx0, cur,
proj_inp_dim, n_merged_patches,
ggml_row_size(cur->type, proj_inp_dim), 0);
cb(cur, "proj_inp_normed", -1);
// projection mlp
cur = build_ffn(cur,
model.mm_1_w, model.mm_1_b,
nullptr, nullptr,
model.mm_2_w, model.mm_2_b,
FFN_GELU,
-1);
cb(cur, "proj_out", -1);
}
// build the graph
ggml_build_forward_expand(gf, cur);
return gf;
}
// this graph is used by llava, granite and glm
// due to having embedding_stack (used by granite), we cannot reuse build_vit
ggml_cgraph * build_llava() {
@@ -1985,23 +2060,20 @@ private:
// utility functions
//
void cb(ggml_tensor * cur0, const char * name, int il) const {
if (ctx->debug_graph) {
ggml_tensor * cur = ggml_cpy(ctx0, cur0, ggml_dup_tensor(ctx0, cur0));
std::string cur_name = il >= 0 ? std::string(name) + "_" + std::to_string(il) : name;
ggml_set_name(cur, cur_name.c_str());
ggml_set_output(cur);
ggml_build_forward_expand(gf, cur);
ctx->debug_print_tensors.push_back(cur);
void cb(ggml_tensor * cur, const char * name, int il) const {
if (il >= 0) {
ggml_format_name(cur, "%s-%d", name, il);
} else {
ggml_set_name(cur, name);
}
}
// siglip2 naflex
ggml_tensor * resize_position_embeddings() {
ggml_tensor * resize_position_embeddings(uint32_t interpolation_mode = DEFAULT_INTERPOLATION_MODE) {
ggml_tensor * pos_embd = model.position_embeddings;
const int height = img.ny / patch_size;
const int width = img.nx / patch_size;
const uint32_t mode = GGML_SCALE_MODE_BILINEAR;
const uint32_t mode = interpolation_mode;
const int n_per_side = (int)std::sqrt(pos_embd->ne[1]);
GGML_ASSERT(pos_embd);
@@ -2054,34 +2126,66 @@ private:
// self-attention
{
ggml_tensor * Qcur = ggml_mul_mat(ctx0, layer.q_w, cur);
if (layer.q_b) {
Qcur = ggml_add(ctx0, Qcur, layer.q_b);
}
ggml_tensor * Qcur = nullptr;
ggml_tensor * Kcur = nullptr;
ggml_tensor * Vcur = nullptr;
if (layer.qkv_w != nullptr) {
// fused qkv
cur = ggml_mul_mat(ctx0, layer.qkv_w, cur);
if (layer.qkv_b != nullptr) {
cur = ggml_add(ctx0, cur, layer.qkv_b);
}
ggml_tensor * Kcur = ggml_mul_mat(ctx0, layer.k_w, cur);
if (layer.k_b) {
Kcur = ggml_add(ctx0, Kcur, layer.k_b);
}
Qcur = ggml_view_3d(ctx0, cur, d_head, n_head, n_pos,
/* nb1 */ ggml_row_size(cur->type, d_head),
/* nb2 */ cur->nb[1],
/* offset */ 0);
ggml_tensor * Vcur = ggml_mul_mat(ctx0, layer.v_w, cur);
if (layer.v_b) {
Vcur = ggml_add(ctx0, Vcur, layer.v_b);
}
Kcur = ggml_view_3d(ctx0, cur, d_head, n_head, n_pos,
/* nb1 */ ggml_row_size(cur->type, d_head),
/* nb2 */ cur->nb[1],
/* offset */ ggml_row_size(cur->type, n_embd));
if (layer.q_norm) {
Qcur = build_norm(Qcur, layer.q_norm, NULL, norm_t, eps, il);
cb(Qcur, "Qcur_norm", il);
}
Vcur = ggml_view_3d(ctx0, cur, d_head, n_head, n_pos,
/* nb1 */ ggml_row_size(cur->type, d_head),
/* nb2 */ cur->nb[1],
/* offset */ ggml_row_size(cur->type, 2 * n_embd));
if (layer.k_norm) {
Kcur = build_norm(Kcur, layer.k_norm, NULL, norm_t, eps, il);
cb(Kcur, "Kcur_norm", il);
}
// TODO: q/k norm requires row size == n_embd, while here it's d_head
// we can add support in the future if needed
GGML_ASSERT(layer.q_norm == nullptr && layer.k_norm == nullptr);
Qcur = ggml_reshape_3d(ctx0, Qcur, d_head, n_head, n_pos);
Kcur = ggml_reshape_3d(ctx0, Kcur, d_head, n_head, n_pos);
Vcur = ggml_reshape_3d(ctx0, Vcur, d_head, n_head, n_pos);
} else {
// separate q, k, v
Qcur = ggml_mul_mat(ctx0, layer.q_w, cur);
if (layer.q_b) {
Qcur = ggml_add(ctx0, Qcur, layer.q_b);
}
Kcur = ggml_mul_mat(ctx0, layer.k_w, cur);
if (layer.k_b) {
Kcur = ggml_add(ctx0, Kcur, layer.k_b);
}
Vcur = ggml_mul_mat(ctx0, layer.v_w, cur);
if (layer.v_b) {
Vcur = ggml_add(ctx0, Vcur, layer.v_b);
}
if (layer.q_norm) {
Qcur = build_norm(Qcur, layer.q_norm, NULL, norm_t, eps, il);
cb(Qcur, "Qcur_norm", il);
}
if (layer.k_norm) {
Kcur = build_norm(Kcur, layer.k_norm, NULL, norm_t, eps, il);
cb(Kcur, "Kcur_norm", il);
}
Qcur = ggml_reshape_3d(ctx0, Qcur, d_head, n_head, n_pos);
Kcur = ggml_reshape_3d(ctx0, Kcur, d_head, n_head, n_pos);
Vcur = ggml_reshape_3d(ctx0, Vcur, d_head, n_head, n_pos);
}
cb(Qcur, "Qcur", il);
cb(Kcur, "Kcur", il);
@@ -2186,19 +2290,14 @@ private:
? ggml_rms_norm(ctx0, cur, norm_eps)
: ggml_norm(ctx0, cur, norm_eps);
if (mw || mb) {
cb(cur, "norm", il);
}
if (mw) {
cur = ggml_mul(ctx0, cur, mw);
if (mb) {
cb(cur, "norm_w", il);
}
cb(cur, "norm_w", il);
}
if (mb) {
cur = ggml_add(ctx0, cur, mb);
cb(cur, "norm_b", il);
}
return cur;
@@ -2383,8 +2482,8 @@ private:
{
first = ggml_view_3d(ctx0, cur,
n_dim/2, n_head, n_pos,
ggml_row_size(cur->type, n_dim),
ggml_row_size(cur->type, n_dim*n_head),
cur->nb[1],
cur->nb[2],
0);
first = ggml_rope_ext(
ctx0,
@@ -2402,8 +2501,8 @@ private:
{
second = ggml_view_3d(ctx0, cur,
n_dim/2, n_head, n_pos,
ggml_row_size(cur->type, n_dim),
ggml_row_size(cur->type, n_dim*n_head),
cur->nb[1],
cur->nb[2],
n_dim/2 * ggml_element_size(cur));
second = ggml_rope_ext(
ctx0,
@@ -2454,6 +2553,34 @@ private:
return cur;
}
// note: this is similar to resize_position_embeddings, major difference is having
// the w/h in ne[1] and ne[2] instead of assuming with sqrt. Could try storing the tensor in 2D instead
// with a w*h? Also the permute is a bit different at (2, 1, 0, 3) instead of (2, 0, 1, 3).
ggml_tensor * resize_position_embeddings_3d(uint32_t interpolation_mode) {
ggml_tensor * pos_embd = model.position_embeddings;
const int height = img.ny / patch_size;
const int width = img.nx / patch_size;
const uint32_t mode = interpolation_mode;
GGML_ASSERT(pos_embd);
const int64_t stored_c = pos_embd->ne[0]; // C = 1152
const int64_t orig_w = pos_embd->ne[1]; // W = 64
const int64_t orig_h = pos_embd->ne[2]; // H = 64
GGML_ASSERT(stored_c == n_embd);
if (height == (int)orig_h && width == (int)orig_w) {
// No interpolation needed, just flatten to [C, H*W]
return ggml_cont_2d(ctx0, pos_embd, n_embd, width * height);
}
pos_embd = ggml_permute(ctx0, pos_embd, 2, 1, 0, 3);
pos_embd = ggml_interpolate(ctx0, pos_embd, height, width, n_embd, 1, mode);
pos_embd = ggml_permute(ctx0, pos_embd, 2, 1, 0, 3);
pos_embd = ggml_cont_2d(ctx0, pos_embd, n_embd, width * height);
return pos_embd;
}
};
static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32_batch & imgs) {
@@ -2505,6 +2632,10 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
{
res = graph.build_kimivl();
} break;
case PROJECTOR_TYPE_KIMIK25:
{
res = graph.build_kimik25();
} break;
case PROJECTOR_TYPE_JANUS_PRO:
{
res = graph.build_siglip();
@@ -2797,6 +2928,23 @@ struct clip_model_loader {
hparams.set_limit_image_tokens(8, 1024);
hparams.set_warmup_n_tokens(256); // avoid OOM on warmup
} break;
case PROJECTOR_TYPE_KIMIK25:
{
hparams.rope_theta = 10000.0f;
get_u32(KEY_PROJ_SCALE_FACTOR, hparams.n_merge, false);
int min_pixels = 0, max_pixels = 0;
get_u32(KEY_IMAGE_MIN_PIXELS, min_pixels, false);
get_u32(KEY_IMAGE_MAX_PIXELS, max_pixels, false);
if (min_pixels > 0 && max_pixels > 0) {
hparams.image_min_pixels = min_pixels;
hparams.image_max_pixels = max_pixels;
hparams.warmup_image_size = static_cast<int>(std::sqrt(max_pixels));
} else {
hparams.set_limit_image_tokens(2, 4096);
}
hparams.set_warmup_n_tokens(256); // avoid OOM on warmup
} break;
case PROJECTOR_TYPE_GEMMA3:
{
// default value (used by all model sizes in gemma 3 family)
@@ -2944,21 +3092,11 @@ struct clip_model_loader {
model.layers.resize(hparams.n_layer);
for (int il = 0; il < hparams.n_layer; ++il) {
auto & layer = model.layers[il];
//layer.k_w = get_tensor(string_format(TN_ATTN_K, prefix, il, "weight"), false);
//layer.q_w = get_tensor(string_format(TN_ATTN_Q, prefix, il, "weight"), false);
//layer.v_w = get_tensor(string_format(TN_ATTN_V, prefix, il, "weight"), false);
// try combined qkv weight first; if absent, require separate k/q/v weights
layer.qkv_w = get_tensor(string_format(TN_ATTN_QKV, prefix, il, "weight"), false);
if (!layer.qkv_w) {
// combined not present => require separate tensors (no 'false' argument because tensors always required)
layer.k_w = get_tensor(string_format(TN_ATTN_K, prefix, il, "weight"));
layer.q_w = get_tensor(string_format(TN_ATTN_Q, prefix, il, "weight"));
layer.v_w = get_tensor(string_format(TN_ATTN_V, prefix, il, "weight"));
}
// other attention tensors (output / norms / ln) left as-is
layer.k_w = get_tensor(string_format(TN_ATTN_K, prefix, il, "weight"), false);
layer.q_w = get_tensor(string_format(TN_ATTN_Q, prefix, il, "weight"), false);
layer.v_w = get_tensor(string_format(TN_ATTN_V, prefix, il, "weight"), false);
layer.o_w = get_tensor(string_format(TN_ATTN_OUTPUT, prefix, il, "weight"));
//layer.qkv_w = get_tensor(string_format(TN_ATTN_QKV, prefix, il, "weight"), false);
layer.qkv_w = get_tensor(string_format(TN_ATTN_QKV, prefix, il, "weight"), false);
layer.k_norm = get_tensor(string_format(TN_ATTN_K_NORM, prefix, il, "weight"), false);
layer.q_norm = get_tensor(string_format(TN_ATTN_Q_NORM, prefix, il, "weight"), false);
layer.ln_1_w = get_tensor(string_format(TN_LN_1, prefix, il, "weight"), false);
@@ -2966,18 +3104,11 @@ struct clip_model_loader {
layer.ls_1_w = get_tensor(string_format(TN_LS_1, prefix, il, "weight"), false); // no bias
layer.ls_2_w = get_tensor(string_format(TN_LS_2, prefix, il, "weight"), false); // no bias
// try combined qkv bias first; if absent, require separate k/q/v biases
layer.qkv_b = get_tensor(string_format(TN_ATTN_QKV, prefix, il, "bias"), false);
if (!layer.qkv_b) {
// combined not present => require separate biases ('false' because tensors not required)
layer.k_b = get_tensor(string_format(TN_ATTN_K, prefix, il, "bias"), false);
layer.q_b = get_tensor(string_format(TN_ATTN_Q, prefix, il, "bias"), false);
layer.v_b = get_tensor(string_format(TN_ATTN_V, prefix, il, "bias"), false);
}
// keep other optional biases as before
layer.k_b = get_tensor(string_format(TN_ATTN_K, prefix, il, "bias"), false);
layer.q_b = get_tensor(string_format(TN_ATTN_Q, prefix, il, "bias"), false);
layer.v_b = get_tensor(string_format(TN_ATTN_V, prefix, il, "bias"), false);
layer.o_b = get_tensor(string_format(TN_ATTN_OUTPUT, prefix, il, "bias"), false);
//layer.qkv_b = get_tensor(string_format(TN_ATTN_QKV, prefix, il, "bias"), false);
layer.qkv_b = get_tensor(string_format(TN_ATTN_QKV, prefix, il, "bias"), false);
layer.ln_1_b = get_tensor(string_format(TN_LN_1, prefix, il, "bias"), false);
layer.ln_2_b = get_tensor(string_format(TN_LN_2, prefix, il, "bias"), false);
@@ -3154,6 +3285,7 @@ struct clip_model_loader {
} break;
case PROJECTOR_TYPE_LFM2:
case PROJECTOR_TYPE_KIMIVL:
case PROJECTOR_TYPE_KIMIK25:
{
model.mm_input_norm_w = get_tensor(TN_MM_INP_NORM);
model.mm_input_norm_b = get_tensor(TN_MM_INP_NORM_B);
@@ -3302,12 +3434,30 @@ struct clip_model_loader {
};
static void warmup(clip_ctx & ctx_clip) {
// create a fake batch
const auto & hparams = ctx_clip.model.hparams;
clip_image_f32_batch batch;
clip_image_f32_ptr img(clip_image_f32_init());
if (ctx_clip.model.modality == CLIP_MODALITY_VISION) {
img->nx = hparams.warmup_image_size;
img->ny = hparams.warmup_image_size;
LOG_INF("%s: warmup with image size = %d x %d\n", __func__, img->nx, img->ny);
} else {
img->nx = hparams.warmup_audio_size;
img->ny = hparams.n_mel_bins;
LOG_INF("%s: warmup with audio size = %d\n", __func__, img->nx);
}
batch.entries.push_back(std::move(img));
warmup(ctx_clip, batch);
}
static void warmup(clip_ctx & ctx_clip, const clip_image_f32_batch & batch) {
support_info_graph info;
if (ctx_clip.flash_attn_type == CLIP_FLASH_ATTN_TYPE_AUTO) {
// try to enable flash attention to see if it's supported
ctx_clip.flash_attn_type = CLIP_FLASH_ATTN_TYPE_ENABLED;
info = alloc_compute_meta(ctx_clip);
info = alloc_compute_meta(ctx_clip, batch);
if (!info.fattn && info.fattn_op) {
auto op = info.fattn_op;
LOG_WRN("%s: *****************************************************************\n", __func__);
@@ -3326,10 +3476,10 @@ struct clip_model_loader {
LOG_WRN("%s: please report this on github as an issue\n", __func__);
LOG_WRN("%s: *****************************************************************\n", __func__);
ctx_clip.flash_attn_type = CLIP_FLASH_ATTN_TYPE_DISABLED;
alloc_compute_meta(ctx_clip);
alloc_compute_meta(ctx_clip, batch);
}
} else {
info = alloc_compute_meta(ctx_clip);
info = alloc_compute_meta(ctx_clip, batch);
if (!info.fattn && ctx_clip.flash_attn_type == CLIP_FLASH_ATTN_TYPE_ENABLED) {
LOG_WRN("%s: flash attention is not supported by the current backend; falling back to CPU (performance will be degraded)\n", __func__);
}
@@ -3366,24 +3516,9 @@ struct clip_model_loader {
}
}
static support_info_graph alloc_compute_meta(clip_ctx & ctx_clip) {
const auto & hparams = ctx_clip.model.hparams;
static support_info_graph alloc_compute_meta(clip_ctx & ctx_clip, const clip_image_f32_batch & batch) {
ctx_clip.buf_compute_meta.resize(ctx_clip.max_nodes * ggml_tensor_overhead() + ggml_graph_overhead());
// create a fake batch
clip_image_f32_batch batch;
clip_image_f32_ptr img(clip_image_f32_init());
if (ctx_clip.model.modality == CLIP_MODALITY_VISION) {
img->nx = hparams.warmup_image_size;
img->ny = hparams.warmup_image_size;
LOG_INF("%s: warmup with image size = %d x %d\n", __func__, img->nx, img->ny);
} else {
img->nx = hparams.warmup_audio_size;
img->ny = hparams.n_mel_bins;
LOG_INF("%s: warmup with audio size = %d\n", __func__, img->nx);
}
batch.entries.push_back(std::move(img));
ggml_cgraph * gf = clip_image_build_graph(&ctx_clip, batch);
ggml_backend_sched_reserve(ctx_clip.sched.get(), gf);
@@ -3744,12 +3879,13 @@ struct img_tool {
const int width = inp_size.width;
const int height = inp_size.height;
auto round_by_factor = [f = align_size](float x) { return static_cast<int>(std::round(x / static_cast<float>(f))) * f; };
auto ceil_by_factor = [f = align_size](float x) { return static_cast<int>(std::ceil(x / static_cast<float>(f))) * f; };
auto floor_by_factor = [f = align_size](float x) { return static_cast<int>(std::floor(x / static_cast<float>(f))) * f; };
// always align up first
int h_bar = std::max(align_size, ceil_by_factor(height));
int w_bar = std::max(align_size, ceil_by_factor(width));
int h_bar = std::max(align_size, round_by_factor(height));
int w_bar = std::max(align_size, round_by_factor(width));
if (h_bar * w_bar > max_pixels) {
const auto beta = std::sqrt(static_cast<float>(height * width) / max_pixels);
@@ -3932,7 +4068,14 @@ struct llava_uhd {
clip_image_size refined_size; // size of image right before slicing (must be multiple of slice size)
clip_image_size grid_size; // grid_size.width * grid_size.height = number of slices
std::vector<slice_coordinates> slices;
img_tool::resize_algo interpolation_overview = img_tool::RESIZE_ALGO_BILINEAR;
bool padding_overview = false; // if true, refine image will be padded to the grid size (e.g. llava-1.6)
std::array<uint8_t, 3> pad_color_overview = {0, 0, 0};
img_tool::resize_algo interpolation_refined = img_tool::RESIZE_ALGO_BICUBIC;
bool padding_refined = false; // if true, refine image will be padded to the grid size (e.g. llava-1.6)
std::array<uint8_t, 3> pad_color_refined = {0, 0, 0};
};
static slice_instructions get_slice_instructions(struct clip_ctx * ctx, const clip_image_size & original_size) {
@@ -3959,10 +4102,11 @@ struct llava_uhd {
auto refine_size = llava_uhd::select_best_resolution(
original_size,
ctx->model.hparams.image_res_candidates);
res.overview_size = clip_image_size{slice_size, slice_size};
res.refined_size = refine_size;
res.grid_size = clip_image_size{0, 0};
res.padding_refined = true;
res.overview_size = clip_image_size{slice_size, slice_size};
res.refined_size = refine_size;
res.grid_size = clip_image_size{0, 0};
res.padding_refined = true;
res.interpolation_refined = img_tool::RESIZE_ALGO_BILINEAR; // preserve old behavior when padding
LOG_DBG("%s: using pinpoints for slicing\n", __func__);
LOG_DBG("%s: original size: %d x %d, overview size: %d x %d, refined size: %d x %d\n",
@@ -4041,12 +4185,13 @@ struct llava_uhd {
static std::vector<clip_image_u8_ptr> slice_image(const clip_image_u8 * img, const slice_instructions & inst) {
std::vector<clip_image_u8_ptr> output;
img_tool::resize_algo interpolation = img_tool::RESIZE_ALGO_BILINEAR; // TODO: make it configurable
// resize to overview size
clip_image_u8_ptr resized_img(clip_image_u8_init());
img_tool::resize(*img, *resized_img, inst.overview_size, interpolation);
img_tool::resize(*img, *resized_img, inst.overview_size, inst.interpolation_overview,
inst.padding_overview, inst.pad_color_overview);
output.push_back(std::move(resized_img));
if (inst.slices.empty()) {
// no slices, just return the resized image
return output;
@@ -4054,13 +4199,8 @@ struct llava_uhd {
// resize to refined size
clip_image_u8_ptr refined_img(clip_image_u8_init());
if (inst.padding_refined) {
img_tool::resize(*img, *refined_img, inst.refined_size, interpolation);
} else {
// only algo bicubic preserves the ratio; old models rely on this behavior
// TODO: do we need to support other algos here?
img_tool::resize(*img, *refined_img, inst.refined_size, img_tool::RESIZE_ALGO_BICUBIC, false);
}
img_tool::resize(*img, *refined_img, inst.refined_size, inst.interpolation_refined,
inst.padding_refined, inst.pad_color_refined);
// create slices
for (const auto & slice : inst.slices) {
@@ -4370,6 +4510,23 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str
res_imgs->entries.push_back(std::move(res));
} break;
case PROJECTOR_TYPE_KIMIK25:
{
GGML_ASSERT(params.image_min_pixels > 0 && params.image_max_pixels > 0);
const clip_image_size target_size = img_tool::calc_size_preserved_ratio(
original_size,
params.patch_size * params.n_merge,
params.image_min_pixels,
params.image_max_pixels);
const std::array<uint8_t, 3> pad_color = {0, 0, 0};
clip_image_u8 resized_img;
img_tool::resize(*img, resized_img, target_size, img_tool::RESIZE_ALGO_BICUBIC, true, pad_color);
clip_image_f32_ptr res(clip_image_f32_init());
normalize_image_u8_to_f32(resized_img, *res, params.image_mean, params.image_std);
res_imgs->entries.push_back(std::move(res));
} break;
case PROJECTOR_TYPE_MLP:
case PROJECTOR_TYPE_MLP_NORM:
case PROJECTOR_TYPE_LDP:
@@ -4551,6 +4708,7 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im
} break;
case PROJECTOR_TYPE_LFM2:
case PROJECTOR_TYPE_KIMIVL:
case PROJECTOR_TYPE_KIMIK25:
{
// dynamic size
int out_patch_size = params.patch_size * ctx->model.hparams.n_merge;
@@ -4709,7 +4867,6 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
}
// build the inference graph
ctx->debug_print_tensors.clear();
ggml_backend_sched_reset(ctx->sched.get());
ggml_cgraph * gf = clip_image_build_graph(ctx, imgs);
ggml_backend_sched_alloc_graph(ctx->sched.get(), gf);
@@ -4951,6 +5108,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
} break;
case PROJECTOR_TYPE_PIXTRAL:
case PROJECTOR_TYPE_KIMIVL:
case PROJECTOR_TYPE_KIMIK25:
case PROJECTOR_TYPE_LIGHTONOCR:
{
// set the 2D positions
@@ -5050,22 +5208,8 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
return false;
}
// print debug nodes
if (ctx->debug_graph) {
LOG_INF("\n\n---\n\n");
LOG_INF("\n\nDebug graph:\n\n");
for (ggml_tensor * t : ctx->debug_print_tensors) {
std::vector<uint8_t> data(ggml_nbytes(t));
ggml_backend_tensor_get(t, data.data(), 0, ggml_nbytes(t));
print_tensor_shape(t);
print_tensor_data(t, data.data(), 3);
}
}
// the last node is the embedding tensor
//ggml_tensor * embeddings = ggml_graph_node(gf, -1);
GGML_ASSERT(gf->n_nodes > 0);
ggml_tensor * embeddings = gf->nodes[gf->n_nodes-1];
ggml_tensor * embeddings = ggml_graph_node(gf, -1);
// sanity check (only support batch size of 1 for now)
const int n_tokens_out = embeddings->ne[1];
@@ -5076,7 +5220,9 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
}
// copy the embeddings to the location passed by the user
ggml_backend_tensor_get(embeddings, vec, 0, ggml_nbytes(embeddings));
if (vec != nullptr) {
ggml_backend_tensor_get(embeddings, vec, 0, ggml_nbytes(embeddings));
}
return true;
}
@@ -5119,6 +5265,7 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) {
return ctx->model.mm_fc_w->ne[1];
case PROJECTOR_TYPE_LFM2:
case PROJECTOR_TYPE_KIMIVL:
case PROJECTOR_TYPE_KIMIK25:
return ctx->model.mm_2_w->ne[1];
case PROJECTOR_TYPE_COGVLM:
return ctx->model.mm_4h_to_h_w->ne[1];

View File

@@ -2279,6 +2279,7 @@ extern "C" {
enum ggml_scale_mode {
GGML_SCALE_MODE_NEAREST = 0,
GGML_SCALE_MODE_BILINEAR = 1,
GGML_SCALE_MODE_BICUBIC = 2,
GGML_SCALE_MODE_COUNT
};

View File

@@ -22,6 +22,70 @@ static __global__ void upscale_f32(const float * x, float * dst,
dst[index] = *(float *)((char *)x + i03 * nb03 + i02 * nb02 + i01 * nb01 + i00 * nb00);
}
namespace bicubic_interpolation {
// https://en.wikipedia.org/wiki/Bicubic_interpolation#Bicubic_convolution_algorithm
__device__ const float a = -0.75f; // use alpha = -0.75 (same as PyTorch)
static __device__ float weight1(float x) { return ((a + 2) * x - (a + 3)) * x * x + 1; };
static __device__ float weight2(float x) { return ((a * x - 5 * a) * x + 8 * a) * x - 4 * a; };
static __device__ float bicubic(float p0, float p1, float p2, float p3, float x) {
const float w0 = weight2(x + 1);
const float w1 = weight1(x + 0);
const float w2 = weight1(1 - x);
const float w3 = weight2(2 - x);
return p0 * w0 + p1 * w1 + p2 * w2 + p3 * w3;
};
} // namespace bicubic_interpolation
static __global__ void upscale_f32_bicubic(const float * x, float * dst,
const int nb00, const int nb01, const int nb02, const int nb03,
const int ne00_src, const int ne01_src,
const int ne10_dst, const int ne11_dst, const int ne12_dst, const int ne13_dst,
const float sf0, const float sf1, const float sf2, const float sf3,
const float pixel_offset) {
using bicubic_interpolation::bicubic;
const int64_t index = threadIdx.x + blockIdx.x * blockDim.x;
const int64_t dst_total_elements = ne10_dst * ne11_dst * ne12_dst * ne13_dst;
if (index >= dst_total_elements) {
return;
}
const int i10_dst = index % ne10_dst;
const int i11_dst = (index / ne10_dst) % ne11_dst;
const int i12_dst = (index / (ne10_dst * ne11_dst)) % ne12_dst;
const int i13_dst = index / (ne10_dst * ne11_dst * ne12_dst);
const int i02_src = (int)(i12_dst / sf2);
const int i03_src = (int)(i13_dst / sf3);
const float y_src_f = ((float)i11_dst + pixel_offset) / sf1 - pixel_offset;
const int y0_src = (int)floorf(y_src_f);
const float dy = y_src_f - (float)y0_src;
const float x_src_f = ((float)i10_dst + pixel_offset) / sf0 - pixel_offset;
const int x0_src = (int)floorf(x_src_f);
const float dx = x_src_f - (float)x0_src;
const char * x_base = (const char *)x + (int64_t)i02_src * nb02 + (int64_t)i03_src * nb03;
auto load = [=](int x_off, int y_off) -> float {
int i00_src = max(0, min(x0_src + x_off, ne00_src - 1));
int i01_src = max(0, min(y0_src + y_off, ne01_src - 1));
return *(const float *)(x_base + (int64_t)i00_src * nb00 + (int64_t)i01_src * nb01);
};
const float result = bicubic(
bicubic(load(-1,-1), load(0,-1), load(1,-1), load(2,-1), dx),
bicubic(load(-1, 0), load(0, 0), load(1, 0), load(2, 0), dx),
bicubic(load(-1, 1), load(0, 1), load(1, 1), load(2, 1), dx),
bicubic(load(-1, 2), load(0, 2), load(1, 2), load(2, 2), dx), dy);
dst[index] = result;
}
static void upscale_f32_cuda(const float * x, float * dst,
const int nb00, const int nb01, const int nb02, const int nb03,
const int ne10, const int ne11, const int ne12, const int ne13,
@@ -33,6 +97,18 @@ static void upscale_f32_cuda(const float * x, float * dst,
upscale_f32<<<num_blocks, CUDA_UPSCALE_BLOCK_SIZE,0,stream>>>(x, dst, nb00, nb01, nb02, nb03, ne10, ne11, ne12, ne13, sf0, sf1, sf2, sf3);
}
static void upscale_f32_bicubic_cuda(const float * x, float * dst,
const int nb00, const int nb01, const int nb02, const int nb03,
const int ne00_src, const int ne01_src,
const int ne10_dst, const int ne11_dst, const int ne12_dst, const int ne13_dst,
const float sf0, const float sf1, const float sf2, const float sf3,
const float pixel_offset, cudaStream_t stream) {
const int64_t dst_size = ne10_dst * ne11_dst * ne12_dst * ne13_dst;
const int64_t num_blocks = (dst_size + CUDA_UPSCALE_BLOCK_SIZE - 1) / CUDA_UPSCALE_BLOCK_SIZE;
upscale_f32_bicubic<<<num_blocks, CUDA_UPSCALE_BLOCK_SIZE,0,stream>>>(x, dst, nb00, nb01, nb02, nb03, ne00_src, ne01_src, ne10_dst, ne11_dst, ne12_dst, ne13_dst, sf0, sf1, sf2, sf3, pixel_offset);
}
void ggml_cuda_op_upscale(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const float * src0_d = (const float *)src0->data;
@@ -42,10 +118,26 @@ void ggml_cuda_op_upscale(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
const float sf0 = (float)dst->ne[0]/src0->ne[0];
const float sf1 = (float)dst->ne[1]/src0->ne[1];
const float sf2 = (float)dst->ne[2]/src0->ne[2];
const int mode_flags = dst->op_params[0];
const ggml_scale_mode mode = (ggml_scale_mode)(mode_flags & 0xFF);
float sf0 = (float)dst->ne[0]/src0->ne[0];
float sf1 = (float)dst->ne[1]/src0->ne[1];
float sf2 = (float)dst->ne[2]/src0->ne[2];
const float sf3 = (float)dst->ne[3]/src0->ne[3];
upscale_f32_cuda(src0_d, dst_d, src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], sf0, sf1, sf2, sf3, stream);
float pixel_offset = 0.5f;
if (mode_flags & GGML_SCALE_FLAG_ALIGN_CORNERS) {
sf0 = dst->ne[0] > 1 && src0->ne[0] > 1 ? (float)(dst->ne[0] - 1) / (src0->ne[0] - 1) : sf0;
sf1 = dst->ne[1] > 1 && src0->ne[1] > 1 ? (float)(dst->ne[1] - 1) / (src0->ne[1] - 1) : sf1;
pixel_offset = 0.0f;
}
if (mode == GGML_SCALE_MODE_NEAREST || mode == GGML_SCALE_MODE_BILINEAR) {
upscale_f32_cuda(src0_d, dst_d, src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], sf0, sf1, sf2, sf3, stream);
} else if (mode == GGML_SCALE_MODE_BICUBIC) {
upscale_f32_bicubic_cuda(src0_d, dst_d, src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],
src0->ne[0], src0->ne[1], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3],
sf0, sf1, sf2, sf3, pixel_offset, stream);
}
}

View File

@@ -21042,6 +21042,26 @@ static void ggml_compute_forward_pool_2d(
// ggml_compute_forward_upscale
#ifndef GGML_CLAMP
#define GGML_CLAMP(x, min, max) ((x) < (min) ? (min) : ((x) > (max) ? (max) : (x)))
#endif
static inline float ggml_bicubic_weight1(float x, float a) {
return ((a + 2.0f) * x - (a + 3.0f)) * x * x + 1.0f;
}
static inline float ggml_bicubic_weight2(float x, float a) {
return ((a * x - 5.0f * a) * x + 8.0f * a) * x - 4.0f * a;
}
static inline float ggml_bicubic_interp(float p0, float p1, float p2, float p3, float x, float a) {
const float w0 = ggml_bicubic_weight2(x + 1.0f, a);
const float w1 = ggml_bicubic_weight1(x + 0.0f, a);
const float w2 = ggml_bicubic_weight1(1.0f - x, a);
const float w3 = ggml_bicubic_weight2(2.0f - x, a);
return p0 * w0 + p1 * w1 + p2 * w2 + p3 * w3;
}
static void ggml_compute_forward_upscale_f32(
const struct ggml_compute_params * params,
struct ggml_tensor * dst) {
@@ -21055,29 +21075,78 @@ static void ggml_compute_forward_upscale_f32(
GGML_TENSOR_UNARY_OP_LOCALS
const float sf0 = (float)ne0/src0->ne[0];
const float sf1 = (float)ne1/src0->ne[1];
const float sf2 = (float)ne2/src0->ne[2];
const float sf3 = (float)ne3/src0->ne[3];
float sf0 = (float)ne0/src0->ne[0];
float sf1 = (float)ne1/src0->ne[1];
float sf2 = (float)ne2/src0->ne[2];
float sf3 = (float)ne3/src0->ne[3];
float pixel_offset = 0.5f;
// TODO: optimize
const int32_t mode_flags = ((int32_t *)dst->op_params)[0];
const int32_t mode = (mode_flags & 0xFF);
for (int64_t i3 = 0; i3 < ne3; i3++) {
const int64_t i03 = i3 / sf3;
for (int64_t i2 = ith; i2 < ne2; i2 += nth) {
const int64_t i02 = i2 / sf2;
for (int64_t i1 = 0; i1 < ne1; i1++) {
const int64_t i01 = i1 / sf1;
for (int64_t i0 = 0; i0 < ne0; i0++) {
const int64_t i00 = i0 / sf0;
if (mode_flags & GGML_SCALE_FLAG_ALIGN_CORNERS) {
pixel_offset = 0.0f;
sf0 = ne0 > 1 && ne00 > 1 ? (float)(ne0 - 1) / (ne00 - 1) : sf0;
sf1 = ne1 > 1 && ne01 > 1 ? (float)(ne1 - 1) / (ne01 - 1) : sf1;
}
const float * x = (float *)((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
float * y = (float *)((char *) dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3);
if (mode == GGML_SCALE_MODE_NEAREST || mode == GGML_SCALE_MODE_BILINEAR) {
for (int64_t i3 = 0; i3 < ne3; i3++) {
const int64_t i03 = i3 / sf3;
for (int64_t i2 = ith; i2 < ne2; i2 += nth) {
const int64_t i02 = i2 / sf2;
for (int64_t i1 = 0; i1 < ne1; i1++) {
const int64_t i01 = i1 / sf1;
for (int64_t i0 = 0; i0 < ne0; i0++) {
const int64_t i00 = i0 / sf0;
*y = *x;
const float * x = (float *)((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
float * y = (float *)((char *) dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3);
*y = *x;
}
}
}
}
} else if (mode == GGML_SCALE_MODE_BICUBIC) {
// https://en.wikipedia.org/wiki/Bicubic_interpolation#Bicubic_convolution_algorithm
const float a = -0.75f; // use alpha = -0.75 (same as PyTorch)
for (int64_t i3 = 0; i3 < ne3; i3++) {
const int64_t i03 = i3 / sf3;
for (int64_t i2 = ith; i2 < ne2; i2 += nth) {
const int64_t i02 = i2 / sf2;
for (int64_t i1 = 0; i1 < ne1; i1++) {
const float y = ((float)i1 + pixel_offset) / sf1 - pixel_offset;
const int64_t y0 = (int64_t)floorf(y);
const float dy = y - (float)y0;
for (int64_t i0 = 0; i0 < ne0; i0++) {
const float x = ((float)i0 + pixel_offset) / sf0 - pixel_offset;
const int64_t x0 = (int64_t)floorf(x);
const float dx = x - (float)x0;
float p_vals[4];
for (int iy = -1; iy <= 2; iy++) {
float row_vals[4];
for (int ix = -1; ix <= 2; ix++) {
int64_t idx_x = GGML_CLAMP(x0 + ix, 0, ne00 - 1);
int64_t idx_y = GGML_CLAMP(y0 + iy, 0, ne01 - 1);
row_vals[ix + 1] = *(const float *)((const char *)src0->data + idx_x*nb00 + idx_y*nb01 + i02*nb02 + i03*nb03);
}
p_vals[iy + 1] = ggml_bicubic_interp(row_vals[0], row_vals[1], row_vals[2], row_vals[3], dx, a);
}
const float val = ggml_bicubic_interp(p_vals[0], p_vals[1], p_vals[2], p_vals[3], dy, a);
float * y_dst = (float *)((char *)dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3);
*y_dst = val;
}
}
}
}
} else {
GGML_ABORT("unsupported upscale mode");
}
}