From bf0b5088e02394590e03f6162bf92d39c43b660b Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Wed, 3 Sep 2025 14:52:32 +0300 Subject: [PATCH] f32 vec kernel --- ggml/src/ggml-cuda/fattn-vec-f32.cuh | 207 +++++++++++++++------------ ggml/src/ggml-cuda/fattn.cu | 4 +- 2 files changed, 117 insertions(+), 94 deletions(-) diff --git a/ggml/src/ggml-cuda/fattn-vec-f32.cuh b/ggml/src/ggml-cuda/fattn-vec-f32.cuh index a97b3737..db1dc132 100644 --- a/ggml/src/ggml-cuda/fattn-vec-f32.cuh +++ b/ggml/src/ggml-cuda/fattn-vec-f32.cuh @@ -6,55 +6,54 @@ // #include "common.cuh" -#include "fattn-common.cuh" +#include "fattn-vec-common.cuh" -template // D == head size -#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) +// Currenlty llvm with the amdgcn target dose not support unrolling loops +// that contain a break that can not be resolved at compile time. +#ifdef __clang__ +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wpass-failed" +#endif // __clang__ +template // Dk, Dv == K-, V-head size +#ifndef GGML_USE_HIP __launch_bounds__(Dk, 1) -#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) +#endif // GGML_USE_HIP static __global__ void flash_attn_vec_ext_f32( const char * __restrict__ Q, const char * __restrict__ K, const char * __restrict__ V, const char * __restrict__ mask, const char * __restrict__ sinks, + const int2 * __restrict__ KV_min_max, float * __restrict__ dst, float2 * __restrict__ dst_meta, const float scale, const float max_bias, const float m0, const float m1, - const float softcap, const uint32_t n_head_log2, - const int ne00, - const int ne01, - const int ne02, - const int ne03, - const int ne10, - const int ne11, - const int ne12, - const int ne13, - const int ne31, - const int nb31, - const int nb01, - const int nb02, - const int nb03, - const int nb11, - const int nb12, - const int nb13, - const int nb21, - const int nb22, - const int nb23, - const int ne0, - const int ne1, - const int ne2, - const int ne3) { + const float logit_softcap, + const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03, + const int32_t nb01, const int32_t nb02, const int32_t nb03, + const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13, + const int32_t nb11, const int32_t nb12, const int64_t nb13, + const int32_t nb21, const int32_t nb22, const int64_t nb23, + const int32_t ne31, const int32_t ne32, const int32_t ne33, + const int32_t nb31, const int32_t nb32, const int64_t nb33) { + // Skip unused kernel variants for faster compilation: if constexpr (Dk == Dv || (Dk == 192 && Dv == 128) || (Dk == 576 && Dv == 512)) { - if (use_softcap && !(Dk == 128 || Dk == 256)) { + if (use_logit_softcap && !(Dk == 128 || Dk == 256)) { NO_DEVICE_CODE; return; } + } +#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) + if (ncols > 1) { + NO_DEVICE_CODE; + return; + } +#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) //In this kernel Q, K, V are matrices while i, j, k are matrix indices. @@ -62,17 +61,19 @@ static __global__ void flash_attn_vec_ext_f32( constexpr bool Q_q8_1 = type_K != GGML_TYPE_F16; constexpr dequantize_1_f32_t dequantize_1_v = get_dequantize_1_f32(type_V); - const int ic0 = (blockIdx.x / parallel_blocks) * ncols; // Index of the Q/QKV column to work on. - const int ip = blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel. + const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on. + const int sequence = blockIdx.z / ne02; + const int head = blockIdx.z - sequence*ne02; const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. - Q += nb02* blockIdx.y + nb01*ic0; - K += nb12*(blockIdx.y / gqa_ratio); - V += nb22*(blockIdx.y / gqa_ratio); // K and V have same shape - const half * maskh = (const half *) mask + ne11*ic0; + Q += nb03*sequence + nb02* head + nb01*ic0; + K += nb13*sequence + nb12*(head / gqa_ratio); + V += nb23*sequence + nb22*(head / gqa_ratio); + + const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0); const float * sinksf = (const float *) (sinks); - const float slope = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1); + const float slope = get_alibi_slope(max_bias, head, n_head_log2, m0, m1); static_assert(Dk % (2*WARP_SIZE) == 0, "Dk not divisible by 2*WARP_SIZE == 64."); static_assert(Dv % (2*WARP_SIZE) == 0, "Dv not divisible by 2*WARP_SIZE == 64."); @@ -87,11 +88,12 @@ static __global__ void flash_attn_vec_ext_f32( } float kqmax[ncols]; + float kqsum[ncols]; #pragma unroll for (int j = 0; j < ncols; ++j) { kqmax[j] = -FLT_MAX/2.0f; + kqsum[j] = 0.0f; } - float kqsum[ncols] = {0.0f}; __shared__ float kqmax_shared[ncols][WARP_SIZE]; __shared__ float kqsum_shared[ncols][WARP_SIZE]; @@ -102,6 +104,13 @@ static __global__ void flash_attn_vec_ext_f32( kqsum_shared[j][threadIdx.x] = 0.0f; } } + + __shared__ float maskf_shared[ncols*Dk]; +#pragma unroll + for (int j = 0; j < ncols; ++j) { + maskf_shared[j*Dk + tid] = 0.0f; + } + __syncthreads(); // Convert Q to float2 (f16 K) or q8_1 (quantized K) and store in registers: @@ -176,10 +185,26 @@ static __global__ void flash_attn_vec_ext_f32( float VKQ[ncols] = {0.0f}; - const int k_start = parallel_blocks == 1 ? 0 : ip*Dk; - for (int k_VKQ_0 = k_start; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*Dk) { + const int k_VKQ_max = KV_min_max ? KV_min_max[sequence*gridDim.x + blockIdx.x].y : ne11; + const int first_y = KV_min_max ? KV_min_max[sequence*gridDim.x + blockIdx.x].x : 0; + + K += (first_y + blockIdx.y*Dk) * nb11; + V += (first_y + blockIdx.y*Dv) * nb21; + maskh += (first_y + blockIdx.y*Dk); + for (int k_VKQ_0 = first_y + blockIdx.y*Dk; k_VKQ_0 < k_VKQ_max; k_VKQ_0 += gridDim.y*Dk, + // Increment pointers after each loop: + K += gridDim.y*Dk*nb11, V += gridDim.y*Dv*nb21, maskh += gridDim.y*Dk) { + // Calculate KQ tile and keep track of new maximum KQ values: + if (mask) { +#pragma unroll + for (int j = 0; j < ncols; ++j) { + maskf_shared[j*Dk + tid] = slope*__half2float(maskh[j*ne11 + tid]); + } + __syncthreads(); + } + float kqmax_new_arr[ncols]; #pragma unroll for (int j = 0; j < ncols; ++j) { @@ -196,12 +221,14 @@ static __global__ void flash_attn_vec_ext_f32( #pragma unroll for (int j = 0; j < ncols; ++j) { - float sum = vec_dot_KQ(K + (k_VKQ_0 + i_KQ)*nb11, Q_f2[j], Q_i32[j], Q_ds[j]); + float sum = vec_dot_KQ(K + i_KQ*nb11, Q_f2[j], Q_i32[j], Q_ds[j]); sum = warp_reduce_sum(sum); - if (use_softcap) { - sum = softcap*tanhf(sum); + + if (use_logit_softcap) { + sum = logit_softcap*tanhf(sum); } - sum += mask ? slope*__half2float(maskh[j*ne11 + k_VKQ_0 + i_KQ]) : 0.0f; + + sum += maskf_shared[j*Dk + i_KQ]; kqmax_new_arr[j] = fmaxf(kqmax_new_arr[j], sum); @@ -215,7 +242,6 @@ static __global__ void flash_attn_vec_ext_f32( for (int j = 0; j < ncols; ++j) { float kqmax_new_j = kqmax_new_arr[j]; - kqmax_new_j = warp_reduce_max(kqmax_new_j); if (threadIdx.x == 0) { kqmax_shared[j][threadIdx.y] = kqmax_new_j; } @@ -246,7 +272,7 @@ static __global__ void flash_attn_vec_ext_f32( break; } - const float V_ki = dequantize_1_v(V + (k_VKQ_0 + k)*nb21, tid); + const float V_ki = dequantize_1_v(V + k*nb21, tid); #pragma unroll for (int j = 0; j < ncols; ++j) { VKQ[j] += V_ki*KQ[j*Dk + k]; @@ -256,8 +282,8 @@ static __global__ void flash_attn_vec_ext_f32( __syncthreads(); } - if (sinksf) { - const float sink = sinksf[blockIdx.y]; + if (sinksf && blockIdx.y == 0) { + const float sink = sinksf[head]; #pragma unroll for (int j = 0; j < ncols; ++j) { @@ -270,7 +296,7 @@ static __global__ void flash_attn_vec_ext_f32( #pragma unroll for (int j = 0; j < ncols; ++j) { - float kqmax_new_j = kqmax_shared[j][threadIdx.y]; + float kqmax_new_j = kqmax_shared[j][threadIdx.x]; kqmax_new_j = warp_reduce_max(kqmax_new_j); const float KQ_max_scale = expf(kqmax[j] - kqmax_new_j); @@ -309,26 +335,28 @@ static __global__ void flash_attn_vec_ext_f32( kqsum[j_VKQ] = warp_reduce_sum(kqsum[j_VKQ]); float dst_val = VKQ[j_VKQ]; - if (parallel_blocks == 1) { + if (gridDim.y == 1) { dst_val /= kqsum[j_VKQ]; } - const int j_dst = (ic0 + j_VKQ)*parallel_blocks + ip; - dst[j_dst*Dv*gridDim.y + Dv*blockIdx.y + tid] = dst_val; + dst[(((sequence*ne01 + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y)*Dv + tid] = dst_val; } - if (parallel_blocks != 1 && tid < ncols && (ncols <= 2 || ic0 + tid < ne01)) { - dst_meta[(ic0 + tid)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[tid], kqsum[tid]); - } + if (gridDim.y != 1 && tid < ncols && (ncols <= 2 || ic0 + tid < ne01)) { + dst_meta[((sequence*ne01 + ic0 + tid)*ne02 + head)*gridDim.y + blockIdx.y] = make_float2(kqmax[tid], kqsum[tid]); } } +#ifdef __clang__ +#pragma clang diagnostic pop +#endif // __clang__ -template +template void ggml_cuda_flash_attn_ext_vec_f32_case_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { constexpr int nwarps = Dk/WARP_SIZE; - fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f32; - constexpr bool need_f16_K = Dk != 128 && Dk != 192; + fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f32; + constexpr bool need_f16_K = Dk != 128; constexpr bool need_f16_V = Dv != 128 && Dv != 64; - launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block, need_f16_K, need_f16_V); + constexpr size_t nbytes_shared = 0; + launch_fattn(ctx, dst, fattn_kernel, nwarps, nbytes_shared, Dv, need_f16_K, need_f16_V); } template @@ -341,59 +369,54 @@ void ggml_cuda_flash_attn_ext_vec_f32_case(ggml_backend_cuda_context & ctx, ggml GGML_ASSERT(K->type == type_K); GGML_ASSERT(V->type == type_V); - float softcap; - memcpy(&softcap, (const float *) KQV->op_params + 2, sizeof(float)); + float logit_softcap; + memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float)); - if (Q->ne[1] == 1) { - constexpr int cols_per_block = 1; - constexpr int parallel_blocks = 4; - if (softcap == 0.0f) { - ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst); + const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; + + if (Q->ne[1] == 1 || GGML_CUDA_CC_IS_NVIDIA(cc)) { + constexpr int cols_per_block = 1; + if (logit_softcap == 0.0f) { + constexpr bool use_logit_softcap = false; + ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst); } else { - ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst); + constexpr bool use_logit_softcap = true; + ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst); } return; } if (Q->ne[1] == 2) { - constexpr int cols_per_block = 2; - constexpr int parallel_blocks = 4; - if (softcap == 0.0f) { - ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst); + constexpr int cols_per_block = 2; + if (logit_softcap == 0.0f) { + constexpr bool use_logit_softcap = false; + ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst); } else { - ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst); + constexpr bool use_logit_softcap = true; + ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst); } return; } if (Q->ne[1] <= 4) { - constexpr int cols_per_block = 4; - constexpr int parallel_blocks = 4; - if (softcap == 0.0f) { - ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst); + constexpr int cols_per_block = 4; + if (logit_softcap == 0.0f) { + constexpr bool use_logit_softcap = false; + ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst); } else { - ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst); + constexpr bool use_logit_softcap = true; + ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst); } return; } - if (Q->ne[1] <= 8) { - constexpr int cols_per_block = 8; - constexpr int parallel_blocks = 4; - if (softcap == 0.0f) { - ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst); - } else { - ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst); - } - return; - } - - constexpr int cols_per_block = 8; - constexpr int parallel_blocks = 1; - if (softcap == 0.0f) { - ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst); + constexpr int cols_per_block = 8; + if (logit_softcap == 0.0f) { + constexpr bool use_logit_softcap = false; + ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst); } else { - ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst); + constexpr bool use_logit_softcap = true; + ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst); } } diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu index 47152e3c..5d04f8b5 100644 --- a/ggml/src/ggml-cuda/fattn.cu +++ b/ggml/src/ggml-cuda/fattn.cu @@ -73,9 +73,9 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst const bool can_use_vector_kernel = Q->ne[0] <= 256 && Q->ne[0] % (2*WARP_SIZE) == 0; if (Q->ne[1] == 1 && can_use_vector_kernel && !mma_faster_for_bs1) { //if (precision == GGML_PREC_DEFAULT) { - ggml_cuda_flash_attn_ext_vec_f16(ctx, dst); + // ggml_cuda_flash_attn_ext_vec_f16(ctx, dst); //} else { - // ggml_cuda_flash_attn_ext_vec_f32(ctx, dst); + ggml_cuda_flash_attn_ext_vec_f32(ctx, dst); //} return; }