From de16fbb133897863366b00527673d219441fe872 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20Kocot?= Date: Thu, 27 Feb 2025 10:36:28 +0100 Subject: [PATCH] [CK TILE] Block universal gemm lds<->vgpr optimizations (#1906) * [CK TILE] Block universal gemm lds<->vgpr optimizations * Rebase * Fixes [ROCm/composable_kernel commit: bf1e17007e46e9f0723d66db41a784dbaf340c6a] --- .../block/block_universal_gemm_as_bs_cr.hpp | 573 +++++++----------- .../ck_tile/ops/gemm/kernel/gemm_kernel.hpp | 28 +- .../pipeline/gemm_pipeline_ag_bg_cr_base.hpp | 28 +- .../gemm_pipeline_ag_bg_cr_comp_v3.hpp | 10 +- .../pipeline/gemm_pipeline_ag_bg_cr_mem.hpp | 24 +- .../gemm_pipeline_agmem_bgmem_creg_v1.hpp | 20 +- .../gemm_pipeline_agmem_bgmem_creg_v2.hpp | 28 +- 7 files changed, 305 insertions(+), 406 deletions(-) diff --git a/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp b/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp index d9d6739fb5..6024e00419 100644 --- a/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp +++ b/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp @@ -68,16 +68,6 @@ struct BlockUniversalGemmAsBsCr static constexpr index_t NPerBlockPerIter = NWarp * WarpGemm::kN; static constexpr index_t KPerBlockPerIter = WarpGemm::kK; - using AWarpTileDistr = remove_cvref_t; - using BWarpTileDistr = remove_cvref_t; - - using AWarpTile = remove_cvref_t( - AWarpTileDistr{}))>; - using BWarpTile = remove_cvref_t( - BWarpTileDistr{}))>; - // TODO: Should we have two policies? Interwave & Intrawave ?? static constexpr index_t InterWaveSchedulingMacClusters = 1; @@ -108,6 +98,25 @@ struct BlockUniversalGemmAsBsCr static constexpr auto Scheduler = Traits::Scheduler; + using AWarpDstr = typename WarpGemm::AWarpDstr; + using BWarpDstr = typename WarpGemm::BWarpDstr; + using CWarpDstr = typename WarpGemm::CWarpDstr; + + using AWarpTensor = typename WarpGemm::AWarpTensor; + using BWarpTensor = typename WarpGemm::BWarpTensor; + using CWarpTensor = typename WarpGemm::CWarpTensor; + + static constexpr auto a_warp_y_lengths = + to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); + static constexpr auto b_warp_y_lengths = + to_sequence(BWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); + static constexpr auto c_warp_y_lengths = + to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); + + static constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t{}; + static constexpr auto b_warp_y_index_zeros = uniform_sequence_gen_t{}; + static constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t{}; + static constexpr index_t APackedSize = ck_tile::numeric_traits>::PackedSize; static constexpr index_t BPackedSize = @@ -116,18 +125,65 @@ struct BlockUniversalGemmAsBsCr using I0 = number<0>; using I1 = number<1>; + CK_TILE_DEVICE static constexpr auto MakeABlockDistributionEncode() + { + constexpr index_t KPerThread = Traits::KPerThread; + constexpr index_t NumMacClusters = Traits::InterWaveSchedulingMacClusters; + constexpr index_t KPerInnerLoop = ck_tile::max(KPerThread / NumMacClusters, Traits::KPack); + constexpr index_t KIterInterWave = KPerInnerLoop / WarpGemm::kK; + + using KIterSeq = std::conditional_t, + sequence>; + + constexpr auto a_block_outer_dstr_encoding = + tile_distribution_encoding, + tuple, KIterSeq>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + a_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{}); + + return a_block_dstr_encode; + } + + CK_TILE_DEVICE static constexpr auto MakeBBlockDistributionEncode() + { + constexpr index_t KPerThread = Traits::KPerThread; + constexpr index_t NumMacClusters = Traits::InterWaveSchedulingMacClusters; + constexpr index_t KPerInnerLoop = ck_tile::max(KPerThread / NumMacClusters, Traits::KPack); + constexpr index_t KIterInterWave = KPerInnerLoop / WarpGemm::kK; + + using KIterSeq = std::conditional_t, + sequence>; + + constexpr auto b_block_outer_dstr_encoding = + tile_distribution_encoding, + tuple, KIterSeq>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + b_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{}); + + return b_block_dstr_encode; + } + private: template - CK_TILE_DEVICE static void load_interleaved_pk_type(const WarpWindow& warp_window, - WarpTile& warp_tile) + CK_TILE_DEVICE static void load_interleaved_pk_type(WarpTile& warp_tile, + const WarpWindow& warp_window) { constexpr index_t UnaryOpSize = 8; const element_wise::PassThroughPack8 elementwise_op{}; - constexpr index_t thread_buffer_size = - Traits::AWarpTile::get_thread_buffer_size() / UnaryOpSize; - const auto in_dstr_tensors = load_tile(warp_window); + constexpr index_t thread_buffer_size = WarpTile::get_thread_buffer_size() / UnaryOpSize; + const auto in_dstr_tensors = load_tile(warp_window); - static_assert(Traits::AWarpTile::get_thread_buffer_size() % UnaryOpSize == 0); + static_assert(WarpTile::get_thread_buffer_size() % UnaryOpSize == 0); using ComputeVectorType = ComputeDataType __attribute__((ext_vector_type(UnaryOpSize))); static_for<0, thread_buffer_size, 1>{}([&](auto i) { @@ -144,6 +200,17 @@ struct BlockUniversalGemmAsBsCr template struct BlockGemmImpl { + static constexpr auto ALdsTileDistr = + decltype(make_static_tile_distribution(MakeABlockDistributionEncode())){}; + static constexpr auto BLdsTileDistr = + decltype(make_static_tile_distribution(MakeBBlockDistributionEncode())){}; + + using ALdsTile = decltype(make_static_distributed_tensor(ALdsTileDistr)); + using BLdsTile = decltype(make_static_distributed_tensor(BLdsTileDistr)); + + ALdsTile a_warp_tile_; + ALdsTile b_warp_tile_; + // C += A * B template CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor, @@ -158,114 +225,39 @@ struct BlockUniversalGemmAsBsCr "The ADataType and BDataType as defined in " "traits should be the same as correspoinding block window data type!"); - static_assert( - GemmTraits::MPerBlock == ASmemBlockWindow{}.get_window_lengths()[I0{}] && - GemmTraits::NPerBlock == BSmemBlockWindow{}.get_window_lengths()[I0{}] && - GemmTraits::KPerBlock == ASmemBlockWindow{}.get_window_lengths()[I1{}], - "MPerBlock, NPerBlock, KPerBlock defined in " - " BlockGemmShape are different from A/B block smem windows apropriate dims!"); - - const index_t iMWarp = get_warp_id() / NWarp; - const index_t iNWarp = get_warp_id() - (iMWarp * NWarp); - - // TODO: refactor warp_window tile type to class member as it should be - // compile-time known information. - auto a_warp_window_tmp = make_tile_window( - a_block_window.get_bottom_tensor_view(), - make_tuple(number{}, number{}), - a_block_window.get_window_origin() + multi_index<2>{iMWarp * WarpGemm::kM, 0}, - make_static_tile_distribution(typename WarpGemm::AWarpDstrEncoding{})); - - using AWarpWindow = remove_cvref_t; - - static_assert(GemmTraits::AWarpTile::get_num_of_dimension() == - AWarpWindow::get_num_of_dimension(), - "AWarpWindow number of dimensions must be equal to " - "AWarpTile number of dimensions!"); - static_assert(GemmTraits::AWarpTile::get_lengths() == - AWarpWindow{}.get_window_lengths(), - "AWarpWindow lengths must be equal to AWarpTile lengths!"); - - statically_indexed_array< - statically_indexed_array, - MIterPerWarp> - a_warp_windows; - - // construct B-warp-window - auto b_warp_window_tmp = make_tile_window( - b_block_window.get_bottom_tensor_view(), - make_tuple(number{}, number{}), - b_block_window.get_window_origin() + multi_index<2>{iNWarp * WarpGemm::kN, 0}, - make_static_tile_distribution(typename WarpGemm::BWarpDstrEncoding{})); - - using BWarpWindow = remove_cvref_t; - - static_assert(GemmTraits::BWarpTile::get_num_of_dimension() == - BWarpWindow::get_num_of_dimension(), - "BWarpWindow number of dimensions must be equal to " - "BWarpTile number of dimensions!"); - static_assert(GemmTraits::BWarpTile::get_lengths() == - BWarpWindow{}.get_window_lengths(), - "BWarpWindow lengths must be equal to BWarpTile lengths!"); - - statically_indexed_array< - statically_indexed_array, - NIterPerWarp> - b_warp_windows; - - static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - static_for<0, GemmTraits::KIterPerWarp, 1>{}([&](auto kIter) { - a_warp_windows(mIter)(kIter) = a_warp_window_tmp; - - // TODO: I don't have to move 0,0 window! - move_tile_window(a_warp_windows(mIter)(kIter), - {mIter * GemmTraits::MPerBlockPerIter, - kIter * GemmTraits::KPerBlockPerIter}); - }); - }); - - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - static_for<0, GemmTraits::KIterPerWarp, 1>{}([&](auto kIter) { - b_warp_windows(nIter)(kIter) = b_warp_window_tmp; - - move_tile_window(b_warp_windows(nIter)(kIter), - {nIter * GemmTraits::NPerBlockPerIter, - kIter * GemmTraits::KPerBlockPerIter}); - }); - }); - - using CWarpDstr = typename WarpGemm::CWarpDstr; - using AWarpTensor = typename WarpGemm::AWarpTensor; - using BWarpTensor = typename WarpGemm::BWarpTensor; - using CWarpTensor = typename WarpGemm::CWarpTensor; - - constexpr auto c_warp_y_lengths = - to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); - constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t{}; - + if constexpr(std::is_same_v) + { + load_interleaved_pk_type(a_warp_tile_, a_block_window); + } + else + { + load_tile(a_warp_tile_, a_block_window); + } + if constexpr(std::is_same_v) + { + load_interleaved_pk_type(b_warp_tile_, b_block_window); + } + else + { + load_tile(b_warp_tile_, b_block_window); + } // hot loop: static_for<0, GemmTraits::KIterPerWarp, 1>{}([&](auto kIter) { static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - AWarpTensor a_warp_tile; - if constexpr(std::is_same_v) - { - load_interleaved_pk_type(a_warp_windows(mIter)(kIter), a_warp_tile); - } - else - { - a_warp_tile = load_tile(a_warp_windows(mIter)(kIter)); - } + // read A warp tensor from A block tensor + AWarpTensor a_warp_tensor; + + a_warp_tensor.get_thread_buffer() = a_warp_tile_.get_y_sliced_thread_data( + merge_sequences(sequence{}, a_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, a_warp_y_lengths)); static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - BWarpTensor b_warp_tile; - if constexpr(std::is_same_v) - { - load_interleaved_pk_type(b_warp_windows(nIter)(kIter), b_warp_tile); - } - else - { - b_warp_tile = load_tile(b_warp_windows(nIter)(kIter)); - } + // read B warp tensor from B block tensor + BWarpTensor b_warp_tensor; + + b_warp_tensor.get_thread_buffer() = b_warp_tile_.get_y_sliced_thread_data( + merge_sequences(sequence{}, b_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, b_warp_y_lengths)); // read C warp tensor from C block tensor- CWarpTensor c_warp_tensor; @@ -275,7 +267,7 @@ struct BlockUniversalGemmAsBsCr merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); // warp GEMM - WarpGemm{}(c_warp_tensor, a_warp_tile, b_warp_tile); + WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor); // write C warp tensor into C block tensor c_block_tensor.set_y_sliced_thread_data( @@ -291,149 +283,68 @@ struct BlockUniversalGemmAsBsCr template struct BlockGemmImpl { - statically_indexed_array< - statically_indexed_array, - MIterPerWarp> - a_warp_tiles_; + static constexpr auto ALdsTileDistr = + decltype(make_static_tile_distribution(MakeABlockDistributionEncode())){}; + static constexpr auto BLdsTileDistr = + decltype(make_static_tile_distribution(MakeBBlockDistributionEncode())){}; - statically_indexed_array< - statically_indexed_array, - NIterPerWarp> - b_warp_tiles_; + using ALdsTile = decltype(make_static_distributed_tensor(ALdsTileDistr)); + using BLdsTile = decltype(make_static_distributed_tensor(BLdsTileDistr)); + + ALdsTile a_warp_tile_; + ALdsTile b_warp_tile_; template CK_TILE_DEVICE void LocalPrefetch(const ASmemBlockWindow& a_block_window, const BSmemBlockWindow& b_block_window) { - static_assert( - GemmTraits::MPerBlock == ASmemBlockWindow{}.get_window_lengths()[I0{}] && - GemmTraits::NPerBlock == BSmemBlockWindow{}.get_window_lengths()[I0{}] && - GemmTraits::KPerBlock == ASmemBlockWindow{}.get_window_lengths()[I1{}], - "MPerBlock, NPerBlock, KPerBlock defined in " - " BlockGemmShape are different from A/B block smem windows apropriate dims!"); - - static_assert(std::is_same_v && - std::is_same_v, - "The ADataType and BDataType as defined in " - "traits should be the same as correspoinding block window data type!"); - - const index_t iMWarp = get_warp_id() / NWarp; - const index_t iNWarp = get_warp_id() - (iMWarp * NWarp); - - // TODO: refactor warp_window tile type to class member as it should be - // compile-time known information. - auto a_warp_window_tmp = make_tile_window( - a_block_window.get_bottom_tensor_view(), - make_tuple(number{}, number{}), - a_block_window.get_window_origin() + multi_index<2>{iMWarp * WarpGemm::kM, 0}, - make_static_tile_distribution(typename WarpGemm::AWarpDstrEncoding{})); - - using AWarpWindow = remove_cvref_t; - - static_assert(GemmTraits::AWarpTile::get_num_of_dimension() == - AWarpWindow::get_num_of_dimension(), - "AWarpWindow number of dimensions must be equal to " - "AWarpTile number of dimensions!"); - static_assert(GemmTraits::AWarpTile::get_lengths() == - AWarpWindow{}.get_window_lengths(), - "AWarpWindow lengths must be equal to AWarpTile lengths!"); - - statically_indexed_array, - MIterPerWarp> - a_warp_windows; - - // construct B-warp-window - auto b_warp_window_tmp = make_tile_window( - b_block_window.get_bottom_tensor_view(), - make_tuple(number{}, number{}), - b_block_window.get_window_origin() + multi_index<2>{iNWarp * WarpGemm::kN, 0}, - make_static_tile_distribution(typename WarpGemm::BWarpDstrEncoding{})); - - using BWarpWindow = remove_cvref_t; - - static_assert(GemmTraits::BWarpTile::get_num_of_dimension() == - BWarpWindow::get_num_of_dimension(), - "BWarpWindow number of dimensions must be equal to " - "BWarpTile number of dimensions!"); - static_assert(GemmTraits::BWarpTile::get_lengths() == - BWarpWindow{}.get_window_lengths(), - "BWarpWindow lengths must be equal to BWarpTile lengths!"); - - statically_indexed_array, - NIterPerWarp> - b_warp_windows; - - static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - a_warp_windows(mIter)(kIter) = a_warp_window_tmp; - - // TODO: I don't have to move 0,0 window! - move_tile_window(a_warp_windows(mIter)(kIter), - {mIter * GemmTraits::MPerBlockPerIter, - kIter * GemmTraits::KPerBlockPerIter}); - }); - }); - - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - b_warp_windows(nIter)(kIter) = b_warp_window_tmp; - - move_tile_window(b_warp_windows(nIter)(kIter), - {nIter * GemmTraits::NPerBlockPerIter, - kIter * GemmTraits::KPerBlockPerIter}); - }); - }); - - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - // read A warp tensor from A block window - if constexpr(std::is_same_v) - { - load_interleaved_pk_type(a_warp_windows(mIter)(kIter), - a_warp_tiles_(mIter)(kIter)); - } - else - { - a_warp_tiles_(mIter)(kIter) = load_tile(a_warp_windows(mIter)(kIter)); - } - }); - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - // read B warp tensor from B Block window - if constexpr(std::is_same_v) - { - load_interleaved_pk_type(b_warp_windows(nIter)(kIter), - b_warp_tiles_(nIter)(kIter)); - } - else - { - b_warp_tiles_(nIter)(kIter) = load_tile(b_warp_windows(nIter)(kIter)); - } - }); - }); + if constexpr(std::is_same_v) + { + load_interleaved_pk_type(a_warp_tile_, a_block_window); + } + else + { + load_tile(a_warp_tile_, a_block_window); + } + if constexpr(std::is_same_v) + { + load_interleaved_pk_type(b_warp_tile_, b_block_window); + } + else + { + load_tile(b_warp_tile_, b_block_window); + } } // C += A * B template CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor, - [[maybe_unused]] const ASmemBlockWindow& a_block_window, - [[maybe_unused]] const BSmemBlockWindow& b_block_window) + [[maybe_unused]] ASmemBlockWindow& a_block_window, + [[maybe_unused]] BSmemBlockWindow& b_block_window) { static_assert(std::is_same_v, "The CDataType as defined in traits should be the same as correspoinding " "C block tensor data type!"); - using CWarpDstr = typename WarpGemm::CWarpDstr; - using CWarpTensor = typename WarpGemm::CWarpTensor; - - constexpr auto c_warp_y_lengths = - to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); - constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t{}; - // hot loop: static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + // read A warp tensor from A block tensor + AWarpTensor a_warp_tensor; + + a_warp_tensor.get_thread_buffer() = a_warp_tile_.get_y_sliced_thread_data( + merge_sequences(sequence{}, a_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, a_warp_y_lengths)); + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - // read C warp tensor from C block tensor- + // read B warp tensor from B block tensor + BWarpTensor b_warp_tensor; + + b_warp_tensor.get_thread_buffer() = b_warp_tile_.get_y_sliced_thread_data( + merge_sequences(sequence{}, b_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, b_warp_y_lengths)); + + // read C warp tensor from C block tensor CWarpTensor c_warp_tensor; c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data( @@ -441,9 +352,7 @@ struct BlockUniversalGemmAsBsCr merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); // warp GEMM - WarpGemm{}(c_warp_tensor, - a_warp_tiles_[mIter][kIter], - b_warp_tiles_[nIter][kIter]); + WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor); // write C warp tensor into C block tensor c_block_tensor.set_y_sliced_thread_data( @@ -468,126 +377,53 @@ struct BlockUniversalGemmAsBsCr static constexpr index_t KRepeat = KPerThread / KPerInnerLoop; static constexpr index_t KInnerLoopIter = KPerInnerLoop / GemmTraits::KPack; - statically_indexed_array< - statically_indexed_array, - MIterPerWarp> - a_warp_tiles_; + static constexpr auto ALdsTileDistr = + decltype(make_static_tile_distribution(MakeABlockDistributionEncode())){}; + static constexpr auto BLdsTileDistr = + decltype(make_static_tile_distribution(MakeBBlockDistributionEncode())){}; - statically_indexed_array< - statically_indexed_array, - NIterPerWarp> - b_warp_tiles_; + using ALdsTile = decltype(make_static_distributed_tensor(ALdsTileDistr)); + using BLdsTile = decltype(make_static_distributed_tensor(BLdsTileDistr)); + + ALdsTile a_warp_tile_; + ALdsTile b_warp_tile_; template CK_TILE_DEVICE void LocalPrefetch(const ASmemBlockWindow& a_block_window, const BSmemBlockWindow& b_block_window) { - static_assert( - GemmTraits::MPerBlock == ASmemBlockWindow{}.get_window_lengths()[I0{}] && - GemmTraits::NPerBlock == BSmemBlockWindow{}.get_window_lengths()[I0{}] && - GemmTraits::KPerBlock == ASmemBlockWindow{}.get_window_lengths()[I1{}], - "MPerBlock, NPerBlock, KPerBlock defined in " - " BlockGemmShape are different from A/B block smem windows apropriate dims!"); + constexpr auto a_lds_load_tile_distr = + make_static_tile_distribution(MakeABlockDistributionEncode()); + constexpr auto b_lds_load_tile_distr = + make_static_tile_distribution(MakeBBlockDistributionEncode()); - static_assert(std::is_same_v && - std::is_same_v, - "The ADataType and BDataType as defined in " - "traits should be the same as correspoinding block window data type!"); - - const index_t iMWarp = get_warp_id() / NWarp; - const index_t iNWarp = get_warp_id() - (iMWarp * NWarp); - - // TODO: refactor warp_window tile type to class member as it should be - // compile-time known information. - auto a_warp_window_tmp = make_tile_window( + auto a_lds_gemm_window = make_tile_window( a_block_window.get_bottom_tensor_view(), - make_tuple(number{}, number{}), - a_block_window.get_window_origin() + - multi_index<2>{iMWarp * WarpGemm::kM, KIdx * KPerInnerLoop}, - make_static_tile_distribution(typename WarpGemm::AWarpDstrEncoding{})); - - using AWarpWindow = remove_cvref_t; - - static_assert(GemmTraits::AWarpTile::get_num_of_dimension() == - AWarpWindow::get_num_of_dimension(), - "AWarpWindow number of dimensions must be equal to " - "AWarpTile number of dimensions!"); - static_assert(GemmTraits::AWarpTile::get_lengths() == - AWarpWindow{}.get_window_lengths(), - "AWarpWindow lengths must be equal to AWarpTile lengths!"); - - statically_indexed_array, - MIterPerWarp> - a_warp_windows; - - // construct B-warp-window - auto b_warp_window_tmp = make_tile_window( + make_tuple(number{}, number{}), + {0, KIdx * KPerInnerLoop}, + a_lds_load_tile_distr); + auto b_lds_gemm_window = make_tile_window( b_block_window.get_bottom_tensor_view(), - make_tuple(number{}, number{}), - b_block_window.get_window_origin() + - multi_index<2>{iNWarp * WarpGemm::kN, KIdx * KPerInnerLoop}, - make_static_tile_distribution(typename WarpGemm::BWarpDstrEncoding{})); + make_tuple(number{}, number{}), + {0, KIdx * KPerInnerLoop}, + b_lds_load_tile_distr); - using BWarpWindow = remove_cvref_t; - - static_assert(GemmTraits::BWarpTile::get_num_of_dimension() == - BWarpWindow::get_num_of_dimension(), - "BWarpWindow number of dimensions must be equal to " - "BWarpTile number of dimensions!"); - static_assert(GemmTraits::BWarpTile::get_lengths() == - BWarpWindow{}.get_window_lengths(), - "BWarpWindow lengths must be equal to BWarpTile lengths!"); - - statically_indexed_array, - NIterPerWarp> - b_warp_windows; - - static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - static_for<0, KInnerLoopIter, 1>{}([&](auto kIter) { - a_warp_windows(mIter)(kIter) = a_warp_window_tmp; - - move_tile_window(a_warp_windows(mIter)(kIter), - {mIter * GemmTraits::MPerBlockPerIter, - kIter * GemmTraits::KPerBlockPerIter}); - }); - }); - - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - static_for<0, KInnerLoopIter, 1>{}([&](auto kIter) { - b_warp_windows(nIter)(kIter) = b_warp_window_tmp; - - move_tile_window(b_warp_windows(nIter)(kIter), - {nIter * GemmTraits::NPerBlockPerIter, - kIter * GemmTraits::KPerBlockPerIter}); - }); - }); - - // TODO check if a_warp_tiles has same desc as a_warp_window - static_for<0, KInnerLoopIter, 1>{}([&](auto kIter) { - static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - if constexpr(std::is_same_v) - { - load_interleaved_pk_type(a_warp_windows(mIter)(kIter), - a_warp_tiles_(mIter)(kIter)); - } - else - { - a_warp_tiles_(mIter)(kIter) = load_tile(a_warp_windows(mIter)(kIter)); - } - }); - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - // read B warp tensor from B Block window - if constexpr(std::is_same_v) - { - load_interleaved_pk_type(b_warp_windows(nIter)(kIter), - b_warp_tiles_(nIter)(kIter)); - } - else - { - b_warp_tiles_(nIter)(kIter) = load_tile(b_warp_windows(nIter)(kIter)); - } - }); - }); + if constexpr(std::is_same_v) + { + load_interleaved_pk_type(a_warp_tile_, a_block_window); + } + else + { + load_tile(a_warp_tile_, a_lds_gemm_window); + } + if constexpr(std::is_same_v) + { + load_interleaved_pk_type(b_warp_tile_, b_block_window); + } + else + { + load_tile(b_warp_tile_, b_lds_gemm_window); + } } // C += A * B @@ -600,13 +436,6 @@ struct BlockUniversalGemmAsBsCr "The CDataType as defined in traits should be the same as correspoinding " "C block tensor data type!"); - using CWarpDstr = typename WarpGemm::CWarpDstr; - using CWarpTensor = typename WarpGemm::CWarpTensor; - - constexpr auto c_warp_y_lengths = - to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); - constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t{}; - // hot loop: static_for<0, KRepeat, 1>{}([&](auto kIter) { LocalPrefetch(a_block_window, b_block_window); @@ -626,7 +455,21 @@ struct BlockUniversalGemmAsBsCr static_for<0, KInnerLoopIter, 1>{}([&](auto kInnerIter) { static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + // read A warp tensor from A block tensor + AWarpTensor a_warp_tensor; + + a_warp_tensor.get_thread_buffer() = a_warp_tile_.get_y_sliced_thread_data( + merge_sequences(sequence{}, a_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, a_warp_y_lengths)); static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + // read B warp tensor from B block tensor + BWarpTensor b_warp_tensor; + + b_warp_tensor.get_thread_buffer() = + b_warp_tile_.get_y_sliced_thread_data( + merge_sequences(sequence{}, + b_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, b_warp_y_lengths)); // read C warp tensor from C block tensor- CWarpTensor c_warp_tensor; @@ -651,9 +494,7 @@ struct BlockUniversalGemmAsBsCr __builtin_amdgcn_sched_barrier(0); } // warp GEMM - WarpGemm{}(c_warp_tensor, - a_warp_tiles_[mIter][kInnerIter], - b_warp_tiles_[nIter][kInnerIter]); + WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor); // write C warp tensor into C block tensor c_block_tensor.set_y_sliced_thread_data( diff --git a/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp index 741a6b9fc3..f2aa3af196 100644 --- a/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp @@ -129,34 +129,34 @@ struct GemmKernel const std::size_t k_id = blockIdx.z) { constexpr auto K1 = TilePartitioner::BlockGemmShape::WarpTile::at(number<2>{}); - const index_t K_t = kargs.k_batch * K1; - const index_t KRead = (kargs.K + K_t - 1) / K_t * K1; + const index_t K_t = __builtin_amdgcn_readfirstlane(kargs.k_batch * K1); + const index_t KRead = __builtin_amdgcn_readfirstlane((kargs.K + K_t - 1) / K_t * K1); if constexpr(std::is_same_v) { - a_k_split_offset = k_id * KRead; + a_k_split_offset = __builtin_amdgcn_readfirstlane(k_id * KRead); } else if constexpr(std::is_same_v) { - a_k_split_offset = k_id * KRead * kargs.stride_A; + a_k_split_offset = __builtin_amdgcn_readfirstlane(k_id * KRead * kargs.stride_A); } if constexpr(std::is_same_v) { - b_k_split_offset = k_id * KRead * kargs.stride_B; + b_k_split_offset = __builtin_amdgcn_readfirstlane(k_id * KRead * kargs.stride_B); } else if constexpr(std::is_same_v) { - b_k_split_offset = k_id * KRead; + b_k_split_offset = __builtin_amdgcn_readfirstlane(k_id * KRead); } if(k_id < static_cast(kargs.k_batch - 1)) { - splitted_k = KRead; + splitted_k = __builtin_amdgcn_readfirstlane(KRead); } else { - splitted_k = kargs.K - KRead * (kargs.k_batch - 1); + splitted_k = __builtin_amdgcn_readfirstlane(kargs.K - KRead * (kargs.k_batch - 1)); } } @@ -523,7 +523,8 @@ struct GemmKernel const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple); auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n); - const index_t num_loop = TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k); + const index_t num_loop = __builtin_amdgcn_readfirstlane( + TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k)); // Run GEMM cooperatively by whole workgroup. const auto& a_block_window = gemm_tile_windows.at(I0); @@ -574,7 +575,8 @@ struct GemmKernel const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple); auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n); - const index_t num_loop = TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k); + const index_t num_loop = __builtin_amdgcn_readfirstlane( + TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k)); // Run GEMM cooperatively by whole workgroup. const auto& a_block_window = gemm_tile_windows.at(I0); @@ -593,7 +595,8 @@ struct GemmKernel CK_TILE_DEVICE void operator()(GemmKernelArgs kargs) const { - const auto [iM, iN] = TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(blockIdx.x); + const auto blockId = __builtin_amdgcn_readfirstlane(blockIdx.x); + const auto [iM, iN] = TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(blockId); const index_t i_m = __builtin_amdgcn_readfirstlane(iM * TilePartitioner::MPerBlock); const index_t i_n = __builtin_amdgcn_readfirstlane(iN * TilePartitioner::NPerBlock); @@ -607,12 +610,12 @@ struct GemmKernel // allocate LDS __shared__ char smem_ptr_0[GetSmemSize()]; - __shared__ char smem_ptr_1[GetSmemSize()]; if(kargs.k_batch == 1) { if constexpr(GemmPipeline::DoubleSmemBuffer == true) { + __shared__ char smem_ptr_1[GetSmemSize()]; RunGemm2LDS(a_ptr, b_ptr, c_ptr, @@ -637,6 +640,7 @@ struct GemmKernel { if constexpr(GemmPipeline::DoubleSmemBuffer == true) { + __shared__ char smem_ptr_1[GetSmemSize()]; RunGemm2LDS(a_ptr, b_ptr, c_ptr, diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp index 4855df0e0e..24bd66a59e 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp @@ -68,9 +68,10 @@ struct GemmPipelineAgBgCrImplBase return make_tuple(std::move(a_lds_block), std::move(b_lds_block)); } - template - CK_TILE_DEVICE auto GetAWindows(const ADramBlockWindowTmp& a_dram_block_window_tmp, - const ALdsTensorView& a_lds_block_view) const + template + CK_TILE_DEVICE constexpr auto GetAWindows(const ADramBlockWindowTmp& a_dram_block_window_tmp, + const ALdsTensorView& a_lds_block_view, + const ALdsLoadTileDistr&) const { constexpr bool is_col_major = std::is_same_v; @@ -88,17 +89,21 @@ struct GemmPipelineAgBgCrImplBase auto a_copy_lds_window = make_tile_window( a_lds_block_view, make_tuple(number{}, number{}), {0, 0}); - auto a_lds_gemm_window = make_tile_window( - a_lds_block_view, make_tuple(number{}, number{}), {0, 0}); + auto a_lds_gemm_window = + make_tile_window(a_lds_block_view, + make_tuple(number{}, number{}), + {0, 0}, + ALdsLoadTileDistr{}); return make_tuple(std::move(a_copy_dram_window), std::move(a_copy_lds_window), std::move(a_lds_gemm_window)); } - template - CK_TILE_DEVICE auto GetBWindows(const BDramBlockWindowTmp& b_dram_block_window_tmp, - const BLdsTensorView& b_lds_block_view) const + template + CK_TILE_DEVICE constexpr auto GetBWindows(const BDramBlockWindowTmp& b_dram_block_window_tmp, + const BLdsTensorView& b_lds_block_view, + const BLdsLoadTileDistr&) const { constexpr bool is_row_major = std::is_same_v; @@ -117,8 +122,11 @@ struct GemmPipelineAgBgCrImplBase auto b_copy_lds_window = make_tile_window( b_lds_block_view, make_tuple(number{}, number{}), {0, 0}); - auto b_lds_gemm_window = make_tile_window( - b_lds_block_view, make_tuple(number{}, number{}), {0, 0}); + auto b_lds_gemm_window = + make_tile_window(b_lds_block_view, + make_tuple(number{}, number{}), + {0, 0}, + BLdsLoadTileDistr{}); return make_tuple(std::move(b_copy_dram_window), std::move(b_copy_lds_window), diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp index 73d5ce8f81..b6e165e6da 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp @@ -346,17 +346,23 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 // A/B tiles in LDS auto&& [a_lds_block, b_lds_block] = Base::GetABLdsTensorViews(p_smem); + // Tile distribution for load from lds + constexpr auto a_lds_load_tile_distr = + make_static_tile_distribution(BlockGemm::MakeABlockDistributionEncode()); + constexpr auto b_lds_load_tile_distr = + make_static_tile_distribution(BlockGemm::MakeBBlockDistributionEncode()); + // A DRAM tile window for load // A LDS tile window for store // A LDS tile for block GEMM auto&& [a_copy_dram_window, a_copy_lds_window, a_lds_gemm_window] = - Base::GetAWindows(a_dram_block_window_tmp, a_lds_block); + Base::GetAWindows(a_dram_block_window_tmp, a_lds_block, a_lds_load_tile_distr); // B DRAM tile window for load // B LDS tile window for store // B LDS tile for block GEMM auto&& [b_copy_dram_window, b_copy_lds_window, b_lds_gemm_window] = - Base::GetBWindows(b_dram_block_window_tmp, b_lds_block); + Base::GetBWindows(b_dram_block_window_tmp, b_lds_block, b_lds_load_tile_distr); // Block GEMM auto block_gemm = BlockGemm(); diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp index b8b2d5b1c9..8a73b4b5a1 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp @@ -215,10 +215,17 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem auto& a_lds_block = ab_lds_blocks.at(I0{}); auto& b_lds_block = ab_lds_blocks.at(I1{}); + // Tile distribution for load from lds + constexpr auto a_lds_load_tile_distr = decltype(make_static_tile_distribution( + BlockGemm::MakeABlockDistributionEncode())){}; + constexpr auto b_lds_load_tile_distr = decltype(make_static_tile_distribution( + BlockGemm::MakeBBlockDistributionEncode())){}; + // A DRAM tile window for load // A LDS tile window for store // A LDS tile for block GEMM - auto a_windows = Base::GetAWindows(a_dram_block_window_tmp, a_lds_block); + auto a_windows = + Base::GetAWindows(a_dram_block_window_tmp, a_lds_block, a_lds_load_tile_distr); auto& a_copy_dram_window = a_windows.at(I0{}); auto& a_copy_lds_window = a_windows.at(I1{}); auto& a_lds_gemm_window = a_windows.at(I2{}); @@ -226,7 +233,8 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem // B DRAM tile window for load // B LDS tile window for store // B LDS tile for block GEMM - auto b_windows = Base::GetBWindows(b_dram_block_window_tmp, b_lds_block); + auto b_windows = + Base::GetBWindows(b_dram_block_window_tmp, b_lds_block, b_lds_load_tile_distr); auto& b_copy_dram_window = b_windows.at(I0{}); auto& b_copy_lds_window = b_windows.at(I1{}); auto& b_lds_gemm_window = b_windows.at(I2{}); @@ -493,10 +501,17 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem auto& a_lds_block = ab_lds_blocks.at(I0{}); auto& b_lds_block = ab_lds_blocks.at(I1{}); + // Tile distribution for load from lds + constexpr auto a_lds_load_tile_distr = decltype(make_static_tile_distribution( + BlockGemm::MakeABlockDistributionEncode())){}; + constexpr auto b_lds_load_tile_distr = decltype(make_static_tile_distribution( + BlockGemm::MakeBBlockDistributionEncode())){}; + // A DRAM tile window for load // A LDS tile window for store // A LDS tile for block GEMM - auto a_windows = Base::GetAWindows(a_dram_block_window_tmp, a_lds_block); + auto a_windows = + Base::GetAWindows(a_dram_block_window_tmp, a_lds_block, a_lds_load_tile_distr); auto& a_copy_dram_window = a_windows.at(I0{}); auto& a_copy_lds_window = a_windows.at(I1{}); auto& a_lds_gemm_window = a_windows.at(I2{}); @@ -504,7 +519,8 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem // B DRAM tile window for load // B LDS tile window for store // B LDS tile for block GEMM - auto b_windows = Base::GetBWindows(b_dram_block_window_tmp, b_lds_block); + auto b_windows = + Base::GetBWindows(b_dram_block_window_tmp, b_lds_block, b_lds_load_tile_distr); auto& b_copy_dram_window = b_windows.at(I0{}); auto& b_copy_lds_window = b_windows.at(I1{}); auto& b_lds_gemm_window = b_windows.at(I2{}); diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp index 33945651ae..76bece9398 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp @@ -125,13 +125,25 @@ struct GemmPipelineAGmemBGmemCRegV1 auto b_copy_lds_window = make_tile_window( b_lds_block, make_tuple(number{}, number{}), {0, 0}); + // Tile distribution for load from lds + constexpr auto a_lds_load_tile_distr = + make_static_tile_distribution(BlockGemm::MakeABlockDistributionEncode()); + constexpr auto b_lds_load_tile_distr = + make_static_tile_distribution(BlockGemm::MakeBBlockDistributionEncode()); + // A LDS tile for block GEMM - auto a_lds_gemm_window = make_tile_window( - a_lds_block, make_tuple(number{}, number{}), {0, 0}); + auto a_lds_gemm_window = + make_tile_window(a_lds_block, + make_tuple(number{}, number{}), + {0, 0}, + a_lds_load_tile_distr); // B LDS tile for block GEMM - auto b_lds_gemm_window = make_tile_window( - b_lds_block, make_tuple(number{}, number{}), {0, 0}); + auto b_lds_gemm_window = + make_tile_window(b_lds_block, + make_tuple(number{}, number{}), + {0, 0}, + b_lds_load_tile_distr); // Block GEMM auto block_gemm = BlockGemm(); diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp index fe706113ae..2f658582c9 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp @@ -122,17 +122,29 @@ struct GemmPipelineAGmemBGmemCRegV2 {0, 0}, b_copy_dram_window.get_tile_distribution()); - // A LDS tile for block GEMM - auto a_lds_gemm_window = make_tile_window( - a_lds_block, make_tuple(number{}, number{}), {0, 0}); - - // B LDS tile for block GEMM - auto b_lds_gemm_window = make_tile_window( - b_lds_block, make_tuple(number{}, number{}), {0, 0}); - // Block GEMM constexpr auto block_gemm = Policy::template GetBlockGemm(); + // Tile distribution for load from lds + constexpr auto a_lds_load_tile_distr = + make_static_tile_distribution(decltype(block_gemm)::MakeABlockDistributionEncode()); + constexpr auto b_lds_load_tile_distr = + make_static_tile_distribution(decltype(block_gemm)::MakeBBlockDistributionEncode()); + + // A LDS tile for block GEMM + auto a_lds_gemm_window = + make_tile_window(a_lds_block, + make_tuple(number{}, number{}), + {0, 0}, + a_lds_load_tile_distr); + + // B LDS tile for block GEMM + auto b_lds_gemm_window = + make_tile_window(b_lds_block, + make_tuple(number{}, number{}), + {0, 0}, + b_lds_load_tile_distr); + // Acc register tile auto c_block_tile = decltype(block_gemm(a_lds_gemm_window, b_lds_gemm_window)){};