From e63bc7a2cec34b653abf5ea4ff9bfac339477fb4 Mon Sep 17 00:00:00 2001 From: "Graner, Johannes" Date: Wed, 12 Nov 2025 13:30:56 +0000 Subject: [PATCH] Fix tensor descriptors and stride calculations --- ...conv_bwd_weight_two_stage_xdl_cshuffle.hpp | 17 ++- ...e_grouped_conv_bwd_weight_xdl_cshuffle.hpp | 55 +++++---- ...rouped_conv_bwd_weight_xdl_cshuffle_v3.hpp | 109 ++++++++++-------- .../grid/gridwise_gemm_xdlops_bwd_weight.hpp | 28 +++-- .../transform_conv_bwd_weight_to_gemm.hpp | 48 +++++--- .../transform_conv_bwd_weight_to_gemm_v2.hpp | 48 +++++--- 6 files changed, 186 insertions(+), 119 deletions(-) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp index c761e326fa..46140ac0c2 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp @@ -814,16 +814,23 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle (Conv_N_ * output_spatial_acum) % (KPerBlock * k_batch_) == 0; // Check if there is KPading and we can divide N * OutSpatialDims by k_batch split_k_offset_a_hack_ = - (Conv_N_ * output_spatial_acum) % k_batch_ == 0 && is_k_not_paded && + k_batch_ > 1 && (Conv_N_ * output_spatial_acum) % k_batch_ == 0 && is_k_not_paded && is_NSpatialGC_GKSpatial_NSpatialGK(); // Check if there is KPading and we can divide N by k_batch split_k_offset_b_hack_ = - Conv_N_ % k_batch_ == 0 && is_k_not_paded && + k_batch_ > 1 && Conv_N_ % k_batch_ == 0 && is_k_not_paded && is_NSpatialGC_GKSpatial_NSpatialGK(); - split_k_stride_a_ = - a_g_n_k_wos_strides[NDimSpatial + I2] * (Conv_N_ * output_spatial_acum) / k_batch_; - split_k_stride_b_ = b_g_n_c_wis_strides[I1] * Conv_N_ / k_batch_; + // Calculate stride from descriptor size + // NOTE: GetElementSpaceSize() returns the full size even when KBatchIndex=1, + // so we need to divide by k_batch_ to get the per-batch stride when the hack is enabled + split_k_stride_a_ = a_grid_desc_k0_m_k1_.GetElementSpaceSize(); + if(split_k_offset_a_hack_) + split_k_stride_a_ /= k_batch_; + + split_k_stride_b_ = b_grid_desc_k0_n_k1_.GetElementSpaceSize(); + if(split_k_offset_b_hack_) + split_k_stride_b_ /= k_batch_; } std::size_t GetWorkspaceATensorSizeBytes() const diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp index 8412896560..2b5770dc00 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp @@ -65,7 +65,8 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) const long_index_t split_k_stride_a, const long_index_t split_k_stride_b, bool split_k_offset_a_hack, - bool split_k_offset_b_hack) + bool split_k_offset_b_hack, + index_t k_batch) { #if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx11__) || \ defined(__gfx12__) @@ -96,7 +97,8 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) split_k_stride_a, split_k_stride_b, split_k_offset_a_hack, - split_k_offset_b_hack); + split_k_offset_b_hack, + k_batch); } #else ignore = p_a_grid; @@ -115,6 +117,7 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) ignore = split_k_stride_b; ignore = split_k_offset_a_hack; ignore = split_k_offset_b_hack; + ignore = k_batch; compute_ptr_offset_of_batch.GetAPtrOffset(0); compute_ptr_offset_of_batch.GetBPtrOffset(0); @@ -587,6 +590,19 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle k_batch_ = split_k; } + const index_t output_spatial_acum = ck::accumulate_n( + output_spatial_lengths_.begin(), NDimSpatial, 1, std::multiplies<>()); + const bool is_k_not_paded = + (Conv_N_ * output_spatial_acum) % (K0PerBlock * K1 * k_batch_) == 0; + // Check if there is KPading and we can divide N * OutSpatialDims by k_batch + split_k_offset_a_hack_ = + k_batch_ > 1 && (Conv_N_ * output_spatial_acum) % k_batch_ == 0 && is_k_not_paded && + is_NSpatialGC_GKSpatial_NSpatialGK(); + // Check if there is KPading and we can divide N by k_batch + split_k_offset_b_hack_ = + k_batch_ > 1 && Conv_N_ % k_batch_ == 0 && is_k_not_paded && + is_NSpatialGC_GKSpatial_NSpatialGK(); + const auto descs = conv_to_gemm_transformer .template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( @@ -603,12 +619,25 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle conv_filter_dilations, input_left_pads, input_right_pads, - k_batch_); + k_batch_, + split_k_offset_a_hack_, + split_k_offset_b_hack_); a_grid_desc_kbatch_k0_m_k1_ = descs[I0]; b_grid_desc_kbatch_k0_n_k1_ = descs[I1]; c_grid_desc_m_n_ = descs[I2]; + // Calculate stride from descriptor size + // NOTE: GetElementSpaceSize() returns the full size even when KBatchIndex=1, + // so we need to divide by k_batch_ to get the per-batch stride when the hack is enabled + split_k_stride_a_ = a_grid_desc_kbatch_k0_m_k1_.GetElementSpaceSize(); + if(split_k_offset_a_hack_) + split_k_stride_a_ /= k_batch_; + + split_k_stride_b_ = b_grid_desc_kbatch_k0_n_k1_.GetElementSpaceSize(); + if(split_k_offset_b_hack_) + split_k_stride_b_ /= k_batch_; + block_2_ctile_map_ = GridwiseGemm64::MakeCBlockClusterAdaptor(c_grid_desc_m_n_, M01, N01, k_batch_); @@ -650,23 +679,6 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle elementwise_block_2_ctile_map_transpose_e_ = Block2TileMapTranspose{ e_in_transpose_desc_.GetLength(I0), e_in_transpose_desc_.GetLength(I1)}; } - - const index_t output_spatial_acum = ck::accumulate_n( - output_spatial_lengths_.begin(), NDimSpatial, 1, std::multiplies<>()); - const bool is_k_not_paded = - (Conv_N_ * output_spatial_acum) % (K0PerBlock * K1 * k_batch_) == 0; - // Check if there is KPading and we can divide N * OutSpatialDims by k_batch - split_k_offset_a_hack_ = - (Conv_N_ * output_spatial_acum) % k_batch_ == 0 && is_k_not_paded && - is_NSpatialGC_GKSpatial_NSpatialGK(); - // Check if there is KPading and we can divide N by k_batch - split_k_offset_b_hack_ = - Conv_N_ % k_batch_ == 0 && is_k_not_paded && - is_NSpatialGC_GKSpatial_NSpatialGK(); - - split_k_stride_a_ = - a_g_n_k_wos_strides[NDimSpatial + I2] * (Conv_N_ * output_spatial_acum) / k_batch_; - split_k_stride_b_ = b_g_n_c_wis_strides[I1] * Conv_N_ / k_batch_; } std::size_t GetWorkspaceATensorSizeBytes() const @@ -913,7 +925,8 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle arg.split_k_stride_a_, arg.split_k_stride_b_, arg.split_k_offset_a_hack_, - arg.split_k_offset_b_hack_); + arg.split_k_offset_b_hack_, + arg.k_batch_); }; if(has_main_k0_block_loop) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp index 2de9d84def..c9d9db1da2 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp @@ -514,8 +514,8 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 : p_a_grid_{p_out_grid}, p_b_grid_{p_in_grid}, p_c_grid_{p_wei_grid}, - a_grid_desc_kbatch_k0_m_k1_{}, - b_grid_desc_kbatch_k0_n_k1_{}, + a_grid_desc_k0_m_k1_{}, + b_grid_desc_k0_n_k1_{}, c_grid_desc_m_n_{}, c_grid_desc_mblock_mperblock_nblock_nperblock_{}, compute_ptr_offset_of_batch_{}, @@ -584,6 +584,19 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 k_batch_ = split_k; } + const index_t output_spatial_acum = ck::accumulate_n( + output_spatial_lengths_.begin(), NDimSpatial, 1, std::multiplies<>()); + const bool is_k_not_paded = + (Conv_N_ * output_spatial_acum) % (K0PerBlock * k_batch_) == 0; + // Check if there is KPading and we can divide N * OutSpatialDims by k_batch + split_k_offset_a_hack_ = + k_batch_ > 1 && (Conv_N_ * output_spatial_acum) % k_batch_ == 0 && is_k_not_paded && + is_NSpatialGC_GKSpatial_NSpatialGK(); + // Check if there is KPading and we can divide N by k_batch + split_k_offset_b_hack_ = + k_batch_ > 1 && Conv_N_ % k_batch_ == 0 && is_k_not_paded && + is_NSpatialGC_GKSpatial_NSpatialGK(); + const auto descs = conv_to_gemm_transformer .template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( @@ -600,11 +613,24 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 conv_filter_dilations, input_left_pads, input_right_pads, - k_batch_); + k_batch_, + split_k_offset_a_hack_, + split_k_offset_b_hack_); - a_grid_desc_kbatch_k0_m_k1_ = descs[I0]; - b_grid_desc_kbatch_k0_n_k1_ = descs[I1]; - c_grid_desc_m_n_ = descs[I2]; + a_grid_desc_k0_m_k1_ = descs[I0]; + b_grid_desc_k0_n_k1_ = descs[I1]; + c_grid_desc_m_n_ = descs[I2]; + + // Calculate stride from descriptor size + // NOTE: GetElementSpaceSize() returns the full size even when KBatchIndex=1, + // so we need to divide by k_batch_ to get the per-batch stride when the hack is enabled + split_k_stride_a_ = a_grid_desc_k0_m_k1_.GetElementSpaceSize(); + if(split_k_offset_a_hack_) + split_k_stride_a_ /= k_batch_; + + split_k_stride_b_ = b_grid_desc_k0_n_k1_.GetElementSpaceSize(); + if(split_k_offset_b_hack_) + split_k_stride_b_ /= k_batch_; // A/B/C Batch Stride compute_ptr_offset_of_batch_.BatchStrideA_ = a_g_n_k_wos_strides[0]; @@ -615,38 +641,21 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 end(filter_spatial_lengths_), index_t{1}, std::multiplies<>{}); - const index_t GemmM = a_grid_desc_kbatch_k0_m_k1_.GetLength(I1); - const index_t GemmN = b_grid_desc_kbatch_k0_n_k1_.GetLength(I1); + const index_t GemmM = a_grid_desc_k0_m_k1_.GetLength(I1); + const index_t GemmN = b_grid_desc_k0_n_k1_.GetLength(I1); c_grid_desc_mblock_mperblock_nblock_nperblock_ = GridwiseGemm64::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( c_grid_desc_m_n_, GridwiseGemm64::CalculateMBlock(GemmM), GridwiseGemm64::CalculateNBlock(GemmN)); - - const index_t output_spatial_acum = ck::accumulate_n( - output_spatial_lengths_.begin(), NDimSpatial, 1, std::multiplies<>()); - const bool is_k_not_paded = - (Conv_N_ * output_spatial_acum) % (K0PerBlock * k_batch_) == 0; - // Check if there is KPading and we can divide N * OutSpatialDims by k_batch - split_k_offset_a_hack_ = - (Conv_N_ * output_spatial_acum) % k_batch_ == 0 && is_k_not_paded && - is_NSpatialGC_GKSpatial_NSpatialGK(); - // Check if there is KPading and we can divide N by k_batch - split_k_offset_b_hack_ = - Conv_N_ % k_batch_ == 0 && is_k_not_paded && - is_NSpatialGC_GKSpatial_NSpatialGK(); - - split_k_stride_a_ = - a_g_n_k_wos_strides[NDimSpatial + I2] * (Conv_N_ * output_spatial_acum) / k_batch_; - split_k_stride_b_ = b_g_n_c_wis_strides[I1] * Conv_N_ / k_batch_; } const ADataType* p_a_grid_; const BDataType* p_b_grid_; CDataType* p_c_grid_; - AGridDesc_K0_M_K1 a_grid_desc_kbatch_k0_m_k1_; - BGridDesc_K0_N_K1 b_grid_desc_kbatch_k0_n_k1_; + AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_; + BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_; CGridDesc_M_N c_grid_desc_m_n_; CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock_; @@ -685,16 +694,16 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 void ShowInfo(const Argument& arg) { std::cout << "arg.a_grid_desc_kbatch_k0_m_k1_{" - << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0) << ", " - << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1) << ", " - << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I2) << ", " - << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I3) << "}" << std::endl; + << arg.a_grid_desc_k0_m_k1_.GetLength(I0) << ", " + << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", " + << arg.a_grid_desc_k0_m_k1_.GetLength(I2) << ", " + << arg.a_grid_desc_k0_m_k1_.GetLength(I3) << "}" << std::endl; std::cout << "arg.b_grid_desc_kbatch_k0_n_k1_{" - << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I0) << ", " - << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I1) << ", " - << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I2) << ", " - << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I3) << "}" << std::endl; + << arg.b_grid_desc_k0_n_k1_.GetLength(I0) << ", " + << arg.b_grid_desc_k0_n_k1_.GetLength(I1) << ", " + << arg.b_grid_desc_k0_n_k1_.GetLength(I2) << ", " + << arg.b_grid_desc_k0_n_k1_.GetLength(I3) << "}" << std::endl; std::cout << "arg.c_grid_desc_m_n_{" << arg.c_grid_desc_m_n_.GetLength(I0) << ", " << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; @@ -703,10 +712,10 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 template float RunImp(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) { - const index_t GemmM = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1); - const index_t GemmN = arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I1); - const index_t GemmK = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0) * - arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I2); + const index_t GemmM = arg.a_grid_desc_k0_m_k1_.GetLength(I1); + const index_t GemmN = arg.b_grid_desc_k0_n_k1_.GetLength(I1); + const index_t GemmK = + arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2); const ADataType* p_a_grid = arg.p_a_grid_; const BDataType* p_b_grid = arg.p_b_grid_; @@ -724,7 +733,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split); const auto num_k_per_block = - arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(Number<0>{}) / gemm_arg.KBatch; + arg.a_grid_desc_k0_m_k1_.GetLength(Number<0>{}) / gemm_arg.KBatch; const auto clear_workspace = [&]() { if(arg.k_batch_ > 1) @@ -760,8 +769,8 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 dim3(BlockSize), 0, gemm_arg_, - arg.a_grid_desc_kbatch_k0_m_k1_, - arg.b_grid_desc_kbatch_k0_n_k1_, + arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, arg.compute_ptr_offset_of_batch_, num_k_per_block, @@ -780,8 +789,8 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 dim3(BlockSize), 0, gemm_arg, - arg.a_grid_desc_kbatch_k0_m_k1_, - arg.b_grid_desc_kbatch_k0_n_k1_, + arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, arg.compute_ptr_offset_of_batch_, num_k_per_block, @@ -1341,10 +1350,10 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 } #endif - const index_t GemmM = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1); - const index_t GemmN = arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I1); - const index_t GemmK = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0) * - arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I2); + const index_t GemmM = arg.a_grid_desc_k0_m_k1_.GetLength(I1); + const index_t GemmN = arg.b_grid_desc_k0_n_k1_.GetLength(I1); + const index_t GemmK = + arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2); if constexpr(is_same_v || is_same_v) { @@ -1475,8 +1484,8 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 } constexpr long_index_t TwoGB = (long_index_t{1} << 31); - if(!(arg.a_grid_desc_kbatch_k0_m_k1_.GetElementSpaceSize() * sizeof(ADataType) <= TwoGB && - arg.b_grid_desc_kbatch_k0_n_k1_.GetElementSpaceSize() * sizeof(BDataType) <= TwoGB && + if(!(arg.a_grid_desc_k0_m_k1_.GetElementSpaceSize() * sizeof(ADataType) <= TwoGB && + arg.b_grid_desc_k0_n_k1_.GetElementSpaceSize() * sizeof(BDataType) <= TwoGB && arg.c_grid_desc_m_n_.GetElementSpaceSize() * sizeof(CDataType) <= TwoGB)) { return false; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp index 414fc3a03f..1286681a5b 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp @@ -164,7 +164,12 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) const AElementwiseOperation a_element_op, const BElementwiseOperation b_element_op, const CElementwiseOperation c_element_op, - const CBlockClusterAdaptor c_block_cluster_adaptor) + const CBlockClusterAdaptor c_block_cluster_adaptor, + const long_index_t split_k_stride_a, + const long_index_t split_k_stride_b, + bool split_k_offset_a_hack, + bool split_k_offset_b_hack, + index_t k_batch) { #if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx11__) || \ defined(__gfx12__) @@ -182,7 +187,12 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) a_element_op, b_element_op, c_element_op, - c_block_cluster_adaptor); + c_block_cluster_adaptor, + split_k_stride_a, + split_k_stride_b, + split_k_offset_a_hack, + split_k_offset_b_hack, + k_batch); } #else ignore = p_a_grid; @@ -195,6 +205,11 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) ignore = b_element_op; ignore = c_element_op; ignore = c_block_cluster_adaptor; + ignore = split_k_stride_a; + ignore = split_k_stride_b; + ignore = split_k_offset_a_hack; + ignore = split_k_offset_b_hack; + ignore = k_batch; #endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) } @@ -662,7 +677,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight const long_index_t split_k_stride_a, const long_index_t split_k_stride_b, bool split_k_offset_a_hack, - bool split_k_offset_b_hack) + bool split_k_offset_b_hack, + index_t k_batch) { const auto K0 = a_b_k0_m_k1_grid_desc.GetLength(I1); @@ -677,10 +693,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight const long_index_t split_k_offset_b = split_k_offset_b_hack ? k_batch_id * split_k_stride_b : 0; - const long_index_t a_space_size_divisor = - split_k_offset_a_hack ? a_b_k0_m_k1_grid_desc.GetLength(I0) : 1; - const long_index_t b_space_size_divisor = - split_k_offset_b_hack ? a_b_k0_m_k1_grid_desc.GetLength(I0) : 1; + const long_index_t a_space_size_divisor = split_k_offset_a_hack ? k_batch : 1; + const long_index_t b_space_size_divisor = split_k_offset_b_hack ? k_batch : 1; const auto a_grid_buf = make_dynamic_buffer( p_a_grid + split_k_offset_a, diff --git a/include/ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm.hpp b/include/ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm.hpp index efc7f20cdc..e4e2a8bbfc 100644 --- a/include/ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm.hpp +++ b/include/ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm.hpp @@ -150,7 +150,9 @@ struct TransformConvBwdWeightToGemm const std::array& conv_filter_dilations, const std::array& input_left_pads, const std::array& input_right_pads, - const index_t batch_k) + const index_t batch_k, + const bool split_k_offset_a_hack = false, + const bool split_k_offset_b_hack = false) { using namespace ck; @@ -173,7 +175,9 @@ struct TransformConvBwdWeightToGemm const index_t GemmK0 = math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) * K0PerBlock; - const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number; + const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number; + const index_t KBatchIndexA = split_k_offset_a_hack ? 1 : GemmKBatch; + const index_t KBatchIndexB = split_k_offset_b_hack ? 1 : GemmKBatch; if constexpr(ConvBackwardWeightSpecialization == device::ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0) @@ -191,7 +195,7 @@ struct TransformConvBwdWeightToGemm const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( out_gemmkpad_gemmm_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), + make_tuple(make_unmerge_transform(make_tuple(KBatchIndexA, GemmK0, GemmK1Number)), make_right_pad_transform(GemmM, PadGemmM)), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); @@ -209,7 +213,7 @@ struct TransformConvBwdWeightToGemm const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( in_gemmkpad_gemmn_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), + make_tuple(make_unmerge_transform(make_tuple(KBatchIndexB, GemmK0, GemmK1Number)), make_right_pad_transform(GemmN, PadGemmN)), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); @@ -247,7 +251,7 @@ struct TransformConvBwdWeightToGemm const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( out_gemmkpad_gemmm_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), + make_tuple(make_unmerge_transform(make_tuple(KBatchIndexA, GemmK0, GemmK1Number)), make_right_pad_transform(GemmM, PadGemmM)), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); @@ -286,7 +290,7 @@ struct TransformConvBwdWeightToGemm const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( in_gemmkpad_gemmn_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), + make_tuple(make_unmerge_transform(make_tuple(KBatchIndexB, GemmK0, GemmK1Number)), make_right_pad_transform(GemmN, PadGemmN)), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); @@ -324,7 +328,9 @@ struct TransformConvBwdWeightToGemm const std::array& conv_filter_dilations, const std::array& input_left_pads, const std::array& input_right_pads, - const index_t batch_k) + const index_t batch_k, + const bool split_k_offset_a_hack = false, + const bool split_k_offset_b_hack = false) { using namespace ck; @@ -360,7 +366,9 @@ struct TransformConvBwdWeightToGemm const index_t GemmK0 = math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) * K0PerBlock; - const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number; + const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number; + const index_t KBatchIndexA = split_k_offset_a_hack ? 1 : GemmKBatch; + const index_t KBatchIndexB = split_k_offset_b_hack ? 1 : GemmKBatch; const auto out_grid_desc = make_out_grid_desc(N, Ho, Wo, K, output_strides); const auto in_grid_desc = make_in_grid_desc(N, Hi, Wi, C, input_strides); @@ -379,7 +387,7 @@ struct TransformConvBwdWeightToGemm const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( out_gemmkpad_gemmm_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), + make_tuple(make_unmerge_transform(make_tuple(KBatchIndexA, GemmK0, GemmK1Number)), make_right_pad_transform(GemmM, PadGemmM)), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); @@ -394,7 +402,7 @@ struct TransformConvBwdWeightToGemm const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( in_gemmkpad_gemmn_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), + make_tuple(make_unmerge_transform(make_tuple(KBatchIndexB, GemmK0, GemmK1Number)), make_right_pad_transform(GemmN, PadGemmN)), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); @@ -423,7 +431,7 @@ struct TransformConvBwdWeightToGemm const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( out_gemmkpad_gemmm_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), + make_tuple(make_unmerge_transform(make_tuple(KBatchIndexA, GemmK0, GemmK1Number)), make_right_pad_transform(GemmM, PadGemmM)), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); @@ -464,7 +472,7 @@ struct TransformConvBwdWeightToGemm const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( in_gemmkpad_gemmn_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), + make_tuple(make_unmerge_transform(make_tuple(KBatchIndexB, GemmK0, GemmK1Number)), make_right_pad_transform(GemmN, PadGemmN)), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); @@ -498,7 +506,9 @@ struct TransformConvBwdWeightToGemm const std::array& conv_filter_dilations, const std::array& input_left_pads, const std::array& input_right_pads, - const index_t batch_k) + const index_t batch_k, + const bool split_k_offset_a_hack = false, + const bool split_k_offset_b_hack = false) { using namespace ck; @@ -541,7 +551,9 @@ struct TransformConvBwdWeightToGemm const index_t GemmK0 = math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) * K0PerBlock; - const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number; + const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number; + const index_t KBatchIndexA = split_k_offset_a_hack ? 1 : GemmKBatch; + const index_t KBatchIndexB = split_k_offset_b_hack ? 1 : GemmKBatch; const auto out_grid_desc = make_out_grid_desc(N, Do, Ho, Wo, K, output_strides); const auto in_grid_desc = make_in_grid_desc(N, Di, Hi, Wi, C, input_strides); @@ -560,7 +572,7 @@ struct TransformConvBwdWeightToGemm const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( out_gemmkpad_gemmm_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), + make_tuple(make_unmerge_transform(make_tuple(KBatchIndexA, GemmK0, GemmK1Number)), make_right_pad_transform(GemmM, PadGemmM)), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); @@ -575,7 +587,7 @@ struct TransformConvBwdWeightToGemm const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( in_gemmkpad_gemmn_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), + make_tuple(make_unmerge_transform(make_tuple(KBatchIndexB, GemmK0, GemmK1Number)), make_right_pad_transform(GemmN, PadGemmN)), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); @@ -604,7 +616,7 @@ struct TransformConvBwdWeightToGemm const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( out_gemmkpad_gemmm_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), + make_tuple(make_unmerge_transform(make_tuple(KBatchIndexA, GemmK0, GemmK1Number)), make_right_pad_transform(GemmM, PadGemmM)), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); @@ -654,7 +666,7 @@ struct TransformConvBwdWeightToGemm const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( in_gemmkpad_gemmn_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), + make_tuple(make_unmerge_transform(make_tuple(KBatchIndexB, GemmK0, GemmK1Number)), make_right_pad_transform(GemmN, PadGemmN)), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); diff --git a/include/ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm_v2.hpp b/include/ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm_v2.hpp index e410f06190..98b7e29439 100644 --- a/include/ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm_v2.hpp +++ b/include/ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm_v2.hpp @@ -324,7 +324,9 @@ struct TransformConvBwdWeightToGemmV2 const std::array& conv_filter_dilations, const std::array& input_left_pads, const std::array& input_right_pads, - const index_t batch_k) + const index_t batch_k, + const bool split_k_offset_a_hack = false, + const bool split_k_offset_b_hack = false) { using namespace ck; @@ -353,7 +355,9 @@ struct TransformConvBwdWeightToGemmV2 const index_t GemmK0 = math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) * K0PerBlock; - const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number; + const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number; + const index_t KBatchIndexA = split_k_offset_a_hack ? 1 : GemmKBatch; + const index_t KBatchIndexB = split_k_offset_b_hack ? 1 : GemmKBatch; const auto out_grid_desc = make_out_grid_desc(N, Wo, K, output_strides); const auto in_grid_desc = make_in_grid_desc(N, Wi, C, input_strides); @@ -373,7 +377,7 @@ struct TransformConvBwdWeightToGemmV2 const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( out_gemmkpad_gemmm_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)), + make_tuple(make_unmerge_transform(make_tuple(KBatchIndexA * GemmK0, GemmK1Number)), make_right_pad_transform(GemmM, PadGemmM)), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 2>{}, Sequence<1>{})); @@ -389,7 +393,7 @@ struct TransformConvBwdWeightToGemmV2 const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( in_gemmkpad_gemmn_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)), + make_tuple(make_unmerge_transform(make_tuple(KBatchIndexB * GemmK0, GemmK1Number)), make_right_pad_transform(GemmN, PadGemmN)), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 2>{}, Sequence<1>{})); @@ -419,7 +423,7 @@ struct TransformConvBwdWeightToGemmV2 const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( out_gemmkpad_gemmm_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)), + make_tuple(make_unmerge_transform(make_tuple(KBatchIndexA * GemmK0, GemmK1Number)), make_right_pad_transform(GemmM, PadGemmM)), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 2>{}, Sequence<1>{})); @@ -460,7 +464,7 @@ struct TransformConvBwdWeightToGemmV2 const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( in_gemmkpad_gemmn_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)), + make_tuple(make_unmerge_transform(make_tuple(KBatchIndexB * GemmK0, GemmK1Number)), make_right_pad_transform(GemmN, PadGemmN)), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 2>{}, Sequence<1>{})); @@ -495,7 +499,9 @@ struct TransformConvBwdWeightToGemmV2 const std::array& conv_filter_dilations, const std::array& input_left_pads, const std::array& input_right_pads, - const index_t batch_k) + const index_t batch_k, + const bool split_k_offset_a_hack = false, + const bool split_k_offset_b_hack = false) { using namespace ck; @@ -531,7 +537,9 @@ struct TransformConvBwdWeightToGemmV2 const index_t GemmK0 = math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) * K0PerBlock; - const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number; + const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number; + const index_t KBatchIndexA = split_k_offset_a_hack ? 1 : GemmKBatch; + const index_t KBatchIndexB = split_k_offset_b_hack ? 1 : GemmKBatch; const auto out_grid_desc = make_out_grid_desc(N, Ho, Wo, K, output_strides); const auto in_grid_desc = make_in_grid_desc(N, Hi, Wi, C, input_strides); @@ -551,7 +559,7 @@ struct TransformConvBwdWeightToGemmV2 const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( out_gemmkpad_gemmm_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)), + make_tuple(make_unmerge_transform(make_tuple(KBatchIndexA * GemmK0, GemmK1Number)), make_right_pad_transform(GemmM, PadGemmM)), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 2>{}, Sequence<1>{})); @@ -567,7 +575,7 @@ struct TransformConvBwdWeightToGemmV2 const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( in_gemmkpad_gemmn_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)), + make_tuple(make_unmerge_transform(make_tuple(KBatchIndexB * GemmK0, GemmK1Number)), make_right_pad_transform(GemmN, PadGemmN)), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 2>{}, Sequence<1>{})); @@ -597,7 +605,7 @@ struct TransformConvBwdWeightToGemmV2 const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( out_gemmkpad_gemmm_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)), + make_tuple(make_unmerge_transform(make_tuple(KBatchIndexA * GemmK0, GemmK1Number)), make_right_pad_transform(GemmM, PadGemmM)), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 2>{}, Sequence<1>{})); @@ -647,7 +655,7 @@ struct TransformConvBwdWeightToGemmV2 const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( in_gemmkpad_gemmn_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)), + make_tuple(make_unmerge_transform(make_tuple(KBatchIndexB * GemmK0, GemmK1Number)), make_right_pad_transform(GemmN, PadGemmN)), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 2>{}, Sequence<1>{})); @@ -681,7 +689,9 @@ struct TransformConvBwdWeightToGemmV2 const std::array& conv_filter_dilations, const std::array& input_left_pads, const std::array& input_right_pads, - const index_t batch_k) + const index_t batch_k, + const bool split_k_offset_a_hack = false, + const bool split_k_offset_b_hack = false) { using namespace ck; @@ -724,7 +734,9 @@ struct TransformConvBwdWeightToGemmV2 const index_t GemmK0 = math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) * K0PerBlock; - const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number; + const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number; + const index_t KBatchIndexA = split_k_offset_a_hack ? 1 : GemmKBatch; + const index_t KBatchIndexB = split_k_offset_b_hack ? 1 : GemmKBatch; const auto out_grid_desc = make_out_grid_desc(N, Do, Ho, Wo, K, output_strides); const auto in_grid_desc = make_in_grid_desc(N, Di, Hi, Wi, C, input_strides); @@ -744,7 +756,7 @@ struct TransformConvBwdWeightToGemmV2 const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( out_gemmkpad_gemmm_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)), + make_tuple(make_unmerge_transform(make_tuple(KBatchIndexA * GemmK0, GemmK1Number)), make_right_pad_transform(GemmM, PadGemmM)), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 2>{}, Sequence<1>{})); @@ -760,7 +772,7 @@ struct TransformConvBwdWeightToGemmV2 const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( in_gemmkpad_gemmn_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)), + make_tuple(make_unmerge_transform(make_tuple(KBatchIndexB * GemmK0, GemmK1Number)), make_right_pad_transform(GemmN, PadGemmN)), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 2>{}, Sequence<1>{})); @@ -790,7 +802,7 @@ struct TransformConvBwdWeightToGemmV2 const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( out_gemmkpad_gemmm_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)), + make_tuple(make_unmerge_transform(make_tuple(KBatchIndexA * GemmK0, GemmK1Number)), make_right_pad_transform(GemmM, PadGemmM)), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 2>{}, Sequence<1>{})); @@ -855,7 +867,7 @@ struct TransformConvBwdWeightToGemmV2 const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( in_gemmkpad_gemmn_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)), + make_tuple(make_unmerge_transform(make_tuple(KBatchIndexB * GemmK0, GemmK1Number)), make_right_pad_transform(GemmN, PadGemmN)), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 2>{}, Sequence<1>{}));