From 9df3f6f8865b3cf6b968e500cccc9ee8a3e1ee18 Mon Sep 17 00:00:00 2001 From: Sudhir Kylasa Date: Wed, 24 Sep 2025 20:59:10 +0000 Subject: [PATCH] N dimension parallelism code drop --- example/ck_tile/03_gemm/universal_gemm.cpp | 2 +- .../ops/epilogue/cshuffle_epilogue.hpp | 15 +- .../ck_tile/ops/gemm/kernel/gemm_kernel.hpp | 5 + .../ops/gemm/kernel/gemm_tile_partitioner.hpp | 12 + .../ops/gemm/kernel/universal_gemm_kernel.hpp | 412 +++++++++++++----- .../gemm_pipeline_ag_bg_cr_comp_v5.hpp | 299 +++++++++++++ 6 files changed, 623 insertions(+), 122 deletions(-) diff --git a/example/ck_tile/03_gemm/universal_gemm.cpp b/example/ck_tile/03_gemm/universal_gemm.cpp index 50d5930be7..50e539f820 100644 --- a/example/ck_tile/03_gemm/universal_gemm.cpp +++ b/example/ck_tile/03_gemm/universal_gemm.cpp @@ -180,7 +180,7 @@ float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) else { std::cout << "Ping pong....ON " << std::endl; - grids = Kernel::PingPongGridSize(args.M, args.N, args.K, args.k_batch); + grids = Kernel::PingPongGridSizeNParallel(args.M, args.N, args.K, args.k_batch); std::cout << "Arguments: { " << args.M << ", " << args.N << ", " << args.K << ", " << args.k_batch << " }" << std::endl; std::cout << "Grid size : {" << grids.x << ", " << grids.y << ", " << grids.z << "}" << std::endl; diff --git a/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp b/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp index 3a4553ff08..4c852f9a7c 100644 --- a/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp +++ b/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp @@ -361,14 +361,17 @@ struct CShuffleEpilogue buffer_store_fence(); if constexpr(iAccess != num_access - 1) { - constexpr auto step = SFC::get_forward_step(iAccess); + if (execute_epilogue) + { + constexpr auto step = SFC::get_forward_step(iAccess); - move_tile_window(out_dram_window, {step.at(number<0>{}), step.at(number<1>{})}); + move_tile_window(out_dram_window, {step.at(number<0>{}), step.at(number<1>{})}); - static_for<0, NumDTensor, 1>{}([&](auto idx) { - move_tile_window(d_dram_windows[idx], - {step.at(number<0>{}), step.at(number<1>{})}); - }); + static_for<0, NumDTensor, 1>{}([&](auto idx) { + move_tile_window(d_dram_windows[idx], + {step.at(number<0>{}), step.at(number<1>{})}); + }); + } } }); } diff --git a/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp index be231bc05e..91e18de477 100644 --- a/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp @@ -135,6 +135,11 @@ struct GemmKernel return dim3(TilePartitioner::PingPongGridSize(N, K), 1, KBatch); } + CK_TILE_HOST static auto PingPongGridSizeNParallel(index_t M, index_t, index_t K, index_t KBatch) -> dim3 + { + return dim3(TilePartitioner::PingPongGridSizeNParallel(M, K), 1, KBatch); + } + CK_TILE_HOST static constexpr auto BlockSize() -> dim3 { return UniversalGemmKernel::BlockSize(); diff --git a/include/ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp b/include/ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp index 629b64053d..369206629a 100644 --- a/include/ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp +++ b/include/ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp @@ -265,6 +265,18 @@ struct GemmSpatiallyLocalTilePartitioner return GridDimX * GridDimY; } + CK_TILE_HOST_DEVICE static auto + PingPongGridSizeNParallel(index_t M, index_t K) noexcept(noexcept(MPerBlock != 0 && KPerBlock != 0)) -> index_t + { + const index_t GridDimX = integer_divide_ceil(M, MPerBlock); + const index_t GridDimY = integer_divide_ceil(K, KPerBlock); + + std::cout << "PingPong Grid size, N_DIM_PARALLELISM M GRID SIZE : {" << GridDimX << ", " << GridDimY << "}" << std::endl; + std::cout << "Arguments: { " << M << ", " << K << " }" << std::endl; + std::cout << "Block size : {" << MPerBlock << ", " << KPerBlock << "}" << std::endl; + return GridDimX * GridDimY; + } + /** * @brief Calculate number of loop iterations over GEMM's K dimension. * diff --git a/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp index 272bf24996..a4cbd6102c 100644 --- a/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp @@ -287,7 +287,13 @@ struct UniversalGemmKernel CK_TILE_HOST static auto PingPongGridSize(index_t, index_t N, index_t K, index_t KBatch) -> dim3 { return dim3(TilePartitioner::PingPongGridSize(N, K), 1, KBatch); - } + } + + CK_TILE_HOST static auto + PingPongGridSizeNParallel(index_t M, index_t, index_t K, index_t KBatch) -> dim3 + { + return dim3(TilePartitioner::PingPongGridSizeNParallel(M, K), 1, KBatch); + } CK_TILE_HOST static auto BlockSize() { @@ -855,41 +861,45 @@ struct UniversalGemmKernel { const auto& a_pad_view = generate_tuple( [&](auto i) { - const auto& a_tensor_view = views.at(I0); - using AiLayout = remove_cvref_t>; - if constexpr(std::is_same_v) - { - return pad_tensor_view(a_tensor_view[i], - make_tuple(number{}, number{}), - sequence{}); - } - else - { - return pad_tensor_view(a_tensor_view[i], - make_tuple(number{}, number{}), - sequence{}); - } - }, - number{}); - + const auto& a_tensor_view = views.at(I0); + using AiLayout = remove_cvref_t>; + if constexpr(std::is_same_v) + { + return pad_tensor_view(a_tensor_view[i], + make_tuple(number{}, + number{}), + sequence{}); + } + else + { + return pad_tensor_view(a_tensor_view[i], + make_tuple(number{}, + number{}), + sequence{}); + } + }, + number{}); + const auto& b_pad_view = generate_tuple( [&](auto i) { - const auto& b_tensor_view = views.at(I1); - using BiLayout = remove_cvref_t>; - if constexpr(std::is_same_v) - { - return pad_tensor_view(b_tensor_view[i], - make_tuple(number{}, number{}), - sequence{}); - } - else - { - return pad_tensor_view(b_tensor_view[i], - make_tuple(number{}, number{}), - sequence{}); - } - }, - number{}); + const auto& b_tensor_view = views.at(I1); + using BiLayout = remove_cvref_t>; + if constexpr(std::is_same_v) + { + return pad_tensor_view(b_tensor_view[i], + make_tuple(number{}, + number{}), + sequence{}); + } + else + { + return pad_tensor_view(b_tensor_view[i], + make_tuple(number{}, + number{}), + sequence{}); + } + }, + number{}); const auto& d_pad_view = generate_tuple( [&](auto i) { @@ -910,29 +920,29 @@ struct UniversalGemmKernel sequence{}); } }, - number{}); + number{}); // TODO vector write in for C in ColMajor const auto& e_pad_view = [&]() { const auto& e_tensor_view = views.at(I3); if constexpr(std::is_same_v) { - return pad_tensor_view( - e_tensor_view, - make_tuple(number{}, number{}), - sequence{}); + return pad_tensor_view(e_tensor_view, + make_tuple(number{}, + number{}), + sequence{}); } else { - return pad_tensor_view( - e_tensor_view, - make_tuple(number{}, number{}), - sequence{}); + return pad_tensor_view(e_tensor_view, + make_tuple(number{}, + number{}), + sequence{}); } - }(); + }(); return make_tuple(a_pad_view, b_pad_view, d_pad_view, e_pad_view); - } + } template CK_TILE_DEVICE static auto @@ -1024,68 +1034,77 @@ struct UniversalGemmKernel } template - CK_TILE_DEVICE static auto MakePingPongGemmTileWindows - (const PadView& views, const index_t i_n, const index_t i_k, [[maybe_unused]] const index_t M, [[maybe_unused]] const index_t N, [[maybe_unused]] const index_t K) + CK_TILE_DEVICE static auto + MakePingPongGemmTileWindowsMParallel(const PadView& views, + const index_t i_n, + const index_t i_k, + [[maybe_unused]] const index_t M, + [[maybe_unused]] const index_t N, + [[maybe_unused]] const index_t K) { const auto& as_pad_view = views.at(I0); const auto& bs_pad_view = views.at(I1); const auto& ds_pad_view = views.at(I2); - const auto& e_pad_view = views.at(I3); + const auto& e_pad_view = views.at(I3); const auto& as_block_window = generate_tuple( [&](auto i) { using AiLayout = remove_cvref_t>; if constexpr(std::is_same_v) { - return make_tile_window( - as_pad_view[i], make_tuple(number{}, number{}), {0, i_k}); + return make_tile_window(as_pad_view[i], + make_tuple(number{}, + number{}), + {0, i_k}); } else { - return make_tile_window( - as_pad_view[i], make_tuple(number{}, number{}), {i_k, 0}); + return make_tile_window(as_pad_view[i], + make_tuple(number{}, + number{}), + {i_k, 0}); } - }, - number{}); + }, + number{}); const auto& bs_block_window = generate_tuple( [&](auto i) { using BiLayout = remove_cvref_t>; if constexpr(std::is_same_v) { - return make_tile_window( - bs_pad_view[i], - make_tuple(number{}, number{}), - {i_n, i_k}); + return make_tile_window(bs_pad_view[i], + make_tuple(number{}, + number{}), + {i_n, i_k}); } else { - return make_tile_window( - bs_pad_view[i], - make_tuple(number{}, number{}), - {i_k, i_n}); + return make_tile_window(bs_pad_view[i], + make_tuple(number{}, + number{}), + {i_k, i_n}); } - }, - number{}); + }, + number{}); const auto& ds_block_window = generate_tuple( [&](auto i) { using DiLayout = remove_cvref_t>; if constexpr(std::is_same_v) { - return make_tile_window( - ds_pad_view[i], - make_tuple(number{}, number{}), - {i_n, i_k}); + return make_tile_window(ds_pad_view[i], + make_tuple(number{}, + number{}), + {i_n, i_k}); } else { - return make_tile_window( - ds_pad_view[i], - make_tuple(number{}, number{}), - {i_k, i_n}); + return make_tile_window(ds_pad_view[i], + make_tuple(number{}, + number{}), + {i_k, i_n}); } - }, + }, number{}); auto e_block_window = make_tile_window( @@ -1093,8 +1112,90 @@ struct UniversalGemmKernel make_tuple(number{}, number{}), {0, i_n}); - return make_tuple(as_block_window, bs_block_window, ds_block_window, e_block_window); - } + return make_tuple(as_block_window, bs_block_window, ds_block_window, e_block_window); + } + + template + CK_TILE_DEVICE static auto + MakePingPongGemmTileWindowsNParallel(const PadView& views, + const index_t i_m, + const index_t i_k, + [[maybe_unused]] const index_t M, + [[maybe_unused]] const index_t N, + [[maybe_unused]] const index_t K) + { + const auto& as_pad_view = views.at(I0); + const auto& bs_pad_view = views.at(I1); + const auto& ds_pad_view = views.at(I2); + const auto& e_pad_view = views.at(I3); + + const auto& as_block_window = generate_tuple( + [&](auto i) { + using AiLayout = remove_cvref_t>; + if constexpr(std::is_same_v) + { + return make_tile_window(as_pad_view[i], + make_tuple(number{}, + number{}), + {i_m, i_k}); + } + else + { + return make_tile_window(as_pad_view[i], + make_tuple(number{}, + number{}), + {i_k, i_m}); + } + }, + number{}); + + const auto& bs_block_window = generate_tuple( + [&](auto i) { + using BiLayout = remove_cvref_t>; + if constexpr(std::is_same_v) + { + return make_tile_window(bs_pad_view[i], + make_tuple(number{}, + number{}), + {0, i_k}); + } + else + { + return make_tile_window(bs_pad_view[i], + make_tuple(number{}, + number{}), + {i_k, 0}); + } + }, + number{}); + + const auto& ds_block_window = generate_tuple( + [&](auto i) { + using DiLayout = remove_cvref_t>; + if constexpr(std::is_same_v) + { + return make_tile_window(ds_pad_view[i], + make_tuple(number{}, + number{}), + {i_m, i_k}); + } + else + { + return make_tile_window(ds_pad_view[i], + make_tuple(number{}, + number{}), + {i_k, i_m}); + } + }, + number{}); + + auto e_block_window = make_tile_window( + e_pad_view, + make_tuple(number{}, number{}), + {i_m, 0}); + + return make_tuple(as_block_window, bs_block_window, ds_block_window, e_block_window); + } /** * @brief Runs single GEMM problem cooperatively by whole workgroup. @@ -1149,43 +1250,42 @@ struct UniversalGemmKernel } } - CK_TILE_DEVICE static void PingPongGemm(const std::array& a_ptr, - const std::array& b_ptr, - const std::array& d_ptr, - EDataType* e_ptr, - void* smem_ptr_0, - const KernelArgs& kargs, - const SplitKBatchOffset& splitk_batch_offset, - [[maybe_unused]] const index_t block_idx_n, - [[maybe_unused]] const index_t block_idx_k) + // PingPongGemmNDim(as_ptr, bs_ptr, kargs.ds_ptr, es_ptr, smem_ptr_0, smem_ptr_1, + // smem_ptr_2, kargs, i_n, i_k); + CK_TILE_DEVICE static void + PingPongGemmNDim(const std::array& a_ptr, + const std::array& b_ptr, + const std::array& d_ptr, + EDataType* e_ptr, + void* smem_ptr_0, + void* smem_ptr_1, + void* smem_ptr_2, + const KernelArgs& kargs, + const SplitKBatchOffset& splitk_batch_offset, + [[maybe_unused]] const index_t block_idx_n, + [[maybe_unused]] const index_t block_idx_k) { - const auto blockId = __builtin_amdgcn_readfirstlane(blockIdx.x); - const auto kBlocks = __builtin_amdgcn_readfirstlane(integer_divide_ceil( - kargs.K, TilePartitioner::KPerBlock)); - auto idx_n = __builtin_amdgcn_readfirstlane(blockId / kBlocks); - auto idx_k = __builtin_amdgcn_readfirstlane(blockId % kBlocks); - auto n_offset = __builtin_amdgcn_readfirstlane(idx_n * TilePartitioner::NPerBlock); + const auto blockId = __builtin_amdgcn_readfirstlane(blockIdx.x); + + const auto kBlocks = __builtin_amdgcn_readfirstlane( + integer_divide_ceil(kargs.K, TilePartitioner::KPerBlock)); + + auto idx_m = __builtin_amdgcn_readfirstlane(blockId / kBlocks); + auto idx_k = __builtin_amdgcn_readfirstlane(blockId % kBlocks); + auto m_offset = __builtin_amdgcn_readfirstlane(idx_m * TilePartitioner::MPerBlock); auto k_offset = __builtin_amdgcn_readfirstlane(idx_k * TilePartitioner::KPerBlock); - //auto idx_k = __builtin_amdgcn_readfirstlane(blockId / kargs.N); - //auto idx_n = __builtin_amdgcn_readfirstlane(blockId % TilePartitioner::NPerBlock); - - //auto n_offset = __builtin_amdgcn_readfirstlane(idx_n * TilePartitioner::NPerBlock); - //auto k_offset = __builtin_amdgcn_readfirstlane(idx_k * TilePartitioner::KPerBlock); - // Create Gemm tensor views, pad views and tile windows const auto& gemm_tensor_views_tuple = MakeGemmTensorViews( a_ptr, b_ptr, d_ptr, e_ptr, kargs, splitk_batch_offset); - const auto& gemm_pad_views = - MakePingPongGemmPadViews(gemm_tensor_views_tuple); - auto gemm_tile_windows = - MakePingPongGemmTileWindows(gemm_pad_views, n_offset, k_offset, kargs.M, kargs.N, kargs.K); + const auto& gemm_pad_views = MakePingPongGemmPadViews(gemm_tensor_views_tuple); + auto gemm_tile_windows = MakePingPongGemmTileWindowsNParallel( + gemm_pad_views, m_offset, k_offset, kargs.M, kargs.N, kargs.K); - const index_t num_loop = __builtin_amdgcn_readfirstlane(integer_divide_ceil( - //kargs.M, TilePartitioner::MPerBlock * GemmPipeline::BlockGemmShape::NumWarps)); - kargs.M, TilePartitioner::MPerBlock)); + const index_t num_loop = __builtin_amdgcn_readfirstlane( + integer_divide_ceil(kargs.N, TilePartitioner::NPerBlock)); // Run GEMM cooperatively by whole workgroup. const auto& a_block_window = gemm_tile_windows.at(I0); @@ -1193,21 +1293,89 @@ struct UniversalGemmKernel const auto& d_block_window = gemm_tile_windows.at(I2); auto& e_block_window = gemm_tile_windows.at(I3); - - const auto EpilogueFunc = [&](auto &out_window, auto& tile, auto &ds_window, auto execute_epilogue) { - EpiloguePipeline{}.template operator()( - out_window, tile, ds_window, smem_ptr_0, execute_epilogue); - }; - + const auto EpilogueFunc = + [&](auto& out_window, auto& tile, auto& ds_window, auto execute_epilogue) { + EpiloguePipeline{} + .template operator()( + out_window, tile, ds_window, smem_ptr_2, execute_epilogue); + }; + + GemmPipeline{}.template operator()(a_block_window[I0], + b_block_window[I0], + d_block_window, + e_block_window, + num_loop, + smem_ptr_0, + smem_ptr_1, + EpilogueFunc); + } + + CK_TILE_DEVICE static void + PingPongGemmMDim(const std::array& a_ptr, + const std::array& b_ptr, + const std::array& d_ptr, + EDataType* e_ptr, + void* smem_ptr_0, + const KernelArgs& kargs, + const SplitKBatchOffset& splitk_batch_offset, + [[maybe_unused]] const index_t block_idx_n, + [[maybe_unused]] const index_t block_idx_k) + { + const auto blockId = __builtin_amdgcn_readfirstlane(blockIdx.x); + const auto kBlocks = __builtin_amdgcn_readfirstlane( + integer_divide_ceil(kargs.K, TilePartitioner::KPerBlock)); + auto idx_n = __builtin_amdgcn_readfirstlane(blockId / kBlocks); + auto idx_k = __builtin_amdgcn_readfirstlane(blockId % kBlocks); + auto n_offset = __builtin_amdgcn_readfirstlane(idx_n * TilePartitioner::NPerBlock); + auto k_offset = __builtin_amdgcn_readfirstlane(idx_k * TilePartitioner::KPerBlock); + + // auto idx_k = __builtin_amdgcn_readfirstlane(blockId / kargs.N); + // auto idx_n = __builtin_amdgcn_readfirstlane(blockId % TilePartitioner::NPerBlock); + + // auto n_offset = __builtin_amdgcn_readfirstlane(idx_n * TilePartitioner::NPerBlock); + // auto k_offset = __builtin_amdgcn_readfirstlane(idx_k * TilePartitioner::KPerBlock); + + // Create Gemm tensor views, pad views and tile windows + const auto& gemm_tensor_views_tuple = + MakeGemmTensorViews( + a_ptr, b_ptr, d_ptr, e_ptr, kargs, splitk_batch_offset); + + const auto& gemm_pad_views = MakePingPongGemmPadViews(gemm_tensor_views_tuple); + auto gemm_tile_windows = MakePingPongGemmTileWindowsMParallel( + gemm_pad_views, n_offset, k_offset, kargs.M, kargs.N, kargs.K); + + const index_t num_loop = __builtin_amdgcn_readfirstlane(integer_divide_ceil( + // kargs.M, TilePartitioner::MPerBlock * GemmPipeline::BlockGemmShape::NumWarps)); + kargs.M, + TilePartitioner::MPerBlock)); + + // Run GEMM cooperatively by whole workgroup. + const auto& a_block_window = gemm_tile_windows.at(I0); + const auto& b_block_window = gemm_tile_windows.at(I1); + const auto& d_block_window = gemm_tile_windows.at(I2); + auto& e_block_window = gemm_tile_windows.at(I3); + + const auto EpilogueFunc = + [&](auto& out_window, auto& tile, auto& ds_window, auto execute_epilogue) { + EpiloguePipeline{} + .template operator()( + out_window, tile, ds_window, smem_ptr_0, execute_epilogue); + }; + /* const auto EpilogueFunc = [&](auto &out_window, auto& tile) { EpiloguePipeline{}.template operator()( out_window, tile); - }; + }; */ - GemmPipeline{}.template operator()( - a_block_window[I0], b_block_window[I0], d_block_window, e_block_window, num_loop, smem_ptr_0, EpilogueFunc); - } + GemmPipeline{}.template operator()(a_block_window[I0], + b_block_window[I0], + d_block_window, + e_block_window, + num_loop, + smem_ptr_0, + EpilogueFunc); + } /** * @brief Runs single GEMM problem cooperatively by whole workgroup. @@ -1296,9 +1464,23 @@ struct UniversalGemmKernel // allocate LDS __shared__ char smem_ptr_0[GetSmemSize()]; + __shared__ char smem_ptr_1[GetSmemSize()]; + __shared__ char smem_ptr_2[GetSmemSize()]; + PingPongGemmNDim(as_ptr, + bs_ptr, + kargs.ds_ptr, + es_ptr, + smem_ptr_0, + smem_ptr_1, + smem_ptr_2, + kargs, + splitk_batch_offset, + i_n, + i_k); - PingPongGemm( - as_ptr, bs_ptr, kargs.ds_ptr, es_ptr, smem_ptr_0, kargs, splitk_batch_offset, i_n, i_k); + // PingPongGemmMDim( + // as_ptr, bs_ptr, kargs.ds_ptr, es_ptr, smem_ptr_0, kargs, splitk_batch_offset, i_n, + // i_k); } // Persistent kernel entry point diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v5.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v5.hpp index c5f3be5c43..da535b5d35 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v5.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v5.hpp @@ -124,6 +124,238 @@ struct GemmPipelineAgBgCrCompV5 : public BaseGemmPipelineAgBgCrCompV5 { using Base = PipelineImplBase; + template + CK_TILE_DEVICE auto + operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + [[maybe_unused]] const AElementFunction& a_element_func, + const BDramBlockWindowTmp& b_dram_block_window_tmp, + [[maybe_unused]] const BElementFunction& b_element_func, + [[maybe_unused]] const DDramBlockWindowTmp& d_dram_block_window_tmp, + [[maybe_unused]] CDramBlockWindowTmp& c_dram_block_window_tmp, + [[maybe_unused]] index_t num_loop, + void* __restrict__ p_smem_0, + [[maybe_unused]] void* __restrict__ p_smem_1, + [[maybe_unused]] const EpilogueFunction& epilogue_func + ) const + { + + [[maybe_unused]] constexpr bool is_a_col_major = + std::is_same_v; + [[maybe_unused]] constexpr bool is_b_row_major = + std::is_same_v; + [[maybe_unused]] constexpr bool is_c_col_major = + std::is_same_v; + + static_assert(NumWaveGroups == 2); + + index_t warp_id = get_warp_id(); + [[maybe_unused]] index_t operation_id = __builtin_amdgcn_readfirstlane((get_warp_id() + 1) % NumWaveGroups); + + [[maybe_unused]] auto b_offset = (warp_id == 0) ? make_array(0, 0) : make_array(NPerBlock, 0); // column major + [[maybe_unused]] auto c_offset = (warp_id == 0) ? make_array(0, 0) : make_array(0, NPerBlock); // row major + + [[maybe_unused]] auto tensor_views = + Base::GetABLdsTensorViews(static_cast(static_cast(p_smem_0))); + [[maybe_unused]] auto& a_lds_block = tensor_views.get(number<0>{}); + [[maybe_unused]] auto& b_lds_block = tensor_views.get(number<1>{}); + + [[maybe_unused]] constexpr auto a_lds_laod_tile_distr = + make_static_tile_distribution(BlockGemm::MakeABlockDistributionEncode()); + [[maybe_unused]] constexpr auto b_lds_load_tile_distr = + make_static_tile_distribution(BlockGemm::MakeBBlockDistributionEncode()); + + [[maybe_unused]] auto a_windows = + Base::GetAWindows(a_dram_block_window_tmp, a_lds_block, a_lds_laod_tile_distr); + [[maybe_unused]] auto& a_copy_dram_window = a_windows.get(number<0>{}); + [[maybe_unused]] auto& a_copy_lds_window = a_windows.get(number<1>{}); + [[maybe_unused]] auto& a_lds_window = a_windows.get(number<2>{}); + + [[maybe_unused]] auto b_windows = + Base::GetBWindows(b_dram_block_window_tmp, b_lds_block, b_lds_load_tile_distr, b_offset); + [[maybe_unused]] auto& b_copy_dram_window = b_windows.get(number<0>{}); + [[maybe_unused]] auto& b_copy_lds_window = b_windows.get(number<1>{}); + [[maybe_unused]] auto& b_lds_window = b_windows.get(number<2>{}); + + [[maybe_unused]] auto epilogue_dram_window = + make_tile_window(c_dram_block_window_tmp.get_bottom_tensor_view(), + make_tuple(MPerBlock, NPerBlock), + c_dram_block_window_tmp.get_window_origin() + c_offset); + + // DRAM window steps. + using ADramTileWindowStep = typename ADramBlockWindowTmp::BottomTensorIndex; + [[maybe_unused]] constexpr ADramTileWindowStep a_dram_tile_window_step = make_array(0, 0); // A is constant. + + using BDramTileWindowStep = typename BDramBlockWindowTmp::BottomTensorIndex; + [[maybe_unused]] constexpr BDramTileWindowStep b_dram_tile_window_step = + is_b_row_major ? make_array(0, NPerBlock * NumWarps) // (k, N) + : make_array(NPerBlock * NumWarps, 0); // (N, K) + + using CDramBlockWindowStep = typename CDramBlockWindowTmp::BottomTensorIndex; + [[maybe_unused]] constexpr CDramBlockWindowStep c_dram_tile_window_step = + is_c_col_major ? make_array(NPerBlock * NumWarps, 0) : make_array(0, NPerBlock * NumWarps); + + [[maybe_unused]] constexpr auto AGemmTileDistr = decltype(make_static_tile_distribution( + BlockGemm::MakeABlockDistributionEncode())){}; + [[maybe_unused]] constexpr auto BGemmTileDistr = decltype(make_static_tile_distribution( + BlockGemm::MakeBBlockDistributionEncode())){}; + + using AGemmTile = decltype(make_static_distributed_tensor(AGemmTileDistr)); + using BGemmTile = decltype(make_static_distributed_tensor(BGemmTileDistr)); + + [[maybe_unused]] AGemmTile a_tile; + [[maybe_unused]] BGemmTile b_tile_0, b_tile_1; + + // Register tiles for A and B. + using ABlockTileDistr = + decltype(a_copy_dram_window.get_tile_distribution()); + using BBlockTileDistr = + decltype(b_copy_dram_window.get_tile_distribution()); + + using ABlockTile = + decltype(make_static_distributed_tensor(ABlockTileDistr{})); + using BBlockTile = + decltype(make_static_distributed_tensor(BBlockTileDistr{})); + + [[maybe_unused]] ABlockTile a_dram_tile; + [[maybe_unused]] BBlockTile b_dram_tile; + + // Block GEMM + auto block_gemm = BlockGemm(); + auto c_block_tile_0 = block_gemm.MakeCBlockTile(); + //auto c_block_tile_1 = block_gemm.MakeCBlockTile(); + + [[maybe_unused]] auto ReadA = [&](){ + + Base::GlobalPrefetch(a_dram_tile, a_copy_dram_window, a_dram_tile_window_step); + Base::LocalPrefill(a_copy_lds_window, a_dram_tile, a_element_func); + Base::LocalPrefetch(a_tile, a_lds_window); + //tile_elementwise_inout([](auto& c) { c = 5; }, a_tile); + }; + + [[maybe_unused]] auto ReadB = [&](auto idx) + { + Base::GlobalPrefetch(b_dram_tile, b_copy_dram_window, b_dram_tile_window_step); + Base::LocalPrefill(b_copy_lds_window, b_dram_tile, b_element_func); + if (idx == 0) + { + Base::LocalPrefetch(b_tile_0, b_lds_window); + //tile_elementwise_inout([](auto& c) { c = 1; }, b_tile_0); + } + else + { + Base::LocalPrefetch(b_tile_1, b_lds_window); + //tile_elementwise_inout([](auto& c) { c = 2; }, b_tile_1); + } + }; + + [[maybe_unused]] auto ComputeStep = [&](auto idx){ + if (idx == 0) + { + c_block_tile_0 = block_gemm(a_tile, b_tile_0); + } + else + { + c_block_tile_0 = block_gemm(a_tile, b_tile_1); + } + }; + + /* + ReadA(); + if (warp_id == 0) + { + ReadB(warp_id); + } + __syncthreads(); + + if (warp_id == 0) + { + ComputeStep(warp_id); + } + else + { + ReadB(warp_id); + } + __syncthreads(); + epilogue_func(epilogue_dram_window, c_block_tile_0, d_dram_block_window_tmp, (warp_id == 0)); + __syncthreads(); + + if (warp_id == 1) + { + ComputeStep(warp_id); + } + __syncthreads(); + epilogue_func(epilogue_dram_window, c_block_tile_0, d_dram_block_window_tmp, (warp_id == 1)); + + + + if (warp_id == 1) + { + //tile_elementwise_inout([](auto& c) { c = 5; }, a_tile); + ReadA(); + //tile_elementwise_inout([](auto& c) { c = 1; }, b_tile_1); + ReadB(warp_id); + ComputeStep(warp_id); + + //store_tile(epilogue_dram_window, cast_tile(c_block_tile_0)); + epilogue_func(epilogue_dram_window, c_block_tile_0, d_dram_block_window_tmp, (operation_id == 0)); + } + */ + + __syncthreads(); + // Read constant A. + ReadA(); + //Read B + if (operation_id == 0) + { + ReadB(warp_id); + } + + index_t num_steps = __builtin_amdgcn_readfirstlane(num_loop); + while(num_steps > 1){ + block_sync_lds(); + operation_id = (operation_id + 1) % NumWaveGroups; + + if(operation_id == 0) + { + ReadB(warp_id); + } + else + { + ComputeStep(warp_id); + } + __syncthreads(); + num_steps -= 1; + + epilogue_func(epilogue_dram_window, c_block_tile_0, d_dram_block_window_tmp, (operation_id == 1)); + if (operation_id == 1) + { + move_tile_window(epilogue_dram_window, c_dram_tile_window_step); + } + } + + if(operation_id == 0) + { + ComputeStep(warp_id); + } + + epilogue_func(epilogue_dram_window, c_block_tile_0, d_dram_block_window_tmp, (operation_id == 0)); + if (operation_id == 0) + { + move_tile_window(epilogue_dram_window, c_dram_tile_window_step); + } + } + + + // M Dimension parallelism here. template } }; + /* + N Dimension parallelism here. + */ + template + CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + const AElementFunction& a_element_func, + const BDramBlockWindowTmp& b_dram_block_window_tmp, + const BElementFunction& b_element_func, + const DDramBlockWindowTmp& d_dram_block_window_tmp, + const CDramBlockWindowTmp& c_dram_block_window_tmp, + index_t num_loop, + void* p_smem_0, + void* p_smem_1, + const EpilogueFunction& epilogue_func) const + { + return PipelineImpl{} + .template operator()(a_dram_block_window_tmp, + a_element_func, + b_dram_block_window_tmp, + b_element_func, + d_dram_block_window_tmp, + c_dram_block_window_tmp, + num_loop, + p_smem_0, + p_smem_1, + epilogue_func); + } + + public: + template + CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + const BDramBlockWindowTmp& b_dram_block_window_tmp, + const DDramBlockWindowTmp& d_dram_block_window_tmp, + const CDramBlockWindowTmp& c_dram_block_window_tmp, + const index_t num_loop, + void* __restrict__ p_smem_0, + void* __restrict__ p_smem_1, + const EpilogueFunction& epilogue_func) const + { + return PipelineImpl{} + .template operator()( + a_dram_block_window_tmp, + [](const ADataType& a) { return a; }, + b_dram_block_window_tmp, + [](const BDataType& b) { return b; }, + d_dram_block_window_tmp, + c_dram_block_window_tmp, + num_loop, + p_smem_0, + p_smem_1, + epilogue_func); + } + + /* + // M dimensional parallelism + template p_smem_0, epilogue_func); } + */ }; } // namespace ck_tile