From f4b22287cdbbf0f0d827f29bd591f123f557c53a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20Kocot?= Date: Mon, 22 Dec 2025 21:32:48 +0100 Subject: [PATCH] Fix grouped conv fwd wmma porting (#3479) * Fix grouped conv fwd wmma porting * add more limitations [ROCm/composable_kernel commit: 2955d77f3cfb3515c6d36d54879ed65b854dafa6] --- ...ce_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp | 12 ++++++++++-- ...grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp | 12 ++++++++++-- ...conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp | 12 ++++++++++-- 3 files changed, 30 insertions(+), 6 deletions(-) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp index 6229362a7a..7cb0ae20c3 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp @@ -327,8 +327,16 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle using DeviceOp = DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle; GET_MXDL_PER_WAVE_IMPL // Force usage of 16x16 instruction for WMMA - static constexpr index_t Wave32MaxMNPerXDL = 16; - static constexpr auto MXdlPerWave64 = GetMXdlPerWave(); + static constexpr bool Wave32Force16MNPerXDL = + is_NSpatialGC_GKSpatial_NSpatialGK() && + sizeof(AComputeDataType) == 2 && sizeof(BComputeDataType) == 2 && + is_same_v && + (ConvForwardSpecialization == ConvolutionForwardSpecialization::Filter1x1Stride1Pad0 || + ConvForwardSpecialization == ConvolutionForwardSpecialization::Default); + static constexpr index_t Wave32MaxMNPerXDL = + Wave32Force16MNPerXDL ? 16 : math::max(MPerXDL, NPerXDL); + + static constexpr auto MXdlPerWave64 = GetMXdlPerWave(); static constexpr auto MXdlPerWave32 = GetMXdlPerWave(); + static constexpr bool Wave32Force16MNPerXDL = + is_NSpatialGC_GKSpatial_NSpatialGK() && + sizeof(AComputeDataType) == 2 && sizeof(BComputeDataType) == 2 && + is_same_v && + (ConvForwardSpecialization == ConvolutionForwardSpecialization::Filter1x1Stride1Pad0 || + ConvForwardSpecialization == ConvolutionForwardSpecialization::Default); + static constexpr index_t Wave32MaxMNPerXDL = + Wave32Force16MNPerXDL ? 16 : math::max(MPerXDL, NPerXDL); + + static constexpr auto MXdlPerWave64 = GetMXdlPerWave(); static constexpr auto MXdlPerWave32 = GetMXdlPerWave(); + static constexpr bool Wave32Force16MNPerXDL = + is_NSpatialGC_GKSpatial_NSpatialGK() && + sizeof(AComputeDataType) == 2 && sizeof(BComputeDataType) == 2 && + is_same_v && + (ConvForwardSpecialization == ConvolutionForwardSpecialization::Filter1x1Stride1Pad0 || + ConvForwardSpecialization == ConvolutionForwardSpecialization::Default); + static constexpr index_t Wave32MaxMNPerXDL = + Wave32Force16MNPerXDL ? 16 : math::max(MPerXDL, NPerXDL); + + static constexpr auto MXdlPerWave64 = GetMXdlPerWave(); static constexpr auto MXdlPerWave32 = GetMXdlPerWave