From 4a6a6f17ee9e72b527ca4d2bbb3b83de7c672fbc Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Thu, 4 Sep 2025 08:42:18 +0200 Subject: [PATCH] Alternative CUDA FA for SWA models (#754) * Bounds for flash attention * Add n_swa to FA parameters * Fix it * This seems very slightly better * Using vec kernel when we have SWA * Need also this * f32 vec kernel * This is slightly better --------- Co-authored-by: Iwan Kawrakow --- ggml/src/ggml-cuda/fattn-mma-f16.cuh | 186 +++- ggml/src/ggml-cuda/fattn-vec-common.cuh | 1074 +++++++++++++++++++++++ ggml/src/ggml-cuda/fattn-vec-f16.cuh | 213 +++-- ggml/src/ggml-cuda/fattn-vec-f32.cuh | 207 +++-- ggml/src/ggml-cuda/fattn.cu | 31 +- ggml/src/ggml.c | 2 + src/llama.cpp | 24 +- 7 files changed, 1528 insertions(+), 209 deletions(-) create mode 100644 ggml/src/ggml-cuda/fattn-vec-common.cuh diff --git a/ggml/src/ggml-cuda/fattn-mma-f16.cuh b/ggml/src/ggml-cuda/fattn-mma-f16.cuh index 16c8c24f..0c1aaf7d 100644 --- a/ggml/src/ggml-cuda/fattn-mma-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-mma-f16.cuh @@ -13,6 +13,45 @@ typedef tile<16, 16, float> tile_C_KQ_16; typedef tile<16, 4, half2> tile_C_VKQ; typedef tile<16, 8, half2> tile_C_VKQ_16; +typedef void (* fattn_kernel_mma_t)( + const char * __restrict__ Q, + const char * __restrict__ K, + const char * __restrict__ V, + const char * __restrict__ mask, + const char * __restrict__ sinks, + const int2 * __restrict__ bounds, + float * __restrict__ dst, + float2 * __restrict__ dst_meta, + const float scale, + const float max_bias, + const float m0, + const float m1, + const float softcap, + const uint32_t n_head_log2, + const int ne00, + const int ne01, + const int ne02, + const int ne03, + const int ne10, + const int ne11, + const int ne12, + const int ne13, + const int ne31, + const int nb31, + const int nb01, + const int nb02, + const int nb03, + const int nb11, + const int nb12, + const int nb13, + const int nb21, + const int nb22, + const int nb23, + const int ne0, + const int ne1, + const int ne2, + const int ne3); + template static __device__ __forceinline__ void flash_attn_ext_f16_load_tile( const half2 * const __restrict__ KV, half2 * const __restrict__ tile_KV, const int stride_KV) { @@ -871,6 +910,7 @@ static __global__ void flash_attn_mma_ext_f16( const char * __restrict__ V, const char * __restrict__ mask, const char * __restrict__ sinks, + const int2 * __restrict__ bounds, float * __restrict__ dst, float2 * __restrict__ dst_meta, const float scale, @@ -948,8 +988,13 @@ static __global__ void flash_attn_mma_ext_f16( const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, channel, n_head_log2, m0, m1) : 1.0f; - const int kb0_start_kernel = kb0_start * kb_niter; - const int kb0_stop_kernel = kb0_stop * kb_niter; + int kb0_start_kernel = kb0_start * kb_niter; + int kb0_stop_kernel = kb0_stop * kb_niter; + + if (bounds) { + kb0_start_kernel = max(kb0_start_kernel, bounds[jt].x / KQ_per_iter); + kb0_stop_kernel = min(kb0_stop_kernel, bounds[jt].y / KQ_per_iter); + } constexpr bool is_fixup = false; // All but (potentially) the last iterations write their data to dst rather than the fixup buffer. if (kb0_start == 0) { @@ -987,8 +1032,12 @@ static __global__ void flash_attn_mma_ext_f16( const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, channel, n_head_log2, m0, m1) : 1.0f; - const int kb0_start_kernel = kb0_start * kb_niter; - const int kb0_stop_kernel = kb0_stop * kb_niter; + int kb0_start_kernel = kb0_start * kb_niter; + int kb0_stop_kernel = kb0_stop * kb_niter; + if (bounds) { + kb0_start_kernel = max(kb0_start_kernel, bounds[jt].x / KQ_per_iter); + kb0_stop_kernel = min(kb0_stop_kernel, bounds[jt].y / KQ_per_iter); + } constexpr bool is_fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks. constexpr bool needs_fixup = false; @@ -1144,9 +1193,109 @@ static __global__ void flash_attn_mma_combine_results( dst[blockIdx.z*D + tid] = VKQ_numerator / VKQ_denominator; } +template +static __device__ __forceinline__ int warp_reduce_all(int x) { + if constexpr (width == WARP_SIZE) { //ggml_cuda_get_physical_warp_size()) { + return __all_sync(0xffffffff, x); + } else { +#pragma unroll + for (int offset = width/2; offset > 0; offset >>= 1) { + x = __shfl_xor_sync(0xffffffff, x, offset, width) && x; + } + return x; + } +} + +template +__launch_bounds__(FATTN_KQ_STRIDE/2, 1) +static __global__ void flash_attn_mask_to_KV_min_max( + const half2 * __restrict__ mask, int2 * __restrict__ KV_min_max, const int ne30, const int s31, const int s33) { + const int ne31 = gridDim.x; + const int tid = threadIdx.x; + const int sequence = blockIdx.y; + const int jt = blockIdx.x; + + mask += sequence*s33 + jt*ncols1*s31; + + __shared__ int buf_iw[WARP_SIZE]; + if (tid < WARP_SIZE) { + buf_iw[tid] = 1; + } + __syncthreads(); + + int KV_max_sj = (ne30 - 1) * FATTN_KQ_STRIDE; + for (; KV_max_sj >= 0; KV_max_sj -= FATTN_KQ_STRIDE) { + int all_inf = 1; + +#pragma unroll + for (int j = 0; j < ncols1; ++j) { + const float2 tmp = __half22float2(mask[j*s31 + KV_max_sj/2 + tid]); + all_inf = all_inf && int(isinf(tmp.x)) && int(isinf(tmp.y)); + } + + all_inf = warp_reduce_all(all_inf); + if (tid % WARP_SIZE == 0) { + buf_iw[tid / WARP_SIZE] = all_inf; + } + __syncthreads(); + all_inf = buf_iw[tid % WARP_SIZE]; + __syncthreads(); + all_inf = warp_reduce_all(all_inf); + + if (!all_inf) { + break; + } + } + + if constexpr (!is_swa) { + if (threadIdx.x == 0) { + KV_min_max[sequence*ne31 + jt] = {0, KV_max_sj + FATTN_KQ_STRIDE}; + } + return; + } + + if (threadIdx.x == 0) { + KV_min_max[sequence*ne31 + jt].y = KV_max_sj + FATTN_KQ_STRIDE; + } + + if (tid < WARP_SIZE) { + buf_iw[tid] = 1; + } + __syncthreads(); + + int KV_min_sj = 0; + for (; KV_min_sj < KV_max_sj; KV_min_sj += FATTN_KQ_STRIDE) { + int all_inf = 1; + +#pragma unroll + for (int j = 0; j < ncols1; ++j) { + const float2 tmp = __half22float2(mask[j*s31 + KV_min_sj/2 + tid]); + all_inf = all_inf && int(isinf(tmp.x)) && int(isinf(tmp.y)); + } + + all_inf = warp_reduce_all(all_inf); + if (tid % WARP_SIZE == 0) { + buf_iw[tid / WARP_SIZE] = all_inf; + } + __syncthreads(); + all_inf = buf_iw[tid % WARP_SIZE]; + __syncthreads(); + all_inf = warp_reduce_all(all_inf); + + if (!all_inf) { + break; + } + } + + if (threadIdx.x == 0) { + KV_min_max[sequence*ne31 + jt].x = KV_min_sj; + } +} + + template void launch_fattn_mma( - ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel, const int nwarps, const size_t nbytes_shared, + ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_mma_t fattn_kernel, const int nwarps, const size_t nbytes_shared, const int KQ_row_granularity, const bool need_f16_K, const bool need_f16_V, const bool stream_k, const int warp_size = WARP_SIZE ) { constexpr int ncols = ncols1 * ncols2; @@ -1171,6 +1320,9 @@ void launch_fattn_mma( GGML_ASSERT(Q->ne[3] == 1); + int n_swa; + memcpy(&n_swa, (const int *) KQV->op_params + 4, sizeof(int)); + ggml_cuda_pool & pool = ctx.pool(); cudaStream_t main_stream = ctx.stream(); const int id = ggml_cuda_get_device(); @@ -1179,6 +1331,7 @@ void launch_fattn_mma( ggml_cuda_pool_alloc K_f16(pool); ggml_cuda_pool_alloc V_f16(pool); + ggml_cuda_pool_alloc KV_min_max(pool); ggml_cuda_pool_alloc dst_tmp(pool); ggml_cuda_pool_alloc dst_tmp_meta(pool); @@ -1225,11 +1378,29 @@ void launch_fattn_mma( const int ntiles_x = ((Q->ne[1] + ncols1 - 1) / ncols1); const int ntiles_total = ntiles_x * (Q->ne[2] / ncols2) * Q->ne[3]; + if (mask && (Q->ne[1] >= 1024 || (n_swa > 0 && K->ne[1] >= FATTN_KQ_STRIDE + n_swa))) { + const int s31 = mask->nb[1] / sizeof(half2); + const int s33 = mask->nb[3] / sizeof(half2); + const dim3 blocks_num_KV_max(ntiles_x, Q->ne[3], 1); + const dim3 block_dim_KV_max(FATTN_KQ_STRIDE/2, 1, 1); + const int ne_KV_max = blocks_num_KV_max.x*blocks_num_KV_max.y; + const int iter_k = K->ne[1] / FATTN_KQ_STRIDE; + KV_min_max.alloc(ne_KV_max); + if (n_swa > 0) { + flash_attn_mask_to_KV_min_max<<>> + ((const half2 *) mask->data, KV_min_max.ptr, iter_k, s31, s33); + } else { + flash_attn_mask_to_KV_min_max<<>> + ((const half2 *) mask->data, KV_min_max.ptr, iter_k, s31, s33); + } + CUDA_CHECK(cudaGetLastError()); + } + const dim3 block_dim(warp_size, nwarps, 1); dim3 blocks_num; if (stream_k) { // For short contexts it can be faster to have the SMs work on whole tiles because this lets us skip the fixup. - const int max_blocks = 2*nsm; + const int max_blocks = Q->ne[1] > 1 ? 2*nsm : nsm; const int tiles_nwaves = (ntiles_total + max_blocks - 1) / max_blocks; const int tiles_efficiency_percent = 100 * ntiles_total / (max_blocks*tiles_nwaves); @@ -1313,6 +1484,7 @@ void launch_fattn_mma( V_data, mask ? ((const char *) mask->data) : nullptr, sinks ? ((const char *)sinks->data) : nullptr, + KV_min_max.ptr, !stream_k && parallel_blocks > 1 ? dst_tmp.ptr : (float *) KQV->data, dst_tmp_meta.ptr, scale, max_bias, m0, m1, n_head_log2, logit_softcap, Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], @@ -1372,7 +1544,7 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml float logit_softcap; memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float)); - fattn_kernel_t fattn_kernel; + fattn_kernel_mma_t fattn_kernel; if (logit_softcap == 0.0f) { constexpr bool use_logit_softcap = false; fattn_kernel = flash_attn_mma_ext_f16; diff --git a/ggml/src/ggml-cuda/fattn-vec-common.cuh b/ggml/src/ggml-cuda/fattn-vec-common.cuh new file mode 100644 index 00000000..c6dda7ea --- /dev/null +++ b/ggml/src/ggml-cuda/fattn-vec-common.cuh @@ -0,0 +1,1074 @@ +// +// Copyright (C) 2023-2024 The ggml authors +// Copyright (C) 2024 Iwan Kawrakow +// MIT license +// SPDX-License-Identifier: MIT +// + +#pragma once + +#include "common.cuh" +#include "convert.cuh" +#include "vecdotq.cuh" + +#include + +#define FATTN_KQ_STRIDE 256 +#define HALF_MAX_HALF __float2half(65504.0f/2) // Use neg. of this instead of -INFINITY to initialize KQ max vals to avoid NaN upon subtraction. +#define SOFTMAX_FTZ_THRESHOLD -20.0f // Softmax exp. of values smaller than this are flushed to zero to avoid NaNs. + +typedef void (* fattn_kernel_t)( + const char * __restrict__ Q, + const char * __restrict__ K, + const char * __restrict__ V, + const char * __restrict__ mask, + const char * __restrict__ sinks, + const int2 * __restrict__ KV_max, + float * __restrict__ dst, + float2 * __restrict__ dst_meta, + const float scale, + const float max_bias, + const float m0, + const float m1, + const uint32_t n_head_log2, + const float logit_softcap, + const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03, + const int32_t nb01, const int32_t nb02, const int32_t nb03, + const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13, + const int32_t nb11, const int32_t nb12, const int64_t nb13, + const int32_t nb21, const int32_t nb22, const int64_t nb23, + const int32_t ne31, const int32_t ne32, const int32_t ne33, + const int32_t nb31, const int32_t nb32, const int64_t nb33); + +typedef half (*vec_dot_KQ_f16_t)( + const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds); +typedef float (*vec_dot_KQ_f32_t)( + const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds); + +template +static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_0( + const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) { + + const block_q4_0 * K_q4_0 = (const block_q4_0 *) K_c; + GGML_UNUSED(Q_v); + + T sum = 0.0f; + +#pragma unroll + for (int k_KQ_0 = 0; k_KQ_0 < Dk/sizeof(int); k_KQ_0 += WARP_SIZE) { + const int k_KQ = k_KQ_0 + threadIdx.x; + + const int ib = k_KQ / QI8_1; + const int iqs4 = k_KQ % QI4_0; + const int shift = k_KQ & (QI8_1/2); + + const int v = (get_int_b2(K_q4_0[ib].qs, iqs4) >> shift) & 0x0F0F0F0F; + const int u = Q_q8[k_KQ_0/WARP_SIZE]; + + const int sumi = ggml_cuda_dp4a(v, u, 0); + +#ifdef FP16_AVAILABLE + if (std::is_same::value) { + const half2 * Q_ds = (const half2 *) Q_ds_v; + + const half2 sum2 = __half2half2(K_q4_0[ib].d) * Q_ds[k_KQ_0/WARP_SIZE]; + sum += (T) (((half) sumi)*__low2half(sum2) - __high2half(sum2) /* *8/QI8_1 == 1 */); + } else +#endif // FP16_AVAILABLE + { + const float2 * Q_ds = (const float2 *) Q_ds_v; + + sum += (T) (__half2float(K_q4_0[ib].d) * (sumi*Q_ds[k_KQ_0/WARP_SIZE].x - (8/QI8_1)*Q_ds[k_KQ_0/WARP_SIZE].y)); + } + } + + return sum; +} + +template +static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_1( + const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) { + + const block_q4_1 * K_q4_1 = (const block_q4_1 *) K_c; + GGML_UNUSED(Q_v); + + T sum = 0.0f; + +#pragma unroll + for (int k_KQ_0 = 0; k_KQ_0 < Dk/sizeof(int); k_KQ_0 += WARP_SIZE) { + const int k_KQ = k_KQ_0 + threadIdx.x; + + const int ib = k_KQ / QI8_1; + const int iqs4 = k_KQ % QI4_1; + const int shift = k_KQ & (QI8_1/2); + + const int v = (get_int_b4(K_q4_1[ib].qs, iqs4) >> shift) & 0x0F0F0F0F; + const int u = Q_q8[k_KQ_0/WARP_SIZE]; + + const int sumi = ggml_cuda_dp4a(v, u, 0); + +#ifdef FP16_AVAILABLE + if (std::is_same::value) { + const half2 * Q_ds = (const half2 *) Q_ds_v; + + const half2 d4d8_m4s8 = K_q4_1[ib].dm * Q_ds[k_KQ_0/WARP_SIZE]; + const half2 sumid4d8_m4s8scaled = d4d8_m4s8 * make_half2(sumi, 1.0f/QI8_1); + sum += (T) (__low2half(sumid4d8_m4s8scaled) + __high2half(sumid4d8_m4s8scaled)); + } else +#endif // FP16_AVAILABLE + { + const float2 * Q_ds = (const float2 *) Q_ds_v; + + const float sumid4d8 = __low2float(K_q4_1[ib].dm)*Q_ds[k_KQ_0/WARP_SIZE].x * sumi; + const float m4s8scaled = __high2float(K_q4_1[ib].dm)*Q_ds[k_KQ_0/WARP_SIZE].y / QI8_1; + + sum += (T) (sumid4d8 + m4s8scaled); + } + } + + return sum; +} + +static __device__ __forceinline__ int get_one_int_from_table_16(const int & q4) { + const uint8_t * q0_8 = (const uint8_t *) &q4; + const char4 val0_8 = make_char4(kvalues_iq4nl[q0_8[0]], kvalues_iq4nl[q0_8[1]], kvalues_iq4nl[q0_8[2]], kvalues_iq4nl[q0_8[3]]); + return *((const int *) &val0_8); +} + +template +static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_iq4_nl( + const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) { + + const block_iq4_nl * K_iq4_nl = (const block_iq4_nl *) K_c; + GGML_UNUSED(Q_v); + + T sum = 0.0f; + +#pragma unroll + for (int k_KQ_0 = 0; k_KQ_0 < Dk/sizeof(int); k_KQ_0 += WARP_SIZE) { + const int k_KQ = k_KQ_0 + threadIdx.x; + + const int ib = k_KQ / QI8_1; + const int iqs4 = k_KQ % QI4_NL; + const int shift = k_KQ & (QI8_1/2); + + const int v = get_one_int_from_table_16((get_int_b2(K_iq4_nl[ib].qs, iqs4) >> shift) & 0x0F0F0F0F); + const int u = Q_q8[k_KQ_0/WARP_SIZE]; + + const int sumi = ggml_cuda_dp4a(v, u, 0); + +#ifdef FP16_AVAILABLE + if (std::is_same::value) { + const half2 * Q_ds = (const half2 *) Q_ds_v; + sum += (T) (((half)sumi) * K_iq4_nl[ib].d * Q_ds[k_KQ_0/WARP_SIZE].x); + } else +#endif // FP16_AVAILABLE + { + const float2 * Q_ds = (const float2 *) Q_ds_v; + sum += (T) ((float)sumi * __half2float(K_iq4_nl[ib].d) * Q_ds[k_KQ_0/WARP_SIZE].x); + } + } + + return sum; +} + +template +static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_0( + const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) { + + const block_q5_0 * K_q5_0 = (const block_q5_0 *) K_c; + GGML_UNUSED(Q_v); + + T sum = 0.0f; + +#pragma unroll + for (int k_KQ_0 = 0; k_KQ_0 < Dk/sizeof(int); k_KQ_0 += WARP_SIZE) { + const int k_KQ = k_KQ_0 + threadIdx.x; + + const int ib = k_KQ / QI8_1; + const int iqs4 = k_KQ % QI5_0; + const int iqs8 = k_KQ % QI8_1; + const int shift = k_KQ & (QI8_1/2); + + int v = (get_int_b2(K_q5_0[ib].qs, iqs4) >> shift) & 0x0F0F0F0F; + const int vh = get_int_b2(K_q5_0[ib].qh, 0) >> (iqs8 * QI5_0); + v |= (vh << 4) & 0x00000010; // 0 -> 4 + v |= (vh << 11) & 0x00001000; // 1 -> 12 + v |= (vh << 18) & 0x00100000; // 2 -> 20 + v |= (vh << 25) & 0x10000000; // 3 -> 28 + + const int u = Q_q8[k_KQ_0/WARP_SIZE]; + + const int sumi = ggml_cuda_dp4a(v, u, 0); + +#ifdef FP16_AVAILABLE + if (std::is_same::value) { + const half2 * Q_ds = (const half2 *) Q_ds_v; + + const half2 sum2 = __half2half2(K_q5_0[ib].d) * Q_ds[k_KQ_0/WARP_SIZE]; + sum += (T) (((half) sumi)*__low2half(sum2) - __high2half(sum2)*__float2half(2.0f)) /* *16/QI8_1 == 2 */; + } else +#endif // FP16_AVAILABLE + { + const float2 * Q_ds = (const float2 *) Q_ds_v; + + sum += (T) (__half2float(K_q5_0[ib].d) * (sumi*Q_ds[k_KQ_0/WARP_SIZE].x - (16/QI8_1)*Q_ds[k_KQ_0/WARP_SIZE].y)); + } + } + + return sum; +} + +template +static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_1( + const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) { + + const block_q5_1 * K_q5_1 = (const block_q5_1 *) K_c; + GGML_UNUSED(Q_v); + + T sum = 0.0f; + +#pragma unroll + for (int k_KQ_0 = 0; k_KQ_0 < Dk/sizeof(int); k_KQ_0 += WARP_SIZE) { + const int k_KQ = k_KQ_0 + threadIdx.x; + + const int ib = k_KQ / QI8_1; + const int iqs4 = k_KQ % QI5_1; + const int iqs8 = k_KQ % QI8_1; + const int shift = k_KQ & (QI8_1/2); + + int v = (get_int_b2(K_q5_1[ib].qs, iqs4) >> shift) & 0x0F0F0F0F; + const int vh = get_int_b2(K_q5_1[ib].qh, 0) >> (iqs8 * QI5_1); + v |= (vh << 4) & 0x00000010; // 0 -> 4 + v |= (vh << 11) & 0x00001000; // 1 -> 12 + v |= (vh << 18) & 0x00100000; // 2 -> 20 + v |= (vh << 25) & 0x10000000; // 3 -> 28 + + const int u = Q_q8[k_KQ_0/WARP_SIZE]; + + const int sumi = ggml_cuda_dp4a(v, u, 0); + +#ifdef FP16_AVAILABLE + if (std::is_same::value) { + const half2 * Q_ds = (const half2 *) Q_ds_v; + + const half2 d5d8_m5s8 = K_q5_1[ib].dm * Q_ds[k_KQ_0/WARP_SIZE]; + const half2 sumid5d8_m5s8scaled = d5d8_m5s8 * make_half2(sumi, 1.0f/QI8_1); + sum += (T) (__low2half(sumid5d8_m5s8scaled) + __high2half(sumid5d8_m5s8scaled)); + } else +#endif // FP16_AVAILABLE + { + const float2 * Q_ds = (const float2 *) Q_ds_v; + + const float sumid5d8 = __low2float(K_q5_1[ib].dm)*Q_ds[k_KQ_0/WARP_SIZE].x * sumi; + const float m5s8scaled = __high2float(K_q5_1[ib].dm)*Q_ds[k_KQ_0/WARP_SIZE].y / QI8_1; + + sum += (T) (sumid5d8 + m5s8scaled); + } + } + + return sum; +} + +template +static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q6_0( + const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) { + + const block_q6_0 * K_q6_0 = (const block_q6_0 *) K_c; + GGML_UNUSED(Q_v); + + T sum = 0.0f; + +#pragma unroll + for (int k_KQ_0 = 0; k_KQ_0 < Dk/sizeof(int); k_KQ_0 += WARP_SIZE) { + const int k_KQ = k_KQ_0 + threadIdx.x; + + const int ib = k_KQ / QI8_1; + const int iqs4 = k_KQ % QI6_0; // 0...3 + const int shift = k_KQ & (QI8_1/2); + + const int vh = (get_int_b2(K_q6_0[ib].qh, iqs4%2) >> (4*(iqs4/2) + shift/2)) & 0x03030303; + const int vl = (get_int_b2(K_q6_0[ib].qs, iqs4) >> shift) & 0x0F0F0F0F; + const int v = vl | (vh << 4); + + const int u = Q_q8[k_KQ_0/WARP_SIZE]; + + const int sumi = ggml_cuda_dp4a(v, u, 0); + +#ifdef FP16_AVAILABLE + if (std::is_same::value) { + const half2 * Q_ds = (const half2 *) Q_ds_v; + + const half2 sum2 = __half2half2(K_q6_0[ib].d) * Q_ds[k_KQ_0/WARP_SIZE]; + sum += (T) (((half) sumi)*__low2half(sum2) - __high2half(sum2)*__float2half(4.0f)) /* *32/QI8_1 == 4 */; + } else +#endif // FP16_AVAILABLE + { + const float2 * Q_ds = (const float2 *) Q_ds_v; + + sum += (T) (__half2float(K_q6_0[ib].d) * (sumi*Q_ds[k_KQ_0/WARP_SIZE].x - (32/QI8_1)*Q_ds[k_KQ_0/WARP_SIZE].y)); + } + } + + return sum; +} + +template +static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q8_0( + const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) { + + const block_q8_0 * K_q8_0 = (const block_q8_0 *) K_c; + GGML_UNUSED(Q_v); + + T sum = 0.0f; + +#pragma unroll + for (int k_KQ_0 = 0; k_KQ_0 < Dk/sizeof(int); k_KQ_0 += WARP_SIZE) { + const int k_KQ = k_KQ_0 + threadIdx.x; + + const int ib = k_KQ / QI8_0; + const int iqs = k_KQ % QI8_0; + + const int v = get_int_b2(K_q8_0[ib].qs, iqs); + + T Q_d; + if (std::is_same::value) { + const half2 * Q_ds = (const half2 *) Q_ds_v; + Q_d = __low2half(Q_ds[k_KQ_0/WARP_SIZE]); + } else { + const float2 * Q_ds = (const float2 *) Q_ds_v; + Q_d = Q_ds[k_KQ_0/WARP_SIZE].x; + } + + sum += vec_dot_q8_0_q8_1_impl(&v, &Q_q8[k_KQ_0/WARP_SIZE], K_q8_0[ib].d, Q_d); + } + + return sum; +} + +template +static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_f16( + const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds_v) { + + const half2 * K_h2 = (const half2 *) K_c; + GGML_UNUSED(Q_q8); + GGML_UNUSED(Q_ds_v); + +#ifdef FP16_AVAILABLE + if (std::is_same::value) { + const half2 * Q_h2 = (const half2 *) Q_v; + + half2 sum2 = make_half2(0.0f, 0.0f); + +#pragma unroll + for (int k_KQ_0 = 0; k_KQ_0 < Dk/2; k_KQ_0 += WARP_SIZE) { + const int k_KQ = k_KQ_0 + threadIdx.x; + + const half2 K_ik = K_h2[k_KQ]; + sum2 += K_ik * Q_h2[k_KQ_0/WARP_SIZE]; + } + + return __low2half(sum2) + __high2half(sum2); + } +#endif // FP16_AVAILABLE + + const float2 * Q_f2 = (const float2 *) Q_v; + + float sum = 0.0f; + +#pragma unroll + for (int k_KQ_0 = 0; k_KQ_0 < Dk/2; k_KQ_0 += WARP_SIZE) { + const int k_KQ = k_KQ_0 + threadIdx.x; + + const half2 K_ik = K_h2[k_KQ]; + sum += __low2float(K_ik) * Q_f2[k_KQ_0/WARP_SIZE].x; + sum += __high2float(K_ik) * Q_f2[k_KQ_0/WARP_SIZE].y; + } + + return sum; +} + +template +static __device__ __forceinline__ void quantize_q8_1_to_shared( + const float * __restrict__ x, const float scale, int * __restrict__ yq32, void * __restrict__ yds) { + + float vals[sizeof(int)] = {0.0f}; +#pragma unroll + for (int l = 0; l < sizeof(int); ++l) { + vals[l] = scale * x[4*threadIdx.x + l]; + } + + float amax = fabsf(vals[0]); + float sum = vals[0]; +#pragma unroll + for (int l = 1; l < sizeof(int); ++l) { + amax = fmaxf(amax, fabsf(vals[l])); + sum += vals[l]; + } +#pragma unroll + for (int mask = QI8_1/2; mask > 0; mask >>= 1) { + amax = fmaxf(amax, __shfl_xor_sync(0xFFFFFFFF, amax, mask, 32)); + sum += __shfl_xor_sync(0xFFFFFFFF, sum, mask, 32); + } + + const float d = amax / 127; + int q32 = 0; + int8_t * q8 = (int8_t *) &q32; + + if (d != 0.0f) { +#pragma unroll + for (int l = 0; l < sizeof(int); ++l) { + q8[l] = roundf(vals[l] / d); + } + } + + yq32[threadIdx.x] = q32; + if (threadIdx.x % QI8_1 == 0) { + if (std::is_same::value) { + ((half2 *) yds)[threadIdx.x/QI8_1] = make_half2(d, sum); + } else { + ((float2 *) yds)[threadIdx.x/QI8_1] = make_float2(d, sum); + } + } +} + +typedef half (*dequantize_1_f16_t)(const void *, const int64_t); +typedef float (*dequantize_1_f32_t)(const void *, const int64_t); + +template +static __device__ __forceinline__ T dequantize_1_q4_0(const void * __restrict__ vx, const int64_t i) { + const block_q4_0 * x = (const block_q4_0 *) vx; + + const int64_t ib = i / QK4_0; + const int iqs = i % (QK4_0/2); + const int shift = (i % QK4_0) / (QK4_0/2); + + const T d = x[ib].d; + const int q0 = x[ib].qs[iqs]; + const int q = ((q0 >> (4*shift)) & 0x0F) - 8; + +#ifdef FP16_AVAILABLE + if (std::is_same::value) { + return ((half) d)*((half) q); + } +#endif // FP16_AVAILABLE + + return ((float) d)*((float) q); +} + +template +static __device__ __forceinline__ T dequantize_1_iq4_nl(const void * __restrict__ vx, const int64_t i) { + const block_iq4_nl * x = (const block_iq4_nl *) vx; + + const int64_t ib = i / QK4_NL; + const int iqs = i % (QK4_NL/2); + const int shift = (i % QK4_NL) / (QK4_NL/2); + +#ifdef FP16_AVAILABLE + if constexpr (std::is_same::value) { + return x[ib].d * ((half) kvalues_iq4nl[(x[ib].qs[iqs] >> 4*(shift)) & 0xf]); + } else { + return (float)x[ib].d * ((float) kvalues_iq4nl[(x[ib].qs[iqs] >> 4*(shift)) & 0xf]); + } +#endif + T result = (float)x[ib].d * ((float) kvalues_iq4nl[(x[ib].qs[iqs] >> 4*(shift)) & 0xf]); + return result; +} + +template +static __device__ __forceinline__ T dequantize_1_q4_1(const void * __restrict__ vx, const int64_t i) { + const block_q4_1 * x = (const block_q4_1 *) vx; + + const int64_t ib = i / QK4_1; + const int iqs = i % (QK4_1/2); + const int shift = (i % QK4_1) / (QK4_1/2); + + const half2 dm = x[ib].dm; + const int q0 = x[ib].qs[iqs]; + const int q = ((q0 >> (4*shift)) & 0x0F); + +#ifdef FP16_AVAILABLE + if (std::is_same::value) { + return __low2half(dm)*((half) q) + __high2half(dm); + } +#endif // FP16_AVAILABLE + + return __low2float(dm)*((float) q) + __high2float(dm); +} + +template +static __device__ __forceinline__ T dequantize_1_q5_0(const void * __restrict__ vx, const int64_t i) { + const block_q5_0 * x = (const block_q5_0 *) vx; + + const int64_t ib = i / QK5_0; + const int idq = i % QK5_0; + const int iqs = i % (QK5_0/2); + const int shift = (i % QK5_0) / (QK5_0/2); + + const T d = x[ib].d; + const int ql0 = x[ib].qs[iqs]; + const int qh0 = get_int_b2(x[ib].qh, 0); + const int ql = ((ql0 >> (4*shift)) & 0x0F); + const int qh = ((qh0 >> idq) << 4) & 0x10; + const int q = (ql | qh) - 16; + +#ifdef FP16_AVAILABLE + if (std::is_same::value) { + return ((half) d)*((half) q); + } +#endif // FP16_AVAILABLE + + return ((float) d)*((float) q); +} + +template +static __device__ __forceinline__ T dequantize_1_q5_1(const void * __restrict__ vx, const int64_t i) { + const block_q5_1 * x = (const block_q5_1 *) vx; + + const int64_t ib = i / QK5_1; + const int idq = i % QK5_1; + const int iqs = i % (QK5_1/2); + const int shift = (i % QK5_1) / (QK5_1/2); + + const half2 dm = x[ib].dm; + const int ql0 = x[ib].qs[iqs]; + const int qh0 = get_int_b4(x[ib].qh, 0); + const int ql = ((ql0 >> (4*shift)) & 0x0F); + const int qh = ((qh0 >> idq) << 4) & 0x10; + const int q = (ql | qh); + +#ifdef FP16_AVAILABLE + if (std::is_same::value) { + return __low2half(dm)*((half) q) + __high2half(dm); + } +#endif // FP16_AVAILABLE + + return __low2float(dm)*((float) q) + __high2float(dm); +} + +template +static __device__ __forceinline__ T dequantize_1_q6_0(const void * __restrict__ vx, const int64_t i) { + const block_q6_0 * x = (const block_q6_0 *) vx; + + const int64_t ib = i / QK6_0; + const int idq = i % QK6_0; + const int iqs = i % (QK6_0/2); + const int shift = idq / (QK6_0/2); + //const int shift = (i % QK6_0) / (QK6_0/2); + + const T d = x[ib].d; + const int ql = x[ib].qs[iqs] >> 4*shift; + const int qh = x[ib].qh[idq%(QK6_0/4)] >> (4*((idq/(QK6_0/4))%2) + 2*shift); + const int q = ((ql & 0x0f) | ((qh & 0x03) << 4)) - 32; + +#ifdef FP16_AVAILABLE + if (std::is_same::value) { + return ((half) d)*((half) q); + } +#endif // FP16_AVAILABLE + + return ((float) d)*((float) q); +} + +template +static __device__ __forceinline__ T dequantize_1_q8_0(const void * __restrict__ vx, const int64_t i) { + const block_q8_0 * x = (const block_q8_0 *) vx; + + const int64_t ib = i / QK8_0; + const int iqs = i % QK8_0; + + const T d = x[ib].d; + const int q = x[ib].qs[iqs]; + +#ifdef FP16_AVAILABLE + if (std::is_same::value) { + return ((half) d)*((half) q); + } +#endif // FP16_AVAILABLE + + return ((float) d)*((float) q); +} + +template +static __device__ __forceinline__ T dequantize_1_f16(const void * __restrict__ vx, const int64_t i) { + const half * x = (const half *) vx; + + return x[i]; +} + +template +constexpr __device__ vec_dot_KQ_f16_t get_vec_dot_KQ_f16(ggml_type type_K) { + return type_K == GGML_TYPE_Q4_0 ? vec_dot_fattn_vec_KQ_q4_0 : + type_K == GGML_TYPE_Q4_1 ? vec_dot_fattn_vec_KQ_q4_1 : + type_K == GGML_TYPE_IQ4_NL ? vec_dot_fattn_vec_KQ_iq4_nl : + type_K == GGML_TYPE_Q5_0 ? vec_dot_fattn_vec_KQ_q5_0 : + type_K == GGML_TYPE_Q5_1 ? vec_dot_fattn_vec_KQ_q5_1 : + type_K == GGML_TYPE_Q6_0 ? vec_dot_fattn_vec_KQ_q6_0 : + type_K == GGML_TYPE_Q8_0 ? vec_dot_fattn_vec_KQ_q8_0 : + type_K == GGML_TYPE_F16 ? vec_dot_fattn_vec_KQ_f16 : + nullptr; +} + +template +constexpr __device__ vec_dot_KQ_f32_t get_vec_dot_KQ_f32(ggml_type type_K) { + return type_K == GGML_TYPE_Q4_0 ? vec_dot_fattn_vec_KQ_q4_0 : + type_K == GGML_TYPE_Q4_1 ? vec_dot_fattn_vec_KQ_q4_1 : + type_K == GGML_TYPE_IQ4_NL ? vec_dot_fattn_vec_KQ_iq4_nl : + type_K == GGML_TYPE_Q5_0 ? vec_dot_fattn_vec_KQ_q5_0 : + type_K == GGML_TYPE_Q5_1 ? vec_dot_fattn_vec_KQ_q5_1 : + type_K == GGML_TYPE_Q6_0 ? vec_dot_fattn_vec_KQ_q6_0 : + type_K == GGML_TYPE_Q8_0 ? vec_dot_fattn_vec_KQ_q8_0 : + type_K == GGML_TYPE_F16 ? vec_dot_fattn_vec_KQ_f16 : + nullptr; +} + +constexpr __device__ dequantize_1_f16_t get_dequantize_1_f16(ggml_type type_V) { + return type_V == GGML_TYPE_Q4_0 ? dequantize_1_q4_0 : + type_V == GGML_TYPE_Q4_1 ? dequantize_1_q4_1 : + type_V == GGML_TYPE_Q5_0 ? dequantize_1_q5_0 : + type_V == GGML_TYPE_Q5_1 ? dequantize_1_q5_1 : + type_V == GGML_TYPE_Q6_0 ? dequantize_1_q6_0 : + type_V == GGML_TYPE_Q8_0 ? dequantize_1_q8_0 : + type_V == GGML_TYPE_IQ4_NL ? dequantize_1_iq4_nl : + type_V == GGML_TYPE_F16 ? dequantize_1_f16 : + nullptr; +} + +constexpr __device__ dequantize_1_f32_t get_dequantize_1_f32(ggml_type type_V) { + return type_V == GGML_TYPE_Q4_0 ? dequantize_1_q4_0 : + type_V == GGML_TYPE_Q4_1 ? dequantize_1_q4_1 : + type_V == GGML_TYPE_Q5_0 ? dequantize_1_q5_0 : + type_V == GGML_TYPE_Q5_1 ? dequantize_1_q5_1 : + type_V == GGML_TYPE_Q6_0 ? dequantize_1_q6_0 : + type_V == GGML_TYPE_Q8_0 ? dequantize_1_q8_0 : + type_V == GGML_TYPE_IQ4_NL ? dequantize_1_iq4_nl : + type_V == GGML_TYPE_F16 ? dequantize_1_f16 : + nullptr; +} + +template // D == head size +#if !defined(GGML_USE_HIP) +__launch_bounds__(D, 1) +#endif // !(defined(GGML_USE_HIP) +static __global__ void flash_attn_combine_results( + const float * __restrict__ VKQ_parts, + const float2 * __restrict__ VKQ_meta, + float * __restrict__ dst, + const int parallel_blocks) { + // Dimension 0: threadIdx.x + // Dimension 1: blockIdx.x + // Dimension 2: blockIdx.y + // Dimension 3: blockIdx.z + // Memory layout is permuted with [0, 2, 1, 3] + + const int ne01 = gridDim.x; + const int ne02 = gridDim.y; + + const int col = blockIdx.x; + const int head = blockIdx.y; + const int sequence = blockIdx.z; + + const int j_dst_unrolled = (sequence*ne01 + col)*ne02 + head; + + VKQ_parts += j_dst_unrolled * parallel_blocks*D; + VKQ_meta += j_dst_unrolled * parallel_blocks; + dst += j_dst_unrolled * D; + + const int tid = threadIdx.x; + __builtin_assume(tid < D); + + extern __shared__ float2 meta[]; + for (int i = tid; i < 2*parallel_blocks; i += D) { + ((float *) meta)[i] = ((const float *)VKQ_meta) [i]; + } + + __syncthreads(); + + float kqmax = meta[0].x; + for (int l = 1; l < parallel_blocks; ++l) { + kqmax = max(kqmax, meta[l].x); + } + + float VKQ_numerator = 0.0f; + float VKQ_denominator = 0.0f; +#pragma unroll + for (int l = 0; l < parallel_blocks; ++l) { + const float diff = meta[l].x - kqmax; + float KQ_max_scale = expf(diff); + const uint32_t ftz_mask = 0xFFFFFFFF * (diff > SOFTMAX_FTZ_THRESHOLD); + *((uint32_t *) &KQ_max_scale) &= ftz_mask; + + VKQ_numerator += KQ_max_scale * VKQ_parts[l*D + tid]; + VKQ_denominator += KQ_max_scale * meta[l].y; + } + + dst[tid] = VKQ_numerator / VKQ_denominator; +} + +template +static __device__ __forceinline__ int warp_reduce_all(int x) { + if constexpr (width == WARP_SIZE) { //ggml_cuda_get_physical_warp_size()) { + return __all_sync(0xffffffff, x); + } else { +#pragma unroll + for (int offset = width/2; offset > 0; offset >>= 1) { + x = __shfl_xor_sync(0xffffffff, x, offset, width) && x; + } + return x; + } +} + +template +__launch_bounds__(FATTN_KQ_STRIDE/2, 1) +static __global__ void flash_attn_mask_to_KV_min_max( + const half2 * __restrict__ mask, int2 * __restrict__ KV_min_max, const int ne30, const int s31, const int s33) { + const int ne31 = gridDim.x; + const int tid = threadIdx.x; + const int sequence = blockIdx.y; + const int jt = blockIdx.x; + + mask += sequence*s33 + jt*ncols1*s31; + + __shared__ int buf_iw[WARP_SIZE]; + if (tid < WARP_SIZE) { + buf_iw[tid] = 1; + } + __syncthreads(); + + int KV_max_sj = (ne30 - 1) * FATTN_KQ_STRIDE; + for (; KV_max_sj >= 0; KV_max_sj -= FATTN_KQ_STRIDE) { + int all_inf = 1; + +#pragma unroll + for (int j = 0; j < ncols1; ++j) { + const float2 tmp = __half22float2(mask[j*s31 + KV_max_sj/2 + tid]); + all_inf = all_inf && int(isinf(tmp.x)) && int(isinf(tmp.y)); + } + + all_inf = warp_reduce_all(all_inf); + if (tid % WARP_SIZE == 0) { + buf_iw[tid / WARP_SIZE] = all_inf; + } + __syncthreads(); + all_inf = buf_iw[tid % WARP_SIZE]; + __syncthreads(); + all_inf = warp_reduce_all(all_inf); + + if (!all_inf) { + break; + } + } + + if constexpr (!is_swa) { + if (threadIdx.x == 0) { + KV_min_max[sequence*ne31 + jt] = {0, KV_max_sj + FATTN_KQ_STRIDE}; + } + return; + } + + if (threadIdx.x == 0) { + KV_min_max[sequence*ne31 + jt].y = KV_max_sj + FATTN_KQ_STRIDE; + } + + if (tid < WARP_SIZE) { + buf_iw[tid] = 1; + } + __syncthreads(); + + int KV_min_sj = 0; + for (; KV_min_sj < KV_max_sj; KV_min_sj += FATTN_KQ_STRIDE) { + int all_inf = 1; + +#pragma unroll + for (int j = 0; j < ncols1; ++j) { + const float2 tmp = __half22float2(mask[j*s31 + KV_min_sj/2 + tid]); + all_inf = all_inf && int(isinf(tmp.x)) && int(isinf(tmp.y)); + } + + all_inf = warp_reduce_all(all_inf); + if (tid % WARP_SIZE == 0) { + buf_iw[tid / WARP_SIZE] = all_inf; + } + __syncthreads(); + all_inf = buf_iw[tid % WARP_SIZE]; + __syncthreads(); + all_inf = warp_reduce_all(all_inf); + + if (!all_inf) { + break; + } + } + + if (threadIdx.x == 0) { + KV_min_max[sequence*ne31 + jt].x = KV_min_sj; + } +} + +static void on_no_fattn_vec_case(const int Dk, const int Dv) { + if (Dk == 64 && Dv == 64) { + fprintf(stderr, "Unsupported KV type combination for head_size 64.\n"); + fprintf(stderr, "By default only f16 KV cache is supported.\n"); + fprintf(stderr, "Compile with GGML_CUDA_FA_ALL_QUANTS for V cache quantization support.\n"); + GGML_ABORT("fatal error"); + } else if (Dk == 128 && Dv == 128) { + fprintf(stderr, "Unsupported KV type combination for head_size 128.\n"); + fprintf(stderr, "Supported combinations:\n"); + fprintf(stderr, " - K == q4_0, V == q4_0, 4.5 BPV\n"); + fprintf(stderr, " - K == iq4_nl, V == iq4_nl, 4.5 BPV\n"); + fprintf(stderr, " - K == q6_0, V == q5_0, 6.0 BPV\n"); + fprintf(stderr, " - K == q8_0, V == iq4_nl, 6.5 BPV\n"); + fprintf(stderr, " - K == q8_0, V == q6_0, 7.5 BPV\n"); + fprintf(stderr, " - K == q8_0, V == q8_0, 8.5 BPV\n"); + fprintf(stderr, " - K == f16, V == f16, 16.0 BPV\n"); + fprintf(stderr, "Compile with GGML_CUDA_FA_ALL_QUANTS for all combinations of q4_0, q4_1, iq4_nl, q5_0, q5_1, q8_0, and f16.\n"); + GGML_ABORT("fatal error"); + } + else if (Dk == 192 && Dv == 128) { + fprintf(stderr, "Unsupported KV type combination for head_sizes 192 / 128\n"); + // TODO: add what is supported + } + else if (Dk == 576 && Dv == 512) { + fprintf(stderr, "Unsupported KV type combination for head_sizes 576 / 512\n"); + // TODO: add what is supported + } else { + fprintf(stderr, "Unsupported KV type combination for head_sizes %d, %d.\n", Dk, Dv); + fprintf(stderr, "Only f16 is supported.\n"); + GGML_ABORT("fatal error"); + } +} + +template +void launch_fattn( + ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel, const int nwarps, const size_t nbytes_shared, + const int KQ_row_granularity, const bool need_f16_K, const bool need_f16_V, const int warp_size = WARP_SIZE) { + + const bool is_mla = DV == 512; // TODO better parameterization + + const ggml_tensor * Q = dst->src[0]; + const ggml_tensor * K = dst->src[1]; + const ggml_tensor * V = dst->src[2]; + + GGML_ASSERT(V || is_mla); + + const ggml_tensor * mask = dst->src[3]; + const ggml_tensor * sinks = dst->src[4]; + + ggml_tensor * KQV = dst; + + GGML_ASSERT(Q->type == GGML_TYPE_F32); + GGML_ASSERT(KQV->type == GGML_TYPE_F32); + + GGML_ASSERT( Q->nb[0] == ggml_element_size(Q)); + GGML_ASSERT( K->nb[0] == ggml_element_size(K)); + GGML_ASSERT(!V || V->nb[0] == ggml_element_size(V)); + + GGML_ASSERT(!mask || mask->type == GGML_TYPE_F16); + GGML_ASSERT(!mask || mask->ne[1] >= GGML_PAD(Q->ne[1], 16) && + "the Flash-Attention CUDA kernel requires the mask to be padded to 16 and at least n_queries big"); + + GGML_ASSERT(K->ne[1] % FATTN_KQ_STRIDE == 0 && "Incorrect KV cache padding."); + + int n_swa; + memcpy(&n_swa, (const int *) KQV->op_params + 4, sizeof(int)); + + auto & pool = ctx.pool(); + cudaStream_t main_stream = ctx.stream(); + const int id = ggml_cuda_get_device(); + const int cc = ggml_cuda_info().devices[id].cc; + const int nsm = ggml_cuda_info().devices[id].nsm; + + ggml_cuda_pool_alloc K_f16(pool); + ggml_cuda_pool_alloc V_f16(pool); + ggml_cuda_pool_alloc KV_min_max(pool); + ggml_cuda_pool_alloc dst_tmp(pool); + ggml_cuda_pool_alloc dst_tmp_meta(pool); + + const char * K_data = (const char *) K->data; + size_t nb11 = K->nb[1]; + size_t nb12 = K->nb[2]; + size_t nb13 = K->nb[3]; + + const char * V_data = V ? (const char *) V->data : nullptr; + size_t nb21 = V ? V->nb[1] : nb11; + size_t nb22 = V ? V->nb[2] : nb12; + size_t nb23 = V ? V->nb[3] : nb13; + + if (need_f16_K && K->type != GGML_TYPE_F16) { + const size_t bs = ggml_blck_size(K->type); + const size_t ts = ggml_type_size(K->type); + + K_f16.alloc(ggml_nelements(K)); + if (ggml_is_contiguously_allocated(K)) { + to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(K->type); + to_fp16(K_data, K_f16.ptr, 1, ggml_nelements(K), main_stream); + + nb11 = nb11*bs*sizeof(half)/ts; + nb12 = nb12*bs*sizeof(half)/ts; + nb13 = nb13*bs*sizeof(half)/ts; + } else { + GGML_ABORT("Non-contiguous K-cache is not supported bu the vector FA kernels"); + //GGML_ASSERT(K->nb[0] == ts); + //to_fp16_nc_cuda_t to_fp16 = ggml_get_to_fp16_nc_cuda(K->type); + //const int64_t s01 = nb11 / ts; + //const int64_t s02 = nb12 / ts; + //const int64_t s03 = nb13 / ts; + //to_fp16(K_data, K_f16.ptr, K->ne[0], K->ne[1], K->ne[2], K->ne[3], s01, s02, s03, main_stream); + + //nb11 = K->ne[0] * sizeof(half); + //nb12 = K->ne[1] * nb11; + //nb13 = K->ne[2] * nb12; + } + K_data = (char *) K_f16.ptr; + } + + if (V && need_f16_V && V->type != GGML_TYPE_F16) { + const size_t bs = ggml_blck_size(V->type); + const size_t ts = ggml_type_size(V->type); + + V_f16.alloc(ggml_nelements(V)); + if (ggml_is_contiguously_allocated(V)) { + to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(V->type); + to_fp16(V_data, V_f16.ptr, 1, ggml_nelements(V), main_stream); + V_data = (char *) V_f16.ptr; + + nb21 = nb21*bs*sizeof(half)/ts; + nb22 = nb22*bs*sizeof(half)/ts; + nb23 = nb23*bs*sizeof(half)/ts; + } else { + GGML_ABORT("Non-contiguous V-cache is not supported bu the vector FA kernels"); + //GGML_ASSERT(V->nb[0] == ts); + //to_fp16_nc_cuda_t to_fp16 = ggml_get_to_fp16_nc_cuda(V->type); + //const int64_t s01 = nb21 / ts; + //const int64_t s02 = nb22 / ts; + //const int64_t s03 = nb23 / ts; + //to_fp16(V_data, V_f16.ptr, V->ne[0], V->ne[1], V->ne[2], V->ne[3], s01, s02, s03, main_stream); + + //nb21 = V->ne[0] * sizeof(half); + //nb22 = V->ne[1] * nb21; + //nb23 = V->ne[2] * nb22; + } + V_data = (char *) V_f16.ptr; + } + + const int ntiles_x = ((Q->ne[1] + ncols1 - 1) / ncols1); + const int ntiles_total = ntiles_x * (Q->ne[2] / ncols2) * Q->ne[3]; + + // Optional optimization where the mask is scanned to determine whether part of the calculation can be skipped. + // Only worth the overhead if there is at lease one FATTN_KQ_STRIDE x FATTN_KQ_STRIDE square to be skipped or + // multiple sequences of possibly different lengths. + if (mask && (Q->ne[1] >= 1024 || Q->ne[3] > 1 || (n_swa > 0 && K->ne[1] >= FATTN_KQ_STRIDE + n_swa))) { + const int s31 = mask->nb[1] / sizeof(half2); + const int s33 = mask->nb[3] / sizeof(half2); + + const dim3 blocks_num_KV_max(ntiles_x, Q->ne[3], 1); + const dim3 block_dim_KV_max(FATTN_KQ_STRIDE/2, 1, 1); + + const int ne_KV_max = blocks_num_KV_max.x*blocks_num_KV_max.y; + const int iter_k = K->ne[1] / FATTN_KQ_STRIDE; + + KV_min_max.alloc(ne_KV_max); + if (n_swa > 0) { + flash_attn_mask_to_KV_min_max<<>> + ((const half2 *) mask->data, KV_min_max.ptr, iter_k, s31, s33); + } else { + flash_attn_mask_to_KV_min_max<<>> + ((const half2 *) mask->data, KV_min_max.ptr, iter_k, s31, s33); + } + CUDA_CHECK(cudaGetLastError()); + } + + int parallel_blocks = 1; + + const dim3 block_dim(warp_size, nwarps, 1); + int max_blocks_per_sm = 1; // Max. number of active blocks limited by occupancy. + CUDA_CHECK(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_blocks_per_sm, fattn_kernel, block_dim.x * block_dim.y * block_dim.z, nbytes_shared)); + + dim3 blocks_num; + + GGML_ASSERT(K->ne[1] % KQ_row_granularity == 0); + const int ntiles_KQ = K->ne[1] / KQ_row_granularity; // Max. number of parallel blocks limited by tensor size. + + // parallel_blocks should be at least large enough to achieve max. occupancy for a single wave: + parallel_blocks = std::max((nsm * max_blocks_per_sm) / ntiles_total, 1); + + // parallel_blocks must not be larger than what the tensor size allows: + parallel_blocks = std::min(parallel_blocks, ntiles_KQ); + + // If ntiles_total % blocks_per_wave != 0 then some efficiency is lost due to tail effects. + // Test whether parallel_blocks can be set to a higher value for better efficiency. + const int blocks_per_wave = nsm * max_blocks_per_sm; + int nwaves_best = 0; + int efficiency_percent_best = 0; + for (int parallel_blocks_test = parallel_blocks; parallel_blocks_test <= ntiles_KQ; ++parallel_blocks_test) { + const int nblocks_total = ntiles_total * parallel_blocks_test; + const int nwaves = (nblocks_total + blocks_per_wave - 1) / blocks_per_wave; + const int efficiency_percent = 100 * nblocks_total / (nwaves*blocks_per_wave); + + // Stop trying configurations with more waves if we already have good efficiency to avoid excessive overhead. + if (efficiency_percent_best >= 90 && nwaves > nwaves_best) { + break; + } + + if (efficiency_percent > efficiency_percent_best) { + nwaves_best = nwaves; + efficiency_percent_best = efficiency_percent; + parallel_blocks = parallel_blocks_test; + } + } + + blocks_num.x = ntiles_x; + blocks_num.y = parallel_blocks; + blocks_num.z = Q->ne[2]*Q->ne[3]; + + if (parallel_blocks > 1) { + dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV)); + dst_tmp_meta.alloc(parallel_blocks*ggml_nrows(KQV)); + } + + float scale = 1.0f; + float max_bias = 0.0f; + float logit_softcap = 0.0f; + + memcpy(&scale, (const float *) KQV->op_params + 0, sizeof(float)); + memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float)); + memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float)); + + if (logit_softcap != 0.0f) { + scale /= logit_softcap; + } + + const uint32_t n_head = Q->ne[2]; + const uint32_t n_head_log2 = 1u << uint32_t(floorf(log2f(float(n_head)))); + + const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); + const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); + + GGML_ASSERT(block_dim.x % warp_size == 0); + fattn_kernel<<>>( + (const char *) Q->data, + K_data, + V_data, + mask ? ((const char *) mask->data) : nullptr, + sinks ? ((const char *) sinks->data) : nullptr, + KV_min_max.ptr, + parallel_blocks > 1 ? dst_tmp.ptr : (float *) KQV->data, dst_tmp_meta.ptr, + scale, max_bias, m0, m1, n_head_log2, logit_softcap, + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], Q->nb[1], Q->nb[2], Q->nb[3], + K->ne[0], K->ne[1], K->ne[2], K->ne[3], nb11, nb12, nb13, + nb21, nb22, nb23, + mask ? mask->ne[1] : 0, mask ? mask->ne[2] : 0, mask ? mask->ne[3] : 0, + mask ? mask->nb[1] : 0, mask ? mask->nb[2] : 0, mask ? mask->nb[3] : 0 + ); + CUDA_CHECK(cudaGetLastError()); + + if (parallel_blocks > 1) { + const dim3 block_dim_combine(DV, 1, 1); + const dim3 blocks_num_combine(Q->ne[1], Q->ne[2], Q->ne[3]); + const size_t nbytes_shared_combine = parallel_blocks*sizeof(float2); + + flash_attn_combine_results + <<>> + (dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data, parallel_blocks); + } + CUDA_CHECK(cudaGetLastError()); +} diff --git a/ggml/src/ggml-cuda/fattn-vec-f16.cuh b/ggml/src/ggml-cuda/fattn-vec-f16.cuh index 95dd0e96..2eb92b19 100644 --- a/ggml/src/ggml-cuda/fattn-vec-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-vec-f16.cuh @@ -6,9 +6,9 @@ // #include "common.cuh" -#include "fattn-common.cuh" +#include "fattn-vec-common.cuh" -template // D == head size +template // Dk, Dv == K, V head sizes #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) __launch_bounds__(Dk, 1) #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) @@ -18,44 +18,37 @@ static __global__ void flash_attn_vec_ext_f16( const char * __restrict__ V, const char * __restrict__ mask, const char * __restrict__ sinks, + const int2 * __restrict__ KV_min_max, float * __restrict__ dst, float2 * __restrict__ dst_meta, const float scale, const float max_bias, const float m0, const float m1, - const float softcap, const uint32_t n_head_log2, - const int ne00, - const int ne01, - const int ne02, - const int ne03, - const int ne10, - const int ne11, - const int ne12, - const int ne13, - const int ne31, - const int nb31, - const int nb01, - const int nb02, - const int nb03, - const int nb11, - const int nb12, - const int nb13, - const int nb21, - const int nb22, - const int nb23, - const int ne0, - const int ne1, - const int ne2, - const int ne3) { -#ifdef FP16_AVAILABLE + const float logit_softcap, + const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03, + const int32_t nb01, const int32_t nb02, const int32_t nb03, + const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13, + const int32_t nb11, const int32_t nb12, const int64_t nb13, + const int32_t nb21, const int32_t nb22, const int64_t nb23, + const int32_t ne31, const int32_t ne32, const int32_t ne33, + const int32_t nb31, const int32_t nb32, const int64_t nb33) { +#if defined(FP16_AVAILABLE) + // Skip unused kernel variants for faster compilation: if constexpr (Dk == Dv || (Dk == 192 && Dv == 128) || (Dk == 576 && Dv == 512)) { - if (use_softcap && !(Dk == 128 || Dk == 256)) { + if (use_logit_softcap && !(Dk == 128 || Dk == 256)) { NO_DEVICE_CODE; return; } + } +#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) + if (ncols > 1) { + NO_DEVICE_CODE; + return; + } +#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) //In this kernel Q, K, V are matrices while i, j, k are matrix indices. @@ -63,18 +56,19 @@ static __global__ void flash_attn_vec_ext_f16( constexpr bool Q_q8_1 = type_K != GGML_TYPE_F16; constexpr dequantize_1_f16_t dequantize_1_v = get_dequantize_1_f16(type_V); - const int ic0 = (blockIdx.x / parallel_blocks) * ncols; // Index of the Q/QKV column to work on. - const int ip = blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel. + const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on. + const int sequence = blockIdx.z / ne02; + const int head = blockIdx.z - sequence*ne02; const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. - Q += nb02* blockIdx.y + nb01*ic0; - K += nb12*(blockIdx.y / gqa_ratio); - V += nb22*(blockIdx.y / gqa_ratio); + Q += nb03*sequence + nb02* head + nb01*ic0; + K += nb13*sequence + nb12*(head / gqa_ratio); + V += nb23*sequence + nb22*(head / gqa_ratio); - const half * maskh = (const half *) mask + ne11*ic0; + const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0); const float * sinksf = (const float *) (sinks); - const float slopef = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1); + const float slopef = get_alibi_slope(max_bias, head, n_head_log2, m0, m1); const half slopeh = __float2half(slopef); static_assert(Dk % (2*WARP_SIZE) == 0, "Dk not divisible by 2*WARP_SIZE == 64."); @@ -87,11 +81,12 @@ static __global__ void flash_attn_vec_ext_f16( half2 * KQ2 = (half2 *) KQ; half kqmax[ncols]; + half kqsum[ncols]; #pragma unroll for (int j = 0; j < ncols; ++j) { kqmax[j] = -HALF_MAX_HALF; + kqsum[j] = 0.0f; } - half kqsum[ncols] = {0.0f}; __shared__ half kqmax_shared[ncols][WARP_SIZE]; __shared__ half kqsum_shared[ncols][WARP_SIZE]; @@ -102,8 +97,21 @@ static __global__ void flash_attn_vec_ext_f16( kqsum_shared[j][threadIdx.x] = 0.0f; } } + + __shared__ half maskh_shared[ncols*Dk]; +#pragma unroll + for (int j = 0; j < ncols; ++j) { + maskh_shared[j*Dk + tid] = 0.0f; + } + __syncthreads(); + half2 VKQ[ncols] = {{0.0f, 0.0f}}; + + const int k_VKQ_max = KV_min_max ? KV_min_max[sequence*gridDim.x + blockIdx.x].y : ne11; + const int first_y = KV_min_max ? KV_min_max[sequence*gridDim.x + blockIdx.x].x : 0; + if (first_y + blockIdx.y*Dk < k_VKQ_max) { + // Convert Q to half2 (f16 K) or q8_1 (quantized K) and store in registers: half2 Q_h2[ncols][Dk/(2*WARP_SIZE)]; int Q_i32[ncols][Dk/(sizeof(int)*QK8_1) == 0 ? 1 : Dk/(sizeof(int)*QK8_1)]; @@ -179,12 +187,19 @@ static __global__ void flash_attn_vec_ext_f16( for (int j = 0; j < ncols; ++j) { KQ[j*Dk + tid] = -HALF_MAX_HALF; } + __syncthreads(); - half2 VKQ[ncols] = {{0.0f, 0.0f}}; - - const int k_start = parallel_blocks == 1 ? 0 : ip*Dk; - for (int k_VKQ_0 = k_start; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*Dk) { - // Calculate KQ tile and keep track of new maximum KQ values: + const int k_VKQ_max = KV_min_max ? KV_min_max[sequence*gridDim.x + blockIdx.x].y : ne11; + const int first_y = KV_min_max ? KV_min_max[sequence*gridDim.x + blockIdx.x].x : 0; + //if (KV_min_max && threadIdx.x == 0 && head == 0) { + // printf("gridDims = %u, %u, %u, ncols = %d, head = %d, blockIdx.x = %d, blockIdx.y = %d, bounds = %d, %d, ne11 = %d, nb11 = %d, blockIdx.y*Dk = %d\n", gridDim.x, gridDim.y, gridDim.z, ncols, head, blockIdx.x, blockIdx.y, KV_min_max[sequence*gridDim.x + blockIdx.x].x, KV_min_max[sequence*gridDim.x + blockIdx.x].y, ne11, nb11, blockIdx.y*Dk); + //} + K += (first_y + blockIdx.y*Dk) * nb11; + V += (first_y + blockIdx.y*Dv) * nb21; + maskh += (first_y + blockIdx.y*Dk); + for (int k_VKQ_0 = first_y + blockIdx.y*Dk; k_VKQ_0 < k_VKQ_max; k_VKQ_0 += gridDim.y*Dk, + // Increment pointers after each loop: + K += gridDim.y*Dk*nb11, V += gridDim.y*Dv*nb21, maskh += gridDim.y*Dk) { // For unknown reasons using a half array of size 1 for kqmax_new causes a performance regression, // see https://github.com/ggerganov/llama.cpp/pull/7061 . @@ -196,6 +211,14 @@ static __global__ void flash_attn_vec_ext_f16( kqmax_new_arr[j] = kqmax[j]; } + if (mask) { +#pragma unroll + for (int j = 0; j < ncols; ++j) { + maskh_shared[j*Dk + tid] = slopeh*maskh[j*ne11 + tid]; + } + __syncthreads(); + } + #pragma unroll for (int i_KQ_0 = 0; i_KQ_0 < Dk; i_KQ_0 += nwarps) { const int i_KQ = i_KQ_0 + threadIdx.y; @@ -206,12 +229,14 @@ static __global__ void flash_attn_vec_ext_f16( #pragma unroll for (int j = 0; j < ncols; ++j) { - half sum = vec_dot_KQ(K + (k_VKQ_0 + i_KQ)*nb11, Q_h2[j], Q_i32[j], Q_ds[j]); - sum = warp_reduce_sum(sum); - if (use_softcap) { - sum = softcap*tanhf(sum); + half sum = vec_dot_KQ(K + i_KQ*nb11, Q_h2[j], Q_i32[j], Q_ds[j]); + sum = warp_reduce_sum((float)sum); + + if (use_logit_softcap) { + sum = logit_softcap*tanhf(sum); } - sum += mask ? slopeh*maskh[j*ne11 + k_VKQ_0 + i_KQ] : __float2half(0.0f); + + sum += maskh_shared[j*Dk + i_KQ]; if (ncols == 1) { kqmax_new = ggml_cuda_hmax(kqmax_new, sum); @@ -229,7 +254,6 @@ static __global__ void flash_attn_vec_ext_f16( for (int j = 0; j < ncols; ++j) { half kqmax_new_j = ncols == 1 ? kqmax_new : kqmax_new_arr[j]; - kqmax_new_j = warp_reduce_max(kqmax_new_j); if (threadIdx.x == 0) { kqmax_shared[j][threadIdx.y] = kqmax_new_j; } @@ -261,8 +285,8 @@ static __global__ void flash_attn_vec_ext_f16( } half2 V_k; - reinterpret_cast(V_k.x) = dequantize_1_v(V + (k_VKQ_0 + k0 + 0)*nb21, tid); - reinterpret_cast(V_k.y) = dequantize_1_v(V + (k_VKQ_0 + k0 + 1)*nb21, tid); + reinterpret_cast(V_k.x) = dequantize_1_v(V + (k0 + 0)*nb21, tid); + reinterpret_cast(V_k.y) = dequantize_1_v(V + (k0 + 1)*nb21, tid); #pragma unroll for (int j = 0; j < ncols; ++j) { VKQ[j] += V_k*KQ2[j*(Dk/2) + k0/2]; @@ -271,9 +295,10 @@ static __global__ void flash_attn_vec_ext_f16( __syncthreads(); } + } - if (sinksf) { - const half sink = __float2half(sinksf[blockIdx.y]); + if (sinksf && blockIdx.y == 0) { + const half sink = __float2half(sinksf[head]); #pragma unroll for (int j = 0; j < ncols; ++j) { @@ -286,7 +311,7 @@ static __global__ void flash_attn_vec_ext_f16( #pragma unroll for (int j = 0; j < ncols; ++j) { - half kqmax_new_j = kqmax_shared[j][threadIdx.y]; + half kqmax_new_j = kqmax_shared[j][threadIdx.x]; kqmax_new_j = warp_reduce_max(kqmax_new_j); const half KQ_max_scale = hexp(kqmax[j] - kqmax_new_j); @@ -307,7 +332,7 @@ static __global__ void flash_attn_vec_ext_f16( #pragma unroll for (int j = 0; j < ncols; ++j) { - kqsum[j] = warp_reduce_sum(kqsum[j]); + kqsum[j] = warp_reduce_sum((float)kqsum[j]); if (threadIdx.x == 0) { kqsum_shared[j][threadIdx.y] = kqsum[j]; } @@ -322,32 +347,31 @@ static __global__ void flash_attn_vec_ext_f16( } kqsum[j_VKQ] = kqsum_shared[j_VKQ][threadIdx.x]; - kqsum[j_VKQ] = warp_reduce_sum(kqsum[j_VKQ]); + kqsum[j_VKQ] = warp_reduce_sum((float)kqsum[j_VKQ]); half dst_val = (__low2half(VKQ[j_VKQ]) + __high2half(VKQ[j_VKQ])); - if (parallel_blocks == 1) { + if (gridDim.y == 1) { dst_val /= kqsum[j_VKQ]; } - const int j_dst = (ic0 + j_VKQ)*parallel_blocks + ip; - dst[j_dst*Dv*gridDim.y + Dv*blockIdx.y + tid] = dst_val; + dst[(((sequence*ne01 + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y)*Dv + tid] = dst_val; } - if (parallel_blocks != 1 && tid < ncols && (ncols <= 2 || ic0 + tid < ne01)) { - dst_meta[(ic0 + tid)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[tid], kqsum[tid]); - } + if (gridDim.y != 1 && tid < ncols && (ncols <= 2 || ic0 + tid < ne01)) { + dst_meta[((sequence*ne01 + ic0 + tid)*ne02 + head)*gridDim.y + blockIdx.y] = make_float2(kqmax[tid], kqsum[tid]); } #else NO_DEVICE_CODE; #endif // FP16_AVAILABLE } -template +template void ggml_cuda_flash_attn_ext_vec_f16_case_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { constexpr int nwarps = Dk/WARP_SIZE; - fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f16; + fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f16; constexpr bool need_f16_K = Dk != 128 && Dk != 192; constexpr bool need_f16_V = Dv != 128 && Dv != 64; - launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block, need_f16_K, need_f16_V); + constexpr size_t nbytes_shared = 0; + launch_fattn(ctx, dst, fattn_kernel, nwarps, nbytes_shared, Dv, need_f16_K, need_f16_V); } template @@ -363,59 +387,54 @@ void ggml_cuda_flash_attn_ext_vec_f16_case(ggml_backend_cuda_context & ctx, ggml GGML_ASSERT(K->type == type_K); GGML_ASSERT(V->type == type_V); - float softcap; - memcpy(&softcap, (const float *) KQV->op_params + 2, sizeof(float)); + float logit_softcap; + memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float)); - if (Q->ne[1] == 1) { - constexpr int cols_per_block = 1; - constexpr int parallel_blocks = 4; - if (softcap == 0.0f) { - ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst); + const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; + + if (Q->ne[1] == 1 || GGML_CUDA_CC_IS_NVIDIA(cc)) { + constexpr int cols_per_block = 1; + if (logit_softcap == 0.0f) { + constexpr bool use_logit_softcap = false; + ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst); } else { - ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst); + constexpr bool use_logit_softcap = true; + ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst); } return; } if (Q->ne[1] == 2) { - constexpr int cols_per_block = 2; - constexpr int parallel_blocks = 4; - if (softcap == 0.0f) { - ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst); + constexpr int cols_per_block = 2; + if (logit_softcap == 0.0f) { + constexpr bool use_logit_softcap = false; + ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst); } else { - ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst); + constexpr bool use_logit_softcap = true; + ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst); } return; } if (Q->ne[1] <= 4) { - constexpr int cols_per_block = 4; - constexpr int parallel_blocks = 4; - if (softcap == 0.0f) { - ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst); + constexpr int cols_per_block = 4; + if (logit_softcap == 0.0f) { + constexpr bool use_logit_softcap = false; + ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst); } else { - ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst); + constexpr bool use_logit_softcap = true; + ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst); } return; } - if (Q->ne[1] <= 8) { - constexpr int cols_per_block = 8; - constexpr int parallel_blocks = 4; - if (softcap == 0.0f) { - ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst); - } else { - ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst); - } - return; - } - - constexpr int cols_per_block = 8; - constexpr int parallel_blocks = 1; - if (softcap == 0.0f) { - ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst); + constexpr int cols_per_block = 8; + if (logit_softcap == 0.0f) { + constexpr bool use_logit_softcap = false; + ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst); } else { - ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst); + constexpr bool use_logit_softcap = true; + ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst); } } diff --git a/ggml/src/ggml-cuda/fattn-vec-f32.cuh b/ggml/src/ggml-cuda/fattn-vec-f32.cuh index a97b3737..db1dc132 100644 --- a/ggml/src/ggml-cuda/fattn-vec-f32.cuh +++ b/ggml/src/ggml-cuda/fattn-vec-f32.cuh @@ -6,55 +6,54 @@ // #include "common.cuh" -#include "fattn-common.cuh" +#include "fattn-vec-common.cuh" -template // D == head size -#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) +// Currenlty llvm with the amdgcn target dose not support unrolling loops +// that contain a break that can not be resolved at compile time. +#ifdef __clang__ +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wpass-failed" +#endif // __clang__ +template // Dk, Dv == K-, V-head size +#ifndef GGML_USE_HIP __launch_bounds__(Dk, 1) -#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) +#endif // GGML_USE_HIP static __global__ void flash_attn_vec_ext_f32( const char * __restrict__ Q, const char * __restrict__ K, const char * __restrict__ V, const char * __restrict__ mask, const char * __restrict__ sinks, + const int2 * __restrict__ KV_min_max, float * __restrict__ dst, float2 * __restrict__ dst_meta, const float scale, const float max_bias, const float m0, const float m1, - const float softcap, const uint32_t n_head_log2, - const int ne00, - const int ne01, - const int ne02, - const int ne03, - const int ne10, - const int ne11, - const int ne12, - const int ne13, - const int ne31, - const int nb31, - const int nb01, - const int nb02, - const int nb03, - const int nb11, - const int nb12, - const int nb13, - const int nb21, - const int nb22, - const int nb23, - const int ne0, - const int ne1, - const int ne2, - const int ne3) { + const float logit_softcap, + const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03, + const int32_t nb01, const int32_t nb02, const int32_t nb03, + const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13, + const int32_t nb11, const int32_t nb12, const int64_t nb13, + const int32_t nb21, const int32_t nb22, const int64_t nb23, + const int32_t ne31, const int32_t ne32, const int32_t ne33, + const int32_t nb31, const int32_t nb32, const int64_t nb33) { + // Skip unused kernel variants for faster compilation: if constexpr (Dk == Dv || (Dk == 192 && Dv == 128) || (Dk == 576 && Dv == 512)) { - if (use_softcap && !(Dk == 128 || Dk == 256)) { + if (use_logit_softcap && !(Dk == 128 || Dk == 256)) { NO_DEVICE_CODE; return; } + } +#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) + if (ncols > 1) { + NO_DEVICE_CODE; + return; + } +#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) //In this kernel Q, K, V are matrices while i, j, k are matrix indices. @@ -62,17 +61,19 @@ static __global__ void flash_attn_vec_ext_f32( constexpr bool Q_q8_1 = type_K != GGML_TYPE_F16; constexpr dequantize_1_f32_t dequantize_1_v = get_dequantize_1_f32(type_V); - const int ic0 = (blockIdx.x / parallel_blocks) * ncols; // Index of the Q/QKV column to work on. - const int ip = blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel. + const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on. + const int sequence = blockIdx.z / ne02; + const int head = blockIdx.z - sequence*ne02; const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. - Q += nb02* blockIdx.y + nb01*ic0; - K += nb12*(blockIdx.y / gqa_ratio); - V += nb22*(blockIdx.y / gqa_ratio); // K and V have same shape - const half * maskh = (const half *) mask + ne11*ic0; + Q += nb03*sequence + nb02* head + nb01*ic0; + K += nb13*sequence + nb12*(head / gqa_ratio); + V += nb23*sequence + nb22*(head / gqa_ratio); + + const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0); const float * sinksf = (const float *) (sinks); - const float slope = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1); + const float slope = get_alibi_slope(max_bias, head, n_head_log2, m0, m1); static_assert(Dk % (2*WARP_SIZE) == 0, "Dk not divisible by 2*WARP_SIZE == 64."); static_assert(Dv % (2*WARP_SIZE) == 0, "Dv not divisible by 2*WARP_SIZE == 64."); @@ -87,11 +88,12 @@ static __global__ void flash_attn_vec_ext_f32( } float kqmax[ncols]; + float kqsum[ncols]; #pragma unroll for (int j = 0; j < ncols; ++j) { kqmax[j] = -FLT_MAX/2.0f; + kqsum[j] = 0.0f; } - float kqsum[ncols] = {0.0f}; __shared__ float kqmax_shared[ncols][WARP_SIZE]; __shared__ float kqsum_shared[ncols][WARP_SIZE]; @@ -102,6 +104,13 @@ static __global__ void flash_attn_vec_ext_f32( kqsum_shared[j][threadIdx.x] = 0.0f; } } + + __shared__ float maskf_shared[ncols*Dk]; +#pragma unroll + for (int j = 0; j < ncols; ++j) { + maskf_shared[j*Dk + tid] = 0.0f; + } + __syncthreads(); // Convert Q to float2 (f16 K) or q8_1 (quantized K) and store in registers: @@ -176,10 +185,26 @@ static __global__ void flash_attn_vec_ext_f32( float VKQ[ncols] = {0.0f}; - const int k_start = parallel_blocks == 1 ? 0 : ip*Dk; - for (int k_VKQ_0 = k_start; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*Dk) { + const int k_VKQ_max = KV_min_max ? KV_min_max[sequence*gridDim.x + blockIdx.x].y : ne11; + const int first_y = KV_min_max ? KV_min_max[sequence*gridDim.x + blockIdx.x].x : 0; + + K += (first_y + blockIdx.y*Dk) * nb11; + V += (first_y + blockIdx.y*Dv) * nb21; + maskh += (first_y + blockIdx.y*Dk); + for (int k_VKQ_0 = first_y + blockIdx.y*Dk; k_VKQ_0 < k_VKQ_max; k_VKQ_0 += gridDim.y*Dk, + // Increment pointers after each loop: + K += gridDim.y*Dk*nb11, V += gridDim.y*Dv*nb21, maskh += gridDim.y*Dk) { + // Calculate KQ tile and keep track of new maximum KQ values: + if (mask) { +#pragma unroll + for (int j = 0; j < ncols; ++j) { + maskf_shared[j*Dk + tid] = slope*__half2float(maskh[j*ne11 + tid]); + } + __syncthreads(); + } + float kqmax_new_arr[ncols]; #pragma unroll for (int j = 0; j < ncols; ++j) { @@ -196,12 +221,14 @@ static __global__ void flash_attn_vec_ext_f32( #pragma unroll for (int j = 0; j < ncols; ++j) { - float sum = vec_dot_KQ(K + (k_VKQ_0 + i_KQ)*nb11, Q_f2[j], Q_i32[j], Q_ds[j]); + float sum = vec_dot_KQ(K + i_KQ*nb11, Q_f2[j], Q_i32[j], Q_ds[j]); sum = warp_reduce_sum(sum); - if (use_softcap) { - sum = softcap*tanhf(sum); + + if (use_logit_softcap) { + sum = logit_softcap*tanhf(sum); } - sum += mask ? slope*__half2float(maskh[j*ne11 + k_VKQ_0 + i_KQ]) : 0.0f; + + sum += maskf_shared[j*Dk + i_KQ]; kqmax_new_arr[j] = fmaxf(kqmax_new_arr[j], sum); @@ -215,7 +242,6 @@ static __global__ void flash_attn_vec_ext_f32( for (int j = 0; j < ncols; ++j) { float kqmax_new_j = kqmax_new_arr[j]; - kqmax_new_j = warp_reduce_max(kqmax_new_j); if (threadIdx.x == 0) { kqmax_shared[j][threadIdx.y] = kqmax_new_j; } @@ -246,7 +272,7 @@ static __global__ void flash_attn_vec_ext_f32( break; } - const float V_ki = dequantize_1_v(V + (k_VKQ_0 + k)*nb21, tid); + const float V_ki = dequantize_1_v(V + k*nb21, tid); #pragma unroll for (int j = 0; j < ncols; ++j) { VKQ[j] += V_ki*KQ[j*Dk + k]; @@ -256,8 +282,8 @@ static __global__ void flash_attn_vec_ext_f32( __syncthreads(); } - if (sinksf) { - const float sink = sinksf[blockIdx.y]; + if (sinksf && blockIdx.y == 0) { + const float sink = sinksf[head]; #pragma unroll for (int j = 0; j < ncols; ++j) { @@ -270,7 +296,7 @@ static __global__ void flash_attn_vec_ext_f32( #pragma unroll for (int j = 0; j < ncols; ++j) { - float kqmax_new_j = kqmax_shared[j][threadIdx.y]; + float kqmax_new_j = kqmax_shared[j][threadIdx.x]; kqmax_new_j = warp_reduce_max(kqmax_new_j); const float KQ_max_scale = expf(kqmax[j] - kqmax_new_j); @@ -309,26 +335,28 @@ static __global__ void flash_attn_vec_ext_f32( kqsum[j_VKQ] = warp_reduce_sum(kqsum[j_VKQ]); float dst_val = VKQ[j_VKQ]; - if (parallel_blocks == 1) { + if (gridDim.y == 1) { dst_val /= kqsum[j_VKQ]; } - const int j_dst = (ic0 + j_VKQ)*parallel_blocks + ip; - dst[j_dst*Dv*gridDim.y + Dv*blockIdx.y + tid] = dst_val; + dst[(((sequence*ne01 + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y)*Dv + tid] = dst_val; } - if (parallel_blocks != 1 && tid < ncols && (ncols <= 2 || ic0 + tid < ne01)) { - dst_meta[(ic0 + tid)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[tid], kqsum[tid]); - } + if (gridDim.y != 1 && tid < ncols && (ncols <= 2 || ic0 + tid < ne01)) { + dst_meta[((sequence*ne01 + ic0 + tid)*ne02 + head)*gridDim.y + blockIdx.y] = make_float2(kqmax[tid], kqsum[tid]); } } +#ifdef __clang__ +#pragma clang diagnostic pop +#endif // __clang__ -template +template void ggml_cuda_flash_attn_ext_vec_f32_case_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { constexpr int nwarps = Dk/WARP_SIZE; - fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f32; - constexpr bool need_f16_K = Dk != 128 && Dk != 192; + fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f32; + constexpr bool need_f16_K = Dk != 128; constexpr bool need_f16_V = Dv != 128 && Dv != 64; - launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block, need_f16_K, need_f16_V); + constexpr size_t nbytes_shared = 0; + launch_fattn(ctx, dst, fattn_kernel, nwarps, nbytes_shared, Dv, need_f16_K, need_f16_V); } template @@ -341,59 +369,54 @@ void ggml_cuda_flash_attn_ext_vec_f32_case(ggml_backend_cuda_context & ctx, ggml GGML_ASSERT(K->type == type_K); GGML_ASSERT(V->type == type_V); - float softcap; - memcpy(&softcap, (const float *) KQV->op_params + 2, sizeof(float)); + float logit_softcap; + memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float)); - if (Q->ne[1] == 1) { - constexpr int cols_per_block = 1; - constexpr int parallel_blocks = 4; - if (softcap == 0.0f) { - ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst); + const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; + + if (Q->ne[1] == 1 || GGML_CUDA_CC_IS_NVIDIA(cc)) { + constexpr int cols_per_block = 1; + if (logit_softcap == 0.0f) { + constexpr bool use_logit_softcap = false; + ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst); } else { - ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst); + constexpr bool use_logit_softcap = true; + ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst); } return; } if (Q->ne[1] == 2) { - constexpr int cols_per_block = 2; - constexpr int parallel_blocks = 4; - if (softcap == 0.0f) { - ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst); + constexpr int cols_per_block = 2; + if (logit_softcap == 0.0f) { + constexpr bool use_logit_softcap = false; + ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst); } else { - ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst); + constexpr bool use_logit_softcap = true; + ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst); } return; } if (Q->ne[1] <= 4) { - constexpr int cols_per_block = 4; - constexpr int parallel_blocks = 4; - if (softcap == 0.0f) { - ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst); + constexpr int cols_per_block = 4; + if (logit_softcap == 0.0f) { + constexpr bool use_logit_softcap = false; + ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst); } else { - ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst); + constexpr bool use_logit_softcap = true; + ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst); } return; } - if (Q->ne[1] <= 8) { - constexpr int cols_per_block = 8; - constexpr int parallel_blocks = 4; - if (softcap == 0.0f) { - ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst); - } else { - ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst); - } - return; - } - - constexpr int cols_per_block = 8; - constexpr int parallel_blocks = 1; - if (softcap == 0.0f) { - ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst); + constexpr int cols_per_block = 8; + if (logit_softcap == 0.0f) { + constexpr bool use_logit_softcap = false; + ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst); } else { - ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst); + constexpr bool use_logit_softcap = true; + ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst); } } diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu index eb4684d2..90a369d2 100644 --- a/ggml/src/ggml-cuda/fattn.cu +++ b/ggml/src/ggml-cuda/fattn.cu @@ -16,6 +16,8 @@ #include +#define FATTN_KQ_STRIDE 256 + void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * KQV = dst; const ggml_tensor * Q = dst->src[0]; @@ -26,6 +28,25 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst ggml_cuda_set_device(ctx.device); const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; const int32_t precision = KQV->op_params[3]; + const int32_t n_swa = KQV->op_params[4]; + + ggml_tensor local_dst, Kl, Vl, Ml; + if (n_swa > 0) { + int ntokens = std::max(FATTN_KQ_STRIDE, int(Q->ne[1])); + int nton = FATTN_KQ_STRIDE*((ntokens + n_swa + FATTN_KQ_STRIDE - 1)/FATTN_KQ_STRIDE); + int first = K->ne[1] - nton; + if (first > 0) { + local_dst = *dst; + Kl = *K; Kl.ne[1] = nton; Kl.data = (char *)K->data + K->nb[1]*first; + Vl = *V; Vl.ne[1] = nton; Vl.data = (char *)V->data + V->nb[1]*first; + Ml = *mask; Ml.ne[0] = nton; Ml.data = (char *)mask->data + mask->nb[0]*first; + local_dst.src[1] = &Kl; + local_dst.src[2] = &Vl; + local_dst.src[3] = &Ml; + local_dst.op_params[4] = 0; + dst = &local_dst; + } + } // On AMD the tile kernels perform poorly, use the vec kernel instead: if (cc >= CC_OFFSET_AMD) { @@ -68,14 +89,14 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst // On my GPU (RTX-4080) MMA is efinitely faster for GQA, both for f16 and for quantized KV cache. //const bool mma_needs_data_conversion = K->type != GGML_TYPE_F16 || V->type != GGML_TYPE_F16; //const bool mma_faster_for_bs1 = new_mma_available(cc) && gqa_opt_applies && cc < CC_ADA_LOVELACE && !mma_needs_data_conversion; - const bool mma_faster_for_bs1 = new_mma_available(cc) && gqa_opt_applies; + const bool mma_faster_for_bs1 = new_mma_available(cc) && gqa_opt_applies && !(Q->ne[1] == 1 && n_swa > 0); const bool can_use_vector_kernel = Q->ne[0] <= 256 && Q->ne[0] % (2*WARP_SIZE) == 0; if (Q->ne[1] == 1 && can_use_vector_kernel && !mma_faster_for_bs1) { - if (precision == GGML_PREC_DEFAULT) { - ggml_cuda_flash_attn_ext_vec_f16(ctx, dst); - } else { + //if (precision == GGML_PREC_DEFAULT) { + // ggml_cuda_flash_attn_ext_vec_f16(ctx, dst); + //} else { ggml_cuda_flash_attn_ext_vec_f32(ctx, dst); - } + //} return; } diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index c6912301..c4227b35 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -9008,6 +9008,8 @@ struct ggml_tensor * ggml_flash_attn_ext( float params[] = { scale, max_bias, softcap }; ggml_set_op_params(result, params, sizeof(params)); + ggml_set_op_params_i32(result, 4, 0); + result->op = GGML_OP_FLASH_ATTN_EXT; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->src[0] = q; diff --git a/src/llama.cpp b/src/llama.cpp index d793d006..49edbb8d 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -7985,7 +7985,7 @@ static struct ggml_tensor * llm_build_kqv( float kq_scale, const llm_build_cb & cb, int il, - ggml_tensor * sinks = nullptr) { + ggml_tensor * sinks = nullptr, int n_swa = 0) { const llama_model & model = lctx.model; const llama_hparams & hparams = lctx.model.hparams; const llama_cparams & cparams = lctx.cparams; @@ -8033,6 +8033,9 @@ static struct ggml_tensor * llm_build_kqv( cur = ggml_flash_attn_ext(ctx, q, k, v, kq_mask, kq_scale, hparams.f_max_alibi_bias, hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0.0f); ggml_flash_attn_ext_add_sinks(cur, sinks); + if (n_swa > 0) { + ((int32_t *)cur->op_params)[4] = n_swa; + } // Some models produced NaNs/gibberish when FA is computed with f16 precision on CUDA // For DeepSeek-2, it is perfectly fine with fp16 for PP, but I get gibberish when uding fp16 for TG. @@ -8190,7 +8193,7 @@ static struct ggml_tensor * llm_build_kv( float kq_scale, const llm_build_cb & cb, int il, - ggml_tensor * sinks = nullptr) { + ggml_tensor * sinks = nullptr, int n_swa = 0) { const llama_hparams & hparams = lctx.model.hparams; const llama_cparams & cparams = lctx.cparams; @@ -8205,7 +8208,7 @@ static struct ggml_tensor * llm_build_kv( struct ggml_tensor * cur; cur = llm_build_kqv(ctx, lctx, kv, graph, wo, wo_b, - q_cur, kq_mask, n_tokens, n_kv, kq_scale, cb, il, sinks); + q_cur, kq_mask, n_tokens, n_kv, kq_scale, cb, il, sinks, n_swa); cb(cur, "kqv_out", il); return cur; @@ -8766,7 +8769,8 @@ struct llm_build_context { cur = llm_build_kv(ctx0, lctx, kv_self, gf, model.layers[il].wo, model.layers[il].bo, - Kcur, Vcur, Qcur, this_KQ_mask, n_tokens, kv_head, n_kv, kq_scale, cb, il); + Kcur, Vcur, Qcur, this_KQ_mask, n_tokens, kv_head, n_kv, kq_scale, cb, il, nullptr, + this_KQ_mask == KQ_mask_swa ? hparams.n_swa : 0); } if (il == n_layer - 1) { @@ -12198,7 +12202,8 @@ struct llm_build_context { cur = llm_build_kv(ctx0, lctx, kv_self, gf, model.layers[il].wo, NULL, - Kcur, Vcur, Qcur, KQ_mask_l, n_tokens, kv_head, n_kv, 1.0f, cb, il); + Kcur, Vcur, Qcur, KQ_mask_l, n_tokens, kv_head, n_kv, 1.0f, cb, il, nullptr, + KQ_mask_l == KQ_mask_swa ? hparams.n_swa : 0); } cur = llm_build_norm(ctx0, cur, hparams, @@ -12335,7 +12340,8 @@ struct llm_build_context { cb(Kcur, "Kcur", il); cur = llm_build_kv(ctx0, lctx, kv_self, gf, model.layers[il].wo, NULL, - Kcur, Vcur, Qcur, KQ_mask_l, n_tokens, kv_head, n_kv, hparams.f_attention_scale, cb, il); + Kcur, Vcur, Qcur, KQ_mask_l, n_tokens, kv_head, n_kv, hparams.f_attention_scale, cb, il, nullptr, + KQ_mask_l == KQ_mask_swa ? hparams.n_swa : 0); } cur = llm_build_norm(ctx0, cur, hparams, model.layers[il].attn_post_norm, NULL, LLM_NORM_RMS, cb, il); @@ -14400,7 +14406,8 @@ struct llm_build_context { } cur = llm_build_kv(ctx0, lctx, kv_self, gf, model.layers[il].wo, model.layers[il].bo, Kcur, Vcur, Qcur, - KQ_mask_l, n_tokens, kv_head, n_kv, 1.0f / sqrtf(float(n_embd_head)), cb, il); + KQ_mask_l, n_tokens, kv_head, n_kv, 1.0f / sqrtf(float(n_embd_head)), cb, il, nullptr, + is_sliding ? hparams.n_swa : 0); } if (il == n_layer - 1) { @@ -15490,7 +15497,8 @@ struct llm_build_context { cb(Kcur, "Kcur", il); cur = llm_build_kv(ctx0, lctx, kv_self, gf, model.layers[il].wo, model.layers[il].bo, - Kcur, Vcur, Qcur, KQ_mask_l, n_tokens, kv_head, n_kv, kq_scale, cb, il, model.layers[il].attn_sinks); + Kcur, Vcur, Qcur, KQ_mask_l, n_tokens, kv_head, n_kv, kq_scale, cb, il, model.layers[il].attn_sinks, + is_sliding ? hparams.n_swa : 0); cb(cur, "attn_out", il); }