This commit is contained in:
Iwan Kawrakow
2025-03-03 17:59:42 +02:00
parent 4c673e7ace
commit 3c72a7b47c
7 changed files with 125 additions and 75 deletions

View File

@@ -6,7 +6,7 @@
#endif // FP16_MMA_AVAILABLE
// Dk == K head size, Dv = V head size, VKQ_stride == num VKQ rows calculated in parallel:
template<int D, int ncols, int nwarps, int VKQ_stride, int parallel_blocks, typename KQ_acc_t, bool use_softcap>
template<int Dk, int Dv, int ncols, int nwarps, int VKQ_stride, int parallel_blocks, typename KQ_acc_t, bool use_softcap>
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
__launch_bounds__(nwarps*WARP_SIZE, 1)
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
@@ -47,8 +47,9 @@ static __global__ void flash_attn_ext_f16(
const int ne2,
const int ne3) {
#ifdef FP16_MMA_AVAILABLE
static_assert(Dk == Dv || (Dk == 192 && Dv == 128) || (Dk == 576 && Dv == 512));
// Skip unused kernel variants for faster compilation:
if (use_softcap && !(D == 128 || D == 256)) {
if (use_softcap && !(Dk == 128 || Dk == 256)) {
NO_DEVICE_CODE;
return;
}
@@ -58,11 +59,11 @@ static __global__ void flash_attn_ext_f16(
const int ic0 = ncols*(blockIdx.x / parallel_blocks); // Index of the first Q/QKV column to work on.
const int ip = blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel.
static_assert(D <= FATTN_KQ_STRIDE, "D must be <= FATTN_KQ_STRIDE.");
static_assert(Dk <= FATTN_KQ_STRIDE, "D must be <= FATTN_KQ_STRIDE.");
static_assert(ncols == 8 || ncols % 16 == 0, "ncols must be 8 or a multiple of 16.");
constexpr int frag_m = ncols == 8 ? 32 : 16;
constexpr int frag_n = ncols == 8 ? 8 : 16;
static_assert(D % frag_m == 0, "If ncols == 8 then D % frag_m must be 0.");
static_assert(Dk % frag_m == 0, "If ncols == 8 then Dk % frag_m must be 0.");
typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_a, frag_m, frag_n, 16, half, nvcuda::wmma::row_major> frag_a_K;
typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_a, frag_m, frag_n, 16, half, nvcuda::wmma::col_major> frag_a_V;
typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_b, frag_m, frag_n, 16, half, nvcuda::wmma::col_major> frag_b;
@@ -74,30 +75,32 @@ static __global__ void flash_attn_ext_f16(
static_assert(VKQ_ratio <= nwarps, "VKQ_ratio must be <= nwarps.");
// Pad internal representation of KQ, KQV to reduce shared memory bank conflicts:
constexpr int D_padded = D + 8;
constexpr int Dk_padded = Dk + 8;
constexpr int Dv_padded = Dv + 8;
constexpr int kqs_padded = FATTN_KQ_STRIDE + 8;
constexpr int kqar = sizeof(KQ_acc_t)/sizeof(half);
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
const float * Q_f = (const float *) (Q + nb02* blockIdx.y + nb01*ic0);
const half * K_h = (const half *) (K + nb12*(blockIdx.y / gqa_ratio));
const half * V_h = (const half *) (V + nb12*(blockIdx.y / gqa_ratio)); // K and V have same shape
const half * V_h = (const half *) (V + nb22*(blockIdx.y / gqa_ratio)); // K and V have same shape
const half * maskh = (const half *) mask + (nb31/sizeof(half))* ic0;
const half2 * mask2 = (const half2 *) mask + (nb31/sizeof(half))*(ic0/2);
const int stride_Q = nb01 / sizeof(float);
const int stride_KV = nb11 / sizeof(half);
const int stride_Q = nb01 / sizeof(float);
const int stride_K = nb11 / sizeof(half);
const int stride_V = nb21 / sizeof(half);
const float slopef = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1);
const half slopeh = __float2half(slopef);
const half2 slope2 = make_half2(slopef, slopef);
const half2 softcap_2 = make_half2(softcap, softcap);
frag_b Q_b[D/16][ncols/frag_n];
frag_b Q_b[Dk/16][ncols/frag_n];
// A single buffer for temporarily holding tiles of KQ and VKQ parts:
constexpr int mem_KQ = ncols*kqs_padded*kqar;
constexpr int mem_VKQ_parts = VKQ_ratio*ncols*D_padded;
constexpr int mem_VKQ_parts = VKQ_ratio*ncols*Dv_padded;
__shared__ half KQ[mem_KQ >= mem_VKQ_parts ? mem_KQ : mem_VKQ_parts];
float * KQ_f = (float *) KQ;
half2 * KQ2 = (half2 *) KQ;
@@ -120,18 +123,18 @@ static __global__ void flash_attn_ext_f16(
KQ_max_h2[j] = make_half2(-HALF_MAX_HALF, -HALF_MAX_HALF);
}
__shared__ half VKQ[ncols*D_padded]; // Accumulator for final VKQ slice.
__shared__ half VKQ[ncols*Dv_padded]; // Accumulator for final VKQ slice.
half2 * VKQ2 = (half2 *) VKQ;
#pragma unroll
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
const int j = j0 + threadIdx.y;
#pragma unroll
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
for (int i0 = 0; i0 < Dv/2; i0 += WARP_SIZE) {
const int i = i0 + threadIdx.x;
if (i0 + WARP_SIZE > D/2 && i >= D/2) {
if (i0 + WARP_SIZE > Dv/2 && i >= Dv/2) {
break;
}
VKQ2[j*(D_padded/2) + i] = make_half2(0.0f, 0.0f);
VKQ2[j*(Dv_padded/2) + i] = make_half2(0.0f, 0.0f);
}
}
@@ -140,12 +143,12 @@ static __global__ void flash_attn_ext_f16(
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
const int j = j0 + threadIdx.y;
#pragma unroll
for (int i0 = 0; i0 < D; i0 += WARP_SIZE) {
for (int i0 = 0; i0 < Dk; i0 += WARP_SIZE) {
const int i = i0 + threadIdx.x;
if (i0 + WARP_SIZE > D && i >= D) {
if (i0 + WARP_SIZE > Dk && i >= Dk) {
break;
}
KQ[j*D_padded + i] = ic0 + j < ne01 ? Q_f[j*stride_Q + i] * scale : 0.0f;
KQ[j*Dk_padded + i] = ic0 + j < ne01 ? Q_f[j*stride_Q + i] * scale : 0.0f;
}
}
@@ -153,10 +156,10 @@ static __global__ void flash_attn_ext_f16(
// Load Q into tensor core fragments/registers since it will be used frequently:
#pragma unroll
for (int i0 = 0; i0 < D; i0 += 16) {
for (int i0 = 0; i0 < Dk; i0 += 16) {
#pragma unroll
for (int j0 = 0; j0 < ncols; j0 += frag_n) {
nvcuda::wmma::load_matrix_sync(Q_b[i0/16][j0/frag_n], KQ + j0*D_padded + i0, D_padded);
nvcuda::wmma::load_matrix_sync(Q_b[i0/16][j0/frag_n], KQ + j0*Dk_padded + i0, Dk_padded);
}
}
@@ -173,9 +176,9 @@ static __global__ void flash_attn_ext_f16(
nvcuda::wmma::fill_fragment(KQ_c[j], 0.0f);
}
#pragma unroll
for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += 16) {
for (int k_KQ_0 = 0; k_KQ_0 < Dk; k_KQ_0 += 16) {
frag_a_K K_a;
nvcuda::wmma::load_matrix_sync(K_a, K_h + (k_VKQ_0 + i_KQ_0 + frag_m*threadIdx.y)*stride_KV + k_KQ_0, stride_KV);
nvcuda::wmma::load_matrix_sync(K_a, K_h + (k_VKQ_0 + i_KQ_0 + frag_m*threadIdx.y)*stride_K + k_KQ_0, stride_K);
#pragma unroll
for (int j = 0; j < ncols/frag_n; ++j) {
nvcuda::wmma::mma_sync(KQ_c[j], K_a, Q_b[k_KQ_0/16][j], KQ_c[j]);
@@ -309,9 +312,9 @@ static __global__ void flash_attn_ext_f16(
}
}
frag_c_VKQ VKQ_c[D/VKQ_stride][ncols/frag_n];
frag_c_VKQ VKQ_c[Dv/VKQ_stride][ncols/frag_n];
#pragma unroll
for (int i_VKQ_0 = 0; i_VKQ_0 < D; i_VKQ_0 += VKQ_stride) {
for (int i_VKQ_0 = 0; i_VKQ_0 < Dv; i_VKQ_0 += VKQ_stride) {
#pragma unroll
for (int j = 0; j < ncols/frag_n; ++j) {
nvcuda::wmma::fill_fragment(VKQ_c[i_VKQ_0/VKQ_stride][j], 0.0f);
@@ -322,7 +325,7 @@ static __global__ void flash_attn_ext_f16(
const int k = k0 + (threadIdx.y % VKQ_ratio)*16;
frag_a_V v_a;
nvcuda::wmma::load_matrix_sync(v_a, V_h + (k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx.y/VKQ_ratio), stride_KV);
nvcuda::wmma::load_matrix_sync(v_a, V_h + (k_VKQ_0 + k)*stride_V + i_VKQ_0 + frag_m*(threadIdx.y/VKQ_ratio), stride_V);
#pragma unroll
for (int j = 0; j < ncols/frag_n; ++j) {
nvcuda::wmma::mma_sync(VKQ_c[i_VKQ_0/VKQ_stride][j], v_a, KQ_b[k0/(VKQ_ratio*16)][j], VKQ_c[i_VKQ_0/VKQ_stride][j]);
@@ -332,15 +335,15 @@ static __global__ void flash_attn_ext_f16(
__syncthreads();
const int offset_k = (threadIdx.y % VKQ_ratio) * (ncols*D_padded);
const int offset_k = (threadIdx.y % VKQ_ratio) * (ncols*Dk_padded);
#pragma unroll
for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += VKQ_stride) {
for (int i_KQ_0 = 0; i_KQ_0 < Dk; i_KQ_0 += VKQ_stride) {
#pragma unroll
for (int j0 = 0; j0 < ncols; j0 += frag_n) {
nvcuda::wmma::store_matrix_sync(
KQ + offset_k + j0*D_padded + i_KQ_0 + frag_m*(threadIdx.y/VKQ_ratio),
KQ + offset_k + j0*Dk_padded + i_KQ_0 + frag_m*(threadIdx.y/VKQ_ratio),
VKQ_c[i_KQ_0/VKQ_stride][j0/frag_n],
D_padded, nvcuda::wmma::mem_col_major);
Dk_padded, nvcuda::wmma::mem_col_major);
}
}
@@ -358,18 +361,18 @@ static __global__ void flash_attn_ext_f16(
}
#pragma unroll
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
for (int i0 = 0; i0 < Dv/2; i0 += WARP_SIZE) {
const int i = i0 + threadIdx.x;
if (i0 + WARP_SIZE > D/2 && i >= D/2) {
if (i0 + WARP_SIZE > Dv/2 && i >= Dv/2) {
break;
}
half2 VKQ_add = make_half2(0.0f, 0.0f);
#pragma unroll
for (int l = 0; l < VKQ_ratio; ++l) {
VKQ_add += KQ2[l*(ncols*D_padded/2) + j*(D_padded/2) + i];
VKQ_add += KQ2[l*(ncols*Dk_padded/2) + j*(Dk_padded/2) + i];
}
VKQ2[j*(D_padded/2) + i] = VKQ_scale*VKQ2[j*(D_padded/2) + i] + VKQ_add;
VKQ2[j*(Dv_padded/2) + i] = VKQ_scale*VKQ2[j*(Dv_padded/2) + i] + VKQ_add;
}
}
@@ -392,16 +395,16 @@ static __global__ void flash_attn_ext_f16(
}
#pragma unroll
for (int i0 = 0; i0 < D; i0 += WARP_SIZE) {
for (int i0 = 0; i0 < Dv; i0 += WARP_SIZE) {
const int i = i0 + threadIdx.x;
if (i0 + WARP_SIZE > D && i >= D) {
if (i0 + WARP_SIZE > Dv && i >= Dv) {
break;
}
float dst_val = VKQ[j_VKQ*D_padded + i];
float dst_val = VKQ[j_VKQ*Dv_padded + i];
if (parallel_blocks == 1) {
dst_val /= KQ_rowsum_j;
}
dst[j_dst*gridDim.y*D + blockIdx.y*D + i] = dst_val;
dst[j_dst*gridDim.y*Dv + blockIdx.y*Dv + i] = dst_val;
}
if (parallel_blocks == 1 || threadIdx.x != 0) {
@@ -446,13 +449,13 @@ static_assert(get_VKQ_stride( 80, 1, 16) == 16, "Test failed.");
static_assert(get_VKQ_stride( 80, 2, 16) == 16, "Test failed.");
static_assert(get_VKQ_stride( 80, 4, 16) == 16, "Test failed.");
template <int D, int cols_per_block, typename KQ_acc_t>
template <int Dk, int Dv, int cols_per_block, typename KQ_acc_t>
void ggml_cuda_flash_attn_ext_wmma_f16_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * Q = dst->src[0];
constexpr int nwarps = 4;
constexpr int frag_m = cols_per_block == 8 && D % 32 == 0 ? 32 : 16;
constexpr int frag_m = cols_per_block == 8 && Dk % 32 == 0 ? 32 : 16;
const int blocks_num_pb1 = ((Q->ne[1] + cols_per_block - 1) / cols_per_block)*Q->ne[2]*Q->ne[3];
const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm;
@@ -462,29 +465,33 @@ void ggml_cuda_flash_attn_ext_wmma_f16_case(ggml_backend_cuda_context & ctx, ggm
if (4*blocks_num_pb1 < 2*nsm) {
constexpr int parallel_blocks = 4;
fattn_kernel_t fattn_kernel = softcap == 0.0f ?
flash_attn_ext_f16<D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, false> :
flash_attn_ext_f16<D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, true>;
launch_fattn<D, D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
flash_attn_ext_f16<Dk, Dv, cols_per_block, nwarps, get_VKQ_stride(Dv, nwarps, frag_m), parallel_blocks, KQ_acc_t, false> :
flash_attn_ext_f16<Dk, Dv, cols_per_block, nwarps, get_VKQ_stride(Dv, nwarps, frag_m), parallel_blocks, KQ_acc_t, true>;
launch_fattn<Dk, Dv, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
return;
}
if (2*blocks_num_pb1 < 2*nsm) {
constexpr int parallel_blocks = 2;
fattn_kernel_t fattn_kernel = softcap == 0.0f ?
flash_attn_ext_f16<D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, false> :
flash_attn_ext_f16<D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, true>;
launch_fattn<D, D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
flash_attn_ext_f16<Dk, Dv, cols_per_block, nwarps, get_VKQ_stride(Dv, nwarps, frag_m), parallel_blocks, KQ_acc_t, false> :
flash_attn_ext_f16<Dk, Dv, cols_per_block, nwarps, get_VKQ_stride(Dv, nwarps, frag_m), parallel_blocks, KQ_acc_t, true>;
launch_fattn<Dk, Dv, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
return;
}
constexpr int parallel_blocks = 1;
fattn_kernel_t fattn_kernel = softcap == 0.0f ?
flash_attn_ext_f16<D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, false> :
flash_attn_ext_f16<D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, true>;
launch_fattn<D, D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
flash_attn_ext_f16<Dk, Dv, cols_per_block, nwarps, get_VKQ_stride(Dv, nwarps, frag_m), parallel_blocks, KQ_acc_t, false> :
flash_attn_ext_f16<Dk, Dv, cols_per_block, nwarps, get_VKQ_stride(Dv, nwarps, frag_m), parallel_blocks, KQ_acc_t, true>;
launch_fattn<Dk, Dv, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
}
#define DECL_FATTN_WMMA_F16_CASE(D, cols_per_block, KQ_acc_t) \
template void ggml_cuda_flash_attn_ext_wmma_f16_case \
<D, cols_per_block, KQ_acc_t>(ggml_backend_cuda_context & ctx, ggml_tensor * dst) \
<D, D, cols_per_block, KQ_acc_t>(ggml_backend_cuda_context & ctx, ggml_tensor * dst) \
#define DECL_FATTN_WMMA_F16_CASE_DKDV(Dk, Dv, cols_per_block, KQ_acc_t) \
template void ggml_cuda_flash_attn_ext_wmma_f16_case \
<Dk, Dv, cols_per_block, KQ_acc_t>(ggml_backend_cuda_context & ctx, ggml_tensor * dst) \
extern DECL_FATTN_WMMA_F16_CASE( 64, 16, float);
extern DECL_FATTN_WMMA_F16_CASE( 80, 16, float);
@@ -518,3 +525,7 @@ extern DECL_FATTN_WMMA_F16_CASE( 96, 32, half);
extern DECL_FATTN_WMMA_F16_CASE(112, 32, half);
extern DECL_FATTN_WMMA_F16_CASE(128, 32, half);
extern DECL_FATTN_WMMA_F16_CASE(256, 16, half);
extern DECL_FATTN_WMMA_F16_CASE_DKDV(192, 128, 8, half);
extern DECL_FATTN_WMMA_F16_CASE_DKDV(192, 128, 16, half);
extern DECL_FATTN_WMMA_F16_CASE_DKDV(192, 128, 32, half);

View File

@@ -12,6 +12,14 @@
static void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * KQV = dst;
const ggml_tensor * Q = dst->src[0];
const ggml_tensor * V = dst->src[2];
if (Q->ne[0] != V->ne[0]) {
if (!((Q->ne[0] == 192 && V->ne[0] == 128) || (Q->ne[0] == 576 && V->ne[0] == 512))) {
fprintf(stderr, "======================= %s: Unhandled head size combination %d, %d\n", __func__, (int)Q->ne[0], (int)V->ne[0]);
GGML_ABORT("fatal error");
}
}
const int32_t precision = KQV->op_params[3];
@@ -20,22 +28,25 @@ static void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, g
constexpr int cols_per_block = 16;
switch (Q->ne[0]) {
case 64:
ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, float>(ctx, dst);
ggml_cuda_flash_attn_ext_wmma_f16_case< 64, 64, cols_per_block, float>(ctx, dst);
break;
case 80:
ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, float>(ctx, dst);
ggml_cuda_flash_attn_ext_wmma_f16_case< 80, 80, cols_per_block, float>(ctx, dst);
break;
case 96:
ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, float>(ctx, dst);
ggml_cuda_flash_attn_ext_wmma_f16_case< 96, 96, cols_per_block, float>(ctx, dst);
break;
case 112:
ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, float>(ctx, dst);
ggml_cuda_flash_attn_ext_wmma_f16_case<112, 112, cols_per_block, float>(ctx, dst);
break;
case 128:
ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, float>(ctx, dst);
ggml_cuda_flash_attn_ext_wmma_f16_case<128, 128, cols_per_block, float>(ctx, dst);
break;
case 256:
ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, float>(ctx, dst);
ggml_cuda_flash_attn_ext_wmma_f16_case<256, 256, cols_per_block, float>(ctx, dst);
break;
case 192:
ggml_cuda_flash_attn_ext_wmma_f16_case<192, 128, cols_per_block, float>(ctx, dst);
break;
default:
fprintf(stderr, "======================= %s: Unhandled head size %d\n", __func__, (int)Q->ne[0]);
@@ -46,19 +57,22 @@ static void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, g
constexpr int cols_per_block = 32;
switch (Q->ne[0]) {
case 64:
ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, float>(ctx, dst);
ggml_cuda_flash_attn_ext_wmma_f16_case< 64, 64, cols_per_block, float>(ctx, dst);
break;
case 80:
ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, float>(ctx, dst);
ggml_cuda_flash_attn_ext_wmma_f16_case< 80, 80, cols_per_block, float>(ctx, dst);
break;
case 96:
ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, float>(ctx, dst);
ggml_cuda_flash_attn_ext_wmma_f16_case< 96, 96, cols_per_block, float>(ctx, dst);
break;
case 112:
ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, float>(ctx, dst);
ggml_cuda_flash_attn_ext_wmma_f16_case<112, 112, cols_per_block, float>(ctx, dst);
break;
case 128:
ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, float>(ctx, dst);
ggml_cuda_flash_attn_ext_wmma_f16_case<128, 128, cols_per_block, float>(ctx, dst);
break;
case 192:
ggml_cuda_flash_attn_ext_wmma_f16_case<192, 128, cols_per_block, float>(ctx, dst);
break;
// case 256:
// ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, float>(ctx, dst);
@@ -76,16 +90,19 @@ static void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, g
constexpr int cols_per_block = 8;
switch (Q->ne[0]) {
case 64:
ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, half>(ctx, dst);
ggml_cuda_flash_attn_ext_wmma_f16_case< 64, 64, cols_per_block, half>(ctx, dst);
break;
case 96:
ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, half>(ctx, dst);
ggml_cuda_flash_attn_ext_wmma_f16_case< 96, 96, cols_per_block, half>(ctx, dst);
break;
case 128:
ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, half>(ctx, dst);
ggml_cuda_flash_attn_ext_wmma_f16_case<128, 128, cols_per_block, half>(ctx, dst);
break;
case 192:
ggml_cuda_flash_attn_ext_wmma_f16_case<192, 128, cols_per_block, half>(ctx, dst);
break;
case 256:
ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, half>(ctx, dst);
ggml_cuda_flash_attn_ext_wmma_f16_case<256, 256, cols_per_block, half>(ctx, dst);
break;
default:
fprintf(stderr, "======================= %s: Unhandled head size %d\n", __func__, (int)Q->ne[0]);
@@ -99,22 +116,25 @@ static void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, g
constexpr int cols_per_block = 16;
switch (Q->ne[0]) {
case 64:
ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, half>(ctx, dst);
ggml_cuda_flash_attn_ext_wmma_f16_case< 64, 64, cols_per_block, half>(ctx, dst);
break;
case 80:
ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, half>(ctx, dst);
ggml_cuda_flash_attn_ext_wmma_f16_case< 80, 80, cols_per_block, half>(ctx, dst);
break;
case 96:
ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, half>(ctx, dst);
ggml_cuda_flash_attn_ext_wmma_f16_case< 96, 96, cols_per_block, half>(ctx, dst);
break;
case 112:
ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, half>(ctx, dst);
ggml_cuda_flash_attn_ext_wmma_f16_case<112, 112, cols_per_block, half>(ctx, dst);
break;
case 128:
ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, half>(ctx, dst);
ggml_cuda_flash_attn_ext_wmma_f16_case<128, 128, cols_per_block, half>(ctx, dst);
break;
case 192:
ggml_cuda_flash_attn_ext_wmma_f16_case<192, 128, cols_per_block, half>(ctx, dst);
break;
case 256:
ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, half>(ctx, dst);
ggml_cuda_flash_attn_ext_wmma_f16_case<256, 256, cols_per_block, half>(ctx, dst);
break;
default:
fprintf(stderr, "======================= %s: Unhandled head size %d\n", __func__, (int)Q->ne[0]);
@@ -127,22 +147,25 @@ static void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, g
constexpr int cols_per_block = 32;
switch (Q->ne[0]) {
case 64:
ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, half>(ctx, dst);
ggml_cuda_flash_attn_ext_wmma_f16_case< 64, 64, cols_per_block, half>(ctx, dst);
break;
case 80:
ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, half>(ctx, dst);
ggml_cuda_flash_attn_ext_wmma_f16_case< 80, 80, cols_per_block, half>(ctx, dst);
break;
case 96:
ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, half>(ctx, dst);
ggml_cuda_flash_attn_ext_wmma_f16_case< 96, 96, cols_per_block, half>(ctx, dst);
break;
case 112:
ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, half>(ctx, dst);
ggml_cuda_flash_attn_ext_wmma_f16_case<112, 112, cols_per_block, half>(ctx, dst);
break;
case 128:
ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, half>(ctx, dst);
ggml_cuda_flash_attn_ext_wmma_f16_case<128, 128, cols_per_block, half>(ctx, dst);
break;
case 192:
ggml_cuda_flash_attn_ext_wmma_f16_case<192, 128, cols_per_block, half>(ctx, dst);
break;
case 256:
ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, half>(ctx, dst);
ggml_cuda_flash_attn_ext_wmma_f16_case<256, 256, cols_per_block, half>(ctx, dst);
break;
default:
fprintf(stderr, "======================= %s: Unhandled head size %d\n", __func__, (int)Q->ne[0]);
@@ -156,6 +179,12 @@ static void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, g
return; \
} \
#define FATTN_VEC_F16_CASE_DKDV(Dk, Dv, type_K, type_V) \
if (Q->ne[0] == (Dk) && V->ne[0] == Dv && K->type == (type_K) && V->type == (type_V)) { \
ggml_cuda_flash_attn_ext_vec_f16_case<Dk, Dv, type_K, type_V>(ctx, dst); \
return; \
} \
static void ggml_cuda_flash_attn_ext_vec_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
ggml_tensor * Q = dst->src[0];
ggml_tensor * K = dst->src[1];

View File

@@ -8,3 +8,5 @@ DECL_FATTN_WMMA_F16_CASE(96, 16, float);
DECL_FATTN_WMMA_F16_CASE(112, 16, float);
DECL_FATTN_WMMA_F16_CASE(128, 16, float);
DECL_FATTN_WMMA_F16_CASE(256, 16, float);
DECL_FATTN_WMMA_F16_CASE_DKDV(192, 128, 16, float);

View File

@@ -7,3 +7,5 @@ DECL_FATTN_WMMA_F16_CASE(80, 32, float);
DECL_FATTN_WMMA_F16_CASE(96, 32, float);
DECL_FATTN_WMMA_F16_CASE(112, 32, float);
DECL_FATTN_WMMA_F16_CASE(128, 32, float);
DECL_FATTN_WMMA_F16_CASE_DKDV(192, 128, 32, float);

View File

@@ -8,3 +8,5 @@ DECL_FATTN_WMMA_F16_CASE(96, 16, half);
DECL_FATTN_WMMA_F16_CASE(112, 16, half);
DECL_FATTN_WMMA_F16_CASE(128, 16, half);
DECL_FATTN_WMMA_F16_CASE(256, 16, half);
DECL_FATTN_WMMA_F16_CASE_DKDV(192, 128, 16, half);

View File

@@ -8,3 +8,5 @@ DECL_FATTN_WMMA_F16_CASE(96, 32, half);
DECL_FATTN_WMMA_F16_CASE(112, 32, half);
DECL_FATTN_WMMA_F16_CASE(128, 32, half);
DECL_FATTN_WMMA_F16_CASE(256, 32, half);
DECL_FATTN_WMMA_F16_CASE_DKDV(192, 128, 32, half);

View File

@@ -6,3 +6,5 @@ DECL_FATTN_WMMA_F16_CASE(64, 8, half);
DECL_FATTN_WMMA_F16_CASE(96, 8, half);
DECL_FATTN_WMMA_F16_CASE(128, 8, half);
DECL_FATTN_WMMA_F16_CASE(256, 8, half);
DECL_FATTN_WMMA_F16_CASE_DKDV(192, 128, 8, half);