diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_decode_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_decode_kernel.hpp index 3be5acedf2..1d62bc80ed 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_decode_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_decode_kernel.hpp @@ -597,8 +597,8 @@ struct FmhaFwdDecodeKernel long_index_t batch_offset_bias = 0; long_index_t batch_offset_lse_acc = 0; long_index_t batch_offset_o_acc = 0; - index_t kv_l2p_offset = - 0; // logical-to-physical offset of seqlen_k coordinate. only used for paged-kvcache + // index_t kv_l2p_offset = + // 0; // logical-to-physical offset of seqlen_k coordinate. only used for paged-kvcache if constexpr(kIsGroupMode) { @@ -648,7 +648,7 @@ struct FmhaFwdDecodeKernel if(kargs.is_gappy) { // seqstart_k_ptr has different meaning in this case - kv_l2p_offset = kargs.seqstart_k_ptr[i_batch]; + // kv_l2p_offset = kargs.seqstart_k_ptr[i_batch]; } } } @@ -809,66 +809,6 @@ struct FmhaFwdDecodeKernel } }(); - auto k_page_block_navigator = [&, i_batch_ = i_batch]() { - if constexpr(kIsPagedKV) - { - const auto* block_indices = - reinterpret_cast(kargs.block_table_ptr) + - i_batch_ * kargs.batch_stride_block_table; - const index_t num_blocks = - integer_divide_ceil(kv_l2p_offset + kargs.seqlen_k, kargs.page_block_size); - - const long_index_t fixed_offset = - static_cast(i_nhead_k) * kargs.nhead_stride_k; - - return make_page_block_navigator( - kargs.k_ptr, - kargs.batch_stride_k, // kcache page-block stride/size - fixed_offset, - block_indices, - num_blocks, - kargs.page_block_size, - k_dram, - make_k_dram(nullptr, - (kv_l2p_offset + kargs.seqlen_k) - - (num_blocks - 1) * kargs.page_block_size)); - } - else - { - return make_page_block_navigator(k_dram); - } - }(); - - auto v_page_block_navigator = [&, i_batch_ = i_batch]() { - if constexpr(kIsPagedKV) - { - const auto* block_indices = - reinterpret_cast(kargs.block_table_ptr) + - i_batch_ * kargs.batch_stride_block_table; - const index_t num_blocks = - integer_divide_ceil(kv_l2p_offset + kargs.seqlen_k, kargs.page_block_size); - - const long_index_t fixed_offset = - static_cast(i_nhead_k) * kargs.nhead_stride_v; - - return make_page_block_navigator( - kargs.v_ptr, - kargs.batch_stride_v, // vcache page-block stride/size - fixed_offset, - block_indices, - num_blocks, - kargs.page_block_size, - v_dram, - make_v_dram(nullptr, - (kv_l2p_offset + kargs.seqlen_k) - - (num_blocks - 1) * kargs.page_block_size)); - } - else - { - return make_page_block_navigator(v_dram); - } - }(); - auto q_dram_window = make_tile_window( q_dram, [&]() { @@ -880,10 +820,11 @@ struct FmhaFwdDecodeKernel }(), {i_m0, 0}); - auto k_dram_window_lengths = - make_tuple(number{}, number{}); - auto v_dram_window_lengths = - make_tuple(number{}, number{}); + auto k_dram_window = make_tile_window( + k_dram, make_tuple(number{}, number{}), {0, 0}); + + auto v_dram_window = make_tile_window( + v_dram, make_tuple(number{}, number{}), {0, 0}); /// FIXME: Before C++20, capturing structured binding variables are not supported. Remove /// following copy capture of the 'i_nhead' if in C++20 @@ -1006,70 +947,24 @@ struct FmhaFwdDecodeKernel } }(); - AttentionVariant variant; - const auto variant_params = [&] { - if constexpr(kHasLogitsSoftCap) - { - return ck_tile::LogitsSoftCapParams{ - mask, kargs.scale_s, kargs.logits_soft_cap, kargs.logits_soft_cap_rcp}; - } - else - { - return ck_tile::StandardAttentionParams{mask, kargs.scale_s}; - } - }(); - - BlockIndices block_indices{i_batch, i_nhead, i_nhead_k}; - auto o_acc_tile = [&, i_split_ = i_split]() { - if constexpr(kDoFp8StaticQuant) - { - return FmhaPipeline{}(q_dram_window, - identity{}, // q_element_func - k_dram_window_lengths, - k_page_block_navigator, - identity{}, // k_element_func - v_dram_window_lengths, - v_page_block_navigator, - identity{}, // v_element_func - bias_dram_window, - identity{}, // bias_element_func - lse_acc_dram_window, - identity{}, // lse_element_func - identity{}, // s_acc_element_func - scales{kargs.scale_p}, // p_compute_element_func - identity{}, // o_acc_element_func - kargs.num_splits, - i_split_, - mask, - position_encoding, - kargs.scale_s, - variant, - variant_params, - block_indices, - kv_l2p_offset, - smem_ptr); - } - else - { - return FmhaPipeline{}(q_dram_window, - k_dram_window_lengths, - k_page_block_navigator, // Remove it - v_dram_window_lengths, - v_page_block_navigator, // Remove it - bias_dram_window, - lse_acc_dram_window, - kargs.num_splits, // Remove it - i_split_, // Remove it - mask, - position_encoding, - kargs.scale_s, - variant, // Remove it - variant_params, // Remove it - block_indices, // Remove it - kv_l2p_offset, // Remove it - smem_ptr); - } + return FmhaPipeline{}(q_dram_window, + k_dram_window, + // k_page_block_navigator, // Remove it + v_dram_window, + // v_page_block_navigator, // Remove it + bias_dram_window, + lse_acc_dram_window, + kargs.num_splits, // Remove it + i_split_, // Remove it + mask, + position_encoding, + kargs.scale_s, + // variant, // Remove it + // variant_params, // Remove it + // block_indices, // Remove it + // kv_l2p_offset, // Remove it + smem_ptr); }(); // Oacc DRAM and Oacc DRAM window diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_decode_pipeline_qr_ks_vs.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_decode_pipeline_qr_ks_vs.hpp index 94af351128..d4f66d236f 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_decode_pipeline_qr_ks_vs.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_decode_pipeline_qr_ks_vs.hpp @@ -125,21 +125,15 @@ struct BlockFmhaFwdDecodePipelineQRKSVS } template + typename PositionEncoding> CK_TILE_HOST_DEVICE auto - operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile - const KDramBlockWindowLengths& k_dram_block_window_lengths, // N0*K0 tile - const KPageBlockNavigator& k_page_block_navigator, - const VDramBlockWindowLengths& v_dram_block_window_lengths, // N1*K1 tile - const VPageBlockNavigator& v_page_block_navigator, + operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile + const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile + const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile LSEaccDramBlockWindowTmp& lse_acc_dram_window_tmp, // M0*1 tile index_t num_splits, @@ -147,29 +141,26 @@ struct BlockFmhaFwdDecodePipelineQRKSVS FmhaMask mask, PositionEncoding position_encoding, float scale_s, - const AttentionVariant& variant, - const AttentionVariantParams& variant_params, - const BlockIndices& block_indices, - index_t kv_l2p_offset, // logical-to-physical offset of seqlen_k coordinate void* smem_ptr) const { static_assert( std::is_same_v> && - std::is_same_v> && - std::is_same_v>, + std::is_same_v> && + std::is_same_v>, "wrong!"); static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && kSubQKHeaddim == QDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] && - kN0 == KDramBlockWindowLengths{}[number<0>{}] && - kK0 == KDramBlockWindowLengths{}[number<1>{}] && - kN1 == VDramBlockWindowLengths{}[number<0>{}] && - kK1 == VDramBlockWindowLengths{}[number<1>{}] && + kN0 == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kK0 == KDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] && + kN1 == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kK1 == VDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] && kM0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<1>{}], "wrong!"); - + ignore = bias_dram_block_window_tmp; + ignore = position_encoding; // Block GEMM constexpr auto gemm_0 = Policy::template GetQKBlockGemm(); constexpr auto gemm_1 = Policy::template GetPVBlockGemm(); @@ -248,29 +239,16 @@ struct BlockFmhaFwdDecodePipelineQRKSVS async_load_tile(q_lds_store_window, q_dram_window); // K tile in LDS - const index_t physical_seqlen_k_start = logical_seqlen_k_start + kv_l2p_offset; - const index_t physical_seqlen_k_end = logical_seqlen_k_end + kv_l2p_offset; + const index_t physical_seqlen_k_start = logical_seqlen_k_start; + const index_t physical_seqlen_k_end = logical_seqlen_k_end; // make sure the first tile is completely located in page-block (page-block size should be // divisible by kN0) // relationship between each *_start variables: aligned_physical_seqlen_k_start <= // physical_seqlen_k_start, logical_seqlen_k_start <= physical_seqlen_k_start - const index_t aligned_physical_seqlen_k_start = - [&, physical_seqlen_k_start_ = physical_seqlen_k_start] { - if constexpr(kIsPagedKV) - { - return kN0 * integer_divide_floor(physical_seqlen_k_start_, kN0); - } - else - { - return physical_seqlen_k_start_; - } - }(); - - auto [i_page_block_k, k_dram_block_window] = k_page_block_navigator.make_tile_window( - k_dram_block_window_lengths, {aligned_physical_seqlen_k_start, 0}); + const index_t aligned_physical_seqlen_k_start = physical_seqlen_k_start; auto k_dram_window = make_tile_window( - k_dram_block_window, Policy::template MakeKDramTileDistribution()); + k_dram_block_window_tmp, Policy::template MakeKDramTileDistribution()); auto k_lds = make_tensor_view( static_cast(smem_ptr), Policy::template MakeKLdsBlockDescriptor()); @@ -297,11 +275,8 @@ struct BlockFmhaFwdDecodePipelineQRKSVS Policy::template MakeSRegTileDistribution()); // V tile in LDS - auto [i_page_block_v, v_dram_block_window] = v_page_block_navigator.make_tile_window( - v_dram_block_window_lengths, {0, aligned_physical_seqlen_k_start}); - auto v_dram_window = make_tile_window( - v_dram_block_window, Policy::template MakeVDramTileDistribution()); + v_dram_block_window_tmp, Policy::template MakeVDramTileDistribution()); auto v_lds = make_tensor_view( reinterpret_cast(static_cast(smem_ptr) + @@ -319,14 +294,14 @@ struct BlockFmhaFwdDecodePipelineQRKSVS Policy::template MakeVRegTileDistribution()); // Bias tile in LDS - const auto bias_origin = bias_dram_block_window_tmp.get_window_origin(); - auto bias_dram_window = - make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(), - bias_dram_block_window_tmp.get_window_lengths(), - {bias_origin.at(number<0>{}), - logical_seqlen_k_start - (physical_seqlen_k_start - - aligned_physical_seqlen_k_start)}, // M/N - Policy::template MakeBiasDramTileDistribution()); + // const auto bias_origin = bias_dram_block_window_tmp.get_window_origin(); + // auto bias_dram_window = + // make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(), + // bias_dram_block_window_tmp.get_window_lengths(), + // {bias_origin.at(number<0>{}), + // logical_seqlen_k_start - (physical_seqlen_k_start - + // aligned_physical_seqlen_k_start)}, // M/N + // Policy::template MakeBiasDramTileDistribution()); block_sync_lds_direct_load<0>(); auto q_tile = load_tile(q_lds_read_window); @@ -352,17 +327,17 @@ struct BlockFmhaFwdDecodePipelineQRKSVS // STAGE 1, QK gemm clear_tile(s_acc); // initialize C - if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) - { - __builtin_amdgcn_sched_barrier( - 0); // prevent from messing up the order of global loads - } - const auto bias_tile = load_tile(bias_dram_window); // load bias tile - if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) - { - __builtin_amdgcn_sched_barrier( - 0); // prevent from messing up the order of global loads - } + // if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + // { + // __builtin_amdgcn_sched_barrier( + // 0); // prevent from messing up the order of global loads + // } + // const auto bias_tile = load_tile(bias_dram_window); // load bias tile + // if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + // { + // __builtin_amdgcn_sched_barrier( + // 0); // prevent from messing up the order of global loads + // } block_sync_lds(); async_load_tile(v_lds_write_window, v_dram_window); // prefetch load v tile @@ -379,105 +354,74 @@ struct BlockFmhaFwdDecodePipelineQRKSVS sequence{}), k_tile); - // STAGE 2, scale_s, add bias, mask, softmax - if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) - { - tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc); - tile_elementwise_inout( - [&](auto& x, const auto& y) { -#if !CK_TILE_FMHA_FWD_FAST_EXP2 - x += type_convert(bias_element_func(y)); -#else - x += log2e_v * - type_convert(bias_element_func(y)); -#endif - }, - s_acc, - bias_tile); - } - else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI) - { - const auto k_origin = k_page_block_navigator.to_global_window_origin( - i_page_block_k, k_dram_block_window.get_window_origin()); - constexpr auto s_spans = decltype(s_acc)::get_distributed_spans(); - sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) { - sweep_tile_span(s_spans[number<1>{}], [&](auto idx1) { - const auto tile_idx = get_x_indices_from_distributed_indices( - s_acc.get_tile_distribution(), make_tuple(idx0, idx1)); + // // STAGE 2, scale_s, add bias, mask, softmax + // if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + // { + // tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc); + // tile_elementwise_inout( + // [&](auto& x, const auto& y) { + // x += log2e_v * + // type_convert(bias_element_func(y)); + // }, + // s_acc, + // bias_tile); + // } + // else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI) + // { + // const auto k_origin = k_page_block_navigator.to_global_window_origin( + // i_page_block_k, k_dram_block_window.get_window_origin()); + // constexpr auto s_spans = decltype(s_acc)::get_distributed_spans(); + // sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) { + // sweep_tile_span(s_spans[number<1>{}], [&](auto idx1) { + // const auto tile_idx = get_x_indices_from_distributed_indices( + // s_acc.get_tile_distribution(), make_tuple(idx0, idx1)); - const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); - const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); - constexpr auto i_j_idx = make_tuple(idx0, idx1); + // const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); + // const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); + // constexpr auto i_j_idx = make_tuple(idx0, idx1); - s_acc(i_j_idx) *= scale_s; - // position_encoding accept only logical coordinates, do conversion here - position_encoding.update(s_acc(i_j_idx), row, col - kv_l2p_offset); - }); - }); - } - else - { - if constexpr(kHasLogitsSoftCap) - { - auto apply_logits_transform = - [&variant, &variant_params, &block_indices](auto& x) { - x = variant.LogitsTransform(variant_params, - variant.QueryTransform(variant_params, x), - block_indices.batch_idx, - block_indices.qo_head_idx, - block_indices.kv_head_idx); - }; -#if !CK_TILE_FMHA_FWD_FAST_EXP2 - for(index_t i = 0; i < s_acc.thread_buf_.size(); ++i) - { - apply_logits_transform(s_acc.thread_buf_[i]); - } -#else - for(index_t i = 0; i < s_acc.thread_buf_.size(); ++i) - { - apply_logits_transform(s_acc.thread_buf_[i]); - } -#endif - } - else - { -#if !CK_TILE_FMHA_FWD_FAST_EXP2 - tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc); -#endif - } - } - move_tile_window(bias_dram_window, {0, kN0}); + // s_acc(i_j_idx) *= scale_s; + // // position_encoding accept only logical coordinates, do conversion here + // position_encoding.update(s_acc(i_j_idx), row, col - kv_l2p_offset); + // }); + // }); + // } + // move_tile_window(bias_dram_window, {0, kN0}); /// TODO: only check in first/last iteration without increasing code size if constexpr(kHasUnevenSplits) { - const auto k_origin = k_page_block_navigator.to_global_window_origin( - i_page_block_k, k_dram_block_window.get_window_origin()); - set_tile_if( - s_acc, - -numeric::infinity(), - [&, - physical_seqlen_k_start_ = physical_seqlen_k_start, - physical_seqlen_k_end_ = physical_seqlen_k_end](auto tile_idx) { - const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); - if constexpr(kIsPagedKV) - { - return col < physical_seqlen_k_start_ || physical_seqlen_k_end_ <= col; - } - else - { - return physical_seqlen_k_end_ <= col; - } - }); + if(i_total_loops == (num_total_loop - 1)) + { + const auto k_origin = make_tuple(kN0 * i_total_loops, 0); + set_tile_if(s_acc, + -numeric::infinity(), + [&, + physical_seqlen_k_start_ = physical_seqlen_k_start, + physical_seqlen_k_end_ = physical_seqlen_k_end](auto tile_idx) { + const auto col = + k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); + if constexpr(kIsPagedKV) + { + return col < physical_seqlen_k_start_ || + physical_seqlen_k_end_ <= col; + } + else + { + return physical_seqlen_k_end_ <= col; + } + }); + } } if constexpr(kPadSeqLenK || FmhaMask::IsMasking) { - const auto k_origin = k_page_block_navigator.to_global_window_origin( - i_page_block_k, k_dram_block_window.get_window_origin()); + const auto k_origin = make_tuple(kN0 * i_total_loops, 0); + // const auto k_origin = k_page_block_navigator.to_global_window_origin( + // i_page_block_k, k_dram_block_window.get_window_origin()); // mask accept only logical coordinates, do conversion here bool need_perpixel_check = mask.IsEdgeTile(q_origin.at(number<0>{}), - k_origin.at(number<0>{}) - kv_l2p_offset, + k_origin.at(number<0>{}), number{}, number{}); if(need_perpixel_check) @@ -486,17 +430,13 @@ struct BlockFmhaFwdDecodePipelineQRKSVS s_acc, -numeric::infinity(), [&](auto tile_idx) { const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); - return mask.IsOutOfBound(row, col - kv_l2p_offset); + return mask.IsOutOfBound(row, col); }); } } // move K tile windows after current status checked - i_page_block_k = k_page_block_navigator.move_tile_window( - i_page_block_k, k_dram_block_window, {kN0, 0}); - - k_dram_window = make_tile_window(k_dram_block_window, - Policy::template MakeKDramTileDistribution()); + move_tile_window(k_dram_window, {kN0, 0}); block_sync_lds(); async_load_tile(k_lds_write_window, k_dram_window); @@ -550,12 +490,9 @@ struct BlockFmhaFwdDecodePipelineQRKSVS constexpr auto p_spans = decltype(p_compute)::get_distributed_spans(); sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) { constexpr auto i_idx = make_tuple(idx0); -#if CK_TILE_FMHA_FWD_FAST_EXP2 - auto row_max = scale_s * get_validated_m(m[i_idx]); -#endif + auto row_max = scale_s * get_validated_m(m[i_idx]); sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) { constexpr auto i_j_idx = make_tuple(idx0, idx1); -#if CK_TILE_FMHA_FWD_FAST_EXP2 if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || BiasEnum == BlockAttentionBiasEnum::ALIBI) { @@ -572,9 +509,6 @@ struct BlockFmhaFwdDecodePipelineQRKSVS p_compute(i_j_idx) = exp2(scale_s * s_new[i_j_idx] - row_max); } } -#else - p_compute(i_j_idx) = exp(s_new[i_j_idx] - get_validated_m(m[i_idx])); -#endif }); }); @@ -591,8 +525,7 @@ struct BlockFmhaFwdDecodePipelineQRKSVS constexpr auto o_spans = decltype(o_acc)::get_distributed_spans(); sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) { constexpr auto i_idx = make_tuple(idx0); -#if CK_TILE_FMHA_FWD_FAST_EXP2 - const auto tmp = [&]() { + const auto tmp = [&]() { if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || BiasEnum == BlockAttentionBiasEnum::ALIBI) { @@ -611,9 +544,6 @@ struct BlockFmhaFwdDecodePipelineQRKSVS } } }(); -#else - const auto tmp = exp(m_old[i_idx] - get_validated_m(m[i_idx])); -#endif l(i_idx) = tmp * l[i_idx] + rowsum_p[i_idx]; sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) { constexpr auto i_j_idx = make_tuple(idx0, idx1); @@ -644,7 +574,6 @@ struct BlockFmhaFwdDecodePipelineQRKSVS constexpr auto lse_acc_spans = decltype(lse_acc)::get_distributed_spans(); sweep_tile_span(lse_acc_spans[number<0>{}], [&, m_ = m, l_ = l](auto idx0) { constexpr auto i_idx = make_tuple(idx0); -#if CK_TILE_FMHA_FWD_FAST_EXP2 if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || BiasEnum == BlockAttentionBiasEnum::ALIBI) { @@ -661,9 +590,6 @@ struct BlockFmhaFwdDecodePipelineQRKSVS lse_acc(i_idx) = m_[i_idx] * scale_s / C_LOG2E + log(l_[i_idx]); } } -#else - lse_acc(i_idx) = m_[i_idx] + log(l_[i_idx]); -#endif }); if(get_thread_local_1d_id() < kM0) diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_decode_pipeline_qr_ks_vs_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_decode_pipeline_qr_ks_vs_policy.hpp index a27ce70f9a..ed1dec6cc1 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_decode_pipeline_qr_ks_vs_policy.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_decode_pipeline_qr_ks_vs_policy.hpp @@ -271,17 +271,22 @@ struct BlockFmhaFwdDecodePipelineQRKSVSDefaultPolicy typename Problem::BlockFmhaShape::Gemm1BlockWarps, typename Problem::BlockFmhaShape::Gemm1WarpTile>>; - using WarpGemm = - WarpGemmMfmaDispatcher{}), - Problem::BlockFmhaShape::Gemm1WarpTile::at(number<1>{}), - Problem::BlockFmhaShape::Gemm1WarpTile::at(number<2>{}), - true, - false, - false, - WGAttrNumAccessEnum::Double>; + using WarpGemm = WarpGemmMfmaDispatcher< + typename Problem::PDataType, + typename Problem::VDataType, + typename Problem::OaccDataType, + Problem::BlockFmhaShape::Gemm1WarpTile::at(number<0>{}), + Problem::BlockFmhaShape::Gemm1WarpTile::at(number<1>{}), + Problem::BlockFmhaShape::Gemm1WarpTile::at(number<2>{}), + true, + false, + false, + ((Problem::BlockFmhaShape::Gemm1WarpTile::at(number<1>{}) == 16 && + Problem::BlockFmhaShape::Gemm1WarpTile::at(number<2>{}) == 32) || + (Problem::BlockFmhaShape::Gemm1WarpTile::at(number<1>{}) == 32 && + Problem::BlockFmhaShape::Gemm1WarpTile::at(number<2>{}) == 16)) + ? WGAttrNumAccessEnum::Double + : WGAttrNumAccessEnum::Single>; using BlockGemmPolicy = BlockGemmARegBRegCRegV1CustomPolicy