mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 05:01:25 +00:00
Add interwave scheduler for gemm mem pipeline (#1647)
* add interwave scheduler for gemm mem pipeline * Fix merge artifacts. * Refactor unit tests. * Switch to interwave scheduler for mem example --------- Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com> Co-authored-by: Adam Osewski <Adam.Osewski@amd.com>
This commit is contained in:
@@ -322,6 +322,7 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
|
||||
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
LocalPrefill(a_copy_lds_window,
|
||||
a_block_tiles.get(number<prefetch_idx>{}),
|
||||
a_element_func);
|
||||
@@ -374,6 +375,229 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct PipelineImpl<GemmPipelineScheduler::Interwave>
|
||||
{
|
||||
template <typename DstBlockTile, typename SrcTileWindow>
|
||||
CK_TILE_DEVICE void GlobalPrefetch(DstBlockTile& dst_block_tile,
|
||||
SrcTileWindow& dram_tile_window) const
|
||||
{
|
||||
load_tile(dst_block_tile, dram_tile_window);
|
||||
move_tile_window(dram_tile_window, {0, KPerBlock});
|
||||
}
|
||||
|
||||
template <typename DstTileWindow, typename SrcBlockTile, typename ElementFunction>
|
||||
CK_TILE_DEVICE void LocalPrefill(DstTileWindow& lds_tile_window,
|
||||
const SrcBlockTile& src_block_tile,
|
||||
const ElementFunction& element_func) const
|
||||
{
|
||||
const auto block_tile_tmp = tile_elementwise_in(element_func, src_block_tile);
|
||||
store_tile(lds_tile_window, block_tile_tmp);
|
||||
}
|
||||
|
||||
template <bool HasHotLoop,
|
||||
TailNumber TailNum,
|
||||
typename ADramBlockWindowTmp,
|
||||
typename BDramBlockWindowTmp,
|
||||
typename AElementFunction,
|
||||
typename BElementFunction>
|
||||
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const AElementFunction& a_element_func,
|
||||
const BDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
const BElementFunction& b_element_func,
|
||||
index_t num_loop,
|
||||
void* p_smem) const
|
||||
{
|
||||
static_assert(
|
||||
std::is_same_v<ADataType, remove_cvref_t<typename ADramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<BDataType,
|
||||
remove_cvref_t<typename BDramBlockWindowTmp::DataType>>,
|
||||
"A/B Dram block window should have the same data type as appropriate "
|
||||
"([A|B]DataType) defined in Problem definition!");
|
||||
|
||||
static_assert(MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
NPerBlock ==
|
||||
BDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
|
||||
"A/B block window appropriate sizes must be equal to MPerBlock/NPerblock"
|
||||
" or KPerBlock!");
|
||||
|
||||
// ------------------------------------------------------------------------------------
|
||||
// Definitions of all needed tiles
|
||||
|
||||
// A tile in LDS
|
||||
ADataType* p_a_lds = static_cast<ADataType*>(p_smem);
|
||||
constexpr auto a_lds_block_desc = Policy::template MakeALdsBlockDescriptor<Problem>();
|
||||
auto a_lds_block = make_tensor_view<address_space_enum::lds>(p_a_lds, a_lds_block_desc);
|
||||
|
||||
// TODO: LDS alignment should come from Policy!
|
||||
constexpr index_t a_lds_block_space_size_aligned =
|
||||
integer_divide_ceil(sizeof(ADataType) * a_lds_block_desc.get_element_space_size(),
|
||||
16) *
|
||||
16;
|
||||
|
||||
// B tile in LDS
|
||||
BDataType* p_b_lds = static_cast<BDataType*>(
|
||||
static_cast<void*>(static_cast<char*>(p_smem) + a_lds_block_space_size_aligned));
|
||||
constexpr auto b_lds_block_desc = Policy::template MakeBLdsBlockDescriptor<Problem>();
|
||||
auto b_lds_block = make_tensor_view<address_space_enum::lds>(p_b_lds, b_lds_block_desc);
|
||||
|
||||
// A DRAM tile window for load
|
||||
auto a_copy_dram_window =
|
||||
make_tile_window(a_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tuple(number<MPerBlock>{}, number<KPerBlock>{}),
|
||||
a_dram_block_window_tmp.get_window_origin(),
|
||||
Policy::template MakeADramTileDistribution<Problem>());
|
||||
|
||||
// A LDS tile window for store
|
||||
auto a_copy_lds_window =
|
||||
make_tile_window(a_lds_block,
|
||||
make_tuple(number<MPerBlock>{}, number<KPerBlock>{}),
|
||||
{0, 0},
|
||||
a_copy_dram_window.get_tile_distribution());
|
||||
// B DRAM tile window for load
|
||||
auto b_copy_dram_window =
|
||||
make_tile_window(b_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tuple(number<NPerBlock>{}, number<KPerBlock>{}),
|
||||
b_dram_block_window_tmp.get_window_origin(),
|
||||
Policy::template MakeBDramTileDistribution<Problem>());
|
||||
|
||||
// B LDS tile window for store
|
||||
auto b_copy_lds_window =
|
||||
make_tile_window(b_lds_block,
|
||||
make_tuple(number<NPerBlock>{}, number<KPerBlock>{}),
|
||||
{0, 0},
|
||||
b_copy_dram_window.get_tile_distribution());
|
||||
|
||||
// A LDS tile for block GEMM
|
||||
auto a_lds_gemm_window = make_tile_window(
|
||||
a_lds_block, make_tuple(number<MPerBlock>{}, number<KPerBlock>{}), {0, 0});
|
||||
// B LDS tile for block GEMM
|
||||
auto b_lds_gemm_window = make_tile_window(
|
||||
b_lds_block, make_tuple(number<NPerBlock>{}, number<KPerBlock>{}), {0, 0});
|
||||
|
||||
// Block GEMM
|
||||
auto block_gemm = BlockGemm();
|
||||
auto c_block_tile = block_gemm.MakeCBlockTile();
|
||||
|
||||
using ABlockTileDistr = decltype(a_copy_dram_window.get_tile_distribution());
|
||||
using BBlockTileDistr = decltype(b_copy_dram_window.get_tile_distribution());
|
||||
|
||||
using ABlockTile =
|
||||
decltype(make_static_distributed_tensor<ADataType>(ABlockTileDistr{}));
|
||||
using BBlockTile =
|
||||
decltype(make_static_distributed_tensor<BDataType>(BBlockTileDistr{}));
|
||||
|
||||
tuple_array<ABlockTile, PrefetchStages> a_block_tiles;
|
||||
tuple_array<BBlockTile, PrefetchStages> b_block_tiles;
|
||||
|
||||
// -----------------------------------------------------------------------------------------
|
||||
// Gemm pipeline start
|
||||
|
||||
// prefetch
|
||||
// global read 0
|
||||
GlobalPrefetch(a_block_tiles.get(I0{}), a_copy_dram_window);
|
||||
GlobalPrefetch(b_block_tiles.get(I0{}), b_copy_dram_window);
|
||||
|
||||
// initialize C
|
||||
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
|
||||
|
||||
// LDS write 0
|
||||
LocalPrefill(a_copy_lds_window, a_block_tiles.get(I0{}), a_element_func);
|
||||
LocalPrefill(b_copy_lds_window, b_block_tiles.get(I0{}), b_element_func);
|
||||
|
||||
// Global prefetch [1, PrefetchStages]
|
||||
static_for<1, PrefetchStages, 1>{}([&](auto prefetch_idx) {
|
||||
GlobalPrefetch(a_block_tiles.get(number<prefetch_idx>{}), a_copy_dram_window);
|
||||
GlobalPrefetch(b_block_tiles.get(number<prefetch_idx>{}), b_copy_dram_window);
|
||||
});
|
||||
|
||||
// main body
|
||||
if constexpr(HasHotLoop)
|
||||
{
|
||||
index_t i = 0;
|
||||
do
|
||||
{
|
||||
static_for<0, PrefetchStages, 1>{}([&](auto prefetch_idx) {
|
||||
block_sync_lds();
|
||||
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
|
||||
// no second block_sync_lds because it's interwave
|
||||
|
||||
LocalPrefill(
|
||||
a_copy_lds_window,
|
||||
a_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{}),
|
||||
a_element_func);
|
||||
LocalPrefill(
|
||||
b_copy_lds_window,
|
||||
b_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{}),
|
||||
b_element_func);
|
||||
|
||||
GlobalPrefetch(a_block_tiles.get(number<prefetch_idx>{}),
|
||||
a_copy_dram_window);
|
||||
GlobalPrefetch(b_block_tiles.get(number<prefetch_idx>{}),
|
||||
b_copy_dram_window);
|
||||
});
|
||||
|
||||
i += PrefetchStages;
|
||||
} while(i < (num_loop - PrefetchStages));
|
||||
}
|
||||
|
||||
auto HotLoopTail = [&](auto tail_num) {
|
||||
static_for<1, tail_num, 1>{}([&](auto prefetch_idx) {
|
||||
block_sync_lds();
|
||||
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
|
||||
// no second block_sync_lds because it's interwave
|
||||
|
||||
LocalPrefill(a_copy_lds_window,
|
||||
a_block_tiles.get(number<prefetch_idx>{}),
|
||||
a_element_func);
|
||||
LocalPrefill(b_copy_lds_window,
|
||||
b_block_tiles.get(number<prefetch_idx>{}),
|
||||
b_element_func);
|
||||
});
|
||||
|
||||
block_sync_lds();
|
||||
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
|
||||
};
|
||||
|
||||
if constexpr(TailNum == TailNumber::One)
|
||||
{
|
||||
block_sync_lds();
|
||||
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
|
||||
}
|
||||
else if constexpr(TailNum == TailNumber::Two)
|
||||
{
|
||||
HotLoopTail(number<2>{});
|
||||
}
|
||||
else if constexpr(TailNum == TailNumber::Three)
|
||||
{
|
||||
HotLoopTail(number<3>{});
|
||||
}
|
||||
else if constexpr(TailNum == TailNumber::Four)
|
||||
{
|
||||
HotLoopTail(number<4>{});
|
||||
}
|
||||
else if constexpr(TailNum == TailNumber::Five)
|
||||
{
|
||||
HotLoopTail(number<5>{});
|
||||
}
|
||||
else if constexpr(TailNum == TailNumber::Six)
|
||||
{
|
||||
HotLoopTail(number<6>{});
|
||||
}
|
||||
else if constexpr(TailNum == TailNumber::Seven)
|
||||
{
|
||||
HotLoopTail(number<7>{});
|
||||
}
|
||||
else if constexpr(TailNum == TailNumber::Full)
|
||||
{
|
||||
HotLoopTail(number<PrefetchStages>{});
|
||||
}
|
||||
|
||||
return c_block_tile;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename ADramBlockWindowTmp,
|
||||
typename BDramBlockWindowTmp,
|
||||
typename AElementFunction,
|
||||
|
||||
Reference in New Issue
Block a user