diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp index c9d9db1da2..6aec428962 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp @@ -615,15 +615,16 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 input_right_pads, k_batch_, split_k_offset_a_hack_, - split_k_offset_b_hack_); + split_k_offset_b_hack_, + true); // use_full_batch_kindex=true for V1-compatible descriptors a_grid_desc_k0_m_k1_ = descs[I0]; b_grid_desc_k0_n_k1_ = descs[I1]; c_grid_desc_m_n_ = descs[I2]; // Calculate stride from descriptor size - // NOTE: GetElementSpaceSize() returns the full size even when KBatchIndex=1, - // so we need to divide by k_batch_ to get the per-batch stride when the hack is enabled + // With use_full_batch_kindex=true, descriptors contain full k-batch dimension + // so we divide by k_batch_ to get per-batch stride split_k_stride_a_ = a_grid_desc_k0_m_k1_.GetElementSpaceSize(); if(split_k_offset_a_hack_) split_k_stride_a_ /= k_batch_; @@ -810,7 +811,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 || BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) { - if(gemm_arg.KBatch > 1) + if(arg.k_batch_ > 1) { const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< GridwiseGemm, @@ -842,7 +843,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 // Tail number could be One to Seven else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2) { - if(gemm_arg.KBatch > 1) + if(arg.k_batch_ > 1) { if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One) { @@ -1151,7 +1152,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 // Tail number could be Odd or Even else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4) { - if(gemm_arg.KBatch > 1) + if(arg.k_batch_ > 1) { if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) { @@ -1220,7 +1221,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 } else { - if(gemm_arg.KBatch > 1) + if(arg.k_batch_ > 1) { if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) { @@ -1293,7 +1294,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 // Tail number always 1 if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) { - if(gemm_arg.KBatch > 1) + if(arg.k_batch_ > 1) { const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< GridwiseGemm, diff --git a/include/ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm_v2.hpp b/include/ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm_v2.hpp index 98b7e29439..e1576ec27d 100644 --- a/include/ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm_v2.hpp +++ b/include/ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm_v2.hpp @@ -326,7 +326,8 @@ struct TransformConvBwdWeightToGemmV2 const std::array& input_right_pads, const index_t batch_k, const bool split_k_offset_a_hack = false, - const bool split_k_offset_b_hack = false) + const bool split_k_offset_b_hack = false, + const bool use_full_batch_kindex = false) { using namespace ck; @@ -355,9 +356,13 @@ struct TransformConvBwdWeightToGemmV2 const index_t GemmK0 = math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) * K0PerBlock; - const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number; - const index_t KBatchIndexA = split_k_offset_a_hack ? 1 : GemmKBatch; - const index_t KBatchIndexB = split_k_offset_b_hack ? 1 : GemmKBatch; + const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number; + // When use_full_batch_kindex=true, create full-batch descriptors (V1 mode) for gridwise + // kernel compatibility + const index_t KBatchIndexA = + (split_k_offset_a_hack && !use_full_batch_kindex) ? 1 : GemmKBatch; + const index_t KBatchIndexB = + (split_k_offset_b_hack && !use_full_batch_kindex) ? 1 : GemmKBatch; const auto out_grid_desc = make_out_grid_desc(N, Wo, K, output_strides); const auto in_grid_desc = make_in_grid_desc(N, Wi, C, input_strides); @@ -501,7 +506,8 @@ struct TransformConvBwdWeightToGemmV2 const std::array& input_right_pads, const index_t batch_k, const bool split_k_offset_a_hack = false, - const bool split_k_offset_b_hack = false) + const bool split_k_offset_b_hack = false, + const bool use_full_batch_kindex = false) { using namespace ck; @@ -537,9 +543,13 @@ struct TransformConvBwdWeightToGemmV2 const index_t GemmK0 = math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) * K0PerBlock; - const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number; - const index_t KBatchIndexA = split_k_offset_a_hack ? 1 : GemmKBatch; - const index_t KBatchIndexB = split_k_offset_b_hack ? 1 : GemmKBatch; + const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number; + // When use_full_batch_kindex=true, create full-batch descriptors (V1 mode) for gridwise + // kernel compatibility + const index_t KBatchIndexA = + (split_k_offset_a_hack && !use_full_batch_kindex) ? 1 : GemmKBatch; + const index_t KBatchIndexB = + (split_k_offset_b_hack && !use_full_batch_kindex) ? 1 : GemmKBatch; const auto out_grid_desc = make_out_grid_desc(N, Ho, Wo, K, output_strides); const auto in_grid_desc = make_in_grid_desc(N, Hi, Wi, C, input_strides); @@ -691,7 +701,8 @@ struct TransformConvBwdWeightToGemmV2 const std::array& input_right_pads, const index_t batch_k, const bool split_k_offset_a_hack = false, - const bool split_k_offset_b_hack = false) + const bool split_k_offset_b_hack = false, + const bool use_full_batch_kindex = false) { using namespace ck; @@ -734,9 +745,13 @@ struct TransformConvBwdWeightToGemmV2 const index_t GemmK0 = math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) * K0PerBlock; - const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number; - const index_t KBatchIndexA = split_k_offset_a_hack ? 1 : GemmKBatch; - const index_t KBatchIndexB = split_k_offset_b_hack ? 1 : GemmKBatch; + const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number; + // When use_full_batch_kindex=true, create full-batch descriptors (V1 mode) for gridwise + // kernel compatibility + const index_t KBatchIndexA = + (split_k_offset_a_hack && !use_full_batch_kindex) ? 1 : GemmKBatch; + const index_t KBatchIndexB = + (split_k_offset_b_hack && !use_full_batch_kindex) ? 1 : GemmKBatch; const auto out_grid_desc = make_out_grid_desc(N, Do, Ho, Wo, K, output_strides); const auto in_grid_desc = make_in_grid_desc(N, Di, Hi, Wi, C, input_strides);