tempsave, fmha_decode

This commit is contained in:
aska-0096
2025-07-08 08:37:20 +00:00
parent 47565f21a5
commit 18686cfe5b
13 changed files with 2562 additions and 71 deletions

View File

@@ -823,7 +823,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
<< (is_rotary_interleaved ? "inter" : "half") << ")";
}
#endif
#if CK_TILE_FMHA_FWD_SPLITKV_API
#if CK_TILE_FMHA_FWD_SPLITKV_API || CK_TILE_FMHA_FWD_DECODE_API
if(1 < num_splits)
{
std::cout << ", num_splits:" << num_splits;
@@ -1048,7 +1048,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
args.drop_seed_offset = std::make_pair(drop_seed, drop_offset);
}
}
else if constexpr(std::is_same_v<fmha_fwd_splitkv_args, std::decay_t<decltype(args)>>)
else if constexpr(std::is_same_v<fmha_fwd_splitkv_args, std::decay_t<decltype(args)>> || std::is_same_v<fmha_fwd_decode_args, std::decay_t<decltype(args)>>)
{
args.lse_acc_ptr = lse_acc_buf.GetDeviceBuffer();
args.o_acc_ptr = o_acc_buf.GetDeviceBuffer();
@@ -1103,7 +1103,15 @@ bool run(const ck_tile::ArgParser& arg_parser)
return fmha_fwd_splitkv(fmha_splitkv_traits, fmha_splitkv_args, stream_config);
}
#endif
#elif CK_TILE_FMHA_FWD_DECODE_API
fmha_fwd_decode_traits fmha_decode_traits;
init_traits(fmha_decode_traits);
fmha_fwd_decode_args fmha_decode_args;
init_args(fmha_decode_args);
return fmha_fwd_decode(fmha_decode_traits, fmha_decode_args, stream_config);
#else
fmha_fwd_traits fmha_traits;
init_traits(fmha_traits);
@@ -1111,6 +1119,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
init_args(fmha_args);
return fmha_fwd(fmha_traits, fmha_args, stream_config);
#endif
}();
if(appendkv_ave_time < 0.0f || fwd_ave_time < 0.0f)