mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-05 14:11:29 +00:00
add vectorloads on non-k dim for memory pipelines (#1856)
This commit is contained in:
@@ -90,7 +90,7 @@ struct BaseGemmPipelineAgBgCrMem
|
||||
// LocalPreFillStages: 1
|
||||
// LocalPreFetchStages: 0
|
||||
// LocalSharedMemoryBuffer: 1
|
||||
template <typename Problem, typename Policy = GemmPipelineAGmemBGmemCRegV1DefaultPolicy>
|
||||
template <typename Problem, typename Policy = UniversalGemmPipelineAgBgCrPolicy>
|
||||
struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
|
||||
{
|
||||
using Base = BaseGemmPipelineAgBgCrMem<Problem>;
|
||||
@@ -165,11 +165,22 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
|
||||
"A/B Dram block window should have the same data type as appropriate "
|
||||
"([A|B]DataType) defined in Problem definition!");
|
||||
|
||||
static_assert(MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
|
||||
NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
|
||||
KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}],
|
||||
"A/B block window appropriate sizes must be equal to MPerBlock/NPerblock"
|
||||
" or KPerBlock!");
|
||||
constexpr bool is_a_col_major =
|
||||
std::is_same_v<ALayout, tensor_layout::gemm::ColumnMajor>;
|
||||
constexpr bool is_b_row_major = std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>;
|
||||
|
||||
static_assert(is_a_col_major
|
||||
? (KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
|
||||
MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}])
|
||||
: (MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
|
||||
KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}]),
|
||||
"A block window has incorrect lengths for defined ALayout!");
|
||||
static_assert(is_b_row_major
|
||||
? (KPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
|
||||
NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}])
|
||||
: (NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
|
||||
KPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}]),
|
||||
"B block window has incorrect lengths for defined BLayout!");
|
||||
|
||||
// ------------------------------------------------------------------------------------
|
||||
// Definitions of all needed tiles
|
||||
@@ -213,25 +224,59 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
|
||||
tuple_array<ABlockTile, PrefetchStages> a_block_tiles;
|
||||
tuple_array<BBlockTile, PrefetchStages> b_block_tiles;
|
||||
|
||||
using ADramTileWindowStep = typename ADramBlockWindowTmp::BottomTensorIndex;
|
||||
using BDramTileWindowStep = typename BDramBlockWindowTmp::BottomTensorIndex;
|
||||
|
||||
constexpr ADramTileWindowStep a_dram_tile_window_step =
|
||||
is_a_col_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock);
|
||||
constexpr BDramTileWindowStep b_dram_tile_window_step =
|
||||
is_b_row_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock);
|
||||
|
||||
// -----------------------------------------------------------------------------------------
|
||||
// Gemm pipeline start
|
||||
|
||||
// prefetch
|
||||
// global read 0
|
||||
Base::GlobalPrefetch(a_block_tiles.get(I0{}), a_copy_dram_window);
|
||||
Base::GlobalPrefetch(b_block_tiles.get(I0{}), b_copy_dram_window);
|
||||
Base::GlobalPrefetch(
|
||||
a_block_tiles.get(I0{}), a_copy_dram_window, a_dram_tile_window_step);
|
||||
Base::GlobalPrefetch(
|
||||
b_block_tiles.get(I0{}), b_copy_dram_window, b_dram_tile_window_step);
|
||||
|
||||
// initialize C
|
||||
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
|
||||
|
||||
// LDS write 0
|
||||
Base::LocalPrefill(a_copy_lds_window, a_block_tiles.get(I0{}), a_element_func);
|
||||
Base::LocalPrefill(b_copy_lds_window, b_block_tiles.get(I0{}), b_element_func);
|
||||
if constexpr(is_a_col_major)
|
||||
{
|
||||
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
|
||||
Policy::template MakeShuffledARegTileDistribution<Problem>());
|
||||
transpose_tile2d(a_shuffle_tmp, a_block_tiles.get(I0{}));
|
||||
Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func);
|
||||
}
|
||||
else
|
||||
{
|
||||
Base::LocalPrefill(a_copy_lds_window, a_block_tiles.get(I0{}), a_element_func);
|
||||
}
|
||||
if constexpr(is_b_row_major)
|
||||
{
|
||||
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
|
||||
Policy::template MakeShuffledBRegTileDistribution<Problem>());
|
||||
transpose_tile2d(b_shuffle_tmp, b_block_tiles.get(I0{}));
|
||||
Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func);
|
||||
}
|
||||
else
|
||||
{
|
||||
Base::LocalPrefill(b_copy_lds_window, b_block_tiles.get(I0{}), b_element_func);
|
||||
}
|
||||
|
||||
// Global prefetch [1, PrefetchStages]
|
||||
static_for<1, PrefetchStages, 1>{}([&](auto prefetch_idx) {
|
||||
Base::GlobalPrefetch(a_block_tiles.get(number<prefetch_idx>{}), a_copy_dram_window);
|
||||
Base::GlobalPrefetch(b_block_tiles.get(number<prefetch_idx>{}), b_copy_dram_window);
|
||||
Base::GlobalPrefetch(a_block_tiles.get(number<prefetch_idx>{}),
|
||||
a_copy_dram_window,
|
||||
a_dram_tile_window_step);
|
||||
Base::GlobalPrefetch(b_block_tiles.get(number<prefetch_idx>{}),
|
||||
b_copy_dram_window,
|
||||
b_dram_tile_window_step);
|
||||
});
|
||||
|
||||
// main body
|
||||
@@ -247,19 +292,45 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
Base::LocalPrefill(
|
||||
a_copy_lds_window,
|
||||
a_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{}),
|
||||
a_element_func);
|
||||
Base::LocalPrefill(
|
||||
b_copy_lds_window,
|
||||
b_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{}),
|
||||
b_element_func);
|
||||
if constexpr(is_a_col_major)
|
||||
{
|
||||
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
|
||||
Policy::template MakeShuffledARegTileDistribution<Problem>());
|
||||
transpose_tile2d(
|
||||
a_shuffle_tmp,
|
||||
a_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{}));
|
||||
Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func);
|
||||
}
|
||||
else
|
||||
{
|
||||
Base::LocalPrefill(
|
||||
a_copy_lds_window,
|
||||
a_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{}),
|
||||
a_element_func);
|
||||
}
|
||||
if constexpr(is_b_row_major)
|
||||
{
|
||||
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
|
||||
Policy::template MakeShuffledBRegTileDistribution<Problem>());
|
||||
transpose_tile2d(
|
||||
b_shuffle_tmp,
|
||||
b_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{}));
|
||||
Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func);
|
||||
}
|
||||
else
|
||||
{
|
||||
Base::LocalPrefill(
|
||||
b_copy_lds_window,
|
||||
b_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{}),
|
||||
b_element_func);
|
||||
}
|
||||
|
||||
Base::GlobalPrefetch(a_block_tiles.get(number<prefetch_idx>{}),
|
||||
a_copy_dram_window);
|
||||
a_copy_dram_window,
|
||||
a_dram_tile_window_step);
|
||||
Base::GlobalPrefetch(b_block_tiles.get(number<prefetch_idx>{}),
|
||||
b_copy_dram_window);
|
||||
b_copy_dram_window,
|
||||
b_dram_tile_window_step);
|
||||
});
|
||||
|
||||
i += PrefetchStages;
|
||||
@@ -275,12 +346,32 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
Base::LocalPrefill(a_copy_lds_window,
|
||||
a_block_tiles.get(number<prefetch_idx>{}),
|
||||
a_element_func);
|
||||
Base::LocalPrefill(b_copy_lds_window,
|
||||
b_block_tiles.get(number<prefetch_idx>{}),
|
||||
b_element_func);
|
||||
if constexpr(is_a_col_major)
|
||||
{
|
||||
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
|
||||
Policy::template MakeShuffledARegTileDistribution<Problem>());
|
||||
transpose_tile2d(a_shuffle_tmp, a_block_tiles.get(number<prefetch_idx>{}));
|
||||
Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func);
|
||||
}
|
||||
else
|
||||
{
|
||||
Base::LocalPrefill(a_copy_lds_window,
|
||||
a_block_tiles.get(number<prefetch_idx>{}),
|
||||
a_element_func);
|
||||
}
|
||||
if constexpr(is_b_row_major)
|
||||
{
|
||||
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
|
||||
Policy::template MakeShuffledBRegTileDistribution<Problem>());
|
||||
transpose_tile2d(b_shuffle_tmp, b_block_tiles.get(number<prefetch_idx>{}));
|
||||
Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func);
|
||||
}
|
||||
else
|
||||
{
|
||||
Base::LocalPrefill(b_copy_lds_window,
|
||||
b_block_tiles.get(number<prefetch_idx>{}),
|
||||
b_element_func);
|
||||
}
|
||||
});
|
||||
|
||||
block_sync_lds();
|
||||
@@ -352,11 +443,22 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
|
||||
"A/B Dram block window should have the same data type as appropriate "
|
||||
"([A|B]DataType) defined in Problem definition!");
|
||||
|
||||
static_assert(MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
|
||||
NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
|
||||
KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}],
|
||||
"A/B block window appropriate sizes must be equal to MPerBlock/NPerblock"
|
||||
" or KPerBlock!");
|
||||
constexpr bool is_a_col_major =
|
||||
std::is_same_v<ALayout, tensor_layout::gemm::ColumnMajor>;
|
||||
constexpr bool is_b_row_major = std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>;
|
||||
|
||||
static_assert(is_a_col_major
|
||||
? (KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
|
||||
MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}])
|
||||
: (MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
|
||||
KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}]),
|
||||
"A block window has incorrect lengths for defined ALayout!");
|
||||
static_assert(is_b_row_major
|
||||
? (KPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
|
||||
NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}])
|
||||
: (NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
|
||||
KPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}]),
|
||||
"B block window has incorrect lengths for defined BLayout!");
|
||||
|
||||
// ------------------------------------------------------------------------------------
|
||||
// Definitions of all needed tiles
|
||||
@@ -400,25 +502,58 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
|
||||
tuple_array<ABlockTile, PrefetchStages> a_block_tiles;
|
||||
tuple_array<BBlockTile, PrefetchStages> b_block_tiles;
|
||||
|
||||
using ADramTileWindowStep = typename ADramBlockWindowTmp::BottomTensorIndex;
|
||||
using BDramTileWindowStep = typename BDramBlockWindowTmp::BottomTensorIndex;
|
||||
|
||||
constexpr ADramTileWindowStep a_dram_tile_window_step =
|
||||
is_a_col_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock);
|
||||
constexpr BDramTileWindowStep b_dram_tile_window_step =
|
||||
is_b_row_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock);
|
||||
// -----------------------------------------------------------------------------------------
|
||||
// Gemm pipeline start
|
||||
|
||||
// prefetch
|
||||
// global read 0
|
||||
Base::GlobalPrefetch(a_block_tiles.get(I0{}), a_copy_dram_window);
|
||||
Base::GlobalPrefetch(b_block_tiles.get(I0{}), b_copy_dram_window);
|
||||
Base::GlobalPrefetch(
|
||||
a_block_tiles.get(I0{}), a_copy_dram_window, a_dram_tile_window_step);
|
||||
Base::GlobalPrefetch(
|
||||
b_block_tiles.get(I0{}), b_copy_dram_window, b_dram_tile_window_step);
|
||||
|
||||
// initialize C
|
||||
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
|
||||
|
||||
// LDS write 0
|
||||
Base::LocalPrefill(a_copy_lds_window, a_block_tiles.get(I0{}), a_element_func);
|
||||
Base::LocalPrefill(b_copy_lds_window, b_block_tiles.get(I0{}), b_element_func);
|
||||
if constexpr(is_a_col_major)
|
||||
{
|
||||
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
|
||||
Policy::template MakeShuffledARegTileDistribution<Problem>());
|
||||
transpose_tile2d(a_shuffle_tmp, a_block_tiles.get(I0{}));
|
||||
Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func);
|
||||
}
|
||||
else
|
||||
{
|
||||
Base::LocalPrefill(a_copy_lds_window, a_block_tiles.get(I0{}), a_element_func);
|
||||
}
|
||||
if constexpr(is_b_row_major)
|
||||
{
|
||||
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
|
||||
Policy::template MakeShuffledBRegTileDistribution<Problem>());
|
||||
transpose_tile2d(b_shuffle_tmp, b_block_tiles.get(I0{}));
|
||||
Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func);
|
||||
}
|
||||
else
|
||||
{
|
||||
Base::LocalPrefill(b_copy_lds_window, b_block_tiles.get(I0{}), b_element_func);
|
||||
}
|
||||
|
||||
// Global prefetch [1, PrefetchStages]
|
||||
static_for<1, PrefetchStages, 1>{}([&](auto prefetch_idx) {
|
||||
Base::GlobalPrefetch(a_block_tiles.get(number<prefetch_idx>{}), a_copy_dram_window);
|
||||
Base::GlobalPrefetch(b_block_tiles.get(number<prefetch_idx>{}), b_copy_dram_window);
|
||||
Base::GlobalPrefetch(a_block_tiles.get(number<prefetch_idx>{}),
|
||||
a_copy_dram_window,
|
||||
a_dram_tile_window_step);
|
||||
Base::GlobalPrefetch(b_block_tiles.get(number<prefetch_idx>{}),
|
||||
b_copy_dram_window,
|
||||
b_dram_tile_window_step);
|
||||
});
|
||||
|
||||
// main body
|
||||
@@ -432,19 +567,45 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
|
||||
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
|
||||
// no second block_sync_lds because it's interwave
|
||||
|
||||
Base::LocalPrefill(
|
||||
a_copy_lds_window,
|
||||
a_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{}),
|
||||
a_element_func);
|
||||
Base::LocalPrefill(
|
||||
b_copy_lds_window,
|
||||
b_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{}),
|
||||
b_element_func);
|
||||
if constexpr(is_a_col_major)
|
||||
{
|
||||
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
|
||||
Policy::template MakeShuffledARegTileDistribution<Problem>());
|
||||
transpose_tile2d(
|
||||
a_shuffle_tmp,
|
||||
a_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{}));
|
||||
Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func);
|
||||
}
|
||||
else
|
||||
{
|
||||
Base::LocalPrefill(
|
||||
a_copy_lds_window,
|
||||
a_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{}),
|
||||
a_element_func);
|
||||
}
|
||||
if constexpr(is_b_row_major)
|
||||
{
|
||||
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
|
||||
Policy::template MakeShuffledBRegTileDistribution<Problem>());
|
||||
transpose_tile2d(
|
||||
b_shuffle_tmp,
|
||||
b_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{}));
|
||||
Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func);
|
||||
}
|
||||
else
|
||||
{
|
||||
Base::LocalPrefill(
|
||||
b_copy_lds_window,
|
||||
b_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{}),
|
||||
b_element_func);
|
||||
}
|
||||
|
||||
Base::GlobalPrefetch(a_block_tiles.get(number<prefetch_idx>{}),
|
||||
a_copy_dram_window);
|
||||
a_copy_dram_window,
|
||||
a_dram_tile_window_step);
|
||||
Base::GlobalPrefetch(b_block_tiles.get(number<prefetch_idx>{}),
|
||||
b_copy_dram_window);
|
||||
b_copy_dram_window,
|
||||
b_dram_tile_window_step);
|
||||
});
|
||||
|
||||
i += PrefetchStages;
|
||||
@@ -457,12 +618,32 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
|
||||
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
|
||||
// no second block_sync_lds because it's interwave
|
||||
|
||||
Base::LocalPrefill(a_copy_lds_window,
|
||||
a_block_tiles.get(number<prefetch_idx>{}),
|
||||
a_element_func);
|
||||
Base::LocalPrefill(b_copy_lds_window,
|
||||
b_block_tiles.get(number<prefetch_idx>{}),
|
||||
b_element_func);
|
||||
if constexpr(is_a_col_major)
|
||||
{
|
||||
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
|
||||
Policy::template MakeShuffledARegTileDistribution<Problem>());
|
||||
transpose_tile2d(a_shuffle_tmp, a_block_tiles.get(number<prefetch_idx>{}));
|
||||
Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func);
|
||||
}
|
||||
else
|
||||
{
|
||||
Base::LocalPrefill(a_copy_lds_window,
|
||||
a_block_tiles.get(number<prefetch_idx>{}),
|
||||
a_element_func);
|
||||
}
|
||||
if constexpr(is_b_row_major)
|
||||
{
|
||||
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
|
||||
Policy::template MakeShuffledBRegTileDistribution<Problem>());
|
||||
transpose_tile2d(b_shuffle_tmp, b_block_tiles.get(number<prefetch_idx>{}));
|
||||
Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func);
|
||||
}
|
||||
else
|
||||
{
|
||||
Base::LocalPrefill(b_copy_lds_window,
|
||||
b_block_tiles.get(number<prefetch_idx>{}),
|
||||
b_element_func);
|
||||
}
|
||||
});
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
@@ -519,7 +519,7 @@ struct UniversalGemmPipelineAgBgCrPolicy
|
||||
using ALayout = remove_cvref_t<typename Problem::ALayout>;
|
||||
static_assert(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::ColumnMajor>);
|
||||
constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
constexpr index_t MPerBlock = Problem::BlockGemmShape::kN;
|
||||
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
constexpr index_t VecLoadSize = GetVectorSizeA<Problem>();
|
||||
|
||||
|
||||
Reference in New Issue
Block a user