diff --git a/example/ck_tile/03_gemm/CMakeLists.txt b/example/ck_tile/03_gemm/CMakeLists.txt index 8ae46cadc6..d166eed458 100644 --- a/example/ck_tile/03_gemm/CMakeLists.txt +++ b/example/ck_tile/03_gemm/CMakeLists.txt @@ -1,2 +1,2 @@ add_executable(tile_example_gemm_basic EXCLUDE_FROM_ALL gemm_basic.cpp) -add_executable(tile_example_gemm_mem_pipeline EXCLUDE_FROM_ALL gemm_mem_pipeline.cpp) +add_executable(tile_example_universal_gemm EXCLUDE_FROM_ALL universal_gemm.cpp) diff --git a/example/ck_tile/03_gemm/gemm_mem_pipeline.cpp b/example/ck_tile/03_gemm/universal_gemm.cpp similarity index 89% rename from example/ck_tile/03_gemm/gemm_mem_pipeline.cpp rename to example/ck_tile/03_gemm/universal_gemm.cpp index cd9d9d96b6..eaafc13b98 100644 --- a/example/ck_tile/03_gemm/gemm_mem_pipeline.cpp +++ b/example/ck_tile/03_gemm/universal_gemm.cpp @@ -14,10 +14,17 @@ #include "ck_tile/host.hpp" #include "gemm_basic.hpp" +#define CK_TILE_PIPELINE_COMPUTE 1 +#define CK_TILE_PIPELINE_MEMORY 2 + +#ifndef CK_TILE_PIPELINE_DEFAULT +#define CK_TILE_PIPELINE_DEFAULT CK_TILE_PIPELINE_COMPUTE +#endif + template float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s) { -#if 1 +#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY) // Memory friendly for Interwave scheduler constexpr ck_tile::index_t M_Tile = 128; constexpr ck_tile::index_t N_Tile = 32; @@ -30,7 +37,8 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s) constexpr ck_tile::index_t M_Warp_Tile = 32; constexpr ck_tile::index_t N_Warp_Tile = 32; constexpr ck_tile::index_t K_Warp_Tile = 8; -#else + +#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE) // Compute friendly for Intrawave scheduler constexpr ck_tile::index_t M_Tile = 256; constexpr ck_tile::index_t N_Tile = 256; @@ -63,8 +71,11 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s) ck_tile::Default2DEpilogueProblem>; using Traits = ck_tile::TileGemmTraits; - +#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY) using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrMem< +#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE) + using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3< +#endif ck_tile::GemmPipelineProblem>; const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(args.K); @@ -77,13 +88,21 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s) constexpr bool has_hot_loop_v = has_hot_loop_.value; constexpr auto tail_number_v = tail_number_.value; +#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY) using GemmPipeline = ck_tile::GemmPipelineAgBgCrMem< +#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE) + using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3< +#endif ck_tile::UniversalGemmPipelineProblem>; using Kernel = ck_tile::GemmKernel; diff --git a/include/ck_tile/ops/gemm.hpp b/include/ck_tile/ops/gemm.hpp index 9a033ee2de..1340fb2048 100644 --- a/include/ck_tile/ops/gemm.hpp +++ b/include/ck_tile/ops/gemm.hpp @@ -25,6 +25,8 @@ #include "ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp" #include "ck_tile/ops/gemm/kernel/gemm_kernel.hpp" #include "ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp" diff --git a/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp b/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp index 5f98a7a0ba..c9e648f437 100644 --- a/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp +++ b/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp @@ -41,13 +41,16 @@ struct BlockUniversalGemmAsBsCr static constexpr index_t MWarp = config.template at<1>(); static constexpr index_t NWarp = config.template at<2>(); - static_assert(MWarp == BlockGemmShape::BlockWarps::at(number<0>{}), + using I0 = number<0>; + using I1 = number<1>; + + static_assert(MWarp == BlockGemmShape::BlockWarps::at(I0{}), "Error! WarpGemm's MWarp is not consisten with BlockGemmShape!"); - static_assert(NWarp == BlockGemmShape::BlockWarps::at(number<1>{}), + static_assert(NWarp == BlockGemmShape::BlockWarps::at(I1{}), "Error! WarpGemm's NWarp is not consisten with BlockGemmShape!"); - static_assert(WarpGemm::kM == BlockGemmShape::WarpTile::at(number<0>{}), + static_assert(WarpGemm::kM == BlockGemmShape::WarpTile::at(I0{}), "Error! WarpGemm's M is not consisten with BlockGemmShape!"); - static_assert(WarpGemm::kN == BlockGemmShape::WarpTile::at(number<1>{}), + static_assert(WarpGemm::kN == BlockGemmShape::WarpTile::at(I1{}), "Error! WarpGemm's N is not consisten with BlockGemmShape!"); static constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WarpGemm::kM); @@ -99,6 +102,9 @@ struct BlockUniversalGemmAsBsCr static constexpr auto Scheduler = Traits::Scheduler; + using I0 = number<0>; + using I1 = number<1>; + private: template struct BlockGemmImpl @@ -114,35 +120,31 @@ struct BlockUniversalGemmAsBsCr const ASmemBlockWindow& a_block_window, const BSmemBlockWindow& b_block_window) { - static_assert( - std::is_same_v, - "The CDataType as defined in traits should be the same as correspoinding " - "C block tensor data type!"); - static_assert(std::is_same_v && - std::is_same_v, + static_assert(std::is_same_v, + "The CDataType as defined in traits should be the same as correspoinding " + "C block tensor data type!"); + static_assert(std::is_same_v && + std::is_same_v, "The ADataType and BDataType as defined in " "traits should be the same as correspoinding block window data type!"); static_assert( - GemmTraits::MPerBlock == ASmemBlockWindow{}.get_window_lengths()[number<0>{}] && - GemmTraits::NPerBlock == BSmemBlockWindow{}.get_window_lengths()[number<0>{}] && - GemmTraits::KPerBlock == ASmemBlockWindow{}.get_window_lengths()[number<1>{}], + GemmTraits::MPerBlock == ASmemBlockWindow{}.get_window_lengths()[I0{}] && + GemmTraits::NPerBlock == BSmemBlockWindow{}.get_window_lengths()[I0{}] && + GemmTraits::KPerBlock == ASmemBlockWindow{}.get_window_lengths()[I1{}], "MPerBlock, NPerBlock, KPerBlock defined in " " BlockGemmShape are different from A/B block smem windows apropriate dims!"); - const index_t iMWarp = get_warp_id() / GemmTraits::NWarp; - const index_t iNWarp = get_warp_id() - (iMWarp * GemmTraits::NWarp); + const index_t iMWarp = get_warp_id() / NWarp; + const index_t iNWarp = get_warp_id() - (iMWarp * NWarp); // TODO: refactor warp_window tile type to class member as it should be // compile-time known information. auto a_warp_window_tmp = make_tile_window( a_block_window.get_bottom_tensor_view(), - make_tuple(number{}, number{}), - a_block_window.get_window_origin() + - multi_index<2>{iMWarp * GemmTraits::WarpGemm::kM, 0}, - make_static_tile_distribution(typename GemmTraits::WarpGemm::AWarpDstrEncoding{})); + make_tuple(number{}, number{}), + a_block_window.get_window_origin() + multi_index<2>{iMWarp * WarpGemm::kM, 0}, + make_static_tile_distribution(typename WarpGemm::AWarpDstrEncoding{})); using AWarpWindow = remove_cvref_t; @@ -156,16 +158,15 @@ struct BlockUniversalGemmAsBsCr statically_indexed_array< statically_indexed_array, - GemmTraits::MIterPerWarp> + MIterPerWarp> a_warp_windows; // construct B-warp-window auto b_warp_window_tmp = make_tile_window( b_block_window.get_bottom_tensor_view(), - make_tuple(number{}, number{}), - b_block_window.get_window_origin() + - multi_index<2>{iNWarp * GemmTraits::WarpGemm::kN, 0}, - make_static_tile_distribution(typename GemmTraits::WarpGemm::BWarpDstrEncoding{})); + make_tuple(number{}, number{}), + b_block_window.get_window_origin() + multi_index<2>{iNWarp * WarpGemm::kN, 0}, + make_static_tile_distribution(typename WarpGemm::BWarpDstrEncoding{})); using BWarpWindow = remove_cvref_t; @@ -179,10 +180,10 @@ struct BlockUniversalGemmAsBsCr statically_indexed_array< statically_indexed_array, - GemmTraits::NIterPerWarp> + NIterPerWarp> b_warp_windows; - static_for<0, GemmTraits::MIterPerWarp, 1>{}([&](auto mIter) { + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { static_for<0, GemmTraits::KIterPerWarp, 1>{}([&](auto kIter) { a_warp_windows(mIter)(kIter) = a_warp_window_tmp; @@ -193,7 +194,7 @@ struct BlockUniversalGemmAsBsCr }); }); - static_for<0, GemmTraits::NIterPerWarp, 1>{}([&](auto nIter) { + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { static_for<0, GemmTraits::KIterPerWarp, 1>{}([&](auto kIter) { b_warp_windows(nIter)(kIter) = b_warp_window_tmp; @@ -203,8 +204,8 @@ struct BlockUniversalGemmAsBsCr }); }); - using CWarpDstr = typename GemmTraits::WarpGemm::CWarpDstr; - using CWarpTensor = typename GemmTraits::WarpGemm::CWarpTensor; + using CWarpDstr = typename WarpGemm::CWarpDstr; + using CWarpTensor = typename WarpGemm::CWarpTensor; constexpr auto c_warp_y_lengths = to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); @@ -212,10 +213,10 @@ struct BlockUniversalGemmAsBsCr // hot loop: static_for<0, GemmTraits::KIterPerWarp, 1>{}([&](auto kIter) { - static_for<0, GemmTraits::MIterPerWarp, 1>{}([&](auto mIter) { + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { const auto a_warp_tile = load_tile(a_warp_windows(mIter)(kIter)); - static_for<0, GemmTraits::NIterPerWarp, 1>{}([&](auto nIter) { + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { const auto b_warp_tile = load_tile(b_warp_windows(nIter)(kIter)); // read C warp tensor from C block tensor- @@ -226,7 +227,7 @@ struct BlockUniversalGemmAsBsCr merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); // warp GEMM - typename GemmTraits::WarpGemm{}(c_warp_tensor, a_warp_tile, b_warp_tile); + WarpGemm{}(c_warp_tensor, a_warp_tile, b_warp_tile); // write C warp tensor into C block tensor c_block_tensor.set_y_sliced_thread_data( @@ -243,13 +244,13 @@ struct BlockUniversalGemmAsBsCr struct BlockGemmImpl { statically_indexed_array< - statically_indexed_array, - GemmTraits::MIterPerWarp> + statically_indexed_array, + MIterPerWarp> a_warp_tiles_; statically_indexed_array< - statically_indexed_array, - GemmTraits::NIterPerWarp> + statically_indexed_array, + NIterPerWarp> b_warp_tiles_; template @@ -257,30 +258,27 @@ struct BlockUniversalGemmAsBsCr const BSmemBlockWindow& b_block_window) { static_assert( - GemmTraits::MPerBlock == ASmemBlockWindow{}.get_window_lengths()[number<0>{}] && - GemmTraits::NPerBlock == BSmemBlockWindow{}.get_window_lengths()[number<0>{}] && - GemmTraits::KPerBlock == ASmemBlockWindow{}.get_window_lengths()[number<1>{}], + GemmTraits::MPerBlock == ASmemBlockWindow{}.get_window_lengths()[I0{}] && + GemmTraits::NPerBlock == BSmemBlockWindow{}.get_window_lengths()[I0{}] && + GemmTraits::KPerBlock == ASmemBlockWindow{}.get_window_lengths()[I1{}], "MPerBlock, NPerBlock, KPerBlock defined in " " BlockGemmShape are different from A/B block smem windows apropriate dims!"); - static_assert(std::is_same_v && - std::is_same_v, + static_assert(std::is_same_v && + std::is_same_v, "The ADataType and BDataType as defined in " "traits should be the same as correspoinding block window data type!"); - const index_t iMWarp = get_warp_id() / GemmTraits::NWarp; - const index_t iNWarp = get_warp_id() - (iMWarp * GemmTraits::NWarp); + const index_t iMWarp = get_warp_id() / NWarp; + const index_t iNWarp = get_warp_id() - (iMWarp * NWarp); // TODO: refactor warp_window tile type to class member as it should be // compile-time known information. auto a_warp_window_tmp = make_tile_window( a_block_window.get_bottom_tensor_view(), - make_tuple(number{}, number{}), - a_block_window.get_window_origin() + - multi_index<2>{iMWarp * GemmTraits::WarpGemm::kM, 0}, - make_static_tile_distribution(typename GemmTraits::WarpGemm::AWarpDstrEncoding{})); + make_tuple(number{}, number{}), + a_block_window.get_window_origin() + multi_index<2>{iMWarp * WarpGemm::kM, 0}, + make_static_tile_distribution(typename WarpGemm::AWarpDstrEncoding{})); using AWarpWindow = remove_cvref_t; @@ -292,18 +290,16 @@ struct BlockUniversalGemmAsBsCr AWarpWindow{}.get_window_lengths(), "AWarpWindow lengths must be equal to AWarpTile lengths!"); - statically_indexed_array< - statically_indexed_array, - GemmTraits::MIterPerWarp> + statically_indexed_array, + MIterPerWarp> a_warp_windows; // construct B-warp-window auto b_warp_window_tmp = make_tile_window( b_block_window.get_bottom_tensor_view(), - make_tuple(number{}, number{}), - b_block_window.get_window_origin() + - multi_index<2>{iNWarp * GemmTraits::WarpGemm::kN, 0}, - make_static_tile_distribution(typename GemmTraits::WarpGemm::BWarpDstrEncoding{})); + make_tuple(number{}, number{}), + b_block_window.get_window_origin() + multi_index<2>{iNWarp * WarpGemm::kN, 0}, + make_static_tile_distribution(typename WarpGemm::BWarpDstrEncoding{})); using BWarpWindow = remove_cvref_t; @@ -315,13 +311,12 @@ struct BlockUniversalGemmAsBsCr BWarpWindow{}.get_window_lengths(), "BWarpWindow lengths must be equal to BWarpTile lengths!"); - statically_indexed_array< - statically_indexed_array, - GemmTraits::NIterPerWarp> + statically_indexed_array, + NIterPerWarp> b_warp_windows; - static_for<0, GemmTraits::MIterPerWarp, 1>{}([&](auto mIter) { - static_for<0, GemmTraits::KIterPerWarp, 1>{}([&](auto kIter) { + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { a_warp_windows(mIter)(kIter) = a_warp_window_tmp; // TODO: I don't have to move 0,0 window! @@ -331,8 +326,8 @@ struct BlockUniversalGemmAsBsCr }); }); - static_for<0, GemmTraits::NIterPerWarp, 1>{}([&](auto nIter) { - static_for<0, GemmTraits::KIterPerWarp, 1>{}([&](auto kIter) { + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { b_warp_windows(nIter)(kIter) = b_warp_window_tmp; move_tile_window(b_warp_windows(nIter)(kIter), @@ -341,12 +336,12 @@ struct BlockUniversalGemmAsBsCr }); }); - static_for<0, GemmTraits::KIterPerWarp, 1>{}([&](auto kIter) { - static_for<0, GemmTraits::MIterPerWarp, 1>{}([&](auto mIter) { + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { // read A warp tensor from A block window load_tile(a_warp_tiles_(mIter)(kIter), a_warp_windows(mIter)(kIter)); }); - static_for<0, GemmTraits::NIterPerWarp, 1>{}([&](auto nIter) { + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { // read B warp tensor from B Block window load_tile(b_warp_tiles_(nIter)(kIter), b_warp_windows(nIter)(kIter)); }); @@ -359,22 +354,21 @@ struct BlockUniversalGemmAsBsCr [[maybe_unused]] const ASmemBlockWindow& a_block_window, [[maybe_unused]] const BSmemBlockWindow& b_block_window) { - static_assert( - std::is_same_v, - "The CDataType as defined in traits should be the same as correspoinding " - "C block tensor data type!"); + static_assert(std::is_same_v, + "The CDataType as defined in traits should be the same as correspoinding " + "C block tensor data type!"); - using CWarpDstr = typename GemmTraits::WarpGemm::CWarpDstr; - using CWarpTensor = typename GemmTraits::WarpGemm::CWarpTensor; + using CWarpDstr = typename WarpGemm::CWarpDstr; + using CWarpTensor = typename WarpGemm::CWarpTensor; constexpr auto c_warp_y_lengths = to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t{}; // hot loop: - static_for<0, GemmTraits::KIterPerWarp, 1>{}([&](auto kIter) { - static_for<0, GemmTraits::MIterPerWarp, 1>{}([&](auto mIter) { - static_for<0, GemmTraits::NIterPerWarp, 1>{}([&](auto nIter) { + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { // read C warp tensor from C block tensor- CWarpTensor c_warp_tensor; @@ -383,9 +377,9 @@ struct BlockUniversalGemmAsBsCr merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); // warp GEMM - typename GemmTraits::WarpGemm{}(c_warp_tensor, - a_warp_tiles_[mIter][kIter], - b_warp_tiles_[nIter][kIter]); + WarpGemm{}(c_warp_tensor, + a_warp_tiles_[mIter][kIter], + b_warp_tiles_[nIter][kIter]); // write C warp tensor into C block tensor c_block_tensor.set_y_sliced_thread_data( @@ -412,12 +406,12 @@ struct BlockUniversalGemmAsBsCr statically_indexed_array< statically_indexed_array, - GemmTraits::MIterPerWarp> + MIterPerWarp> a_warp_tiles_; statically_indexed_array< statically_indexed_array, - GemmTraits::NIterPerWarp> + NIterPerWarp> b_warp_tiles_; template @@ -425,30 +419,28 @@ struct BlockUniversalGemmAsBsCr const BSmemBlockWindow& b_block_window) { static_assert( - GemmTraits::MPerBlock == ASmemBlockWindow{}.get_window_lengths()[number<0>{}] && - GemmTraits::NPerBlock == BSmemBlockWindow{}.get_window_lengths()[number<0>{}] && - GemmTraits::KPerBlock == ASmemBlockWindow{}.get_window_lengths()[number<1>{}], + GemmTraits::MPerBlock == ASmemBlockWindow{}.get_window_lengths()[I0{}] && + GemmTraits::NPerBlock == BSmemBlockWindow{}.get_window_lengths()[I0{}] && + GemmTraits::KPerBlock == ASmemBlockWindow{}.get_window_lengths()[I1{}], "MPerBlock, NPerBlock, KPerBlock defined in " " BlockGemmShape are different from A/B block smem windows apropriate dims!"); - static_assert(std::is_same_v && - std::is_same_v, + static_assert(std::is_same_v && + std::is_same_v, "The ADataType and BDataType as defined in " "traits should be the same as correspoinding block window data type!"); - const index_t iMWarp = get_warp_id() / GemmTraits::NWarp; - const index_t iNWarp = get_warp_id() - (iMWarp * GemmTraits::NWarp); + const index_t iMWarp = get_warp_id() / NWarp; + const index_t iNWarp = get_warp_id() - (iMWarp * NWarp); // TODO: refactor warp_window tile type to class member as it should be // compile-time known information. auto a_warp_window_tmp = make_tile_window( a_block_window.get_bottom_tensor_view(), - make_tuple(number{}, number{}), + make_tuple(number{}, number{}), a_block_window.get_window_origin() + - multi_index<2>{iMWarp * GemmTraits::WarpGemm::kM, KIdx * KPerInnerLoop}, - make_static_tile_distribution(typename GemmTraits::WarpGemm::AWarpDstrEncoding{})); + multi_index<2>{iMWarp * WarpGemm::kM, KIdx * KPerInnerLoop}, + make_static_tile_distribution(typename WarpGemm::AWarpDstrEncoding{})); using AWarpWindow = remove_cvref_t; @@ -461,16 +453,16 @@ struct BlockUniversalGemmAsBsCr "AWarpWindow lengths must be equal to AWarpTile lengths!"); statically_indexed_array, - GemmTraits::MIterPerWarp> + MIterPerWarp> a_warp_windows; // construct B-warp-window auto b_warp_window_tmp = make_tile_window( b_block_window.get_bottom_tensor_view(), - make_tuple(number{}, number{}), + make_tuple(number{}, number{}), b_block_window.get_window_origin() + - multi_index<2>{iNWarp * GemmTraits::WarpGemm::kN, KIdx * KPerInnerLoop}, - make_static_tile_distribution(typename GemmTraits::WarpGemm::BWarpDstrEncoding{})); + multi_index<2>{iNWarp * WarpGemm::kN, KIdx * KPerInnerLoop}, + make_static_tile_distribution(typename WarpGemm::BWarpDstrEncoding{})); using BWarpWindow = remove_cvref_t; @@ -483,10 +475,10 @@ struct BlockUniversalGemmAsBsCr "BWarpWindow lengths must be equal to BWarpTile lengths!"); statically_indexed_array, - GemmTraits::NIterPerWarp> + NIterPerWarp> b_warp_windows; - static_for<0, GemmTraits::MIterPerWarp, 1>{}([&](auto mIter) { + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { static_for<0, KInnerLoopIter, 1>{}([&](auto kIter) { a_warp_windows(mIter)(kIter) = a_warp_window_tmp; @@ -496,7 +488,7 @@ struct BlockUniversalGemmAsBsCr }); }); - static_for<0, GemmTraits::NIterPerWarp, 1>{}([&](auto nIter) { + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { static_for<0, KInnerLoopIter, 1>{}([&](auto kIter) { b_warp_windows(nIter)(kIter) = b_warp_window_tmp; @@ -508,11 +500,11 @@ struct BlockUniversalGemmAsBsCr // TODO check if a_warp_tiles has same desc as a_warp_window static_for<0, KInnerLoopIter, 1>{}([&](auto kIter) { - static_for<0, GemmTraits::MIterPerWarp, 1>{}([&](auto mIter) { + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { // read A warp tensor from A block window load_tile(a_warp_tiles_(mIter)(kIter), a_warp_windows(mIter)(kIter)); }); - static_for<0, GemmTraits::NIterPerWarp, 1>{}([&](auto nIter) { + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { // read B warp tensor from B Block window load_tile(b_warp_tiles_(nIter)(kIter), b_warp_windows(nIter)(kIter)); }); @@ -525,13 +517,12 @@ struct BlockUniversalGemmAsBsCr const ASmemBlockWindow& a_block_window, const BSmemBlockWindow& b_block_window) { - static_assert( - std::is_same_v, - "The CDataType as defined in traits should be the same as correspoinding " - "C block tensor data type!"); + static_assert(std::is_same_v, + "The CDataType as defined in traits should be the same as correspoinding " + "C block tensor data type!"); - using CWarpDstr = typename GemmTraits::WarpGemm::CWarpDstr; - using CWarpTensor = typename GemmTraits::WarpGemm::CWarpTensor; + using CWarpDstr = typename WarpGemm::CWarpDstr; + using CWarpTensor = typename WarpGemm::CWarpTensor; constexpr auto c_warp_y_lengths = to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); @@ -555,8 +546,8 @@ struct BlockUniversalGemmAsBsCr } static_for<0, KInnerLoopIter, 1>{}([&](auto kInnerIter) { - static_for<0, GemmTraits::MIterPerWarp, 1>{}([&](auto mIter) { - static_for<0, GemmTraits::NIterPerWarp, 1>{}([&](auto nIter) { + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { // read C warp tensor from C block tensor- CWarpTensor c_warp_tensor; @@ -573,17 +564,17 @@ struct BlockUniversalGemmAsBsCr // penalty if constexpr(kIter.value == KRepeat - 1 && kInnerIter.value == KInnerLoopIter - 1 && - mIter.value == GemmTraits::MIterPerWarp - 1 && - nIter.value == GemmTraits::NIterPerWarp - 1) + mIter.value == MIterPerWarp - 1 && + nIter.value == NIterPerWarp - 1) { __builtin_amdgcn_sched_barrier(0); block_sync_lds(); __builtin_amdgcn_sched_barrier(0); } // warp GEMM - typename GemmTraits::WarpGemm{}(c_warp_tensor, - a_warp_tiles_[mIter][kInnerIter], - b_warp_tiles_[nIter][kInnerIter]); + WarpGemm{}(c_warp_tensor, + a_warp_tiles_[mIter][kInnerIter], + b_warp_tiles_[nIter][kInnerIter]); // write C warp tensor into C block tensor c_block_tensor.set_y_sliced_thread_data( diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp new file mode 100644 index 0000000000..431534af15 --- /dev/null +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp @@ -0,0 +1,111 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" + +namespace ck_tile { + +template +struct GemmPipelineAgBgCrImplBase +{ + using ADataType = remove_cvref_t; + using BDataType = remove_cvref_t; + using BlockGemmShape = remove_cvref_t; + + static constexpr index_t MPerBlock = BlockGemmShape::kM; + static constexpr index_t NPerBlock = BlockGemmShape::kN; + static constexpr index_t KPerBlock = BlockGemmShape::kK; + + template + CK_TILE_DEVICE void GlobalPrefetch(DstBlockTile& dst_block_tile, + SrcTileWindow& dram_tile_window) const + { + load_tile(dst_block_tile, dram_tile_window); + move_tile_window(dram_tile_window, {0, KPerBlock}); + } + + template + CK_TILE_DEVICE void LocalPrefill(DstTileWindow& lds_tile_window, + const SrcBlockTile& src_block_tile, + const ElementFunction& element_func) const + { + const auto block_tile_tmp = tile_elementwise_in(element_func, src_block_tile); + store_tile(lds_tile_window, block_tile_tmp); + } + + CK_TILE_DEVICE auto GetABLdsTensorViews(void* p_smem) const + { + // A tile in LDS + ADataType* p_a_lds = static_cast(p_smem); + constexpr auto a_lds_block_desc = Policy::template MakeALdsBlockDescriptor(); + auto a_lds_block = make_tensor_view(p_a_lds, a_lds_block_desc); + + // TODO: LDS alignment should come from Policy! + constexpr index_t a_lds_block_space_size_aligned = + integer_divide_ceil(sizeof(ADataType) * a_lds_block_desc.get_element_space_size(), 16) * + 16; + + // B tile in LDS + BDataType* p_b_lds = static_cast( + static_cast(static_cast(p_smem) + a_lds_block_space_size_aligned)); + constexpr auto b_lds_block_desc = Policy::template MakeBLdsBlockDescriptor(); + auto b_lds_block = make_tensor_view(p_b_lds, b_lds_block_desc); + + return make_tuple(std::move(a_lds_block), std::move(b_lds_block)); + } + + template + CK_TILE_DEVICE auto GetAWindows(const ADramBlockWindowTmp& a_dram_block_window_tmp, + const ALdsTensorView& a_lds_block_view) const + { + // A DRAM tile window for load + auto a_copy_dram_window = + make_tile_window(a_dram_block_window_tmp.get_bottom_tensor_view(), + make_tuple(number{}, number{}), + a_dram_block_window_tmp.get_window_origin(), + Policy::template MakeADramTileDistribution()); + + // A LDS tile window for store + auto a_copy_lds_window = + make_tile_window(a_lds_block_view, + make_tuple(number{}, number{}), + {0, 0}, + a_copy_dram_window.get_tile_distribution()); + + auto a_lds_gemm_window = make_tile_window( + a_lds_block_view, make_tuple(number{}, number{}), {0, 0}); + + return make_tuple(std::move(a_copy_dram_window), + std::move(a_copy_lds_window), + std::move(a_lds_gemm_window)); + } + + template + CK_TILE_DEVICE auto GetBWindows(const BDramBlockWindowTmp& b_dram_block_window_tmp, + const BLdsTensorView& b_lds_block_view) const + { + auto b_copy_dram_window = + make_tile_window(b_dram_block_window_tmp.get_bottom_tensor_view(), + make_tuple(number{}, number{}), + b_dram_block_window_tmp.get_window_origin(), + Policy::template MakeBDramTileDistribution()); + + // B LDS tile window for store + auto b_copy_lds_window = + make_tile_window(b_lds_block_view, + make_tuple(number{}, number{}), + {0, 0}, + b_copy_dram_window.get_tile_distribution()); + + auto b_lds_gemm_window = make_tile_window( + b_lds_block_view, make_tuple(number{}, number{}), {0, 0}); + + return make_tuple(std::move(b_copy_dram_window), + std::move(b_copy_lds_window), + std::move(b_lds_gemm_window)); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp new file mode 100644 index 0000000000..a72728b4a0 --- /dev/null +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp @@ -0,0 +1,383 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp" + +namespace ck_tile { + +// A Tile Window: global memory +// B Tile Window: global memory +// C Distributed tensor: register +template +struct BaseGemmPipelineAgBgCrCompV3 +{ + static constexpr index_t PrefetchStages = 2; + static constexpr index_t PrefillStages = 1; + static constexpr index_t GlobalBufferNum = 1; + + CK_TILE_HOST static constexpr bool BlockHasHotloop(index_t num_loop) + { + return num_loop > PrefetchStages; + } + + CK_TILE_HOST static constexpr TailNumber GetBlockLoopTailNum(index_t num_loop) + { + ignore = num_loop; + return TailNumber::Full; + } +}; + +// Compute optimized pipeline +// GlobalPrefetchStages: 2 +// LocalPreFillStages: 1 +// LocalPreFetchStages: 1 +// LocalSharedMemoryBuffer: 1 +template +struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 +{ + using Base = BaseGemmPipelineAgBgCrCompV3; + using PipelineImplBase = GemmPipelineAgBgCrImplBase; + + using ADataType = remove_cvref_t; + using BDataType = remove_cvref_t; + using CDataType = remove_cvref_t; + using BlockGemmShape = remove_cvref_t; + + using ALayout = remove_cvref_t; + using BLayout = remove_cvref_t; + using CLayout = remove_cvref_t; + + using BlockGemm = remove_cvref_t())>; + using I0 = number<0>; + using I1 = number<1>; + using I2 = number<2>; + + static constexpr index_t BlockSize = Problem::kBlockSize; + static constexpr index_t MPerBlock = BlockGemmShape::kM; + static constexpr index_t NPerBlock = BlockGemmShape::kN; + static constexpr index_t KPerBlock = BlockGemmShape::kK; + + static constexpr index_t VectorSizeA = Problem::VectorSizeA; + static constexpr index_t VectorSizeB = Problem::VectorSizeB; + static constexpr index_t VectorSizeC = Problem::VectorSizeC; + + static constexpr bool kPadM = Problem::kPadM; + static constexpr bool kPadN = Problem::kPadN; + static constexpr bool kPadK = Problem::kPadK; + + // Where is the right place for HasHotLoop and TailNum ??? + static constexpr bool HasHotLoop = Problem::HasHotLoop; + static constexpr auto TailNum = Problem::TailNum; + static constexpr auto Scheduler = Problem::Scheduler; + + using Base::PrefetchStages; + + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() + { + return Policy::template GetSmemSize(); + } + + template + struct PipelineImpl : public PipelineImplBase + { + }; + + template <> + struct PipelineImpl : public PipelineImplBase + { + using Base = PipelineImplBase; + + CK_TILE_DEVICE static constexpr auto HotLoopScheduler() + { + constexpr index_t MPerXDL = BlockGemmShape::WarpTile::at(I0{}); + constexpr index_t NPerXDL = BlockGemmShape::WarpTile::at(I1{}); + constexpr index_t KPerXDL = BlockGemmShape::WarpTile::at(I2{}); + + constexpr index_t WaveSize = 64; + constexpr index_t WaveNumM = BlockGemmShape::BlockWarps::at(I0{}); + constexpr index_t WaveNumN = BlockGemmShape::BlockWarps::at(I1{}); + + constexpr index_t A_LDS_Read_Width = KPerXDL; + constexpr index_t B_LDS_Read_Width = KPerXDL; + + constexpr index_t A_Buffer_Load_Inst_Num = + MPerBlock * KPerBlock / (BlockSize * VectorSizeA); + constexpr index_t B_Buffer_Load_Inst_Num = + NPerBlock * KPerBlock / (BlockSize * VectorSizeB); + + constexpr index_t A_LDS_Write_Inst_Num = MPerBlock * KPerBlock / (BlockSize * KPerXDL); + constexpr index_t B_LDS_Write_Inst_Num = NPerBlock * KPerBlock / (BlockSize * KPerXDL); + + constexpr index_t A_LDS_Read_Inst_Num = + WaveNumN * MPerBlock * KPerBlock / (BlockSize * KPerXDL); + constexpr index_t B_LDS_Read_Inst_Num = + WaveNumM * MPerBlock * KPerBlock / (BlockSize * KPerXDL); + + constexpr index_t C_MFMA_Inst_Num = MPerBlock * NPerBlock * KPerBlock / + (BlockSize / WaveSize) / + (MPerXDL * NPerXDL * KPerXDL); + + // A/B split schedule + // compiler is likely to use ds_read2 when instruction width smaller than 16bytes + constexpr auto num_ds_read_inst_a = A_LDS_Read_Width * sizeof(ADataType) == 16 + ? A_LDS_Read_Inst_Num + : A_LDS_Read_Inst_Num / 2; + constexpr auto num_ds_read_inst_b = B_LDS_Read_Width * sizeof(BDataType) == 16 + ? B_LDS_Read_Inst_Num + : B_LDS_Read_Inst_Num / 2; + + constexpr auto num_ds_write_inst_a = A_LDS_Write_Inst_Num; + constexpr auto num_ds_write_inst_b = B_LDS_Write_Inst_Num; + + constexpr auto num_buffer_load_inst_a = A_Buffer_Load_Inst_Num; + constexpr auto num_buffer_load_inst_b = B_Buffer_Load_Inst_Num; + + constexpr auto num_mfma_inst = C_MFMA_Inst_Num; + + constexpr auto mfma_cycle = NPerXDL == 16 ? 16 : 32; + constexpr auto ds_read_a_issue_cycle = + A_LDS_Read_Width * sizeof(ADataType) == 16 ? 8 : 4; + constexpr auto ds_read_b_issue_cycle = + B_LDS_Read_Width * sizeof(BDataType) == 16 ? 8 : 4; + constexpr auto ds_read_a_mfma_rate = + (mfma_cycle - 4 + 2 * ds_read_a_issue_cycle - 1) / (2 * ds_read_a_issue_cycle); + constexpr auto ds_read_b_mfma_rate = + (mfma_cycle - 4 + 2 * ds_read_b_issue_cycle - 1) / (2 * ds_read_b_issue_cycle); + + constexpr auto num_dsread_a_mfma = + (num_ds_read_inst_a + ds_read_a_mfma_rate - 1) / ds_read_a_mfma_rate; + constexpr auto num_dsread_b_mfma = + (num_ds_read_inst_b + ds_read_b_mfma_rate - 1) / ds_read_b_mfma_rate; + + // stage 1 + // Separate this part? + // constexpr auto num_mfma_per_ds_read = sizeof(ComputeDataType) / sizeof(ADataType) > + // sizeof(ComputeDataType) / + // sizeof(BDataType) + // ? sizeof(ComputeDataType) / + // sizeof(ADataType) : sizeof(ComputeDataType) + // / sizeof(BDataType); + constexpr auto num_mfma_stage1 = + num_mfma_inst - (num_dsread_a_mfma + num_dsread_b_mfma); + constexpr auto num_mfma_per_issue = + num_mfma_stage1 / (num_buffer_load_inst_a + num_buffer_load_inst_b); + constexpr auto num_dswrite_per_issue_a = num_ds_write_inst_a / num_buffer_load_inst_a; + constexpr auto num_dswrite_per_issue_b = num_ds_write_inst_b / num_buffer_load_inst_b; + + static_for<0, num_buffer_load_inst_a, 1>{}([&](auto i) { + ignore = i; + static_for<0, num_dswrite_per_issue_a, 1>{}([&](auto idswrite) { + ignore = idswrite; + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier( + 0x008, num_mfma_per_issue - num_dswrite_per_issue_a, 0); // MFMA + }); + static_for<0, num_buffer_load_inst_b, 1>{}([&](auto i) { + ignore = i; + static_for<0, num_dswrite_per_issue_b, 1>{}([&](auto idswrite) { + ignore = idswrite; + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier( + 0x008, num_mfma_per_issue - num_dswrite_per_issue_b, 0); // MFMA + }); + + // stage 2 + static_for<0, num_dsread_a_mfma, 1>{}([&](auto i) { + if constexpr((num_ds_read_inst_a - (i + 1) * ds_read_a_mfma_rate) >= + ds_read_a_mfma_rate) + { + __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read + } + else + { + __builtin_amdgcn_sched_group_barrier( + 0x100, + num_ds_read_inst_a - (num_dsread_a_mfma - 1) * ds_read_a_mfma_rate, + 0); // DS read + } + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + + static_for<0, num_dsread_b_mfma, 1>{}([&](auto i) { + if constexpr((num_ds_read_inst_b - (i + 1) * ds_read_b_mfma_rate) >= + ds_read_b_mfma_rate) + { + __builtin_amdgcn_sched_group_barrier(0x100, ds_read_b_mfma_rate, 0); // DS read + } + else + { + __builtin_amdgcn_sched_group_barrier( + 0x100, + num_ds_read_inst_b - (num_dsread_b_mfma - 1) * ds_read_b_mfma_rate, + 0); // DS read + } + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + } + + template + CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + const AElementFunction& a_element_func, + const BDramBlockWindowTmp& b_dram_block_window_tmp, + const BElementFunction& b_element_func, + index_t num_loop, + void* p_smem) const + { + static_assert( + std::is_same_v> && + std::is_same_v>, + "A/B Dram block window should have the same data type as appropriate " + "([A|B]DataType) defined in Problem definition!"); + + static_assert(MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] && + NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] && + KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}], + "A/B block window appropriate sizes must be equal to MPerBlock/NPerblock" + " or KPerBlock!"); + + // ------------------------------------------------------------------------------------ + // Definitions of all needed tiles + + // A/B tiles in LDS + auto&& [a_lds_block, b_lds_block] = Base::GetABLdsTensorViews(p_smem); + + // A DRAM tile window for load + // A LDS tile window for store + // A LDS tile for block GEMM + auto&& [a_copy_dram_window, a_copy_lds_window, a_lds_gemm_window] = + Base::GetAWindows(a_dram_block_window_tmp, a_lds_block); + + // B DRAM tile window for load + // B LDS tile window for store + // B LDS tile for block GEMM + auto&& [b_copy_dram_window, b_copy_lds_window, b_lds_gemm_window] = + Base::GetBWindows(b_dram_block_window_tmp, b_lds_block); + + // Block GEMM + auto block_gemm = BlockGemm(); + auto c_block_tile = block_gemm.MakeCBlockTile(); + + using ABlockTileDistr = decltype(a_copy_dram_window.get_tile_distribution()); + using BBlockTileDistr = decltype(b_copy_dram_window.get_tile_distribution()); + + using ABlockTile = + decltype(make_static_distributed_tensor(ABlockTileDistr{})); + using BBlockTile = + decltype(make_static_distributed_tensor(BBlockTileDistr{})); + + ABlockTile a_block_tile; + BBlockTile b_block_tile; + + // ----------------------------------------------------------------------------------------- + // Gemm pipeline start + + // prefetch + // global read 0 + Base::GlobalPrefetch(a_block_tile, a_copy_dram_window); + Base::GlobalPrefetch(b_block_tile, b_copy_dram_window); + + // initialize C + tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); + + // LDS write 0 + Base::LocalPrefill(a_copy_lds_window, a_block_tile, a_element_func); + Base::LocalPrefill(b_copy_lds_window, b_block_tile, b_element_func); + + Base::GlobalPrefetch(a_block_tile, a_copy_dram_window); + Base::GlobalPrefetch(b_block_tile, b_copy_dram_window); + + block_sync_lds(); + block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window); + + __builtin_amdgcn_sched_barrier(0); + + // main body + if constexpr(HasHotLoop) + { + index_t i = 0; + do + { + block_sync_lds(); + + Base::LocalPrefill(a_copy_lds_window, a_block_tile, a_element_func); + Base::LocalPrefill(b_copy_lds_window, b_block_tile, b_element_func); + + Base::GlobalPrefetch(a_block_tile, a_copy_dram_window); + Base::GlobalPrefetch(b_block_tile, b_copy_dram_window); + + block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); + + block_sync_lds(); + block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window); + HotLoopScheduler(); + __builtin_amdgcn_sched_barrier(0); + + i += 1; + } while(i < (num_loop - 1)); + } + // tail + if constexpr(TailNum == TailNumber::Full) + { + block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); + } + // Let's leak last MFMA block to epilogue region, cover the potential lds-shuffle + // latency + // __builtin_amdgcn_sched_barrier(0); + return c_block_tile; + } + }; + + template + CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + const AElementFunction& a_element_func, + const BDramBlockWindowTmp& b_dram_block_window_tmp, + const BElementFunction& b_element_func, + index_t num_loop, + void* p_smem) const + { + return PipelineImpl{}.template operator()( + a_dram_block_window_tmp, + a_element_func, + b_dram_block_window_tmp, + b_element_func, + num_loop, + p_smem); + } + + template + CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + const BDramBlockWindowTmp& b_dram_block_window_tmp, + index_t num_loop, + void* p_smem) const + { + return PipelineImpl{}.template operator()( + a_dram_block_window_tmp, + [](const ADataType& a) { return a; }, + b_dram_block_window_tmp, + [](const BDataType& b) { return b; }, + num_loop, + p_smem); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp index 847c5b187d..e2e94cf92b 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp @@ -6,6 +6,7 @@ #include "ck_tile/core.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp" namespace ck_tile { @@ -90,7 +91,8 @@ struct BaseGemmPipelineAgBgCrMem template struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem { - using Base = BaseGemmPipelineAgBgCrMem; + using Base = BaseGemmPipelineAgBgCrMem; + using PipelineImplBase = GemmPipelineAgBgCrImplBase; using ADataType = remove_cvref_t; using BDataType = remove_cvref_t; @@ -103,8 +105,9 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem using BlockGemm = remove_cvref_t())>; using I0 = number<0>; + using I1 = number<1>; + using I2 = number<2>; - static constexpr index_t BlockSize = Problem::kBlockSize; static constexpr index_t MPerBlock = BlockGemmShape::kM; static constexpr index_t NPerBlock = BlockGemmShape::kN; static constexpr index_t KPerBlock = BlockGemmShape::kK; @@ -124,46 +127,20 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem using Base::PrefetchStages; - CK_TILE_HOST_DEVICE constexpr index_t GetStaticLdsSize() - { - return integer_divide_ceil( - sizeof(ADataType) * - Policy::template MakeALdsBlockDescriptor().get_element_space_size(), - 16) * - 16 + - sizeof(BDataType) * - Policy::template MakeBLdsBlockDescriptor().get_element_space_size(); - } - CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return Policy::template GetSmemSize(); } template - struct PipelineImpl + struct PipelineImpl : public PipelineImplBase { }; template <> - struct PipelineImpl + struct PipelineImpl : public PipelineImplBase { - template - CK_TILE_DEVICE void GlobalPrefetch(DstBlockTile& dst_block_tile, - SrcTileWindow& dram_tile_window) const - { - load_tile(dst_block_tile, dram_tile_window); - move_tile_window(dram_tile_window, {0, KPerBlock}); - } - - template - CK_TILE_DEVICE void LocalPrefill(DstTileWindow& lds_tile_window, - const SrcBlockTile& src_block_tile, - const ElementFunction& element_func) const - { - const auto block_tile_tmp = tile_elementwise_in(element_func, src_block_tile); - store_tile(lds_tile_window, block_tile_tmp); - } + using Base = PipelineImplBase; template "A/B Dram block window should have the same data type as appropriate " "([A|B]DataType) defined in Problem definition!"); - static_assert(MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && - NPerBlock == - BDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && - KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<1>{}], + static_assert(MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] && + NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] && + KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}], "A/B block window appropriate sizes must be equal to MPerBlock/NPerblock" " or KPerBlock!"); // ------------------------------------------------------------------------------------ // Definitions of all needed tiles - // A tile in LDS - ADataType* p_a_lds = static_cast(p_smem); - constexpr auto a_lds_block_desc = Policy::template MakeALdsBlockDescriptor(); - auto a_lds_block = make_tensor_view(p_a_lds, a_lds_block_desc); - - // TODO: LDS alignment should come from Policy! - constexpr index_t a_lds_block_space_size_aligned = - integer_divide_ceil(sizeof(ADataType) * a_lds_block_desc.get_element_space_size(), - 16) * - 16; - - // B tile in LDS - BDataType* p_b_lds = static_cast( - static_cast(static_cast(p_smem) + a_lds_block_space_size_aligned)); - constexpr auto b_lds_block_desc = Policy::template MakeBLdsBlockDescriptor(); - auto b_lds_block = make_tensor_view(p_b_lds, b_lds_block_desc); + // A/B tiles in LDS + // With c++20 could simplify to below line. + // Currently get error: captured structured bindings are a C++20 extension + // auto&& [a_lds_block, b_lds_block] = Base::GetABLdsTensorViews(p_smem); + auto ab_lds_blocks = Base::GetABLdsTensorViews(p_smem); + auto& a_lds_block = ab_lds_blocks.at(I0{}); + auto& b_lds_block = ab_lds_blocks.at(I1{}); // A DRAM tile window for load - auto a_copy_dram_window = - make_tile_window(a_dram_block_window_tmp.get_bottom_tensor_view(), - make_tuple(number{}, number{}), - a_dram_block_window_tmp.get_window_origin(), - Policy::template MakeADramTileDistribution()); - // A LDS tile window for store - auto a_copy_lds_window = - make_tile_window(a_lds_block, - make_tuple(number{}, number{}), - {0, 0}, - a_copy_dram_window.get_tile_distribution()); - // B DRAM tile window for load - auto b_copy_dram_window = - make_tile_window(b_dram_block_window_tmp.get_bottom_tensor_view(), - make_tuple(number{}, number{}), - b_dram_block_window_tmp.get_window_origin(), - Policy::template MakeBDramTileDistribution()); - - // B LDS tile window for store - auto b_copy_lds_window = - make_tile_window(b_lds_block, - make_tuple(number{}, number{}), - {0, 0}, - b_copy_dram_window.get_tile_distribution()); - // A LDS tile for block GEMM - auto a_lds_gemm_window = make_tile_window( - a_lds_block, make_tuple(number{}, number{}), {0, 0}); + auto a_windows = Base::GetAWindows(a_dram_block_window_tmp, a_lds_block); + auto& a_copy_dram_window = a_windows.at(I0{}); + auto& a_copy_lds_window = a_windows.at(I1{}); + auto& a_lds_gemm_window = a_windows.at(I2{}); + + // B DRAM tile window for load + // B LDS tile window for store // B LDS tile for block GEMM - auto b_lds_gemm_window = make_tile_window( - b_lds_block, make_tuple(number{}, number{}), {0, 0}); + auto b_windows = Base::GetBWindows(b_dram_block_window_tmp, b_lds_block); + auto& b_copy_dram_window = b_windows.at(I0{}); + auto& b_copy_lds_window = b_windows.at(I1{}); + auto& b_lds_gemm_window = b_windows.at(I2{}); // Block GEMM auto block_gemm = BlockGemm(); @@ -266,20 +215,20 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem // prefetch // global read 0 - GlobalPrefetch(a_block_tiles.get(I0{}), a_copy_dram_window); - GlobalPrefetch(b_block_tiles.get(I0{}), b_copy_dram_window); + Base::GlobalPrefetch(a_block_tiles.get(I0{}), a_copy_dram_window); + Base::GlobalPrefetch(b_block_tiles.get(I0{}), b_copy_dram_window); // initialize C tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); // LDS write 0 - LocalPrefill(a_copy_lds_window, a_block_tiles.get(I0{}), a_element_func); - LocalPrefill(b_copy_lds_window, b_block_tiles.get(I0{}), b_element_func); + Base::LocalPrefill(a_copy_lds_window, a_block_tiles.get(I0{}), a_element_func); + Base::LocalPrefill(b_copy_lds_window, b_block_tiles.get(I0{}), b_element_func); // Global prefetch [1, PrefetchStages] static_for<1, PrefetchStages, 1>{}([&](auto prefetch_idx) { - GlobalPrefetch(a_block_tiles.get(number{}), a_copy_dram_window); - GlobalPrefetch(b_block_tiles.get(number{}), b_copy_dram_window); + Base::GlobalPrefetch(a_block_tiles.get(number{}), a_copy_dram_window); + Base::GlobalPrefetch(b_block_tiles.get(number{}), b_copy_dram_window); }); // main body @@ -295,19 +244,19 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem block_sync_lds(); - LocalPrefill( + Base::LocalPrefill( a_copy_lds_window, a_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{}), a_element_func); - LocalPrefill( + Base::LocalPrefill( b_copy_lds_window, b_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{}), b_element_func); - GlobalPrefetch(a_block_tiles.get(number{}), - a_copy_dram_window); - GlobalPrefetch(b_block_tiles.get(number{}), - b_copy_dram_window); + Base::GlobalPrefetch(a_block_tiles.get(number{}), + a_copy_dram_window); + Base::GlobalPrefetch(b_block_tiles.get(number{}), + b_copy_dram_window); }); i += PrefetchStages; @@ -323,12 +272,12 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem block_sync_lds(); - LocalPrefill(a_copy_lds_window, - a_block_tiles.get(number{}), - a_element_func); - LocalPrefill(b_copy_lds_window, - b_block_tiles.get(number{}), - b_element_func); + Base::LocalPrefill(a_copy_lds_window, + a_block_tiles.get(number{}), + a_element_func); + Base::LocalPrefill(b_copy_lds_window, + b_block_tiles.get(number{}), + b_element_func); }); block_sync_lds(); @@ -376,24 +325,9 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem }; template <> - struct PipelineImpl + struct PipelineImpl : public PipelineImplBase { - template - CK_TILE_DEVICE void GlobalPrefetch(DstBlockTile& dst_block_tile, - SrcTileWindow& dram_tile_window) const - { - load_tile(dst_block_tile, dram_tile_window); - move_tile_window(dram_tile_window, {0, KPerBlock}); - } - - template - CK_TILE_DEVICE void LocalPrefill(DstTileWindow& lds_tile_window, - const SrcBlockTile& src_block_tile, - const ElementFunction& element_func) const - { - const auto block_tile_tmp = tile_elementwise_in(element_func, src_block_tile); - store_tile(lds_tile_window, block_tile_tmp); - } + using Base = PipelineImplBase; template "A/B Dram block window should have the same data type as appropriate " "([A|B]DataType) defined in Problem definition!"); - static_assert(MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && - NPerBlock == - BDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && - KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<1>{}], + static_assert(MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] && + NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] && + KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}], "A/B block window appropriate sizes must be equal to MPerBlock/NPerblock" " or KPerBlock!"); // ------------------------------------------------------------------------------------ // Definitions of all needed tiles - // A tile in LDS - ADataType* p_a_lds = static_cast(p_smem); - constexpr auto a_lds_block_desc = Policy::template MakeALdsBlockDescriptor(); - auto a_lds_block = make_tensor_view(p_a_lds, a_lds_block_desc); - - // TODO: LDS alignment should come from Policy! - constexpr index_t a_lds_block_space_size_aligned = - integer_divide_ceil(sizeof(ADataType) * a_lds_block_desc.get_element_space_size(), - 16) * - 16; - - // B tile in LDS - BDataType* p_b_lds = static_cast( - static_cast(static_cast(p_smem) + a_lds_block_space_size_aligned)); - constexpr auto b_lds_block_desc = Policy::template MakeBLdsBlockDescriptor(); - auto b_lds_block = make_tensor_view(p_b_lds, b_lds_block_desc); + // A/B tiles in LDS + // With c++20 could simplify to below line. + // Currently get error: captured structured bindings are a C++20 extension + // auto&& [a_lds_block, b_lds_block] = Base::GetABLdsTensorViews(p_smem); + auto ab_lds_blocks = Base::GetABLdsTensorViews(p_smem); + auto& a_lds_block = ab_lds_blocks.at(I0{}); + auto& b_lds_block = ab_lds_blocks.at(I1{}); // A DRAM tile window for load - auto a_copy_dram_window = - make_tile_window(a_dram_block_window_tmp.get_bottom_tensor_view(), - make_tuple(number{}, number{}), - a_dram_block_window_tmp.get_window_origin(), - Policy::template MakeADramTileDistribution()); - // A LDS tile window for store - auto a_copy_lds_window = - make_tile_window(a_lds_block, - make_tuple(number{}, number{}), - {0, 0}, - a_copy_dram_window.get_tile_distribution()); - // B DRAM tile window for load - auto b_copy_dram_window = - make_tile_window(b_dram_block_window_tmp.get_bottom_tensor_view(), - make_tuple(number{}, number{}), - b_dram_block_window_tmp.get_window_origin(), - Policy::template MakeBDramTileDistribution()); - - // B LDS tile window for store - auto b_copy_lds_window = - make_tile_window(b_lds_block, - make_tuple(number{}, number{}), - {0, 0}, - b_copy_dram_window.get_tile_distribution()); - // A LDS tile for block GEMM - auto a_lds_gemm_window = make_tile_window( - a_lds_block, make_tuple(number{}, number{}), {0, 0}); + auto a_windows = Base::GetAWindows(a_dram_block_window_tmp, a_lds_block); + auto& a_copy_dram_window = a_windows.at(I0{}); + auto& a_copy_lds_window = a_windows.at(I1{}); + auto& a_lds_gemm_window = a_windows.at(I2{}); + + // B DRAM tile window for load + // B LDS tile window for store // B LDS tile for block GEMM - auto b_lds_gemm_window = make_tile_window( - b_lds_block, make_tuple(number{}, number{}), {0, 0}); + auto b_windows = Base::GetBWindows(b_dram_block_window_tmp, b_lds_block); + auto& b_copy_dram_window = b_windows.at(I0{}); + auto& b_copy_lds_window = b_windows.at(I1{}); + auto& b_lds_gemm_window = b_windows.at(I2{}); // Block GEMM auto block_gemm = BlockGemm(); @@ -496,20 +402,20 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem // prefetch // global read 0 - GlobalPrefetch(a_block_tiles.get(I0{}), a_copy_dram_window); - GlobalPrefetch(b_block_tiles.get(I0{}), b_copy_dram_window); + Base::GlobalPrefetch(a_block_tiles.get(I0{}), a_copy_dram_window); + Base::GlobalPrefetch(b_block_tiles.get(I0{}), b_copy_dram_window); // initialize C tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); // LDS write 0 - LocalPrefill(a_copy_lds_window, a_block_tiles.get(I0{}), a_element_func); - LocalPrefill(b_copy_lds_window, b_block_tiles.get(I0{}), b_element_func); + Base::LocalPrefill(a_copy_lds_window, a_block_tiles.get(I0{}), a_element_func); + Base::LocalPrefill(b_copy_lds_window, b_block_tiles.get(I0{}), b_element_func); // Global prefetch [1, PrefetchStages] static_for<1, PrefetchStages, 1>{}([&](auto prefetch_idx) { - GlobalPrefetch(a_block_tiles.get(number{}), a_copy_dram_window); - GlobalPrefetch(b_block_tiles.get(number{}), b_copy_dram_window); + Base::GlobalPrefetch(a_block_tiles.get(number{}), a_copy_dram_window); + Base::GlobalPrefetch(b_block_tiles.get(number{}), b_copy_dram_window); }); // main body @@ -523,19 +429,19 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); // no second block_sync_lds because it's interwave - LocalPrefill( + Base::LocalPrefill( a_copy_lds_window, a_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{}), a_element_func); - LocalPrefill( + Base::LocalPrefill( b_copy_lds_window, b_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{}), b_element_func); - GlobalPrefetch(a_block_tiles.get(number{}), - a_copy_dram_window); - GlobalPrefetch(b_block_tiles.get(number{}), - b_copy_dram_window); + Base::GlobalPrefetch(a_block_tiles.get(number{}), + a_copy_dram_window); + Base::GlobalPrefetch(b_block_tiles.get(number{}), + b_copy_dram_window); }); i += PrefetchStages; @@ -548,12 +454,12 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); // no second block_sync_lds because it's interwave - LocalPrefill(a_copy_lds_window, - a_block_tiles.get(number{}), - a_element_func); - LocalPrefill(b_copy_lds_window, - b_block_tiles.get(number{}), - b_element_func); + Base::LocalPrefill(a_copy_lds_window, + a_block_tiles.get(number{}), + a_element_func); + Base::LocalPrefill(b_copy_lds_window, + b_block_tiles.get(number{}), + b_element_func); }); block_sync_lds();