[CK_TILE] Move GEMM pipeline tail handling logic to pipelines (#2222)

* Add TailHandler for V3, V4 and Mem pipelines

* Adapt examples and tests to use TailHandler

* move tail-handling logic to pipeline in persistent grouped gemm

* Fix Mem pipeline dispatching, add CompV4 dispatching

* Use a macro for handling the many tails of Mem pipeline

* Fix formatting again

* Use const-ref RunFunction, remove unnecessary try_run
This commit is contained in:
Sami Remes
2025-06-04 11:50:21 +03:00
committed by GitHub
parent ffb52783d0
commit 7ea1508b59
10 changed files with 234 additions and 553 deletions

View File

@@ -50,6 +50,50 @@ struct BaseGemmPipelineAgBgCrCompV3
}
}
}
template <typename RunFunction>
CK_TILE_HOST_DEVICE static auto
TailHandler(const RunFunction& run_func, bool has_hot_loop, TailNumber tail_number)
{
// Handle all the valid cases.
if(has_hot_loop)
{
if(tail_number == TailNumber::Full)
{
return run_func(bool_constant<true>{},
integral_constant<TailNumber, TailNumber::Full>{});
}
}
else
{
if(tail_number == TailNumber::Odd)
{
return run_func(bool_constant<false>{},
integral_constant<TailNumber, TailNumber::Odd>{});
}
else if(tail_number == TailNumber::Even)
{
return run_func(bool_constant<false>{},
integral_constant<TailNumber, TailNumber::Even>{});
}
}
#if defined(__HIP_DEVICE_COMPILE__)
// This path should be unreachable in device code if tail_number is valid.
__builtin_unreachable();
#else
// If execution reaches here, it's an invalid combination of arguments.
if(has_hot_loop)
{
throw std::logic_error("Invalid TailNumber: If has_hot_loop is true, tail_number must "
"be TailNumber::Full.");
}
else
{
throw std::logic_error("Invalid TailNumber: If has_hot_loop is false, tail_number must "
"be TailNumber::Odd or TailNumber::Even.");
}
#endif
}
};
// Compute optimized pipeline
@@ -556,6 +600,42 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
p_smem);
}
/**
* @brief This function runs the pipeline by wrapping it with the tail handler.
*
* @note This is used by the persistent gemm kernel variants that don't determine
* hot loop and tail number on the host side, e.g. grouped gemm kernel.
*/
template <typename ADramBlockWindowTmp, typename BDramBlockWindowTmp>
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
const BDramBlockWindowTmp& b_dram_block_window_tmp,
index_t num_loop,
bool has_hot_loop,
TailNumber tail_number,
void* p_smem) const
{
const auto RunPipeline = [&](auto hot_loop_, auto tail_num_) {
constexpr bool hot_loop = hot_loop_.value;
constexpr auto tail_num = tail_num_.value;
constexpr auto PassThrough = [](const auto& x) { return x; };
return PipelineImpl<Scheduler>{}.template operator()<hot_loop, tail_num>(
a_dram_block_window_tmp,
PassThrough,
b_dram_block_window_tmp,
PassThrough,
num_loop,
p_smem);
};
return Base::TailHandler(RunPipeline, has_hot_loop, tail_number);
}
/**
* @brief This function runs the pipeline using compile-time known hot loop and tail number.
* @param num_loop The number of loop iterations. This is determined at runtime due to e.g.
* SplitK.
* @note This is used by the kernel variants that are able to determine
* hot loop and tail number on the host side, e.g. non-persistent gemm kernel.
*/
template <typename ADramBlockWindowTmp, typename BDramBlockWindowTmp>
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
const BDramBlockWindowTmp& b_dram_block_window_tmp,

View File

@@ -34,6 +34,46 @@ struct BaseGemmPipelineAgBgCrCompV4
return TailNumber::Two;
}
}
template <typename RunFunction>
CK_TILE_HOST_DEVICE static auto
TailHandler(const RunFunction& run_func, bool has_hot_loop, TailNumber tail_number)
{
// Handle all the valid cases.
if(has_hot_loop)
{
if(tail_number == TailNumber::Three)
{
return run_func(bool_constant<true>{},
integral_constant<TailNumber, TailNumber::Three>{});
}
else if(tail_number == TailNumber::Two)
{
return run_func(bool_constant<true>{},
integral_constant<TailNumber, TailNumber::Two>{});
}
}
else
{
if(tail_number == TailNumber::Three)
{
return run_func(bool_constant<false>{},
integral_constant<TailNumber, TailNumber::Three>{});
}
else if(tail_number == TailNumber::Two)
{
return run_func(bool_constant<false>{},
integral_constant<TailNumber, TailNumber::Two>{});
}
}
// If execution reaches here, it's an invalid tail_number because it wasn't handled above.
#if defined(__HIP_DEVICE_COMPILE__)
__builtin_unreachable();
#else
throw std::logic_error("Invalid TailNumber: Only TailNumber::Full and smaller than "
"PrefetchStages are supported.");
#endif
}
};
/**
@@ -572,5 +612,30 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem>
p_smem_0,
p_smem_1);
}
template <typename ADramBlockWindowTmp, typename BDramBlockWindowTmp>
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
const BDramBlockWindowTmp& b_dram_block_window_tmp,
index_t num_loop,
bool has_hot_loop,
TailNumber tail_number,
void* __restrict__ p_smem_0,
void* __restrict__ p_smem_1) const
{
const auto RunPipeline = [&](auto hot_loop_, auto tail_num_) {
constexpr bool hot_loop = hot_loop_.value;
constexpr auto tail_num = tail_num_.value;
constexpr auto PassThrough = [](const auto& x) { return x; };
return PipelineImpl<Scheduler>{}.template operator()<hot_loop, tail_num>(
a_dram_block_window_tmp,
PassThrough,
b_dram_block_window_tmp,
PassThrough,
num_loop,
p_smem_0,
p_smem_1);
};
return Base::TailHandler(RunPipeline, has_hot_loop, tail_number);
}
};
} // namespace ck_tile

View File

@@ -52,13 +52,14 @@ struct BaseGemmPipelineAgBgCrMem
static constexpr index_t LocalPrefillStages = 1;
static constexpr index_t GlobalBufferNum = PrefetchStages;
static constexpr bool UsePersistentKernel = Problem::Traits::UsePersistentKernel;
CK_TILE_HOST static constexpr bool BlockHasHotloop(index_t num_loop)
CK_TILE_HOST_DEVICE static constexpr bool BlockHasHotloop(index_t num_loop)
{
return num_loop > PrefetchStages;
}
CK_TILE_HOST static constexpr TailNumber GetBlockLoopTailNum(index_t num_loop)
CK_TILE_HOST_DEVICE static constexpr TailNumber GetBlockLoopTailNum(index_t num_loop)
{
if(num_loop % PrefetchStages == 1)
{
@@ -93,6 +94,56 @@ struct BaseGemmPipelineAgBgCrMem
return TailNumber::Full;
}
}
template <typename RunFunction>
CK_TILE_HOST_DEVICE static auto
TailHandler(const RunFunction& run_func, bool has_hot_loop, TailNumber tail_number)
{
// Wrap the hot_loop dispatch first.
auto tail_dispatch = [&](auto tail_num_constant) {
if(has_hot_loop)
{
return run_func(bool_constant<true>{}, tail_num_constant);
}
else
{
return run_func(bool_constant<false>{}, tail_num_constant);
}
};
#define CHECK_TAIL_NUMBER(TAIL_NUMBER, PREFETCH_VALUE) \
else if(tail_number == TailNumber::TAIL_NUMBER) \
{ \
if constexpr(PrefetchStages > PREFETCH_VALUE) \
{ \
return tail_dispatch(integral_constant<TailNumber, TailNumber::TAIL_NUMBER>{}); \
} \
}
// Handle all the valid cases.
if(tail_number == TailNumber::One)
{
return tail_dispatch(integral_constant<TailNumber, TailNumber::One>{});
}
else if(tail_number == TailNumber::Full)
{
return tail_dispatch(integral_constant<TailNumber, TailNumber::Full>{});
}
CHECK_TAIL_NUMBER(Two, 2)
CHECK_TAIL_NUMBER(Three, 3)
CHECK_TAIL_NUMBER(Four, 4)
CHECK_TAIL_NUMBER(Five, 5)
CHECK_TAIL_NUMBER(Six, 6)
CHECK_TAIL_NUMBER(Seven, 7)
#undef CHECK_TAIL_NUMBER
// We shouldn't get here unless we have a tail number larger than the prefetch stages.
#if defined(__HIP_DEVICE_COMPILE__)
__builtin_unreachable();
#else
throw std::logic_error("Invalid TailNumber: Only TailNumber::Full and smaller than "
"PrefetchStages are supported.");
#endif
}
};
// Maximum Global Memory throughput pipeline with >=32KB data in fly
@@ -749,6 +800,29 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
p_smem);
}
template <typename ADramBlockWindowTmp, typename BDramBlockWindowTmp>
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
const BDramBlockWindowTmp& b_dram_block_window_tmp,
index_t num_loop,
bool has_hot_loop,
TailNumber tail_number,
void* p_smem) const
{
const auto RunPipeline = [&](auto hot_loop_, auto tail_num_) {
constexpr bool hot_loop = hot_loop_.value;
constexpr auto tail_num = tail_num_.value;
constexpr auto PassThrough = [](const auto& x) { return x; };
return PipelineImpl<Scheduler>{}.template operator()<hot_loop, tail_num>(
a_dram_block_window_tmp,
PassThrough,
b_dram_block_window_tmp,
PassThrough,
num_loop,
p_smem);
};
return Base::TailHandler(RunPipeline, has_hot_loop, tail_number);
}
template <typename ADramBlockWindowTmp, typename BDramBlockWindowTmp>
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
const BDramBlockWindowTmp& b_dram_block_window_tmp,