[CK TILE] Block universal gemm lds<->vgpr optimizations (#1906)

* [CK TILE] Block universal gemm lds<->vgpr optimizations

* Rebase

* Fixes
This commit is contained in:
Bartłomiej Kocot
2025-02-27 10:36:28 +01:00
committed by GitHub
parent e9ee568681
commit bf1e17007e
7 changed files with 305 additions and 406 deletions

View File

@@ -68,9 +68,10 @@ struct GemmPipelineAgBgCrImplBase
return make_tuple(std::move(a_lds_block), std::move(b_lds_block));
}
template <typename ADramBlockWindowTmp, typename ALdsTensorView>
CK_TILE_DEVICE auto GetAWindows(const ADramBlockWindowTmp& a_dram_block_window_tmp,
const ALdsTensorView& a_lds_block_view) const
template <typename ADramBlockWindowTmp, typename ALdsTensorView, typename ALdsLoadTileDistr>
CK_TILE_DEVICE constexpr auto GetAWindows(const ADramBlockWindowTmp& a_dram_block_window_tmp,
const ALdsTensorView& a_lds_block_view,
const ALdsLoadTileDistr&) const
{
constexpr bool is_col_major = std::is_same_v<ALayout, tensor_layout::gemm::ColumnMajor>;
@@ -88,17 +89,21 @@ struct GemmPipelineAgBgCrImplBase
auto a_copy_lds_window = make_tile_window(
a_lds_block_view, make_tuple(number<MPerBlock>{}, number<KPerBlock>{}), {0, 0});
auto a_lds_gemm_window = make_tile_window(
a_lds_block_view, make_tuple(number<MPerBlock>{}, number<KPerBlock>{}), {0, 0});
auto a_lds_gemm_window =
make_tile_window(a_lds_block_view,
make_tuple(number<MPerBlock>{}, number<KPerBlock>{}),
{0, 0},
ALdsLoadTileDistr{});
return make_tuple(std::move(a_copy_dram_window),
std::move(a_copy_lds_window),
std::move(a_lds_gemm_window));
}
template <typename BDramBlockWindowTmp, typename BLdsTensorView>
CK_TILE_DEVICE auto GetBWindows(const BDramBlockWindowTmp& b_dram_block_window_tmp,
const BLdsTensorView& b_lds_block_view) const
template <typename BDramBlockWindowTmp, typename BLdsTensorView, typename BLdsLoadTileDistr>
CK_TILE_DEVICE constexpr auto GetBWindows(const BDramBlockWindowTmp& b_dram_block_window_tmp,
const BLdsTensorView& b_lds_block_view,
const BLdsLoadTileDistr&) const
{
constexpr bool is_row_major = std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>;
@@ -117,8 +122,11 @@ struct GemmPipelineAgBgCrImplBase
auto b_copy_lds_window = make_tile_window(
b_lds_block_view, make_tuple(number<NPerBlock>{}, number<KPerBlock>{}), {0, 0});
auto b_lds_gemm_window = make_tile_window(
b_lds_block_view, make_tuple(number<NPerBlock>{}, number<KPerBlock>{}), {0, 0});
auto b_lds_gemm_window =
make_tile_window(b_lds_block_view,
make_tuple(number<NPerBlock>{}, number<KPerBlock>{}),
{0, 0},
BLdsLoadTileDistr{});
return make_tuple(std::move(b_copy_dram_window),
std::move(b_copy_lds_window),

View File

@@ -346,17 +346,23 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
// A/B tiles in LDS
auto&& [a_lds_block, b_lds_block] = Base::GetABLdsTensorViews(p_smem);
// Tile distribution for load from lds
constexpr auto a_lds_load_tile_distr =
make_static_tile_distribution(BlockGemm::MakeABlockDistributionEncode());
constexpr auto b_lds_load_tile_distr =
make_static_tile_distribution(BlockGemm::MakeBBlockDistributionEncode());
// A DRAM tile window for load
// A LDS tile window for store
// A LDS tile for block GEMM
auto&& [a_copy_dram_window, a_copy_lds_window, a_lds_gemm_window] =
Base::GetAWindows(a_dram_block_window_tmp, a_lds_block);
Base::GetAWindows(a_dram_block_window_tmp, a_lds_block, a_lds_load_tile_distr);
// B DRAM tile window for load
// B LDS tile window for store
// B LDS tile for block GEMM
auto&& [b_copy_dram_window, b_copy_lds_window, b_lds_gemm_window] =
Base::GetBWindows(b_dram_block_window_tmp, b_lds_block);
Base::GetBWindows(b_dram_block_window_tmp, b_lds_block, b_lds_load_tile_distr);
// Block GEMM
auto block_gemm = BlockGemm();

View File

@@ -215,10 +215,17 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
auto& a_lds_block = ab_lds_blocks.at(I0{});
auto& b_lds_block = ab_lds_blocks.at(I1{});
// Tile distribution for load from lds
constexpr auto a_lds_load_tile_distr = decltype(make_static_tile_distribution(
BlockGemm::MakeABlockDistributionEncode())){};
constexpr auto b_lds_load_tile_distr = decltype(make_static_tile_distribution(
BlockGemm::MakeBBlockDistributionEncode())){};
// A DRAM tile window for load
// A LDS tile window for store
// A LDS tile for block GEMM
auto a_windows = Base::GetAWindows(a_dram_block_window_tmp, a_lds_block);
auto a_windows =
Base::GetAWindows(a_dram_block_window_tmp, a_lds_block, a_lds_load_tile_distr);
auto& a_copy_dram_window = a_windows.at(I0{});
auto& a_copy_lds_window = a_windows.at(I1{});
auto& a_lds_gemm_window = a_windows.at(I2{});
@@ -226,7 +233,8 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
// B DRAM tile window for load
// B LDS tile window for store
// B LDS tile for block GEMM
auto b_windows = Base::GetBWindows(b_dram_block_window_tmp, b_lds_block);
auto b_windows =
Base::GetBWindows(b_dram_block_window_tmp, b_lds_block, b_lds_load_tile_distr);
auto& b_copy_dram_window = b_windows.at(I0{});
auto& b_copy_lds_window = b_windows.at(I1{});
auto& b_lds_gemm_window = b_windows.at(I2{});
@@ -493,10 +501,17 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
auto& a_lds_block = ab_lds_blocks.at(I0{});
auto& b_lds_block = ab_lds_blocks.at(I1{});
// Tile distribution for load from lds
constexpr auto a_lds_load_tile_distr = decltype(make_static_tile_distribution(
BlockGemm::MakeABlockDistributionEncode())){};
constexpr auto b_lds_load_tile_distr = decltype(make_static_tile_distribution(
BlockGemm::MakeBBlockDistributionEncode())){};
// A DRAM tile window for load
// A LDS tile window for store
// A LDS tile for block GEMM
auto a_windows = Base::GetAWindows(a_dram_block_window_tmp, a_lds_block);
auto a_windows =
Base::GetAWindows(a_dram_block_window_tmp, a_lds_block, a_lds_load_tile_distr);
auto& a_copy_dram_window = a_windows.at(I0{});
auto& a_copy_lds_window = a_windows.at(I1{});
auto& a_lds_gemm_window = a_windows.at(I2{});
@@ -504,7 +519,8 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
// B DRAM tile window for load
// B LDS tile window for store
// B LDS tile for block GEMM
auto b_windows = Base::GetBWindows(b_dram_block_window_tmp, b_lds_block);
auto b_windows =
Base::GetBWindows(b_dram_block_window_tmp, b_lds_block, b_lds_load_tile_distr);
auto& b_copy_dram_window = b_windows.at(I0{});
auto& b_copy_lds_window = b_windows.at(I1{});
auto& b_lds_gemm_window = b_windows.at(I2{});

View File

@@ -125,13 +125,25 @@ struct GemmPipelineAGmemBGmemCRegV1
auto b_copy_lds_window = make_tile_window(
b_lds_block, make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}), {0, 0});
// Tile distribution for load from lds
constexpr auto a_lds_load_tile_distr =
make_static_tile_distribution(BlockGemm::MakeABlockDistributionEncode());
constexpr auto b_lds_load_tile_distr =
make_static_tile_distribution(BlockGemm::MakeBBlockDistributionEncode());
// A LDS tile for block GEMM
auto a_lds_gemm_window = make_tile_window(
a_lds_block, make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}), {0, 0});
auto a_lds_gemm_window =
make_tile_window(a_lds_block,
make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}),
{0, 0},
a_lds_load_tile_distr);
// B LDS tile for block GEMM
auto b_lds_gemm_window = make_tile_window(
b_lds_block, make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}), {0, 0});
auto b_lds_gemm_window =
make_tile_window(b_lds_block,
make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}),
{0, 0},
b_lds_load_tile_distr);
// Block GEMM
auto block_gemm = BlockGemm();

View File

@@ -122,17 +122,29 @@ struct GemmPipelineAGmemBGmemCRegV2
{0, 0},
b_copy_dram_window.get_tile_distribution());
// A LDS tile for block GEMM
auto a_lds_gemm_window = make_tile_window(
a_lds_block, make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}), {0, 0});
// B LDS tile for block GEMM
auto b_lds_gemm_window = make_tile_window(
b_lds_block, make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}), {0, 0});
// Block GEMM
constexpr auto block_gemm = Policy::template GetBlockGemm<Problem>();
// Tile distribution for load from lds
constexpr auto a_lds_load_tile_distr =
make_static_tile_distribution(decltype(block_gemm)::MakeABlockDistributionEncode());
constexpr auto b_lds_load_tile_distr =
make_static_tile_distribution(decltype(block_gemm)::MakeBBlockDistributionEncode());
// A LDS tile for block GEMM
auto a_lds_gemm_window =
make_tile_window(a_lds_block,
make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}),
{0, 0},
a_lds_load_tile_distr);
// B LDS tile for block GEMM
auto b_lds_gemm_window =
make_tile_window(b_lds_block,
make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}),
{0, 0},
b_lds_load_tile_distr);
// Acc register tile
auto c_block_tile = decltype(block_gemm(a_lds_gemm_window, b_lds_gemm_window)){};