diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp index a47a2f90cc..5846cc8913 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp @@ -759,6 +759,48 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle input_right_pads, k_batch_)[I2]; + // Extract complex conditions to named boolean variables for clarity + const index_t output_spatial_acum = ck::accumulate_n( + output_spatial_lengths_.begin(), NDimSpatial, 1, std::multiplies<>()); + + const bool is_k_not_paded = + (Conv_N_ * output_spatial_acum) % (KPerBlock * k_batch_) == 0; + + const bool can_divide_n_spatial_by_k_batch = + (Conv_N_ * output_spatial_acum) % k_batch_ == 0; + + const bool can_divide_n_by_k_batch = Conv_N_ % k_batch_ == 0; + + const bool is_correct_layout = + is_NSpatialGC_GKSpatial_NSpatialGK(); + + // NEW: Check if descriptor element space sizes are divisible by k_batch_ + // This prevents integer division truncation when calculating strides + const bool is_a_stride_divisible = + a_grid_desc_k0_m_k1_.GetElementSpaceSize() % k_batch_ == 0; + + const bool is_b_stride_divisible = + b_grid_desc_k0_n_k1_.GetElementSpaceSize() % k_batch_ == 0; + + // Determine if we can safely use the split-k offset hack + split_k_offset_a_hack_ = k_batch_ > 1 && can_divide_n_spatial_by_k_batch && + is_k_not_paded && is_correct_layout && is_a_stride_divisible; + + split_k_offset_b_hack_ = k_batch_ > 1 && can_divide_n_by_k_batch && is_k_not_paded && + is_correct_layout && is_b_stride_divisible; + + // Calculate stride from descriptor size + // NOTE: GetElementSpaceSize() returns the full size even when KBatchIndex=1, + // so we need to divide by k_batch_ to get the per-batch stride when the hack is + // enabled. The divisibility check above ensures no truncation occurs. + split_k_stride_a_ = a_grid_desc_k0_m_k1_.GetElementSpaceSize(); + if(split_k_offset_a_hack_) + split_k_stride_a_ /= k_batch_; + + split_k_stride_b_ = b_grid_desc_k0_n_k1_.GetElementSpaceSize(); + if(split_k_offset_b_hack_) + split_k_stride_b_ /= k_batch_; + const index_t GemmM = a_grid_desc_k0_m_k1_.GetLength(I1); const index_t GemmN = b_grid_desc_k0_n_k1_.GetLength(I1); @@ -810,30 +852,6 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle e_in_transpose_desc_.GetLength(I1)} : Block2TileMapElementwise{ce_grid_desc_m_n_.GetLength(I0), ce_grid_desc_m_n_.GetLength(I1)}; - - const index_t output_spatial_acum = ck::accumulate_n( - output_spatial_lengths_.begin(), NDimSpatial, 1, std::multiplies<>()); - const bool is_k_not_paded = - (Conv_N_ * output_spatial_acum) % (KPerBlock * k_batch_) == 0; - // Check if there is KPading and we can divide N * OutSpatialDims by k_batch - split_k_offset_a_hack_ = - k_batch_ > 1 && (Conv_N_ * output_spatial_acum) % k_batch_ == 0 && is_k_not_paded && - is_NSpatialGC_GKSpatial_NSpatialGK(); - // Check if there is KPading and we can divide N by k_batch - split_k_offset_b_hack_ = - k_batch_ > 1 && Conv_N_ % k_batch_ == 0 && is_k_not_paded && - is_NSpatialGC_GKSpatial_NSpatialGK(); - - // Calculate stride from descriptor size - // NOTE: GetElementSpaceSize() returns the full size even when KBatchIndex=1, - // so we need to divide by k_batch_ to get the per-batch stride when the hack is enabled - split_k_stride_a_ = a_grid_desc_k0_m_k1_.GetElementSpaceSize(); - if(split_k_offset_a_hack_) - split_k_stride_a_ /= k_batch_; - - split_k_stride_b_ = b_grid_desc_k0_n_k1_.GetElementSpaceSize(); - if(split_k_offset_b_hack_) - split_k_stride_b_ /= k_batch_; } std::size_t GetWorkspaceATensorSizeBytes() const diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp index 2b5770dc00..04e40059d1 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp @@ -590,19 +590,59 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle k_batch_ = split_k; } + // Create descriptors first (with hack flags temporarily set to false) + // so we can check if element space sizes are divisible by k_batch + const auto descs_initial = + conv_to_gemm_transformer + .template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( + Conv_N_, + Conv_K_, + Conv_C_, + input_spatial_lengths_, + filter_spatial_lengths_, + output_spatial_lengths_, + b_g_n_c_wis_strides_transposed, + e_g_k_c_xs_strides_transposed, + a_g_n_k_wos_strides_transposed, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + k_batch_, + false, // split_k_offset_a_hack (temporary) + false); // split_k_offset_b_hack (temporary) + + // Extract complex conditions to named boolean variables for clarity const index_t output_spatial_acum = ck::accumulate_n( output_spatial_lengths_.begin(), NDimSpatial, 1, std::multiplies<>()); + const bool is_k_not_paded = (Conv_N_ * output_spatial_acum) % (K0PerBlock * K1 * k_batch_) == 0; - // Check if there is KPading and we can divide N * OutSpatialDims by k_batch - split_k_offset_a_hack_ = - k_batch_ > 1 && (Conv_N_ * output_spatial_acum) % k_batch_ == 0 && is_k_not_paded && - is_NSpatialGC_GKSpatial_NSpatialGK(); - // Check if there is KPading and we can divide N by k_batch - split_k_offset_b_hack_ = - k_batch_ > 1 && Conv_N_ % k_batch_ == 0 && is_k_not_paded && + + const bool can_divide_n_spatial_by_k_batch = + (Conv_N_ * output_spatial_acum) % k_batch_ == 0; + + const bool can_divide_n_by_k_batch = Conv_N_ % k_batch_ == 0; + + const bool is_correct_layout = is_NSpatialGC_GKSpatial_NSpatialGK(); + // NEW: Check if descriptor element space sizes are divisible by k_batch_ + // This prevents integer division truncation when calculating strides + const bool is_a_stride_divisible = + descs_initial[I0].GetElementSpaceSize() % k_batch_ == 0; + + const bool is_b_stride_divisible = + descs_initial[I1].GetElementSpaceSize() % k_batch_ == 0; + + // Determine if we can safely use the split-k offset hack + split_k_offset_a_hack_ = k_batch_ > 1 && can_divide_n_spatial_by_k_batch && + is_k_not_paded && is_correct_layout && is_a_stride_divisible; + + split_k_offset_b_hack_ = k_batch_ > 1 && can_divide_n_by_k_batch && is_k_not_paded && + is_correct_layout && is_b_stride_divisible; + + // Now create descriptors with the correct hack flags const auto descs = conv_to_gemm_transformer .template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( @@ -629,7 +669,8 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle // Calculate stride from descriptor size // NOTE: GetElementSpaceSize() returns the full size even when KBatchIndex=1, - // so we need to divide by k_batch_ to get the per-batch stride when the hack is enabled + // so we need to divide by k_batch_ to get the per-batch stride when the hack is + // enabled. The divisibility check above ensures no truncation occurs. split_k_stride_a_ = a_grid_desc_kbatch_k0_m_k1_.GetElementSpaceSize(); if(split_k_offset_a_hack_) split_k_stride_a_ /= k_batch_;