diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp index 6229362a7a..7cb0ae20c3 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp @@ -327,8 +327,16 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle using DeviceOp = DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle; GET_MXDL_PER_WAVE_IMPL // Force usage of 16x16 instruction for WMMA - static constexpr index_t Wave32MaxMNPerXDL = 16; - static constexpr auto MXdlPerWave64 = GetMXdlPerWave(); + static constexpr bool Wave32Force16MNPerXDL = + is_NSpatialGC_GKSpatial_NSpatialGK() && + sizeof(AComputeDataType) == 2 && sizeof(BComputeDataType) == 2 && + is_same_v && + (ConvForwardSpecialization == ConvolutionForwardSpecialization::Filter1x1Stride1Pad0 || + ConvForwardSpecialization == ConvolutionForwardSpecialization::Default); + static constexpr index_t Wave32MaxMNPerXDL = + Wave32Force16MNPerXDL ? 16 : math::max(MPerXDL, NPerXDL); + + static constexpr auto MXdlPerWave64 = GetMXdlPerWave(); static constexpr auto MXdlPerWave32 = GetMXdlPerWave(); + static constexpr bool Wave32Force16MNPerXDL = + is_NSpatialGC_GKSpatial_NSpatialGK() && + sizeof(AComputeDataType) == 2 && sizeof(BComputeDataType) == 2 && + is_same_v && + (ConvForwardSpecialization == ConvolutionForwardSpecialization::Filter1x1Stride1Pad0 || + ConvForwardSpecialization == ConvolutionForwardSpecialization::Default); + static constexpr index_t Wave32MaxMNPerXDL = + Wave32Force16MNPerXDL ? 16 : math::max(MPerXDL, NPerXDL); + + static constexpr auto MXdlPerWave64 = GetMXdlPerWave(); static constexpr auto MXdlPerWave32 = GetMXdlPerWave(); + static constexpr bool Wave32Force16MNPerXDL = + is_NSpatialGC_GKSpatial_NSpatialGK() && + sizeof(AComputeDataType) == 2 && sizeof(BComputeDataType) == 2 && + is_same_v && + (ConvForwardSpecialization == ConvolutionForwardSpecialization::Filter1x1Stride1Pad0 || + ConvForwardSpecialization == ConvolutionForwardSpecialization::Default); + static constexpr index_t Wave32MaxMNPerXDL = + Wave32Force16MNPerXDL ? 16 : math::max(MPerXDL, NPerXDL); + + static constexpr auto MXdlPerWave64 = GetMXdlPerWave(); static constexpr auto MXdlPerWave32 = GetMXdlPerWave