diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp index 367c569769..a7ade264f5 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp @@ -78,6 +78,23 @@ struct has_naive_hdim_load_flag< template static inline constexpr bool is_naive_hdim_load_v = has_naive_hdim_load_flag::value; +// A helper struct for detechting kUseTrLoad +template +struct has_use_trload_flag : std::false_type +{ +}; + +template +struct has_use_trload_flag< + T, + std::enable_if_t && T::kUseTrLoad>> + : std::true_type +{ +}; + +template +static inline constexpr bool is_using_trload_v = has_use_trload_flag::value; + }; // namespace detail template @@ -121,13 +138,8 @@ struct FmhaFwdKernel static constexpr bool kHasMask = FmhaMask::IsMasking; static constexpr bool kUseAsyncCopy = FmhaPipeline::Policy::AsyncCopy; + static constexpr bool kUseTrLoad = detail::is_using_trload_v; - static constexpr bool kUseTrLoad = FmhaPipeline::Problem::kUseTrLoad; -#if defined(__gfx950__) - static constexpr bool kIsAvailable = true; -#else - static constexpr bool kIsAvailable = !kUseTrLoad; -#endif static constexpr std::string_view kPipelineName = FmhaPipeline::name; // clang-format off @@ -1177,11 +1189,7 @@ struct FmhaFwdKernel return ck_tile::max(FmhaPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize()); } - CK_TILE_DEVICE void operator()(Kargs kargs) const - { - if constexpr(kIsAvailable) - run_(std::move(kargs)); - } + CK_TILE_DEVICE void operator()(Kargs kargs) const { run_(std::move(kargs)); } CK_TILE_DEVICE void run_(Kargs kargs) const { @@ -1346,10 +1354,10 @@ struct FmhaFwdKernel number<1>{}); if constexpr(FmhaPipeline::kQLoadOnce) { - return pad_tensor_view(q_dram_naive, - make_tuple(number{}, - number{}), - sequence{}); + return pad_tensor_view( + q_dram_naive, + make_tuple(number{}, number{}), + sequence{}); } else { @@ -1371,10 +1379,10 @@ struct FmhaFwdKernel if constexpr(detail::is_n0loop_pipeline_v) { - return pad_tensor_view(k_dram_naive, - make_tuple(number{}, - number{}), - sequence{}); + return pad_tensor_view( + k_dram_naive, + make_tuple(number{}, number{}), + sequence{}); } else { @@ -1439,8 +1447,7 @@ struct FmhaFwdKernel q_dram, [&]() { if constexpr(FmhaPipeline::kQLoadOnce) - return make_tuple(number{}, - number{}); + return make_tuple(number{}, number{}); else return make_tuple(number{}, number{}); }(), @@ -1449,10 +1456,10 @@ struct FmhaFwdKernel auto k_dram_window = [&]() { if constexpr(detail::is_n0loop_pipeline_v) { - return make_tile_window(k_dram, - make_tuple(number{}, - number{}), - {0, 0}); + return make_tile_window( + k_dram, + make_tuple(number{}, number{}), + {0, 0}); } else {