diff --git a/ggml/src/ggml-cuda/fattn-new-mma.cu b/ggml/src/ggml-cuda/fattn-new-mma.cu index c37a618f..b2285fdd 100644 --- a/ggml/src/ggml-cuda/fattn-new-mma.cu +++ b/ggml/src/ggml-cuda/fattn-new-mma.cu @@ -14,6 +14,46 @@ using namespace ggml_cuda_mma; +typedef void (* fattn_new_mma_kernel_t)( + const char * __restrict__ Q, + const char * __restrict__ K, + const char * __restrict__ V, + const char * __restrict__ mask, + const char * __restrict__ sinks, + const int * __restrict__ KV_max, + float * __restrict__ dst, + float2 * __restrict__ dst_meta, + const float scale, + const float max_bias, + const float m0, + const float m1, + const float softcap, + const uint32_t n_head_log2, + const int ne00, + const int ne01, + const int ne02, + const int ne03, + const int ne10, + const int ne11, + const int ne12, + const int ne13, + const int ne31, + const int nb31, + const int nb01, + const int nb02, + const int nb03, + const int nb11, + const int nb12, + const int nb13, + const int nb21, + const int nb22, + const int nb23, + const int ne0, + const int ne1, + const int ne2, + const int ne3); + + typedef tile<16, 8, half2> tile_A; typedef tile< 8, 8, half2> tile_B; typedef tile<16, 8, half2> tile_B_16; @@ -43,37 +83,37 @@ struct fattn_mma_f16_config; // Perhaps the 256 head size needs a closer look // to see if this implementation is better. // -template <> -struct fattn_mma_f16_config< 64, 64> { - static constexpr int nbatch_fa = 64; - static constexpr int nwarps_max = 4; - static constexpr bool Q_in_reg = true; - static constexpr int nstages_target = 2; - - static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) { - return 32; - } - - static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) { - return 32; - } - - static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) { - return 32; - } - - static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) { - return 32; - } - - static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) { - return 32; - } - - static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) { - return 32; - } -}; +//template <> +//struct fattn_mma_f16_config< 64, 64> { +// static constexpr int nbatch_fa = 64; +// static constexpr int nwarps_max = 4; +// static constexpr bool Q_in_reg = true; +// static constexpr int nstages_target = 2; +// +// static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) { +// return 32; +// } +// +// static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) { +// return 32; +// } +// +// static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) { +// return 32; +// } +// +// static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) { +// return 32; +// } +// +// static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) { +// return 32; +// } +// +// static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) { +// return 32; +// } +//}; // //template <> //struct fattn_mma_f16_config< 80, 80> { @@ -243,6 +283,38 @@ struct fattn_mma_f16_config< 64, 64> { // } //}; +template <> +struct fattn_mma_f16_config<192, 128> { + static constexpr int nbatch_fa = 64; + static constexpr int nwarps_max = 4; + static constexpr bool Q_in_reg = true; + static constexpr int nstages_target = 1; + + static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) { + return 64; + } + + static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) { + return 64; + } + + static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) { + return 64; + } + + static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) { + return 64; + } + + static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) { + return 64; + } + + static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) { + return 64; + } +}; + template <> struct fattn_mma_f16_config<576, 512> { static constexpr int nbatch_fa = 32; @@ -1287,6 +1359,7 @@ static __global__ void flash_attn_ext_f16( const char * __restrict__ V, const char * __restrict__ mask, const char * __restrict__ sinks, + const int * __restrict__ KV_max, float * __restrict__ dst, float2 * __restrict__ dst_meta, const float scale, @@ -1377,8 +1450,11 @@ static __global__ void flash_attn_ext_f16( const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, channel, n_head_log2, m0, m1) : 1.0f; - const int kb0_start_kernel = kb0_start * kb_niter; - const int kb0_stop_kernel = kb0_stop * kb_niter; + int kb0_start_kernel = kb0_start * kb_niter; + int kb0_stop_kernel = kb0_stop * kb_niter; + if (KV_max) { + kb0_stop_kernel = min(kb0_stop_kernel, KV_max[jt] / c::nbatch_fa); + } constexpr bool is_fixup = false; // All but (potentially) the last iterations write their data to dst rather than the fixup buffer. if (kb0_start == 0) { @@ -1417,8 +1493,11 @@ static __global__ void flash_attn_ext_f16( const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, channel, n_head_log2, m0, m1) : 1.0f; - const int kb0_start_kernel = kb0_start * kb_niter; - const int kb0_stop_kernel = kb0_stop * kb_niter; + int kb0_start_kernel = kb0_start * kb_niter; + int kb0_stop_kernel = kb0_stop * kb_niter; + if (KV_max) { + kb0_stop_kernel = min(kb0_stop_kernel, KV_max[jt] / c::nbatch_fa); + } constexpr bool is_fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks. constexpr bool needs_fixup = false; @@ -1574,9 +1653,68 @@ static __global__ void flash_attn_combine_results_new( dst[blockIdx.z*D + tid] = VKQ_numerator / VKQ_denominator; } +template +static __device__ __forceinline__ int warp_reduce_all(int x) { + if constexpr (width == WARP_SIZE) { //ggml_cuda_get_physical_warp_size()) { + return __all_sync(0xffffffff, x); + } else { +#pragma unroll + for (int offset = width/2; offset > 0; offset >>= 1) { + x = __shfl_xor_sync(0xffffffff, x, offset, width) && x; + } + return x; + } +} + +template +__launch_bounds__(FATTN_KQ_STRIDE/2, 1) +static __global__ void flash_attn_mask_to_KV_max( + const half2 * __restrict__ mask, int * __restrict__ KV_min_max, const int ne30, const int s31, const int s33) { + const int ne31 = gridDim.x; + const int tid = threadIdx.x; + const int sequence = blockIdx.y; + const int jt = blockIdx.x; + + mask += sequence*s33 + jt*ncols1*s31; + + __shared__ int buf_iw[WARP_SIZE]; + if (tid < WARP_SIZE) { + buf_iw[tid] = 1; + } + __syncthreads(); + + int KV_max_sj = (ne30 - 1) * FATTN_KQ_STRIDE; + for (; KV_max_sj >= 0; KV_max_sj -= FATTN_KQ_STRIDE) { + int all_inf = 1; + +#pragma unroll + for (int j = 0; j < ncols1; ++j) { + const float2 tmp = __half22float2(mask[j*s31 + KV_max_sj/2 + tid]); + all_inf = all_inf && int(isinf(tmp.x)) && int(isinf(tmp.y)); + } + + all_inf = warp_reduce_all(all_inf); + if (tid % WARP_SIZE == 0) { + buf_iw[tid / WARP_SIZE] = all_inf; + } + __syncthreads(); + all_inf = buf_iw[tid % WARP_SIZE]; + __syncthreads(); + all_inf = warp_reduce_all(all_inf); + + if (!all_inf) { + break; + } + } + + if (threadIdx.x == 0) { + KV_min_max[sequence*ne31 + jt] = KV_max_sj + FATTN_KQ_STRIDE; + } +} + template static void launch_fattn_new_mma( - ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel, const int nwarps, const size_t nbytes_shared, + ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_new_mma_kernel_t fattn_kernel, const int nwarps, const size_t nbytes_shared, const int KQ_row_granularity, const bool need_f16_K, const bool need_f16_V, const bool stream_k, const int warp_size = WARP_SIZE ) { constexpr int ncols = ncols1 * ncols2; @@ -1605,10 +1743,15 @@ static void launch_fattn_new_mma( cudaStream_t main_stream = ctx.stream(); const int id = ggml_cuda_get_device(); const int cc = ggml_cuda_info().devices[id].cc; - const int nsm = ggml_cuda_info().devices[id].nsm; + const int nsm_actual = ggml_cuda_info().devices[id].nsm; + int nsm = 1; while (nsm*2 <= nsm_actual) nsm *= 2; + + if (Q->ne[1] == 1 && K->ne[1] <= 4096 && nsm > 32) nsm /= 2; + if (Q->ne[1] >= 32 && K->ne[1] >= 4096) nsm *= 2; ggml_cuda_pool_alloc K_f16(pool); ggml_cuda_pool_alloc V_f16(pool); + ggml_cuda_pool_alloc KV_max(pool); ggml_cuda_pool_alloc dst_tmp(pool); ggml_cuda_pool_alloc dst_tmp_meta(pool); @@ -1675,6 +1818,25 @@ static void launch_fattn_new_mma( const int ntiles_x = ((Q->ne[1] + ncols1 - 1) / ncols1); const int ntiles_total = ntiles_x * (Q->ne[2] / ncols2) * Q->ne[3]; + // Optional optimization where the mask is scanned to determine whether part of the calculation can be skipped. + // Only worth the overhead if there is at lease one FATTN_KQ_STRIDE x FATTN_KQ_STRIDE square to be skipped or + // multiple sequences of possibly different lengths. + if (mask && K->ne[1] % FATTN_KQ_STRIDE == 0 && (Q->ne[1] >= 1024 || Q->ne[3] > 1)) { + const int s31 = mask->nb[1] / sizeof(half2); + const int s33 = mask->nb[3] / sizeof(half2); + + const dim3 blocks_num_KV_max(ntiles_x, Q->ne[3], 1); + const dim3 block_dim_KV_max(FATTN_KQ_STRIDE/2, 1, 1); + + const int ne_KV_max = blocks_num_KV_max.x*blocks_num_KV_max.y; + const int iter_k = K->ne[1] / FATTN_KQ_STRIDE; + + KV_max.alloc(ne_KV_max); + flash_attn_mask_to_KV_max<<>> + ((const half2 *) mask->data, KV_max.ptr, iter_k, s31, s33); + CUDA_CHECK(cudaGetLastError()); + } + const dim3 block_dim(warp_size, nwarps, 1); int max_blocks_per_sm = 1; // Max. number of active blocks limited by occupancy. CUDA_CHECK(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_blocks_per_sm, fattn_kernel, block_dim.x * block_dim.y * block_dim.z, nbytes_shared)); @@ -1761,6 +1923,7 @@ static void launch_fattn_new_mma( V_data, mask ? ((const char *) mask->data) : nullptr, sinks ? ((const char *)sinks->data) : nullptr, + KV_max.get(), !stream_k && parallel_blocks > 1 ? dst_tmp.ptr : (float *) KQV->data, dst_tmp_meta.ptr, scale, max_bias, m0, m1, logit_softcap, n_head_log2, Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], @@ -1837,7 +2000,7 @@ static void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ct float logit_softcap; memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float)); - fattn_kernel_t fattn_kernel; + fattn_new_mma_kernel_t fattn_kernel; if (logit_softcap == 0.0f) { constexpr bool use_logit_softcap = false; fattn_kernel = flash_attn_ext_f16; @@ -1944,8 +2107,16 @@ void ggml_cuda_flash_attn_ext_mma_new(ggml_backend_cuda_context & ctx, ggml_tens GGML_ASSERT(Q->ne[2] % K->ne[2] == 0); const int gqa_ratio = Q->ne[2] / K->ne[2]; - if (K->ne[0] == 64 && V->ne[0] == 64) { - ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2<64, 64>(ctx, dst); + //if (K->ne[0] == 64 && V->ne[0] == 64) { + // ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2<64, 64>(ctx, dst); + // return; + //} + if (K->ne[0] == 192 && V->ne[0] == 128) { + GGML_ASSERT(Q->ne[0] == 192); + GGML_ASSERT(gqa_ratio == 1); + //ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2<192, 128>(ctx, dst); + // Reduce compile time + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<192, 128, 1>(ctx, dst); return; } GGML_ASSERT(Q->ne[0] == 576 && K->ne[0] == 576 && V->ne[0] == 512); diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu index a1544e41..1ed10bd9 100644 --- a/ggml/src/ggml-cuda/fattn.cu +++ b/ggml/src/ggml-cuda/fattn.cu @@ -102,7 +102,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst // Hence, we use it only for DeepSeek with MLA enabled, where head sizes are 576, 512, // so no other implementation works. // - if (new_mma_available(cc) && Q->ne[0] == 576) { + if (new_mma_available(cc) && ((K->ne[0] == 576 && V->ne[0] == 512) || (K->ne[0] == 192 && V->ne[0] == 128))) { ggml_cuda_flash_attn_ext_mma_new(ctx, dst); return; } @@ -172,8 +172,8 @@ bool ggml_cuda_fattn_is_supported(ggml_backend_cuda_context & ctx, const ggml_te return ggml_cuda_fattn_tile_f32_is_supported(ctx, dst); } - if (new_mma_available(cc) && Q->ne[0] == 576) { - return V->ne[0] == 512; + if (new_mma_available(cc) && (Q->ne[0] == 576 || (K->ne[0] == 192 && V->ne[0] == 128))) { + return true; } if (!new_mma_available(cc) || K->ne[0] != V->ne[0]) {