mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 05:01:25 +00:00
Fix grouped conv fwd wmma porting (#3479)
* Fix grouped conv fwd wmma porting * add more limitations
This commit is contained in:
@@ -327,8 +327,16 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
using DeviceOp = DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle;
|
||||
GET_MXDL_PER_WAVE_IMPL
|
||||
// Force usage of 16x16 instruction for WMMA
|
||||
static constexpr index_t Wave32MaxMNPerXDL = 16;
|
||||
static constexpr auto MXdlPerWave64 = GetMXdlPerWave<true>();
|
||||
static constexpr bool Wave32Force16MNPerXDL =
|
||||
is_NSpatialGC_GKSpatial_NSpatialGK<ALayout, BLayout, ELayout>() &&
|
||||
sizeof(AComputeDataType) == 2 && sizeof(BComputeDataType) == 2 &&
|
||||
is_same_v<CDEElementwiseOperation, tensor_operation::element_wise::PassThrough> &&
|
||||
(ConvForwardSpecialization == ConvolutionForwardSpecialization::Filter1x1Stride1Pad0 ||
|
||||
ConvForwardSpecialization == ConvolutionForwardSpecialization::Default);
|
||||
static constexpr index_t Wave32MaxMNPerXDL =
|
||||
Wave32Force16MNPerXDL ? 16 : math::max(MPerXDL, NPerXDL);
|
||||
|
||||
static constexpr auto MXdlPerWave64 = GetMXdlPerWave<true>();
|
||||
static constexpr auto MXdlPerWave32 =
|
||||
GetMXdlPerWave<false,
|
||||
Wave32MaxMNPerXDL,
|
||||
|
||||
@@ -402,8 +402,16 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
|
||||
using DeviceOp = DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3;
|
||||
GET_MXDL_PER_WAVE_IMPL
|
||||
// Force usage of 16x16 instruction for WMMA
|
||||
static constexpr index_t Wave32MaxMNPerXDL = 16;
|
||||
static constexpr auto MXdlPerWave64 = GetMXdlPerWave<true>();
|
||||
static constexpr bool Wave32Force16MNPerXDL =
|
||||
is_NSpatialGC_GKSpatial_NSpatialGK<ALayout, BLayout, ELayout>() &&
|
||||
sizeof(AComputeDataType) == 2 && sizeof(BComputeDataType) == 2 &&
|
||||
is_same_v<CDEElementwiseOperation, tensor_operation::element_wise::PassThrough> &&
|
||||
(ConvForwardSpecialization == ConvolutionForwardSpecialization::Filter1x1Stride1Pad0 ||
|
||||
ConvForwardSpecialization == ConvolutionForwardSpecialization::Default);
|
||||
static constexpr index_t Wave32MaxMNPerXDL =
|
||||
Wave32Force16MNPerXDL ? 16 : math::max(MPerXDL, NPerXDL);
|
||||
|
||||
static constexpr auto MXdlPerWave64 = GetMXdlPerWave<true>();
|
||||
static constexpr auto MXdlPerWave32 =
|
||||
GetMXdlPerWave<false,
|
||||
Wave32MaxMNPerXDL,
|
||||
|
||||
@@ -208,8 +208,16 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor
|
||||
using DeviceOp = DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor;
|
||||
GET_MXDL_PER_WAVE_IMPL
|
||||
// Force usage of 16x16 instruction for WMMA
|
||||
static constexpr index_t Wave32MaxMNPerXDL = 16;
|
||||
static constexpr auto MXdlPerWave64 = GetMXdlPerWave<true>();
|
||||
static constexpr bool Wave32Force16MNPerXDL =
|
||||
is_NSpatialGC_GKSpatial_NSpatialGK<ALayout, BLayout, ELayout>() &&
|
||||
sizeof(AComputeDataType) == 2 && sizeof(BComputeDataType) == 2 &&
|
||||
is_same_v<CDEElementwiseOperation, tensor_operation::element_wise::PassThrough> &&
|
||||
(ConvForwardSpecialization == ConvolutionForwardSpecialization::Filter1x1Stride1Pad0 ||
|
||||
ConvForwardSpecialization == ConvolutionForwardSpecialization::Default);
|
||||
static constexpr index_t Wave32MaxMNPerXDL =
|
||||
Wave32Force16MNPerXDL ? 16 : math::max(MPerXDL, NPerXDL);
|
||||
|
||||
static constexpr auto MXdlPerWave64 = GetMXdlPerWave<true>();
|
||||
static constexpr auto MXdlPerWave32 =
|
||||
GetMXdlPerWave<false,
|
||||
Wave32MaxMNPerXDL,
|
||||
|
||||
Reference in New Issue
Block a user