mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
[CK_TILE] MX Flatmm Split kernel instances (#3207)
* [CK_TILE] MX Flatmm Split kernel instances * Fix flatmm example compile
This commit is contained in:
@@ -23,22 +23,28 @@ struct BaseFlatmmPipelineAGmemBGmemCRegV1
|
||||
{
|
||||
return num_loop % 2 == 0 ? TailNumber::Even : TailNumber::Odd;
|
||||
}
|
||||
template <typename RunFunction>
|
||||
|
||||
template <bool DispatchHotloop = false, TailNumber tail_num, typename RunFunction>
|
||||
CK_TILE_HOST_DEVICE static auto TailHandler(const RunFunction& run_func, bool has_hot_loop)
|
||||
{
|
||||
if constexpr(!DispatchHotloop)
|
||||
return run_func(bool_constant<true>{}, integral_constant<TailNumber, tail_num>{});
|
||||
else if(has_hot_loop)
|
||||
return run_func(bool_constant<true>{}, integral_constant<TailNumber, tail_num>{});
|
||||
else
|
||||
return run_func(bool_constant<false>{}, integral_constant<TailNumber, tail_num>{});
|
||||
}
|
||||
|
||||
template <bool DispatchHotloop = false, typename RunFunction>
|
||||
CK_TILE_HOST_DEVICE static auto
|
||||
TailHandler(const RunFunction& run_func, bool, TailNumber tail_num)
|
||||
TailHandler(const RunFunction& run_func, bool has_hot_loop, TailNumber tail_num)
|
||||
{
|
||||
if(TailNumber::Even == tail_num)
|
||||
{
|
||||
return run_func(bool_constant<true>{},
|
||||
integral_constant<TailNumber, TailNumber::Even>{});
|
||||
}
|
||||
return TailHandler<DispatchHotloop, TailNumber::Even>(run_func, has_hot_loop);
|
||||
else if(TailNumber::Odd == tail_num)
|
||||
{
|
||||
return run_func(bool_constant<true>{},
|
||||
integral_constant<TailNumber, TailNumber::Odd>{});
|
||||
}
|
||||
// return run_func(bool_constant<true>{}, integral_constant<TailNumber,
|
||||
// TailNumber::Empty>{});
|
||||
return TailHandler<DispatchHotloop, TailNumber::Odd>(run_func, has_hot_loop);
|
||||
else
|
||||
assert(("Wrong TailNumber!", false));
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -216,17 +216,14 @@ struct UniversalFlatmmPipelineAgBgCrPolicy
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeA()
|
||||
{
|
||||
constexpr index_t smem_size_a = sizeof(typename Problem::ADataType) *
|
||||
MakeALdsBlockDescriptor<Problem>().get_element_space_size();
|
||||
return smem_size_a;
|
||||
return sizeof(typename Problem::ADataType) *
|
||||
MakeALdsBlockDescriptor<Problem>().get_element_space_size();
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
|
||||
{
|
||||
constexpr index_t smem_size_a = GetSmemSizeA<Problem>();
|
||||
|
||||
return smem_size_a;
|
||||
return GetSmemSizeA<Problem>();
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
|
||||
Reference in New Issue
Block a user