From 38533d6bd445932a9f4a2c75c0e429e45d5eb26d Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Tue, 12 Aug 2025 13:39:37 +0300 Subject: [PATCH] Move row sums to the write place --- ggml/src/ggml-cuda/fattn-new-mma.cu | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/ggml/src/ggml-cuda/fattn-new-mma.cu b/ggml/src/ggml-cuda/fattn-new-mma.cu index cf87184a..27a411ea 100644 --- a/ggml/src/ggml-cuda/fattn-new-mma.cu +++ b/ggml/src/ggml-cuda/fattn-new-mma.cu @@ -976,20 +976,6 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( __syncthreads(); } - // Finally, sum up partial KQ rowsums. - // The partial sums are spread across 8/4 threads each, does not need full reduce. - { - constexpr int offset_first = ntiles == 1 ? 16 : 2; - constexpr int offset_last = ntiles == 1 ? 4 : 1; -#pragma unroll - for (int col = 0; col < cols_per_thread; ++col) { -#pragma unroll - for (int offset = offset_first; offset >= offset_last; offset >>= 1) { - KQ_rowsum[col] += __shfl_xor_sync(0xFFFFFFFF, KQ_rowsum[col], offset, WARP_SIZE); - } - } - } - // If attention sinks are used, potentially re-scale if KQ_max is small. // Also add the sink as a value to KQ_rowsum, this is done after synchonization of KQ_rowsum // so it's being done unconditionally for every thread. @@ -1036,6 +1022,20 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( } } + // Finally, sum up partial KQ rowsums. + // The partial sums are spread across 8/4 threads each, does not need full reduce. + { + constexpr int offset_first = ntiles == 1 ? 16 : 2; + constexpr int offset_last = ntiles == 1 ? 4 : 1; +#pragma unroll + for (int col = 0; col < cols_per_thread; ++col) { +#pragma unroll + for (int offset = offset_first; offset >= offset_last; offset >>= 1) { + KQ_rowsum[col] += __shfl_xor_sync(0xFFFFFFFF, KQ_rowsum[col], offset, WARP_SIZE); + } + } + } + // Combine VKQ accumulator values if np > 1. // It's also faster to do small writes to shared memory, then large write to VRAM than to do small writes to VRAM. // So also write VKQ accumulators to shared memory in column-major format if np == 1.