diff --git a/ggml/src/ggml-cuda/fattn-new-mma.cu b/ggml/src/ggml-cuda/fattn-new-mma.cu index 72a5f2d6..8c22b076 100644 --- a/ggml/src/ggml-cuda/fattn-new-mma.cu +++ b/ggml/src/ggml-cuda/fattn-new-mma.cu @@ -566,10 +566,10 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( constexpr int nstages = 0; #endif // CP_ASYNC_AVAILABLE + constexpr int ncols = ncols1 * ncols2; constexpr int cols_per_warp = ntiles * tile_B::I; constexpr int cols_per_thread = ntiles == 1 ? 2 : ntiles; - constexpr int np = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column. - constexpr int ncols = ncols1 * ncols2; + constexpr int np = cols_per_warp > ncols ? nwarps : nwarps * cols_per_warp/ncols; // Number of parallel CUDA warps per Q column. constexpr int nbatch_K2 = c::get_nbatch_K2_device(ncols); constexpr int nbatch_V2 = c::get_nbatch_V2_device(ncols); @@ -936,6 +936,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( const int stride_V, const int stride_mask, const int jt, + const int zt, const int kb0_start, const int kb0_stop) { #ifdef INT8_MMA_AVAILABLE @@ -952,7 +953,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( constexpr int ncols = ncols1 * ncols2; constexpr int cols_per_warp = ntiles * tile_B::I; constexpr int cols_per_thread = ntiles == 1 ? 2 : ntiles; - constexpr int np = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column. + constexpr int np = cols_per_warp > ncols ? nwarps : nwarps * cols_per_warp/ncols; // Number of parallel CUDA warps per Q column. constexpr int nbatch_K2 = c::get_nbatch_K2_device(ncols); constexpr int nbatch_V2 = c::get_nbatch_V2_device(ncols); @@ -1008,7 +1009,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( const int j = jc / ncols2; const int c = jc % ncols2; - if (jt*ncols1 + j < ne01) { + if ((ncols1 == 1 || jt*ncols1 + j < ne01) && (ncols2 == 1 || zt*ncols2 + c < ne02)) { #pragma unroll for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) { const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k); @@ -1336,7 +1337,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( const int j_dst = jc_dst / ncols2; const int c_dst = jc_dst % ncols2; - if (!is_fixup && jt*ncols1 + j_dst >= ne01) { + if (!is_fixup && ((ncols1 > 1 && jt*ncols1 + j_dst >= ne01) || (ncols2 > 1 && zt*ncols2 + c_dst >= ne02))) { continue; } @@ -1401,29 +1402,13 @@ static __global__ void flash_attn_ext_f16( const float m1, const float logit_softcap, const uint32_t n_head_log2, - const int ne00, - const int ne01, - const int ne02, - const int ne03, - const int ne10, - const int ne11, - const int ne12, - const int ne13, - const int ne31, - const int nb31, - const int nb01, - const int nb02, - const int nb03, - const int nb11, - const int nb12, - const int nb13, - const int nb21, - const int nb22, - const int nb23, - const int ne0, - const int ne1, - const int ne2, - const int ne3) { + const int ne00, const int ne01, const int ne02, const int ne03, + const int ne10, const int ne11, const int ne12, const int ne13, + const int ne31, const int nb31, + const int nb01, const int nb02, const int nb03, + const int nb11, const int nb12, const int nb13, + const int nb21, const int nb22, const int nb23, + const int ne0, const int ne1, const int ne2, const int ne3) { #if defined(INT8_MMA_AVAILABLE) // Skip unused kernel variants for faster compilation: @@ -1455,12 +1440,13 @@ static __global__ void flash_attn_ext_f16( const int iter_k = ne11 / FATTN_KQ_STRIDE; const int iter_j = (ne01 + (ncols1 - 1)) / ncols1; + const int iter_z = (ne02 + (ncols2 - 1)) / ncols2; constexpr int kb_niter = FATTN_KQ_STRIDE / c::nbatch_fa; // Number of kernel iterations per assigned KQ slice. // kbc == k block continuous, current index in continuous ijk space. - int kbc = (blockIdx.x + 0)*iter_k*iter_j*(ne02/ncols2) / gridDim.x; - const int kbc_stop = (blockIdx.x + 1)*iter_k*iter_j*(ne02/ncols2) / gridDim.x; + int kbc = int64_t((blockIdx.x + 0)*iter_k*iter_j*iter_z) / gridDim.x; + const int kbc_stop = int64_t((blockIdx.x + 1)*iter_k*iter_j*iter_z) / gridDim.x; // If the seams of 2 CUDA blocks fall within an output tile their results need to be combined. // For this we need to track both the block that starts the tile (needs_fixup) and the block that finishes the tile (is_fixup). @@ -1470,18 +1456,19 @@ static __global__ void flash_attn_ext_f16( int kb0_start = kbc % iter_k; int kb0_stop = min(iter_k, kb0_start + kbc_stop - kbc); while (kbc < kbc_stop && kb0_stop == iter_k) { - const int channel = kbc / (iter_k*iter_j); - const int jt = (kbc - channel*iter_k*iter_j) / iter_k; // j index of current tile. + const int sequence = kbc / (iter_k*iter_j*iter_z); + const int zt = (kbc - iter_k*iter_j*iter_z*sequence) / (iter_k*iter_j); + const int jt = (kbc - iter_k*iter_j*iter_z*sequence - iter_k*iter_j*zt) / iter_k; - const float2 * Q_f2 = (const float2 *) (Q + nb02* channel*ncols2); - const half2 * K_h2 = (const half2 *) (K + nb12*(channel*ncols2 / gqa_ratio)); + const float2 * Q_f2 = (const float2 *) (Q + nb02*zt*ncols2); + const half2 * K_h2 = (const half2 *) (K + nb12*(zt*ncols2 / gqa_ratio)); const half2 * mask_h2 = ncols2 > 1 || mask ? (const half2 *) mask + (nb31/sizeof(half2))*jt*ncols1 : nullptr; - float2 * dstk = ((float2 *) dst) + channel*(ncols2 * DV/2); - const float * sinks_f = sinks ? (const float *) sinks + channel*ncols2 : nullptr; + float2 * dstk = ((float2 *) dst) + zt*(ncols2 * DV/2); + const float * sinks_f = sinks ? (const float *) sinks + zt*ncols2 : nullptr; - const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio)); + const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb22*(zt*ncols2 / gqa_ratio)); - const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, channel, n_head_log2, m0, m1) : 1.0f; + const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, zt, n_head_log2, m0, m1) : 1.0f; int kb0_start_kernel = kb0_start * kb_niter; int kb0_stop_kernel = kb0_stop * kb_niter; @@ -1494,12 +1481,12 @@ static __global__ void flash_attn_ext_f16( constexpr bool needs_fixup = false; // CUDA block is working on an entire tile. flash_attn_ext_f16_process_tile (Q_f2, K_h2, V_h2, mask_h2, sinks_f, dstk, dst_meta, scale, slope, logit_softcap, - ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel); + ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, zt, kb0_start_kernel, kb0_stop_kernel); } else { constexpr bool needs_fixup = true; // CUDA block is working on the beginning of a tile. flash_attn_ext_f16_process_tile (Q_f2, K_h2, V_h2, mask_h2, sinks_f, dstk, dst_meta, scale, slope, logit_softcap, - ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel); + ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, zt, kb0_start_kernel, kb0_stop_kernel); } kbc += iter_k; @@ -1513,18 +1500,20 @@ static __global__ void flash_attn_ext_f16( return; } - const int channel = kbc / (iter_k*iter_j); - const int jt = (kbc - channel*iter_k*iter_j) / iter_k; // j index of current tile. - const float2 * Q_f2 = (const float2 *) (Q + nb02* channel*ncols2); - const half2 * K_h2 = (const half2 *) (K + nb12*(channel*ncols2 / gqa_ratio)); + const int sequence = kbc / (iter_k*iter_j*iter_z); + const int zt = (kbc - iter_k*iter_j*iter_z*sequence) / (iter_k*iter_j); // head in units of ncols2 + const int jt = (kbc - iter_k*iter_j*iter_z*sequence - iter_k*iter_j*zt) / iter_k; // j index of current tile. + + const float2 * Q_f2 = (const float2 *) (Q + nb02* zt*ncols2); + const half2 * K_h2 = (const half2 *) (K + nb12*(zt*ncols2 / gqa_ratio)); const half2 * mask_h2 = ncols2 > 1 || mask ? (const half2 *) mask + (nb31/sizeof(half2))*jt*ncols1 : nullptr; - float2 * dstk = ((float2 *) dst) + channel*(ncols2 * DV/2); - const float * sinks_f = sinks ? (const float *) sinks + channel*ncols2 : nullptr; + float2 * dstk = ((float2 *) dst) + zt*(ncols2 * DV/2); + const float * sinks_f = sinks ? (const float *) sinks + zt*ncols2 : nullptr; - const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio)); + const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb22*(zt*ncols2 / gqa_ratio)); - const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, channel, n_head_log2, m0, m1) : 1.0f; + const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, zt, n_head_log2, m0, m1) : 1.0f; int kb0_start_kernel = kb0_start * kb_niter; int kb0_stop_kernel = kb0_stop * kb_niter; @@ -1536,7 +1525,7 @@ static __global__ void flash_attn_ext_f16( constexpr bool needs_fixup = false; flash_attn_ext_f16_process_tile (Q_f2, K_h2, V_h2, mask_h2, sinks_f, dstk, dst_meta, scale, slope, logit_softcap, - ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel); + ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, zt, kb0_start_kernel, kb0_stop_kernel); #else GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask); GGML_UNUSED(dst); GGML_UNUSED(dst_meta); GGML_UNUSED(scale); @@ -1558,6 +1547,7 @@ __launch_bounds__(D, 1) static __global__ void flash_attn_stream_k_fixup( float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne11) { constexpr int ncols = ncols1*ncols2; + constexpr int ne03 = 1; const int bidx0 = blockIdx.x; const int j = blockIdx.y; @@ -1569,9 +1559,10 @@ static __global__ void flash_attn_stream_k_fixup( const int iter_k = ne11 / FATTN_KQ_STRIDE; const int iter_j = (ne01 + (ncols1 - 1)) / ncols1; + const int iter_z = (ne02 + (ncols2 - 1)) / ncols2; - const int kbc0 = (bidx0 + 0)*iter_k*iter_j*(ne02/ncols2) / gridDim.x; - const int kbc0_stop = (bidx0 + 1)*iter_k*iter_j*(ne02/ncols2) / gridDim.x; + const int kbc0 = int64_t((bidx0 + 0)*iter_k*iter_j*iter_z*ne03) / gridDim.x; + const int kbc0_stop = int64_t((bidx0 + 1)*iter_k*iter_j*iter_z*ne03) / gridDim.x; const bool did_not_have_any_data = kbc0 == kbc0_stop; const bool wrote_beginning_of_tile = kbc0 % iter_k == 0; @@ -1580,14 +1571,15 @@ static __global__ void flash_attn_stream_k_fixup( return; } - const int channel = kbc0 / (iter_k*iter_j); - const int jt = (kbc0 - channel*iter_k*iter_j) / iter_k; + const int sequence = kbc0 / (iter_k*iter_j*iter_z); + const int zt = (kbc0 - iter_k*iter_j*iter_z*sequence) / (iter_k*iter_j); + const int jt = (kbc0 - iter_k*iter_j*iter_z*sequence - iter_k*iter_j*zt) / iter_k; // j index of current tile. - if (jt*ncols1 + j >= ne01) { + if (jt*ncols1 + j >= ne01 || zt*ncols2 + c >= ne02) { return; } - dst += jt*ne02*(ncols1*D) + channel*(ncols2*D) + (j*ne02 + c)*D + tid; + dst += sequence*ne02*ne01*D + jt*ne02*(ncols1*D) + zt*(ncols2*D) + (j*ne02 + c)*D + tid; // Load the partial result that needs a fixup: float dst_val = 0.0f; @@ -1606,7 +1598,7 @@ static __global__ void flash_attn_stream_k_fixup( int bidx = bidx0 - 1; int kbc_stop = kbc0; while(true) { - const int kbc = bidx*iter_k*iter_j*(ne02/ncols2) / gridDim.x; + const int kbc = int64_t(bidx)*(iter_k*iter_j*iter_z*ne03) / gridDim.x; if (kbc == kbc_stop) { // Did not have any data. bidx--; kbc_stop = kbc; @@ -1846,7 +1838,8 @@ static void launch_fattn_new_mma( int parallel_blocks = 1; const int ntiles_x = ((Q->ne[1] + ncols1 - 1) / ncols1); - const int ntiles_total = ntiles_x * (Q->ne[2] / ncols2) * Q->ne[3]; + const int ntiles_z = ((Q->ne[2] + ncols2 - 1) / ncols2); + const int ntiles_total = ntiles_x * ntiles_z * Q->ne[3]; // Optional optimization where the mask is scanned to determine whether part of the calculation can be skipped. // Only worth the overhead if there is at lease one FATTN_KQ_STRIDE x FATTN_KQ_STRIDE square to be skipped or @@ -1921,7 +1914,7 @@ static void launch_fattn_new_mma( blocks_num.x = ntiles_x; blocks_num.y = parallel_blocks; - blocks_num.z = Q->ne[2]*Q->ne[3]; + blocks_num.z = ntiles_z*Q->ne[3]; if (parallel_blocks > 1) { dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV)); @@ -2161,6 +2154,10 @@ void ggml_cuda_flash_attn_ext_mma_new(ggml_backend_cuda_context & ctx, ggml_tens return; } GGML_ASSERT(Q->ne[0] == 576 && K->ne[0] == 576 && V->ne[0] == 512); + if (gqa_ratio == 20 && Q->ne[1] <= 4 && K->ne[1] >= 2048) { + ggml_cuda_flash_attn_ext_mma_f16_case<576, 512, 1, 32>(ctx, dst); + return; + } if (gqa_ratio % 16 == 0) { ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst); } else if (gqa_ratio % 4 == 0) { diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu index b2c744ac..5396c7ae 100644 --- a/ggml/src/ggml-cuda/fattn.cu +++ b/ggml/src/ggml-cuda/fattn.cu @@ -107,7 +107,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst // Hence, we use it only for DeepSeek with MLA enabled, where head sizes are 576, 512, // so no other implementation works. // - if (new_mma_available(cc) && K->ne[0] == 576 && V->ne[0] == 512 && Q->ne[1] == 1 && + if (false && new_mma_available(cc) && K->ne[0] == 576 && V->ne[0] == 512 && Q->ne[1] == 1 && Q->ne[2]/K->ne[2] == 20 && K->ne[1] > 8192) { // GLM-4.7-Flash TG hack: split 20 heads into 16+4 heads auto local_Q = *Q;