From 0b73f19a8038306ca0002971bcb239c0d6eddcaa Mon Sep 17 00:00:00 2001 From: jakpiase Date: Wed, 27 Nov 2024 18:25:07 +0100 Subject: [PATCH] Add interwave scheduler for gemm mem pipeline (#1647) * add interwave scheduler for gemm mem pipeline * Fix merge artifacts. * Refactor unit tests. * Switch to interwave scheduler for mem example --------- Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com> Co-authored-by: Adam Osewski [ROCm/composable_kernel commit: e7b6286441aae59d3a87db67f42369d3cc2636a4] --- example/ck_tile/03_gemm/gemm_mem_pipeline.cpp | 3 +- example/ck_tile/03_gemm/run_gemm_example.inc | 3 +- .../pipeline/gemm_pipeline_ag_bg_cr_mem.hpp | 224 ++++++++++++++++++ test/ck_tile/gemm/test_gemm_mem_pipeline.cpp | 19 +- .../gemm/test_gemm_mem_pipeline_ut_cases.inc | 59 ++++- .../gemm/test_gemm_mem_pipeline_util.hpp | 25 +- 6 files changed, 311 insertions(+), 22 deletions(-) diff --git a/example/ck_tile/03_gemm/gemm_mem_pipeline.cpp b/example/ck_tile/03_gemm/gemm_mem_pipeline.cpp index 97d150412d..cd9d9d96b6 100644 --- a/example/ck_tile/03_gemm/gemm_mem_pipeline.cpp +++ b/example/ck_tile/03_gemm/gemm_mem_pipeline.cpp @@ -30,7 +30,6 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s) constexpr ck_tile::index_t M_Warp_Tile = 32; constexpr ck_tile::index_t N_Warp_Tile = 32; constexpr ck_tile::index_t K_Warp_Tile = 8; - #else // Compute friendly for Intrawave scheduler constexpr ck_tile::index_t M_Tile = 256; @@ -84,7 +83,7 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s) AccDataType, GemmShape, Traits, - ck_tile::GemmPipelineScheduler::Intrawave, + ck_tile::GemmPipelineScheduler::Interwave, has_hot_loop_v, tail_number_v>>; using Kernel = ck_tile::GemmKernel; diff --git a/example/ck_tile/03_gemm/run_gemm_example.inc b/example/ck_tile/03_gemm/run_gemm_example.inc index 5199c1e3ef..a1fc155775 100644 --- a/example/ck_tile/03_gemm/run_gemm_example.inc +++ b/example/ck_tile/03_gemm/run_gemm_example.inc @@ -200,7 +200,8 @@ int run_gemm_example(int argc, char* argv[]) return run_gemm_example_with_layouts(argc, argv, Row{}, Col{}, Row{}); } // TODO: Fixme: with latest changes to GemmPipelineAGmemBGmemCRegV1DefaultPolicy below do not - // work. else if(a_layout == "C" && b_layout == "C") + // work. + // else if(a_layout == "C" && b_layout == "C") // { // return run_gemm_example_with_layouts(argc, argv, Col{}, Col{}, Row{}); // } diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp index 4634e9dcb9..847c5b187d 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp @@ -322,6 +322,7 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); block_sync_lds(); + LocalPrefill(a_copy_lds_window, a_block_tiles.get(number{}), a_element_func); @@ -374,6 +375,229 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem } }; + template <> + struct PipelineImpl + { + template + CK_TILE_DEVICE void GlobalPrefetch(DstBlockTile& dst_block_tile, + SrcTileWindow& dram_tile_window) const + { + load_tile(dst_block_tile, dram_tile_window); + move_tile_window(dram_tile_window, {0, KPerBlock}); + } + + template + CK_TILE_DEVICE void LocalPrefill(DstTileWindow& lds_tile_window, + const SrcBlockTile& src_block_tile, + const ElementFunction& element_func) const + { + const auto block_tile_tmp = tile_elementwise_in(element_func, src_block_tile); + store_tile(lds_tile_window, block_tile_tmp); + } + + template + CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + const AElementFunction& a_element_func, + const BDramBlockWindowTmp& b_dram_block_window_tmp, + const BElementFunction& b_element_func, + index_t num_loop, + void* p_smem) const + { + static_assert( + std::is_same_v> && + std::is_same_v>, + "A/B Dram block window should have the same data type as appropriate " + "([A|B]DataType) defined in Problem definition!"); + + static_assert(MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + NPerBlock == + BDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<1>{}], + "A/B block window appropriate sizes must be equal to MPerBlock/NPerblock" + " or KPerBlock!"); + + // ------------------------------------------------------------------------------------ + // Definitions of all needed tiles + + // A tile in LDS + ADataType* p_a_lds = static_cast(p_smem); + constexpr auto a_lds_block_desc = Policy::template MakeALdsBlockDescriptor(); + auto a_lds_block = make_tensor_view(p_a_lds, a_lds_block_desc); + + // TODO: LDS alignment should come from Policy! + constexpr index_t a_lds_block_space_size_aligned = + integer_divide_ceil(sizeof(ADataType) * a_lds_block_desc.get_element_space_size(), + 16) * + 16; + + // B tile in LDS + BDataType* p_b_lds = static_cast( + static_cast(static_cast(p_smem) + a_lds_block_space_size_aligned)); + constexpr auto b_lds_block_desc = Policy::template MakeBLdsBlockDescriptor(); + auto b_lds_block = make_tensor_view(p_b_lds, b_lds_block_desc); + + // A DRAM tile window for load + auto a_copy_dram_window = + make_tile_window(a_dram_block_window_tmp.get_bottom_tensor_view(), + make_tuple(number{}, number{}), + a_dram_block_window_tmp.get_window_origin(), + Policy::template MakeADramTileDistribution()); + + // A LDS tile window for store + auto a_copy_lds_window = + make_tile_window(a_lds_block, + make_tuple(number{}, number{}), + {0, 0}, + a_copy_dram_window.get_tile_distribution()); + // B DRAM tile window for load + auto b_copy_dram_window = + make_tile_window(b_dram_block_window_tmp.get_bottom_tensor_view(), + make_tuple(number{}, number{}), + b_dram_block_window_tmp.get_window_origin(), + Policy::template MakeBDramTileDistribution()); + + // B LDS tile window for store + auto b_copy_lds_window = + make_tile_window(b_lds_block, + make_tuple(number{}, number{}), + {0, 0}, + b_copy_dram_window.get_tile_distribution()); + + // A LDS tile for block GEMM + auto a_lds_gemm_window = make_tile_window( + a_lds_block, make_tuple(number{}, number{}), {0, 0}); + // B LDS tile for block GEMM + auto b_lds_gemm_window = make_tile_window( + b_lds_block, make_tuple(number{}, number{}), {0, 0}); + + // Block GEMM + auto block_gemm = BlockGemm(); + auto c_block_tile = block_gemm.MakeCBlockTile(); + + using ABlockTileDistr = decltype(a_copy_dram_window.get_tile_distribution()); + using BBlockTileDistr = decltype(b_copy_dram_window.get_tile_distribution()); + + using ABlockTile = + decltype(make_static_distributed_tensor(ABlockTileDistr{})); + using BBlockTile = + decltype(make_static_distributed_tensor(BBlockTileDistr{})); + + tuple_array a_block_tiles; + tuple_array b_block_tiles; + + // ----------------------------------------------------------------------------------------- + // Gemm pipeline start + + // prefetch + // global read 0 + GlobalPrefetch(a_block_tiles.get(I0{}), a_copy_dram_window); + GlobalPrefetch(b_block_tiles.get(I0{}), b_copy_dram_window); + + // initialize C + tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); + + // LDS write 0 + LocalPrefill(a_copy_lds_window, a_block_tiles.get(I0{}), a_element_func); + LocalPrefill(b_copy_lds_window, b_block_tiles.get(I0{}), b_element_func); + + // Global prefetch [1, PrefetchStages] + static_for<1, PrefetchStages, 1>{}([&](auto prefetch_idx) { + GlobalPrefetch(a_block_tiles.get(number{}), a_copy_dram_window); + GlobalPrefetch(b_block_tiles.get(number{}), b_copy_dram_window); + }); + + // main body + if constexpr(HasHotLoop) + { + index_t i = 0; + do + { + static_for<0, PrefetchStages, 1>{}([&](auto prefetch_idx) { + block_sync_lds(); + block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); + // no second block_sync_lds because it's interwave + + LocalPrefill( + a_copy_lds_window, + a_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{}), + a_element_func); + LocalPrefill( + b_copy_lds_window, + b_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{}), + b_element_func); + + GlobalPrefetch(a_block_tiles.get(number{}), + a_copy_dram_window); + GlobalPrefetch(b_block_tiles.get(number{}), + b_copy_dram_window); + }); + + i += PrefetchStages; + } while(i < (num_loop - PrefetchStages)); + } + + auto HotLoopTail = [&](auto tail_num) { + static_for<1, tail_num, 1>{}([&](auto prefetch_idx) { + block_sync_lds(); + block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); + // no second block_sync_lds because it's interwave + + LocalPrefill(a_copy_lds_window, + a_block_tiles.get(number{}), + a_element_func); + LocalPrefill(b_copy_lds_window, + b_block_tiles.get(number{}), + b_element_func); + }); + + block_sync_lds(); + block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); + }; + + if constexpr(TailNum == TailNumber::One) + { + block_sync_lds(); + block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); + } + else if constexpr(TailNum == TailNumber::Two) + { + HotLoopTail(number<2>{}); + } + else if constexpr(TailNum == TailNumber::Three) + { + HotLoopTail(number<3>{}); + } + else if constexpr(TailNum == TailNumber::Four) + { + HotLoopTail(number<4>{}); + } + else if constexpr(TailNum == TailNumber::Five) + { + HotLoopTail(number<5>{}); + } + else if constexpr(TailNum == TailNumber::Six) + { + HotLoopTail(number<6>{}); + } + else if constexpr(TailNum == TailNumber::Seven) + { + HotLoopTail(number<7>{}); + } + else if constexpr(TailNum == TailNumber::Full) + { + HotLoopTail(number{}); + } + + return c_block_tile; + } + }; + template +class TestCkTileGemmMemPipelineIntrawave : public TestCkTileGemmMemPipeline +{ +}; + +template +class TestCkTileGemmMemPipelineInterwave : public TestCkTileGemmMemPipeline +{ +}; // clang-format off using KernelTypes = ::testing::Types< @@ -24,6 +36,7 @@ using KernelTypes = ::testing::Types< >; // clang-format on -TYPED_TEST_SUITE(TestCkTileGemmMemPipeline, KernelTypes); +TYPED_TEST_SUITE(TestCkTileGemmMemPipelineIntrawave, KernelTypes); +TYPED_TEST_SUITE(TestCkTileGemmMemPipelineInterwave, KernelTypes); #include "test_gemm_mem_pipeline_ut_cases.inc" diff --git a/test/ck_tile/gemm/test_gemm_mem_pipeline_ut_cases.inc b/test/ck_tile/gemm/test_gemm_mem_pipeline_ut_cases.inc index b26114f39d..6b914e7975 100644 --- a/test/ck_tile/gemm/test_gemm_mem_pipeline_ut_cases.inc +++ b/test/ck_tile/gemm/test_gemm_mem_pipeline_ut_cases.inc @@ -1,6 +1,13 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + #pragma once -TYPED_TEST(TestCkTileGemmMemPipeline, SmallM) +//------------------------------------------------------------------------------------------------ +// INTERWAVE SCHEDULER +//------------------------------------------------------------------------------------------------ + +TYPED_TEST(TestCkTileGemmMemPipelineInterwave, SmallM) { std::vector Ms{1, 2, 3, 4, 5, 6}; constexpr int N = 1024; @@ -10,7 +17,7 @@ TYPED_TEST(TestCkTileGemmMemPipeline, SmallM) this->Run(M, N, K); } -TYPED_TEST(TestCkTileGemmMemPipeline, MidLargeM) +TYPED_TEST(TestCkTileGemmMemPipelineInterwave, MidLargeM) { std::vector Ms{127, 255, 312, 799, 1573}; constexpr int N = 1024; @@ -20,7 +27,7 @@ TYPED_TEST(TestCkTileGemmMemPipeline, MidLargeM) this->Run(M, N, K); } -TYPED_TEST(TestCkTileGemmMemPipeline, PaddK) +TYPED_TEST(TestCkTileGemmMemPipelineInterwave, PaddK) { std::vector Ms{127}; constexpr int N = 1024; @@ -30,7 +37,51 @@ TYPED_TEST(TestCkTileGemmMemPipeline, PaddK) this->Run(M, N, K); } -TYPED_TEST(TestCkTileGemmMemPipeline, Regular) +TYPED_TEST(TestCkTileGemmMemPipelineInterwave, Regular) +{ + std::vector Ms{512}; + constexpr int N = 1024; + constexpr int K = 512; + + for(int M : Ms) + this->Run(M, N, K); +} + +//------------------------------------------------------------------------------------------------ +// INTRAWAVE SCHEDULER +//------------------------------------------------------------------------------------------------ + +TYPED_TEST(TestCkTileGemmMemPipelineIntrawave, SmallM) +{ + std::vector Ms{1, 2, 3, 4, 5, 6}; + constexpr int N = 1024; + constexpr int K = 320; + + for(int M : Ms) + this->Run(M, N, K); +} + +TYPED_TEST(TestCkTileGemmMemPipelineIntrawave, MidLargeM) +{ + std::vector Ms{127, 255, 312, 799, 1573}; + constexpr int N = 1024; + constexpr int K = 320; + + for(int M : Ms) + this->Run(M, N, K); +} + +TYPED_TEST(TestCkTileGemmMemPipelineIntrawave, PaddK) +{ + std::vector Ms{127}; + constexpr int N = 1024; + constexpr int K = 432; + + for(int M : Ms) + this->Run(M, N, K); +} + +TYPED_TEST(TestCkTileGemmMemPipelineIntrawave, Regular) { std::vector Ms{512}; constexpr int N = 1024; diff --git a/test/ck_tile/gemm/test_gemm_mem_pipeline_util.hpp b/test/ck_tile/gemm/test_gemm_mem_pipeline_util.hpp index 6b47898339..15f9f516ee 100644 --- a/test/ck_tile/gemm/test_gemm_mem_pipeline_util.hpp +++ b/test/ck_tile/gemm/test_gemm_mem_pipeline_util.hpp @@ -11,20 +11,21 @@ #include "ck_tile/ops/epilogue.hpp" #include "ck_tile/ops/gemm.hpp" -template +template class TestCkTileGemmMemPipeline : public ::testing::Test { protected: - using ALayout = std::tuple_element_t<0, Tuple>; - using BLayout = std::tuple_element_t<1, Tuple>; - using CLayout = std::tuple_element_t<2, Tuple>; - using ADataType = std::tuple_element_t<3, Tuple>; - using BDataType = std::tuple_element_t<4, Tuple>; - using AccDataType = std::tuple_element_t<5, Tuple>; - using CDataType = std::tuple_element_t<6, Tuple>; + using ALayout = std::tuple_element_t<0, Tuple>; + using BLayout = std::tuple_element_t<1, Tuple>; + using CLayout = std::tuple_element_t<2, Tuple>; + using ADataType = std::tuple_element_t<3, Tuple>; + using BDataType = std::tuple_element_t<4, Tuple>; + using AccDataType = std::tuple_element_t<5, Tuple>; + using CDataType = std::tuple_element_t<6, Tuple>; + static constexpr auto Scheduler = Scheduler_; // TODO: expose tile size through test t-param ? - struct gemm_basic_args + struct gemm_args { const void* p_a; const void* p_b; @@ -38,7 +39,7 @@ class TestCkTileGemmMemPipeline : public ::testing::Test ck_tile::index_t stride_C; }; - void invoke_gemm(const gemm_basic_args& args, const ck_tile::stream_config& s) + void invoke_gemm(const gemm_args& args, const ck_tile::stream_config& s) { // TODO: This should be parameterized in tests constexpr ck_tile::index_t M_Tile = 128; @@ -89,7 +90,7 @@ class TestCkTileGemmMemPipeline : public ::testing::Test AccDataType, GemmShape, Traits, - ck_tile::GemmPipelineScheduler::Intrawave, + Scheduler, has_hot_loop_v, tail_number_v>>; using Kernel = ck_tile::GemmKernel; @@ -288,7 +289,7 @@ class TestCkTileGemmMemPipeline : public ::testing::Test c_m_n_dev_buf.SetZero(); c_m_n_dev_result.SetZero(); - gemm_basic_args args; + gemm_args args; args.p_a = a_m_k_dev_buf.GetDeviceBuffer(); args.p_b = b_k_n_dev_buf.GetDeviceBuffer(); args.p_c = c_m_n_dev_buf.GetDeviceBuffer();