From 29743bc0f4e8bf819c201cff7d2cf5a9d39b8539 Mon Sep 17 00:00:00 2001 From: Enrico Degregori Date: Fri, 12 Dec 2025 09:49:17 +0000 Subject: [PATCH] Fix explicit conv bwd weight struct --- ...evice_grouped_conv_bwd_weight_explicit.hpp | 47 +++++++++++++++---- 1 file changed, 39 insertions(+), 8 deletions(-) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_explicit.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_explicit.hpp index 377eb5f3a1..4cac694986 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_explicit.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_explicit.hpp @@ -144,18 +144,39 @@ struct DeviceGroupedConvBwdWeight_Explicit end(e_g_k_c_xs_lengths), begin(filter_spatial_lengths_)); - if(split_k < 0) + if constexpr(IsTwoStageNeeded) { - const auto max_occupancy = DeviceGemmV3Op::GetMaxOccupancy(); - index_t gdx, gdy, gdz; - std::tie(gdx, gdy, gdz) = - DeviceGemmV3Op::GridwiseGemm::CalculateGridSize(M, N, BatchSize); - const index_t grid_size = gdx * gdy * gdz; - k_batch_ = get_best_occupancy_k_batch_value(max_occupancy, grid_size); + if(split_k < 0) + { + const auto max_occupancy = DeviceGemmV3Op::GetMaxOccupancy(); + index_t gdx, gdy, gdz; + std::tie(gdx, gdy, gdz) = + DeviceGemmV3Op::GridwiseGemm::CalculateGridSize(M, N, BatchSize); + const index_t grid_size = gdx * gdy * gdz; + k_batch_ = get_best_occupancy_k_batch_value(max_occupancy, grid_size); + } + else + { + k_batch_ = split_k; + } } else { - k_batch_ = split_k; +#if !DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS + if(split_k < 0) + { + const auto max_occupancy = DeviceGemmV3Op::GetMaxOccupancy(); + index_t gdx, gdy, gdz; + std::tie(gdx, gdy, gdz) = + DeviceGemmV3Op::GridwiseGemm::CalculateGridSize(M, N, BatchSize); + const index_t grid_size = gdx * gdy * gdz; + k_batch_ = get_best_occupancy_k_batch_value(max_occupancy, grid_size); + } + else +#endif + { + k_batch_ = split_k; + } } if constexpr(IsTwoStageNeeded) @@ -317,6 +338,16 @@ struct DeviceGroupedConvBwdWeight_Explicit static bool IsSupportedArgument(const Argument& arg) { +#if DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS + if constexpr(!IsTwoStageNeeded) + { + if(arg.split_k_ < 0) + { + return false; + } + } +#endif + if constexpr(NDimSpatial == 2) { if constexpr(!is_NHWGC_GKYXC_NHWGK())