From 3c72a7b47cefb974331cde246bc83620c1191f58 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Mon, 3 Mar 2025 17:59:42 +0200 Subject: [PATCH] WIP --- ggml/src/ggml-cuda/fattn-wmma-f16.cuh | 107 ++++++++++-------- ggml/src/ggml-cuda/fattn.cu | 83 +++++++++----- .../fattn-wmma-f16-instance-kqfloat-cpb16.cu | 2 + .../fattn-wmma-f16-instance-kqfloat-cpb32.cu | 2 + .../fattn-wmma-f16-instance-kqhalf-cpb16.cu | 2 + .../fattn-wmma-f16-instance-kqhalf-cpb32.cu | 2 + .../fattn-wmma-f16-instance-kqhalf-cpb8.cu | 2 + 7 files changed, 125 insertions(+), 75 deletions(-) diff --git a/ggml/src/ggml-cuda/fattn-wmma-f16.cuh b/ggml/src/ggml-cuda/fattn-wmma-f16.cuh index 4804a9ef..c5ffc7d1 100644 --- a/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-wmma-f16.cuh @@ -6,7 +6,7 @@ #endif // FP16_MMA_AVAILABLE // Dk == K head size, Dv = V head size, VKQ_stride == num VKQ rows calculated in parallel: -template +template #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) __launch_bounds__(nwarps*WARP_SIZE, 1) #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) @@ -47,8 +47,9 @@ static __global__ void flash_attn_ext_f16( const int ne2, const int ne3) { #ifdef FP16_MMA_AVAILABLE + static_assert(Dk == Dv || (Dk == 192 && Dv == 128) || (Dk == 576 && Dv == 512)); // Skip unused kernel variants for faster compilation: - if (use_softcap && !(D == 128 || D == 256)) { + if (use_softcap && !(Dk == 128 || Dk == 256)) { NO_DEVICE_CODE; return; } @@ -58,11 +59,11 @@ static __global__ void flash_attn_ext_f16( const int ic0 = ncols*(blockIdx.x / parallel_blocks); // Index of the first Q/QKV column to work on. const int ip = blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel. - static_assert(D <= FATTN_KQ_STRIDE, "D must be <= FATTN_KQ_STRIDE."); + static_assert(Dk <= FATTN_KQ_STRIDE, "D must be <= FATTN_KQ_STRIDE."); static_assert(ncols == 8 || ncols % 16 == 0, "ncols must be 8 or a multiple of 16."); constexpr int frag_m = ncols == 8 ? 32 : 16; constexpr int frag_n = ncols == 8 ? 8 : 16; - static_assert(D % frag_m == 0, "If ncols == 8 then D % frag_m must be 0."); + static_assert(Dk % frag_m == 0, "If ncols == 8 then Dk % frag_m must be 0."); typedef nvcuda::wmma::fragment frag_a_K; typedef nvcuda::wmma::fragment frag_a_V; typedef nvcuda::wmma::fragment frag_b; @@ -74,30 +75,32 @@ static __global__ void flash_attn_ext_f16( static_assert(VKQ_ratio <= nwarps, "VKQ_ratio must be <= nwarps."); // Pad internal representation of KQ, KQV to reduce shared memory bank conflicts: - constexpr int D_padded = D + 8; + constexpr int Dk_padded = Dk + 8; + constexpr int Dv_padded = Dv + 8; constexpr int kqs_padded = FATTN_KQ_STRIDE + 8; constexpr int kqar = sizeof(KQ_acc_t)/sizeof(half); const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. const float * Q_f = (const float *) (Q + nb02* blockIdx.y + nb01*ic0); const half * K_h = (const half *) (K + nb12*(blockIdx.y / gqa_ratio)); - const half * V_h = (const half *) (V + nb12*(blockIdx.y / gqa_ratio)); // K and V have same shape + const half * V_h = (const half *) (V + nb22*(blockIdx.y / gqa_ratio)); // K and V have same shape const half * maskh = (const half *) mask + (nb31/sizeof(half))* ic0; const half2 * mask2 = (const half2 *) mask + (nb31/sizeof(half))*(ic0/2); - const int stride_Q = nb01 / sizeof(float); - const int stride_KV = nb11 / sizeof(half); + const int stride_Q = nb01 / sizeof(float); + const int stride_K = nb11 / sizeof(half); + const int stride_V = nb21 / sizeof(half); const float slopef = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1); const half slopeh = __float2half(slopef); const half2 slope2 = make_half2(slopef, slopef); const half2 softcap_2 = make_half2(softcap, softcap); - frag_b Q_b[D/16][ncols/frag_n]; + frag_b Q_b[Dk/16][ncols/frag_n]; // A single buffer for temporarily holding tiles of KQ and VKQ parts: constexpr int mem_KQ = ncols*kqs_padded*kqar; - constexpr int mem_VKQ_parts = VKQ_ratio*ncols*D_padded; + constexpr int mem_VKQ_parts = VKQ_ratio*ncols*Dv_padded; __shared__ half KQ[mem_KQ >= mem_VKQ_parts ? mem_KQ : mem_VKQ_parts]; float * KQ_f = (float *) KQ; half2 * KQ2 = (half2 *) KQ; @@ -120,18 +123,18 @@ static __global__ void flash_attn_ext_f16( KQ_max_h2[j] = make_half2(-HALF_MAX_HALF, -HALF_MAX_HALF); } - __shared__ half VKQ[ncols*D_padded]; // Accumulator for final VKQ slice. + __shared__ half VKQ[ncols*Dv_padded]; // Accumulator for final VKQ slice. half2 * VKQ2 = (half2 *) VKQ; #pragma unroll for (int j0 = 0; j0 < ncols; j0 += nwarps) { const int j = j0 + threadIdx.y; #pragma unroll - for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { + for (int i0 = 0; i0 < Dv/2; i0 += WARP_SIZE) { const int i = i0 + threadIdx.x; - if (i0 + WARP_SIZE > D/2 && i >= D/2) { + if (i0 + WARP_SIZE > Dv/2 && i >= Dv/2) { break; } - VKQ2[j*(D_padded/2) + i] = make_half2(0.0f, 0.0f); + VKQ2[j*(Dv_padded/2) + i] = make_half2(0.0f, 0.0f); } } @@ -140,12 +143,12 @@ static __global__ void flash_attn_ext_f16( for (int j0 = 0; j0 < ncols; j0 += nwarps) { const int j = j0 + threadIdx.y; #pragma unroll - for (int i0 = 0; i0 < D; i0 += WARP_SIZE) { + for (int i0 = 0; i0 < Dk; i0 += WARP_SIZE) { const int i = i0 + threadIdx.x; - if (i0 + WARP_SIZE > D && i >= D) { + if (i0 + WARP_SIZE > Dk && i >= Dk) { break; } - KQ[j*D_padded + i] = ic0 + j < ne01 ? Q_f[j*stride_Q + i] * scale : 0.0f; + KQ[j*Dk_padded + i] = ic0 + j < ne01 ? Q_f[j*stride_Q + i] * scale : 0.0f; } } @@ -153,10 +156,10 @@ static __global__ void flash_attn_ext_f16( // Load Q into tensor core fragments/registers since it will be used frequently: #pragma unroll - for (int i0 = 0; i0 < D; i0 += 16) { + for (int i0 = 0; i0 < Dk; i0 += 16) { #pragma unroll for (int j0 = 0; j0 < ncols; j0 += frag_n) { - nvcuda::wmma::load_matrix_sync(Q_b[i0/16][j0/frag_n], KQ + j0*D_padded + i0, D_padded); + nvcuda::wmma::load_matrix_sync(Q_b[i0/16][j0/frag_n], KQ + j0*Dk_padded + i0, Dk_padded); } } @@ -173,9 +176,9 @@ static __global__ void flash_attn_ext_f16( nvcuda::wmma::fill_fragment(KQ_c[j], 0.0f); } #pragma unroll - for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += 16) { + for (int k_KQ_0 = 0; k_KQ_0 < Dk; k_KQ_0 += 16) { frag_a_K K_a; - nvcuda::wmma::load_matrix_sync(K_a, K_h + (k_VKQ_0 + i_KQ_0 + frag_m*threadIdx.y)*stride_KV + k_KQ_0, stride_KV); + nvcuda::wmma::load_matrix_sync(K_a, K_h + (k_VKQ_0 + i_KQ_0 + frag_m*threadIdx.y)*stride_K + k_KQ_0, stride_K); #pragma unroll for (int j = 0; j < ncols/frag_n; ++j) { nvcuda::wmma::mma_sync(KQ_c[j], K_a, Q_b[k_KQ_0/16][j], KQ_c[j]); @@ -309,9 +312,9 @@ static __global__ void flash_attn_ext_f16( } } - frag_c_VKQ VKQ_c[D/VKQ_stride][ncols/frag_n]; + frag_c_VKQ VKQ_c[Dv/VKQ_stride][ncols/frag_n]; #pragma unroll - for (int i_VKQ_0 = 0; i_VKQ_0 < D; i_VKQ_0 += VKQ_stride) { + for (int i_VKQ_0 = 0; i_VKQ_0 < Dv; i_VKQ_0 += VKQ_stride) { #pragma unroll for (int j = 0; j < ncols/frag_n; ++j) { nvcuda::wmma::fill_fragment(VKQ_c[i_VKQ_0/VKQ_stride][j], 0.0f); @@ -322,7 +325,7 @@ static __global__ void flash_attn_ext_f16( const int k = k0 + (threadIdx.y % VKQ_ratio)*16; frag_a_V v_a; - nvcuda::wmma::load_matrix_sync(v_a, V_h + (k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx.y/VKQ_ratio), stride_KV); + nvcuda::wmma::load_matrix_sync(v_a, V_h + (k_VKQ_0 + k)*stride_V + i_VKQ_0 + frag_m*(threadIdx.y/VKQ_ratio), stride_V); #pragma unroll for (int j = 0; j < ncols/frag_n; ++j) { nvcuda::wmma::mma_sync(VKQ_c[i_VKQ_0/VKQ_stride][j], v_a, KQ_b[k0/(VKQ_ratio*16)][j], VKQ_c[i_VKQ_0/VKQ_stride][j]); @@ -332,15 +335,15 @@ static __global__ void flash_attn_ext_f16( __syncthreads(); - const int offset_k = (threadIdx.y % VKQ_ratio) * (ncols*D_padded); + const int offset_k = (threadIdx.y % VKQ_ratio) * (ncols*Dk_padded); #pragma unroll - for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += VKQ_stride) { + for (int i_KQ_0 = 0; i_KQ_0 < Dk; i_KQ_0 += VKQ_stride) { #pragma unroll for (int j0 = 0; j0 < ncols; j0 += frag_n) { nvcuda::wmma::store_matrix_sync( - KQ + offset_k + j0*D_padded + i_KQ_0 + frag_m*(threadIdx.y/VKQ_ratio), + KQ + offset_k + j0*Dk_padded + i_KQ_0 + frag_m*(threadIdx.y/VKQ_ratio), VKQ_c[i_KQ_0/VKQ_stride][j0/frag_n], - D_padded, nvcuda::wmma::mem_col_major); + Dk_padded, nvcuda::wmma::mem_col_major); } } @@ -358,18 +361,18 @@ static __global__ void flash_attn_ext_f16( } #pragma unroll - for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { + for (int i0 = 0; i0 < Dv/2; i0 += WARP_SIZE) { const int i = i0 + threadIdx.x; - if (i0 + WARP_SIZE > D/2 && i >= D/2) { + if (i0 + WARP_SIZE > Dv/2 && i >= Dv/2) { break; } half2 VKQ_add = make_half2(0.0f, 0.0f); #pragma unroll for (int l = 0; l < VKQ_ratio; ++l) { - VKQ_add += KQ2[l*(ncols*D_padded/2) + j*(D_padded/2) + i]; + VKQ_add += KQ2[l*(ncols*Dk_padded/2) + j*(Dk_padded/2) + i]; } - VKQ2[j*(D_padded/2) + i] = VKQ_scale*VKQ2[j*(D_padded/2) + i] + VKQ_add; + VKQ2[j*(Dv_padded/2) + i] = VKQ_scale*VKQ2[j*(Dv_padded/2) + i] + VKQ_add; } } @@ -392,16 +395,16 @@ static __global__ void flash_attn_ext_f16( } #pragma unroll - for (int i0 = 0; i0 < D; i0 += WARP_SIZE) { + for (int i0 = 0; i0 < Dv; i0 += WARP_SIZE) { const int i = i0 + threadIdx.x; - if (i0 + WARP_SIZE > D && i >= D) { + if (i0 + WARP_SIZE > Dv && i >= Dv) { break; } - float dst_val = VKQ[j_VKQ*D_padded + i]; + float dst_val = VKQ[j_VKQ*Dv_padded + i]; if (parallel_blocks == 1) { dst_val /= KQ_rowsum_j; } - dst[j_dst*gridDim.y*D + blockIdx.y*D + i] = dst_val; + dst[j_dst*gridDim.y*Dv + blockIdx.y*Dv + i] = dst_val; } if (parallel_blocks == 1 || threadIdx.x != 0) { @@ -446,13 +449,13 @@ static_assert(get_VKQ_stride( 80, 1, 16) == 16, "Test failed."); static_assert(get_VKQ_stride( 80, 2, 16) == 16, "Test failed."); static_assert(get_VKQ_stride( 80, 4, 16) == 16, "Test failed."); -template +template void ggml_cuda_flash_attn_ext_wmma_f16_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * Q = dst->src[0]; constexpr int nwarps = 4; - constexpr int frag_m = cols_per_block == 8 && D % 32 == 0 ? 32 : 16; + constexpr int frag_m = cols_per_block == 8 && Dk % 32 == 0 ? 32 : 16; const int blocks_num_pb1 = ((Q->ne[1] + cols_per_block - 1) / cols_per_block)*Q->ne[2]*Q->ne[3]; const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm; @@ -462,29 +465,33 @@ void ggml_cuda_flash_attn_ext_wmma_f16_case(ggml_backend_cuda_context & ctx, ggm if (4*blocks_num_pb1 < 2*nsm) { constexpr int parallel_blocks = 4; fattn_kernel_t fattn_kernel = softcap == 0.0f ? - flash_attn_ext_f16 : - flash_attn_ext_f16; - launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true); + flash_attn_ext_f16 : + flash_attn_ext_f16; + launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true); return; } if (2*blocks_num_pb1 < 2*nsm) { constexpr int parallel_blocks = 2; fattn_kernel_t fattn_kernel = softcap == 0.0f ? - flash_attn_ext_f16 : - flash_attn_ext_f16; - launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true); + flash_attn_ext_f16 : + flash_attn_ext_f16; + launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true); return; } constexpr int parallel_blocks = 1; fattn_kernel_t fattn_kernel = softcap == 0.0f ? - flash_attn_ext_f16 : - flash_attn_ext_f16; - launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true); + flash_attn_ext_f16 : + flash_attn_ext_f16; + launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true); } #define DECL_FATTN_WMMA_F16_CASE(D, cols_per_block, KQ_acc_t) \ template void ggml_cuda_flash_attn_ext_wmma_f16_case \ - (ggml_backend_cuda_context & ctx, ggml_tensor * dst) \ + (ggml_backend_cuda_context & ctx, ggml_tensor * dst) \ + +#define DECL_FATTN_WMMA_F16_CASE_DKDV(Dk, Dv, cols_per_block, KQ_acc_t) \ + template void ggml_cuda_flash_attn_ext_wmma_f16_case \ + (ggml_backend_cuda_context & ctx, ggml_tensor * dst) \ extern DECL_FATTN_WMMA_F16_CASE( 64, 16, float); extern DECL_FATTN_WMMA_F16_CASE( 80, 16, float); @@ -518,3 +525,7 @@ extern DECL_FATTN_WMMA_F16_CASE( 96, 32, half); extern DECL_FATTN_WMMA_F16_CASE(112, 32, half); extern DECL_FATTN_WMMA_F16_CASE(128, 32, half); extern DECL_FATTN_WMMA_F16_CASE(256, 16, half); + +extern DECL_FATTN_WMMA_F16_CASE_DKDV(192, 128, 8, half); +extern DECL_FATTN_WMMA_F16_CASE_DKDV(192, 128, 16, half); +extern DECL_FATTN_WMMA_F16_CASE_DKDV(192, 128, 32, half); diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu index c93da895..bbfc1ccb 100644 --- a/ggml/src/ggml-cuda/fattn.cu +++ b/ggml/src/ggml-cuda/fattn.cu @@ -12,6 +12,14 @@ static void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * KQV = dst; const ggml_tensor * Q = dst->src[0]; + const ggml_tensor * V = dst->src[2]; + + if (Q->ne[0] != V->ne[0]) { + if (!((Q->ne[0] == 192 && V->ne[0] == 128) || (Q->ne[0] == 576 && V->ne[0] == 512))) { + fprintf(stderr, "======================= %s: Unhandled head size combination %d, %d\n", __func__, (int)Q->ne[0], (int)V->ne[0]); + GGML_ABORT("fatal error"); + } + } const int32_t precision = KQV->op_params[3]; @@ -20,22 +28,25 @@ static void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, g constexpr int cols_per_block = 16; switch (Q->ne[0]) { case 64: - ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, float>(ctx, dst); + ggml_cuda_flash_attn_ext_wmma_f16_case< 64, 64, cols_per_block, float>(ctx, dst); break; case 80: - ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, float>(ctx, dst); + ggml_cuda_flash_attn_ext_wmma_f16_case< 80, 80, cols_per_block, float>(ctx, dst); break; case 96: - ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, float>(ctx, dst); + ggml_cuda_flash_attn_ext_wmma_f16_case< 96, 96, cols_per_block, float>(ctx, dst); break; case 112: - ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, float>(ctx, dst); + ggml_cuda_flash_attn_ext_wmma_f16_case<112, 112, cols_per_block, float>(ctx, dst); break; case 128: - ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, float>(ctx, dst); + ggml_cuda_flash_attn_ext_wmma_f16_case<128, 128, cols_per_block, float>(ctx, dst); break; case 256: - ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, float>(ctx, dst); + ggml_cuda_flash_attn_ext_wmma_f16_case<256, 256, cols_per_block, float>(ctx, dst); + break; + case 192: + ggml_cuda_flash_attn_ext_wmma_f16_case<192, 128, cols_per_block, float>(ctx, dst); break; default: fprintf(stderr, "======================= %s: Unhandled head size %d\n", __func__, (int)Q->ne[0]); @@ -46,19 +57,22 @@ static void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, g constexpr int cols_per_block = 32; switch (Q->ne[0]) { case 64: - ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, float>(ctx, dst); + ggml_cuda_flash_attn_ext_wmma_f16_case< 64, 64, cols_per_block, float>(ctx, dst); break; case 80: - ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, float>(ctx, dst); + ggml_cuda_flash_attn_ext_wmma_f16_case< 80, 80, cols_per_block, float>(ctx, dst); break; case 96: - ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, float>(ctx, dst); + ggml_cuda_flash_attn_ext_wmma_f16_case< 96, 96, cols_per_block, float>(ctx, dst); break; case 112: - ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, float>(ctx, dst); + ggml_cuda_flash_attn_ext_wmma_f16_case<112, 112, cols_per_block, float>(ctx, dst); break; case 128: - ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, float>(ctx, dst); + ggml_cuda_flash_attn_ext_wmma_f16_case<128, 128, cols_per_block, float>(ctx, dst); + break; + case 192: + ggml_cuda_flash_attn_ext_wmma_f16_case<192, 128, cols_per_block, float>(ctx, dst); break; // case 256: // ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, float>(ctx, dst); @@ -76,16 +90,19 @@ static void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, g constexpr int cols_per_block = 8; switch (Q->ne[0]) { case 64: - ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, half>(ctx, dst); + ggml_cuda_flash_attn_ext_wmma_f16_case< 64, 64, cols_per_block, half>(ctx, dst); break; case 96: - ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, half>(ctx, dst); + ggml_cuda_flash_attn_ext_wmma_f16_case< 96, 96, cols_per_block, half>(ctx, dst); break; case 128: - ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, half>(ctx, dst); + ggml_cuda_flash_attn_ext_wmma_f16_case<128, 128, cols_per_block, half>(ctx, dst); + break; + case 192: + ggml_cuda_flash_attn_ext_wmma_f16_case<192, 128, cols_per_block, half>(ctx, dst); break; case 256: - ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, half>(ctx, dst); + ggml_cuda_flash_attn_ext_wmma_f16_case<256, 256, cols_per_block, half>(ctx, dst); break; default: fprintf(stderr, "======================= %s: Unhandled head size %d\n", __func__, (int)Q->ne[0]); @@ -99,22 +116,25 @@ static void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, g constexpr int cols_per_block = 16; switch (Q->ne[0]) { case 64: - ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, half>(ctx, dst); + ggml_cuda_flash_attn_ext_wmma_f16_case< 64, 64, cols_per_block, half>(ctx, dst); break; case 80: - ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, half>(ctx, dst); + ggml_cuda_flash_attn_ext_wmma_f16_case< 80, 80, cols_per_block, half>(ctx, dst); break; case 96: - ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, half>(ctx, dst); + ggml_cuda_flash_attn_ext_wmma_f16_case< 96, 96, cols_per_block, half>(ctx, dst); break; case 112: - ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, half>(ctx, dst); + ggml_cuda_flash_attn_ext_wmma_f16_case<112, 112, cols_per_block, half>(ctx, dst); break; case 128: - ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, half>(ctx, dst); + ggml_cuda_flash_attn_ext_wmma_f16_case<128, 128, cols_per_block, half>(ctx, dst); + break; + case 192: + ggml_cuda_flash_attn_ext_wmma_f16_case<192, 128, cols_per_block, half>(ctx, dst); break; case 256: - ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, half>(ctx, dst); + ggml_cuda_flash_attn_ext_wmma_f16_case<256, 256, cols_per_block, half>(ctx, dst); break; default: fprintf(stderr, "======================= %s: Unhandled head size %d\n", __func__, (int)Q->ne[0]); @@ -127,22 +147,25 @@ static void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, g constexpr int cols_per_block = 32; switch (Q->ne[0]) { case 64: - ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, half>(ctx, dst); + ggml_cuda_flash_attn_ext_wmma_f16_case< 64, 64, cols_per_block, half>(ctx, dst); break; case 80: - ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, half>(ctx, dst); + ggml_cuda_flash_attn_ext_wmma_f16_case< 80, 80, cols_per_block, half>(ctx, dst); break; case 96: - ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, half>(ctx, dst); + ggml_cuda_flash_attn_ext_wmma_f16_case< 96, 96, cols_per_block, half>(ctx, dst); break; case 112: - ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, half>(ctx, dst); + ggml_cuda_flash_attn_ext_wmma_f16_case<112, 112, cols_per_block, half>(ctx, dst); break; case 128: - ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, half>(ctx, dst); + ggml_cuda_flash_attn_ext_wmma_f16_case<128, 128, cols_per_block, half>(ctx, dst); + break; + case 192: + ggml_cuda_flash_attn_ext_wmma_f16_case<192, 128, cols_per_block, half>(ctx, dst); break; case 256: - ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, half>(ctx, dst); + ggml_cuda_flash_attn_ext_wmma_f16_case<256, 256, cols_per_block, half>(ctx, dst); break; default: fprintf(stderr, "======================= %s: Unhandled head size %d\n", __func__, (int)Q->ne[0]); @@ -156,6 +179,12 @@ static void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, g return; \ } \ +#define FATTN_VEC_F16_CASE_DKDV(Dk, Dv, type_K, type_V) \ + if (Q->ne[0] == (Dk) && V->ne[0] == Dv && K->type == (type_K) && V->type == (type_V)) { \ + ggml_cuda_flash_attn_ext_vec_f16_case(ctx, dst); \ + return; \ + } \ + static void ggml_cuda_flash_attn_ext_vec_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { ggml_tensor * Q = dst->src[0]; ggml_tensor * K = dst->src[1]; diff --git a/ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqfloat-cpb16.cu b/ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqfloat-cpb16.cu index 2d94e65c..334e1deb 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqfloat-cpb16.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqfloat-cpb16.cu @@ -8,3 +8,5 @@ DECL_FATTN_WMMA_F16_CASE(96, 16, float); DECL_FATTN_WMMA_F16_CASE(112, 16, float); DECL_FATTN_WMMA_F16_CASE(128, 16, float); DECL_FATTN_WMMA_F16_CASE(256, 16, float); + +DECL_FATTN_WMMA_F16_CASE_DKDV(192, 128, 16, float); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqfloat-cpb32.cu b/ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqfloat-cpb32.cu index c3d9df3c..1faf3c9b 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqfloat-cpb32.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqfloat-cpb32.cu @@ -7,3 +7,5 @@ DECL_FATTN_WMMA_F16_CASE(80, 32, float); DECL_FATTN_WMMA_F16_CASE(96, 32, float); DECL_FATTN_WMMA_F16_CASE(112, 32, float); DECL_FATTN_WMMA_F16_CASE(128, 32, float); + +DECL_FATTN_WMMA_F16_CASE_DKDV(192, 128, 32, float); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb16.cu b/ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb16.cu index bb680e40..48973618 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb16.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb16.cu @@ -8,3 +8,5 @@ DECL_FATTN_WMMA_F16_CASE(96, 16, half); DECL_FATTN_WMMA_F16_CASE(112, 16, half); DECL_FATTN_WMMA_F16_CASE(128, 16, half); DECL_FATTN_WMMA_F16_CASE(256, 16, half); + +DECL_FATTN_WMMA_F16_CASE_DKDV(192, 128, 16, half); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb32.cu b/ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb32.cu index 073f71b1..ed92963e 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb32.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb32.cu @@ -8,3 +8,5 @@ DECL_FATTN_WMMA_F16_CASE(96, 32, half); DECL_FATTN_WMMA_F16_CASE(112, 32, half); DECL_FATTN_WMMA_F16_CASE(128, 32, half); DECL_FATTN_WMMA_F16_CASE(256, 32, half); + +DECL_FATTN_WMMA_F16_CASE_DKDV(192, 128, 32, half); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb8.cu b/ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb8.cu index d30710c5..4e221003 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb8.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb8.cu @@ -6,3 +6,5 @@ DECL_FATTN_WMMA_F16_CASE(64, 8, half); DECL_FATTN_WMMA_F16_CASE(96, 8, half); DECL_FATTN_WMMA_F16_CASE(128, 8, half); DECL_FATTN_WMMA_F16_CASE(256, 8, half); + +DECL_FATTN_WMMA_F16_CASE_DKDV(192, 128, 8, half);