From 281d1bf50b095f4c40da49a107857abf66930b92 Mon Sep 17 00:00:00 2001 From: Artem Kuzmitckii <6463225+k-artem@users.noreply.github.com> Date: Thu, 23 Apr 2026 20:12:40 +0000 Subject: [PATCH] [rocm-libraries] ROCm/rocm-libraries#6132 (commit e97065d) [CK] Fix divide-by-zero crash for grouped conv kernels (#6132) ## Motivation During run pytorch unit tests for conv3d: `test_dtypes_nn_functional_conv3d_cuda`, `test_fake_crossref_backward_amp_nn_functional_conv3d_cuda_float32` found divide-by-zero crash during CK kernel selection. Refs ROCM-20764 ## Technical Details Add assert for K0PerBlock equal 0, also covered other potential places related with k_batch calculation. ## Test Plan Run miopen command extracted from mentioned test: `MIOpenDriver convfp16 --spatial_dim 3 -I NCDHW -O NCDHW -f NCDHW -n 1 -c 1 -k 1 -g 1 --in_d 4 -H 4 -W 4 --fil_d 4 -y 4 -x 4 --pad_d 0 -p 0 -q 0 --conv_stride_d 2 -u 2 -v 2 --dilation_d 1 -l 1 -j 1 -m conv -F 4 -t 1` ## Test Result Passed ## Submission Checklist - [X] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests. Signed-off-by: Artem Kuzmitckii --- .../device/impl/device_grouped_conv_bwd_weight_dl.hpp | 2 ++ .../impl/device_grouped_conv_bwd_weight_explicit.hpp | 1 + ...ed_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp | 1 + ...rouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp | 1 + ...ped_conv_bwd_weight_two_stage_wmma_cshuffle_v3.hpp | 1 + ...grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp | 1 + ...evice_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp | 1 + .../device_grouped_conv_bwd_weight_xdl_cshuffle.hpp | 1 + ...device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp | 1 + .../gpu/device/impl/split_k_utils.hpp | 11 +++++++++++ .../transform_conv_bwd_weight_to_gemm.hpp | 4 ++++ .../transform_conv_bwd_weight_to_gemm_v2.hpp | 5 +++++ 12 files changed, 30 insertions(+) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp index 87c7697386..9245a54b7b 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp @@ -19,6 +19,7 @@ #include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/kernel_launch.hpp" #include "ck/tensor_operation/gpu/device/impl/split_k_arg.hpp" +#include "ck/tensor_operation/gpu/device/impl/split_k_utils.hpp" #ifdef CK_EXPERIMENTAL_BUILDER #include "ck_tile/builder/reflect/description.hpp" @@ -853,6 +854,7 @@ struct DeviceGroupedConvBwdWeight_Dl : public DeviceGroupedConvBwdWeight( diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_explicit.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_explicit.hpp index a811d2f44a..172a53d652 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_explicit.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_explicit.hpp @@ -179,6 +179,7 @@ struct DeviceGroupedConvBwdWeight_Explicit k_batch_ = split_k; } } + k_batch_ = clamp_gemm_k_batch(k_batch_); if constexpr(IsTwoStageNeeded) { diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp index a3eab579e7..ed0378e23f 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp @@ -670,6 +670,7 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3 { k_batch_ = split_k; } + k_batch_ = clamp_gemm_k_batch(k_batch_); const auto descs = conv_to_gemm_transformer diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp index 1e23fef191..ff0616481f 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp @@ -695,6 +695,7 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle { k_batch_ = split_k; } + k_batch_ = clamp_gemm_k_batch(k_batch_); const auto descs = conv_to_gemm_transformer diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_wmma_cshuffle_v3.hpp index 87117be4ce..bc44cf2bb3 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_wmma_cshuffle_v3.hpp @@ -611,6 +611,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3 { k_batch_ = split_k; } + k_batch_ = clamp_gemm_k_batch(k_batch_); const auto descs = conv_to_gemm_transformer_v2 diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp index 0ee5ac3647..011bb068f9 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp @@ -717,6 +717,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle { k_batch_ = split_k; } + k_batch_ = clamp_gemm_k_batch(k_batch_); // Create initial descriptors with hack=false to check compactness const auto descs_initial = diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp index bfc88753a2..66fb526641 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp @@ -555,6 +555,7 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3 { k_batch_ = split_k; } + k_batch_ = clamp_gemm_k_batch(k_batch_); std::array a_g_n_k_wos_strides_transposed = conv_ngchw_to_nhwgc_transformer.TransposeInOutStrides(a_g_n_k_wos_lengths, diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp index 46a9009f83..fef81b281a 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp @@ -669,6 +669,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle { k_batch_ = split_k; } + k_batch_ = clamp_gemm_k_batch(k_batch_); // Create descriptors first (with hack flags temporarily set to false) // so we can check if element space sizes are divisible by k_batch diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp index 58de8dd3dc..07c8e02514 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp @@ -638,6 +638,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 { k_batch_ = split_k; } + k_batch_ = clamp_gemm_k_batch(k_batch_); // Create descriptors first (with hack flags temporarily set to false) // so we can check if element space sizes match product of dimensions diff --git a/include/ck/tensor_operation/gpu/device/impl/split_k_utils.hpp b/include/ck/tensor_operation/gpu/device/impl/split_k_utils.hpp index 3a3bacd945..ea5b282ed1 100644 --- a/include/ck/tensor_operation/gpu/device/impl/split_k_utils.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/split_k_utils.hpp @@ -13,6 +13,13 @@ namespace ck { namespace tensor_operation { namespace device { +/// Ensures GemmKBatch in conv to GEMM transforms is never 0 (would zero the divisor in +/// integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch)). +inline constexpr index_t clamp_gemm_k_batch(index_t k_batch) noexcept +{ + return k_batch < 1 ? index_t{1} : k_batch; +} + struct DeviceProperties { DeviceProperties() @@ -33,6 +40,10 @@ inline ck::index_t get_best_occupancy_k_batch_value(int max_occupancy, ck::index const int max_capacity = max_occupancy * device_properties.num_cu_; ck::index_t k_batch = 1; + if(grid_size <= 0) + { + return k_batch; + } const auto optimal_split = static_cast(std::floor((1.0 * max_capacity) / grid_size)); if(optimal_split > 1) diff --git a/include/ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm.hpp b/include/ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm.hpp index 3379fb2c59..74ec0af7d5 100644 --- a/include/ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm.hpp +++ b/include/ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm.hpp @@ -21,6 +21,10 @@ template struct TransformConvBwdWeightToGemm { + // Same contract as TransformConvBwdWeightToGemmV2 (non-zero K tile factors). + static_assert(GemmK1Number > 0, "GemmK1Number must be positive"); + static_assert(K0PerBlock > 0, "K0PerBlock must be positive"); + static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; diff --git a/include/ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm_v2.hpp b/include/ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm_v2.hpp index 94eae555e9..eeef3e736e 100644 --- a/include/ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm_v2.hpp +++ b/include/ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm_v2.hpp @@ -31,6 +31,11 @@ template struct TransformConvBwdWeightToGemmV2 { + // Compile-time contract: divisor GemmK1Number * K0PerBlock * GemmKBatch in + // integer_divide_ceil(GemmKTotal, ...) must stay non-zero (GemmKBatch clamped at runtime). + static_assert(GemmK1Number > 0, "GemmK1Number must be positive"); + static_assert(K0PerBlock > 0, "K0PerBlock must be positive"); + static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{};