From afdd6a84a7e7faef195a46b60b8ddcc4577a7c5a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ville=20Pietil=C3=A4?= <> Date: Tue, 10 Feb 2026 10:03:41 -0500 Subject: [PATCH] WIP: Double buffer implementation. --- .../gridwise_gemm_multiple_d_xdl_cshuffle.hpp | 27 +++- .../gpu/grid/gridwise_gemm_pipeline_v1.hpp | 124 +++++++++--------- 2 files changed, 86 insertions(+), 65 deletions(-) diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp index 8efa0e355d..428b1657b0 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp @@ -357,7 +357,8 @@ struct GridwiseGemmMultipleD_xdl_cshuffle #if !defined(__HIPCC_RTC__) || !defined(CK_CODE_GEN_RTC) if(ck::get_device_name() == "gfx950") { - return Base::GetSharedMemoryNumberOfByte(gfx950_t{}); + // Double buffering -> 2 times shared memory + return 2*Base::GetSharedMemoryNumberOfByte(gfx950_t{}); } else #endif @@ -755,12 +756,24 @@ struct GridwiseGemmMultipleD_xdl_cshuffle // LDS allocation for A and B: be careful of alignment constexpr auto a_block_space_size_aligned = math::integer_least_multiple( a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + constexpr auto b_block_space_size_aligned = math::integer_least_multiple( + b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align); - auto a_block_buf = make_dynamic_buffer( + // Double buffers for A and B in LDS + auto a_block_buf_0 = make_dynamic_buffer( static_cast(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); - auto b_block_buf = make_dynamic_buffer( - static_cast(p_shared) + a_block_space_size_aligned, + auto a_block_buf_1 = make_dynamic_buffer( + static_cast(p_shared) + a_block_space_size_aligned, + a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + + auto b_block_buf_0 = make_dynamic_buffer( + static_cast(p_shared) + 2*a_block_space_size_aligned, + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + + auto b_block_buf_1 = make_dynamic_buffer( + static_cast(p_shared) + 2*a_block_space_size_aligned + + b_block_space_size_aligned, b_block_desc_bk0_n_bk1.GetElementSpaceSize()); constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1, 0, 0); @@ -778,13 +791,15 @@ struct GridwiseGemmMultipleD_xdl_cshuffle a_block_desc_ak0_m_ak1, a_blockwise_copy, a_grid_buf, - a_block_buf, + a_block_buf_0, + a_block_buf_1, a_block_slice_copy_step, b_grid_desc_bk0_n_bk1, b_block_desc_bk0_n_bk1, b_blockwise_copy, b_grid_buf, - b_block_buf, + b_block_buf_0, + b_block_buf_1, b_block_slice_copy_step, blockwise_gemm, c_thread_buf, diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp index 1262029f21..af88874f4f 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp @@ -113,10 +113,9 @@ struct GridwiseGemmPipeline_v1<2, true, true> static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; - __host__ __device__ static constexpr bool IsSupported(index_t num_loop) + __host__ __device__ static constexpr bool IsSupported(index_t) { - // TODO: improve applicability - return num_loop % 2 == 0; + return true; } __host__ __device__ static constexpr bool CalculateHasMainLoop(index_t num_loop) @@ -124,6 +123,11 @@ struct GridwiseGemmPipeline_v1<2, true, true> return (num_loop / 2) > 1; } + __host__ __device__ static constexpr bool CalculateIsOddLoop(index_t num_loop) + { + return (num_loop % 2) == 1; + } + template const ABlockDesc& a_block_desc, ABlockTransfer& a_blockwise_copy, const AGridBuffer& a_grid_buf, - ABlockBuffer& a_block_buf, + ABlockBuffer& a_block_buf_0, + ABlockBuffer& a_block_buf_1, const ABlockTransferStep& a_block_copy_step, const BGridDesc& b_grid_desc, const BBlockDesc& b_block_desc, BBlockTransfer& b_blockwise_copy, const BGridBuffer& b_grid_buf, - BBlockBuffer& b_block_buf, + BBlockBuffer& b_block_buf_0, + BBlockBuffer& b_block_buf_1, const BBlockTransferStep& b_block_copy_step, const BlockwiseGemm& blockwise_gemm, CThreadBuffer& c_thread_buf, index_t num_loop) { - // preload data into LDS + // Prologue - load data into buffer 0 { - // Read 0 + // Read from global mem to registers (I0) a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0); b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I0); - // Move + // Move source slice window for next read (I1) a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); - // Read 1 - a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I1); - b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I1); + // Write from registers to LDS buffer 0 + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf_0); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf_0); } // Initialize C @@ -180,76 +186,76 @@ struct GridwiseGemmPipeline_v1<2, true, true> do { - // Move - a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); - b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); - - // Write i - a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I0); - b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, I0); - - // Read i+2 - a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0); - b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I0); - - // Sync - block_sync_lds(); - - // Gemm i - blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); - - // Sync - block_sync_lds(); - - // Move - a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); - b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); - - // Write i+1 - a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I1); - b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, I1); - - // Read i+3 + // Read from global mem to registers (I1) a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I1); b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I1); - // Sync + // Move source slice window for next read (I0) + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + // Sync LDS to ensure buffer 0 is ready block_sync_lds(); - // Gemm i+1 - blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); + // Run GEMM on buffer 0 while buffer 1 is loading + blockwise_gemm.Run(a_block_buf_0, b_block_buf_0, c_thread_buf); - // Sync + // Write from registers to LDS buffer 1 + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf_1); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf_1); + + // Read from global mem to registers (I0) + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I0); + + // Move source slice window for next read (I1) + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + // Sync LDS to ensure buffer 1 is ready block_sync_lds(); + // Run GEMM on buffer 1 while buffer 0 is loading + blockwise_gemm.Run(a_block_buf_1, b_block_buf_1, c_thread_buf); + + // Write from registers to LDS buffer 0 + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf_0); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf_0); + i += 2; } while(i < (num_loop - 2)); } // tail + if (num_loop % 2 == 0) { - // Write num_loop - 2 - a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I0); - b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, I0); + // Read from global mem to registers (I1) + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I1); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I1); - // Sync + // Sync LDS to ensure buffer 0 is ready block_sync_lds(); - // Gemm num_loop - 2 - blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); + // Run GEMM on buffer 0 + blockwise_gemm.Run(a_block_buf_0, b_block_buf_0, c_thread_buf); - // Sync + // Write from registers to LDS buffer 1 + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf_1); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf_1); + + // Sync LDS to ensure buffer 1 is ready block_sync_lds(); - // Write num_loop - 1 - a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I1); - b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, I1); - - // Sync + // Run GEMM on buffer 1 + blockwise_gemm.Run(a_block_buf_1, b_block_buf_1, c_thread_buf); + } + else + { + // Sync LDS to ensure buffer 0 is ready block_sync_lds(); - // Gemm num_loop - 1 - blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); + // Run GEMM on buffer 0 + blockwise_gemm.Run(a_block_buf_0, b_block_buf_0, c_thread_buf); } } };