From 1ca612375e9053f458a68007647dcc0ac562135b Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Sun, 17 Aug 2025 14:31:03 +0300 Subject: [PATCH] Fix GLM-4.5 attention (#700) Co-authored-by: Iwan Kawrakow --- ggml/src/ggml-cuda/fattn-wmma-f16.cuh | 2 +- ggml/src/ggml-cuda/fattn.cu | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/ggml/src/ggml-cuda/fattn-wmma-f16.cuh b/ggml/src/ggml-cuda/fattn-wmma-f16.cuh index e1e2ec6e..db144a02 100644 --- a/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-wmma-f16.cuh @@ -96,7 +96,7 @@ static __global__ void flash_attn_ext_f16( const half * V_h = (const half *) (V + nb22*(blockIdx.y / gqa_ratio)); // K and V have same shape const half * maskh = (const half *) mask + (nb31/sizeof(half))* ic0; const half2 * mask2 = (const half2 *) mask + (nb31/sizeof(half))*(ic0/2); - const float * sinks_f = sinks ? (const float *)sinks + blockIdx.y : nullptr; + [[maybe_unused]] const float * sinks_f = sinks ? (const float *)sinks + blockIdx.y : nullptr; const int stride_Q = nb01 / sizeof(float); const int stride_K = nb11 / sizeof(half); diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu index ffcaf219..243211f7 100644 --- a/ggml/src/ggml-cuda/fattn.cu +++ b/ggml/src/ggml-cuda/fattn.cu @@ -78,7 +78,7 @@ static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, gg float max_bias = 0.0f; memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float)); - const float use_gqa_opt = mask && max_bias == 0.0f; + const bool use_gqa_opt = mask && max_bias == 0.0f; GGML_ASSERT(Q->ne[2] % K->ne[2] == 0); const int gqa_ratio = Q->ne[2] / K->ne[2]; @@ -88,12 +88,12 @@ static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, gg return; } - if (use_gqa_opt && gqa_ratio == 4) { + if (use_gqa_opt && gqa_ratio % 4 == 0) { ggml_cuda_flash_attn_ext_mma_f16_switch_hs<4>(ctx, dst); return; } - if (use_gqa_opt && gqa_ratio == 2) { + if (use_gqa_opt && gqa_ratio % 2 == 0) { ggml_cuda_flash_attn_ext_mma_f16_switch_hs<2>(ctx, dst); return; } @@ -508,7 +508,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst //const bool mma_needs_data_conversion = K->type != GGML_TYPE_F16 || V->type != GGML_TYPE_F16; //const bool mma_faster_for_bs1 = new_mma_available(cc) && gqa_opt_applies && cc < CC_ADA_LOVELACE && !mma_needs_data_conversion; const bool mma_faster_for_bs1 = new_mma_available(cc) && gqa_opt_applies; - const bool can_use_vector_kernel = Q->ne[0] % (2*WARP_SIZE) == 0; + const bool can_use_vector_kernel = Q->ne[0] <= 256 && Q->ne[0] % (2*WARP_SIZE) == 0; if (Q->ne[1] == 1 && can_use_vector_kernel && !mma_faster_for_bs1) { if (precision == GGML_PREC_DEFAULT) { ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);