From e7ddefc9c5265a493a2c796008d02583c656ddfc Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Fri, 26 Sep 2025 08:39:18 +0300 Subject: [PATCH] Add mtmd: fix swiglu --- common/common.cpp | 4 ++ ggml/src/ggml.c | 176 +++++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 179 insertions(+), 1 deletion(-) diff --git a/common/common.cpp b/common/common.cpp index 6027b192..1625e87b 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -907,6 +907,10 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa params.mmproj.url = argv[i]; return true; } + if (arg == "--no-mmproj-offload") { + params.mmproj_use_gpu = false; + return true; + } if (arg == "--image") { CHECK_ARG params.image.emplace_back(argv[i]); diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 72520928..a29db2a4 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -3660,6 +3660,38 @@ static void ggml_vec_swiglu_f32(const int n, float * y, const float * x) { } } +static void ggml_vec_swiglu_f32_glu(const int n, float * y, const float * x, const float * g) { + int i = 0; +#if defined(__AVX512F__) && defined(__AVX512DQ__) + for (; i + 15 < n; i += 16) { + _mm512_storeu_ps(y + i, _mm512_mul_ps(ggml_v_silu(_mm512_loadu_ps(x + i)), _mm512_loadu_ps(g + i))); + } +#elif defined(__AVX2__) && defined(__FMA__) + for (; i + 7 < n; i += 8) { + _mm256_storeu_ps(y + i, _mm256_mul_ps(ggml_v_silu(_mm256_loadu_ps(x + i)), _mm256_loadu_ps(g + i))); + } +#elif defined(__SSE2__) + for (; i + 3 < n; i += 4) { + _mm_storeu_ps(y + i, _mm_mul_ps(ggml_v_silu(_mm_loadu_ps(x + i)), _mm_loadu_ps(g + i))); + } +#elif defined(__ARM_NEON) && defined(__aarch64__) + for (; i + 3 < n; i += 4) { + vst1q_f32(y + i, vmulq_f32(ggml_v_silu(vld1q_f32(x + i)), vld1q_f32(g + i))); + } +#endif + for (; i < n; ++i) { + y[i] = ggml_silu_f32(x[i]) * g[i]; + } +} + +inline static void ggml_vec_swiglu_f16_glu(const int n, ggml_fp16_t * y, const ggml_fp16_t * x, const ggml_fp16_t * g) { + for (int i = 0; i < n; ++i) { + float xi = GGML_FP16_TO_FP32(x[i]); + float gi = GGML_FP16_TO_FP32(g[i]); + y[i] = GGML_FP32_TO_FP16((xi/(1.0f + expf(-xi))) * gi); + } +} + static void ggml_vec_tanh_f32(const int n, float * y, const float * x) { int i = 0; #if defined(__AVX512F__) && defined(__AVX512DQ__) @@ -14435,6 +14467,148 @@ static void ggml_compute_forward_swiglu( } } +static void ggml_compute_forward_swiglu_f32_glu( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + const struct ggml_tensor * src1 = dst->src[1]; + char * src0_d = (char *) src0->data; + char * src1_d = (char *) (src1 ? src1->data : src0->data); + const size_t src0_o = src0->nb[1]; + const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1]; + + GGML_ASSERT(ggml_is_contiguous_1(src0)); + GGML_ASSERT(ggml_is_contiguous_1(dst)); + + if (src1) { + GGML_ASSERT(ggml_is_contiguous_1(src1)); + GGML_ASSERT(src0->type == src1->type); + } + + const int ith = params->ith; + const int nth = params->nth; + + const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2; + const int nr = ggml_nrows(src0); + + GGML_ASSERT(dst->ne[0] == nc); + GGML_ASSERT(ggml_nrows(dst) == nr); + + const int32_t swapped = ggml_get_op_params_i32(dst, 1); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int i1 = ir0; i1 < ir1; i1++) { + float * src0_p = (float *) (src0_d + i1*src0_o); + float * src1_p = (float *) (src1_d + i1*src1_o); + + if (!src1) { + src0_p += swapped ? nc : 0; + src1_p += swapped ? 0 : nc; + } + + ggml_vec_swiglu_f32_glu(nc, (float *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p); + +#ifndef NDEBUG + for (int k = 0; k < nc; k++) { + const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k]; + GGML_UNUSED(x); + assert(!isnan(x)); + assert(!isinf(x)); + } +#endif + } +} + +static void ggml_compute_forward_swiglu_f16_glu( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + const struct ggml_tensor * src1 = dst->src[1]; + char * src0_d = (char *) src0->data; + char * src1_d = (char *) (src1 ? src1->data : src0->data); + const size_t src0_o = src0->nb[1]; + const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1]; + + GGML_ASSERT(ggml_is_contiguous_1(src0)); + GGML_ASSERT(ggml_is_contiguous_1(dst)); + + if (src1) { + GGML_ASSERT(ggml_is_contiguous_1(src1)); + GGML_ASSERT(src0->type == src1->type); + } + + const int ith = params->ith; + const int nth = params->nth; + + const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2; + const int nr = ggml_nrows(src0); + + GGML_ASSERT(dst->ne[0] == nc); + GGML_ASSERT(ggml_nrows(dst) == nr); + + const int32_t swapped = ggml_get_op_params_i32(dst, 1); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int i1 = ir0; i1 < ir1; i1++) { + ggml_fp16_t * src0_p = (ggml_fp16_t *) (src0_d + i1*src0_o); + ggml_fp16_t * src1_p = (ggml_fp16_t *) (src1_d + i1*src1_o); + + if (!src1) { + src0_p += swapped ? nc : 0; + src1_p += swapped ? 0 : nc; + } + + ggml_vec_swiglu_f16_glu(nc, (ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p); + +#ifndef NDEBUG + for (int k = 0; k < nc; k++) { + const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k]; + const float v = GGML_FP16_TO_FP32(x); + GGML_UNUSED(v); + assert(!isnan(v)); + assert(!isinf(v)); + } +#endif + } +} + +static void ggml_compute_forward_swiglu_glu( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_swiglu_f32_glu(params, dst); + } break; + case GGML_TYPE_F16: + { + ggml_compute_forward_swiglu_f16_glu(params, dst); + } break; + default: + { + GGML_ABORT("fatal error"); + } + } +} + + // ggml_compute_forward_swiglu_oai static void ggml_compute_forward_swiglu_oai_f32( @@ -21322,7 +21496,7 @@ static void ggml_compute_forward_glu( } break; case GGML_GLU_OP_SWIGLU: { - ggml_compute_forward_swiglu(params, dst); + ggml_compute_forward_swiglu_glu(params, dst); } break; case GGML_GLU_OP_SWIGLU_OAI: {