mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
Using is_using_trload_v to check the kUseTrLoad from pipeline
This commit is contained in:
@@ -78,6 +78,23 @@ struct has_naive_hdim_load_flag<
|
||||
template <typename T>
|
||||
static inline constexpr bool is_naive_hdim_load_v = has_naive_hdim_load_flag<T>::value;
|
||||
|
||||
// A helper struct for detechting kUseTrLoad
|
||||
template <typename T, typename = void>
|
||||
struct has_use_trload_flag : std::false_type
|
||||
{
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct has_use_trload_flag<
|
||||
T,
|
||||
std::enable_if_t<std::is_convertible_v<decltype(T::kUseTrLoad), bool> && T::kUseTrLoad>>
|
||||
: std::true_type
|
||||
{
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
static inline constexpr bool is_using_trload_v = has_use_trload_flag<T>::value;
|
||||
|
||||
}; // namespace detail
|
||||
|
||||
template <typename FmhaPipeline_, typename EpiloguePipeline_>
|
||||
@@ -121,13 +138,8 @@ struct FmhaFwdKernel
|
||||
static constexpr bool kHasMask = FmhaMask::IsMasking;
|
||||
|
||||
static constexpr bool kUseAsyncCopy = FmhaPipeline::Policy::AsyncCopy;
|
||||
static constexpr bool kUseTrLoad = detail::is_using_trload_v<FmhaPipeline>;
|
||||
|
||||
static constexpr bool kUseTrLoad = FmhaPipeline::Problem::kUseTrLoad;
|
||||
#if defined(__gfx950__)
|
||||
static constexpr bool kIsAvailable = true;
|
||||
#else
|
||||
static constexpr bool kIsAvailable = !kUseTrLoad;
|
||||
#endif
|
||||
static constexpr std::string_view kPipelineName = FmhaPipeline::name;
|
||||
|
||||
// clang-format off
|
||||
@@ -1177,11 +1189,7 @@ struct FmhaFwdKernel
|
||||
return ck_tile::max(FmhaPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void operator()(Kargs kargs) const
|
||||
{
|
||||
if constexpr(kIsAvailable)
|
||||
run_(std::move(kargs));
|
||||
}
|
||||
CK_TILE_DEVICE void operator()(Kargs kargs) const { run_(std::move(kargs)); }
|
||||
|
||||
CK_TILE_DEVICE void run_(Kargs kargs) const
|
||||
{
|
||||
@@ -1346,10 +1354,10 @@ struct FmhaFwdKernel
|
||||
number<1>{});
|
||||
if constexpr(FmhaPipeline::kQLoadOnce)
|
||||
{
|
||||
return pad_tensor_view(q_dram_naive,
|
||||
make_tuple(number<FmhaPipeline::kM0>{},
|
||||
number<kQKHeaddimToUse>{}),
|
||||
sequence<kPadSeqLenQ, kPadHeadDimQ>{});
|
||||
return pad_tensor_view(
|
||||
q_dram_naive,
|
||||
make_tuple(number<FmhaPipeline::kM0>{}, number<kQKHeaddimToUse>{}),
|
||||
sequence<kPadSeqLenQ, kPadHeadDimQ>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -1371,10 +1379,10 @@ struct FmhaFwdKernel
|
||||
|
||||
if constexpr(detail::is_n0loop_pipeline_v<FmhaPipeline>)
|
||||
{
|
||||
return pad_tensor_view(k_dram_naive,
|
||||
make_tuple(number<FmhaPipeline::kN0Sub>{},
|
||||
number<kQKHeaddimToUse>{}),
|
||||
sequence<kPadSeqLenK_, kPadHeadDimQ>{});
|
||||
return pad_tensor_view(
|
||||
k_dram_naive,
|
||||
make_tuple(number<FmhaPipeline::kN0Sub>{}, number<kQKHeaddimToUse>{}),
|
||||
sequence<kPadSeqLenK_, kPadHeadDimQ>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -1439,8 +1447,7 @@ struct FmhaFwdKernel
|
||||
q_dram,
|
||||
[&]() {
|
||||
if constexpr(FmhaPipeline::kQLoadOnce)
|
||||
return make_tuple(number<FmhaPipeline::kM0>{},
|
||||
number<kQKHeaddimToUse>{});
|
||||
return make_tuple(number<FmhaPipeline::kM0>{}, number<kQKHeaddimToUse>{});
|
||||
else
|
||||
return make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kK0>{});
|
||||
}(),
|
||||
@@ -1449,10 +1456,10 @@ struct FmhaFwdKernel
|
||||
auto k_dram_window = [&]() {
|
||||
if constexpr(detail::is_n0loop_pipeline_v<FmhaPipeline>)
|
||||
{
|
||||
return make_tile_window(k_dram,
|
||||
make_tuple(number<FmhaPipeline::kN0Sub>{},
|
||||
number<kQKHeaddimToUse>{}),
|
||||
{0, 0});
|
||||
return make_tile_window(
|
||||
k_dram,
|
||||
make_tuple(number<FmhaPipeline::kN0Sub>{}, number<kQKHeaddimToUse>{}),
|
||||
{0, 0});
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user