mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-02 04:31:25 +00:00
[CK_TILE] Move GEMM pipeline tail handling logic to pipelines (#2222)
* Add TailHandler for V3, V4 and Mem pipelines * Adapt examples and tests to use TailHandler * move tail-handling logic to pipeline in persistent grouped gemm * Fix Mem pipeline dispatching, add CompV4 dispatching * Use a macro for handling the many tails of Mem pipeline * Fix formatting again * Use const-ref RunFunction, remove unnecessary try_run
This commit is contained in:
@@ -50,6 +50,50 @@ struct BaseGemmPipelineAgBgCrCompV3
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename RunFunction>
|
||||
CK_TILE_HOST_DEVICE static auto
|
||||
TailHandler(const RunFunction& run_func, bool has_hot_loop, TailNumber tail_number)
|
||||
{
|
||||
// Handle all the valid cases.
|
||||
if(has_hot_loop)
|
||||
{
|
||||
if(tail_number == TailNumber::Full)
|
||||
{
|
||||
return run_func(bool_constant<true>{},
|
||||
integral_constant<TailNumber, TailNumber::Full>{});
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if(tail_number == TailNumber::Odd)
|
||||
{
|
||||
return run_func(bool_constant<false>{},
|
||||
integral_constant<TailNumber, TailNumber::Odd>{});
|
||||
}
|
||||
else if(tail_number == TailNumber::Even)
|
||||
{
|
||||
return run_func(bool_constant<false>{},
|
||||
integral_constant<TailNumber, TailNumber::Even>{});
|
||||
}
|
||||
}
|
||||
#if defined(__HIP_DEVICE_COMPILE__)
|
||||
// This path should be unreachable in device code if tail_number is valid.
|
||||
__builtin_unreachable();
|
||||
#else
|
||||
// If execution reaches here, it's an invalid combination of arguments.
|
||||
if(has_hot_loop)
|
||||
{
|
||||
throw std::logic_error("Invalid TailNumber: If has_hot_loop is true, tail_number must "
|
||||
"be TailNumber::Full.");
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::logic_error("Invalid TailNumber: If has_hot_loop is false, tail_number must "
|
||||
"be TailNumber::Odd or TailNumber::Even.");
|
||||
}
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
// Compute optimized pipeline
|
||||
@@ -556,6 +600,42 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
|
||||
p_smem);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief This function runs the pipeline by wrapping it with the tail handler.
|
||||
*
|
||||
* @note This is used by the persistent gemm kernel variants that don't determine
|
||||
* hot loop and tail number on the host side, e.g. grouped gemm kernel.
|
||||
*/
|
||||
template <typename ADramBlockWindowTmp, typename BDramBlockWindowTmp>
|
||||
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const BDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
index_t num_loop,
|
||||
bool has_hot_loop,
|
||||
TailNumber tail_number,
|
||||
void* p_smem) const
|
||||
{
|
||||
const auto RunPipeline = [&](auto hot_loop_, auto tail_num_) {
|
||||
constexpr bool hot_loop = hot_loop_.value;
|
||||
constexpr auto tail_num = tail_num_.value;
|
||||
constexpr auto PassThrough = [](const auto& x) { return x; };
|
||||
return PipelineImpl<Scheduler>{}.template operator()<hot_loop, tail_num>(
|
||||
a_dram_block_window_tmp,
|
||||
PassThrough,
|
||||
b_dram_block_window_tmp,
|
||||
PassThrough,
|
||||
num_loop,
|
||||
p_smem);
|
||||
};
|
||||
return Base::TailHandler(RunPipeline, has_hot_loop, tail_number);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief This function runs the pipeline using compile-time known hot loop and tail number.
|
||||
* @param num_loop The number of loop iterations. This is determined at runtime due to e.g.
|
||||
* SplitK.
|
||||
* @note This is used by the kernel variants that are able to determine
|
||||
* hot loop and tail number on the host side, e.g. non-persistent gemm kernel.
|
||||
*/
|
||||
template <typename ADramBlockWindowTmp, typename BDramBlockWindowTmp>
|
||||
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const BDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
|
||||
@@ -34,6 +34,46 @@ struct BaseGemmPipelineAgBgCrCompV4
|
||||
return TailNumber::Two;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename RunFunction>
|
||||
CK_TILE_HOST_DEVICE static auto
|
||||
TailHandler(const RunFunction& run_func, bool has_hot_loop, TailNumber tail_number)
|
||||
{
|
||||
// Handle all the valid cases.
|
||||
if(has_hot_loop)
|
||||
{
|
||||
if(tail_number == TailNumber::Three)
|
||||
{
|
||||
return run_func(bool_constant<true>{},
|
||||
integral_constant<TailNumber, TailNumber::Three>{});
|
||||
}
|
||||
else if(tail_number == TailNumber::Two)
|
||||
{
|
||||
return run_func(bool_constant<true>{},
|
||||
integral_constant<TailNumber, TailNumber::Two>{});
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if(tail_number == TailNumber::Three)
|
||||
{
|
||||
return run_func(bool_constant<false>{},
|
||||
integral_constant<TailNumber, TailNumber::Three>{});
|
||||
}
|
||||
else if(tail_number == TailNumber::Two)
|
||||
{
|
||||
return run_func(bool_constant<false>{},
|
||||
integral_constant<TailNumber, TailNumber::Two>{});
|
||||
}
|
||||
}
|
||||
// If execution reaches here, it's an invalid tail_number because it wasn't handled above.
|
||||
#if defined(__HIP_DEVICE_COMPILE__)
|
||||
__builtin_unreachable();
|
||||
#else
|
||||
throw std::logic_error("Invalid TailNumber: Only TailNumber::Full and smaller than "
|
||||
"PrefetchStages are supported.");
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
@@ -572,5 +612,30 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem>
|
||||
p_smem_0,
|
||||
p_smem_1);
|
||||
}
|
||||
|
||||
template <typename ADramBlockWindowTmp, typename BDramBlockWindowTmp>
|
||||
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const BDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
index_t num_loop,
|
||||
bool has_hot_loop,
|
||||
TailNumber tail_number,
|
||||
void* __restrict__ p_smem_0,
|
||||
void* __restrict__ p_smem_1) const
|
||||
{
|
||||
const auto RunPipeline = [&](auto hot_loop_, auto tail_num_) {
|
||||
constexpr bool hot_loop = hot_loop_.value;
|
||||
constexpr auto tail_num = tail_num_.value;
|
||||
constexpr auto PassThrough = [](const auto& x) { return x; };
|
||||
return PipelineImpl<Scheduler>{}.template operator()<hot_loop, tail_num>(
|
||||
a_dram_block_window_tmp,
|
||||
PassThrough,
|
||||
b_dram_block_window_tmp,
|
||||
PassThrough,
|
||||
num_loop,
|
||||
p_smem_0,
|
||||
p_smem_1);
|
||||
};
|
||||
return Base::TailHandler(RunPipeline, has_hot_loop, tail_number);
|
||||
}
|
||||
};
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -52,13 +52,14 @@ struct BaseGemmPipelineAgBgCrMem
|
||||
|
||||
static constexpr index_t LocalPrefillStages = 1;
|
||||
static constexpr index_t GlobalBufferNum = PrefetchStages;
|
||||
static constexpr bool UsePersistentKernel = Problem::Traits::UsePersistentKernel;
|
||||
|
||||
CK_TILE_HOST static constexpr bool BlockHasHotloop(index_t num_loop)
|
||||
CK_TILE_HOST_DEVICE static constexpr bool BlockHasHotloop(index_t num_loop)
|
||||
{
|
||||
return num_loop > PrefetchStages;
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr TailNumber GetBlockLoopTailNum(index_t num_loop)
|
||||
CK_TILE_HOST_DEVICE static constexpr TailNumber GetBlockLoopTailNum(index_t num_loop)
|
||||
{
|
||||
if(num_loop % PrefetchStages == 1)
|
||||
{
|
||||
@@ -93,6 +94,56 @@ struct BaseGemmPipelineAgBgCrMem
|
||||
return TailNumber::Full;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename RunFunction>
|
||||
CK_TILE_HOST_DEVICE static auto
|
||||
TailHandler(const RunFunction& run_func, bool has_hot_loop, TailNumber tail_number)
|
||||
{
|
||||
// Wrap the hot_loop dispatch first.
|
||||
auto tail_dispatch = [&](auto tail_num_constant) {
|
||||
if(has_hot_loop)
|
||||
{
|
||||
return run_func(bool_constant<true>{}, tail_num_constant);
|
||||
}
|
||||
else
|
||||
{
|
||||
return run_func(bool_constant<false>{}, tail_num_constant);
|
||||
}
|
||||
};
|
||||
|
||||
#define CHECK_TAIL_NUMBER(TAIL_NUMBER, PREFETCH_VALUE) \
|
||||
else if(tail_number == TailNumber::TAIL_NUMBER) \
|
||||
{ \
|
||||
if constexpr(PrefetchStages > PREFETCH_VALUE) \
|
||||
{ \
|
||||
return tail_dispatch(integral_constant<TailNumber, TailNumber::TAIL_NUMBER>{}); \
|
||||
} \
|
||||
}
|
||||
// Handle all the valid cases.
|
||||
if(tail_number == TailNumber::One)
|
||||
{
|
||||
return tail_dispatch(integral_constant<TailNumber, TailNumber::One>{});
|
||||
}
|
||||
else if(tail_number == TailNumber::Full)
|
||||
{
|
||||
return tail_dispatch(integral_constant<TailNumber, TailNumber::Full>{});
|
||||
}
|
||||
CHECK_TAIL_NUMBER(Two, 2)
|
||||
CHECK_TAIL_NUMBER(Three, 3)
|
||||
CHECK_TAIL_NUMBER(Four, 4)
|
||||
CHECK_TAIL_NUMBER(Five, 5)
|
||||
CHECK_TAIL_NUMBER(Six, 6)
|
||||
CHECK_TAIL_NUMBER(Seven, 7)
|
||||
#undef CHECK_TAIL_NUMBER
|
||||
|
||||
// We shouldn't get here unless we have a tail number larger than the prefetch stages.
|
||||
#if defined(__HIP_DEVICE_COMPILE__)
|
||||
__builtin_unreachable();
|
||||
#else
|
||||
throw std::logic_error("Invalid TailNumber: Only TailNumber::Full and smaller than "
|
||||
"PrefetchStages are supported.");
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
// Maximum Global Memory throughput pipeline with >=32KB data in fly
|
||||
@@ -749,6 +800,29 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
|
||||
p_smem);
|
||||
}
|
||||
|
||||
template <typename ADramBlockWindowTmp, typename BDramBlockWindowTmp>
|
||||
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const BDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
index_t num_loop,
|
||||
bool has_hot_loop,
|
||||
TailNumber tail_number,
|
||||
void* p_smem) const
|
||||
{
|
||||
const auto RunPipeline = [&](auto hot_loop_, auto tail_num_) {
|
||||
constexpr bool hot_loop = hot_loop_.value;
|
||||
constexpr auto tail_num = tail_num_.value;
|
||||
constexpr auto PassThrough = [](const auto& x) { return x; };
|
||||
return PipelineImpl<Scheduler>{}.template operator()<hot_loop, tail_num>(
|
||||
a_dram_block_window_tmp,
|
||||
PassThrough,
|
||||
b_dram_block_window_tmp,
|
||||
PassThrough,
|
||||
num_loop,
|
||||
p_smem);
|
||||
};
|
||||
return Base::TailHandler(RunPipeline, has_hot_loop, tail_number);
|
||||
}
|
||||
|
||||
template <typename ADramBlockWindowTmp, typename BDramBlockWindowTmp>
|
||||
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const BDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
|
||||
Reference in New Issue
Block a user