Post-merge cleanup for WMMA grouped conv fwd (#3468)

* remove duplicate aliases

* Split scaleadd_ab instances for WMMA grouped conv fwd

* removed big shape from the test
This commit is contained in:
Wojciech Laskowski
2025-12-22 15:57:45 +01:00
committed by GitHub
parent 44f1b5c5de
commit a8aebb7a8e
13 changed files with 570 additions and 62 deletions

View File

@@ -40,7 +40,7 @@ template <index_t NDimSpatial,
typename BLayout,
typename ELayout,
ConvolutionForwardSpecialization ConvSpec>
using device_grouped_conv_fwd_wmma_cshufflev3_scaleadd_ab_bf16_instances =
using device_grouped_conv_fwd_wmma_cshufflev3_scaleadd_ab_bf16_instances_part1 =
std::tuple<
// clang-format off
//########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MWmma| NWmma| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Pipeline scheduler | Pipeline version |
@@ -57,30 +57,84 @@ using device_grouped_conv_fwd_wmma_cshufflev3_scaleadd_ab_bf16_instances =
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, ck::Tuple<>, ELayout, ck::Tuple<BF16, BF16>, ck::Tuple<BF16, BF16>, F32, BF16, ck::Tuple<>, BF16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 256, 64, 64, 64, 8, 8, 16, 16, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, ck::Tuple<>, ELayout, ck::Tuple<BF16, BF16>, ck::Tuple<BF16, BF16>, F32, BF16, ck::Tuple<>, BF16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 256, 64, 64, 64, 8, 8, 16, 16, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, ck::Tuple<>, ELayout, ck::Tuple<BF16, BF16>, ck::Tuple<BF16, BF16>, F32, BF16, ck::Tuple<>, BF16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, ck::Tuple<>, ELayout, ck::Tuple<BF16, BF16>, ck::Tuple<BF16, BF16>, F32, BF16, ck::Tuple<>, BF16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 128, 64, 32, 64, 8, 8, 16, 16, 2, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, ck::Tuple<>, ELayout, ck::Tuple<BF16, BF16>, ck::Tuple<BF16, BF16>, F32, BF16, ck::Tuple<>, BF16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 128, 64, 64, 64, 8, 8, 16, 16, 2, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, ck::Tuple<>, ELayout, ck::Tuple<BF16, BF16>, ck::Tuple<BF16, BF16>, F32, BF16, ck::Tuple<>, BF16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 128, 64, 32, 64, 8, 8, 16, 16, 2, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>
#endif
// clang-format on
>;
template <index_t NDimSpatial,
typename ALayout,
typename BLayout,
typename ELayout,
ConvolutionForwardSpecialization ConvSpec>
using device_grouped_conv_fwd_wmma_cshufflev3_scaleadd_ab_bf16_instances_part2 =
std::tuple<
// clang-format off
//########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MWmma| NWmma| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Pipeline scheduler | Pipeline version |
//########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Size| Block| Block| Block| | | WMMA| WMMA| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MWmmaPerWave| NWmmaPerWave| _MBlock_MWaveMPerWmma| ScalarPerVector| | |
//########################################| | | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerWmma| _NWaveNPerWmma| | |
//########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, ck::Tuple<>, ELayout, ck::Tuple<BF16, BF16>, ck::Tuple<BF16, BF16>, F32, BF16, ck::Tuple<>, BF16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 128, 64, 64, 64, 8, 8, 16, 16, 2, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>
#ifndef ONE_INSTANCE_PER_LIST
,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, ck::Tuple<>, ELayout, ck::Tuple<BF16, BF16>, ck::Tuple<BF16, BF16>, F32, BF16, ck::Tuple<>, BF16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 128, 64, 64, 64, 8, 8, 16, 16, 2, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, ck::Tuple<>, ELayout, ck::Tuple<BF16, BF16>, ck::Tuple<BF16, BF16>, F32, BF16, ck::Tuple<>, BF16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 128, 64, 64, 64, 8, 8, 16, 16, 2, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, ck::Tuple<>, ELayout, ck::Tuple<BF16, BF16>, ck::Tuple<BF16, BF16>, F32, BF16, ck::Tuple<>, BF16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 128, 64, 96, 64, 8, 8, 16, 16, 2, 3, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, ck::Tuple<>, ELayout, ck::Tuple<BF16, BF16>, ck::Tuple<BF16, BF16>, F32, BF16, ck::Tuple<>, BF16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 128, 64, 96, 64, 8, 8, 16, 16, 2, 3, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, ck::Tuple<>, ELayout, ck::Tuple<BF16, BF16>, ck::Tuple<BF16, BF16>, F32, BF16, ck::Tuple<>, BF16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 128, 128, 96, 32, 8, 8, 16, 16, 4, 3, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, ck::Tuple<>, ELayout, ck::Tuple<BF16, BF16>, ck::Tuple<BF16, BF16>, F32, BF16, ck::Tuple<>, BF16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 128, 128, 96, 32, 8, 8, 16, 16, 4, 3, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, ck::Tuple<>, ELayout, ck::Tuple<BF16, BF16>, ck::Tuple<BF16, BF16>, F32, BF16, ck::Tuple<>, BF16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 128, 128, 96, 64, 8, 8, 16, 16, 4, 3, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, ck::Tuple<>, ELayout, ck::Tuple<BF16, BF16>, ck::Tuple<BF16, BF16>, F32, BF16, ck::Tuple<>, BF16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, ck::Tuple<>, ELayout, ck::Tuple<BF16, BF16>, ck::Tuple<BF16, BF16>, F32, BF16, ck::Tuple<>, BF16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 128, 128, 96, 64, 8, 8, 16, 16, 4, 3, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>
#endif
// clang-format on
>;
template <index_t NDimSpatial,
typename ALayout,
typename BLayout,
typename ELayout,
ConvolutionForwardSpecialization ConvSpec>
using device_grouped_conv_fwd_wmma_cshufflev3_scaleadd_ab_bf16_instances_part3 =
std::tuple<
// clang-format off
//########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MWmma| NWmma| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Pipeline scheduler | Pipeline version |
//########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Size| Block| Block| Block| | | WMMA| WMMA| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MWmmaPerWave| NWmmaPerWave| _MBlock_MWaveMPerWmma| ScalarPerVector| | |
//########################################| | | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerWmma| _NWaveNPerWmma| | |
//########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, ck::Tuple<>, ELayout, ck::Tuple<BF16, BF16>, ck::Tuple<BF16, BF16>, F32, BF16, ck::Tuple<>, BF16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>
#ifndef ONE_INSTANCE_PER_LIST
,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, ck::Tuple<>, ELayout, ck::Tuple<BF16, BF16>, ck::Tuple<BF16, BF16>, F32, BF16, ck::Tuple<>, BF16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 256, 128, 96, 64, 8, 8, 16, 16, 2, 3, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 64, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, ck::Tuple<>, ELayout, ck::Tuple<BF16, BF16>, ck::Tuple<BF16, BF16>, F32, BF16, ck::Tuple<>, BF16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 256, 128, 96, 64, 8, 8, 16, 16, 2, 3, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, ck::Tuple<>, ELayout, ck::Tuple<BF16, BF16>, ck::Tuple<BF16, BF16>, F32, BF16, ck::Tuple<>, BF16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 256, 128, 96, 64, 8, 8, 16, 16, 2, 3, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 64, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, ck::Tuple<>, ELayout, ck::Tuple<BF16, BF16>, ck::Tuple<BF16, BF16>, F32, BF16, ck::Tuple<>, BF16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 128, 128, 64, 32, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, ck::Tuple<>, ELayout, ck::Tuple<BF16, BF16>, ck::Tuple<BF16, BF16>, F32, BF16, ck::Tuple<>, BF16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 128, 128, 64, 32, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, ck::Tuple<>, ELayout, ck::Tuple<BF16, BF16>, ck::Tuple<BF16, BF16>, F32, BF16, ck::Tuple<>, BF16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 128, 128, 64, 32, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, ck::Tuple<>, ELayout, ck::Tuple<BF16, BF16>, ck::Tuple<BF16, BF16>, F32, BF16, ck::Tuple<>, BF16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 128, 128, 32, 32, 8, 8, 16, 16, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, ck::Tuple<>, ELayout, ck::Tuple<BF16, BF16>, ck::Tuple<BF16, BF16>, F32, BF16, ck::Tuple<>, BF16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 128, 128, 32, 32, 8, 8, 16, 16, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, ck::Tuple<>, ELayout, ck::Tuple<BF16, BF16>, ck::Tuple<BF16, BF16>, F32, BF16, ck::Tuple<>, BF16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 128, 128, 32, 32, 8, 8, 16, 16, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>
#endif
// clang-format on
>;
template <index_t NDimSpatial,
typename ALayout,
typename BLayout,
typename ELayout,
ConvolutionForwardSpecialization ConvSpec>
using device_grouped_conv_fwd_wmma_cshufflev3_scaleadd_ab_bf16_instances_part4 =
std::tuple<
// clang-format off
//########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MWmma| NWmma| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Pipeline scheduler | Pipeline version |
//########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Size| Block| Block| Block| | | WMMA| WMMA| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MWmmaPerWave| NWmmaPerWave| _MBlock_MWaveMPerWmma| ScalarPerVector| | |
//########################################| | | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerWmma| _NWaveNPerWmma| | |
//########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, ck::Tuple<>, ELayout, ck::Tuple<BF16, BF16>, ck::Tuple<BF16, BF16>, F32, BF16, ck::Tuple<>, BF16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 128, 128, 32, 32, 8, 8, 16, 16, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>
#ifndef ONE_INSTANCE_PER_LIST
,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, ck::Tuple<>, ELayout, ck::Tuple<BF16, BF16>, ck::Tuple<BF16, BF16>, F32, BF16, ck::Tuple<>, BF16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 128, 128, 32, 32, 8, 8, 16, 16, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, ck::Tuple<>, ELayout, ck::Tuple<BF16, BF16>, ck::Tuple<BF16, BF16>, F32, BF16, ck::Tuple<>, BF16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 256, 256, 64, 64, 8, 8, 16, 16, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, ck::Tuple<>, ELayout, ck::Tuple<BF16, BF16>, ck::Tuple<BF16, BF16>, F32, BF16, ck::Tuple<>, BF16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 256, 128, 256, 32, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, ck::Tuple<>, ELayout, ck::Tuple<BF16, BF16>, ck::Tuple<BF16, BF16>, F32, BF16, ck::Tuple<>, BF16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 128, 128, 128, 32, 8, 8, 16, 16, 8, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, ck::Tuple<>, ELayout, ck::Tuple<BF16, BF16>, ck::Tuple<BF16, BF16>, F32, BF16, ck::Tuple<>, BF16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, ck::Tuple<>, ELayout, ck::Tuple<BF16, BF16>, ck::Tuple<BF16, BF16>, F32, BF16, ck::Tuple<>, BF16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 128, 64, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, ck::Tuple<>, ELayout, ck::Tuple<BF16, BF16>, ck::Tuple<BF16, BF16>, F32, BF16, ck::Tuple<>, BF16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 128, 64, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, ck::Tuple<>, ELayout, ck::Tuple<BF16, BF16>, ck::Tuple<BF16, BF16>, F32, BF16, ck::Tuple<>, BF16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 256, 64, 64, 32, 8, 8, 16, 16, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, ck::Tuple<>, ELayout, ck::Tuple<BF16, BF16>, ck::Tuple<BF16, BF16>, F32, BF16, ck::Tuple<>, BF16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 256, 64, 64, 32, 8, 8, 16, 16, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>
#endif
@@ -92,7 +146,7 @@ template <index_t NDimSpatial,
typename BLayout,
typename ELayout,
ConvolutionForwardSpecialization ConvSpec>
using device_grouped_conv_fwd_wmma_cshufflev3_scaleadd_ab_f16_instances =
using device_grouped_conv_fwd_wmma_cshufflev3_scaleadd_ab_f16_instances_part1 =
std::tuple<
// clang-format off
//########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MWmma| NWmma| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Pipeline scheduler | Pipeline version |
@@ -109,30 +163,84 @@ using device_grouped_conv_fwd_wmma_cshufflev3_scaleadd_ab_f16_instances =
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, ck::Tuple<>, ELayout, ck::Tuple<F16, F16>, ck::Tuple<F16, F16>, F32, F16, ck::Tuple<>, F16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 256, 64, 64, 64, 8, 8, 16, 16, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, ck::Tuple<>, ELayout, ck::Tuple<F16, F16>, ck::Tuple<F16, F16>, F32, F16, ck::Tuple<>, F16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 256, 64, 64, 64, 8, 8, 16, 16, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, ck::Tuple<>, ELayout, ck::Tuple<F16, F16>, ck::Tuple<F16, F16>, F32, F16, ck::Tuple<>, F16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, ck::Tuple<>, ELayout, ck::Tuple<F16, F16>, ck::Tuple<F16, F16>, F32, F16, ck::Tuple<>, F16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 128, 64, 32, 64, 8, 8, 16, 16, 2, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, ck::Tuple<>, ELayout, ck::Tuple<F16, F16>, ck::Tuple<F16, F16>, F32, F16, ck::Tuple<>, F16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 128, 64, 64, 64, 8, 8, 16, 16, 2, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, ck::Tuple<>, ELayout, ck::Tuple<F16, F16>, ck::Tuple<F16, F16>, F32, F16, ck::Tuple<>, F16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 128, 64, 32, 64, 8, 8, 16, 16, 2, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>
#endif
// clang-format on
>;
template <index_t NDimSpatial,
typename ALayout,
typename BLayout,
typename ELayout,
ConvolutionForwardSpecialization ConvSpec>
using device_grouped_conv_fwd_wmma_cshufflev3_scaleadd_ab_f16_instances_part2 =
std::tuple<
// clang-format off
//########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MWmma| NWmma| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Pipeline scheduler | Pipeline version |
//########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Size| Block| Block| Block| | | WMMA| WMMA| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MWmmaPerWave| NWmmaPerWave| _MBlock_MWaveMPerWmma| ScalarPerVector| | |
//########################################| | | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerWmma| _NWaveNPerWmma| | |
//########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, ck::Tuple<>, ELayout, ck::Tuple<F16, F16>, ck::Tuple<F16, F16>, F32, F16, ck::Tuple<>, F16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 128, 64, 64, 64, 8, 8, 16, 16, 2, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>
#ifndef ONE_INSTANCE_PER_LIST
,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, ck::Tuple<>, ELayout, ck::Tuple<F16, F16>, ck::Tuple<F16, F16>, F32, F16, ck::Tuple<>, F16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 128, 64, 64, 64, 8, 8, 16, 16, 2, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, ck::Tuple<>, ELayout, ck::Tuple<F16, F16>, ck::Tuple<F16, F16>, F32, F16, ck::Tuple<>, F16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 128, 64, 64, 64, 8, 8, 16, 16, 2, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, ck::Tuple<>, ELayout, ck::Tuple<F16, F16>, ck::Tuple<F16, F16>, F32, F16, ck::Tuple<>, F16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 128, 64, 96, 64, 8, 8, 16, 16, 2, 3, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, ck::Tuple<>, ELayout, ck::Tuple<F16, F16>, ck::Tuple<F16, F16>, F32, F16, ck::Tuple<>, F16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 128, 64, 96, 64, 8, 8, 16, 16, 2, 3, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, ck::Tuple<>, ELayout, ck::Tuple<F16, F16>, ck::Tuple<F16, F16>, F32, F16, ck::Tuple<>, F16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 128, 128, 96, 32, 8, 8, 16, 16, 4, 3, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, ck::Tuple<>, ELayout, ck::Tuple<F16, F16>, ck::Tuple<F16, F16>, F32, F16, ck::Tuple<>, F16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 128, 128, 96, 32, 8, 8, 16, 16, 4, 3, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, ck::Tuple<>, ELayout, ck::Tuple<F16, F16>, ck::Tuple<F16, F16>, F32, F16, ck::Tuple<>, F16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 128, 128, 96, 64, 8, 8, 16, 16, 4, 3, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, ck::Tuple<>, ELayout, ck::Tuple<F16, F16>, ck::Tuple<F16, F16>, F32, F16, ck::Tuple<>, F16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, ck::Tuple<>, ELayout, ck::Tuple<F16, F16>, ck::Tuple<F16, F16>, F32, F16, ck::Tuple<>, F16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 128, 128, 96, 64, 8, 8, 16, 16, 4, 3, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>
#endif
// clang-format on
>;
template <index_t NDimSpatial,
typename ALayout,
typename BLayout,
typename ELayout,
ConvolutionForwardSpecialization ConvSpec>
using device_grouped_conv_fwd_wmma_cshufflev3_scaleadd_ab_f16_instances_part3 =
std::tuple<
// clang-format off
//########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MWmma| NWmma| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Pipeline scheduler | Pipeline version |
//########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Size| Block| Block| Block| | | WMMA| WMMA| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MWmmaPerWave| NWmmaPerWave| _MBlock_MWaveMPerWmma| ScalarPerVector| | |
//########################################| | | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerWmma| _NWaveNPerWmma| | |
//########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, ck::Tuple<>, ELayout, ck::Tuple<F16, F16>, ck::Tuple<F16, F16>, F32, F16, ck::Tuple<>, F16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>
#ifndef ONE_INSTANCE_PER_LIST
,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, ck::Tuple<>, ELayout, ck::Tuple<F16, F16>, ck::Tuple<F16, F16>, F32, F16, ck::Tuple<>, F16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 256, 128, 96, 64, 8, 8, 16, 16, 2, 3, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 64, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, ck::Tuple<>, ELayout, ck::Tuple<F16, F16>, ck::Tuple<F16, F16>, F32, F16, ck::Tuple<>, F16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 256, 128, 96, 64, 8, 8, 16, 16, 2, 3, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, ck::Tuple<>, ELayout, ck::Tuple<F16, F16>, ck::Tuple<F16, F16>, F32, F16, ck::Tuple<>, F16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 256, 128, 96, 64, 8, 8, 16, 16, 2, 3, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 64, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, ck::Tuple<>, ELayout, ck::Tuple<F16, F16>, ck::Tuple<F16, F16>, F32, F16, ck::Tuple<>, F16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 128, 128, 64, 32, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, ck::Tuple<>, ELayout, ck::Tuple<F16, F16>, ck::Tuple<F16, F16>, F32, F16, ck::Tuple<>, F16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 128, 128, 64, 32, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, ck::Tuple<>, ELayout, ck::Tuple<F16, F16>, ck::Tuple<F16, F16>, F32, F16, ck::Tuple<>, F16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 128, 128, 64, 32, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, ck::Tuple<>, ELayout, ck::Tuple<F16, F16>, ck::Tuple<F16, F16>, F32, F16, ck::Tuple<>, F16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 128, 128, 32, 32, 8, 8, 16, 16, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, ck::Tuple<>, ELayout, ck::Tuple<F16, F16>, ck::Tuple<F16, F16>, F32, F16, ck::Tuple<>, F16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 128, 128, 32, 32, 8, 8, 16, 16, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, ck::Tuple<>, ELayout, ck::Tuple<F16, F16>, ck::Tuple<F16, F16>, F32, F16, ck::Tuple<>, F16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 128, 128, 32, 32, 8, 8, 16, 16, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>
#endif
// clang-format on
>;
template <index_t NDimSpatial,
typename ALayout,
typename BLayout,
typename ELayout,
ConvolutionForwardSpecialization ConvSpec>
using device_grouped_conv_fwd_wmma_cshufflev3_scaleadd_ab_f16_instances_part4 =
std::tuple<
// clang-format off
//########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MWmma| NWmma| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Pipeline scheduler | Pipeline version |
//########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Size| Block| Block| Block| | | WMMA| WMMA| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MWmmaPerWave| NWmmaPerWave| _MBlock_MWaveMPerWmma| ScalarPerVector| | |
//########################################| | | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerWmma| _NWaveNPerWmma| | |
//########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, ck::Tuple<>, ELayout, ck::Tuple<F16, F16>, ck::Tuple<F16, F16>, F32, F16, ck::Tuple<>, F16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 128, 128, 32, 32, 8, 8, 16, 16, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>
#ifndef ONE_INSTANCE_PER_LIST
,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, ck::Tuple<>, ELayout, ck::Tuple<F16, F16>, ck::Tuple<F16, F16>, F32, F16, ck::Tuple<>, F16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 128, 128, 32, 32, 8, 8, 16, 16, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, ck::Tuple<>, ELayout, ck::Tuple<F16, F16>, ck::Tuple<F16, F16>, F32, F16, ck::Tuple<>, F16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 256, 256, 64, 64, 8, 8, 16, 16, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, ck::Tuple<>, ELayout, ck::Tuple<F16, F16>, ck::Tuple<F16, F16>, F32, F16, ck::Tuple<>, F16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 256, 128, 256, 32, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, ck::Tuple<>, ELayout, ck::Tuple<F16, F16>, ck::Tuple<F16, F16>, F32, F16, ck::Tuple<>, F16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 128, 128, 128, 32, 8, 8, 16, 16, 8, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, ck::Tuple<>, ELayout, ck::Tuple<F16, F16>, ck::Tuple<F16, F16>, F32, F16, ck::Tuple<>, F16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, ck::Tuple<>, ELayout, ck::Tuple<F16, F16>, ck::Tuple<F16, F16>, F32, F16, ck::Tuple<>, F16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 128, 64, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, ck::Tuple<>, ELayout, ck::Tuple<F16, F16>, ck::Tuple<F16, F16>, F32, F16, ck::Tuple<>, F16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 128, 64, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, ck::Tuple<>, ELayout, ck::Tuple<F16, F16>, ck::Tuple<F16, F16>, F32, F16, ck::Tuple<>, F16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 256, 64, 64, 32, 8, 8, 16, 16, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<NDimSpatial, ALayout, BLayout, ck::Tuple<>, ELayout, ck::Tuple<F16, F16>, ck::Tuple<F16, F16>, F32, F16, ck::Tuple<>, F16, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 256, 64, 64, 32, 8, 8, 16, 16, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>
#endif

View File

@@ -91,7 +91,46 @@ void add_device_grouped_conv3d_fwd_xdl_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_int8_ins
#ifdef CK_USE_WMMA
#ifdef CK_ENABLE_BF16
// grouped conv3d forward multi AB scaleadd, NDHWGC/GKZYXC/NDHWGK
void add_device_grouped_conv3d_fwd_wmma_cshufflev3_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_bf16_instances(
void add_device_grouped_conv3d_fwd_wmma_cshufflev3_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_bf16_instances_part1(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
ck::Tuple<>,
NDHWGK,
ck::Tuple<BF16, BF16>,
ck::Tuple<BF16, BF16>,
ck::Tuple<>,
BF16,
ScaleAdd,
ScaleAdd,
PassThrough>>>& instances);
void add_device_grouped_conv3d_fwd_wmma_cshufflev3_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_bf16_instances_part2(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
ck::Tuple<>,
NDHWGK,
ck::Tuple<BF16, BF16>,
ck::Tuple<BF16, BF16>,
ck::Tuple<>,
BF16,
ScaleAdd,
ScaleAdd,
PassThrough>>>& instances);
void add_device_grouped_conv3d_fwd_wmma_cshufflev3_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_bf16_instances_part3(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
ck::Tuple<>,
NDHWGK,
ck::Tuple<BF16, BF16>,
ck::Tuple<BF16, BF16>,
ck::Tuple<>,
BF16,
ScaleAdd,
ScaleAdd,
PassThrough>>>& instances);
void add_device_grouped_conv3d_fwd_wmma_cshufflev3_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_bf16_instances_part4(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
@@ -107,7 +146,46 @@ void add_device_grouped_conv3d_fwd_wmma_cshufflev3_scaleadd_ab_ndhwgc_gkzyxc_ndh
#endif
#ifdef CK_ENABLE_FP16
void add_device_grouped_conv3d_fwd_wmma_cshufflev3_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_f16_instances(
void add_device_grouped_conv3d_fwd_wmma_cshufflev3_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_f16_instances_part1(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
ck::Tuple<>,
NDHWGK,
ck::Tuple<F16, F16>,
ck::Tuple<F16, F16>,
ck::Tuple<>,
F16,
ScaleAdd,
ScaleAdd,
PassThrough>>>& instances);
void add_device_grouped_conv3d_fwd_wmma_cshufflev3_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_f16_instances_part2(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
ck::Tuple<>,
NDHWGK,
ck::Tuple<F16, F16>,
ck::Tuple<F16, F16>,
ck::Tuple<>,
F16,
ScaleAdd,
ScaleAdd,
PassThrough>>>& instances);
void add_device_grouped_conv3d_fwd_wmma_cshufflev3_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_f16_instances_part3(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
ck::Tuple<>,
NDHWGK,
ck::Tuple<F16, F16>,
ck::Tuple<F16, F16>,
ck::Tuple<>,
F16,
ScaleAdd,
ScaleAdd,
PassThrough>>>& instances);
void add_device_grouped_conv3d_fwd_wmma_cshufflev3_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_f16_instances_part4(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
@@ -218,7 +296,13 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
is_same_v<WeiDataType, ck::Tuple<half_t, half_t>> &&
is_same_v<OutDataType, half_t> && is_same_v<ComputeType, half_t>)
{
add_device_grouped_conv3d_fwd_wmma_cshufflev3_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_f16_instances(
add_device_grouped_conv3d_fwd_wmma_cshufflev3_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_f16_instances_part1(
op_ptrs);
add_device_grouped_conv3d_fwd_wmma_cshufflev3_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_f16_instances_part2(
op_ptrs);
add_device_grouped_conv3d_fwd_wmma_cshufflev3_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_f16_instances_part3(
op_ptrs);
add_device_grouped_conv3d_fwd_wmma_cshufflev3_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_f16_instances_part4(
op_ptrs);
}
#endif
@@ -227,7 +311,13 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
is_same_v<WeiDataType, ck::Tuple<ck::bhalf_t, ck::bhalf_t>> &&
is_same_v<OutDataType, ck::bhalf_t> && is_same_v<ComputeType, ck::bhalf_t>)
{
add_device_grouped_conv3d_fwd_wmma_cshufflev3_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_bf16_instances(
add_device_grouped_conv3d_fwd_wmma_cshufflev3_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_bf16_instances_part1(
op_ptrs);
add_device_grouped_conv3d_fwd_wmma_cshufflev3_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_bf16_instances_part2(
op_ptrs);
add_device_grouped_conv3d_fwd_wmma_cshufflev3_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_bf16_instances_part3(
op_ptrs);
add_device_grouped_conv3d_fwd_wmma_cshufflev3_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_bf16_instances_part4(
op_ptrs);
}
#endif