From ba0e88a5e3b16a950c1ade267124e1c4e539ae0f Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Sun, 28 Dec 2025 08:57:24 +0000 Subject: [PATCH] Minor --- ggml/src/ggml-cuda/fattn-new-mma.cu | 4 ++-- src/llama-build-context.cpp | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/ggml/src/ggml-cuda/fattn-new-mma.cu b/ggml/src/ggml-cuda/fattn-new-mma.cu index 63a9ca57..8a5bd1b1 100644 --- a/ggml/src/ggml-cuda/fattn-new-mma.cu +++ b/ggml/src/ggml-cuda/fattn-new-mma.cu @@ -1102,8 +1102,8 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( #pragma unroll for (int col = 0; col < cols_per_thread; ++col) { static_assert(ntiles == 1 || ntiles == 2, "ntiles > 2 not implemented"); - const int jc = cols_per_warp == 8 ? tile_C_VKQ::get_j(col) : tile_C_VKQ_16::get_i(2*col); - //const int jc = ntiles == 1 ? 2*tile_C_VKQ::get_j(col/2) + col % 2 : tile_C_VKQ_16::get_i(col); + //const int jc = cols_per_warp == 8 ? tile_C_VKQ::get_j(col) : tile_C_VKQ_16::get_i(2*col); + const int jc = ntiles == 1 ? 2*tile_C_VKQ::get_j(col/2) + col % 2 : tile_C_VKQ_16::get_i(col); const float sink = sinks_f[jc % ncols2]; const float KQ_max_new = fmaxf(KQ_max[col], sink); diff --git a/src/llama-build-context.cpp b/src/llama-build-context.cpp index 3c6e552f..937540a2 100644 --- a/src/llama-build-context.cpp +++ b/src/llama-build-context.cpp @@ -1407,7 +1407,7 @@ static ggml_tensor * llm_build_kqv( //ggml_mul_mat_set_prec(kq, GGML_PREC_F32); if (use_f32_precision || model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX || model.arch == LLM_ARCH_QWEN2 || - model.arch == LLM_ARCH_COHERE2 || model.arch == LLM_ARCH_GLM4 || model.arch == LLM_ARCH_GLM4_MOE) { + model.arch == LLM_ARCH_COHERE2 || model.arch == LLM_ARCH_GLM4 || model.arch == LLM_ARCH_GLM4_MOE || model.arch == LLM_ARCH_MIMO2) { // for this arch, we need to perform the KQ multiplication with F32 precision, otherwise we get NaNs // ref: https://github.com/ggerganov/llama.cpp/pull/4490#issuecomment-1859055847 ggml_mul_mat_set_prec(kq, GGML_PREC_F32);