Fix grouped conv fwd wmma porting (#3479)

* Fix grouped conv fwd wmma porting

* add more limitations
This commit is contained in:
Bartłomiej Kocot
2025-12-22 21:32:48 +01:00
committed by GitHub
parent a8aebb7a8e
commit 2955d77f3c
3 changed files with 30 additions and 6 deletions

View File

@@ -327,8 +327,16 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
using DeviceOp = DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle;
GET_MXDL_PER_WAVE_IMPL
// Force usage of 16x16 instruction for WMMA
static constexpr index_t Wave32MaxMNPerXDL = 16;
static constexpr auto MXdlPerWave64 = GetMXdlPerWave<true>();
static constexpr bool Wave32Force16MNPerXDL =
is_NSpatialGC_GKSpatial_NSpatialGK<ALayout, BLayout, ELayout>() &&
sizeof(AComputeDataType) == 2 && sizeof(BComputeDataType) == 2 &&
is_same_v<CDEElementwiseOperation, tensor_operation::element_wise::PassThrough> &&
(ConvForwardSpecialization == ConvolutionForwardSpecialization::Filter1x1Stride1Pad0 ||
ConvForwardSpecialization == ConvolutionForwardSpecialization::Default);
static constexpr index_t Wave32MaxMNPerXDL =
Wave32Force16MNPerXDL ? 16 : math::max(MPerXDL, NPerXDL);
static constexpr auto MXdlPerWave64 = GetMXdlPerWave<true>();
static constexpr auto MXdlPerWave32 =
GetMXdlPerWave<false,
Wave32MaxMNPerXDL,

View File

@@ -402,8 +402,16 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
using DeviceOp = DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3;
GET_MXDL_PER_WAVE_IMPL
// Force usage of 16x16 instruction for WMMA
static constexpr index_t Wave32MaxMNPerXDL = 16;
static constexpr auto MXdlPerWave64 = GetMXdlPerWave<true>();
static constexpr bool Wave32Force16MNPerXDL =
is_NSpatialGC_GKSpatial_NSpatialGK<ALayout, BLayout, ELayout>() &&
sizeof(AComputeDataType) == 2 && sizeof(BComputeDataType) == 2 &&
is_same_v<CDEElementwiseOperation, tensor_operation::element_wise::PassThrough> &&
(ConvForwardSpecialization == ConvolutionForwardSpecialization::Filter1x1Stride1Pad0 ||
ConvForwardSpecialization == ConvolutionForwardSpecialization::Default);
static constexpr index_t Wave32MaxMNPerXDL =
Wave32Force16MNPerXDL ? 16 : math::max(MPerXDL, NPerXDL);
static constexpr auto MXdlPerWave64 = GetMXdlPerWave<true>();
static constexpr auto MXdlPerWave32 =
GetMXdlPerWave<false,
Wave32MaxMNPerXDL,

View File

@@ -208,8 +208,16 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor
using DeviceOp = DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor;
GET_MXDL_PER_WAVE_IMPL
// Force usage of 16x16 instruction for WMMA
static constexpr index_t Wave32MaxMNPerXDL = 16;
static constexpr auto MXdlPerWave64 = GetMXdlPerWave<true>();
static constexpr bool Wave32Force16MNPerXDL =
is_NSpatialGC_GKSpatial_NSpatialGK<ALayout, BLayout, ELayout>() &&
sizeof(AComputeDataType) == 2 && sizeof(BComputeDataType) == 2 &&
is_same_v<CDEElementwiseOperation, tensor_operation::element_wise::PassThrough> &&
(ConvForwardSpecialization == ConvolutionForwardSpecialization::Filter1x1Stride1Pad0 ||
ConvForwardSpecialization == ConvolutionForwardSpecialization::Default);
static constexpr index_t Wave32MaxMNPerXDL =
Wave32Force16MNPerXDL ? 16 : math::max(MPerXDL, NPerXDL);
static constexpr auto MXdlPerWave64 = GetMXdlPerWave<true>();
static constexpr auto MXdlPerWave32 =
GetMXdlPerWave<false,
Wave32MaxMNPerXDL,