mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
[CK-Tile] Refactor base pipeline usage (#3251)
* initial poc * factor out common parts in operator() * cv4 * rest of the universal gemm pipelines * fix test * remove boilerplate from tile engine * fix example * fix example * format * fix tests build for gemm * remove base pipeline codegen from gemm instance builder * unify v3 logic with the rest of universal gemm pipelines * fix build for multi abd test * fix test gemm multi d * fix build for weight preshuffle * fix grouped gemm test * fix grouped gemm multi d test * fix grouped gemm preshuffle * fix grouped gemm example except for quant * fix gemm preshuffle * fix splitk 2 stage example * fix batched gemm example * fix multid example * fix multiabd example * fix batched gemm test * fixup * fix examples build * fix grouped gemm test build * fix smoke builder
This commit is contained in:
@@ -19,12 +19,12 @@ struct BaseGemmPipelineAgBgCrCompAsync
|
||||
static constexpr index_t PrefillStages = 1;
|
||||
static constexpr index_t GlobalBufferNum = 1;
|
||||
|
||||
CK_TILE_HOST static constexpr bool BlockHasHotloop(index_t num_loop)
|
||||
CK_TILE_HOST_DEVICE static constexpr bool BlockHasHotloop(index_t num_loop)
|
||||
{
|
||||
return num_loop > PrefetchStages;
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr TailNumber GetBlockLoopTailNum(index_t num_loop)
|
||||
CK_TILE_HOST_DEVICE static constexpr TailNumber GetBlockLoopTailNum(index_t num_loop)
|
||||
{
|
||||
if(num_loop == 1)
|
||||
{
|
||||
@@ -158,9 +158,7 @@ struct GemmPipelineAgBgCrCompAsync : public BaseGemmPipelineAgBgCrCompAsync<Prob
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = Problem::DoubleSmemBuffer;
|
||||
|
||||
static constexpr bool HasHotLoop = Problem::HasHotLoop;
|
||||
static constexpr auto TailNum = Problem::TailNum;
|
||||
static constexpr auto Scheduler = Problem::Scheduler;
|
||||
static constexpr auto Scheduler = Problem::Scheduler;
|
||||
|
||||
static constexpr auto is_a_load_tr_v = bool_constant<PipelineImplBase::is_a_load_tr>{};
|
||||
static constexpr auto is_b_load_tr_v = bool_constant<PipelineImplBase::is_b_load_tr>{};
|
||||
@@ -539,14 +537,21 @@ struct GemmPipelineAgBgCrCompAsync : public BaseGemmPipelineAgBgCrCompAsync<Prob
|
||||
void* p_smem_0,
|
||||
void* p_smem_1) const
|
||||
{
|
||||
return PipelineImpl<Scheduler>{}.template operator()<HasHotLoop, TailNum>(
|
||||
a_dram_block_window_tmp,
|
||||
a_element_func,
|
||||
b_dram_block_window_tmp,
|
||||
b_element_func,
|
||||
num_loop,
|
||||
p_smem_0,
|
||||
p_smem_1);
|
||||
const bool has_hot_loop = Base::BlockHasHotloop(num_loop);
|
||||
const auto tail_number = Base::GetBlockLoopTailNum(num_loop);
|
||||
|
||||
const auto RunPipeline = [&](auto hot_loop_, auto tail_num_) {
|
||||
return PipelineImpl<Scheduler>{}.template operator()<hot_loop_.value, tail_num_.value>(
|
||||
a_dram_block_window_tmp,
|
||||
a_element_func,
|
||||
b_dram_block_window_tmp,
|
||||
b_element_func,
|
||||
num_loop,
|
||||
p_smem_0,
|
||||
p_smem_1);
|
||||
};
|
||||
|
||||
return Base::TailHandler(RunPipeline, has_hot_loop, tail_number);
|
||||
}
|
||||
|
||||
public:
|
||||
@@ -557,14 +562,21 @@ struct GemmPipelineAgBgCrCompAsync : public BaseGemmPipelineAgBgCrCompAsync<Prob
|
||||
void* __restrict__ p_smem_0,
|
||||
void* __restrict__ p_smem_1) const
|
||||
{
|
||||
return PipelineImpl<Scheduler>{}.template operator()<HasHotLoop, TailNum>(
|
||||
a_dram_block_window_tmp,
|
||||
[](const ADataType& a) { return a; },
|
||||
b_dram_block_window_tmp,
|
||||
[](const BDataType& b) { return b; },
|
||||
num_loop,
|
||||
p_smem_0,
|
||||
p_smem_1);
|
||||
const bool has_hot_loop = Base::BlockHasHotloop(num_loop);
|
||||
const auto tail_number = Base::GetBlockLoopTailNum(num_loop);
|
||||
|
||||
const auto RunPipeline = [&](auto hot_loop_, auto tail_num_) {
|
||||
return PipelineImpl<Scheduler>{}.template operator()<hot_loop_.value, tail_num_.value>(
|
||||
a_dram_block_window_tmp,
|
||||
[](const ADataType& a) { return a; },
|
||||
b_dram_block_window_tmp,
|
||||
[](const BDataType& b) { return b; },
|
||||
num_loop,
|
||||
p_smem_0,
|
||||
p_smem_1);
|
||||
};
|
||||
|
||||
return Base::TailHandler(RunPipeline, has_hot_loop, tail_number);
|
||||
}
|
||||
};
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -154,10 +154,6 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
|
||||
static constexpr index_t NumWaveGroups = Problem::NumWaveGroups;
|
||||
static constexpr index_t Preshuffle = Problem::Preshuffle;
|
||||
|
||||
static constexpr bool HasHotLoop =
|
||||
Problem::HasHotLoop; // Base::BlockHasHotloop(Problem::num_loop);
|
||||
static constexpr auto TailNum =
|
||||
Problem::TailNum; // Base::GetBlockLoopTailNum(Problem::num_loop);
|
||||
static constexpr auto Scheduler = Problem::Scheduler;
|
||||
|
||||
static constexpr auto is_a_load_tr_v = bool_constant<PipelineImplBase::is_a_load_tr>{};
|
||||
@@ -641,13 +637,20 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
|
||||
index_t num_loop,
|
||||
void* p_smem) const
|
||||
{
|
||||
return PipelineImpl<Scheduler>{}.template operator()<HasHotLoop, TailNum>(
|
||||
a_dram_block_window_tmp,
|
||||
a_element_func,
|
||||
b_dram_block_window_tmp,
|
||||
b_element_func,
|
||||
num_loop,
|
||||
p_smem);
|
||||
const bool has_hot_loop = Base::BlockHasHotloop(num_loop);
|
||||
const auto tail_number = Base::GetBlockLoopTailNum(num_loop);
|
||||
|
||||
const auto RunPipeline = [&](auto hot_loop_, auto tail_num_) {
|
||||
return PipelineImpl<Scheduler>{}.template operator()<hot_loop_.value, tail_num_.value>(
|
||||
a_dram_block_window_tmp,
|
||||
a_element_func,
|
||||
b_dram_block_window_tmp,
|
||||
b_element_func,
|
||||
num_loop,
|
||||
p_smem);
|
||||
};
|
||||
|
||||
return Base::TailHandler(RunPipeline, has_hot_loop, tail_number);
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -700,13 +703,15 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
|
||||
index_t num_loop,
|
||||
void* p_smem) const
|
||||
{
|
||||
return PipelineImpl<Scheduler>{}.template operator()<HasHotLoop, TailNum>(
|
||||
a_dram_block_window_tmp,
|
||||
[](auto& e, const ADataType& a) { e = a; },
|
||||
b_dram_block_window_tmp,
|
||||
[](auto& e, const BDataType& b) { e = b; },
|
||||
num_loop,
|
||||
p_smem);
|
||||
const bool has_hot_loop = Base::BlockHasHotloop(num_loop);
|
||||
const auto tail_number = Base::GetBlockLoopTailNum(num_loop);
|
||||
|
||||
return operator()(a_dram_block_window_tmp,
|
||||
b_dram_block_window_tmp,
|
||||
num_loop,
|
||||
has_hot_loop,
|
||||
tail_number,
|
||||
p_smem);
|
||||
}
|
||||
|
||||
template <typename AsDramBlockWindowTmp,
|
||||
|
||||
@@ -167,9 +167,7 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem>
|
||||
static constexpr index_t NumWaveGroups = Problem::NumWaveGroups;
|
||||
static constexpr index_t Preshuffle = Problem::Preshuffle;
|
||||
|
||||
static constexpr bool HasHotLoop = Problem::HasHotLoop;
|
||||
static constexpr auto TailNum = Problem::TailNum;
|
||||
static constexpr auto Scheduler = Problem::Scheduler;
|
||||
static constexpr auto Scheduler = Problem::Scheduler;
|
||||
|
||||
static constexpr auto is_a_load_tr_v = bool_constant<PipelineImplBase::is_a_load_tr>{};
|
||||
static constexpr auto is_b_load_tr_v = bool_constant<PipelineImplBase::is_b_load_tr>{};
|
||||
@@ -685,14 +683,21 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem>
|
||||
void* p_smem_0,
|
||||
void* p_smem_1) const
|
||||
{
|
||||
return PipelineImpl<Scheduler>{}.template operator()<HasHotLoop, TailNum>(
|
||||
a_dram_block_window_tmp,
|
||||
a_element_func,
|
||||
b_dram_block_window_tmp,
|
||||
b_element_func,
|
||||
num_loop,
|
||||
p_smem_0,
|
||||
p_smem_1);
|
||||
const bool has_hot_loop = Base::BlockHasHotloop(num_loop);
|
||||
const auto tail_number = Base::GetBlockLoopTailNum(num_loop);
|
||||
|
||||
const auto RunPipeline = [&](auto hot_loop_, auto tail_num_) {
|
||||
return PipelineImpl<Scheduler>{}.template operator()<hot_loop_.value, tail_num_.value>(
|
||||
a_dram_block_window_tmp,
|
||||
a_element_func,
|
||||
b_dram_block_window_tmp,
|
||||
b_element_func,
|
||||
num_loop,
|
||||
p_smem_0,
|
||||
p_smem_1);
|
||||
};
|
||||
|
||||
return Base::TailHandler(RunPipeline, has_hot_loop, tail_number);
|
||||
}
|
||||
|
||||
template <typename AsDramBlockWindowTmp,
|
||||
@@ -706,14 +711,21 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem>
|
||||
void* __restrict__ p_smem_0,
|
||||
void* __restrict__ p_smem_1) const
|
||||
{
|
||||
return PipelineImpl<Scheduler>{}.template operator()<HasHotLoop, TailNum>(
|
||||
a_dram_block_window_tmp,
|
||||
[](auto& e, const ADataType& a) { e = a; },
|
||||
b_dram_block_window_tmp,
|
||||
[](auto& e, const BDataType& b) { e = b; },
|
||||
num_loop,
|
||||
p_smem_0,
|
||||
p_smem_1);
|
||||
const bool has_hot_loop = Base::BlockHasHotloop(num_loop);
|
||||
const auto tail_number = Base::GetBlockLoopTailNum(num_loop);
|
||||
|
||||
const auto RunPipeline = [&](auto hot_loop_, auto tail_num_) {
|
||||
return PipelineImpl<Scheduler>{}.template operator()<hot_loop_.value, tail_num_.value>(
|
||||
a_dram_block_window_tmp,
|
||||
[](auto& e, const ADataType& a) { e = a; },
|
||||
b_dram_block_window_tmp,
|
||||
[](auto& e, const BDataType& b) { e = b; },
|
||||
num_loop,
|
||||
p_smem_0,
|
||||
p_smem_1);
|
||||
};
|
||||
|
||||
return Base::TailHandler(RunPipeline, has_hot_loop, tail_number);
|
||||
}
|
||||
|
||||
template <typename AsDramBlockWindowTmp,
|
||||
|
||||
@@ -92,9 +92,7 @@ struct GemmPipelineAgBgCrCompV5 : public BaseGemmPipelineAgBgCrCompV5<Problem>
|
||||
static constexpr bool DoubleSmemBuffer = Problem::DoubleSmemBuffer;
|
||||
static constexpr index_t Preshuffle = Problem::Preshuffle;
|
||||
|
||||
static constexpr bool HasHotLoop = Problem::HasHotLoop;
|
||||
static constexpr auto TailNum = Problem::TailNum;
|
||||
static constexpr auto Scheduler = Problem::Scheduler;
|
||||
static constexpr auto Scheduler = Problem::Scheduler;
|
||||
|
||||
static constexpr index_t NumWarps = BlockGemmShape::NumWarps;
|
||||
static constexpr index_t KTileSize = BlockGemmShape::WarpTile::at(I2{});
|
||||
@@ -404,13 +402,20 @@ struct GemmPipelineAgBgCrCompV5 : public BaseGemmPipelineAgBgCrCompV5<Problem>
|
||||
index_t num_loop,
|
||||
void* p_smem_0) const
|
||||
{
|
||||
return PipelineImpl<Scheduler>{}.template operator()<HasHotLoop, TailNum>(
|
||||
a_dram_block_window_tmp,
|
||||
a_element_func,
|
||||
b_dram_block_window_tmp,
|
||||
b_element_func,
|
||||
num_loop,
|
||||
p_smem_0);
|
||||
const bool has_hot_loop = Base::BlockHasHotloop(num_loop);
|
||||
const auto tail_number = Base::GetBlockLoopTailNum(num_loop);
|
||||
|
||||
const auto RunPipeline = [&](auto hot_loop_, auto tail_num_) {
|
||||
return PipelineImpl<Scheduler>{}.template operator()<hot_loop_.value, tail_num_.value>(
|
||||
a_dram_block_window_tmp,
|
||||
a_element_func,
|
||||
b_dram_block_window_tmp,
|
||||
b_element_func,
|
||||
num_loop,
|
||||
p_smem_0);
|
||||
};
|
||||
|
||||
return Base::TailHandler(RunPipeline, has_hot_loop, tail_number);
|
||||
}
|
||||
|
||||
template <typename AsDramBlockWindowTmp,
|
||||
@@ -423,13 +428,20 @@ struct GemmPipelineAgBgCrCompV5 : public BaseGemmPipelineAgBgCrCompV5<Problem>
|
||||
const index_t num_loop,
|
||||
void* __restrict__ p_smem_0) const
|
||||
{
|
||||
return PipelineImpl<Scheduler>{}.template operator()<HasHotLoop, TailNum>(
|
||||
a_dram_block_window_tmp,
|
||||
[](auto& e, const ADataType& a) { e = a; },
|
||||
b_dram_block_window_tmp,
|
||||
[](auto& e, const BDataType& b) { e = b; },
|
||||
num_loop,
|
||||
p_smem_0);
|
||||
const bool has_hot_loop = Base::BlockHasHotloop(num_loop);
|
||||
const auto tail_number = Base::GetBlockLoopTailNum(num_loop);
|
||||
|
||||
const auto RunPipeline = [&](auto hot_loop_, auto tail_num_) {
|
||||
return PipelineImpl<Scheduler>{}.template operator()<hot_loop_.value, tail_num_.value>(
|
||||
a_dram_block_window_tmp,
|
||||
[](auto& e, const ADataType& a) { e = a; },
|
||||
b_dram_block_window_tmp,
|
||||
[](auto& e, const BDataType& b) { e = b; },
|
||||
num_loop,
|
||||
p_smem_0);
|
||||
};
|
||||
|
||||
return Base::TailHandler(RunPipeline, has_hot_loop, tail_number);
|
||||
}
|
||||
|
||||
template <typename ADramBlockWindowTmp,
|
||||
|
||||
@@ -22,12 +22,12 @@ struct BaseGemmPipelineAgBgCrCompV6
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto TransposeC() { return Problem::TransposeC; }
|
||||
|
||||
CK_TILE_HOST static constexpr bool BlockHasHotloop(index_t num_loop)
|
||||
CK_TILE_HOST_DEVICE static constexpr bool BlockHasHotloop(index_t num_loop)
|
||||
{
|
||||
return num_loop > PrefetchStages;
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr TailNumber GetBlockLoopTailNum(index_t num_loop)
|
||||
CK_TILE_HOST_DEVICE static constexpr TailNumber GetBlockLoopTailNum(index_t num_loop)
|
||||
{
|
||||
if(num_loop % HotloopUnroll == 1)
|
||||
{
|
||||
@@ -153,9 +153,7 @@ struct GemmPipelineAgBgCrCompV6 : public BaseGemmPipelineAgBgCrCompV6<Problem>
|
||||
static constexpr bool DoubleSmemBuffer = Problem::DoubleSmemBuffer;
|
||||
static constexpr index_t Preshuffle = Problem::Preshuffle;
|
||||
|
||||
static constexpr bool HasHotLoop = Problem::HasHotLoop;
|
||||
static constexpr auto TailNum = Problem::TailNum;
|
||||
static constexpr auto Scheduler = Problem::Scheduler;
|
||||
static constexpr auto Scheduler = Problem::Scheduler;
|
||||
|
||||
static constexpr auto is_a_load_tr_v = bool_constant<BasePImpl::is_a_load_tr>{};
|
||||
static constexpr auto is_b_load_tr_v = bool_constant<BasePImpl::is_b_load_tr>{};
|
||||
@@ -173,11 +171,9 @@ struct GemmPipelineAgBgCrCompV6 : public BaseGemmPipelineAgBgCrCompV6<Problem>
|
||||
return concat('_', "pipeline_AgBgCrCompV6", BlockSize,
|
||||
concat('x', GetVectorSizeA(), GetVectorSizeB(), GetVectorSizeC()),
|
||||
concat('x', kPadM, kPadN, kPadK),
|
||||
concat('x', TailNum),
|
||||
concat('_', KRepeat),
|
||||
concat('_', DoubleSmemBuffer),
|
||||
concat('_', Preshuffle),
|
||||
concat('_', HasHotLoop));
|
||||
concat('_', Preshuffle));
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
@@ -725,13 +721,20 @@ struct GemmPipelineAgBgCrCompV6 : public BaseGemmPipelineAgBgCrCompV6<Problem>
|
||||
index_t num_loop,
|
||||
void* __restrict__ p_smem) const
|
||||
{
|
||||
return PipelineImpl<Scheduler>{}.template operator()<HasHotLoop, TailNum>(
|
||||
a_dram_block_window_tmp,
|
||||
a_element_func,
|
||||
b_dram_block_window_tmp,
|
||||
b_element_func,
|
||||
num_loop,
|
||||
p_smem);
|
||||
const bool has_hot_loop = Base::BlockHasHotloop(num_loop);
|
||||
const auto tail_number = Base::GetBlockLoopTailNum(num_loop);
|
||||
|
||||
const auto RunPipeline = [&](auto hot_loop_, auto tail_num_) {
|
||||
return PipelineImpl<Scheduler>{}.template operator()<hot_loop_.value, tail_num_.value>(
|
||||
a_dram_block_window_tmp,
|
||||
a_element_func,
|
||||
b_dram_block_window_tmp,
|
||||
b_element_func,
|
||||
num_loop,
|
||||
p_smem);
|
||||
};
|
||||
|
||||
return Base::TailHandler(RunPipeline, has_hot_loop, tail_number);
|
||||
}
|
||||
|
||||
template <typename AsDramBlockWindowTmp,
|
||||
@@ -744,13 +747,20 @@ struct GemmPipelineAgBgCrCompV6 : public BaseGemmPipelineAgBgCrCompV6<Problem>
|
||||
const index_t num_loop,
|
||||
void* __restrict__ p_smem) const
|
||||
{
|
||||
return PipelineImpl<Scheduler>{}.template operator()<HasHotLoop, TailNum>(
|
||||
a_dram_block_window_tmp,
|
||||
[](auto& e, const ADataType& a) { e = a; },
|
||||
b_dram_block_window_tmp,
|
||||
[](auto& e, const BDataType& b) { e = b; },
|
||||
num_loop,
|
||||
p_smem);
|
||||
const bool has_hot_loop = Base::BlockHasHotloop(num_loop);
|
||||
const auto tail_number = Base::GetBlockLoopTailNum(num_loop);
|
||||
|
||||
const auto RunPipeline = [&](auto hot_loop_, auto tail_num_) {
|
||||
return PipelineImpl<Scheduler>{}.template operator()<hot_loop_.value, tail_num_.value>(
|
||||
a_dram_block_window_tmp,
|
||||
[](auto& e, const ADataType& a) { e = a; },
|
||||
b_dram_block_window_tmp,
|
||||
[](auto& e, const BDataType& b) { e = b; },
|
||||
num_loop,
|
||||
p_smem);
|
||||
};
|
||||
|
||||
return Base::TailHandler(RunPipeline, has_hot_loop, tail_number);
|
||||
}
|
||||
|
||||
template <typename ADramBlockWindowTmp,
|
||||
|
||||
@@ -206,10 +206,7 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
|
||||
static constexpr index_t NumWaveGroups = Problem::NumWaveGroups;
|
||||
static constexpr index_t Preshuffle = Problem::Preshuffle;
|
||||
|
||||
// Where is the right place for HasHotLoop and TailNum ???
|
||||
static constexpr bool HasHotLoop = Problem::HasHotLoop;
|
||||
static constexpr auto TailNum = Problem::TailNum;
|
||||
static constexpr auto Scheduler = Problem::Scheduler;
|
||||
static constexpr auto Scheduler = Problem::Scheduler;
|
||||
|
||||
static constexpr auto is_a_load_tr_v = bool_constant<PipelineImplBase::is_a_load_tr>{};
|
||||
static constexpr auto is_b_load_tr_v = bool_constant<PipelineImplBase::is_b_load_tr>{};
|
||||
@@ -887,13 +884,20 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
|
||||
index_t num_loop,
|
||||
void* p_smem) const
|
||||
{
|
||||
return PipelineImpl<Scheduler>{}.template operator()<HasHotLoop, TailNum>(
|
||||
a_dram_block_window_tmp,
|
||||
a_element_func,
|
||||
b_dram_block_window_tmp,
|
||||
b_element_func,
|
||||
num_loop,
|
||||
p_smem);
|
||||
const bool has_hot_loop = Base::BlockHasHotloop(num_loop);
|
||||
const auto tail_number = Base::GetBlockLoopTailNum(num_loop);
|
||||
|
||||
const auto RunPipeline = [&](auto hot_loop_, auto tail_num_) {
|
||||
return PipelineImpl<Scheduler>{}.template operator()<hot_loop_.value, tail_num_.value>(
|
||||
a_dram_block_window_tmp,
|
||||
a_element_func,
|
||||
b_dram_block_window_tmp,
|
||||
b_element_func,
|
||||
num_loop,
|
||||
p_smem);
|
||||
};
|
||||
|
||||
return Base::TailHandler(RunPipeline, has_hot_loop, tail_number);
|
||||
}
|
||||
|
||||
template <typename AsDramBlockWindowTmp,
|
||||
@@ -933,13 +937,20 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
|
||||
index_t num_loop,
|
||||
void* p_smem) const
|
||||
{
|
||||
return PipelineImpl<Scheduler>{}.template operator()<HasHotLoop, TailNum>(
|
||||
a_dram_block_window_tmp,
|
||||
[](auto& e, const ADataType& a) { e = a; },
|
||||
b_dram_block_window_tmp,
|
||||
[](auto& e, const ADataType& a) { e = a; },
|
||||
num_loop,
|
||||
p_smem);
|
||||
const bool has_hot_loop = Base::BlockHasHotloop(num_loop);
|
||||
const auto tail_number = Base::GetBlockLoopTailNum(num_loop);
|
||||
|
||||
const auto RunPipeline = [&](auto hot_loop_, auto tail_num_) {
|
||||
return PipelineImpl<Scheduler>{}.template operator()<hot_loop_.value, tail_num_.value>(
|
||||
a_dram_block_window_tmp,
|
||||
[](auto& e, const ADataType& a) { e = a; },
|
||||
b_dram_block_window_tmp,
|
||||
[](auto& e, const BDataType& b) { e = b; },
|
||||
num_loop,
|
||||
p_smem);
|
||||
};
|
||||
|
||||
return Base::TailHandler(RunPipeline, has_hot_loop, tail_number);
|
||||
}
|
||||
|
||||
template <typename ADramBlockWindowTmp,
|
||||
|
||||
@@ -224,8 +224,6 @@ template <typename AsDataType_,
|
||||
typename BlockGemmShape_,
|
||||
typename Traits_,
|
||||
GemmPipelineScheduler Scheduler_ = GemmPipelineScheduler::Intrawave,
|
||||
bool HasHotLoop_ = true,
|
||||
TailNumber TailNum_ = TailNumber::Full,
|
||||
typename AElementWise_ = ck_tile::element_wise::PassThrough,
|
||||
typename BElementWise_ = ck_tile::element_wise::PassThrough,
|
||||
typename ComputeDataType_ = AsDataType_,
|
||||
@@ -296,8 +294,6 @@ struct UniversalGemmPipelineProblem
|
||||
static constexpr index_t VectorSizeA = VectorSizeA_;
|
||||
static constexpr index_t VectorSizeB = VectorSizeB_;
|
||||
|
||||
static constexpr auto HasHotLoop = HasHotLoop_;
|
||||
static constexpr auto TailNum = TailNum_;
|
||||
static constexpr index_t VectorLoadSize = Traits::_VectorSize;
|
||||
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
|
||||
{
|
||||
|
||||
@@ -148,7 +148,6 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2
|
||||
static constexpr index_t m_preload = (MIterPerWarp * KIterPerWarp >= DsReadPreload)
|
||||
? DsReadPreload
|
||||
: MIterPerWarp * KIterPerWarp;
|
||||
static constexpr auto TailNum = Problem::TailNum;
|
||||
|
||||
#ifdef __gfx942__
|
||||
static constexpr index_t mfma_per_wg = 2;
|
||||
@@ -1042,13 +1041,20 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2
|
||||
void* p_smem_ping,
|
||||
void* p_smem_pong) const
|
||||
{
|
||||
return operator()<TailNum>(
|
||||
a_dram_block_window_tmp[number<0>{}],
|
||||
[](const ADataType& a) { return a; },
|
||||
b_flat_dram_block_window_tmp[number<0>{}],
|
||||
num_loop,
|
||||
p_smem_ping,
|
||||
p_smem_pong);
|
||||
const auto tail_number = Base::GetBlockLoopTailNum(num_loop);
|
||||
|
||||
const auto RunPipeline = [&](auto bool_val, auto tail_num_) {
|
||||
(void)bool_val; // Suppress unused parameter warning
|
||||
constexpr auto tail_num = tail_num_.value;
|
||||
constexpr auto PassThrough = [](const ADataType& a) { return a; };
|
||||
return operator()<tail_num>(a_dram_block_window_tmp[number<0>{}],
|
||||
PassThrough,
|
||||
b_flat_dram_block_window_tmp[number<0>{}],
|
||||
num_loop,
|
||||
p_smem_ping,
|
||||
p_smem_pong);
|
||||
};
|
||||
return Base::TailHandler(RunPipeline, true, tail_number);
|
||||
}
|
||||
|
||||
// called from general gemm kernel
|
||||
@@ -1063,13 +1069,20 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2
|
||||
void* p_smem_ping,
|
||||
void* p_smem_pong) const
|
||||
{
|
||||
return operator()<TailNum>(
|
||||
a_dram_block_window_tmp,
|
||||
[](const ADataType& a) { return a; },
|
||||
b_flat_dram_block_window_tmp,
|
||||
num_loop,
|
||||
p_smem_ping,
|
||||
p_smem_pong);
|
||||
const auto tail_number = Base::GetBlockLoopTailNum(num_loop);
|
||||
|
||||
const auto RunPipeline = [&](auto bool_val, auto tail_num_) {
|
||||
(void)bool_val; // Suppress unused parameter warning
|
||||
constexpr auto tail_num = tail_num_.value;
|
||||
constexpr auto PassThrough = [](const ADataType& a) { return a; };
|
||||
return operator()<tail_num>(a_dram_block_window_tmp,
|
||||
PassThrough,
|
||||
b_flat_dram_block_window_tmp,
|
||||
num_loop,
|
||||
p_smem_ping,
|
||||
p_smem_pong);
|
||||
};
|
||||
return Base::TailHandler(RunPipeline, true, tail_number);
|
||||
}
|
||||
|
||||
// called from grouped gemm kernel
|
||||
|
||||
Reference in New Issue
Block a user