diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp index 8e44a54133..f47d7d79d3 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp @@ -127,39 +127,25 @@ struct BlockFmhaPipelineQSKSVS typename OAccElementFunction, typename PositionEncoding> CK_TILE_HOST_DEVICE auto - operator()(const QDramBlockWindowTmp & q_dram_block_window_tmp, // M0*K0 tile - const QElementFunction & - q_element_func, - const KDramBlockWindowTmp & - k_dram_block_window_tmp, // N0*K0 tile - const KElementFunction & - k_element_func, - const VDramBlockWindowTmp & - v_dram_block_window_tmp, // N1*K1 tile - const VElementFunction & - v_element_func, - const BiasDramBlockWindowTmp & - bias_dram_block_window_tmp, // M0*N0 tile - const BiasElementFunction & - bias_element_func, - RandValDramBlockWindowTmp & - randval_dram_block_window_tmp, - LSEDramBlockWindowTmp & - lse_dram_window_tmp, // M0*1 tile - const LSEElementFunction & - lse_element_func, - const SAccElementFunction & - s_acc_element_func, - const PComputeElementFunction & - p_compute_element_func, - const OAccElementFunction & - o_acc_element_func, + operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile + const QElementFunction& q_element_func, + const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile + const KElementFunction& k_element_func, + const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile + const VElementFunction& v_element_func, + const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile + const BiasElementFunction& bias_element_func, + RandValDramBlockWindowTmp& randval_dram_block_window_tmp, + LSEDramBlockWindowTmp& lse_dram_window_tmp, // M0*1 tile + const LSEElementFunction& lse_element_func, + const SAccElementFunction& s_acc_element_func, + const PComputeElementFunction& p_compute_element_func, + const OAccElementFunction& o_acc_element_func, FmhaMask mask, PositionEncoding position_encoding, float scale_s, void* smem_ptr, - DropoutType & - dropout) const + DropoutType& dropout) const { static_assert( std::is_same_v> &&