From 9dbbb07953dc1ffba8873f60634c5f8d8fce4969 Mon Sep 17 00:00:00 2001 From: Enrico Degregori Date: Thu, 7 Aug 2025 07:27:11 +0000 Subject: [PATCH] Fix bug and disable splitK=-1 tests for wmma --- ...ouped_conv_bwd_weight_wmma_cshuffle_v3.hpp | 21 +++++++------------ .../test_grouped_convnd_bwd_weight.cpp | 5 +++++ 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp index d064136964..79ec51228b 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp @@ -447,6 +447,11 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3 input_right_pads_{input_right_pads}, k_batch_{split_k} { + c_space_size_bytes = + ck::accumulate_n( + e_g_k_c_xs_lengths.begin(), NDimSpatial + I3, 1, std::multiplies<>()) * + sizeof(WeiDataType); + constexpr index_t spatial_offset = 3; std::copy(begin(b_g_n_c_wis_lengths) + spatial_offset, end(b_g_n_c_wis_lengths), @@ -629,6 +634,7 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3 const std::array& input_left_pads_; const std::array& input_right_pads_; const index_t k_batch_; + long_index_t c_space_size_bytes; }; // Invoker @@ -757,12 +763,8 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3 arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(Number<0>{}) / gemm_arg.KBatch; const auto clear_workspace = [&]() { - if constexpr(is_NGCHW_GKCYX_NGKHW() || - is_NGCDHW_GKCZYX_NGKDHW()) - { - hip_check_error(hipMemsetAsync( - p_e_grid, 0, arg.GetWorkspaceETensorSizeBytes(), stream_config.stream_id_)); - } + hip_check_error( + hipMemsetAsync(p_e_grid, 0, arg.c_space_size_bytes, stream_config.stream_id_)); }; const auto Run = [&](const auto& kernel) { @@ -1047,13 +1049,6 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3 if constexpr(ConvBackwardWeightSpecialization == ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0) { -// workaround: disable when K, C is even -#if CK_WORKAROUND_DISABLE_FILTER1x1STRIDE1PAD0_WHEN_K_C_IS_EVEN - if(arg.Conv_C_ % 2 == 0 || arg.Conv_K_ % 2 == 0) - { - return false; - } -#endif // check if it's 1x1, stride=1 pad = 0 conv for(int i = 0; i < NDimSpatial; i++) { diff --git a/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight.cpp b/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight.cpp index fd84f57095..e216ecb27c 100644 --- a/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight.cpp +++ b/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight.cpp @@ -44,6 +44,11 @@ class TestGroupedConvndBwdWeight : public ::testing::Test } } + if((split_k < 1) && (ck::is_gfx11_supported() || ck::is_gfx12_supported())) + { + return true; + } + return false; }