From 02ae22388fcbc666a85c22754ee0438892a59b74 Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Thu, 29 Jan 2026 07:27:53 +0200 Subject: [PATCH] Apply offfset to KQ_max in CUDA flash attention (#1196) * Apply offfset to KQ_max in CUDA flash attention * Forgot to add to fattn-common.h --- ggml/src/ggml-cuda/fattn-common.cuh | 2 ++ ggml/src/ggml-cuda/fattn-mma-f16.cuh | 4 ++-- ggml/src/ggml-cuda/fattn-new-mma.cu | 4 ++-- 3 files changed, 6 insertions(+), 4 deletions(-) diff --git a/ggml/src/ggml-cuda/fattn-common.cuh b/ggml/src/ggml-cuda/fattn-common.cuh index b4ff8913..fbff9e7f 100644 --- a/ggml/src/ggml-cuda/fattn-common.cuh +++ b/ggml/src/ggml-cuda/fattn-common.cuh @@ -17,6 +17,8 @@ #define HALF_MAX_HALF __float2half(65504.0f/2) // Use neg. of this instead of -INFINITY to initialize KQ max vals to avoid NaN upon subtraction. #define SOFTMAX_FTZ_THRESHOLD -20.0f // Softmax exp. of values smaller than this are flushed to zero to avoid NaNs. +#define FATTN_KQ_MAX_OFFSET 1.386294361f + typedef void (* fattn_kernel_t)( const char * __restrict__ Q, const char * __restrict__ K, diff --git a/ggml/src/ggml-cuda/fattn-mma-f16.cuh b/ggml/src/ggml-cuda/fattn-mma-f16.cuh index 27328319..050186ce 100644 --- a/ggml/src/ggml-cuda/fattn-mma-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-mma-f16.cuh @@ -264,7 +264,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( for (int k = 0; k < KQ_per_iter/(np*tile_C_KQ::I); ++k) { #pragma unroll for (int l = 0; l < tile_C_KQ::ne; ++l) { - KQ_max_new[l % 2] = fmaxf(KQ_max_new[l % 2], KQ_C[k].x[l]); + KQ_max_new[l % 2] = fmaxf(KQ_max_new[l % 2], KQ_C[k].x[l] + FATTN_KQ_MAX_OFFSET); } } @@ -319,7 +319,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( #pragma unroll for (int l = 0; l < tile_C_KQ_16::ne; ++l) { const int KQ_index = 2*t + (l/2) % 2; - KQ_max_new[KQ_index] = fmaxf(KQ_max_new[KQ_index], KQ_C_16[k*ntiles/2 + t].x[l]); + KQ_max_new[KQ_index] = fmaxf(KQ_max_new[KQ_index], KQ_C_16[k*ntiles/2 + t].x[l] + FATTN_KQ_MAX_OFFSET); } } } diff --git a/ggml/src/ggml-cuda/fattn-new-mma.cu b/ggml/src/ggml-cuda/fattn-new-mma.cu index 52a14639..89ee392e 100644 --- a/ggml/src/ggml-cuda/fattn-new-mma.cu +++ b/ggml/src/ggml-cuda/fattn-new-mma.cu @@ -702,7 +702,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( for (int k = 0; k < c::nbatch_fa/(np*tile_C_KQ::I); ++k) { #pragma unroll for (int l = 0; l < tile_C_KQ::ne; ++l) { - KQ_max_new[l % 2] = fmaxf(KQ_max_new[l % 2], KQ_C[k].x[l]); + KQ_max_new[l % 2] = fmaxf(KQ_max_new[l % 2], KQ_C[k].x[l] + FATTN_KQ_MAX_OFFSET); } } @@ -756,7 +756,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( #pragma unroll for (int l = 0; l < tile_C_KQ_16::ne; ++l) { const int KQ_index = 2*t + (l/2) % 2; - KQ_max_new[KQ_index] = fmaxf(KQ_max_new[KQ_index], KQ_C_16[k*ntiles/2 + t].x[l]); + KQ_max_new[KQ_index] = fmaxf(KQ_max_new[KQ_index], KQ_C_16[k*ntiles/2 + t].x[l] + FATTN_KQ_MAX_OFFSET); } } }