Using is_using_trload_v to check the kUseTrLoad from pipeline

This commit is contained in:
Qianfeng Zhang
2025-12-20 10:23:30 +00:00
parent eb598a9d1e
commit 57abd10b95

View File

@@ -78,6 +78,23 @@ struct has_naive_hdim_load_flag<
template <typename T>
static inline constexpr bool is_naive_hdim_load_v = has_naive_hdim_load_flag<T>::value;
// A helper struct for detechting kUseTrLoad
template <typename T, typename = void>
struct has_use_trload_flag : std::false_type
{
};
template <typename T>
struct has_use_trload_flag<
T,
std::enable_if_t<std::is_convertible_v<decltype(T::kUseTrLoad), bool> && T::kUseTrLoad>>
: std::true_type
{
};
template <typename T>
static inline constexpr bool is_using_trload_v = has_use_trload_flag<T>::value;
}; // namespace detail
template <typename FmhaPipeline_, typename EpiloguePipeline_>
@@ -121,13 +138,8 @@ struct FmhaFwdKernel
static constexpr bool kHasMask = FmhaMask::IsMasking;
static constexpr bool kUseAsyncCopy = FmhaPipeline::Policy::AsyncCopy;
static constexpr bool kUseTrLoad = detail::is_using_trload_v<FmhaPipeline>;
static constexpr bool kUseTrLoad = FmhaPipeline::Problem::kUseTrLoad;
#if defined(__gfx950__)
static constexpr bool kIsAvailable = true;
#else
static constexpr bool kIsAvailable = !kUseTrLoad;
#endif
static constexpr std::string_view kPipelineName = FmhaPipeline::name;
// clang-format off
@@ -1177,11 +1189,7 @@ struct FmhaFwdKernel
return ck_tile::max(FmhaPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
}
CK_TILE_DEVICE void operator()(Kargs kargs) const
{
if constexpr(kIsAvailable)
run_(std::move(kargs));
}
CK_TILE_DEVICE void operator()(Kargs kargs) const { run_(std::move(kargs)); }
CK_TILE_DEVICE void run_(Kargs kargs) const
{
@@ -1346,10 +1354,10 @@ struct FmhaFwdKernel
number<1>{});
if constexpr(FmhaPipeline::kQLoadOnce)
{
return pad_tensor_view(q_dram_naive,
make_tuple(number<FmhaPipeline::kM0>{},
number<kQKHeaddimToUse>{}),
sequence<kPadSeqLenQ, kPadHeadDimQ>{});
return pad_tensor_view(
q_dram_naive,
make_tuple(number<FmhaPipeline::kM0>{}, number<kQKHeaddimToUse>{}),
sequence<kPadSeqLenQ, kPadHeadDimQ>{});
}
else
{
@@ -1371,10 +1379,10 @@ struct FmhaFwdKernel
if constexpr(detail::is_n0loop_pipeline_v<FmhaPipeline>)
{
return pad_tensor_view(k_dram_naive,
make_tuple(number<FmhaPipeline::kN0Sub>{},
number<kQKHeaddimToUse>{}),
sequence<kPadSeqLenK_, kPadHeadDimQ>{});
return pad_tensor_view(
k_dram_naive,
make_tuple(number<FmhaPipeline::kN0Sub>{}, number<kQKHeaddimToUse>{}),
sequence<kPadSeqLenK_, kPadHeadDimQ>{});
}
else
{
@@ -1439,8 +1447,7 @@ struct FmhaFwdKernel
q_dram,
[&]() {
if constexpr(FmhaPipeline::kQLoadOnce)
return make_tuple(number<FmhaPipeline::kM0>{},
number<kQKHeaddimToUse>{});
return make_tuple(number<FmhaPipeline::kM0>{}, number<kQKHeaddimToUse>{});
else
return make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kK0>{});
}(),
@@ -1449,10 +1456,10 @@ struct FmhaFwdKernel
auto k_dram_window = [&]() {
if constexpr(detail::is_n0loop_pipeline_v<FmhaPipeline>)
{
return make_tile_window(k_dram,
make_tuple(number<FmhaPipeline::kN0Sub>{},
number<kQKHeaddimToUse>{}),
{0, 0});
return make_tile_window(
k_dram,
make_tuple(number<FmhaPipeline::kN0Sub>{}, number<kQKHeaddimToUse>{}),
{0, 0});
}
else
{