mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-04-30 19:31:48 +00:00
constexpr and minor changes
This commit is contained in:
@@ -1,3 +1,11 @@
|
|||||||
|
// Adapted from https://github.com/ggml-org/llama.cpp/pull/13435
|
||||||
|
//
|
||||||
|
// Copyright (C) 2025 The ggml authors
|
||||||
|
// Copyright (C) 2025 Iwan Kawrakow
|
||||||
|
// MIT license
|
||||||
|
// SPDX-License-Identifier: MIT
|
||||||
|
//
|
||||||
|
|
||||||
#include "common.cuh"
|
#include "common.cuh"
|
||||||
#include "cp-async.cuh"
|
#include "cp-async.cuh"
|
||||||
#include "mma_new.cuh"
|
#include "mma_new.cuh"
|
||||||
@@ -28,6 +36,13 @@ typedef tile<16, 8, half2> tile_C_VKQ_16;
|
|||||||
template <int DKQ, int DV>
|
template <int DKQ, int DV>
|
||||||
struct fattn_mma_f16_config;
|
struct fattn_mma_f16_config;
|
||||||
|
|
||||||
|
//
|
||||||
|
// The previous MMA version is better (faster)
|
||||||
|
// I'm keeping these around commented out for now,
|
||||||
|
// and only using the 576, 512 case.
|
||||||
|
// Perhaps the 256 head size needs a closer look
|
||||||
|
// to see if this implementation is better.
|
||||||
|
//
|
||||||
//template <>
|
//template <>
|
||||||
//struct fattn_mma_f16_config< 64, 64> {
|
//struct fattn_mma_f16_config< 64, 64> {
|
||||||
// static constexpr int nbatch_fa = 64;
|
// static constexpr int nbatch_fa = 64;
|
||||||
@@ -212,19 +227,19 @@ struct fattn_mma_f16_config;
|
|||||||
// }
|
// }
|
||||||
//
|
//
|
||||||
// static int get_nbatch_combine_host(const int cc, const int ncols) {
|
// static int get_nbatch_combine_host(const int cc, const int ncols) {
|
||||||
// if (ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_TURING) {
|
// if (ggml_cuda_highest_compiled_arch(cc) == CC_TURING) {
|
||||||
// return ncols <= 16 ? 128 : 64;
|
// return ncols <= 16 ? 128 : 64;
|
||||||
// }
|
// }
|
||||||
// return 64;
|
// return 64;
|
||||||
// }
|
// }
|
||||||
//
|
//
|
||||||
// static constexpr __device__ int get_nbatch_combine_device(int ncols) {
|
// static constexpr __device__ int get_nbatch_combine_device(int ncols) {
|
||||||
//#if __CUDA_ARCH__ == GGML_CUDA_CC_TURING
|
//#if __CUDA_ARCH__ == CC_TURING
|
||||||
// return ncols <= 16 ? 128 : 64;
|
// return ncols <= 16 ? 128 : 64;
|
||||||
//#else
|
//#else
|
||||||
// GGML_UNUSED(ncols);
|
// GGML_UNUSED(ncols);
|
||||||
// return 128;
|
// return 128;
|
||||||
//#endif // __CUDA_ARCH__ == GGML_CUDA_CC_TURING
|
//#endif // __CUDA_ARCH__ == CC_TURING
|
||||||
// }
|
// }
|
||||||
//};
|
//};
|
||||||
|
|
||||||
@@ -302,7 +317,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile(
|
|||||||
// K/V data is loaded with decreasing granularity for D for better memory bandwidth.
|
// K/V data is loaded with decreasing granularity for D for better memory bandwidth.
|
||||||
// The minimum granularity with cp.async is 16 bytes, with synchronous data loading it's 4 bytes.
|
// The minimum granularity with cp.async is 16 bytes, with synchronous data loading it's 4 bytes.
|
||||||
|
|
||||||
if (use_cp_async) {
|
if constexpr (use_cp_async) {
|
||||||
constexpr int preload = 64;
|
constexpr int preload = 64;
|
||||||
constexpr int h2_per_chunk = 16/sizeof(half2);
|
constexpr int h2_per_chunk = 16/sizeof(half2);
|
||||||
const int chunks_per_row = D2 / h2_per_chunk;
|
const int chunks_per_row = D2 / h2_per_chunk;
|
||||||
@@ -373,7 +388,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_mask(
|
|||||||
const half2 * const __restrict__ mask_h2, half2 * const __restrict__ tile_mask, const int stride_mask) {
|
const half2 * const __restrict__ mask_h2, half2 * const __restrict__ tile_mask, const int stride_mask) {
|
||||||
static_assert(nbatch_fa == 2*WARP_SIZE || WARP_SIZE % nbatch_fa == 0, "bad KQ_per_iter");
|
static_assert(nbatch_fa == 2*WARP_SIZE || WARP_SIZE % nbatch_fa == 0, "bad KQ_per_iter");
|
||||||
|
|
||||||
if (use_cp_async) {
|
if constexpr (use_cp_async) {
|
||||||
constexpr int preload = nbatch_fa >= 32 ? nbatch_fa * sizeof(half) : 64;
|
constexpr int preload = nbatch_fa >= 32 ? nbatch_fa * sizeof(half) : 64;
|
||||||
constexpr int cols_per_warp = 8*WARP_SIZE/nbatch_fa;
|
constexpr int cols_per_warp = 8*WARP_SIZE/nbatch_fa;
|
||||||
constexpr int stride_j = nwarps * cols_per_warp;
|
constexpr int stride_j = nwarps * cols_per_warp;
|
||||||
@@ -478,7 +493,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
|||||||
(V_h2 + k_VKQ_0*stride_V, tile_V, nbatch_V2, stride_V);
|
(V_h2 + k_VKQ_0*stride_V, tile_V, nbatch_V2, stride_V);
|
||||||
} else {
|
} else {
|
||||||
constexpr bool use_cp_async = nstages == 1;
|
constexpr bool use_cp_async = nstages == 1;
|
||||||
if (ncols2 > 1 || mask_h2) {
|
if constexpr (ncols2 > 1 || mask_h2) {
|
||||||
flash_attn_ext_f16_load_mask<ncols1, nwarps, c::nbatch_fa, use_cp_async>(mask_h2 + k_VKQ_0/2, tile_mask, stride_mask);
|
flash_attn_ext_f16_load_mask<ncols1, nwarps, c::nbatch_fa, use_cp_async>(mask_h2 + k_VKQ_0/2, tile_mask, stride_mask);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -488,11 +503,11 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
|||||||
const int k0_stop = k0_start + nbatch_K2 < DKQ/2 ? k0_start + nbatch_K2 : DKQ/2;
|
const int k0_stop = k0_start + nbatch_K2 < DKQ/2 ? k0_start + nbatch_K2 : DKQ/2;
|
||||||
const int k0_diff = k0_stop - k0_start;
|
const int k0_diff = k0_stop - k0_start;
|
||||||
|
|
||||||
if (nstages <= 1) {
|
if constexpr (nstages <= 1) {
|
||||||
constexpr bool use_cp_async = nstages == 1;
|
constexpr bool use_cp_async = nstages == 1;
|
||||||
flash_attn_ext_f16_load_tile<stride_tile_K, nwarps, c::nbatch_fa, use_cp_async>
|
flash_attn_ext_f16_load_tile<stride_tile_K, nwarps, c::nbatch_fa, use_cp_async>
|
||||||
(K_h2 + k_VKQ_0*stride_K + k0_start, tile_K, k0_diff, stride_K);
|
(K_h2 + k_VKQ_0*stride_K + k0_start, tile_K, k0_diff, stride_K);
|
||||||
if (use_cp_async) {
|
if constexpr (use_cp_async) {
|
||||||
cp_async_wait_all();
|
cp_async_wait_all();
|
||||||
}
|
}
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
@@ -507,7 +522,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
|||||||
for (int k_KQ_0 = k0_start; k_KQ_0 < k0_stop; k_KQ_0 += tile_A::J) {
|
for (int k_KQ_0 = k0_start; k_KQ_0 < k0_stop; k_KQ_0 += tile_A::J) {
|
||||||
tile_A K_A;
|
tile_A K_A;
|
||||||
load_ldmatrix(K_A, tile_K + i_KQ_0*stride_tile_K + (k_KQ_0 - k0_start), stride_tile_K);
|
load_ldmatrix(K_A, tile_K + i_KQ_0*stride_tile_K + (k_KQ_0 - k0_start), stride_tile_K);
|
||||||
if (ntiles == 1) {
|
if constexpr (ntiles == 1) {
|
||||||
mma(KQ_C[i_KQ_00/(np*tile_A::I)], K_A, Q_B[k_KQ_0/tile_A::J]);
|
mma(KQ_C[i_KQ_00/(np*tile_A::I)], K_A, Q_B[k_KQ_0/tile_A::J]);
|
||||||
} else {
|
} else {
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
@@ -537,12 +552,12 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (nstages <= 1) {
|
if constexpr (nstages <= 1) {
|
||||||
__syncthreads(); // Only needed if tile_K == tile_V.
|
__syncthreads(); // Only needed if tile_K == tile_V.
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (use_logit_softcap) {
|
if constexpr (use_logit_softcap) {
|
||||||
static_assert(c::nbatch_fa % (np*tile_C_KQ::I) == 0, "bad loop size");
|
static_assert(c::nbatch_fa % (np*tile_C_KQ::I) == 0, "bad loop size");
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < c::nbatch_fa/(np*tile_C_KQ::I) * ntiles; ++i) {
|
for (int i = 0; i < c::nbatch_fa/(np*tile_C_KQ::I) * ntiles; ++i) {
|
||||||
@@ -560,8 +575,8 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
|||||||
}
|
}
|
||||||
float KQ_rowsum_add[cols_per_thread] = {0.0f};
|
float KQ_rowsum_add[cols_per_thread] = {0.0f};
|
||||||
|
|
||||||
if (ntiles == 1) {
|
if constexpr (ntiles == 1) {
|
||||||
if (ncols2 > 1 || mask_h2) {
|
if constexpr (ncols2 > 1 || mask_h2) {
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i00 = 0; i00 < c::nbatch_fa; i00 += np*tile_C_KQ::I) {
|
for (int i00 = 0; i00 < c::nbatch_fa; i00 += np*tile_C_KQ::I) {
|
||||||
const int i0 = i00 + (threadIdx.y % np)*tile_C_KQ::I;
|
const int i0 = i00 + (threadIdx.y % np)*tile_C_KQ::I;
|
||||||
@@ -679,7 +694,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
|||||||
KQ_rowsum[col] = KQ_max_scale[col]*KQ_rowsum[col] + KQ_rowsum_add[col];
|
KQ_rowsum[col] = KQ_max_scale[col]*KQ_rowsum[col] + KQ_rowsum_add[col];
|
||||||
}
|
}
|
||||||
|
|
||||||
if (ntiles == 1) {
|
if constexpr (ntiles == 1) {
|
||||||
const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[0], KQ_max_scale[1]);
|
const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[0], KQ_max_scale[1]);
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < DV/tile_C_VKQ::I; ++i) {
|
for (int i = 0; i < DV/tile_C_VKQ::I; ++i) {
|
||||||
@@ -707,7 +722,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
|||||||
tile_B B[c::nbatch_fa/(np*2*tile_B::J) * ntiles];
|
tile_B B[c::nbatch_fa/(np*2*tile_B::J) * ntiles];
|
||||||
tile_B_16 * B_16 = (tile_B_16 *) B;
|
tile_B_16 * B_16 = (tile_B_16 *) B;
|
||||||
static_assert(c::nbatch_fa % (np*2*tile_B::J) == 0, "bad loop size");
|
static_assert(c::nbatch_fa % (np*2*tile_B::J) == 0, "bad loop size");
|
||||||
if (ntiles == 1) {
|
if constexpr (ntiles == 1) {
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int k = 0; k < c::nbatch_fa/(np*2*tile_B::J); ++k) {
|
for (int k = 0; k < c::nbatch_fa/(np*2*tile_B::J); ++k) {
|
||||||
B[k] = get_transposed(get_half2(KQ_C[k]));
|
B[k] = get_transposed(get_half2(KQ_C[k]));
|
||||||
@@ -721,7 +736,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (nstages > 1) {
|
if constexpr (nstages > 1) {
|
||||||
// Preload K tile for next iteration:
|
// Preload K tile for next iteration:
|
||||||
constexpr bool use_cp_async = true;
|
constexpr bool use_cp_async = true;
|
||||||
cp_async_wait_all();
|
cp_async_wait_all();
|
||||||
@@ -750,7 +765,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
|||||||
constexpr bool use_cp_async = nstages == 1;
|
constexpr bool use_cp_async = nstages == 1;
|
||||||
flash_attn_ext_f16_load_tile<stride_tile_V, nwarps, c::nbatch_fa, use_cp_async>
|
flash_attn_ext_f16_load_tile<stride_tile_V, nwarps, c::nbatch_fa, use_cp_async>
|
||||||
(V_h2 + k_VKQ_0*stride_V + i0_start/2, tile_V, i0_diff/2, stride_V);
|
(V_h2 + k_VKQ_0*stride_V + i0_start/2, tile_V, i0_diff/2, stride_V);
|
||||||
if (use_cp_async) {
|
if constexpr (use_cp_async) {
|
||||||
cp_async_wait_all();
|
cp_async_wait_all();
|
||||||
}
|
}
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
@@ -767,7 +782,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
|||||||
|
|
||||||
tile_A A;
|
tile_A A;
|
||||||
load_ldmatrix_trans(A, tile_V_i + 2*k0*stride_tile_V + (i_VKQ_0 - i0_start)/2, stride_tile_V);
|
load_ldmatrix_trans(A, tile_V_i + 2*k0*stride_tile_V + (i_VKQ_0 - i0_start)/2, stride_tile_V);
|
||||||
if (ntiles == 1) {
|
if constexpr (ntiles == 1) {
|
||||||
mma(VKQ_C[i_VKQ_0/tile_C_VKQ::I], A, B[k00/(np*tile_A::J)]);
|
mma(VKQ_C[i_VKQ_0/tile_C_VKQ::I], A, B[k00/(np*tile_A::J)]);
|
||||||
} else {
|
} else {
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
@@ -779,7 +794,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (nstages <= 1) {
|
if constexpr (nstages <= 1) {
|
||||||
__syncthreads(); // Only needed if tile_K == tile_V.
|
__syncthreads(); // Only needed if tile_K == tile_V.
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -909,12 +924,12 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|||||||
|
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
if (c::Q_in_reg) {
|
if constexpr (c::Q_in_reg) {
|
||||||
const int j0 = (threadIdx.y / np) * cols_per_warp;
|
const int j0 = (threadIdx.y / np) * cols_per_warp;
|
||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int k0 = 0; k0 < DKQ/2; k0 += tile_B::J) {
|
for (int k0 = 0; k0 < DKQ/2; k0 += tile_B::J) {
|
||||||
if (ntiles == 1) {
|
if constexpr (ntiles == 1) {
|
||||||
load_ldmatrix(Q_B[k0/tile_B::J], tile_Q + j0*stride_tile_Q + k0, stride_tile_Q);
|
load_ldmatrix(Q_B[k0/tile_B::J], tile_Q + j0*stride_tile_Q + k0, stride_tile_Q);
|
||||||
} else {
|
} else {
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
@@ -956,7 +971,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|||||||
|
|
||||||
// With multi-stage loading there is no __syncthreads at the end of the iter,
|
// With multi-stage loading there is no __syncthreads at the end of the iter,
|
||||||
// there can be a race condition on shared memory access for combining/writing back results.
|
// there can be a race condition on shared memory access for combining/writing back results.
|
||||||
if (nstages > 1 && nwarps*cols_per_warp > c::nbatch_fa) {
|
if constexpr (nstages > 1 && nwarps*cols_per_warp > c::nbatch_fa) {
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1101,7 +1116,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int k00 = 0; k00 < DV/2; k00 += nbatch_combine) {
|
for (int k00 = 0; k00 < DV/2; k00 += nbatch_combine) {
|
||||||
if (ntiles == 1) {
|
if constexpr (ntiles == 1) {
|
||||||
const int jc_cwd = threadIdx.y*tile_B::I + tile_B::get_i(-1); // jc combine write data
|
const int jc_cwd = threadIdx.y*tile_B::I + tile_B::get_i(-1); // jc combine write data
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int k0 = 0; k0 < nbatch_combine; k0 += tile_B::J) {
|
for (int k0 = 0; k0 < nbatch_combine; k0 += tile_B::J) {
|
||||||
@@ -1250,12 +1265,12 @@ static __global__ void flash_attn_ext_f16(
|
|||||||
#if defined(INT8_MMA_AVAILABLE)
|
#if defined(INT8_MMA_AVAILABLE)
|
||||||
|
|
||||||
// Skip unused kernel variants for faster compilation:
|
// Skip unused kernel variants for faster compilation:
|
||||||
if (use_logit_softcap && !(DKQ == 128 || DKQ == 256)) {
|
if constexpr (use_logit_softcap && !(DKQ == 128 || DKQ == 256)) {
|
||||||
NO_DEVICE_CODE;
|
NO_DEVICE_CODE;
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
#if __CUDA_ARCH__ == CC_TURING
|
#if __CUDA_ARCH__ == CC_TURING
|
||||||
if (ncols1*ncols2 > 32) {
|
if constexpr (ncols1*ncols2 > 32) {
|
||||||
NO_DEVICE_CODE;
|
NO_DEVICE_CODE;
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user