diff --git a/example/ck_tile/03_gemm/run_gemm_example.inc b/example/ck_tile/03_gemm/run_gemm_example.inc index fc67e3eaa2..418fab6250 100644 --- a/example/ck_tile/03_gemm/run_gemm_example.inc +++ b/example/ck_tile/03_gemm/run_gemm_example.inc @@ -263,9 +263,10 @@ int run_gemm_example_with_layouts(int argc, // set 1 column in A and 1 Row in B to perform outer product. // and test the results. //const ck_tile::index_t K_len = a_m_k.get_length(1); - const ck_tile::index_t M_len = a_m_k.get_length(0); - const ck_tile::index_t N_len = b_k_n.get_length(1); + //const ck_tile::index_t M_len = a_m_k.get_length(0); + //const ck_tile::index_t N_len = b_k_n.get_length(1); + /* // Fill 0th column in A ck_tile::half_t dd = 1; for(int i = 0; i < M_len; i++) @@ -273,11 +274,7 @@ int run_gemm_example_with_layouts(int argc, int j = 0; { a_m_k(i, j) = dd; - } - int k = 8; - { - a_m_k(i, k) = dd++; - } + } } // Fill 0th row in B @@ -289,15 +286,7 @@ int run_gemm_example_with_layouts(int argc, b_k_n(i, j) = dd; } } - - i = 8; - { - for(int j=0; j < N_len; j++) - { - b_k_n(i, j) = dd; - } - } - + */ ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size_in_bytes()); ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes()); @@ -441,5 +430,17 @@ int run_gemm_example_with_layouts(int argc, std::cout << "The GPU verification result is: " << (pass ? "correct" : "fail") << std::endl; } + const ck_tile::index_t M_len = c_m_n_dev_result.get_length(0); + const ck_tile::index_t N_len = c_m_n_dev_result.get_length(1); + + for (int i = 0; i < M_len; i++) + { + for (int j = 0; j < N_len; j++) + { + std::cout << std::setw(6) << ck_tile::type_convert(c_m_n_dev_result(i, j)); + } + std::cout<; using GemmPipeline = GEMM_PIPELINE; - + /* using GemmEpilogue = ck_tile::DefaultGemm2DEpilogue< ck_tile::DefaultGemm2DEpilogueProblem>; + */ - /* using GemmEpilogue = ck_tile::CShuffleEpilogue< ck_tile::CShuffleEpilogueProblem>; - */ + using Kernel = ck_tile::GemmKernel; auto kargs = Kernel::MakeKernelArgs(args); diff --git a/include/ck_tile/core/algorithm/static_encoding_pattern.hpp b/include/ck_tile/core/algorithm/static_encoding_pattern.hpp index 78884f3f9f..e6deb177e8 100644 --- a/include/ck_tile/core/algorithm/static_encoding_pattern.hpp +++ b/include/ck_tile/core/algorithm/static_encoding_pattern.hpp @@ -56,17 +56,19 @@ template struct TileDistributionEncodingPattern2D : public TileDistributionEncodingPattern { }; // Thread raked -template +template struct TileDistributionEncodingPattern2D : public TileDistributionEncodingPattern { @@ -90,6 +92,14 @@ struct TileDistributionEncodingPattern2D, sequence<1, 2>>{}); } + + CK_TILE_HOST_DEVICE static constexpr auto MakePingPong2DStaticTileDistribution() + { + // X0, X1 + // X0 - One Row/Column needs X1 no. of instructions to read/write. + // X1 - VecSize - The read instruction size. + // X is always the fastest changing dimension of the input matrix. + + // Y0, Y1, Y2 + // Y0 - Total number of warps in a thread group. + // Y1 - WarpSize / no-of-threads-in-N-dimension. + // - No. of threads needed in the M dimension + // Y2 - YPerTile / (Y1 * Y0) + // - Y size / (no. of threads on Y dimensions * no. of warps) + // - Total no. of iterations needed by all the warps in the thread group to cover the + // - entire tile window. + + // (2, 0) = PY0 -- Number of warps in the threadblock + // (2, 1) * (1, 0) = PY1 * PX0, (M threads) * (N Threads) + + static_assert(NumWarpGroups == 2); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<2, 1>>, + tuple, sequence<1, 0>>, + sequence<1, 2>, + sequence<1, 2>>{} + ); + } + + CK_TILE_HOST_DEVICE static constexpr auto MakePingPongShuffled2DStaticTileDistribution() + { + static_assert(NumWarpGroups == 2); + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<2, 1>>, + tuple, sequence<1, 0>>, + sequence<1, 2>, + sequence<1, 2>>{} + ); + } }; // Warp raked @@ -119,6 +173,7 @@ struct TileDistributionEncodingPattern2D : public TileDistributionEncodingPattern { @@ -167,6 +222,7 @@ struct TileDistributionEncodingPattern2D : public TileDistributionEncodingPattern { diff --git a/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp b/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp index 0081edcb2e..f61016b80b 100644 --- a/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp +++ b/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp @@ -151,6 +151,7 @@ struct CShuffleEpilogue kMPerIteration, kNPerIteration, GetVectorSizeC(), + 1, tile_distribution_pattern::thread_raked>; constexpr auto dram_tile_distribution = TileEncodingPattern::Make2DStaticTileDistribution(); diff --git a/include/ck_tile/ops/gemm.hpp b/include/ck_tile/ops/gemm.hpp index ec21c73abb..9e73820f08 100644 --- a/include/ck_tile/ops/gemm.hpp +++ b/include/ck_tile/ops/gemm.hpp @@ -30,8 +30,8 @@ #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4.hpp" -#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v5.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4_default_policy.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v5.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp" diff --git a/include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1.hpp b/include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1.hpp index b4362d9069..c84c0e7e4b 100644 --- a/include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1.hpp +++ b/include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1.hpp @@ -11,11 +11,11 @@ namespace ck_tile { // A is block distributed tensor // B is block distributed tensor // C is block distributed tensor -template +template struct BlockGemmARegBRegCRegV1 { private: - template + template struct GemmTraits_ { using Problem = remove_cvref_t; @@ -34,7 +34,7 @@ struct BlockGemmARegBRegCRegV1 static constexpr auto config = Policy::template GetWarpGemmMWarpNWarp(); using WarpGemm = remove_cvref_t())>; - static constexpr index_t MWarp = config.template at<1>(); + static constexpr index_t MWarp = (config.template at<1>()) / NumWarpGroups; static constexpr index_t NWarp = config.template at<2>(); static constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WarpGemm::kM); static constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WarpGemm::kN); @@ -47,7 +47,7 @@ struct BlockGemmARegBRegCRegV1 using Problem = remove_cvref_t; using Policy = remove_cvref_t; - using Traits = GemmTraits_; + using Traits = GemmTraits_; using WarpGemm = typename Traits::WarpGemm; using BlockGemmShape = typename Traits::BlockGemmShape; @@ -65,6 +65,8 @@ struct BlockGemmARegBRegCRegV1 CK_TILE_DEVICE static constexpr auto MakeABlockDistributionEncode() { + static_assert( MWarp == 1); + static_assert(MIterPerWarp == 2); constexpr auto a_block_outer_dstr_encoding = tile_distribution_encoding, tuple, sequence>, @@ -95,6 +97,7 @@ struct BlockGemmARegBRegCRegV1 CK_TILE_DEVICE static constexpr auto MakeCBlockDistributionEncode() { + static_assert(MWarp * NWarp == 1); constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding< sequence<>, tuple, sequence>, diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4_default_policy.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4_default_policy.hpp index e528847438..a6ae423b05 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4_default_policy.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4_default_policy.hpp @@ -67,7 +67,7 @@ struct GemmPipelineAgBgCrCompV4DefaultPolicy return b_lds_block_desc; } - template + template CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm() { using AccDataType = float; @@ -86,7 +86,9 @@ struct GemmPipelineAgBgCrCompV4DefaultPolicy BlockWarps, WarpGemm>; - return BlockGemmARegBRegCRegV1{}; + static_assert(NumWarpGroups == 2); + + return BlockGemmARegBRegCRegV1{}; } }; } // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v5.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v5.hpp index 5d2203aca7..4ecd41a846 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v5.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v5.hpp @@ -14,7 +14,7 @@ struct BaseGemmPipelineAgBgCrCompV5 { static constexpr index_t PrefetchStages = 1; static constexpr index_t PrefillStages = 1; - static constexpr index_t GlobalBufferNum = 2; + static constexpr index_t GlobalBufferNum = 1; CK_TILE_HOST_DEVICE static constexpr auto TransposeC() { return Problem::TransposeC; } @@ -25,14 +25,11 @@ struct BaseGemmPipelineAgBgCrCompV5 CK_TILE_HOST_DEVICE static constexpr TailNumber GetBlockLoopTailNum(index_t num_loop) { - if(num_loop > PrefetchStages) + if(num_loop > 0) { return TailNumber::One; } - else - { - return TailNumber::Two; - } + return TailNumber::One; } }; @@ -51,7 +48,9 @@ struct GemmPipelineAgBgCrCompV5 : public BaseGemmPipelineAgBgCrCompV5 using BLayout = remove_cvref_t; using CLayout = remove_cvref_t; - using BlockGemm = remove_cvref_t())>; + static constexpr index_t NumWarpGroups = 2; + + using BlockGemm = remove_cvref_t())>; using I0 = number<0>; using I1 = number<1>; using I2 = number<2>; @@ -102,6 +101,255 @@ struct GemmPipelineAgBgCrCompV5 : public BaseGemmPipelineAgBgCrCompV5 { using Base = PipelineImplBase; + template + CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + const AElementFunction& a_element_func, + const BDramBlockWindowTmp& b_dram_block_window_tmp, + const BElementFunction& b_element_func, + index_t num_loop, + void* __restrict__ p_smem_0) const + { + static_assert( + std::is_same_v> && + std::is_same_v>, + "Data Type conflict on A and B matrix input data type."); + + static_assert( + KPerBlock % ((NumWarps / 2) * KTileSize) == 0, + "Ping Pong Warps, TileSize and Block Size for K dimensions does not match."); + + constexpr bool is_a_col_major = + std::is_same_v; + constexpr bool is_b_row_major = std::is_same_v; + + static_assert(is_a_col_major + ? (KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] && + MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}]) + : (MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] && + KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}]), + "A block window has incorrect lengths for defined ALayout!"); + static_assert(is_b_row_major + ? (KPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] && + NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}]) + : (NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] && + KPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}]), + "B block window has incorrect lengths for defined BLayout!"); + + constexpr index_t num_stages_ = 2; // mem-read and GEMM + index_t group_id = __builtin_amdgcn_readfirstlane(get_warp_id() % NumWarpGroups); // warp-id (0, 1) for warp specific data in this pipeline + index_t op_id = __builtin_amdgcn_readfirstlane(get_warp_id() % NumWarpGroups); // operation to perform (mem-read (0) or GEMM(1)) + + // global memory structures here. + auto a_copy_dram_window = + make_tile_window(a_dram_block_window_tmp.get_bottom_tensor_view(), + make_tuple(number{}, number{}), + a_dram_block_window_tmp.get_window_origin(), + Policy::template MakePingPongADramTileDistribution()); + + // B DRAM tile window for load + auto b_copy_dram_window = + make_tile_window(b_dram_block_window_tmp.get_bottom_tensor_view(), + make_tuple(number{}, number{}), + b_dram_block_window_tmp.get_window_origin(), + Policy::template MakePingPongBDramTileDistribution()); + + // DRAM window steps. + using ADramTileWindowStep = typename ADramBlockWindowTmp::BottomTensorIndex; + using BDramTileWindowStep = typename BDramBlockWindowTmp::BottomTensorIndex; + constexpr ADramTileWindowStep a_dram_tile_window_step = + is_a_col_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock); + constexpr BDramTileWindowStep b_dram_tile_window_step = + is_b_row_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock); + + constexpr auto AGemmTileDistr = decltype(make_static_tile_distribution( + BlockGemm::MakeABlockDistributionEncode())){}; + constexpr auto BGemmTileDistr = decltype(make_static_tile_distribution( + BlockGemm::MakeBBlockDistributionEncode())){}; + + using AGemmTile = decltype(make_static_distributed_tensor(AGemmTileDistr)); + using BGemmTile = decltype(make_static_distributed_tensor(BGemmTileDistr)); + AGemmTile a_tile_0, a_tile_1; // Gemm Tiles in registers. + BGemmTile b_tile_0, b_tile_1; + + // Register tile for A and B. + constexpr auto ABlockTileDistr = a_copy_dram_window.get_tile_distribution(); + constexpr auto BBlockTileDistr = b_copy_dram_window.get_tile_distribution(); + using ABlockTile = decltype(make_static_distributed_tensor(ABlockTileDistr)); + using BBlockTile = decltype(make_static_distributed_tensor(BBlockTileDistr)); + ABlockTile a_global_load_tile; + BBlockTile b_global_load_tile; + + // Block GEMM + auto block_gemm = BlockGemm(); + auto c_block_tile_0 = block_gemm.MakeCBlockTile(); // Gemm distribution. + auto c_block_tile_1 = block_gemm.MakeCBlockTile(); + + + // Not needed + auto&& [a_lds_block, b_lds_block] = Base::GetABLdsTensorViews(p_smem_0); + + auto a_copy_lds_window = make_tile_window( + a_lds_block, make_tuple(number{}, number{}), {0, 0}); + auto b_copy_lds_window = make_tile_window( + b_lds_block, make_tuple(number{}, number{}), {0, 0}); + + auto a_lds_window = + make_tile_window(a_lds_block, + make_tuple(number{}, number{}), + {0, 0}, + AGemmTileDistr); + auto b_lds_window = + make_tile_window(b_lds_block, + make_tuple(number{}, number{}), + {0, 0}, + BGemmTileDistr); + + // initialize C + tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile_0); + tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile_1); + + // define ping, pong steps here as lambda functions. + auto MemoryOpsStep = [&](auto idx) { + + // Memory read half here. + Base::GlobalPrefetch( + a_global_load_tile, a_copy_dram_window, a_dram_tile_window_step); + Base::GlobalPrefetch( + b_global_load_tile, b_copy_dram_window, b_dram_tile_window_step); + + if constexpr(is_a_col_major) + { + auto a_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakePingPongShuffledARegTileDistribution()); + transpose_tile2d(a_shuffle_tmp, a_global_load_tile); + Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func); + } + else + { + Base::LocalPrefill(a_copy_lds_window, a_global_load_tile, a_element_func); + } + + if constexpr(is_b_row_major) + { + auto b_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakePingPongShuffledBRegTileDistribution()); + transpose_tile2d(b_shuffle_tmp, b_global_load_tile); + Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func); + } + else + { + Base::LocalPrefill(b_copy_lds_window, b_global_load_tile, b_element_func); + } + + if (idx == 0) + { + Base::LocalPrefetch(a_tile_0, a_lds_window); + Base::LocalPrefetch(b_tile_0, b_lds_window); + } + else + { + Base::LocalPrefetch(a_tile_1, a_lds_window); + Base::LocalPrefetch(b_tile_1, b_lds_window); + } + + }; + + auto ComputeStep = [&](auto idx) { + if (idx == 0) + { + block_gemm(c_block_tile_0, a_tile_0, b_tile_0); + } + else + { + block_gemm(c_block_tile_1, a_tile_1, b_tile_1); + } + }; + + if (op_id == 0) + { + MemoryOpsStep(group_id); + ComputeStep(group_id); + } + + // start the main loop. + index_t num_compute_steps = __builtin_amdgcn_readfirstlane(num_loop); + while(num_compute_steps > 10) + { + block_sync_lds(); + op_id = (op_id + 1) % num_stages_; + + if(op_id == 0) + { + MemoryOpsStep(group_id); + } + else + { + ComputeStep(group_id); + } + num_compute_steps -= 1; + } + block_sync_lds(); + + return c_block_tile_0; + + /* + if(op_id == 0) + { + MemoryOpsStep(group_id); + } + + // start the main loop. + index_t num_compute_steps = __builtin_amdgcn_readfirstlane(num_loop); + while(num_compute_steps > 1) + { + block_sync_lds(); + op_id = (op_id + 1) % num_stages_; + + if(op_id == 0) + { + MemoryOpsStep(group_id); + } + else + { + ComputeStep(group_id); + } + num_compute_steps -= 1; + } + block_sync_lds(); + + // Handle Tail Number here. + if(op_id == 0) + { + ComputeStep(group_id); + } + block_sync_lds(); + + // Add both the tiles and return the result. + + + if (group_id == 0) + { + constexpr auto s_spans = decltype(c_block_tile_0)::get_distributed_spans(); + sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) { + sweep_tile_span(s_spans[number<1>{}], [&](auto idx1) { + auto idx2 = make_tuple(idx0, idx1); + c_block_tile_0(idx2) += c_block_tile_1(idx2); + }); + }); + } + + return c_block_tile_0; + */ + } + + /* template "B block window has incorrect lengths for defined BLayout!"); static constexpr index_t num_stages_ = 2; - // This is used to identify the register tile on which a warp always operates on. - // For instance, warp-0 always uses a_block_tile_0 for reading in one cycle - // and execution in the next cycle. - index_t group_id = get_warp_id() % num_stages_; - - // op_id indicated one of the steps (0 - Read, 1 - Gemm Execution) - // Each warp performs read in one cycle and in the next cycles performs GEMM operation - // on the same block_tile that it has read in the previous cycle. - index_t op_id = get_warp_id() % num_stages_; + index_t group_id = __builtin_amdgcn_readfirstlane(get_warp_id() % num_stages_); + index_t op_id = __builtin_amdgcn_readfirstlane(get_warp_id() % num_stages_); // global memory structures here. auto a_copy_dram_window = - make_tile_window_linear(a_dram_block_window_tmp.get_bottom_tensor_view(), + make_tile_window(a_dram_block_window_tmp.get_bottom_tensor_view(), make_tuple(number{}, number{}), a_dram_block_window_tmp.get_window_origin(), Policy::template MakeADramTileDistribution()); // B DRAM tile window for load auto b_copy_dram_window = - make_tile_window_linear(b_dram_block_window_tmp.get_bottom_tensor_view(), + make_tile_window(b_dram_block_window_tmp.get_bottom_tensor_view(), make_tuple(number{}, number{}), b_dram_block_window_tmp.get_window_origin(), Policy::template MakeBDramTileDistribution()); @@ -184,8 +425,6 @@ struct GemmPipelineAgBgCrCompV5 : public BaseGemmPipelineAgBgCrCompV5 using ALdsTile = decltype(make_static_distributed_tensor(ALdsTileDistr)); using BLdsTile = decltype(make_static_distributed_tensor(BLdsTileDistr)); - // array a_tiles; - // array b_tiles; ALdsTile a_tile_0, a_tile_1; BLdsTile b_tile_0, b_tile_1; @@ -199,16 +438,15 @@ struct GemmPipelineAgBgCrCompV5 : public BaseGemmPipelineAgBgCrCompV5 b_lds_block, make_tuple(number{}, number{}), {0, 0}); auto a_lds_window = - make_tile_window_linear(a_lds_block, + make_tile_window(a_lds_block, make_tuple(number{}, number{}), {0, 0}, ALdsTileDistr); auto b_lds_window = - make_tile_window_linear(b_lds_block, + make_tile_window(b_lds_block, make_tuple(number{}, number{}), {0, 0}, BLdsTileDistr); - // Register tile for A and B. constexpr auto ABlockTileDistr = a_copy_dram_window.get_tile_distribution(); constexpr auto BBlockTileDistr = b_copy_dram_window.get_tile_distribution(); @@ -219,13 +457,14 @@ struct GemmPipelineAgBgCrCompV5 : public BaseGemmPipelineAgBgCrCompV5 // Block GEMM auto block_gemm = BlockGemm(); - auto c_block_tile = block_gemm.MakeCBlockTile(); + auto c_block_tile_0 = block_gemm.MakeCBlockTile(); // initialize C - tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); - + tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile_0); + // define ping, pong steps here as lambda functions. - auto MemoryOpsStep = [&]() { + auto MemoryOpsStep = [&](auto idx) { + // Memory read half here. Base::GlobalPrefetch( a_global_load_tile, a_copy_dram_window, a_dram_tile_window_step); @@ -237,49 +476,52 @@ struct GemmPipelineAgBgCrCompV5 : public BaseGemmPipelineAgBgCrCompV5 { auto a_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledARegTileDistribution()); - transpose_tile2d(a_shuffle_tmp, a_global_load_tile); - Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func); + transpose_tile2d(a_shuffle_tmp, a_global_load_tile); + Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func); } else { - Base::LocalPrefill(a_copy_lds_window, a_global_load_tile, a_element_func); + Base::LocalPrefill(a_copy_lds_window, a_global_load_tile, a_element_func); } + if constexpr(is_b_row_major) { auto b_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledBRegTileDistribution()); transpose_tile2d(b_shuffle_tmp, b_global_load_tile); - Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func); + Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, a_element_func); } else { - Base::LocalPrefill(b_copy_lds_window, b_global_load_tile, b_element_func); + Base::LocalPrefill(a_copy_lds_window, b_global_load_tile, a_element_func); } }; auto ComputeStep = [&](auto idx) { - if(idx == 0) + if (idx == 0) { + //tile_elementwise_inout([&step](auto& c) { c = step; }, c_block_tile_0); Base::LocalPrefetch(a_tile_0, a_lds_window); Base::LocalPrefetch(b_tile_0, b_lds_window); - block_gemm(c_block_tile, a_tile_0, b_tile_0); - // tile_elementwise_inout([](auto& c) { c += 1; }, c_block_tile); + block_gemm(c_block_tile_0, a_tile_0, b_tile_0); } else { + //tile_elementwise_inout([&step](auto& c) { c = step; }, c_block_tile_0); Base::LocalPrefetch(a_tile_1, a_lds_window); - Base::LocalPrefetch(b_tile_1, b_lds_window); - block_gemm(c_block_tile, a_tile_1, b_tile_1); - // tile_elementwise_inout([](auto& c) { c += 1; }, c_block_tile); + Base::LocalPrefetch(b_tile_1, b_lds_window); + block_gemm(c_block_tile_0, a_tile_1, b_tile_1); + } }; - + if(op_id == 0) { - MemoryOpsStep(); - } + MemoryOpsStep(group_id); + } + // start the main loop. - index_t num_compute_steps = __builtin_amdgcn_readfirstlane(num_loop) * 2 - 1; + index_t num_compute_steps = __builtin_amdgcn_readfirstlane(num_loop)*2 - 1; while(num_compute_steps > 0) { @@ -287,29 +529,40 @@ struct GemmPipelineAgBgCrCompV5 : public BaseGemmPipelineAgBgCrCompV5 block_sync_lds(); op_id = (op_id + 1) % num_stages_; - if(op_id == 0) { - MemoryOpsStep(); + if(op_id == 0) + { + MemoryOpsStep(group_id); + } + else + { + ComputeStep(group_id); + } } - else - { - ComputeStep(group_id); - } - num_compute_steps -= 1; } - - - // Handle Tail Number here. block_sync_lds(); + + // Handle Tail Number here. if(op_id == 0) { ComputeStep(group_id); } - block_sync_lds(); - return c_block_tile; + block_sync_lds(); + + + //constexpr auto s_spans = decltype(c_block_tile_0)::get_distributed_spans(); + //sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) { + // sweep_tile_span(s_spans[number<1>{}], [&](auto idx1) { + // auto idx2 = make_tuple(idx0, idx1); + // c_block_tile_2(idx2) = c_block_tile_0(idx2) + c_block_tile_1(idx2); + // }); + //}); + + return c_block_tile_0; } + */ }; template const BDramBlockWindowTmp& b_dram_block_window_tmp, const BElementFunction& b_element_func, index_t num_loop, - void* p_smem_0) const + void* p_smem_0 + ) const { - return PipelineImpl{}.template operator()( + return PipelineImpl{}.template operator()( a_dram_block_window_tmp, a_element_func, b_dram_block_window_tmp, b_element_func, num_loop, - p_smem_0); + p_smem_0 + ); } public: @@ -337,15 +592,17 @@ struct GemmPipelineAgBgCrCompV5 : public BaseGemmPipelineAgBgCrCompV5 CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, const BDramBlockWindowTmp& b_dram_block_window_tmp, const index_t num_loop, - void* __restrict__ p_smem_0) const + void* __restrict__ p_smem_0 + ) const { - return PipelineImpl{}.template operator()( + return PipelineImpl{}.template operator()( a_dram_block_window_tmp, [](const ADataType& a) { return a; }, b_dram_block_window_tmp, [](const BDataType& b) { return b; }, num_loop, - p_smem_0); + p_smem_0 + ); } }; diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp index f5b3523f60..2b51b4d118 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp @@ -123,7 +123,7 @@ struct UniversalGemmBasePolicy template CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeC() { - using BlockGemm = remove_cvref_t())>; + using BlockGemm = remove_cvref_t())>; using WG = typename BlockGemm::WarpGemm; constexpr bool TransposeC = Problem::TransposeC; @@ -182,6 +182,75 @@ struct UniversalGemmBasePolicy return Problem::TransposeC; } + template + CK_TILE_HOST_DEVICE static constexpr auto MakePingPongADramTileDistribution() + { + using ALayout = remove_cvref_t; + + constexpr index_t BlockSize = Problem::kBlockSize; + constexpr index_t MPerBlock = Problem::BlockGemmShape::kM; + constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + constexpr index_t VecLoadSize = GetVectorSizeA(); + + if constexpr(std::is_same_v) + { + using TileEncodingPattern = TileDistributionEncodingPattern2D; + return TileEncodingPattern::MakePingPong2DStaticTileDistribution(); + } + else + { + using TileEncodingPattern = TileDistributionEncodingPattern2D; + return TileEncodingPattern::MakePingPong2DStaticTileDistribution(); + } + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakePingPongBDramTileDistribution() + { + using BLayout = remove_cvref_t; + + constexpr index_t BlockSize = Problem::kBlockSize; + constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; + constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + constexpr index_t VecLoadSize = GetVectorSizeB(); + + // Tile: KPerBlock X NPerBlock + if constexpr(std::is_same_v) + { + using TileEncodingPattern = TileDistributionEncodingPattern2D; + return TileEncodingPattern::MakePingpPong2DStaticTileDistribution(); + } + // Tile: NPerBlock X KPerBlock + else + { + using TileEncodingPattern = TileDistributionEncodingPattern2D; + return TileEncodingPattern::MakePingPong2DStaticTileDistribution(); + + } + } + template CK_TILE_HOST_DEVICE static constexpr auto MakeADramTileDistribution() { @@ -199,6 +268,7 @@ struct UniversalGemmBasePolicy MPerBlock, KPerBlock, VecLoadSize, + 1, ATileAccessPattern>; return TileEncodingPattern::Make2DStaticTileDistribution(); } @@ -209,6 +279,7 @@ struct UniversalGemmBasePolicy KPerBlock, MPerBlock, VecLoadSize, + 1, ATileAccessPattern>; return TileEncodingPattern::Make2DStaticTileDistribution(); } @@ -231,6 +302,7 @@ struct UniversalGemmBasePolicy KPerBlock, NPerBlock, VecLoadSize, + 1, BTileAccessPattern>; return TileEncodingPattern::Make2DStaticTileDistribution(); } @@ -241,11 +313,50 @@ struct UniversalGemmBasePolicy NPerBlock, KPerBlock, VecLoadSize, + 1, BTileAccessPattern>; return TileEncodingPattern::Make2DStaticTileDistribution(); } } + template + CK_TILE_HOST_DEVICE static constexpr auto MakePingPongShuffledARegTileDistribution() + { + using ALayout = remove_cvref_t; + static_assert(std::is_same_v); + constexpr index_t BlockSize = Problem::kBlockSize; + constexpr index_t MPerBlock = Problem::BlockGemmShape::kM; + constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + constexpr index_t VecLoadSize = GetVectorSizeA(); + + using TileEncodingPattern = TileDistributionEncodingPattern2D; + return TileEncodingPattern::MakePingPongShuffled2DStaticTileDistribution(); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakePingPongShuffledBRegTileDistribution() + { + using BLayout = remove_cvref_t; + static_assert(std::is_same_v); + constexpr index_t BlockSize = Problem::kBlockSize; + constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; + constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + constexpr index_t VecLoadSize = GetVectorSizeB(); + + using TileEncodingPattern = TileDistributionEncodingPattern2D; + return TileEncodingPattern::MakePingPongShuffled2DStaticTileDistribution(); + } + template CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledARegTileDistribution() { @@ -260,6 +371,7 @@ struct UniversalGemmBasePolicy KPerBlock, MPerBlock, VecLoadSize, + 1, ATileAccessPattern>; return TileEncodingPattern::MakeShuffled2DStaticTileDistribution(); } @@ -278,6 +390,7 @@ struct UniversalGemmBasePolicy KPerBlock, NPerBlock, VecLoadSize, + 1, BTileAccessPattern>; return TileEncodingPattern::MakeShuffled2DStaticTileDistribution(); } @@ -285,7 +398,7 @@ struct UniversalGemmBasePolicy template CK_TILE_HOST_DEVICE static constexpr auto GetSmemPackA() { - using BlockGemm = remove_cvref_t())>; + using BlockGemm = remove_cvref_t())>; constexpr index_t KPack = BlockGemm::Traits::KPack; return KPack; } @@ -293,7 +406,7 @@ struct UniversalGemmBasePolicy template CK_TILE_HOST_DEVICE static constexpr auto GetSmemPackB() { - using BlockGemm = remove_cvref_t())>; + using BlockGemm = remove_cvref_t())>; constexpr index_t KPack = BlockGemm::Traits::KPack; return KPack; } @@ -362,7 +475,7 @@ struct UniversalGemmPipelineAgBgCrPolicy constexpr auto a_lds_block_desc_xk0_mnldslayer_mn_xk1 = transform_tensor_descriptor( a_lds_block_desc_permuted, make_tuple(make_unmerge_transform( - make_tuple(number{}, number{})), + make_tuple(number{}, number{})), make_pass_through_transform(number{}), make_pass_through_transform(number{})), make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), @@ -374,7 +487,7 @@ struct UniversalGemmPipelineAgBgCrPolicy make_tuple(number{}, number{})), make_merge_transform_v3_division_mod( make_tuple(number{}, number{}))), - make_tuple(sequence<1, 2>{}, sequence<0, 3>{}), + make_tuple(sequence<1, 0>{}, sequence<2, 3>{}), make_tuple(sequence<0>{}, sequence<1>{})); return a_lds_block_desc; @@ -421,7 +534,7 @@ struct UniversalGemmPipelineAgBgCrPolicy constexpr auto b_lds_block_desc_bk0_nldslayer_n_bk1 = transform_tensor_descriptor( b_lds_block_desc_permuted, - make_tuple(make_unmerge_transform(make_tuple(BK0, number{})), + make_tuple(make_unmerge_transform(make_tuple(number{}, BK0)), make_pass_through_transform(number{}), make_pass_through_transform(number{})), make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), @@ -432,7 +545,7 @@ struct UniversalGemmPipelineAgBgCrPolicy make_tuple(make_merge_transform_v3_division_mod( make_tuple(number{}, number{})), make_merge_transform_v3_division_mod(make_tuple(BK0, number{}))), - make_tuple(sequence<1, 2>{}, sequence<0, 3>{}), + make_tuple(sequence<1, 0>{}, sequence<2, 3>{}), make_tuple(sequence<0>{}, sequence<1>{})); return b_lds_block_desc; }