mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-05 14:11:29 +00:00
tempsave, fmha_decode
This commit is contained in:
@@ -823,7 +823,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
<< (is_rotary_interleaved ? "inter" : "half") << ")";
|
||||
}
|
||||
#endif
|
||||
#if CK_TILE_FMHA_FWD_SPLITKV_API
|
||||
#if CK_TILE_FMHA_FWD_SPLITKV_API || CK_TILE_FMHA_FWD_DECODE_API
|
||||
if(1 < num_splits)
|
||||
{
|
||||
std::cout << ", num_splits:" << num_splits;
|
||||
@@ -1048,7 +1048,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
args.drop_seed_offset = std::make_pair(drop_seed, drop_offset);
|
||||
}
|
||||
}
|
||||
else if constexpr(std::is_same_v<fmha_fwd_splitkv_args, std::decay_t<decltype(args)>>)
|
||||
else if constexpr(std::is_same_v<fmha_fwd_splitkv_args, std::decay_t<decltype(args)>> || std::is_same_v<fmha_fwd_decode_args, std::decay_t<decltype(args)>>)
|
||||
{
|
||||
args.lse_acc_ptr = lse_acc_buf.GetDeviceBuffer();
|
||||
args.o_acc_ptr = o_acc_buf.GetDeviceBuffer();
|
||||
@@ -1103,7 +1103,15 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
|
||||
return fmha_fwd_splitkv(fmha_splitkv_traits, fmha_splitkv_args, stream_config);
|
||||
}
|
||||
#endif
|
||||
#elif CK_TILE_FMHA_FWD_DECODE_API
|
||||
fmha_fwd_decode_traits fmha_decode_traits;
|
||||
init_traits(fmha_decode_traits);
|
||||
|
||||
fmha_fwd_decode_args fmha_decode_args;
|
||||
init_args(fmha_decode_args);
|
||||
|
||||
return fmha_fwd_decode(fmha_decode_traits, fmha_decode_args, stream_config);
|
||||
#else
|
||||
fmha_fwd_traits fmha_traits;
|
||||
init_traits(fmha_traits);
|
||||
|
||||
@@ -1111,6 +1119,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
init_args(fmha_args);
|
||||
|
||||
return fmha_fwd(fmha_traits, fmha_args, stream_config);
|
||||
#endif
|
||||
}();
|
||||
|
||||
if(appendkv_ave_time < 0.0f || fwd_ave_time < 0.0f)
|
||||
|
||||
Reference in New Issue
Block a user