mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-07 00:04:37 +00:00
Use generic lambda to init traits objects
This commit is contained in:
@@ -862,6 +862,23 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
return ck_tile::identity{};
|
||||
}();
|
||||
|
||||
const auto init_traits = [&](auto& traits) {
|
||||
traits.hdim_q = hdim_q;
|
||||
traits.hdim_v = hdim_v;
|
||||
traits.data_type = data_type;
|
||||
traits.is_group_mode = (mode == mode_enum::group);
|
||||
traits.is_v_rowmajor = is_v_rowmajor;
|
||||
traits.mask_type = mask.type;
|
||||
traits.bias_type = bias.type;
|
||||
traits.has_lse = lse;
|
||||
traits.do_fp8_static_quant = squant;
|
||||
|
||||
if constexpr(std::is_same_v<fmha_fwd_traits, std::decay_t<decltype(traits)>>)
|
||||
{
|
||||
traits.has_dropout = (p_drop > 0.0f);
|
||||
}
|
||||
};
|
||||
|
||||
auto fmha_args = [&, k_paddings_ = seqlen_kpads]() {
|
||||
assert(nhead % nhead_k == 0);
|
||||
/// NOTE: we broadcast bias from [1, 1, seqlen_q, seqlen_k] to [batch, nhead, seqlen_q,
|
||||
@@ -988,29 +1005,14 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
#if CK_TILE_FMHA_FWD_SPLITKV_API
|
||||
if(1 < num_splits || 0 < page_block_size)
|
||||
{
|
||||
auto fmha_splitkv_traits = fmha_fwd_splitkv_traits{hdim_q,
|
||||
hdim_v,
|
||||
data_type,
|
||||
mode == mode_enum::group,
|
||||
is_v_rowmajor,
|
||||
mask.type,
|
||||
bias.type,
|
||||
lse,
|
||||
squant};
|
||||
fmha_fwd_splitkv_traits fmha_splitkv_traits;
|
||||
init_traits(fmha_splitkv_traits);
|
||||
|
||||
return fmha_fwd_splitkv(fmha_splitkv_traits, fmha_args, stream_config);
|
||||
}
|
||||
#endif
|
||||
auto fmha_traits = fmha_fwd_traits{hdim_q,
|
||||
hdim_v,
|
||||
data_type,
|
||||
mode == mode_enum::group,
|
||||
is_v_rowmajor,
|
||||
mask.type,
|
||||
bias.type,
|
||||
lse,
|
||||
p_drop > 0.0f,
|
||||
squant};
|
||||
fmha_fwd_traits fmha_traits;
|
||||
init_traits(fmha_traits);
|
||||
|
||||
return fmha_fwd(fmha_traits, fmha_args, stream_config);
|
||||
}();
|
||||
|
||||
Reference in New Issue
Block a user