diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp index 87c7697386..9245a54b7b 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp @@ -19,6 +19,7 @@ #include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/kernel_launch.hpp" #include "ck/tensor_operation/gpu/device/impl/split_k_arg.hpp" +#include "ck/tensor_operation/gpu/device/impl/split_k_utils.hpp" #ifdef CK_EXPERIMENTAL_BUILDER #include "ck_tile/builder/reflect/description.hpp" @@ -853,6 +854,7 @@ struct DeviceGroupedConvBwdWeight_Dl : public DeviceGroupedConvBwdWeight( diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_explicit.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_explicit.hpp index a811d2f44a..172a53d652 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_explicit.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_explicit.hpp @@ -179,6 +179,7 @@ struct DeviceGroupedConvBwdWeight_Explicit k_batch_ = split_k; } } + k_batch_ = clamp_gemm_k_batch(k_batch_); if constexpr(IsTwoStageNeeded) { diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp index a3eab579e7..ed0378e23f 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp @@ -670,6 +670,7 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3 { k_batch_ = split_k; } + k_batch_ = clamp_gemm_k_batch(k_batch_); const auto descs = conv_to_gemm_transformer diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp index 1e23fef191..ff0616481f 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp @@ -695,6 +695,7 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle { k_batch_ = split_k; } + k_batch_ = clamp_gemm_k_batch(k_batch_); const auto descs = conv_to_gemm_transformer diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_wmma_cshuffle_v3.hpp index 87117be4ce..bc44cf2bb3 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_wmma_cshuffle_v3.hpp @@ -611,6 +611,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3 { k_batch_ = split_k; } + k_batch_ = clamp_gemm_k_batch(k_batch_); const auto descs = conv_to_gemm_transformer_v2 diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp index 0ee5ac3647..011bb068f9 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp @@ -717,6 +717,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle { k_batch_ = split_k; } + k_batch_ = clamp_gemm_k_batch(k_batch_); // Create initial descriptors with hack=false to check compactness const auto descs_initial = diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp index bfc88753a2..66fb526641 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp @@ -555,6 +555,7 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3 { k_batch_ = split_k; } + k_batch_ = clamp_gemm_k_batch(k_batch_); std::array a_g_n_k_wos_strides_transposed = conv_ngchw_to_nhwgc_transformer.TransposeInOutStrides(a_g_n_k_wos_lengths, diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp index 46a9009f83..fef81b281a 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp @@ -669,6 +669,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle { k_batch_ = split_k; } + k_batch_ = clamp_gemm_k_batch(k_batch_); // Create descriptors first (with hack flags temporarily set to false) // so we can check if element space sizes are divisible by k_batch diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp index 58de8dd3dc..07c8e02514 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp @@ -638,6 +638,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 { k_batch_ = split_k; } + k_batch_ = clamp_gemm_k_batch(k_batch_); // Create descriptors first (with hack flags temporarily set to false) // so we can check if element space sizes match product of dimensions diff --git a/include/ck/tensor_operation/gpu/device/impl/split_k_utils.hpp b/include/ck/tensor_operation/gpu/device/impl/split_k_utils.hpp index 3a3bacd945..ea5b282ed1 100644 --- a/include/ck/tensor_operation/gpu/device/impl/split_k_utils.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/split_k_utils.hpp @@ -13,6 +13,13 @@ namespace ck { namespace tensor_operation { namespace device { +/// Ensures GemmKBatch in conv to GEMM transforms is never 0 (would zero the divisor in +/// integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch)). +inline constexpr index_t clamp_gemm_k_batch(index_t k_batch) noexcept +{ + return k_batch < 1 ? index_t{1} : k_batch; +} + struct DeviceProperties { DeviceProperties() @@ -33,6 +40,10 @@ inline ck::index_t get_best_occupancy_k_batch_value(int max_occupancy, ck::index const int max_capacity = max_occupancy * device_properties.num_cu_; ck::index_t k_batch = 1; + if(grid_size <= 0) + { + return k_batch; + } const auto optimal_split = static_cast(std::floor((1.0 * max_capacity) / grid_size)); if(optimal_split > 1) diff --git a/include/ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm.hpp b/include/ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm.hpp index 3379fb2c59..74ec0af7d5 100644 --- a/include/ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm.hpp +++ b/include/ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm.hpp @@ -21,6 +21,10 @@ template struct TransformConvBwdWeightToGemm { + // Same contract as TransformConvBwdWeightToGemmV2 (non-zero K tile factors). + static_assert(GemmK1Number > 0, "GemmK1Number must be positive"); + static_assert(K0PerBlock > 0, "K0PerBlock must be positive"); + static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; diff --git a/include/ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm_v2.hpp b/include/ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm_v2.hpp index 94eae555e9..eeef3e736e 100644 --- a/include/ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm_v2.hpp +++ b/include/ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm_v2.hpp @@ -31,6 +31,11 @@ template struct TransformConvBwdWeightToGemmV2 { + // Compile-time contract: divisor GemmK1Number * K0PerBlock * GemmKBatch in + // integer_divide_ceil(GemmKTotal, ...) must stay non-zero (GemmKBatch clamped at runtime). + static_assert(GemmK1Number > 0, "GemmK1Number must be positive"); + static_assert(K0PerBlock > 0, "K0PerBlock must be positive"); + static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{};