mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-23 22:54:10 +00:00
Move row sums to the write place
This commit is contained in:
@@ -976,20 +976,6 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
// Finally, sum up partial KQ rowsums.
|
||||
// The partial sums are spread across 8/4 threads each, does not need full reduce.
|
||||
{
|
||||
constexpr int offset_first = ntiles == 1 ? 16 : 2;
|
||||
constexpr int offset_last = ntiles == 1 ? 4 : 1;
|
||||
#pragma unroll
|
||||
for (int col = 0; col < cols_per_thread; ++col) {
|
||||
#pragma unroll
|
||||
for (int offset = offset_first; offset >= offset_last; offset >>= 1) {
|
||||
KQ_rowsum[col] += __shfl_xor_sync(0xFFFFFFFF, KQ_rowsum[col], offset, WARP_SIZE);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If attention sinks are used, potentially re-scale if KQ_max is small.
|
||||
// Also add the sink as a value to KQ_rowsum, this is done after synchonization of KQ_rowsum
|
||||
// so it's being done unconditionally for every thread.
|
||||
@@ -1036,6 +1022,20 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
||||
}
|
||||
}
|
||||
|
||||
// Finally, sum up partial KQ rowsums.
|
||||
// The partial sums are spread across 8/4 threads each, does not need full reduce.
|
||||
{
|
||||
constexpr int offset_first = ntiles == 1 ? 16 : 2;
|
||||
constexpr int offset_last = ntiles == 1 ? 4 : 1;
|
||||
#pragma unroll
|
||||
for (int col = 0; col < cols_per_thread; ++col) {
|
||||
#pragma unroll
|
||||
for (int offset = offset_first; offset >= offset_last; offset >>= 1) {
|
||||
KQ_rowsum[col] += __shfl_xor_sync(0xFFFFFFFF, KQ_rowsum[col], offset, WARP_SIZE);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Combine VKQ accumulator values if np > 1.
|
||||
// It's also faster to do small writes to shared memory, then large write to VRAM than to do small writes to VRAM.
|
||||
// So also write VKQ accumulators to shared memory in column-major format if np == 1.
|
||||
|
||||
Reference in New Issue
Block a user