Use generic lambda to init traits objects

This commit is contained in:
PoYen, Chen
2024-08-08 16:38:17 +00:00
parent 2f42e4460f
commit 677d9b28dd

View File

@@ -862,6 +862,23 @@ bool run(const ck_tile::ArgParser& arg_parser)
return ck_tile::identity{};
}();
const auto init_traits = [&](auto& traits) {
traits.hdim_q = hdim_q;
traits.hdim_v = hdim_v;
traits.data_type = data_type;
traits.is_group_mode = (mode == mode_enum::group);
traits.is_v_rowmajor = is_v_rowmajor;
traits.mask_type = mask.type;
traits.bias_type = bias.type;
traits.has_lse = lse;
traits.do_fp8_static_quant = squant;
if constexpr(std::is_same_v<fmha_fwd_traits, std::decay_t<decltype(traits)>>)
{
traits.has_dropout = (p_drop > 0.0f);
}
};
auto fmha_args = [&, k_paddings_ = seqlen_kpads]() {
assert(nhead % nhead_k == 0);
/// NOTE: we broadcast bias from [1, 1, seqlen_q, seqlen_k] to [batch, nhead, seqlen_q,
@@ -988,29 +1005,14 @@ bool run(const ck_tile::ArgParser& arg_parser)
#if CK_TILE_FMHA_FWD_SPLITKV_API
if(1 < num_splits || 0 < page_block_size)
{
auto fmha_splitkv_traits = fmha_fwd_splitkv_traits{hdim_q,
hdim_v,
data_type,
mode == mode_enum::group,
is_v_rowmajor,
mask.type,
bias.type,
lse,
squant};
fmha_fwd_splitkv_traits fmha_splitkv_traits;
init_traits(fmha_splitkv_traits);
return fmha_fwd_splitkv(fmha_splitkv_traits, fmha_args, stream_config);
}
#endif
auto fmha_traits = fmha_fwd_traits{hdim_q,
hdim_v,
data_type,
mode == mode_enum::group,
is_v_rowmajor,
mask.type,
bias.type,
lse,
p_drop > 0.0f,
squant};
fmha_fwd_traits fmha_traits;
init_traits(fmha_traits);
return fmha_fwd(fmha_traits, fmha_args, stream_config);
}();