From 59159f196269509bbfbaa8dd556ccaefc13efdb9 Mon Sep 17 00:00:00 2001 From: Bartlomiej Kocot Date: Mon, 3 Nov 2025 15:10:55 +0000 Subject: [PATCH] Optimize grouped conv bwd wei split_k off calc (cherry picked from commit 6f61dd56c5d45409826e660175accb51ace24bcc) --- ...conv_bwd_weight_two_stage_xdl_cshuffle.hpp | 88 +++- ...e_grouped_conv_bwd_weight_xdl_cshuffle.hpp | 42 +- ...rouped_conv_bwd_weight_xdl_cshuffle_v3.hpp | 81 +++- .../gridwise_gemm_xdl_cshuffle_conv_v3.hpp | 36 +- .../grid/gridwise_gemm_xdlops_bwd_weight.hpp | 409 ++++++++++++++++++ 5 files changed, 609 insertions(+), 47 deletions(-) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp index 77a321c885..c761e326fa 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp @@ -55,13 +55,20 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) [[maybe_unused]] const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock, [[maybe_unused]] const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, - [[maybe_unused]] const index_t num_k_per_block) + [[maybe_unused]] const index_t num_k_per_block, + const long_index_t split_k_stride_a, + const long_index_t split_k_stride_b, + bool split_k_offset_a_hack, + bool split_k_offset_b_hack) { #if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__) if constexpr(GridwiseGemm::template IsValidCompilationParameter()) { const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.z * NumGroupsToMerge); - const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.y * num_k_per_block); + const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.y); + + const long_index_t split_k_offset_a = split_k_offset_a_hack ? k_idx * split_k_stride_a : 0; + const long_index_t split_k_offset_b = split_k_offset_b_hack ? k_idx * split_k_stride_b : 0; const long_index_t a_batch_offset = amd_wave_read_first_lane( static_cast(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); @@ -77,18 +84,25 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, HasMainKBlockLoop, CGlobalMemoryDataOperation, - TailNum>(karg.p_a_grid + a_batch_offset, - karg.p_b_grid + b_batch_offset, + TailNum>(karg.p_a_grid + a_batch_offset + split_k_offset_a, + karg.p_b_grid + b_batch_offset + split_k_offset_b, karg.p_c_grid + e_batch_offset, p_shared, karg, a_grid_desc_ak0_m_ak1, b_grid_desc_bk0_n_bk1, c_grid_desc_mblock_mperblock_nblock_nperblock, - k_idx); + k_idx * num_k_per_block, + gridDim.y, + split_k_offset_a_hack, + split_k_offset_b_hack); } #else ignore = karg; + ignore = split_k_stride_a; + ignore = split_k_stride_b; + ignore = split_k_offset_a_hack; + ignore = split_k_offset_b_hack; #endif // end of if (defined(__gfx9__)) } @@ -113,14 +127,21 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) [[maybe_unused]] const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock, [[maybe_unused]] const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, - [[maybe_unused]] const index_t num_k_per_block) + [[maybe_unused]] const index_t num_k_per_block, + const long_index_t split_k_stride_a, + const long_index_t split_k_stride_b, + bool split_k_offset_a_hack, + bool split_k_offset_b_hack) { #if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__) if constexpr(GridwiseGemm::template IsValidCompilationParameter()) { // offset base pointer for each work-group const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.z * NumGroupsToMerge); - const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.y * num_k_per_block); + const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.y); + + const long_index_t split_k_offset_a = split_k_offset_a_hack ? k_idx * split_k_stride_a : 0; + const long_index_t split_k_offset_b = split_k_offset_b_hack ? k_idx * split_k_stride_b : 0; const long_index_t a_batch_offset = amd_wave_read_first_lane( static_cast(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); @@ -139,8 +160,8 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, HasMainKBlockLoop, CGlobalMemoryDataOperation, - TailNum>(karg.p_a_grid + a_batch_offset, - karg.p_b_grid + b_batch_offset, + TailNum>(karg.p_a_grid + a_batch_offset + split_k_offset_a, + karg.p_b_grid + b_batch_offset + split_k_offset_b, karg.p_c_grid + e_batch_offset, p_shared_0, p_shared_1, @@ -148,10 +169,17 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) a_grid_desc_ak0_m_ak1, b_grid_desc_bk0_n_bk1, c_grid_desc_mblock_mperblock_nblock_nperblock, - k_idx); + k_idx * num_k_per_block, + gridDim.y, + split_k_offset_a_hack, + split_k_offset_b_hack); } #else ignore = karg; + ignore = split_k_offset_a_hack; + ignore = split_k_offset_b_hack; + ignore = split_k_stride_a; + ignore = split_k_stride_b; #endif // end of if (defined(__gfx9__)) } @@ -779,6 +807,23 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle e_in_transpose_desc_.GetLength(I1)} : Block2TileMapElementwise{ce_grid_desc_m_n_.GetLength(I0), ce_grid_desc_m_n_.GetLength(I1)}; + + const index_t output_spatial_acum = ck::accumulate_n( + output_spatial_lengths_.begin(), NDimSpatial, 1, std::multiplies<>()); + const bool is_k_not_paded = + (Conv_N_ * output_spatial_acum) % (KPerBlock * k_batch_) == 0; + // Check if there is KPading and we can divide N * OutSpatialDims by k_batch + split_k_offset_a_hack_ = + (Conv_N_ * output_spatial_acum) % k_batch_ == 0 && is_k_not_paded && + is_NSpatialGC_GKSpatial_NSpatialGK(); + // Check if there is KPading and we can divide N by k_batch + split_k_offset_b_hack_ = + Conv_N_ % k_batch_ == 0 && is_k_not_paded && + is_NSpatialGC_GKSpatial_NSpatialGK(); + + split_k_stride_a_ = + a_g_n_k_wos_strides[NDimSpatial + I2] * (Conv_N_ * output_spatial_acum) / k_batch_; + split_k_stride_b_ = b_g_n_c_wis_strides[I1] * Conv_N_ / k_batch_; } std::size_t GetWorkspaceATensorSizeBytes() const @@ -864,6 +909,9 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle const std::array& input_left_pads_; const std::array& input_right_pads_; long_index_t c_space_size_bytes; + + bool split_k_offset_a_hack_, split_k_offset_b_hack_; + long_index_t split_k_stride_a_, split_k_stride_b_; }; // Invoker @@ -966,7 +1014,11 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle arg.b_grid_desc_k0_n_k1_, arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, arg.compute_ptr_offset_of_batch_, - num_k_per_block); + num_k_per_block, + arg.split_k_stride_a_, + arg.split_k_stride_b_, + arg.split_k_offset_a_hack_, + arg.split_k_offset_b_hack_); } else { @@ -982,7 +1034,11 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle arg.b_grid_desc_k0_n_k1_, arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, arg.compute_ptr_offset_of_batch_, - num_k_per_block); + num_k_per_block, + arg.split_k_stride_a_, + arg.split_k_stride_b_, + arg.split_k_offset_a_hack_, + arg.split_k_offset_b_hack_); } }; @@ -1886,14 +1942,6 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle } } - constexpr long_index_t TwoGB = (long_index_t{1} << 31); - if(!(arg.a_grid_desc_k0_m_k1_.GetElementSpaceSize() * sizeof(ADataType) <= TwoGB && - arg.b_grid_desc_k0_n_k1_.GetElementSpaceSize() * sizeof(BDataType) <= TwoGB && - arg.ce_grid_desc_m_n_.GetElementSpaceSize() * sizeof(EDataType) <= TwoGB)) - { - return false; - } - return true; } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp index 650c6f11d3..8412896560 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp @@ -61,7 +61,11 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock, const Block2CTileMap block_2_ctile_map, - const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) + const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, + const long_index_t split_k_stride_a, + const long_index_t split_k_stride_b, + bool split_k_offset_a_hack, + bool split_k_offset_b_hack) { #if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx11__) || \ defined(__gfx12__) @@ -88,7 +92,11 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) a_element_op, b_element_op, c_element_op, - block_2_ctile_map); + block_2_ctile_map, + split_k_stride_a, + split_k_stride_b, + split_k_offset_a_hack, + split_k_offset_b_hack); } #else ignore = p_a_grid; @@ -103,6 +111,10 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) ignore = batch_count; ignore = block_2_ctile_map; ignore = compute_ptr_offset_of_batch; + ignore = split_k_stride_a; + ignore = split_k_stride_b; + ignore = split_k_offset_a_hack; + ignore = split_k_offset_b_hack; compute_ptr_offset_of_batch.GetAPtrOffset(0); compute_ptr_offset_of_batch.GetBPtrOffset(0); @@ -638,6 +650,23 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle elementwise_block_2_ctile_map_transpose_e_ = Block2TileMapTranspose{ e_in_transpose_desc_.GetLength(I0), e_in_transpose_desc_.GetLength(I1)}; } + + const index_t output_spatial_acum = ck::accumulate_n( + output_spatial_lengths_.begin(), NDimSpatial, 1, std::multiplies<>()); + const bool is_k_not_paded = + (Conv_N_ * output_spatial_acum) % (K0PerBlock * K1 * k_batch_) == 0; + // Check if there is KPading and we can divide N * OutSpatialDims by k_batch + split_k_offset_a_hack_ = + (Conv_N_ * output_spatial_acum) % k_batch_ == 0 && is_k_not_paded && + is_NSpatialGC_GKSpatial_NSpatialGK(); + // Check if there is KPading and we can divide N by k_batch + split_k_offset_b_hack_ = + Conv_N_ % k_batch_ == 0 && is_k_not_paded && + is_NSpatialGC_GKSpatial_NSpatialGK(); + + split_k_stride_a_ = + a_g_n_k_wos_strides[NDimSpatial + I2] * (Conv_N_ * output_spatial_acum) / k_batch_; + split_k_stride_b_ = b_g_n_c_wis_strides[I1] * Conv_N_ / k_batch_; } std::size_t GetWorkspaceATensorSizeBytes() const @@ -731,6 +760,9 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle const std::array& input_left_pads_; const std::array& input_right_pads_; long_index_t c_space_size_bytes; + + bool split_k_offset_a_hack_, split_k_offset_b_hack_; + long_index_t split_k_stride_a_, split_k_stride_b_; }; // Invoker @@ -877,7 +909,11 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle arg.b_grid_desc_kbatch_k0_n_k1_, c_grid_desc_mblock_mperblock_nblock_nperblock, arg.block_2_ctile_map_, - arg.compute_ptr_offset_of_batch_); + arg.compute_ptr_offset_of_batch_, + arg.split_k_stride_a_, + arg.split_k_stride_b_, + arg.split_k_offset_a_hack_, + arg.split_k_offset_b_hack_); }; if(has_main_k0_block_loop) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp index 6d3ca9d9dd..2de9d84def 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp @@ -53,13 +53,20 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock, const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, - const index_t num_k_per_block) + const index_t num_k_per_block, + const long_index_t split_k_stride_a, + const long_index_t split_k_stride_b, + bool split_k_offset_a_hack, + bool split_k_offset_b_hack) { #if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__) if constexpr(GridwiseGemm::template IsValidCompilationParameter()) { const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.z); - const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.y * num_k_per_block); + const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.y); + + const long_index_t split_k_offset_a = split_k_offset_a_hack ? k_idx * split_k_stride_a : 0; + const long_index_t split_k_offset_b = split_k_offset_b_hack ? k_idx * split_k_stride_b : 0; const long_index_t a_batch_offset = amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)); @@ -74,15 +81,18 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, HasMainKBlockLoop, CGlobalMemoryDataOperation, - TailNum>(karg.p_a_grid + a_batch_offset, - karg.p_b_grid + b_batch_offset, + TailNum>(karg.p_a_grid + a_batch_offset + split_k_offset_a, + karg.p_b_grid + b_batch_offset + split_k_offset_b, karg.p_c_grid + e_batch_offset, p_shared, karg, a_grid_desc_ak0_m_ak1, b_grid_desc_bk0_n_bk1, c_grid_desc_mblock_mperblock_nblock_nperblock, - k_idx); + k_idx * num_k_per_block, + gridDim.y, + split_k_offset_a_hack, + split_k_offset_b_hack); } #else ignore = karg; @@ -91,6 +101,11 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) ignore = c_grid_desc_mblock_mperblock_nblock_nperblock; ignore = compute_ptr_offset_of_batch; ignore = num_k_per_block; + ignore = split_k_stride_a; + ignore = split_k_stride_b; + ignore = split_k_offset_a_hack; + ignore = split_k_offset_b_hack; + #endif // end of if (defined(__gfx9__) } @@ -114,14 +129,21 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock, const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, - const index_t num_k_per_block) + const index_t num_k_per_block, + const long_index_t split_k_stride_a, + const long_index_t split_k_stride_b, + bool split_k_offset_a_hack, + bool split_k_offset_b_hack) { #if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__) if constexpr(GridwiseGemm::template IsValidCompilationParameter()) { // offset base pointer for each work-group const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.z); - const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.y * num_k_per_block); + const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.y); + + const long_index_t split_k_offset_a = split_k_offset_a_hack ? k_idx * split_k_stride_a : 0; + const long_index_t split_k_offset_b = split_k_offset_b_hack ? k_idx * split_k_stride_b : 0; const long_index_t a_batch_offset = amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)); @@ -140,8 +162,8 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, HasMainKBlockLoop, CGlobalMemoryDataOperation, - TailNum>(karg.p_a_grid + a_batch_offset, - karg.p_b_grid + b_batch_offset, + TailNum>(karg.p_a_grid + a_batch_offset + split_k_offset_a, + karg.p_b_grid + b_batch_offset + split_k_offset_b, karg.p_c_grid + e_batch_offset, p_shared_0, p_shared_1, @@ -149,7 +171,10 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) a_grid_desc_ak0_m_ak1, b_grid_desc_bk0_n_bk1, c_grid_desc_mblock_mperblock_nblock_nperblock, - k_idx); + k_idx * num_k_per_block, + gridDim.y, + split_k_offset_a_hack, + split_k_offset_b_hack); } #else ignore = karg; @@ -158,6 +183,10 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) ignore = c_grid_desc_mblock_mperblock_nblock_nperblock; ignore = compute_ptr_offset_of_batch; ignore = num_k_per_block; + ignore = split_k_stride_a; + ignore = split_k_stride_b; + ignore = split_k_offset_a_hack; + ignore = split_k_offset_b_hack; #endif // end of if (defined(__gfx9__) } @@ -594,6 +623,23 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 c_grid_desc_m_n_, GridwiseGemm64::CalculateMBlock(GemmM), GridwiseGemm64::CalculateNBlock(GemmN)); + + const index_t output_spatial_acum = ck::accumulate_n( + output_spatial_lengths_.begin(), NDimSpatial, 1, std::multiplies<>()); + const bool is_k_not_paded = + (Conv_N_ * output_spatial_acum) % (K0PerBlock * k_batch_) == 0; + // Check if there is KPading and we can divide N * OutSpatialDims by k_batch + split_k_offset_a_hack_ = + (Conv_N_ * output_spatial_acum) % k_batch_ == 0 && is_k_not_paded && + is_NSpatialGC_GKSpatial_NSpatialGK(); + // Check if there is KPading and we can divide N by k_batch + split_k_offset_b_hack_ = + Conv_N_ % k_batch_ == 0 && is_k_not_paded && + is_NSpatialGC_GKSpatial_NSpatialGK(); + + split_k_stride_a_ = + a_g_n_k_wos_strides[NDimSpatial + I2] * (Conv_N_ * output_spatial_acum) / k_batch_; + split_k_stride_b_ = b_g_n_c_wis_strides[I1] * Conv_N_ / k_batch_; } const ADataType* p_a_grid_; @@ -626,6 +672,9 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 const std::array& input_left_pads_; const std::array& input_right_pads_; long_index_t c_space_size_bytes; + + bool split_k_offset_a_hack_, split_k_offset_b_hack_; + long_index_t split_k_stride_a_, split_k_stride_b_; }; // Invoker @@ -715,7 +764,11 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 arg.b_grid_desc_kbatch_k0_n_k1_, arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, arg.compute_ptr_offset_of_batch_, - num_k_per_block); + num_k_per_block, + arg.split_k_stride_a_, + arg.split_k_stride_b_, + arg.split_k_offset_a_hack_, + arg.split_k_offset_b_hack_); } else { @@ -731,7 +784,11 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 arg.b_grid_desc_kbatch_k0_n_k1_, arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, arg.compute_ptr_offset_of_batch_, - num_k_per_block); + num_k_per_block, + arg.split_k_stride_a_, + arg.split_k_stride_b_, + arg.split_k_offset_a_hack_, + arg.split_k_offset_b_hack_); } }; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_conv_v3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_conv_v3.hpp index 60ad4651b6..2657409c4b 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_conv_v3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_conv_v3.hpp @@ -45,7 +45,7 @@ template ( - p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); + p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize() / a_space_size_divisor); const auto b_grid_buf = make_dynamic_buffer( - p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize()); + p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize() / b_space_size_divisor); auto c_grid_buf = make_dynamic_buffer( p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); @@ -744,7 +750,7 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3 true, BlockwiseGemmPipe::GlobalBufferNum>( a_grid_desc_ak0_m_ak1, - make_multi_index(k_id, m_block_data_idx_on_grid, 0), + make_multi_index(split_k_offset_a_hack ? 0 : k_id, m_block_data_idx_on_grid, 0), a_element_op, a_block_desc_ak0_m_ak1, make_multi_index(0, 0, 0), @@ -775,7 +781,7 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3 true, BlockwiseGemmPipe::GlobalBufferNum>( b_grid_desc_bk0_n_bk1, - make_multi_index(k_id, n_block_data_idx_on_grid, 0), + make_multi_index(split_k_offset_b_hack ? 0 : k_id, n_block_data_idx_on_grid, 0), b_element_op, b_block_desc_bk0_n_bk1, make_multi_index(0, 0, 0), @@ -1035,12 +1041,18 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3 const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1, const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock& c_grid_desc_mblock_mperblock_nblock_nperblock, - const index_t k_id = 0) + const index_t k_id = 0, + const index_t k_batch = 1, + const bool split_k_offset_a_hack = false, + const bool split_k_offset_b_hack = false) { + const long_index_t a_space_size_divisor = split_k_offset_a_hack ? k_batch : 1; + const long_index_t b_space_size_divisor = split_k_offset_b_hack ? k_batch : 1; + const auto a_grid_buf = make_dynamic_buffer( - p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); + p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize() / a_space_size_divisor); const auto b_grid_buf = make_dynamic_buffer( - p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize()); + p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize() / b_space_size_divisor); auto c_grid_buf = make_dynamic_buffer( p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); @@ -1106,7 +1118,7 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3 true, BlockwiseGemmPipe::GlobalBufferNum>( a_grid_desc_ak0_m_ak1, - make_multi_index(k_id, m_block_data_idx_on_grid, 0), + make_multi_index(split_k_offset_a_hack ? 0 : k_id, m_block_data_idx_on_grid, 0), a_element_op, a_block_desc_ak0_m_ak1, make_multi_index(0, 0, 0), @@ -1137,7 +1149,7 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3 true, BlockwiseGemmPipe::GlobalBufferNum>( b_grid_desc_bk0_n_bk1, - make_multi_index(k_id, n_block_data_idx_on_grid, 0), + make_multi_index(split_k_offset_b_hack ? 0 : k_id, n_block_data_idx_on_grid, 0), b_element_op, b_block_desc_bk0_n_bk1, make_multi_index(0, 0, 0), diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp index 422b9afa61..414fc3a03f 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp @@ -646,6 +646,415 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight decltype(MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(CMNGridDesc{})); using CBlockClusterAdaptor = decltype(MakeCBlockClusterAdaptor(CMNGridDesc{}, 1, 1, 1)); + template + __device__ static void Run(const FloatA* __restrict__ p_a_grid, + const FloatB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + void* __restrict__ p_shared, + const AGridDesc_B_K0_M_K1& a_b_k0_m_k1_grid_desc, + const BGridDesc_B_K0_N_K1& b_b_k0_n_k1_grid_desc, + const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock& + c_grid_desc_mblock_mperblock_nblock_nperblock, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CElementwiseOperation& c_element_op, + const CBlockClusterAdaptor& c_block_cluster_adaptor, + const long_index_t split_k_stride_a, + const long_index_t split_k_stride_b, + bool split_k_offset_a_hack, + bool split_k_offset_b_hack) + { + const auto K0 = a_b_k0_m_k1_grid_desc.GetLength(I1); + + // divide block work by [M, N] + const auto block_work_idx = + c_block_cluster_adaptor.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + const index_t k_batch_id = block_work_idx[I0]; + + const long_index_t split_k_offset_a = + split_k_offset_a_hack ? k_batch_id * split_k_stride_a : 0; + const long_index_t split_k_offset_b = + split_k_offset_b_hack ? k_batch_id * split_k_stride_b : 0; + + const long_index_t a_space_size_divisor = + split_k_offset_a_hack ? a_b_k0_m_k1_grid_desc.GetLength(I0) : 1; + const long_index_t b_space_size_divisor = + split_k_offset_b_hack ? a_b_k0_m_k1_grid_desc.GetLength(I0) : 1; + + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid + split_k_offset_a, + a_b_k0_m_k1_grid_desc.GetElementSpaceSize() / a_space_size_divisor); + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid + split_k_offset_b, + b_b_k0_n_k1_grid_desc.GetElementSpaceSize() / b_space_size_divisor); + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + if(!c_block_cluster_adaptor.ValidCTileIndex( + make_tuple(block_work_idx[I1], block_work_idx[I2]), + make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), + c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) + { + return; + } + + // HACK: this force m/n_block_data_idx_on_grid into SGPR + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I1] * MPerBlock); + + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I2] * NPerBlock); + + // lds max alignment + constexpr auto max_lds_align = K1; + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_k0_m_k1_block_desc = GetABlockDescriptor_K0PerBlock_MPerBlock_K1(); + + constexpr auto a_b_k0_m_k1_block_desc = GetABlockDescriptor_Batch_K0PerBlock_MPerBlock_K1(); + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_k0_n_k1_block_desc = GetBBlockDescriptor_K0PerBlock_NPerBlock_K1(); + + constexpr auto b_b_k0_n_k1_block_desc = GetBBlockDescriptor_Batch_K0PerBlock_NPerBlock_K1(); + // A matrix blockwise copy + auto a_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + FloatA, + FloatAAdjusted, + decltype(a_b_k0_m_k1_grid_desc), + decltype(a_b_k0_m_k1_block_desc), + ABlockTransferSrcAccessOrder, + Sequence<0, 2, 1, 3>, + ABlockTransferSrcVectorDim, + 3, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true>( + a_b_k0_m_k1_grid_desc, + make_multi_index( + split_k_offset_a_hack ? 0 : k_batch_id, 0, m_block_data_idx_on_grid, 0), + a_element_op, + a_b_k0_m_k1_block_desc, + make_multi_index(0, 0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // B matrix blockwise copy + auto b_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + FloatB, + FloatBAdjusted, + decltype(b_b_k0_n_k1_grid_desc), + decltype(b_b_k0_n_k1_block_desc), + BBlockTransferSrcAccessOrder, + Sequence<0, 2, 1, 3>, + BBlockTransferSrcVectorDim, + 3, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true>( + b_b_k0_n_k1_grid_desc, + make_multi_index( + split_k_offset_b_hack ? 0 : k_batch_id, 0, n_block_data_idx_on_grid, 0), + b_element_op, + b_b_k0_n_k1_block_desc, + make_multi_index(0, 0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // GEMM definition + // c_mtx += transpose(a_mtx) * b_mtx + // a_mtx[K0PerBlock, MPerBlock] is in LDS + // b_mtx[K0PerBlock, NPerBlock] is in LDS + // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in + // register + // sanity check + constexpr bool is_single_rate_mfma = + (((is_same::value || is_same::value) && + K1 <= 4) || + (is_same::value && K1 <= 8) || + ((is_same::value || is_same::value) && + K1 < 32)) + ? true + : false; + constexpr auto is_scale_mfma = false; + constexpr index_t KPack = math::max(K1, + MfmaSelector::selected_mfma.k_per_blk); + + auto blockwise_gemm = + BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1{}; + + auto c_thread_buf = blockwise_gemm.GetCThreadBuffer(); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size = + math::integer_least_multiple(a_k0_m_k1_block_desc.GetElementSpaceSize(), max_lds_align); + + constexpr auto a_block_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0); + + auto a_block_buf = make_dynamic_buffer( + static_cast(p_shared), a_k0_m_k1_block_desc.GetElementSpaceSize()); + + auto b_block_buf = make_dynamic_buffer( + static_cast(p_shared) + a_block_space_size, + b_k0_n_k1_block_desc.GetElementSpaceSize()); + + // gridwise GEMM pipeline + const index_t K0BlockMainLoop = __builtin_amdgcn_readfirstlane(K0 / K0PerBlock); + + GridwiseGemmPipe::template Run(a_b_k0_m_k1_grid_desc, + a_b_k0_m_k1_block_desc, + a_blockwise_copy, + a_grid_buf, + a_block_buf, + a_block_slice_copy_step, + b_b_k0_n_k1_grid_desc, + b_b_k0_n_k1_block_desc, + b_blockwise_copy, + b_grid_buf, + b_block_buf, + b_block_slice_copy_step, + blockwise_gemm, + c_thread_buf, + K0BlockMainLoop); + + // output: register to global memory + { + constexpr index_t MWave = MPerBlock / (MRepeat * MPerXdl); + constexpr index_t NWave = NPerBlock / (NRepeat * NPerXdl); + + constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc = + blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc = + blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto M0 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I0); + constexpr auto N0 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I1); + constexpr auto M1 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I2); + constexpr auto N1 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I3); + constexpr auto M2 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I4); + constexpr auto M3 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I5); + constexpr auto M4 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I6); + constexpr auto N2 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I7); + + constexpr auto c_block_desc_mblock_mperblock_nblock_nperblock = + GetCBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + auto c_block_buf = make_dynamic_buffer( + static_cast(p_shared), + c_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + static_assert(M1 == MWave, ""); + static_assert(N1 == NWave, ""); + static_assert(M2 * M3 * M4 == MPerXdl, ""); + static_assert(N2 == NPerXdl, ""); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( + c_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple( + make_freeze_transform(I0), // freeze mblock + make_unmerge_transform(make_tuple(CShuffleMRepeatPerShuffle, + M1, + M2, + M3, + M4)), // M1 = MWave, M2 * M3 * M4 = MPerXdl + make_freeze_transform(I0), // freeze nblock + make_unmerge_transform(make_tuple(CShuffleNRepeatPerShuffle, + N1, + N2))), // M1 = MWave, M2 * M3 * M4 = MPerXdl + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple( + Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // VGPR to LDS + auto c_thread_copy_vgpr_to_lds = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{ + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3], + m_thread_data_on_block_idx[I4], + n_thread_data_on_block_idx[I2]), + ck::tensor_operation::element_wise::PassThrough{}}; + + // LDS to global + auto c_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1< + ThisThreadBlock, // index_t BlockSize, + CElementwiseOperation, // ElementwiseOperation, + CGlobalMemoryDataOperation, // DstInMemOp, + Sequence<1, + CShuffleMRepeatPerShuffle * MWave * MPerXdl, + 1, + CShuffleNRepeatPerShuffle * NWave * NPerXdl>, // BlockSliceLengths, + CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + FloatC, // typename SrcData, + FloatC, // typename DstData, + decltype(c_block_desc_mblock_mperblock_nblock_nperblock), + decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), + Sequence<0, 1, 2, 3>, // typename DimAccessOrder, + 3, // index_t VectorDim, + CBlockTransferScalarPerVector_NWaveNPerXDL, // index_t ScalarPerVector, + true, // bool ThreadTransferSrcResetCoordinateAfterRun, + false> // bool ThreadTransferDstResetCoordinateAfterRun + {c_block_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(0, 0, 0, 0), + c_grid_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(block_work_idx[I1], 0, block_work_idx[I2], 0), + c_element_op}; + + constexpr auto mxdlperwave_forward_step = + make_multi_index(0, CShuffleMRepeatPerShuffle * MWave * MPerXdl, 0, 0); + constexpr auto nxdlperwave_forward_step = + make_multi_index(0, 0, 0, CShuffleNRepeatPerShuffle * NWave * NPerXdl); + constexpr auto nxdlperwave_backward_step = + make_multi_index(0, 0, 0, -CShuffleNRepeatPerShuffle * NWave * NPerXdl); + + static_for<0, MRepeat, CShuffleMRepeatPerShuffle>{}([&](auto mxdlperwave_iter) { + constexpr auto mxdlperwave = mxdlperwave_iter; + + static_for<0, NRepeat, CShuffleNRepeatPerShuffle>{}([&](auto nxdlperwave_iter) { + constexpr bool nxdlperwave_forward_sweep = + (mxdlperwave % (2 * CShuffleMRepeatPerShuffle) == 0); + + constexpr index_t nxdlperwave_value = + nxdlperwave_forward_sweep + ? nxdlperwave_iter + : (NRepeat - nxdlperwave_iter - CShuffleNRepeatPerShuffle); + + constexpr auto nxdlperwave = Number{}; + + // make sure it's safe to do ds_write + block_sync_lds(); + + // VGPR to LDS + c_thread_copy_vgpr_to_lds.Run( + c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc, + make_tuple(mxdlperwave, nxdlperwave, I0, I0, I0, I0, I0, I0), + c_thread_buf, + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_block_buf); + + // make sure it's safe to do ds_read + block_sync_lds(); + + // LDS to global + c_block_copy_lds_to_global.Run(c_block_desc_mblock_mperblock_nblock_nperblock, + c_block_buf, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c_grid_buf); + + // move on nxdlperwave dimension + if constexpr(nxdlperwave_forward_sweep && + (nxdlperwave < NRepeat - CShuffleNRepeatPerShuffle)) + { + c_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mperblock_nblock_nperblock, + nxdlperwave_forward_step); + } + else if constexpr((!nxdlperwave_forward_sweep) && (nxdlperwave > 0)) + { + c_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mperblock_nblock_nperblock, + nxdlperwave_backward_step); + } + }); + + // move on mxdlperwave dimension + if constexpr(mxdlperwave < MRepeat - CShuffleMRepeatPerShuffle) + { + c_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mperblock_nblock_nperblock, mxdlperwave_forward_step); + } + }); + } + } + template __device__ static void Run(const FloatA* __restrict__ p_a_grid, const FloatB* __restrict__ p_b_grid,