From 505e2c57f92fc233a0cef0b338cb215109144db6 Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Mon, 2 Mar 2026 17:54:56 +0100 Subject: [PATCH] Reduce memory use when processing large images (#1349) --- examples/mtmd/clip.cpp | 36 +++++++++++++++++++++++++++--------- 1 file changed, 27 insertions(+), 9 deletions(-) diff --git a/examples/mtmd/clip.cpp b/examples/mtmd/clip.cpp index 81644040..90d7f91c 100644 --- a/examples/mtmd/clip.cpp +++ b/examples/mtmd/clip.cpp @@ -2425,18 +2425,36 @@ private: ggml_tensor * v = ggml_permute(ctx0, v_cur, 1, 2, 0, 3); v = ggml_cont(ctx0, v); - const auto n_tokens = q->ne[1]; - const auto n_head = q->ne[2]; + if (q->ne[3] == 1 && q->ne[2] > 1 && q->ne[2] == k->ne[2] && q->ne[2] == v->ne[2] && q->ne[1]*k->ne[1]*q->ne[2]/1024./1024. >= 256.) { + cur = nullptr; + for (int64_t i2 = 0; i2 < q->ne[2]; ++i2) { + auto qi = ggml_view_2d(ctx0, q, q->ne[0], q->ne[1], q->nb[1], q->nb[2]*i2); + auto ki = ggml_view_2d(ctx0, k, k->ne[0], k->ne[1], k->nb[1], k->nb[2]*i2); + auto vi = ggml_view_2d(ctx0, v, v->ne[0], v->ne[1], v->nb[1], v->nb[2]*i2); + auto kq = ggml_mul_mat(ctx0, ki, qi); + kq = ggml_soft_max_ext(ctx0, kq, kq_mask, kq_scale, 0.0f); + auto kqv = ggml_mul_mat(ctx0, vi, kq); + if (cur) { + cur = ggml_concat(ctx0, cur, kqv, 0); + } else { + cur = kqv; + } + } + } else { - ggml_tensor * kq = ggml_mul_mat(ctx0, k, q); - // F32 may not needed for vision encoders? - // ggml_mul_mat_set_prec(kq, GGML_PREC_F32); + const auto n_tokens = q->ne[1]; + const auto n_head = q->ne[2]; - kq = ggml_soft_max_ext(ctx0, kq, kq_mask, kq_scale, 0.0f); + ggml_tensor * kq = ggml_mul_mat(ctx0, k, q); + // F32 may not needed for vision encoders? + // ggml_mul_mat_set_prec(kq, GGML_PREC_F32); - ggml_tensor * kqv = ggml_mul_mat(ctx0, v, kq); - cur = ggml_permute(ctx0, kqv, 0, 2, 1, 3); - cur = ggml_cont_2d(ctx0, cur, cur->ne[0]*n_head, n_tokens); + kq = ggml_soft_max_ext(ctx0, kq, kq_mask, kq_scale, 0.0f); + + ggml_tensor * kqv = ggml_mul_mat(ctx0, v, kq); + cur = ggml_permute(ctx0, kqv, 0, 2, 1, 3); + cur = ggml_cont_2d(ctx0, cur, cur->ne[0]*n_head, n_tokens); + } } cb(cur, "kqv_out", il);