From d7008ad52daf9da54694fc74f5a34be5c5bed389 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Sun, 11 May 2025 11:21:51 +0300 Subject: [PATCH] constexpr and minor changes --- ggml/src/ggml-cuda/fattn-new-mma.cu | 65 ++++++++++++++++++----------- 1 file changed, 40 insertions(+), 25 deletions(-) diff --git a/ggml/src/ggml-cuda/fattn-new-mma.cu b/ggml/src/ggml-cuda/fattn-new-mma.cu index 5cfaee8a..630baf33 100644 --- a/ggml/src/ggml-cuda/fattn-new-mma.cu +++ b/ggml/src/ggml-cuda/fattn-new-mma.cu @@ -1,3 +1,11 @@ +// Adapted from https://github.com/ggml-org/llama.cpp/pull/13435 +// +// Copyright (C) 2025 The ggml authors +// Copyright (C) 2025 Iwan Kawrakow +// MIT license +// SPDX-License-Identifier: MIT +// + #include "common.cuh" #include "cp-async.cuh" #include "mma_new.cuh" @@ -28,6 +36,13 @@ typedef tile<16, 8, half2> tile_C_VKQ_16; template struct fattn_mma_f16_config; +// +// The previous MMA version is better (faster) +// I'm keeping these around commented out for now, +// and only using the 576, 512 case. +// Perhaps the 256 head size needs a closer look +// to see if this implementation is better. +// //template <> //struct fattn_mma_f16_config< 64, 64> { // static constexpr int nbatch_fa = 64; @@ -212,19 +227,19 @@ struct fattn_mma_f16_config; // } // // static int get_nbatch_combine_host(const int cc, const int ncols) { -// if (ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_TURING) { +// if (ggml_cuda_highest_compiled_arch(cc) == CC_TURING) { // return ncols <= 16 ? 128 : 64; // } // return 64; // } // // static constexpr __device__ int get_nbatch_combine_device(int ncols) { -//#if __CUDA_ARCH__ == GGML_CUDA_CC_TURING +//#if __CUDA_ARCH__ == CC_TURING // return ncols <= 16 ? 128 : 64; //#else // GGML_UNUSED(ncols); // return 128; -//#endif // __CUDA_ARCH__ == GGML_CUDA_CC_TURING +//#endif // __CUDA_ARCH__ == CC_TURING // } //}; @@ -302,7 +317,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile( // K/V data is loaded with decreasing granularity for D for better memory bandwidth. // The minimum granularity with cp.async is 16 bytes, with synchronous data loading it's 4 bytes. - if (use_cp_async) { + if constexpr (use_cp_async) { constexpr int preload = 64; constexpr int h2_per_chunk = 16/sizeof(half2); const int chunks_per_row = D2 / h2_per_chunk; @@ -373,7 +388,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_mask( const half2 * const __restrict__ mask_h2, half2 * const __restrict__ tile_mask, const int stride_mask) { static_assert(nbatch_fa == 2*WARP_SIZE || WARP_SIZE % nbatch_fa == 0, "bad KQ_per_iter"); - if (use_cp_async) { + if constexpr (use_cp_async) { constexpr int preload = nbatch_fa >= 32 ? nbatch_fa * sizeof(half) : 64; constexpr int cols_per_warp = 8*WARP_SIZE/nbatch_fa; constexpr int stride_j = nwarps * cols_per_warp; @@ -478,7 +493,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( (V_h2 + k_VKQ_0*stride_V, tile_V, nbatch_V2, stride_V); } else { constexpr bool use_cp_async = nstages == 1; - if (ncols2 > 1 || mask_h2) { + if constexpr (ncols2 > 1 || mask_h2) { flash_attn_ext_f16_load_mask(mask_h2 + k_VKQ_0/2, tile_mask, stride_mask); } } @@ -488,11 +503,11 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( const int k0_stop = k0_start + nbatch_K2 < DKQ/2 ? k0_start + nbatch_K2 : DKQ/2; const int k0_diff = k0_stop - k0_start; - if (nstages <= 1) { + if constexpr (nstages <= 1) { constexpr bool use_cp_async = nstages == 1; flash_attn_ext_f16_load_tile (K_h2 + k_VKQ_0*stride_K + k0_start, tile_K, k0_diff, stride_K); - if (use_cp_async) { + if constexpr (use_cp_async) { cp_async_wait_all(); } __syncthreads(); @@ -507,7 +522,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( for (int k_KQ_0 = k0_start; k_KQ_0 < k0_stop; k_KQ_0 += tile_A::J) { tile_A K_A; load_ldmatrix(K_A, tile_K + i_KQ_0*stride_tile_K + (k_KQ_0 - k0_start), stride_tile_K); - if (ntiles == 1) { + if constexpr (ntiles == 1) { mma(KQ_C[i_KQ_00/(np*tile_A::I)], K_A, Q_B[k_KQ_0/tile_A::J]); } else { #pragma unroll @@ -537,12 +552,12 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( } } - if (nstages <= 1) { + if constexpr (nstages <= 1) { __syncthreads(); // Only needed if tile_K == tile_V. } } - if (use_logit_softcap) { + if constexpr (use_logit_softcap) { static_assert(c::nbatch_fa % (np*tile_C_KQ::I) == 0, "bad loop size"); #pragma unroll for (int i = 0; i < c::nbatch_fa/(np*tile_C_KQ::I) * ntiles; ++i) { @@ -560,8 +575,8 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( } float KQ_rowsum_add[cols_per_thread] = {0.0f}; - if (ntiles == 1) { - if (ncols2 > 1 || mask_h2) { + if constexpr (ntiles == 1) { + if constexpr (ncols2 > 1 || mask_h2) { #pragma unroll for (int i00 = 0; i00 < c::nbatch_fa; i00 += np*tile_C_KQ::I) { const int i0 = i00 + (threadIdx.y % np)*tile_C_KQ::I; @@ -679,7 +694,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( KQ_rowsum[col] = KQ_max_scale[col]*KQ_rowsum[col] + KQ_rowsum_add[col]; } - if (ntiles == 1) { + if constexpr (ntiles == 1) { const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[0], KQ_max_scale[1]); #pragma unroll for (int i = 0; i < DV/tile_C_VKQ::I; ++i) { @@ -707,7 +722,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( tile_B B[c::nbatch_fa/(np*2*tile_B::J) * ntiles]; tile_B_16 * B_16 = (tile_B_16 *) B; static_assert(c::nbatch_fa % (np*2*tile_B::J) == 0, "bad loop size"); - if (ntiles == 1) { + if constexpr (ntiles == 1) { #pragma unroll for (int k = 0; k < c::nbatch_fa/(np*2*tile_B::J); ++k) { B[k] = get_transposed(get_half2(KQ_C[k])); @@ -721,7 +736,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( } } - if (nstages > 1) { + if constexpr (nstages > 1) { // Preload K tile for next iteration: constexpr bool use_cp_async = true; cp_async_wait_all(); @@ -750,7 +765,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( constexpr bool use_cp_async = nstages == 1; flash_attn_ext_f16_load_tile (V_h2 + k_VKQ_0*stride_V + i0_start/2, tile_V, i0_diff/2, stride_V); - if (use_cp_async) { + if constexpr (use_cp_async) { cp_async_wait_all(); } __syncthreads(); @@ -767,7 +782,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( tile_A A; load_ldmatrix_trans(A, tile_V_i + 2*k0*stride_tile_V + (i_VKQ_0 - i0_start)/2, stride_tile_V); - if (ntiles == 1) { + if constexpr (ntiles == 1) { mma(VKQ_C[i_VKQ_0/tile_C_VKQ::I], A, B[k00/(np*tile_A::J)]); } else { #pragma unroll @@ -779,7 +794,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( } } - if (nstages <= 1) { + if constexpr (nstages <= 1) { __syncthreads(); // Only needed if tile_K == tile_V. } } @@ -909,12 +924,12 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( __syncthreads(); - if (c::Q_in_reg) { + if constexpr (c::Q_in_reg) { const int j0 = (threadIdx.y / np) * cols_per_warp; #pragma unroll for (int k0 = 0; k0 < DKQ/2; k0 += tile_B::J) { - if (ntiles == 1) { + if constexpr (ntiles == 1) { load_ldmatrix(Q_B[k0/tile_B::J], tile_Q + j0*stride_tile_Q + k0, stride_tile_Q); } else { #pragma unroll @@ -956,7 +971,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( // With multi-stage loading there is no __syncthreads at the end of the iter, // there can be a race condition on shared memory access for combining/writing back results. - if (nstages > 1 && nwarps*cols_per_warp > c::nbatch_fa) { + if constexpr (nstages > 1 && nwarps*cols_per_warp > c::nbatch_fa) { __syncthreads(); } @@ -1101,7 +1116,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( #pragma unroll for (int k00 = 0; k00 < DV/2; k00 += nbatch_combine) { - if (ntiles == 1) { + if constexpr (ntiles == 1) { const int jc_cwd = threadIdx.y*tile_B::I + tile_B::get_i(-1); // jc combine write data #pragma unroll for (int k0 = 0; k0 < nbatch_combine; k0 += tile_B::J) { @@ -1250,12 +1265,12 @@ static __global__ void flash_attn_ext_f16( #if defined(INT8_MMA_AVAILABLE) // Skip unused kernel variants for faster compilation: - if (use_logit_softcap && !(DKQ == 128 || DKQ == 256)) { + if constexpr (use_logit_softcap && !(DKQ == 128 || DKQ == 256)) { NO_DEVICE_CODE; return; } #if __CUDA_ARCH__ == CC_TURING - if (ncols1*ncols2 > 32) { + if constexpr (ncols1*ncols2 > 32) { NO_DEVICE_CODE; return; }