diff --git a/ggml/src/ggml-cuda/fattn-new-mma.cu b/ggml/src/ggml-cuda/fattn-new-mma.cu index 26b605b4..48a13168 100644 --- a/ggml/src/ggml-cuda/fattn-new-mma.cu +++ b/ggml/src/ggml-cuda/fattn-new-mma.cu @@ -39,6 +39,8 @@ typedef void (* fattn_new_mma_kernel_t)( const int ne13, const int ne31, const int nb31, + const int ne33, + const int nb33, const int nb01, const int nb02, const int nb03, @@ -211,37 +213,37 @@ struct fattn_mma_f16_config; // } //}; // -//template <> -//struct fattn_mma_f16_config<128, 128> { -// static constexpr int nbatch_fa = 64; -// static constexpr int nwarps_max = 4; -// static constexpr bool Q_in_reg = true; -// static constexpr int nstages_target = 2; -// -// static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) { -// return 64; -// } -// -// static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) { -// return 64; -// } -// -// static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) { -// return 64; -// } -// -// static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) { -// return 64; -// } -// -// static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) { -// return 64; -// } -// -// static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) { -// return 64; -// } -//}; +template <> +struct fattn_mma_f16_config<128, 128> { + static constexpr int nbatch_fa = 64; + static constexpr int nwarps_max = 4; + static constexpr bool Q_in_reg = true; + static constexpr int nstages_target = 2; + + static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) { + return 64; + } + + static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) { + return 64; + } + + static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) { + return 64; + } + + static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) { + return 64; + } + + static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) { + return 64; + } + + static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) { + return 64; + } +}; // //template <> //struct fattn_mma_f16_config<256, 256> { @@ -930,6 +932,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( const float logit_softcap, const int ne01, const int ne02, + const int gqa_ratio, const int stride_Q1, const int stride_Q2, const int stride_K, @@ -1009,7 +1012,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( const int j = jc / ncols2; const int c = jc % ncols2; - if ((ncols1 == 1 || jt*ncols1 + j < ne01) && (ncols2 == 1 || zt*ncols2 + c < ne02)) { + if ((ncols1 == 1 || jt*ncols1 + j < ne01) && (ncols2 == 1 || zt*ncols2 + c < gqa_ratio)) { #pragma unroll for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) { const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k); @@ -1337,7 +1340,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( const int j_dst = jc_dst / ncols2; const int c_dst = jc_dst % ncols2; - if (!is_fixup && ((ncols1 > 1 && jt*ncols1 + j_dst >= ne01) || (ncols2 > 1 && zt*ncols2 + c_dst >= ne02))) { + if (!is_fixup && ((ncols1 > 1 && jt*ncols1 + j_dst >= ne01) || (ncols2 > 1 && zt*ncols2 + c_dst >= gqa_ratio))) { continue; } @@ -1381,6 +1384,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(stride_Q1); GGML_UNUSED(stride_Q2); GGML_UNUSED(stride_K); GGML_UNUSED(stride_V); GGML_UNUSED(stride_mask); GGML_UNUSED(jt); GGML_UNUSED(kb0_start); GGML_UNUSED(kb0_stop); + GGML_UNUSED(gqa_ratio); NO_DEVICE_CODE; #endif // INT8_MMA_AVAILABLE } @@ -1404,7 +1408,7 @@ static __global__ void flash_attn_ext_f16( const uint32_t n_head_log2, const int ne00, const int ne01, const int ne02, const int ne03, const int ne10, const int ne11, const int ne12, const int ne13, - const int ne31, const int nb31, + const int ne31, const int nb31, const int ne33, const int nb33, const int nb01, const int nb02, const int nb03, const int nb11, const int nb12, const int nb13, const int nb21, const int nb22, const int nb23, @@ -1440,13 +1444,13 @@ static __global__ void flash_attn_ext_f16( const int iter_k = ne11 / FATTN_KQ_STRIDE; const int iter_j = (ne01 + (ncols1 - 1)) / ncols1; - const int iter_z = (ne02 + (ncols2 - 1)) / ncols2; + const int iter_z = (gqa_ratio + (ncols2 - 1)) / ncols2; constexpr int kb_niter = FATTN_KQ_STRIDE / c::nbatch_fa; // Number of kernel iterations per assigned KQ slice. // kbc == k block continuous, current index in continuous ijk space. - int kbc = int64_t((blockIdx.x + 0)*iter_k*iter_j*iter_z) / gridDim.x; - const int kbc_stop = int64_t((blockIdx.x + 1)*iter_k*iter_j*iter_z) / gridDim.x; + int kbc = int64_t((blockIdx.x + 0)*iter_k*iter_j*iter_z*ne12*ne03) / gridDim.x; + const int kbc_stop = int64_t((blockIdx.x + 1)*iter_k*iter_j*iter_z*ne12*ne03) / gridDim.x; // If the seams of 2 CUDA blocks fall within an output tile their results need to be combined. // For this we need to track both the block that starts the tile (needs_fixup) and the block that finishes the tile (is_fixup). @@ -1456,19 +1460,22 @@ static __global__ void flash_attn_ext_f16( int kb0_start = kbc % iter_k; int kb0_stop = min(iter_k, kb0_start + kbc_stop - kbc); while (kbc < kbc_stop && kb0_stop == iter_k) { - const int sequence = kbc / (iter_k*iter_j*iter_z); - const int zt = (kbc - iter_k*iter_j*iter_z*sequence) / (iter_k*iter_j); - const int jt = (kbc - iter_k*iter_j*iter_z*sequence - iter_k*iter_j*zt) / iter_k; + const int sequence = kbc /(iter_k*iter_j*iter_z*ne12); + const int z_KV = (kbc - iter_k*iter_j*iter_z*ne12 * sequence)/(iter_k*iter_j*iter_z); + const int zt = (kbc - iter_k*iter_j*iter_z*ne12 * sequence - iter_k*iter_j*iter_z * z_KV)/(iter_k*iter_j); + const int jt = (kbc - iter_k*iter_j*iter_z*ne12 * sequence - iter_k*iter_j*iter_z * z_KV - iter_k*iter_j * zt) / iter_k; + const int zt_Q = z_KV*gqa_ratio + zt*ncols2; // Global Q head start index. - const float2 * Q_f2 = (const float2 *) (Q + nb02*zt*ncols2); - const half2 * K_h2 = (const half2 *) (K + nb12*(zt*ncols2 / gqa_ratio)); + const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02*zt_Q); + const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*z_KV); const half2 * mask_h2 = ncols2 > 1 || mask ? (const half2 *) mask + (nb31/sizeof(half2))*jt*ncols1 : nullptr; - float2 * dstk = ((float2 *) dst) + zt*(ncols2 * DV/2); - const float * sinks_f = sinks ? (const float *) sinks + zt*ncols2 : nullptr; + //const half2 * mask_h2 = ncols2 == 1 && !mask ? nullptr : (const half2 *) (mask + nb33*(sequence % ne33)); + float2 * dstk = ((float2 *) dst) + (sequence*ne01*ne02 + zt_Q) * (DV/2); - const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb22*(zt*ncols2 / gqa_ratio)); + const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb23*sequence + nb22*z_KV); + const float * sinks_f = sinks ? (const float *) sinks + zt_Q : nullptr; - const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, zt, n_head_log2, m0, m1) : 1.0f; + const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, zt_Q, n_head_log2, m0, m1) : 1.0f; int kb0_start_kernel = kb0_start * kb_niter; int kb0_stop_kernel = kb0_stop * kb_niter; @@ -1481,12 +1488,12 @@ static __global__ void flash_attn_ext_f16( constexpr bool needs_fixup = false; // CUDA block is working on an entire tile. flash_attn_ext_f16_process_tile (Q_f2, K_h2, V_h2, mask_h2, sinks_f, dstk, dst_meta, scale, slope, logit_softcap, - ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, zt, kb0_start_kernel, kb0_stop_kernel); + ne01, ne02, gqa_ratio, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, zt, kb0_start_kernel, kb0_stop_kernel); } else { constexpr bool needs_fixup = true; // CUDA block is working on the beginning of a tile. flash_attn_ext_f16_process_tile (Q_f2, K_h2, V_h2, mask_h2, sinks_f, dstk, dst_meta, scale, slope, logit_softcap, - ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, zt, kb0_start_kernel, kb0_stop_kernel); + ne01, ne02, gqa_ratio, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, zt, kb0_start_kernel, kb0_stop_kernel); } kbc += iter_k; @@ -1500,20 +1507,23 @@ static __global__ void flash_attn_ext_f16( return; } + const int sequence = kbc /(iter_k*iter_j*iter_z*ne12); + const int z_KV = (kbc - iter_k*iter_j*iter_z*ne12 * sequence)/(iter_k*iter_j*iter_z); + const int zt = (kbc - iter_k*iter_j*iter_z*ne12 * sequence - iter_k*iter_j*iter_z * z_KV)/(iter_k*iter_j); + const int jt = (kbc - iter_k*iter_j*iter_z*ne12 * sequence - iter_k*iter_j*iter_z * z_KV - iter_k*iter_j * zt) / iter_k; - const int sequence = kbc / (iter_k*iter_j*iter_z); - const int zt = (kbc - iter_k*iter_j*iter_z*sequence) / (iter_k*iter_j); // head in units of ncols2 - const int jt = (kbc - iter_k*iter_j*iter_z*sequence - iter_k*iter_j*zt) / iter_k; // j index of current tile. + const int zt_Q = z_KV*gqa_ratio + zt*ncols2; // Global Q head start index. - const float2 * Q_f2 = (const float2 *) (Q + nb02* zt*ncols2); - const half2 * K_h2 = (const half2 *) (K + nb12*(zt*ncols2 / gqa_ratio)); + const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02*zt_Q); + const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*z_KV); const half2 * mask_h2 = ncols2 > 1 || mask ? (const half2 *) mask + (nb31/sizeof(half2))*jt*ncols1 : nullptr; - float2 * dstk = ((float2 *) dst) + zt*(ncols2 * DV/2); - const float * sinks_f = sinks ? (const float *) sinks + zt*ncols2 : nullptr; + //const half2 * mask_h2 = ncols2 == 1 && !mask ? nullptr : (const half2 *) (mask + nb33*(sequence % ne33)); + float2 * dstk = ((float2 *) dst) + (sequence*ne01*ne02 + zt_Q) * (DV/2); - const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb22*(zt*ncols2 / gqa_ratio)); + const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb23*sequence + nb22*z_KV); + const float * sinks_f = sinks ? (const float *) sinks + zt_Q : nullptr; - const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, zt, n_head_log2, m0, m1) : 1.0f; + const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, zt_Q, n_head_log2, m0, m1) : 1.0f; int kb0_start_kernel = kb0_start * kb_niter; int kb0_stop_kernel = kb0_stop * kb_niter; @@ -1525,7 +1535,7 @@ static __global__ void flash_attn_ext_f16( constexpr bool needs_fixup = false; flash_attn_ext_f16_process_tile (Q_f2, K_h2, V_h2, mask_h2, sinks_f, dstk, dst_meta, scale, slope, logit_softcap, - ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, zt, kb0_start_kernel, kb0_stop_kernel); + ne01, ne02, gqa_ratio, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, zt, kb0_start_kernel, kb0_stop_kernel); #else GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask); GGML_UNUSED(dst); GGML_UNUSED(dst_meta); GGML_UNUSED(scale); @@ -1545,7 +1555,8 @@ static __global__ void flash_attn_ext_f16( template // D == head size __launch_bounds__(D, 1) static __global__ void flash_attn_stream_k_fixup( - float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne11) { + float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, + const int ne11, const int ne12) { constexpr int ncols = ncols1*ncols2; constexpr int ne03 = 1; @@ -1557,12 +1568,14 @@ static __global__ void flash_attn_stream_k_fixup( const float * dst_fixup_data = ((const float *) dst_fixup) + gridDim.x*(2*2*ncols); - const int iter_k = ne11 / FATTN_KQ_STRIDE; - const int iter_j = (ne01 + (ncols1 - 1)) / ncols1; - const int iter_z = (ne02 + (ncols2 - 1)) / ncols2; + const int gqa_ratio = ne02/ne12; - const int kbc0 = int64_t((bidx0 + 0)*iter_k*iter_j*iter_z*ne03) / gridDim.x; - const int kbc0_stop = int64_t((bidx0 + 1)*iter_k*iter_j*iter_z*ne03) / gridDim.x; + const int iter_k = ne11 / FATTN_KQ_STRIDE; // In our implementation ne11 is always a multiple of FATTN_KQ_STRIDE + const int iter_j = (ne01 + (ncols1 - 1)) / ncols1; + const int iter_z = (gqa_ratio + (ncols2 - 1)) / ncols2; + + const int kbc0 = int64_t((bidx0 + 0)*iter_k*iter_j*iter_z*ne12*ne03) / gridDim.x; + const int kbc0_stop = int64_t((bidx0 + 1)*iter_k*iter_j*iter_z*ne12*ne03) / gridDim.x; const bool did_not_have_any_data = kbc0 == kbc0_stop; const bool wrote_beginning_of_tile = kbc0 % iter_k == 0; @@ -1571,15 +1584,18 @@ static __global__ void flash_attn_stream_k_fixup( return; } - const int sequence = kbc0 / (iter_k*iter_j*iter_z); - const int zt = (kbc0 - iter_k*iter_j*iter_z*sequence) / (iter_k*iter_j); - const int jt = (kbc0 - iter_k*iter_j*iter_z*sequence - iter_k*iter_j*zt) / iter_k; // j index of current tile. + const int sequence = kbc0 /(iter_k*iter_j*iter_z*ne12); + const int z_KV = (kbc0 - iter_k*iter_j*iter_z*ne12 * sequence)/(iter_k*iter_j*iter_z); + const int zt_gqa = (kbc0 - iter_k*iter_j*iter_z*ne12 * sequence - iter_k*iter_j*iter_z * z_KV)/(iter_k*iter_j); + const int jt = (kbc0 - iter_k*iter_j*iter_z*ne12 * sequence - iter_k*iter_j*iter_z * z_KV - iter_k*iter_j * zt_gqa) / iter_k; - if (jt*ncols1 + j >= ne01 || zt*ncols2 + c >= ne02) { + const int zt_Q = z_KV*gqa_ratio + zt_gqa*ncols2; // Global Q head start index. + + if (jt*ncols1 + j >= ne01 || zt_gqa*ncols2 + c >= gqa_ratio) { return; } - dst += sequence*ne02*ne01*D + jt*ne02*(ncols1*D) + zt*(ncols2*D) + (j*ne02 + c)*D + tid; + dst += sequence*ne02*ne01*D + jt*ne02*(ncols1*D) + zt_Q*D + (j*ne02 + c)*D + tid; // Load the partial result that needs a fixup: float dst_val = 0.0f; @@ -1598,7 +1614,7 @@ static __global__ void flash_attn_stream_k_fixup( int bidx = bidx0 - 1; int kbc_stop = kbc0; while(true) { - const int kbc = int64_t(bidx)*(iter_k*iter_j*iter_z*ne03) / gridDim.x; + const int kbc = int64_t(bidx)*(iter_k*iter_j*iter_z*ne12*ne03) / gridDim.x; if (kbc == kbc_stop) { // Did not have any data. bidx--; kbc_stop = kbc; @@ -1837,9 +1853,10 @@ static void launch_fattn_new_mma( int parallel_blocks = 1; - const int ntiles_x = ((Q->ne[1] + ncols1 - 1) / ncols1); - const int ntiles_z = ((Q->ne[2] + ncols2 - 1) / ncols2); - const int ntiles_total = ntiles_x * ntiles_z * Q->ne[3]; + const int ntiles_x = ((Q->ne[1] + ncols1 - 1) / ncols1); + const int gqa_ratio = Q->ne[2] / K->ne[2]; + const int ntiles_z = ((gqa_ratio + ncols2 - 1) / ncols2); + const int ntiles_total = ntiles_x * ntiles_z * K->ne[2] * Q->ne[3]; // Optional optimization where the mask is scanned to determine whether part of the calculation can be skipped. // Only worth the overhead if there is at lease one FATTN_KQ_STRIDE x FATTN_KQ_STRIDE square to be skipped or @@ -1914,7 +1931,7 @@ static void launch_fattn_new_mma( blocks_num.x = ntiles_x; blocks_num.y = parallel_blocks; - blocks_num.z = ntiles_z*Q->ne[3]; + blocks_num.z = ntiles_z*K->ne[2]*Q->ne[3]; if (parallel_blocks > 1) { dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV)); @@ -1951,7 +1968,8 @@ static void launch_fattn_new_mma( scale, max_bias, m0, m1, logit_softcap, n_head_log2, Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], K->ne[0], K->ne[1], K->ne[2], K->ne[3], - mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + mask ? mask->ne[3] : 0, mask ? mask->nb[3] : 0, Q->nb[1], Q->nb[2], Q->nb[3], nb11, nb12, nb13, nb21, nb22, nb23, @@ -1966,7 +1984,7 @@ static void launch_fattn_new_mma( flash_attn_stream_k_fixup <<>> - ((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], K->ne[1]); + ((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], K->ne[1], K->ne[2]); } } else if (parallel_blocks > 1) { const dim3 block_dim_combine(DV, 1, 1); @@ -2135,6 +2153,13 @@ void ggml_cuda_flash_attn_ext_mma_new(ggml_backend_cuda_context & ctx, ggml_tens GGML_ASSERT(Q->ne[2] % K->ne[2] == 0); const int gqa_ratio = Q->ne[2] / K->ne[2]; + if (K->ne[0] == 128 && gqa_ratio == 12) { + GGML_ASSERT(Q->ne[0] == 128 && V->ne[0] == 128); + //GGML_ASSERT(Q->ne[1] <= 4); + //ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<128, 128, 16>(ctx, dst); + ggml_cuda_flash_attn_ext_mma_f16_case<128, 128, 1, 16>(ctx, dst); + return; + } //if (K->ne[0] == 64 && V->ne[0] == 64) { // ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2<64, 64>(ctx, dst); // return; diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu index c6babe90..6f1c766c 100644 --- a/ggml/src/ggml-cuda/fattn.cu +++ b/ggml/src/ggml-cuda/fattn.cu @@ -89,6 +89,12 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst return; } + if (new_mma_available(cc) && K->ne[0] == 128 && V->ne[0] == 128 && Q->ne[0] == 128 && Q->ne[1] == 1 && + Q->ne[2] / K->ne[2] == 12) { + ggml_cuda_flash_attn_ext_mma_new(ctx, dst); + return; + } + const bool gqa_opt_applies = ((Q->ne[2] / K->ne[2]) % 2 == 0) && mask; // The mma-based kernels have GQA-specific optimizations // So, not sure why in mainline they thought that for CC_ADA_LOVELACE or when KV cache is not f16 the vector kernels are faster. // On my GPU (RTX-4080) MMA is efinitely faster for GQA, both for f16 and for quantized KV cache.