constexpr and minor changes

This commit is contained in:
Iwan Kawrakow
2025-05-11 11:21:51 +03:00
parent d1601d463b
commit d7008ad52d

View File

@@ -1,3 +1,11 @@
// Adapted from https://github.com/ggml-org/llama.cpp/pull/13435
//
// Copyright (C) 2025 The ggml authors
// Copyright (C) 2025 Iwan Kawrakow
// MIT license
// SPDX-License-Identifier: MIT
//
#include "common.cuh"
#include "cp-async.cuh"
#include "mma_new.cuh"
@@ -28,6 +36,13 @@ typedef tile<16, 8, half2> tile_C_VKQ_16;
template <int DKQ, int DV>
struct fattn_mma_f16_config;
//
// The previous MMA version is better (faster)
// I'm keeping these around commented out for now,
// and only using the 576, 512 case.
// Perhaps the 256 head size needs a closer look
// to see if this implementation is better.
//
//template <>
//struct fattn_mma_f16_config< 64, 64> {
// static constexpr int nbatch_fa = 64;
@@ -212,19 +227,19 @@ struct fattn_mma_f16_config;
// }
//
// static int get_nbatch_combine_host(const int cc, const int ncols) {
// if (ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_TURING) {
// if (ggml_cuda_highest_compiled_arch(cc) == CC_TURING) {
// return ncols <= 16 ? 128 : 64;
// }
// return 64;
// }
//
// static constexpr __device__ int get_nbatch_combine_device(int ncols) {
//#if __CUDA_ARCH__ == GGML_CUDA_CC_TURING
//#if __CUDA_ARCH__ == CC_TURING
// return ncols <= 16 ? 128 : 64;
//#else
// GGML_UNUSED(ncols);
// return 128;
//#endif // __CUDA_ARCH__ == GGML_CUDA_CC_TURING
//#endif // __CUDA_ARCH__ == CC_TURING
// }
//};
@@ -302,7 +317,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile(
// K/V data is loaded with decreasing granularity for D for better memory bandwidth.
// The minimum granularity with cp.async is 16 bytes, with synchronous data loading it's 4 bytes.
if (use_cp_async) {
if constexpr (use_cp_async) {
constexpr int preload = 64;
constexpr int h2_per_chunk = 16/sizeof(half2);
const int chunks_per_row = D2 / h2_per_chunk;
@@ -373,7 +388,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_mask(
const half2 * const __restrict__ mask_h2, half2 * const __restrict__ tile_mask, const int stride_mask) {
static_assert(nbatch_fa == 2*WARP_SIZE || WARP_SIZE % nbatch_fa == 0, "bad KQ_per_iter");
if (use_cp_async) {
if constexpr (use_cp_async) {
constexpr int preload = nbatch_fa >= 32 ? nbatch_fa * sizeof(half) : 64;
constexpr int cols_per_warp = 8*WARP_SIZE/nbatch_fa;
constexpr int stride_j = nwarps * cols_per_warp;
@@ -478,7 +493,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
(V_h2 + k_VKQ_0*stride_V, tile_V, nbatch_V2, stride_V);
} else {
constexpr bool use_cp_async = nstages == 1;
if (ncols2 > 1 || mask_h2) {
if constexpr (ncols2 > 1 || mask_h2) {
flash_attn_ext_f16_load_mask<ncols1, nwarps, c::nbatch_fa, use_cp_async>(mask_h2 + k_VKQ_0/2, tile_mask, stride_mask);
}
}
@@ -488,11 +503,11 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
const int k0_stop = k0_start + nbatch_K2 < DKQ/2 ? k0_start + nbatch_K2 : DKQ/2;
const int k0_diff = k0_stop - k0_start;
if (nstages <= 1) {
if constexpr (nstages <= 1) {
constexpr bool use_cp_async = nstages == 1;
flash_attn_ext_f16_load_tile<stride_tile_K, nwarps, c::nbatch_fa, use_cp_async>
(K_h2 + k_VKQ_0*stride_K + k0_start, tile_K, k0_diff, stride_K);
if (use_cp_async) {
if constexpr (use_cp_async) {
cp_async_wait_all();
}
__syncthreads();
@@ -507,7 +522,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
for (int k_KQ_0 = k0_start; k_KQ_0 < k0_stop; k_KQ_0 += tile_A::J) {
tile_A K_A;
load_ldmatrix(K_A, tile_K + i_KQ_0*stride_tile_K + (k_KQ_0 - k0_start), stride_tile_K);
if (ntiles == 1) {
if constexpr (ntiles == 1) {
mma(KQ_C[i_KQ_00/(np*tile_A::I)], K_A, Q_B[k_KQ_0/tile_A::J]);
} else {
#pragma unroll
@@ -537,12 +552,12 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
}
}
if (nstages <= 1) {
if constexpr (nstages <= 1) {
__syncthreads(); // Only needed if tile_K == tile_V.
}
}
if (use_logit_softcap) {
if constexpr (use_logit_softcap) {
static_assert(c::nbatch_fa % (np*tile_C_KQ::I) == 0, "bad loop size");
#pragma unroll
for (int i = 0; i < c::nbatch_fa/(np*tile_C_KQ::I) * ntiles; ++i) {
@@ -560,8 +575,8 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
}
float KQ_rowsum_add[cols_per_thread] = {0.0f};
if (ntiles == 1) {
if (ncols2 > 1 || mask_h2) {
if constexpr (ntiles == 1) {
if constexpr (ncols2 > 1 || mask_h2) {
#pragma unroll
for (int i00 = 0; i00 < c::nbatch_fa; i00 += np*tile_C_KQ::I) {
const int i0 = i00 + (threadIdx.y % np)*tile_C_KQ::I;
@@ -679,7 +694,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
KQ_rowsum[col] = KQ_max_scale[col]*KQ_rowsum[col] + KQ_rowsum_add[col];
}
if (ntiles == 1) {
if constexpr (ntiles == 1) {
const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[0], KQ_max_scale[1]);
#pragma unroll
for (int i = 0; i < DV/tile_C_VKQ::I; ++i) {
@@ -707,7 +722,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
tile_B B[c::nbatch_fa/(np*2*tile_B::J) * ntiles];
tile_B_16 * B_16 = (tile_B_16 *) B;
static_assert(c::nbatch_fa % (np*2*tile_B::J) == 0, "bad loop size");
if (ntiles == 1) {
if constexpr (ntiles == 1) {
#pragma unroll
for (int k = 0; k < c::nbatch_fa/(np*2*tile_B::J); ++k) {
B[k] = get_transposed(get_half2(KQ_C[k]));
@@ -721,7 +736,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
}
}
if (nstages > 1) {
if constexpr (nstages > 1) {
// Preload K tile for next iteration:
constexpr bool use_cp_async = true;
cp_async_wait_all();
@@ -750,7 +765,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
constexpr bool use_cp_async = nstages == 1;
flash_attn_ext_f16_load_tile<stride_tile_V, nwarps, c::nbatch_fa, use_cp_async>
(V_h2 + k_VKQ_0*stride_V + i0_start/2, tile_V, i0_diff/2, stride_V);
if (use_cp_async) {
if constexpr (use_cp_async) {
cp_async_wait_all();
}
__syncthreads();
@@ -767,7 +782,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
tile_A A;
load_ldmatrix_trans(A, tile_V_i + 2*k0*stride_tile_V + (i_VKQ_0 - i0_start)/2, stride_tile_V);
if (ntiles == 1) {
if constexpr (ntiles == 1) {
mma(VKQ_C[i_VKQ_0/tile_C_VKQ::I], A, B[k00/(np*tile_A::J)]);
} else {
#pragma unroll
@@ -779,7 +794,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
}
}
if (nstages <= 1) {
if constexpr (nstages <= 1) {
__syncthreads(); // Only needed if tile_K == tile_V.
}
}
@@ -909,12 +924,12 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
__syncthreads();
if (c::Q_in_reg) {
if constexpr (c::Q_in_reg) {
const int j0 = (threadIdx.y / np) * cols_per_warp;
#pragma unroll
for (int k0 = 0; k0 < DKQ/2; k0 += tile_B::J) {
if (ntiles == 1) {
if constexpr (ntiles == 1) {
load_ldmatrix(Q_B[k0/tile_B::J], tile_Q + j0*stride_tile_Q + k0, stride_tile_Q);
} else {
#pragma unroll
@@ -956,7 +971,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
// With multi-stage loading there is no __syncthreads at the end of the iter,
// there can be a race condition on shared memory access for combining/writing back results.
if (nstages > 1 && nwarps*cols_per_warp > c::nbatch_fa) {
if constexpr (nstages > 1 && nwarps*cols_per_warp > c::nbatch_fa) {
__syncthreads();
}
@@ -1101,7 +1116,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
#pragma unroll
for (int k00 = 0; k00 < DV/2; k00 += nbatch_combine) {
if (ntiles == 1) {
if constexpr (ntiles == 1) {
const int jc_cwd = threadIdx.y*tile_B::I + tile_B::get_i(-1); // jc combine write data
#pragma unroll
for (int k0 = 0; k0 < nbatch_combine; k0 += tile_B::J) {
@@ -1250,12 +1265,12 @@ static __global__ void flash_attn_ext_f16(
#if defined(INT8_MMA_AVAILABLE)
// Skip unused kernel variants for faster compilation:
if (use_logit_softcap && !(DKQ == 128 || DKQ == 256)) {
if constexpr (use_logit_softcap && !(DKQ == 128 || DKQ == 256)) {
NO_DEVICE_CODE;
return;
}
#if __CUDA_ARCH__ == CC_TURING
if (ncols1*ncols2 > 32) {
if constexpr (ncols1*ncols2 > 32) {
NO_DEVICE_CODE;
return;
}