diff --git a/example/ck_tile/01_fmha/fmha_fwd.cpp b/example/ck_tile/01_fmha/fmha_fwd.cpp index 263ea6d48b..c7f1a6d15e 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.cpp +++ b/example/ck_tile/01_fmha/fmha_fwd.cpp @@ -862,6 +862,23 @@ bool run(const ck_tile::ArgParser& arg_parser) return ck_tile::identity{}; }(); + const auto init_traits = [&](auto& traits) { + traits.hdim_q = hdim_q; + traits.hdim_v = hdim_v; + traits.data_type = data_type; + traits.is_group_mode = (mode == mode_enum::group); + traits.is_v_rowmajor = is_v_rowmajor; + traits.mask_type = mask.type; + traits.bias_type = bias.type; + traits.has_lse = lse; + traits.do_fp8_static_quant = squant; + + if constexpr(std::is_same_v>) + { + traits.has_dropout = (p_drop > 0.0f); + } + }; + auto fmha_args = [&, k_paddings_ = seqlen_kpads]() { assert(nhead % nhead_k == 0); /// NOTE: we broadcast bias from [1, 1, seqlen_q, seqlen_k] to [batch, nhead, seqlen_q, @@ -988,29 +1005,14 @@ bool run(const ck_tile::ArgParser& arg_parser) #if CK_TILE_FMHA_FWD_SPLITKV_API if(1 < num_splits || 0 < page_block_size) { - auto fmha_splitkv_traits = fmha_fwd_splitkv_traits{hdim_q, - hdim_v, - data_type, - mode == mode_enum::group, - is_v_rowmajor, - mask.type, - bias.type, - lse, - squant}; + fmha_fwd_splitkv_traits fmha_splitkv_traits; + init_traits(fmha_splitkv_traits); return fmha_fwd_splitkv(fmha_splitkv_traits, fmha_args, stream_config); } #endif - auto fmha_traits = fmha_fwd_traits{hdim_q, - hdim_v, - data_type, - mode == mode_enum::group, - is_v_rowmajor, - mask.type, - bias.type, - lse, - p_drop > 0.0f, - squant}; + fmha_fwd_traits fmha_traits; + init_traits(fmha_traits); return fmha_fwd(fmha_traits, fmha_args, stream_config); }();