mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-21 13:44:10 +00:00
Apply offfset to KQ_max in CUDA flash attention (#1196)
* Apply offfset to KQ_max in CUDA flash attention * Forgot to add to fattn-common.h
This commit is contained in:
@@ -17,6 +17,8 @@
|
||||
#define HALF_MAX_HALF __float2half(65504.0f/2) // Use neg. of this instead of -INFINITY to initialize KQ max vals to avoid NaN upon subtraction.
|
||||
#define SOFTMAX_FTZ_THRESHOLD -20.0f // Softmax exp. of values smaller than this are flushed to zero to avoid NaNs.
|
||||
|
||||
#define FATTN_KQ_MAX_OFFSET 1.386294361f
|
||||
|
||||
typedef void (* fattn_kernel_t)(
|
||||
const char * __restrict__ Q,
|
||||
const char * __restrict__ K,
|
||||
|
||||
@@ -264,7 +264,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
||||
for (int k = 0; k < KQ_per_iter/(np*tile_C_KQ::I); ++k) {
|
||||
#pragma unroll
|
||||
for (int l = 0; l < tile_C_KQ::ne; ++l) {
|
||||
KQ_max_new[l % 2] = fmaxf(KQ_max_new[l % 2], KQ_C[k].x[l]);
|
||||
KQ_max_new[l % 2] = fmaxf(KQ_max_new[l % 2], KQ_C[k].x[l] + FATTN_KQ_MAX_OFFSET);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -319,7 +319,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
||||
#pragma unroll
|
||||
for (int l = 0; l < tile_C_KQ_16::ne; ++l) {
|
||||
const int KQ_index = 2*t + (l/2) % 2;
|
||||
KQ_max_new[KQ_index] = fmaxf(KQ_max_new[KQ_index], KQ_C_16[k*ntiles/2 + t].x[l]);
|
||||
KQ_max_new[KQ_index] = fmaxf(KQ_max_new[KQ_index], KQ_C_16[k*ntiles/2 + t].x[l] + FATTN_KQ_MAX_OFFSET);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -702,7 +702,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
||||
for (int k = 0; k < c::nbatch_fa/(np*tile_C_KQ::I); ++k) {
|
||||
#pragma unroll
|
||||
for (int l = 0; l < tile_C_KQ::ne; ++l) {
|
||||
KQ_max_new[l % 2] = fmaxf(KQ_max_new[l % 2], KQ_C[k].x[l]);
|
||||
KQ_max_new[l % 2] = fmaxf(KQ_max_new[l % 2], KQ_C[k].x[l] + FATTN_KQ_MAX_OFFSET);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -756,7 +756,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
||||
#pragma unroll
|
||||
for (int l = 0; l < tile_C_KQ_16::ne; ++l) {
|
||||
const int KQ_index = 2*t + (l/2) % 2;
|
||||
KQ_max_new[KQ_index] = fmaxf(KQ_max_new[KQ_index], KQ_C_16[k*ntiles/2 + t].x[l]);
|
||||
KQ_max_new[KQ_index] = fmaxf(KQ_max_new[KQ_index], KQ_C_16[k*ntiles/2 + t].x[l] + FATTN_KQ_MAX_OFFSET);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user