[CK_TILE] MX Flatmm Split kernel instances (#3207)

* [CK_TILE] MX Flatmm Split kernel instances

* Fix flatmm example compile
This commit is contained in:
Yi DING
2025-11-18 13:46:30 +08:00
committed by GitHub
parent 92498464f6
commit b6720531de
12 changed files with 371 additions and 276 deletions

View File

@@ -23,22 +23,28 @@ struct BaseFlatmmPipelineAGmemBGmemCRegV1
{
return num_loop % 2 == 0 ? TailNumber::Even : TailNumber::Odd;
}
template <typename RunFunction>
template <bool DispatchHotloop = false, TailNumber tail_num, typename RunFunction>
CK_TILE_HOST_DEVICE static auto TailHandler(const RunFunction& run_func, bool has_hot_loop)
{
if constexpr(!DispatchHotloop)
return run_func(bool_constant<true>{}, integral_constant<TailNumber, tail_num>{});
else if(has_hot_loop)
return run_func(bool_constant<true>{}, integral_constant<TailNumber, tail_num>{});
else
return run_func(bool_constant<false>{}, integral_constant<TailNumber, tail_num>{});
}
template <bool DispatchHotloop = false, typename RunFunction>
CK_TILE_HOST_DEVICE static auto
TailHandler(const RunFunction& run_func, bool, TailNumber tail_num)
TailHandler(const RunFunction& run_func, bool has_hot_loop, TailNumber tail_num)
{
if(TailNumber::Even == tail_num)
{
return run_func(bool_constant<true>{},
integral_constant<TailNumber, TailNumber::Even>{});
}
return TailHandler<DispatchHotloop, TailNumber::Even>(run_func, has_hot_loop);
else if(TailNumber::Odd == tail_num)
{
return run_func(bool_constant<true>{},
integral_constant<TailNumber, TailNumber::Odd>{});
}
// return run_func(bool_constant<true>{}, integral_constant<TailNumber,
// TailNumber::Empty>{});
return TailHandler<DispatchHotloop, TailNumber::Odd>(run_func, has_hot_loop);
else
assert(("Wrong TailNumber!", false));
}
};

View File

@@ -216,17 +216,14 @@ struct UniversalFlatmmPipelineAgBgCrPolicy
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeA()
{
constexpr index_t smem_size_a = sizeof(typename Problem::ADataType) *
MakeALdsBlockDescriptor<Problem>().get_element_space_size();
return smem_size_a;
return sizeof(typename Problem::ADataType) *
MakeALdsBlockDescriptor<Problem>().get_element_space_size();
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
{
constexpr index_t smem_size_a = GetSmemSizeA<Problem>();
return smem_size_a;
return GetSmemSizeA<Problem>();
}
template <typename Problem>