diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_exp_gemm_wmma_universal_km_kn_mn_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_exp_gemm_wmma_universal_km_kn_mn_instance.hpp index 48c9f10312..daf221bb5a 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_exp_gemm_wmma_universal_km_kn_mn_instance.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_exp_gemm_wmma_universal_km_kn_mn_instance.hpp @@ -37,77 +37,89 @@ static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding; static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; -template -using device_gemm_wmma_universal_km_kn_mn_GemmDefault_instances = std::tuple< +template +using device_gemm_wmma_universal_km_kn_mn_comp_instances = std::tuple< // clang-format off //#####################################| ALayout| BLayout| DsLayout |ELayout| ADataType| BDataType| DsDataType| CDataType| AccDataType| CShuffle| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransferClusterLengths| CShuffleBlockTransfer| BlockwiseGemm| BlockwiseGemm| //#####################################| | | | | | | | | | DataType| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MPerBlock_NBlock_NPerBlock| ScalarPerVectors| Pipeline| Pipeline| //#####################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| | | Scheduler| Verision| //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, InOutDataType, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 48, 96, 64, 8, 8, 16, 16, 3, 3, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 6, 8, 0, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 6, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, - DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, InOutDataType, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 96, 32, 8, 8, 16, 16, 4, 3, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 6, 8, 1, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, - DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, InOutDataType, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 32, 64, 128, 8, 8, 16, 16, 2, 1, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, - DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, InOutDataType, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 32, 8, 8, 16, 16, 4, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, - DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, InOutDataType, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 96, 32, 8, 8, 16, 16, 4, 3, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 6, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, - DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, InOutDataType, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmDefault, 192, 48, 96, 192, 8, 8, 16, 16, 3, 1, S<24, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 6, 8, 1, S<24, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 6, 8, 0, 1, 1, S<1, 16, 1, 12>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, - DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, InOutDataType, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 48, 64, 64, 8, 8, 16, 16, 3, 2, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 6, 8, 1, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, - DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, InOutDataType, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 96, 128, 64, 8, 8, 16, 16, 6, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 6, 8, 0, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, - DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, InOutDataType, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmDefault, 192, 32, 96, 192, 8, 8, 16, 16, 2, 1, S<24, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 0, S<24, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 6, 8, 1, 1, 1, S<1, 16, 1, 12>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, - DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, InOutDataType, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 32, 96, 64, 8, 8, 16, 16, 2, 3, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 0, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 6, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, - DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, InOutDataType, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 32, 32, 128, 8, 8, 16, 16, 2, 1, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 0, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, - DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, InOutDataType, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, - DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, InOutDataType, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmDefault, 192, 32, 96, 192, 8, 8, 16, 16, 2, 1, S<24, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<24, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 6, 8, 0, 1, 1, S<1, 16, 1, 12>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, - DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, InOutDataType, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 0, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, - DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, InOutDataType, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 32, 8, 8, 16, 16, 4, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, - DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, InOutDataType, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 32, 32, 64, 8, 8, 16, 16, 2, 1, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 0, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1> + DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, InOutDataType, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 4, 4, 16, 16, 8, 4, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 16>, S<4,4,4>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, InOutDataType, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 4, 4, 16, 16, 4, 2, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 16>, S<4,4,4>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, InOutDataType, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 2, 2, 16, 16, 4, 2, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 16>, S<4,4,4>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, InOutDataType, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 4, 4, 16, 16, 8, 4, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 16>, S<4,4,4>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, InOutDataType, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 2, 2, 16, 16, 8, 4, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 16>, S<4,4,4>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, InOutDataType, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 4, 4, 16, 16, 8, 4, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 16>, S<4,4,4>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, + DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, InOutDataType, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 2, 2, 16, 16, 8, 4, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 16>, S<4,4,4>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, + DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, InOutDataType, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 4, 4, 16, 16, 4, 2, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 16>, S<4,4,4>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, + DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, InOutDataType, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 4, 4, 16, 16, 4, 2, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 16>, S<4,4,4>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1> + // clang-format on >; -using device_gemm_wmma_universal_km_kn_mn_GemmMNKPadding_f16_instances = std::tuple< +template +using device_gemm_wmma_universal_km_kn_mn_mem_instances = std::tuple< // clang-format off - //#####################################| ALayout| BLayout| DsLayout |ELayout| ADataType| BDataType| DsDataType| CDataType| AccDataType| CShuffle| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransferClusterLengths| CShuffleBlockTransfer| BlockwiseGemm| BlockwiseGemm| - //#####################################| | | | | | | | | | DataType| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MPerBlock_NBlock_NPerBlock| ScalarPerVectors| Pipeline| Pipeline| - //#####################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| | | Scheduler| Verision| - //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, F16, F16, Tuple<>, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 64, 64, 96, 32, 8, 8, 16, 16, 4, 3, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 6, 8, 1, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, - DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, F16, F16, Tuple<>, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 64, 64, 64, 64, 8, 8, 16, 16, 4, 2, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 0, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, - DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, F16, F16, Tuple<>, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 128, 48, 64, 128, 8, 8, 16, 16, 3, 1, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 6, 8, 1, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 0, 1, 1, S<1, 16, 1, 8>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, - DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, F16, F16, Tuple<>, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 64, 64, 64, 64, 8, 8, 16, 16, 4, 2, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, - DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, F16, F16, Tuple<>, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 64, 64, 96, 32, 8, 8, 16, 16, 4, 3, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 6, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, - DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, F16, F16, Tuple<>, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 64, 48, 64, 64, 8, 8, 16, 16, 3, 2, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 6, 8, 1, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, - DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, F16, F16, Tuple<>, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 64, 96, 64, 32, 8, 8, 16, 16, 6, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 6, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, - DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, F16, F16, Tuple<>, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 64, 32, 32, 128, 8, 8, 16, 16, 2, 1, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, - DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, F16, F16, Tuple<>, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 64, 32, 32, 128, 8, 8, 16, 16, 2, 1, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 0, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, - DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, F16, F16, Tuple<>, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 64, 48, 32, 128, 8, 8, 16, 16, 3, 1, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 6, 8, 1, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, - DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, F16, F16, Tuple<>, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 64, 32, 96, 64, 8, 8, 16, 16, 2, 3, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 0, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 6, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, - // DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, F16, F16, Tuple<>, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 96, 64, 96, 48, 8, 8, 16, 16, 4, 2, S<6, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<6, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 6, 8, 0, 1, 1, S<1, 16, 1, 6>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, // Incorrect results for f16 - DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, F16, F16, Tuple<>, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 64, 32, 32, 128, 8, 8, 16, 16, 2, 1, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 0, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1> + //#####################################| ALayout| BLayout| DsLayout |ELayout| ADataType| BDataType| DsDataType| CDataType| AccDataType| CShuffle| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransferClusterLengths| CShuffleBlockTransfer| BlockwiseGemm| BlockwiseGemm| + //#####################################| | | | | | | | | | DataType| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MPerBlock_NBlock_NPerBlock| ScalarPerVectors| Pipeline| Pipeline| + //#####################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| | | Scheduler| Verision| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // Latency friendly + DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, InOutDataType, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 16, 64, 4, 4, 16, 16, 1, 1, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 0, 1, 1, S<1, 8, 1, 8>, S<2,2,2>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, InOutDataType, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 16, 64, 2, 2, 16, 16, 1, 1, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, 0, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, 0, 1, 1, S<1, 8, 1, 8>, S<2,2,2>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, InOutDataType, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 32, 16, 16, 64, 4, 4, 16, 16, 1, 1, S< 8, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, S< 8, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 8, 1, 4>, S<4,4,4>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, InOutDataType, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 32, 64, 4, 4, 16, 16, 1, 1, S< 8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 0, S< 8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 8, 1, 8>, S<4,4,4>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, InOutDataType, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 32, 64, 2, 2, 16, 16, 1, 1, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, 0, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, 0, 1, 1, S<1, 8, 1, 8>, S<4,4,4>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1> + // Memory friendly + // TODO: add once v2 is implemented // clang-format on >; -using device_gemm_wmma_universal_km_kn_mn_GemmMNKPadding_bf16_instances = std::tuple< +template +using device_gemm_wmma_universal_km_kn_mn_irregular_odd_m_instances = std::tuple< // clang-format off - //#####################################| ALayout| BLayout| DsLayout |ELayout| ADataType| BDataType| DsDataType| CDataType| AccDataType| CShuffle| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransferClusterLengths| CShuffleBlockTransfer| BlockwiseGemm| BlockwiseGemm| - //#####################################| | | | | | | | | | DataType| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MPerBlock_NBlock_NPerBlock| ScalarPerVectors| Pipeline| Pipeline| - //#####################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| | | Scheduler| Verision| - //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, BF16, BF16, Tuple<>, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 64, 64, 96, 32, 8, 8, 16, 16, 4, 3, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 6, 8, 1, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, - DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, BF16, BF16, Tuple<>, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 64, 64, 64, 64, 8, 8, 16, 16, 4, 2, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 0, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, - DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, BF16, BF16, Tuple<>, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 128, 48, 64, 128, 8, 8, 16, 16, 3, 1, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 6, 8, 1, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 0, 1, 1, S<1, 16, 1, 8>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, - DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, BF16, BF16, Tuple<>, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 64, 64, 64, 64, 8, 8, 16, 16, 4, 2, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, - DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, BF16, BF16, Tuple<>, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 64, 64, 96, 32, 8, 8, 16, 16, 4, 3, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 6, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, - DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, BF16, BF16, Tuple<>, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 64, 48, 64, 64, 8, 8, 16, 16, 3, 2, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 6, 8, 1, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, - DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, BF16, BF16, Tuple<>, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 64, 96, 64, 32, 8, 8, 16, 16, 6, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 6, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, - DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, BF16, BF16, Tuple<>, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 64, 32, 32, 128, 8, 8, 16, 16, 2, 1, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, - DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, BF16, BF16, Tuple<>, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 64, 32, 32, 128, 8, 8, 16, 16, 2, 1, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 0, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, - DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, BF16, BF16, Tuple<>, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 64, 48, 32, 128, 8, 8, 16, 16, 3, 1, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 6, 8, 1, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, - DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, BF16, BF16, Tuple<>, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 64, 32, 96, 64, 8, 8, 16, 16, 2, 3, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 0, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 6, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, - DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, BF16, BF16, Tuple<>, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 96, 64, 96, 48, 8, 8, 16, 16, 4, 2, S<6, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<6, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 6, 8, 0, 1, 1, S<1, 16, 1, 6>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, // Incorrect results for f16 - DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, BF16, BF16, Tuple<>, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 64, 32, 32, 128, 8, 8, 16, 16, 2, 1, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 0, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1> + //#####################################| ALayout| BLayout| DsLayout |ELayout| ADataType| BDataType| DsDataType| CDataType| AccDataType| CShuffle| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransferClusterLengths| CShuffleBlockTransfer| BlockwiseGemm| BlockwiseGemm| + //#####################################| | | | | | | | | | DataType| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MPerBlock_NBlock_NPerBlock| ScalarPerVectors| Pipeline| Pipeline| + //#####################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| | | Scheduler| Verision| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // Latency friendly + DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, InOutDataType, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 16, 64, 4, 4, 16, 16, 1, 1, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 0, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 0, 1, 1, S<1, 8, 1, 8>, S<2,2,2>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, InOutDataType, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 16, 64, 2, 2, 16, 16, 1, 1, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 2, 0, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, 0, 1, 1, S<1, 8, 1, 8>, S<2,2,2>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, InOutDataType, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 32, 16, 16, 64, 4, 4, 16, 16, 1, 1, S< 8, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 0, S< 8, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 8, 1, 4>, S<4,4,4>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, InOutDataType, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 32, 64, 4, 4, 16, 16, 1, 1, S< 8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 0, S< 8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 8, 1, 8>, S<4,4,4>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, InOutDataType, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 32, 64, 2, 2, 16, 16, 1, 1, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 2, 0, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, 0, 1, 1, S<1, 8, 1, 8>, S<4,4,4>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1> + // Memory friendly + // TODO: add once v2 is implemented // clang-format on >; -template +template +using device_gemm_wmma_universal_km_kn_mn_odd_n_instances = std::tuple< + // clang-format off + //#####################################| ALayout| BLayout| DsLayout |ELayout| ADataType| BDataType| DsDataType| CDataType| AccDataType| CShuffle| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransferClusterLengths| CShuffleBlockTransfer| BlockwiseGemm| BlockwiseGemm| + //#####################################| | | | | | | | | | DataType| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MPerBlock_NBlock_NPerBlock| ScalarPerVectors| Pipeline| Pipeline| + //#####################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| | | Scheduler| Verision| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // Latency friendly + DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, F32, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 16, 64, 4, 4, 16, 16, 1, 1, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 0, 1, 1, S<1, 8, 1, 8>, S<1,1,1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, F32, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 16, 64, 2, 2, 16, 16, 1, 1, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, 0, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 2, 0, 1, 1, S<1, 8, 1, 8>, S<1,1,1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, F32, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 32, 16, 16, 64, 4, 4, 16, 16, 1, 1, S< 8, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, S< 8, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 0, 1, 1, S<1, 8, 1, 4>, S<1,1,1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, F32, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 32, 64, 4, 4, 16, 16, 1, 1, S< 8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 0, S< 8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 0, 1, 1, S<1, 8, 1, 8>, S<1,1,1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, F32, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 32, 64, 2, 2, 16, 16, 1, 1, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, 0, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 2, 0, 1, 1, S<1, 8, 1, 8>, S<1,1,1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1> + // Memory friendly + // TODO: add once v2 is implemented + // clang-format on + >; + +template using device_gemm_wmma_universal_km_kn_mn_irregular_odd_mn_instances = std::tuple< // clang-format off //#####################################| ALayout| BLayout| DsLayout |ELayout| ADataType| BDataType| DsDataType| CDataType| AccDataType| CShuffle| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransferClusterLengths| CShuffleBlockTransfer| BlockwiseGemm| BlockwiseGemm| @@ -115,18 +127,11 @@ using device_gemm_wmma_universal_km_kn_mn_irregular_odd_mn_instances = std::tupl //#####################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| | | Scheduler| Verision| //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // Latency friendly - DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, F32, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 256, 128, 32, 64, 8, 8, 16, 16, 1, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 16, 1, 16>, S<1, 1, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, - DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, F32, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 256, 128, 32, 64, 8, 8, 16, 16, 1, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 16>, S<1, 1, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, - DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, F32, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 256, 128, 32, 128, 8, 8, 16, 16, 1, 2, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 16>, S<1, 1, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, - DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, F32, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 256, 128, 48, 128, 8, 8, 16, 16, 1, 3, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 16>, S<1, 1, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, - DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, F32, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 256, 128, 64, 32, 8, 8, 16, 16, 1, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 16>, S<1, 1, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, - DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, F32, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 256, 128, 64, 32, 8, 8, 16, 16, 1, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 16>, S<1, 1, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, - DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, F32, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 256, 128, 64, 64, 8, 8, 16, 16, 1, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 16>, S<1, 1, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, - DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, F32, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 256, 128, 64, 128, 8, 8, 16, 16, 1, 4, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 16>, S<1, 1, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, - DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, F32, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 256, 128, 96, 64, 8, 8, 16, 16, 1, 6, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 16>, S<1, 1, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, - DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, F32, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 256, 128, 96, 128, 8, 8, 16, 16, 1, 6, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 16>, S<1, 1, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, - DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, F32, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 256, 128, 192, 32, 8, 8, 16, 16, 1, 12, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 16>, S<1, 1, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, - DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, F32, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 256, 256, 96, 64, 8, 8, 16, 16, 2, 6, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 16>, S<1, 1, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1> + DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, F32, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 16, 64, 4, 4, 16, 16, 1, 1, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 0, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 0, 1, 1, S<1, 8, 1, 8>, S<1,1,1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, F32, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 16, 64, 2, 2, 16, 16, 1, 1, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, 0, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 2, 0, 1, 1, S<1, 8, 1, 8>, S<1,1,1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, F32, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 32, 16, 16, 64, 4, 4, 16, 16, 1, 1, S< 8, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, S< 8, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 0, 1, 1, S<1, 8, 1, 4>, S<1,1,1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, F32, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 32, 64, 4, 4, 16, 16, 1, 1, S< 8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 0, S< 8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 0, 1, 1, S<1, 8, 1, 8>, S<1,1,1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, F32, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 32, 64, 2, 2, 16, 16, 1, 1, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, 0, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 2, 0, 1, 1, S<1, 8, 1, 8>, S<1,1,1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1> // Memory friendly // TODO: add once v2 is implemented // clang-format on diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_wmma_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_wmma_instance.hpp index b77c7348db..0a1167b1bd 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_wmma_instance.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_wmma_instance.hpp @@ -45,20 +45,68 @@ using device_grouped_conv_bwd_weight_two_stage_nhwgc_wmma_c_shuffle_f16_instance //################################################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| Pipeline| Pipeline| ToMerge| //################################################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| MBlock_MPerBlock| _NPerBlock| Sched| Ver| | //################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | | | - DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 32, 16, 16, 32, 8, 16, 16, 1, 1, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 4, 0, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 4, 0, 1, 1, S<1, 4, 1, 8>, 1, Scheduler, PipelineVersion, 1> - // DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 128, 32, 8, 16, 16, 8, 2, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 0, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, Scheduler, PipelineVersion, 1>, - // DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 128, 32, 8, 16, 16, 8, 2, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 1, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, Scheduler, PipelineVersion, 1>, - // DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 64, 64, 8, 16, 16, 4, 2, S<8, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, S<8, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 4, Scheduler, PipelineVersion, 1>, - // DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 128, 256, 64, 8, 16, 16, 8, 2, S<8, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 1, S<8, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 4, Scheduler, PipelineVersion, 1>, - // DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 48, 64, 128, 8, 16, 16, 3, 1, S<16, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 1, S<16, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, Scheduler, PipelineVersion, 1>, - // DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 96, 128, 64, 8, 16, 16, 6, 2, S<8, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 1, S<8, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, Scheduler, PipelineVersion, 1>, - // DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 64, 64, 128, 8, 16, 16, 4, 1, S<16, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, S<16, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, Scheduler, PipelineVersion, 1>, - // DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 96, 128, 128, 8, 16, 16, 6, 1, S<16, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 1, S<16, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 4, Scheduler, PipelineVersion, 1> - // DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 96, 96, 96, 48, 8, 16, 16, 6, 2, S<6, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 0, S<6, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 0, 1, 1, S<1, 16, 1, 6>, 4, Scheduler, PipelineVersion, 1>, // Incorrect results for at least GemmDefault - // DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 96, 96, 96, 48, 8, 16, 16, 6, 2, S<6, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 1, S<6, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 0, 1, 1, S<1, 16, 1, 6>, 4, Scheduler, PipelineVersion, 1> // Incorrect results for at least GemmDefault + DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 32, 16, 16, 32, 8, 16, 16, 1, 1, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 4, false, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 4, false, 1, 1, S<1, 4, 1, 8>, 1, Scheduler, PipelineVersion, 1>, + DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 64, 32, 8, 16, 16, 2, 2, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, false, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 2>, + + DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 32, 32, 8, 16, 16, 2, 1, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, false, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 2>, + DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 64, 32, 8, 16, 16, 2, 2, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 4>, + DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 128, 32, 8, 16, 16, 2, 4, S<4, 4, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 8>, + + DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 32, 32, 8, 16, 16, 2, 1, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, false, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 2>, + DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 32, 32, 8, 16, 16, 2, 2, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 4>, + DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 128, 32, 32, 8, 16, 16, 4, 2, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, S<4, 4, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 8> + // clang-format on >; +template +using device_grouped_conv_bwd_weight_two_stage_nhwgc_wmma_c_shuffle_f16_part2_instances = + std::tuple< + // clang-format off + //################################################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| KPer| ABK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer| BlockGemm| BlockGemm| NumGroups| + //################################################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| Pipeline| Pipeline| ToMerge| + //################################################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| MBlock_MPerBlock| _NPerBlock| Sched| Ver| | + //################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | | | + DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 64, 64, 64, 8, 16, 16, 2, 1, S<8, 8, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 8, 8, false, S<8, 8, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 8, 8, false, 1, 1, S<1, 16, 1, 16>, 4, Scheduler, PipelineVersion, 1>, + DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 64, 64, 64, 8, 16, 16, 2, 1, S<8, 16, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 4, 8, false, S<8, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 1, 8, false, 1, 1, S<1, 4, 1, 16>, 1, Scheduler, PipelineVersion, 1>, + DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 64, 64, 64, 8, 16, 16, 2, 1, S<8, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 1, 8, false, S<8, 16, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 4, 8, false, 1, 1, S<1, 16, 1, 16>, 4, Scheduler, PipelineVersion, 1>, + DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 64, 64, 64, 8, 16, 16, 2, 1, S<8, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 2, 8, false, S<8, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 2, 8, false, 1, 1, S<1, 8, 1, 32>, 2, Scheduler, PipelineVersion, 1>, + DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 64, 64, 64, 8, 16, 16, 2, 1, S<8, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 1, 8, false, S<8, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 1, 8, false, 1, 1, S<1, 4, 1, 64>, 1, Scheduler, PipelineVersion, 1>, + + DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 32, 16, 256, 32, 8, 16, 16, 1, 16, S<4, 4, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, 1, 1, S<1, 2, 1, 16>, 1, Scheduler, PipelineVersion, 4>, + DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 32, 16, 128, 32, 8, 16, 16, 1, 8, S<4, 4, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, 1, 1, S<1, 2, 1, 16>, 1, Scheduler, PipelineVersion, 4> + + // clang-format on + >; + +template +using device_grouped_conv_bwd_weight_two_stage_nhwgc_wmma_c_shuffle_f16_irregular_instances = + std::tuple< + // clang-format off + //################################################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| KPer| ABK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer| BlockGemm| BlockGemm| NumGroups| + //################################################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| Pipeline| Pipeline| ToMerge| + //################################################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| MBlock_MPerBlock| _NPerBlock| Sched| Ver| | + //################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | | | + DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 48, 64, 32, 8, 16, 16, 3, 2, S<4, 16, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 3, 4, false, S<4, 16, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 4, 4, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 1>, + DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 48, 32, 8, 16, 16, 2, 3, S<4, 16, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 4, 4, false, S<4, 16, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 3, 4, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 1>, + DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 80, 32, 8, 16, 16, 2, 5, S<4, 16, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 4, 4, false, S<4, 16, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 5, 4, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 1>, + DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 112, 32, 8, 16, 16, 2, 7, S<4, 16, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 4, 4, false, S<4, 16, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 7, 4, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 1>, + DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 208, 32, 8, 16, 16, 2, 13, S<4, 16, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 4, 4, false, S<4, 16, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 13, 4, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 1> + // clang-format on + >; + template , S<2, 0, 1>, S<1, 0, 2>, 1, 1, 4, 0, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 4, 0, 1, 1, S<1, 4, 1, 8>, 1, Scheduler, PipelineVersion, 1> - // DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 128, 32, 8, 16, 16, 8, 2, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 0, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, Scheduler, PipelineVersion, 1>, - // DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 128, 32, 8, 16, 16, 8, 2, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 1, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, Scheduler, PipelineVersion, 1>, - // DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 64, 64, 8, 16, 16, 4, 2, S<8, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, S<8, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 4, Scheduler, PipelineVersion, 1>, - // DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 128, 256, 64, 8, 16, 16, 8, 2, S<8, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 1, S<8, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 4, Scheduler, PipelineVersion, 1>, - // DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 48, 64, 128, 8, 16, 16, 3, 1, S<16, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 1, S<16, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, Scheduler, PipelineVersion, 1>, - // DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 96, 128, 64, 8, 16, 16, 6, 2, S<8, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 1, S<8, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, Scheduler, PipelineVersion, 1>, - // DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 64, 64, 128, 8, 16, 16, 4, 1, S<16, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, S<16, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, Scheduler, PipelineVersion, 1>, - // DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 96, 128, 128, 8, 16, 16, 6, 1, S<16, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 1, S<16, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 4, Scheduler, PipelineVersion, 1>, - // DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 96, 96, 96, 48, 8, 16, 16, 6, 2, S<6, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 0, S<6, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 0, 1, 1, S<1, 16, 1, 6>, 4, Scheduler, PipelineVersion, 1>, - // DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 96, 96, 96, 48, 8, 16, 16, 6, 2, S<6, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 1, S<6, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 0, 1, 1, S<1, 16, 1, 6>, 4, Scheduler, PipelineVersion, 1> + DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 32, 16, 16, 32, 8, 16, 16, 1, 1, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 4, false, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 4, false, 1, 1, S<1, 4, 1, 8>, 1, Scheduler, PipelineVersion, 1>, + DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 64, 32, 8, 16, 16, 2, 2, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, false, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 2>, + + DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 32, 32, 8, 16, 16, 2, 1, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, false, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 2>, + DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 64, 32, 8, 16, 16, 2, 2, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 4>, + DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 128, 32, 8, 16, 16, 2, 4, S<4, 4, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 8>, + + DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 32, 32, 8, 16, 16, 2, 1, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, false, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 2>, + DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 32, 32, 8, 16, 16, 2, 2, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 4>, + DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 128, 32, 32, 8, 16, 16, 4, 2, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, S<4, 4, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 8> // clang-format on >; + +template +using device_grouped_conv_bwd_weight_two_stage_nhwgc_wmma_c_shuffle_bf16_part2_instances = + std::tuple< + // clang-format off + //################################################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| KPer| ABK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer| BlockGemm| BlockGemm| NumGroups| + //################################################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| Pipeline| Pipeline| ToMerge| + //################################################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| MBlock_MPerBlock| _NPerBlock| Sched| Ver| | + //################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | | | + DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 64, 64, 64, 8, 16, 16, 2, 1, S<8, 4, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 8, 8, false, S<8, 4, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 8, 8, false, 1, 1, S<1, 16, 1, 16>, 4, Scheduler, PipelineVersion, 1>, + DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 64, 64, 64, 8, 16, 16, 2, 1, S<8, 16, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 4, 8, false, S<8, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 1, 8, false, 1, 1, S<1, 4, 1, 16>, 1, Scheduler, PipelineVersion, 1>, + DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 64, 64, 64, 8, 16, 16, 2, 1, S<8, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 1, 8, false, S<8, 16, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 4, 8, false, 1, 1, S<1, 16, 1, 16>, 4, Scheduler, PipelineVersion, 1>, + DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 64, 64, 64, 8, 16, 16, 2, 1, S<8, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 2, 8, false, S<8, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 2, 8, false, 1, 1, S<1, 8, 1, 32>, 2, Scheduler, PipelineVersion, 1>, + DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 64, 64, 64, 8, 16, 16, 2, 1, S<8, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 1, 8, false, S<8, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 1, 8, false, 1, 1, S<1, 4, 1, 64>, 1, Scheduler, PipelineVersion, 1>, + + DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 32, 16, 256, 32, 8, 16, 16, 1, 16, S<4, 4, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, 1, 1, S<1, 2, 1, 16>, 1, Scheduler, PipelineVersion, 4>, + DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 32, 16, 128, 32, 8, 16, 16, 1, 8, S<4, 4, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, 1, 1, S<1, 2, 1, 16>, 1, Scheduler, PipelineVersion, 4> + // clang-format on + >; + +template +using device_grouped_conv_bwd_weight_two_stage_nhwgc_wmma_c_shuffle_bf16_irregular_instances = + std::tuple< + // clang-format off + //################################################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| KPer| ABK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer| BlockGemm| BlockGemm| NumGroups| + //################################################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| Pipeline| Pipeline| ToMerge| + //################################################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| MBlock_MPerBlock| _NPerBlock| Sched| Ver| | + //################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | | | + DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 48, 64, 32, 8, 16, 16, 3, 2, S<4, 16, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 3, 4, false, S<4, 16, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 4, 4, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 1>, + DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 48, 32, 8, 16, 16, 2, 3, S<4, 16, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 4, 4, false, S<4, 16, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 3, 4, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 1>, + DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 80, 32, 8, 16, 16, 2, 5, S<4, 16, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 4, 4, false, S<4, 16, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 5, 4, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 1>, + DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 112, 32, 8, 16, 16, 2, 7, S<4, 16, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 4, 4, false, S<4, 16, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 7, 4, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 1>, + DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 208, 32, 8, 16, 16, 2, 13, S<4, 16, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 4, 4, false, S<4, 16, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 13, 4, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 1> + // clang-format on + >; + +template +using device_grouped_conv_bwd_weight_two_stage_ngchw_wmma_c_shuffle_f16_generic_instances = + std::tuple< + // clang-format off + //################################################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| KPer| ABK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer| BlockGemm| BlockGemm| NumGroups| + //################################################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| Pipeline| Pipeline| ToMerge| + //################################################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| MBlock_MPerBlock| _NPerBlock| Sched| Ver| | + //################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | | | + DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 32, 16, 16, 32, 8, 16, 16, 1, 1, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 4, false, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 4, false, 1, 1, S<1, 4, 1, 8>, 1, Scheduler, PipelineVersion, 1, F16, F16, 1, 1> + // clang-format on + >; + +// NGCHW requires transpose, we use vector loads and stores params for them +template +using device_grouped_conv_bwd_weight_two_stage_ngchw_wmma_c_shuffle_f16_instances = std::tuple< + // clang-format off + //################################################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| KPer| ABK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer| BlockGemm| BlockGemm| NumGroups| + //################################################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| Pipeline| Pipeline| ToMerge| + //################################################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| MBlock_MPerBlock| _NPerBlock| Sched| Ver| | + //################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | | | + DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 32, 16, 16, 32, 8, 16, 16, 1, 1, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 4, false, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 4, false, 1, 1, S<1, 4, 1, 8>, 1, Scheduler, PipelineVersion, 1, F16, F16, 1, 1>, + + DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 32, 32, 8, 16, 16, 2, 1, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, false, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 2, F16, F16, 2, 2>, + DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 64, 32, 8, 16, 16, 2, 2, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 4, F16, F16, 4, 4>, + DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 128, 32, 8, 16, 16, 2, 4, S<4, 4, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, 1, 1, S<1, 4, 1, 8>, 1, Scheduler, PipelineVersion, 8, F16, F16, 8, 8>, + + DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 32, 32, 8, 16, 16, 2, 1, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, false, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 2, F16, F16, 2, 2>, + DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 32, 32, 8, 16, 16, 2, 2, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 4, F16, F16, 4, 4>, + DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 128, 32, 32, 8, 16, 16, 4, 2, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, S<4, 4, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, 1, 1, S<1, 8, 1, 4>, 1, Scheduler, PipelineVersion, 8, F16, F16, 8, 8>, + + DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 32, 32, 8, 16, 16, 2, 1, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, false, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 2, F16, F16, 1, 2>, + DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 64, 32, 8, 16, 16, 2, 2, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 4, F16, F16, 1, 4>, + DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 128, 32, 8, 16, 16, 2, 4, S<4, 4, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, 1, 1, S<1, 4, 1, 8>, 1, Scheduler, PipelineVersion, 8, F16, F16, 1, 8>, + + DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 32, 32, 8, 16, 16, 2, 2, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 4, F16, F16, 1, 4>, + DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 128, 32, 32, 8, 16, 16, 4, 2, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, S<4, 4, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, 1, 1, S<1, 8, 1, 4>, 1, Scheduler, PipelineVersion, 8, F16, F16, 1, 8>, + + DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 32, 32, 8, 16, 16, 2, 1, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, false, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 2, F16, F16, 2, 1>, + DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 64, 32, 8, 16, 16, 2, 2, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 4, F16, F16, 4, 1>, + DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 128, 32, 8, 16, 16, 2, 4, S<4, 4, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, 1, 1, S<1, 4, 1, 8>, 1, Scheduler, PipelineVersion, 8, F16, F16, 8, 1>, + + DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 32, 32, 8, 16, 16, 2, 2, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 4, F16, F16, 4, 1>, + DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 128, 32, 32, 8, 16, 16, 4, 2, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, S<4, 4, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, 1, 1, S<1, 8, 1, 4>, 1, Scheduler, PipelineVersion, 8, F16, F16, 8, 1> + // clang-format on + >; + +template +using device_grouped_conv_bwd_weight_two_stage_ngchw_wmma_c_shuffle_f16_part2_instances = + std::tuple< + // clang-format off + //################################################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| KPer| ABK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer| BlockGemm| BlockGemm| NumGroups| + //################################################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| Pipeline| Pipeline| ToMerge| + //################################################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| MBlock_MPerBlock| _NPerBlock| Sched| Ver| | + //################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | | | + DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 64, 64, 64, 8, 16, 16, 2, 1, S<8, 8, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 8, 8, false, S<8, 8, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 8, 8, false, 1, 1, S<1, 16, 1, 16>, 4, Scheduler, PipelineVersion, 1, F16, F16, 4, 4>, + DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 64, 64, 64, 8, 16, 16, 2, 1, S<8, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 2, 8, false, S<8, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 2, 8, false, 1, 1, S<1, 8, 1, 32>, 2, Scheduler, PipelineVersion, 1, F16, F16, 2, 2>, + DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 64, 64, 64, 8, 16, 16, 2, 1, S<8, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 1, 8, false, S<8, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 1, 8, false, 1, 1, S<1, 4, 1, 64>, 1, Scheduler, PipelineVersion, 1, F16, F16, 1, 1>, + + DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 16, 256, 32, 8, 16, 16, 1, 8, S<4, 2, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, 1, 1, S<1, 4, 1, 16>, 1, Scheduler, PipelineVersion, 8, F16, F16, 4, 4>, + DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 16, 256, 32, 8, 16, 16, 1, 8, S<4, 2, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, 1, 1, S<1, 4, 1, 16>, 1, Scheduler, PipelineVersion, 8, F16, F16, 2, 4>, + DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 16, 256, 32, 8, 16, 16, 1, 8, S<4, 2, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, 1, 1, S<1, 4, 1, 16>, 1, Scheduler, PipelineVersion, 8, F16, F16, 1, 4>, + + DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 16, 128, 32, 8, 16, 16, 1, 4, S<4, 4, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, false, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, false, 1, 1, S<1, 4, 1, 16>, 1, Scheduler, PipelineVersion, 4, F16, F16, 4, 4>, + DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 16, 128, 32, 8, 16, 16, 1, 4, S<4, 4, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, false, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, false, 1, 1, S<1, 4, 1, 16>, 1, Scheduler, PipelineVersion, 4, F16, F16, 2, 4>, + DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 16, 128, 32, 8, 16, 16, 1, 4, S<4, 4, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, false, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, false, 1, 1, S<1, 4, 1, 16>, 1, Scheduler, PipelineVersion, 4, F16, F16, 1, 4> + // clang-format on + >; + +template +using device_grouped_conv_bwd_weight_two_stage_ngchw_wmma_c_shuffle_bf16_generic_instances = + std::tuple< + // clang-format off + //################################################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| KPer| ABK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer| BlockGemm| BlockGemm| NumGroups| + //################################################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| Pipeline| Pipeline| ToMerge| + //################################################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| MBlock_MPerBlock| _NPerBlock| Sched| Ver| | + //################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | | | + DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 32, 16, 16, 32, 8, 16, 16, 1, 1, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 4, false, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 4, false, 1, 1, S<1, 4, 1, 8>, 1, Scheduler, PipelineVersion, 1, BF16, BF16, 1, 1> + // clang-format on + >; + +template +using device_grouped_conv_bwd_weight_two_stage_ngchw_wmma_c_shuffle_bf16_instances = std::tuple< + // clang-format off + //################################################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| KPer| ABK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer| BlockGemm| BlockGemm| NumGroups| + //################################################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| Pipeline| Pipeline| ToMerge| + //################################################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| MBlock_MPerBlock| _NPerBlock| Sched| Ver| | + //################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | | | + DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 32, 16, 16, 32, 8, 16, 16, 1, 1, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 4, false, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 4, false, 1, 1, S<1, 4, 1, 8>, 1, Scheduler, PipelineVersion, 1, BF16, BF16, 1, 1>, + + DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 32, 32, 8, 16, 16, 2, 1, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, false, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 2, BF16, BF16, 2, 2>, + DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 64, 32, 8, 16, 16, 2, 2, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 4, BF16, BF16, 4, 4>, + DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 128, 32, 8, 16, 16, 2, 4, S<4, 4, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, 1, 1, S<1, 4, 1, 8>, 1, Scheduler, PipelineVersion, 8, BF16, BF16, 8, 8>, + + DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 32, 32, 8, 16, 16, 2, 1, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, false, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 2, BF16, BF16, 2, 2>, + DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 32, 32, 8, 16, 16, 2, 2, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 4, BF16, BF16, 4, 4>, + DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 128, 32, 32, 8, 16, 16, 4, 2, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, S<4, 4, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, 1, 1, S<1, 8, 1, 4>, 1, Scheduler, PipelineVersion, 8, BF16, BF16, 8, 8>, + + DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 32, 32, 8, 16, 16, 2, 1, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, false, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 2, BF16, BF16, 1, 2>, + DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 64, 32, 8, 16, 16, 2, 2, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 4, BF16, BF16, 1, 4>, + DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 128, 32, 8, 16, 16, 2, 4, S<4, 4, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, 1, 1, S<1, 4, 1, 8>, 1, Scheduler, PipelineVersion, 8, BF16, BF16, 1, 8>, + + DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 32, 32, 8, 16, 16, 2, 2, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 4, BF16, BF16, 1, 4>, + DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 128, 32, 32, 8, 16, 16, 4, 2, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, S<4, 4, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, 1, 1, S<1, 8, 1, 4>, 1, Scheduler, PipelineVersion, 8, BF16, BF16, 1, 8>, + + DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 32, 32, 8, 16, 16, 2, 1, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, false, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 2, BF16, BF16, 2, 1>, + DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 64, 32, 8, 16, 16, 2, 2, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 4, BF16, BF16, 4, 1>, + DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 128, 32, 8, 16, 16, 2, 4, S<4, 4, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, 1, 1, S<1, 4, 1, 8>, 1, Scheduler, PipelineVersion, 8, BF16, BF16, 8, 1>, + + DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 32, 32, 8, 16, 16, 2, 2, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 4, BF16, BF16, 4, 1>, + DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 128, 32, 32, 8, 16, 16, 4, 2, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, S<4, 4, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, 1, 1, S<1, 8, 1, 4>, 1, Scheduler, PipelineVersion, 8, BF16, BF16, 8, 1> + // clang-format on + >; + +template +using device_grouped_conv_bwd_weight_two_stage_ngchw_wmma_c_shuffle_bf16_part2_instances = + std::tuple< + // clang-format off + //################################################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| KPer| ABK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer| BlockGemm| BlockGemm| NumGroups| + //################################################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| Pipeline| Pipeline| ToMerge| + //################################################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| MBlock_MPerBlock| _NPerBlock| Sched| Ver| | + //################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | | | + DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 64, 64, 64, 8, 16, 16, 2, 1, S<8, 8, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 8, 8, false, S<8, 8, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 8, 8, false, 1, 1, S<1, 16, 1, 16>, 4, Scheduler, PipelineVersion, 1, BF16, BF16, 4, 4>, + DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 64, 64, 64, 8, 16, 16, 2, 1, S<8, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 2, 8, false, S<8, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 2, 8, false, 1, 1, S<1, 8, 1, 32>, 2, Scheduler, PipelineVersion, 1, BF16, BF16, 2, 2>, + DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 64, 64, 64, 8, 16, 16, 2, 1, S<8, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 1, 8, false, S<8, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 1, 8, false, 1, 1, S<1, 4, 1, 64>, 1, Scheduler, PipelineVersion, 1, BF16, BF16, 1, 1>, + + DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 16, 256, 32, 8, 16, 16, 1, 8, S<4, 2, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, 1, 1, S<1, 4, 1, 16>, 1, Scheduler, PipelineVersion, 8, BF16, BF16, 4, 4>, + DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 16, 256, 32, 8, 16, 16, 1, 8, S<4, 2, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, 1, 1, S<1, 4, 1, 16>, 1, Scheduler, PipelineVersion, 8, BF16, BF16, 2, 4>, + DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 16, 256, 32, 8, 16, 16, 1, 8, S<4, 2, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, 1, 1, S<1, 4, 1, 16>, 1, Scheduler, PipelineVersion, 8, BF16, BF16, 1, 4>, + + DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 16, 128, 32, 8, 16, 16, 1, 4, S<4, 4, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, false, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, false, 1, 1, S<1, 4, 1, 16>, 1, Scheduler, PipelineVersion, 4, BF16, BF16, 4, 4>, + DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 16, 128, 32, 8, 16, 16, 1, 4, S<4, 4, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, false, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, false, 1, 1, S<1, 4, 1, 16>, 1, Scheduler, PipelineVersion, 4, BF16, BF16, 2, 4>, + DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 16, 128, 32, 8, 16, 16, 1, 4, S<4, 4, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, false, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, false, 1, 1, S<1, 4, 1, 16>, 1, Scheduler, PipelineVersion, 4, BF16, BF16, 1, 4> + // clang-format on + >; + } // namespace instance } // namespace device } // namespace tensor_operation diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_wmma_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_wmma_instance.hpp index 761b07ea60..531a0408b0 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_wmma_instance.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_wmma_instance.hpp @@ -45,25 +45,22 @@ template + BlockGemmPipelineScheduler Scheduler, + BlockGemmPipelineVersion PipelineVersion> using device_grouped_conv_bwd_weight_v3_wmma_c_shuffle_f16_instances = std::tuple< // clang-format off //#########################################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| KPer| ABK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer| BlockGemm| BlockGemm| //#########################################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| Pipeline| Pipeline| //#########################################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| MBlock_MPerBlock| _NPerBlock| Scheduler| Version| //#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | | - DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 32, 32, 8, 16, 16, 2, 1, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, 0, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, 0, 1, 1, S<1, 8, 1, 8>, 2, Scheduler, PipelineVersion>, - DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 128, 32, 8, 16, 16, 8, 2, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 0, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, Scheduler, PipelineVersion>, - DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 128, 32, 8, 16, 16, 8, 2, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 1, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, Scheduler, PipelineVersion>, - DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 64, 64, 8, 16, 16, 4, 2, S<8, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, S<8, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, Scheduler, PipelineVersion>, - DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 128, 256, 64, 8, 16, 16, 8, 2, S<8, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 1, S<8, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 8, Scheduler, PipelineVersion>, - DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 48, 64, 128, 8, 16, 16, 3, 1, S<16, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 1, S<16, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, Scheduler, PipelineVersion>, - DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 96, 128, 64, 8, 16, 16, 6, 2, S<8, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 1, S<8, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, Scheduler, PipelineVersion>, - DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 64, 64, 128, 8, 16, 16, 4, 1, S<16, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, S<16, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, Scheduler, PipelineVersion>, - DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 96, 128, 128, 8, 16, 16, 6, 1, S<16, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 1, S<16, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 8, Scheduler, PipelineVersion> - // DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 96, 96, 96, 48, 8, 16, 16, 6, 2, S<6, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 0, S<6, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 0, 1, 1, S<1, 16, 1, 6>, 8, Scheduler, PipelineVersion>, // Incorrect results for at least GemmDefault - // DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 96, 96, 96, 48, 8, 16, 16, 6, 2, S<6, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 1, S<6, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 0, 1, 1, S<1, 16, 1, 6>, 8, Scheduler, PipelineVersion> // Incorrect results for at least GemmDefault + // generic instance + DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 32, 32, 8, 16, 16, 2, 1, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, false, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, false, 1, 1, S<1, 8, 1, 8>, 2, Scheduler, PipelineVersion>, + DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 64, 32, 8, 16, 16, 2, 2, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, 1, 1, S<1, 8, 1, 8>, 2, Scheduler, PipelineVersion>, + DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 128, 32, 8, 16, 16, 2, 4, S<4, 4, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, 1, 1, S<1, 8, 1, 8>, 2, Scheduler, PipelineVersion>, + DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 32, 32, 8, 16, 16, 2, 2, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, 1, 1, S<1, 8, 1, 8>, 2, Scheduler, PipelineVersion>, + DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 128, 32, 32, 8, 16, 16, 4, 2, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, S<4, 4, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, 1, 1, S<1, 8, 1, 8>, 2, Scheduler, PipelineVersion>, + DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 80, 32, 8, 16, 16, 2, 5, S<4, 16, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 4, 4, false, S<4, 16, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 5, 4, false, 1, 1, S<1, 8, 1, 8>, 2, Scheduler, PipelineVersion>, + DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 112, 32, 8, 16, 16, 2, 7, S<4, 16, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 4, 4, false, S<4, 16, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 7, 4, false, 1, 1, S<1, 8, 1, 8>, 2, Scheduler, PipelineVersion> // clang-format on >; @@ -72,25 +69,22 @@ template + BlockGemmPipelineScheduler Scheduler, + BlockGemmPipelineVersion PipelineVersion> using device_grouped_conv_bwd_weight_v3_wmma_c_shuffle_bf16_instances = std::tuple< // clang-format off //#########################################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| KPer| ABK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer| BlockGemm| BlockGemm| //#########################################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| Pipeline| Pipeline| //#########################################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| MBlock_MPerBlock| _NPerBlock| Scheduler| Version| //#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | | - DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 32, 32, 8, 16, 16, 2, 1, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, 0, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, 0, 1, 1, S<1, 8, 1, 8>, 2, Scheduler, PipelineVersion>, - DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 128, 32, 8, 16, 16, 8, 2, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 0, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, Scheduler, PipelineVersion>, - DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 128, 32, 8, 16, 16, 8, 2, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 1, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, Scheduler, PipelineVersion>, - DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 64, 64, 8, 16, 16, 4, 2, S<8, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, S<8, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, Scheduler, PipelineVersion>, - DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 128, 256, 64, 8, 16, 16, 8, 2, S<8, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 1, S<8, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 8, Scheduler, PipelineVersion>, - DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 48, 64, 128, 8, 16, 16, 3, 1, S<16, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 1, S<16, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, Scheduler, PipelineVersion>, - DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 96, 128, 64, 8, 16, 16, 6, 2, S<8, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 1, S<8, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, Scheduler, PipelineVersion>, - DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 64, 64, 128, 8, 16, 16, 4, 1, S<16, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, S<16, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, Scheduler, PipelineVersion>, - DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 96, 128, 128, 8, 16, 16, 6, 1, S<16, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 1, S<16, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 8, Scheduler, PipelineVersion>, - DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 96, 96, 96, 48, 8, 16, 16, 6, 2, S<6, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 0, S<6, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 0, 1, 1, S<1, 16, 1, 6>, 8, Scheduler, PipelineVersion>, - DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 96, 96, 96, 48, 8, 16, 16, 6, 2, S<6, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 1, S<6, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 0, 1, 1, S<1, 16, 1, 6>, 8, Scheduler, PipelineVersion> + // generic instance + DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 32, 32, 8, 16, 16, 2, 1, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, false, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, false, 1, 1, S<1, 8, 1, 8>, 2, Scheduler, PipelineVersion>, + DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 64, 32, 8, 16, 16, 2, 2, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, 1, 1, S<1, 8, 1, 8>, 2, Scheduler, PipelineVersion>, + DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 128, 32, 8, 16, 16, 2, 4, S<4, 4, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, 1, 1, S<1, 8, 1, 8>, 2, Scheduler, PipelineVersion>, + DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 32, 32, 8, 16, 16, 2, 2, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, 1, 1, S<1, 8, 1, 8>, 2, Scheduler, PipelineVersion>, + DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 128, 32, 32, 8, 16, 16, 4, 2, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, S<4, 4, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, 1, 1, S<1, 8, 1, 8>, 2, Scheduler, PipelineVersion>, + DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 80, 32, 8, 16, 16, 2, 5, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 5, 4, false, 1, 1, S<1, 8, 1, 8>, 2, Scheduler, PipelineVersion>, + DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 112, 32, 8, 16, 16, 2, 7, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 7, 4, false, 1, 1, S<1, 8, 1, 8>, 2, Scheduler, PipelineVersion> //clang-format on >; diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_wmma_bilinear_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_wmma_bilinear_instance.hpp index f254628f73..9f8a315c59 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_wmma_bilinear_instance.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_wmma_bilinear_instance.hpp @@ -45,21 +45,27 @@ using device_grouped_conv_bwd_weight_wmma_c_shuffle_f16_bilinear_instances = std //#################################################| Spatial| | | | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| MBlock_MPerBlock| _NPerBlock| Scheduler| Version | //#################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | | // generic instance - DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Tuple, F16, F16, F16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 64, 64, 64, 32, 8, 16, 16, 4, 2, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 4, 1, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, - DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Tuple, F16, F16, F16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 64, 64, 64, 32, 8, 16, 16, 4, 2, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 4, 1, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 4, 1, 1, 1, S<1, 16, 1, 4>, 2, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Tuple, F16, F16, F16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 64, 64, 64, 32, 8, 16, 16, 4, 2, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 4, true, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 4, true, 1, 1, S<1, 16, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Tuple, F16, F16, F16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 64, 64, 64, 32, 8, 16, 16, 4, 2, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 4, true, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 4, true, 1, 1, S<1, 16, 1, 4>, 2, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + // instance for small conv.K // for fp16 conv.K and conv.C must be divisible by 2 // since half_t atomic_add require scalar_per_x_vector % 2 == 0 - DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Tuple, F16, F16, F16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 64, 32, 32, 32, 8, 16, 16, 2, 1, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, 0, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, 0, 1, 1, S<1, 8, 1, 8>, 2, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, - DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Tuple, F16, F16, F16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 128, 128, 128, 32, 8, 16, 16, 8, 2, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 0, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, - DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Tuple, F16, F16, F16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 128, 128, 128, 32, 8, 16, 16, 8, 2, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 1, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, - DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Tuple, F16, F16, F16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 64, 64, 64, 64, 8, 16, 16, 4, 2, S<8, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, S<8, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, - DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Tuple, F16, F16, F16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 256, 128, 256, 64, 8, 16, 16, 8, 2, S<8, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 1, S<8, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, - DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Tuple, F16, F16, F16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 128, 48, 64, 128, 8, 16, 16, 3, 1, S<16, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 1, S<16, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, - DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Tuple, F16, F16, F16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 128, 96, 128, 64, 8, 16, 16, 6, 2, S<8, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 1, S<8, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, - DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Tuple, F16, F16, F16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 128, 64, 64, 128, 8, 16, 16, 4, 1, S<16, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, S<16, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, - DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Tuple, F16, F16, F16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 256, 96, 128, 128, 8, 16, 16, 6, 1, S<16, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 1, S<16, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1> - // DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Tuple, F16, F16, F16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 96, 96, 96, 48, 8, 16, 16, 6, 2, S<6, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 0, S<6, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 0, 1, 1, S<1, 16, 1, 6>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, // Presumably doesn't produce correct results for f16 - // DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Tuple, F16, F16, F16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 96, 96, 96, 48, 8, 16, 16, 6, 2, S<6, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 1, S<6, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 0, 1, 1, S<1, 16, 1, 6>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1> // Presumably doesn't produce correct results for f16 + DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Tuple, F16, F16, F16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 128, 128, 32, 32, 8, 16, 16, 4, 1, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 4, true, S<4, 4, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 1, true, 1, 1, S<1, 32, 1, 4>, 2, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Tuple, F16, F16, F16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 64, 32, 64, 32, 8, 16, 16, 2, 2, S<4, 4, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, true, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + + DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Tuple, F16, F16, F16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 256, 256, 128, 32, 8, 16, 16, 8, 2, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 4, true, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 2, true, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Tuple, F16, F16, F16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 256, 128, 256, 32, 8, 16, 16, 4, 4, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 2, true, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 4, true, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Tuple, F16, F16, F16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 128, 128, 128, 32, 8, 16, 16, 4, 4, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 4, true, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Tuple, F16, F16, F16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 256, 128, 128, 32, 8, 16, 16, 4, 2, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 2, true, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Tuple, F16, F16, F16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 128, 128, 64, 32, 8, 16, 16, 4, 2, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 4, true, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Tuple, F16, F16, F16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 128, 64, 128, 32, 8, 16, 16, 2, 4, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 2, true, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Tuple, F16, F16, F16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 64, 64, 64, 32, 8, 16, 16, 4, 2, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 4, true, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Tuple, F16, F16, F16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 256, 128, 64, 32, 8, 16, 16, 2, 2, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 2, true, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Tuple, F16, F16, F16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 256, 64, 128, 32, 8, 16, 16, 2, 2, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 1, true, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Tuple, F16, F16, F16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 128, 128, 32, 32, 8, 16, 16, 4, 1, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 4, true, S<4, 4, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Tuple, F16, F16, F16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 128, 32, 128, 32, 8, 16, 16, 1, 4, S<4, 4, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 1, true, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Tuple, F16, F16, F16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 64, 64, 32, 32, 8, 16, 16, 4, 1, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 4, true, S<4, 4, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 2, true, 1, 1, S<1, 16, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Tuple, F16, F16, F16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 64, 32, 64, 32, 8, 16, 16, 2, 2, S<4, 4, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 2, true, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1> // clang-format on >; @@ -75,19 +81,24 @@ using device_grouped_conv_bwd_weight_wmma_c_shuffle_bf16_bilinear_instances = st //#################################################| Spatial| | | | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| MBlock_MPerBlock| _NPerBlock| Scheduler| Version | //#################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | | // generic instance - DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Tuple, BF16, F32, BF16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 64, 64, 64, 32, 8, 16, 16, 4, 2, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 4, 1, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, - // other instances - DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Tuple, BF16, F32, BF16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 64, 32, 32, 32, 8, 16, 16, 2, 1, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, 0, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, 0, 1, 1, S<1, 8, 1, 8>, 2, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, - DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Tuple, BF16, F32, BF16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 128, 128, 128, 32, 8, 16, 16, 8, 2, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 0, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, - DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Tuple, BF16, F32, BF16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 128, 128, 128, 32, 8, 16, 16, 8, 2, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 1, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, - DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Tuple, BF16, F32, BF16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 64, 64, 64, 64, 8, 16, 16, 4, 2, S<8, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, S<8, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, - DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Tuple, BF16, F32, BF16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 256, 128, 256, 64, 8, 16, 16, 8, 2, S<8, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 1, S<8, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, - DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Tuple, BF16, F32, BF16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 128, 48, 64, 128, 8, 16, 16, 3, 1, S<16, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 1, S<16, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, - DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Tuple, BF16, F32, BF16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 128, 96, 128, 64, 8, 16, 16, 6, 2, S<8, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 1, S<8, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, - DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Tuple, BF16, F32, BF16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 128, 64, 64, 128, 8, 16, 16, 4, 1, S<16, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, S<16, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, - DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Tuple, BF16, F32, BF16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 256, 96, 128, 128, 8, 16, 16, 6, 1, S<16, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 1, S<16, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1> - // DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Tuple, BF16, F32, BF16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 96, 96, 96, 48, 8, 16, 16, 6, 2, S<6, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 0, S<6, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 0, 1, 1, S<1, 16, 1, 6>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, // Verification failure - // DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Tuple, BF16, F32, BF16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 96, 96, 96, 48, 8, 16, 16, 6, 2, S<6, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 1, S<6, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 0, 1, 1, S<1, 16, 1, 6>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1> // Verification failure + DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Tuple, BF16, F32, BF16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 64, 64, 64, 32, 8, 16, 16, 4, 2, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 4, true, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 4, true, 1, 1, S<1, 16, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + // instance for small conv.K + DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Tuple, BF16, F32, BF16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 128, 128, 32, 32, 8, 16, 16, 4, 1, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 4, true, S<4, 4, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 1, true, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Tuple, BF16, F32, BF16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 64, 32, 64, 32, 8, 16, 16, 2, 2, S<4, 4, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 2, true, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + + DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Tuple, BF16, F32, BF16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 256, 256, 128, 32, 8, 16, 16, 8, 2, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 4, true, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 2, true, 1, 1, S<1, 32, 1, 8>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Tuple, BF16, F32, BF16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 256, 128, 256, 32, 8, 16, 16, 4, 4, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 2, true, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 4, true, 1, 1, S<1, 32, 1, 8>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Tuple, BF16, F32, BF16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 128, 128, 128, 32, 8, 16, 16, 4, 4, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 4, true, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Tuple, BF16, F32, BF16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 256, 128, 128, 32, 8, 16, 16, 4, 2, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 2, true, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Tuple, BF16, F32, BF16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 128, 128, 64, 32, 8, 16, 16, 4, 2, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 4, true, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Tuple, BF16, F32, BF16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 128, 64, 128, 32, 8, 16, 16, 2, 4, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 2, true, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Tuple, BF16, F32, BF16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 64, 64, 64, 32, 8, 16, 16, 4, 2, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 4, true, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Tuple, BF16, F32, BF16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 256, 128, 64, 32, 8, 16, 16, 2, 2, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 2, true, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Tuple, BF16, F32, BF16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 256, 64, 128, 32, 8, 16, 16, 2, 2, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 1, true, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Tuple, BF16, F32, BF16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 128, 128, 32, 32, 8, 16, 16, 4, 1, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 4, true, S<4, 4, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Tuple, BF16, F32, BF16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 128, 32, 128, 32, 8, 16, 16, 1, 4, S<4, 4, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 1, true, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Tuple, BF16, F32, BF16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 64, 64, 32, 32, 8, 16, 16, 4, 1, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 4, true, S<4, 4, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 2, true, 1, 1, S<1, 16, 1, 4>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Tuple, BF16, F32, BF16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 64, 32, 64, 32, 8, 16, 16, 2, 2, S<4, 4, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 2, true, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1> // clang-format on >; diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_wmma_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_wmma_instance.hpp new file mode 100644 index 0000000000..46e9f34988 --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_wmma_instance.hpp @@ -0,0 +1,182 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp" +#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using BF16 = ck::bhalf_t; +using F16 = ck::half_t; +using F32 = float; +using I8 = int8_t; +using I32 = int32_t; + +template +using S = ck::Sequence; + +using namespace ck::tensor_layout::convolution; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvBwdWeightDefault = + ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::Default; + +static constexpr auto ConvBwdWeightFilter1x1Stride1Pad0 = + ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0; + +template +using device_grouped_conv_bwd_weight_wmma_c_shuffle_f16_generic_instances = std::tuple< + // clang-format off + //#########################################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| KPer| ABK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer| BlockGemm| BlockGemm| + //#########################################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| Pipeline| Pipeline| + //#########################################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| MBlock_MPerBlock| _NPerBlock| Scheduler| Version| + //#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | | + // generic instance + DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 64, 32, 8, 16, 16, 4, 2, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 4, true, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 4, true, 1, 1, S<1, 16, 1, 4>, 2, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F16, F16, TransposeTransferSrcScalarPerVector, TransposeTransferDstScalarPerVector> + // clang-format on + >; + +template +using device_grouped_conv_bwd_weight_wmma_c_shuffle_f16_instances = std::tuple< + // clang-format off + //#########################################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| KPer| ABK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer| BlockGemm| BlockGemm| + //#########################################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| Pipeline| Pipeline| + //#########################################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| MBlock_MPerBlock| _NPerBlock| Scheduler| Version| + //#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | | + // // generic instance + DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 32, 32, 8, 16, 16, 2, 1, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, false, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, false, 1, 1, S<1, 8, 1, 8>, 2, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F16, F16, TransposeTransferSrcScalarPerVector, TransposeTransferDstScalarPerVector>, + // // instance for small conv.K + // // for fp16 conv.K and conv.C must be divisible by 2 + // // since half_t atomic_add require scalar_per_x_vector % 2 == 0 + DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 32, 32, 8, 16, 16, 2, 2, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 4, true, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 1, true, 1, 1, S<1, 32, 1, 4>, 2, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F16, F16, TransposeTransferSrcScalarPerVector, TransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 64, 32, 8, 16, 16, 2, 2, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, true, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F16, F16, TransposeTransferSrcScalarPerVector, TransposeTransferDstScalarPerVector>, + + DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 256, 128, 32, 8, 16, 16, 8, 2, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 4, true, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 2, true, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F16, F16, TransposeTransferSrcScalarPerVector, TransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 128, 256, 32, 8, 16, 16, 4, 4, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 2, true, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 4, true, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F16, F16, TransposeTransferSrcScalarPerVector, TransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 128, 32, 8, 16, 16, 4, 4, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 4, true, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F16, F16, TransposeTransferSrcScalarPerVector, TransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 64, 32, 8, 16, 16, 4, 2, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 4, true, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F16, F16, TransposeTransferSrcScalarPerVector, TransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 64, 128, 32, 8, 16, 16, 2, 4, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 2, true, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F16, F16, TransposeTransferSrcScalarPerVector, TransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 64, 32, 8, 16, 16, 4, 2, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 4, true, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F16, F16, TransposeTransferSrcScalarPerVector, TransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 128, 64, 32, 8, 16, 16, 2, 2, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 2, true, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F16, F16, TransposeTransferSrcScalarPerVector, TransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 64, 128, 32, 8, 16, 16, 2, 2, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 1, true, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F16, F16, TransposeTransferSrcScalarPerVector, TransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 32, 32, 8, 16, 16, 4, 1, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 4, true, S<4, 4, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F16, F16, TransposeTransferSrcScalarPerVector, TransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 32, 128, 32, 8, 16, 16, 1, 4, S<4, 4, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 1, true, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F16, F16, TransposeTransferSrcScalarPerVector, TransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 32, 32, 8, 16, 16, 4, 1, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 4, true, S<4, 4, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 2, true, 1, 1, S<1, 16, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F16, F16, TransposeTransferSrcScalarPerVector, TransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 64, 32, 8, 16, 16, 2, 2, S<4, 4, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 2, true, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F16, F16, TransposeTransferSrcScalarPerVector, TransposeTransferDstScalarPerVector> + // clang-format on + >; + +template +using device_grouped_conv_bwd_weight_wmma_c_shuffle_bf16_f32_bf16_generic_instances = std::tuple< + // clang-format off + //#########################################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| KPer| ABK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer| BlockGemm| BlockGemm| + //#########################################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| Pipeline| Pipeline| + //#########################################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| MBlock_MPerBlock| _NPerBlock| Scheduler| Version| + //#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | | + // generic instance + DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 64, 32, 8, 16, 16, 4, 2, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 4, true, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 4, true, 1, 1, S<1, 16, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1> + // clang-format on + >; + +template +using device_grouped_conv_bwd_weight_wmma_c_shuffle_bf16_f32_bf16_instances = std::tuple< + // clang-format off + //#########################################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| KPer| ABK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer| BlockGemm| BlockGemm| + //#########################################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| Pipeline| Pipeline| + //#########################################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| MBlock_MPerBlock| _NPerBlock| Scheduler| Version| + //#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | | + // generic instance + DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 64, 32, 8, 16, 16, 4, 2, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 4, true, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 4, true, 1, 1, S<1, 16, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, BF16, BF16, TransposeTransferSrcScalarPerVector, TransposeTransferDstScalarPerVector>, + // instance for small conv.K + DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 32, 32, 8, 16, 16, 2, 2, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 4, true, S<4, 4, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 1, true, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, BF16, BF16, TransposeTransferSrcScalarPerVector, TransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 64, 32, 8, 16, 16, 2, 2, S<4, 4, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 2, true, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, BF16, BF16, TransposeTransferSrcScalarPerVector, TransposeTransferDstScalarPerVector>, + + DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 256, 128, 32, 8, 16, 16, 4, 4, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 4, true, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 2, true, 1, 1, S<1, 32, 1, 8>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, BF16, BF16, TransposeTransferSrcScalarPerVector, TransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 128, 256, 32, 8, 16, 16, 4, 4, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 2, true, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 4, true, 1, 1, S<1, 32, 1, 8>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, BF16, BF16, TransposeTransferSrcScalarPerVector, TransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 128, 32, 8, 16, 16, 4, 4, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 4, true, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, BF16, BF16, TransposeTransferSrcScalarPerVector, TransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 128, 128, 32, 8, 16, 16, 4, 2, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 2, true, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, BF16, BF16, TransposeTransferSrcScalarPerVector, TransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 64, 32, 8, 16, 16, 4, 2, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 4, true, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, BF16, BF16, TransposeTransferSrcScalarPerVector, TransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 64, 128, 32, 8, 16, 16, 2, 4, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 2, true, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, BF16, BF16, TransposeTransferSrcScalarPerVector, TransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 64, 32, 8, 16, 16, 4, 2, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 4, true, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, BF16, BF16, TransposeTransferSrcScalarPerVector, TransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 128, 64, 32, 8, 16, 16, 2, 2, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 2, true, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, BF16, BF16, TransposeTransferSrcScalarPerVector, TransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 64, 128, 32, 8, 16, 16, 2, 2, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 1, true, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, BF16, BF16, TransposeTransferSrcScalarPerVector, TransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 32, 32, 8, 16, 16, 2, 2, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 4, true, S<4, 4, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, BF16, BF16, TransposeTransferSrcScalarPerVector, TransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 32, 128, 32, 8, 16, 16, 1, 4, S<4, 4, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 1, true, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, BF16, BF16, TransposeTransferSrcScalarPerVector, TransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 32, 32, 8, 16, 16, 4, 1, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 4, true, S<4, 4, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 2, true, 1, 1, S<1, 16, 1, 4>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, BF16, BF16, TransposeTransferSrcScalarPerVector, TransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 64, 32, 8, 16, 16, 2, 2, S<4, 4, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 2, true, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, BF16, BF16, TransposeTransferSrcScalarPerVector, TransposeTransferDstScalarPerVector> + // clang-format on + >; + +template +using device_grouped_conv_bwd_weight_wmma_c_shuffle_bf16_instances = std::tuple< + // clang-format off + //#########################################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| KPer| ABK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer| BlockGemm| BlockGemm| + //#########################################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| Pipeline| Pipeline| + //#########################################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| MBlock_MPerBlock| _NPerBlock| Scheduler| Version| + //#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | | + // generic instance + DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 64, 32, 8, 16, 16, 4, 2, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 4, true, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 4, true, 1, 1, S<1, 16, 1, 4>, 2, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, BF16, BF16, TransposeTransferSrcScalarPerVector, TransposeTransferDstScalarPerVector>, + // instance for small conv.K + // for bf16 conv.K and conv.C must be divisible by 2 + // since half_t atomic_add require scalar_per_x_vector % 2 == 0 + DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 32, 32, 8, 16, 16, 4, 1, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 4, true, S<4, 4, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 1, true, 1, 1, S<1, 32, 1, 4>, 2, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, BF16, BF16, TransposeTransferSrcScalarPerVector, TransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 64, 32, 8, 16, 16, 2, 2, S<4, 4, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, true, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, BF16, BF16, TransposeTransferSrcScalarPerVector, TransposeTransferDstScalarPerVector>, + + DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 256, 128, 32, 8, 16, 16, 8, 2, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 4, true, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 2, true, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, BF16, BF16, TransposeTransferSrcScalarPerVector, TransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 128, 256, 32, 8, 16, 16, 4, 4, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 2, true, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 4, true, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, BF16, BF16, TransposeTransferSrcScalarPerVector, TransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 128, 32, 8, 16, 16, 4, 4, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 4, true, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, BF16, BF16, TransposeTransferSrcScalarPerVector, TransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 128, 128, 32, 8, 16, 16, 4, 2, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 2, true, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, BF16, BF16, TransposeTransferSrcScalarPerVector, TransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 64, 32, 8, 16, 16, 4, 2, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 4, true, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, BF16, BF16, TransposeTransferSrcScalarPerVector, TransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 64, 128, 32, 8, 16, 16, 2, 4, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 2, true, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, BF16, BF16, TransposeTransferSrcScalarPerVector, TransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 64, 32, 8, 16, 16, 4, 2, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 4, true, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, BF16, BF16, TransposeTransferSrcScalarPerVector, TransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 128, 64, 32, 8, 16, 16, 2, 2, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 2, true, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, BF16, BF16, TransposeTransferSrcScalarPerVector, TransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 64, 128, 32, 8, 16, 16, 2, 2, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 1, true, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, BF16, BF16, TransposeTransferSrcScalarPerVector, TransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 32, 32, 8, 16, 16, 4, 1, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 4, true, S<4, 4, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, BF16, BF16, TransposeTransferSrcScalarPerVector, TransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 32, 128, 32, 8, 16, 16, 1, 4, S<4, 4, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 1, true, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, BF16, BF16, TransposeTransferSrcScalarPerVector, TransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 32, 32, 8, 16, 16, 4, 1, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 4, true, S<4, 4, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 2, true, 1, 1, S<1, 16, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, BF16, BF16, TransposeTransferSrcScalarPerVector, TransposeTransferDstScalarPerVector>, + DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 64, 32, 8, 16, 16, 2, 2, S<4, 4, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 2, true, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, BF16, BF16, TransposeTransferSrcScalarPerVector, TransposeTransferDstScalarPerVector> + + // clang-format on + >; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_wmma_scale_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_wmma_scale_instance.hpp index e893c92d1d..c2b752188f 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_wmma_scale_instance.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_wmma_scale_instance.hpp @@ -45,20 +45,26 @@ using device_grouped_conv_bwd_weight_wmma_c_shuffle_f16_scale_instances = std::t //#################################################| Spatial| | | | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| MBlock_MPerBlock| _NPerBlock| Scheduler| Version | //#################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | | // generic instance - DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F16, F16, F16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 64, 64, 64, 32, 8, 16, 16, 4, 2, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 4, 1, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 4, 1, 1, 1, S<1, 16, 1, 4>, 2, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F16, F16, F16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 64, 64, 64, 32, 8, 16, 16, 4, 2, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 4, true, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 4, true, 1, 1, S<1, 16, 1, 4>, 2, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + // instance for small conv.K // for fp16 conv.K and conv.C must be divisible by 2 // since half_t atomic_add require scalar_per_x_vector % 2 == 0 - DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F16, F16, F16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 64, 32, 32, 32, 8, 16, 16, 2, 1, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, 0, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, 0, 1, 1, S<1, 8, 1, 8>, 2, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, - DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F16, F16, F16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 128, 128, 128, 32, 8, 16, 16, 8, 2, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 0, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, - DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F16, F16, F16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 128, 128, 128, 32, 8, 16, 16, 8, 2, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 1, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, - DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F16, F16, F16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 64, 64, 64, 64, 8, 16, 16, 4, 2, S<8, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, S<8, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, - DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F16, F16, F16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 256, 128, 256, 64, 8, 16, 16, 8, 2, S<8, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 1, S<8, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, - DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F16, F16, F16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 128, 48, 64, 128, 8, 16, 16, 3, 1, S<16, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 1, S<16, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, - DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F16, F16, F16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 128, 96, 128, 64, 8, 16, 16, 6, 2, S<8, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 1, S<8, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, - DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F16, F16, F16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 128, 64, 64, 128, 8, 16, 16, 4, 1, S<16, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, S<16, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, - DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F16, F16, F16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 256, 96, 128, 128, 8, 16, 16, 6, 1, S<16, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 1, S<16, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1> - // DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F16, F16, F16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 96, 96, 96, 48, 8, 16, 16, 6, 2, S<6, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 0, S<6, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 0, 1, 1, S<1, 16, 1, 6>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, // Presumably doesn't produce correct results for fp16 - // DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F16, F16, F16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 96, 96, 96, 48, 8, 16, 16, 6, 2, S<6, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 1, S<6, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 0, 1, 1, S<1, 16, 1, 6>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1> // Presumably doesn't produce correct results for fp16 + DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F16, F16, F16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 128, 128, 32, 32, 8, 16, 16, 4, 1, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 4, true, S<4, 4, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 1, true, 1, 1, S<1, 32, 1, 4>, 2, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F16, F16, F16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 64, 32, 64, 32, 8, 16, 16, 2, 2, S<4, 4, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, true, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + + DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F16, F16, F16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 256, 256, 128, 32, 8, 16, 16, 8, 2, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 4, true, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 2, true, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F16, F16, F16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 256, 128, 256, 32, 8, 16, 16, 4, 4, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 2, true, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 4, true, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F16, F16, F16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 128, 128, 128, 32, 8, 16, 16, 4, 4, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 4, true, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F16, F16, F16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 256, 128, 128, 32, 8, 16, 16, 4, 2, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 2, true, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F16, F16, F16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 128, 128, 64, 32, 8, 16, 16, 4, 2, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 4, true, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F16, F16, F16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 128, 64, 128, 32, 8, 16, 16, 2, 4, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 2, true, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F16, F16, F16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 64, 64, 64, 32, 8, 16, 16, 4, 2, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 4, true, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F16, F16, F16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 256, 128, 64, 32, 8, 16, 16, 2, 2, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 2, true, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F16, F16, F16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 256, 64, 128, 32, 8, 16, 16, 2, 2, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 1, true, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F16, F16, F16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 128, 128, 32, 32, 8, 16, 16, 4, 1, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 4, true, S<4, 4, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F16, F16, F16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 128, 32, 128, 32, 8, 16, 16, 1, 4, S<4, 4, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 1, true, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F16, F16, F16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 64, 64, 32, 32, 8, 16, 16, 4, 1, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 4, true, S<4, 4, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 2, true, 1, 1, S<1, 16, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F16, F16, F16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 64, 32, 64, 32, 8, 16, 16, 2, 2, S<4, 4, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 2, true, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1> // clang-format on >; @@ -74,19 +80,24 @@ using device_grouped_conv_bwd_weight_wmma_c_shuffle_bf16_scale_instances = std:: //#################################################| Spatial| | | | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| MBlock_MPerBlock| _NPerBlock| Scheduler| Version | //#################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | | // generic instance - DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, BF16, F32, BF16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 64, 64, 64, 32, 8, 16, 16, 4, 2, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 4, 1, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, - // other instances - DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, BF16, F32, BF16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 64, 32, 32, 32, 8, 16, 16, 2, 1, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, 0, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, 0, 1, 1, S<1, 8, 1, 8>, 2, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, - DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, BF16, F32, BF16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 128, 128, 128, 32, 8, 16, 16, 8, 2, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 0, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, - DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, BF16, F32, BF16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 128, 128, 128, 32, 8, 16, 16, 8, 2, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 1, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, - DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, BF16, F32, BF16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 64, 64, 64, 64, 8, 16, 16, 4, 2, S<8, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, S<8, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, - DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, BF16, F32, BF16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 256, 128, 256, 64, 8, 16, 16, 8, 2, S<8, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 1, S<8, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, - DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, BF16, F32, BF16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 128, 48, 64, 128, 8, 16, 16, 3, 1, S<16, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 1, S<16, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, - DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, BF16, F32, BF16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 128, 96, 128, 64, 8, 16, 16, 6, 2, S<8, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 1, S<8, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, - DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, BF16, F32, BF16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 128, 64, 64, 128, 8, 16, 16, 4, 1, S<16, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, S<16, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, - DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, BF16, F32, BF16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 256, 96, 128, 128, 8, 16, 16, 6, 1, S<16, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 1, S<16, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1> - // DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, BF16, F32, BF16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 96, 96, 96, 48, 8, 16, 16, 6, 2, S<6, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 0, S<6, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 0, 1, 1, S<1, 16, 1, 6>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, // Verification failure - // DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, BF16, F32, BF16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 96, 96, 96, 48, 8, 16, 16, 6, 2, S<6, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 1, S<6, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 0, 1, 1, S<1, 16, 1, 6>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1> // Verification failure + DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, BF16, F32, BF16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 64, 64, 64, 32, 8, 16, 16, 4, 2, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 4, true, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 4, true, 1, 1, S<1, 16, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + // instance for small conv.K + DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, BF16, F32, BF16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 128, 128, 32, 32, 8, 16, 16, 4, 1, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 4, true, S<4, 4, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 1, true, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, BF16, F32, BF16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 64, 32, 64, 32, 8, 16, 16, 2, 2, S<4, 4, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 2, true, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + + DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, BF16, F32, BF16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 256, 256, 128, 32, 8, 16, 16, 8, 2, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 4, true, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 2, true, 1, 1, S<1, 32, 1, 8>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, BF16, F32, BF16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 256, 128, 256, 32, 8, 16, 16, 4, 4, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 2, true, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 4, true, 1, 1, S<1, 32, 1, 8>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, BF16, F32, BF16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 128, 128, 128, 32, 8, 16, 16, 4, 4, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 4, true, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, BF16, F32, BF16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 256, 128, 128, 32, 8, 16, 16, 4, 2, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 2, true, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, BF16, F32, BF16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 128, 128, 64, 32, 8, 16, 16, 4, 2, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 4, true, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, BF16, F32, BF16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 128, 64, 128, 32, 8, 16, 16, 2, 4, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 2, true, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, BF16, F32, BF16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 64, 64, 64, 32, 8, 16, 16, 4, 2, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 4, true, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, BF16, F32, BF16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 256, 128, 64, 32, 8, 16, 16, 2, 2, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 2, true, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, BF16, F32, BF16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 256, 64, 128, 32, 8, 16, 16, 2, 2, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 1, true, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 2, true, 1, 1, S<1, 32, 1, 4>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, BF16, F32, BF16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 128, 128, 32, 32, 8, 16, 16, 4, 1, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 4, true, S<4, 4, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 1, true, 1, 1, S<1, 32, 1, 4>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, BF16, F32, BF16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 128, 32, 128, 32, 8, 16, 16, 1, 4, S<4, 4, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 1, true, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 4, true, 1, 1, S<1, 32, 1, 4>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, BF16, F32, BF16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 64, 64, 32, 32, 8, 16, 16, 4, 1, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 4, true, S<4, 4, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 2, true, 1, 1, S<1, 16, 1, 4>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, BF16, F32, BF16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 64, 32, 64, 32, 8, 16, 16, 2, 2, S<4, 4, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 2, true, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 4, true, 1, 1, S<1, 16, 1, 4>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1> // clang-format on >; diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp index 6dd8758eb7..573a42e743 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp @@ -857,8 +857,60 @@ struct DeviceOperationInstanceFactory && is_same_v && + is_same_v) + { +#ifdef CK_ENABLE_FP16 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv1d_bwd_weight_wmma_gnwc_gkxc_gnwk_f16_instances(op_ptrs); + } +#endif +#ifdef CK_ENABLE_BF16 + if constexpr(is_same_v && is_same_v && + is_same_v && + is_same_v && + is_same_v) + { + add_device_grouped_conv1d_bwd_weight_wmma_gnwc_gkxc_gnwk_bf16_f32_bf16_instances( + op_ptrs); + } +#endif + } + } if constexpr(NumDimSpatial == 2) { + if constexpr(is_same_v && is_same_v && + is_same_v) + { +#ifdef CK_ENABLE_FP16 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv2d_bwd_weight_wmma_gnhwc_gkyxc_gnhwk_f16_instances( + op_ptrs); + add_device_grouped_conv2d_bwd_weight_wmma_gnhwc_gkyxc_gnhwk_f16_default_pipev1_instances( + op_ptrs); + add_device_grouped_conv2d_bwd_weight_wmma_gnhwc_gkyxc_gnhwk_f16_pad0_pipev1_instances( + op_ptrs); + } +#endif +#ifdef CK_ENABLE_BF16 + if constexpr(is_same_v && is_same_v && + is_same_v && + is_same_v && + is_same_v) + { + add_device_grouped_conv2d_bwd_weight_wmma_gnhwc_gkyxc_gnhwk_bf16_f32_bf16_instances( + op_ptrs); + } +#endif + } if constexpr(is_same_v && is_same_v && is_same_v) { @@ -869,15 +921,102 @@ struct DeviceOperationInstanceFactory && is_same_v && + is_same_v && + is_same_v && + is_same_v) + { + add_device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_bf16_f32_bf16_instances( + op_ptrs); + } + if constexpr(is_same_v && + is_same_v && + is_same_v && + is_same_v && + is_same_v) + { + add_device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_bf16_instances( + op_ptrs); + add_device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_bf16_default_pipev1_instances( + op_ptrs); + add_device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_bf16_pad0_pipev1_instances( + op_ptrs); + + add_device_grouped_conv2d_bwd_weight_two_stage_wmma_nhwgc_gkyxc_nhwgk_bf16_pipev1_instances( + op_ptrs); + add_device_grouped_conv2d_bwd_weight_two_stage_wmma_nhwgc_gkyxc_nhwgk_bf16_pipev1_part2_instances( + op_ptrs); + add_device_grouped_conv2d_bwd_weight_two_stage_wmma_nhwgc_gkyxc_nhwgk_bf16_pipev1_irregular_instances( + op_ptrs); + // Explicit GEMM + add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_comp_default_instances( + op_ptrs); + add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_mem_v1_default_instances( + op_ptrs); + add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_odd_m_instances( + op_ptrs); + add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_odd_n_instances( + op_ptrs); + add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_odd_mn_instances( + op_ptrs); + add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_comp_mnkpadding_instances( + op_ptrs); + add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_mem_v1_mnkpadding_instances( + op_ptrs); + add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_mem_v2_default_instances( + op_ptrs); + add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_mem_v2_mnkpadding_instances( + op_ptrs); + } +#endif + } + if constexpr(is_same_v && is_same_v && + is_same_v) + { +#ifdef CK_ENABLE_FP16 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv2d_bwd_weight_two_stage_wmma_ngchw_gkcyx_ngkhw_f16_pipev1_instances( + op_ptrs); + add_device_grouped_conv2d_bwd_weight_two_stage_wmma_ngchw_gkcyx_ngkhw_f16_pipev1_part2_instances( + op_ptrs); + add_device_grouped_conv2d_bwd_weight_wmma_ngchw_gkcyx_ngkhw_f16_instances( + op_ptrs); } #endif #ifdef CK_ENABLE_BF16 @@ -887,16 +1026,35 @@ struct DeviceOperationInstanceFactory && is_same_v) { - add_device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_bf16_instances( + add_device_grouped_conv2d_bwd_weight_wmma_ngchw_gkcyx_ngkhw_bf16_instances( op_ptrs); - add_device_grouped_conv2d_bwd_weight_two_stage_wmma_nhwgc_gkyxc_nhwgk_bf16_pipev1_instances( + add_device_grouped_conv2d_bwd_weight_two_stage_wmma_ngchw_gkcyx_ngkhw_bf16_pipev1_instances( op_ptrs); - // Explicit GEMM - add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_comp_default_instances( + add_device_grouped_conv2d_bwd_weight_two_stage_wmma_ngchw_gkcyx_ngkhw_bf16_pipev1_part2_instances( op_ptrs); - add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_odd_mn_instances( + } +#endif + } + if constexpr(is_same_v && is_same_v && + is_same_v) + { +#ifdef CK_ENABLE_FP16 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv2d_bwd_weight_two_stage_wmma_ngchw_gkyxc_ngkhw_f16_pipev1_instances( op_ptrs); - add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_comp_mnkpadding_instances( + } +#endif +#ifdef CK_ENABLE_BF16 + if constexpr(is_same_v && + is_same_v && + is_same_v && + is_same_v && + is_same_v) + { + add_device_grouped_conv2d_bwd_weight_two_stage_wmma_ngchw_gkyxc_ngkhw_bf16_pipev1_instances( op_ptrs); } #endif @@ -904,6 +1062,29 @@ struct DeviceOperationInstanceFactory && is_same_v && + is_same_v) + { +#ifdef CK_ENABLE_FP16 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_f16_instances( + op_ptrs); + } +#endif +#ifdef CK_ENABLE_BF16 + if constexpr(is_same_v && is_same_v && + is_same_v && + is_same_v && + is_same_v) + { + add_device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_bf16_f32_bf16_instances( + op_ptrs); + } +#endif + } if constexpr(is_same_v && is_same_v && is_same_v) { @@ -914,15 +1095,102 @@ struct DeviceOperationInstanceFactory && is_same_v && + is_same_v && + is_same_v && + is_same_v) + { + add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_instances( + op_ptrs); + } + if constexpr(is_same_v && + is_same_v && + is_same_v && + is_same_v && + is_same_v) + { + add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_bf16_instances( + op_ptrs); + add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_bf16_default_pipev1_instances( + op_ptrs); + add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_bf16_pad0_pipev1_instances( + op_ptrs); + + add_device_grouped_conv3d_bwd_weight_two_stage_wmma_ndhwgc_gkzyxc_ndhwgk_bf16_pipev1_instances( + op_ptrs); + add_device_grouped_conv3d_bwd_weight_two_stage_wmma_ndhwgc_gkzyxc_ndhwgk_bf16_pipev1_part2_instances( + op_ptrs); + add_device_grouped_conv3d_bwd_weight_two_stage_wmma_ndhwgc_gkzyxc_ndhwgk_bf16_pipev1_irregular_instances( + op_ptrs); + // Explicit GEMM + add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_comp_default_instances( + op_ptrs); + add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_mem_v1_default_instances( + op_ptrs); + add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_odd_m_instances( + op_ptrs); + add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_odd_n_instances( + op_ptrs); + add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_odd_mn_instances( + op_ptrs); + add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_comp_mnkpadding_instances( + op_ptrs); + add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_mem_v1_mnkpadding_instances( + op_ptrs); + add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_mem_v2_default_instances( + op_ptrs); + add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_mem_v2_mnkpadding_instances( + op_ptrs); + } +#endif + } + if constexpr(is_same_v && is_same_v && + is_same_v) + { +#ifdef CK_ENABLE_FP16 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv3d_bwd_weight_wmma_ngcdhw_gkczyx_ngkdhw_f16_instances( + op_ptrs); + add_device_grouped_conv3d_bwd_weight_two_stage_wmma_ngcdhw_gkczyx_ngkdhw_f16_pipev1_instances( + op_ptrs); + add_device_grouped_conv3d_bwd_weight_two_stage_wmma_ngcdhw_gkczyx_ngkdhw_f16_pipev1_part2_instances( + op_ptrs); } #endif #ifdef CK_ENABLE_BF16 @@ -932,16 +1200,35 @@ struct DeviceOperationInstanceFactory && is_same_v) { - add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_bf16_instances( + add_device_grouped_conv3d_bwd_weight_wmma_ngcdhw_gkczyx_ngkdhw_bf16_instances( op_ptrs); - add_device_grouped_conv3d_bwd_weight_two_stage_wmma_ndhwgc_gkzyxc_ndhwgk_bf16_pipev1_instances( + add_device_grouped_conv3d_bwd_weight_two_stage_wmma_ngcdhw_gkczyx_ngkdhw_bf16_pipev1_instances( op_ptrs); - // Explicit GEMM - add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_comp_default_instances( + add_device_grouped_conv3d_bwd_weight_two_stage_wmma_ngcdhw_gkczyx_ngkdhw_bf16_pipev1_part2_instances( op_ptrs); - add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_odd_mn_instances( + } +#endif + } + if constexpr(is_same_v && is_same_v && + is_same_v) + { +#ifdef CK_ENABLE_FP16 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv3d_bwd_weight_two_stage_wmma_ngcdhw_gkzyxc_ngkdhw_f16_pipev1_instances( op_ptrs); - add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_comp_mnkpadding_instances( + } +#endif +#ifdef CK_ENABLE_BF16 + if constexpr(is_same_v && + is_same_v && + is_same_v && + is_same_v && + is_same_v) + { + add_device_grouped_conv3d_bwd_weight_two_stage_wmma_ngcdhw_gkzyxc_ngkdhw_bf16_pipev1_instances( op_ptrs); } #endif diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_explicit_wmma.inc b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_explicit_wmma.inc index d7fefde5cd..372d6f2223 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_explicit_wmma.inc +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_explicit_wmma.inc @@ -22,6 +22,42 @@ void add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_comp_default_i PassThrough, PassThrough>>>& instances); +void add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_mem_v1_default_instances( + std::vector>>& instances); + +void add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_odd_m_instances( + std::vector>>& instances); + +void add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_odd_n_instances( + std::vector>>& instances); + void add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_odd_mn_instances( std::vector>>& instances); + +void add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_mem_v1_mnkpadding_instances( + std::vector>>& instances); + +void add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_mem_v2_default_instances( + std::vector>>& instances); + +void add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_mem_v2_mnkpadding_instances( + std::vector>>& instances); #endif #ifdef CK_ENABLE_FP16 @@ -73,6 +145,54 @@ void add_device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_comp_mnkpadding_i PassThrough, PassThrough>>>& instances); +void add_device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_mem_v1_default_instances( + std::vector>>& instances); + +void add_device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_mem_v1_mnkpadding_instances( + std::vector>>& instances); + +void add_device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_mem_v2_default_instances( + std::vector>>& instances); + +void add_device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_mem_v2_mnkpadding_instances( + std::vector>>& instances); + void add_device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_odd_mn_instances( std::vector>>& instances); + +void add_device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_odd_m_instances( + std::vector>>& instances); + +void add_device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_odd_n_instances( + std::vector>>& instances); #endif // 3D @@ -101,6 +245,42 @@ void add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_comp_default_i PassThrough, PassThrough>>>& instances); +void add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_mem_v1_default_instances( + std::vector>>& instances); + +void add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_odd_m_instances( + std::vector>>& instances); + +void add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_odd_n_instances( + std::vector>>& instances); + void add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_odd_mn_instances( std::vector>>& instances); + +void add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_mem_v1_mnkpadding_instances( + std::vector>>& instances); + +void add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_mem_v2_default_instances( + std::vector>>& instances); + +void add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_mem_v2_mnkpadding_instances( + std::vector>>& instances); #endif #ifdef CK_ENABLE_FP16 @@ -152,6 +368,54 @@ void add_device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_comp_mnkpadding_i PassThrough, PassThrough>>>& instances); +void add_device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_mem_v1_default_instances( + std::vector>>& instances); + +void add_device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_mem_v1_mnkpadding_instances( + std::vector>>& instances); + +void add_device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_mem_v2_default_instances( + std::vector>>& instances); + +void add_device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_mem_v2_mnkpadding_instances( + std::vector>>& instances); + void add_device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_odd_mn_instances( std::vector>>& instances); + +void add_device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_odd_m_instances( + std::vector>>& instances); + +void add_device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_odd_n_instances( + std::vector>>& instances); #endif } // namespace instance diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_wmma.inc b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_wmma.inc index 06247019f1..822a46b16f 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_wmma.inc +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_wmma.inc @@ -8,8 +8,73 @@ namespace tensor_operation { namespace device { namespace instance { +// conv1d backward weight +#ifdef CK_ENABLE_BF16 +void add_device_grouped_conv1d_bwd_weight_wmma_gnwc_gkxc_gnwk_bf16_f32_bf16_instances( + std::vector>>& instances); +#endif + +#ifdef CK_ENABLE_FP16 +void add_device_grouped_conv1d_bwd_weight_wmma_gnwc_gkxc_gnwk_f16_instances( + std::vector>>& instances); +#endif + // conv2d backward weight #ifdef CK_ENABLE_FP16 +void add_device_grouped_conv2d_bwd_weight_wmma_gnhwc_gkyxc_gnhwk_f16_instances( + std::vector>>& instances); + +void add_device_grouped_conv2d_bwd_weight_wmma_gnhwc_gkyxc_gnhwk_f16_default_pipev1_instances( + std::vector>>& instances); + +void add_device_grouped_conv2d_bwd_weight_wmma_gnhwc_gkyxc_gnhwk_f16_pad0_pipev1_instances( + std::vector>>& instances); + void add_device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_f16_instances( std::vector>>& instances); +void add_device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_f16_default_pipev1_instances( + std::vector>>& instances); + +void add_device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_f16_pad0_pipev1_instances( + std::vector>>& instances); + void add_device_grouped_conv2d_bwd_weight_two_stage_wmma_nhwgc_gkyxc_nhwgk_f16_pipev1_instances( std::vector>>& instances); + +void add_device_grouped_conv2d_bwd_weight_two_stage_wmma_nhwgc_gkyxc_nhwgk_f16_pipev1_part2_instances( + std::vector>>& instances); + +void add_device_grouped_conv2d_bwd_weight_two_stage_wmma_nhwgc_gkyxc_nhwgk_f16_pipev1_irregular_instances( + std::vector>>& instances); + +void add_device_grouped_conv2d_bwd_weight_two_stage_wmma_ngchw_gkyxc_ngkhw_f16_pipev1_instances( + std::vector>>& instances); + +void add_device_grouped_conv2d_bwd_weight_two_stage_wmma_ngchw_gkcyx_ngkhw_f16_pipev1_instances( + std::vector>>& instances); + +void add_device_grouped_conv2d_bwd_weight_two_stage_wmma_ngchw_gkcyx_ngkhw_f16_pipev1_part2_instances( + std::vector>>& instances); + +void add_device_grouped_conv2d_bwd_weight_wmma_ngchw_gkcyx_ngkhw_f16_instances( + std::vector>>& instances); #endif #ifdef CK_ENABLE_BF16 +void add_device_grouped_conv2d_bwd_weight_wmma_gnhwc_gkyxc_gnhwk_bf16_f32_bf16_instances( + std::vector>>& instances); + void add_device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_bf16_instances( std::vector>>& instances); +void add_device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_bf16_default_pipev1_instances( + std::vector>>& instances); + +void add_device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_bf16_pad0_pipev1_instances( + std::vector>>& instances); + void add_device_grouped_conv2d_bwd_weight_two_stage_wmma_nhwgc_gkyxc_nhwgk_bf16_pipev1_instances( std::vector>>& instances); + +void add_device_grouped_conv2d_bwd_weight_two_stage_wmma_nhwgc_gkyxc_nhwgk_bf16_pipev1_part2_instances( + std::vector>>& instances); + +void add_device_grouped_conv2d_bwd_weight_two_stage_wmma_nhwgc_gkyxc_nhwgk_bf16_pipev1_irregular_instances( + std::vector>>& instances); + +void add_device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_bf16_f32_bf16_instances( + std::vector>>& instances); + +void add_device_grouped_conv2d_bwd_weight_two_stage_wmma_ngchw_gkyxc_ngkhw_bf16_pipev1_instances( + std::vector>>& instances); + +void add_device_grouped_conv2d_bwd_weight_wmma_ngchw_gkcyx_ngkhw_bf16_instances( + std::vector>>& instances); + +void add_device_grouped_conv2d_bwd_weight_two_stage_wmma_ngchw_gkcyx_ngkhw_bf16_pipev1_instances( + std::vector>>& instances); + +void add_device_grouped_conv2d_bwd_weight_two_stage_wmma_ngchw_gkcyx_ngkhw_bf16_pipev1_part2_instances( + std::vector>>& instances); #endif // conv3d backward weight #ifdef CK_ENABLE_FP16 +void add_device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_f16_instances( + std::vector>>& instances); + void add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_f16_instances( std::vector>>& instances); +void add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_f16_default_pipev1_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_f16_pad0_pipev1_instances( + std::vector>>& instances); + void add_device_grouped_conv3d_bwd_weight_two_stage_wmma_ndhwgc_gkzyxc_ndhwgk_f16_pipev1_instances( std::vector>>& instances); -#endif +void add_device_grouped_conv3d_bwd_weight_two_stage_wmma_ndhwgc_gkzyxc_ndhwgk_f16_pipev1_part2_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_bwd_weight_two_stage_wmma_ndhwgc_gkzyxc_ndhwgk_f16_pipev1_irregular_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_bwd_weight_wmma_ngcdhw_gkczyx_ngkdhw_f16_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_bwd_weight_two_stage_wmma_ngcdhw_gkczyx_ngkdhw_f16_pipev1_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_bwd_weight_two_stage_wmma_ngcdhw_gkczyx_ngkdhw_f16_pipev1_part2_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_bwd_weight_two_stage_wmma_ngcdhw_gkzyxc_ngkdhw_f16_pipev1_instances( + std::vector>>& instances); +#endif #ifdef CK_ENABLE_BF16 +void add_device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_bf16_f32_bf16_instances( + std::vector>>& instances); + void add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_bf16_instances( std::vector>>& instances); +void add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_bf16_default_pipev1_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_bf16_pad0_pipev1_instances( + std::vector>>& instances); + void add_device_grouped_conv3d_bwd_weight_two_stage_wmma_ndhwgc_gkzyxc_ndhwgk_bf16_pipev1_instances( std::vector>>& instances); + +void add_device_grouped_conv3d_bwd_weight_two_stage_wmma_ndhwgc_gkzyxc_ndhwgk_bf16_pipev1_part2_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_bwd_weight_two_stage_wmma_ndhwgc_gkzyxc_ndhwgk_bf16_pipev1_irregular_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_bwd_weight_wmma_ngcdhw_gkczyx_ngkdhw_bf16_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_bwd_weight_two_stage_wmma_ngcdhw_gkczyx_ngkdhw_bf16_pipev1_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_bwd_weight_two_stage_wmma_ngcdhw_gkczyx_ngkdhw_bf16_pipev1_part2_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_bwd_weight_two_stage_wmma_ngcdhw_gkzyxc_ngkdhw_bf16_pipev1_instances( + std::vector>>& instances); #endif } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv1d_bwd_weight/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv1d_bwd_weight/CMakeLists.txt index 56a9d16623..eaf1aa6b30 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv1d_bwd_weight/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv1d_bwd_weight/CMakeLists.txt @@ -17,4 +17,8 @@ if(DL_KERNELS) dl/device_grouped_conv1d_bwd_weight_dl_nwgc_gkxc_nwgk_bf16_f32_bf16_instance.cpp) endif() +list(APPEND GROUPED_CONV1D_BWD_WEIGHT + wmma/device_grouped_conv1d_bwd_weight_wmma_gnwc_gkxc_gnwk_f16_instance.cpp + wmma/device_grouped_conv1d_bwd_weight_wmma_gnwc_gkxc_gnwk_bf16_f32_bf16_instance.cpp) + add_instance_library(device_grouped_conv1d_bwd_weight_instance ${GROUPED_CONV1D_BWD_WEIGHT}) diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv1d_bwd_weight/wmma/device_grouped_conv1d_bwd_weight_wmma_gnwc_gkxc_gnwk_bf16_f32_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv1d_bwd_weight/wmma/device_grouped_conv1d_bwd_weight_wmma_gnwc_gkxc_gnwk_bf16_f32_bf16_instance.cpp new file mode 100644 index 0000000000..43d8024180 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv1d_bwd_weight/wmma/device_grouped_conv1d_bwd_weight_wmma_gnwc_gkxc_gnwk_bf16_f32_bf16_instance.cpp @@ -0,0 +1,47 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_wmma_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv1d_bwd_weight_wmma_gnwc_gkxc_gnwk_bf16_f32_bf16_instances( + std::vector>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_wmma_c_shuffle_bf16_f32_bf16_instances< + 1, + GNWC, + GKXC, + GNWK, + ConvBwdWeightDefault>{}); + // 2. Filter1x1Stride1Pad0 + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_wmma_c_shuffle_bf16_f32_bf16_instances< + 1, + GNWC, + GKXC, + GNWK, + ConvBwdWeightFilter1x1Stride1Pad0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv1d_bwd_weight/wmma/device_grouped_conv1d_bwd_weight_wmma_gnwc_gkxc_gnwk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv1d_bwd_weight/wmma/device_grouped_conv1d_bwd_weight_wmma_gnwc_gkxc_gnwk_f16_instance.cpp new file mode 100644 index 0000000000..724a5f409f --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv1d_bwd_weight/wmma/device_grouped_conv1d_bwd_weight_wmma_gnwc_gkxc_gnwk_f16_instance.cpp @@ -0,0 +1,45 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_wmma_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv1d_bwd_weight_wmma_gnwc_gkxc_gnwk_f16_instances( + std::vector>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_wmma_c_shuffle_f16_instances<1, + GNWC, + GKXC, + GNWK, + ConvBwdWeightDefault>{}); + // 2. Filter1x1Stride1Pad0 + add_device_operation_instances(instances, + device_grouped_conv_bwd_weight_wmma_c_shuffle_f16_instances< + 1, + GNWC, + GKXC, + GNWK, + ConvBwdWeightFilter1x1Stride1Pad0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/CMakeLists.txt index ec9e7da391..bf32812e5a 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/CMakeLists.txt @@ -73,10 +73,34 @@ if(DL_KERNELS) endif() list(APPEND GROUPED_CONV2D_BWD_WEIGHT + wmma/gnhwc_gkyxc_gnhwk/device_grouped_conv2d_bwd_weight_wmma_gnhwc_gkyxc_gnhwk_bf16_f32_bf16_instance.cpp + wmma/gnhwc_gkyxc_gnhwk/device_grouped_conv2d_bwd_weight_wmma_gnhwc_gkyxc_gnhwk_f16_default_pipev1_instance.cpp + wmma/gnhwc_gkyxc_gnhwk/device_grouped_conv2d_bwd_weight_wmma_gnhwc_gkyxc_gnhwk_f16_pad0_pipev1_instance.cpp + wmma/gnhwc_gkyxc_gnhwk/device_grouped_conv2d_bwd_weight_wmma_gnhwc_gkyxc_gnhwk_f16_instance.cpp + + wmma/ngchw_gkcyx_ngkhw/device_grouped_conv2d_bwd_weight_wmma_ngchw_gkcyx_ngkhw_f16_instance.cpp + wmma/ngchw_gkcyx_ngkhw/device_grouped_conv2d_bwd_weight_wmma_ngchw_gkcyx_ngkhw_bf16_instance.cpp + wmma/ngchw_gkcyx_ngkhw/device_grouped_conv2d_bwd_weight_two_stage_wmma_ngchw_gkcyx_ngkhw_f16_pipev1_instance.cpp + wmma/ngchw_gkcyx_ngkhw/device_grouped_conv2d_bwd_weight_two_stage_wmma_ngchw_gkcyx_ngkhw_f16_pipev1_part2_instance.cpp + wmma/ngchw_gkcyx_ngkhw/device_grouped_conv2d_bwd_weight_two_stage_wmma_ngchw_gkcyx_ngkhw_bf16_pipev1_instance.cpp + wmma/ngchw_gkcyx_ngkhw/device_grouped_conv2d_bwd_weight_two_stage_wmma_ngchw_gkcyx_ngkhw_bf16_pipev1_part2_instance.cpp + + wmma/ngchw_gkyxc_ngkhw/device_grouped_conv2d_bwd_weight_two_stage_wmma_ngchw_gkyxc_ngkhw_f16_pipev1_instance.cpp + wmma/ngchw_gkyxc_ngkhw/device_grouped_conv2d_bwd_weight_two_stage_wmma_ngchw_gkyxc_ngkhw_bf16_pipev1_instance.cpp + wmma/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_bf16_instance.cpp + wmma/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_bf16_default_pipev1_instance.cpp + wmma/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_bf16_pad0_pipev1_instance.cpp wmma/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_two_stage_wmma_nhwgc_gkyxc_nhwgk_bf16_pipev1_instance.cpp + wmma/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_two_stage_wmma_nhwgc_gkyxc_nhwgk_bf16_pipev1_part2_instance.cpp + wmma/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_two_stage_wmma_nhwgc_gkyxc_nhwgk_bf16_pipev1_irregular_instance.cpp + wmma/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_bf16_f32_bf16_instance.cpp wmma/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_f16_instance.cpp + wmma/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_f16_default_pipev1_instance.cpp + wmma/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_f16_pad0_pipev1_instance.cpp wmma/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_two_stage_wmma_nhwgc_gkyxc_nhwgk_f16_pipev1_instance.cpp + wmma/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_two_stage_wmma_nhwgc_gkyxc_nhwgk_f16_pipev1_part2_instance.cpp + wmma/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_two_stage_wmma_nhwgc_gkyxc_nhwgk_f16_pipev1_irregular_instance.cpp ) add_instance_library(device_grouped_conv2d_bwd_weight_instance ${GROUPED_CONV2D_BWD_WEIGHT}) diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/wmma/gnhwc_gkyxc_gnhwk/device_grouped_conv2d_bwd_weight_wmma_gnhwc_gkyxc_gnhwk_bf16_f32_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/wmma/gnhwc_gkyxc_gnhwk/device_grouped_conv2d_bwd_weight_wmma_gnhwc_gkyxc_gnhwk_bf16_f32_bf16_instance.cpp new file mode 100644 index 0000000000..b2a78136f7 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/wmma/gnhwc_gkyxc_gnhwk/device_grouped_conv2d_bwd_weight_wmma_gnhwc_gkyxc_gnhwk_bf16_f32_bf16_instance.cpp @@ -0,0 +1,49 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_wmma_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Compilation parameters for in[g, n, hi, wi, c] * wei[g, k, y, x, c] = out[g, n, ho, wo, k] +void add_device_grouped_conv2d_bwd_weight_wmma_gnhwc_gkyxc_gnhwk_bf16_f32_bf16_instances( + std::vector>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_wmma_c_shuffle_bf16_f32_bf16_generic_instances< + 2, + GNHWC, + GKYXC, + GNHWK, + ConvBwdWeightDefault>{}); + + // 2. Filter1x1Stride1Pad0 + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_wmma_c_shuffle_bf16_f32_bf16_generic_instances< + 2, + GNHWC, + GKYXC, + GNHWK, + ConvBwdWeightFilter1x1Stride1Pad0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/wmma/gnhwc_gkyxc_gnhwk/device_grouped_conv2d_bwd_weight_wmma_gnhwc_gkyxc_gnhwk_f16_default_pipev1_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/wmma/gnhwc_gkyxc_gnhwk/device_grouped_conv2d_bwd_weight_wmma_gnhwc_gkyxc_gnhwk_f16_default_pipev1_instance.cpp new file mode 100644 index 0000000000..78b2d7b93d --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/wmma/gnhwc_gkyxc_gnhwk/device_grouped_conv2d_bwd_weight_wmma_gnhwc_gkyxc_gnhwk_f16_default_pipev1_instance.cpp @@ -0,0 +1,38 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_wmma_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Compilation parameters for in[g, n, hi, wi, c] * wei[g, k, y, x, c] = out[g, n, ho, wo, k] +void add_device_grouped_conv2d_bwd_weight_wmma_gnhwc_gkyxc_gnhwk_f16_default_pipev1_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_bwd_weight_v3_wmma_c_shuffle_f16_instances< + 2, + GNHWC, + GKYXC, + GNHWK, + ConvBwdWeightDefault, + BlockGemmPipelineScheduler::Intrawave, + BlockGemmPipelineVersion::v1>{}); +} +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/wmma/gnhwc_gkyxc_gnhwk/device_grouped_conv2d_bwd_weight_wmma_gnhwc_gkyxc_gnhwk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/wmma/gnhwc_gkyxc_gnhwk/device_grouped_conv2d_bwd_weight_wmma_gnhwc_gkyxc_gnhwk_f16_instance.cpp new file mode 100644 index 0000000000..131c505468 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/wmma/gnhwc_gkyxc_gnhwk/device_grouped_conv2d_bwd_weight_wmma_gnhwc_gkyxc_gnhwk_f16_instance.cpp @@ -0,0 +1,48 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_wmma_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Compilation parameters for in[g, n, hi, wi, c] * wei[g, k, y, x, c] = out[g, n, ho, wo, k] +void add_device_grouped_conv2d_bwd_weight_wmma_gnhwc_gkyxc_gnhwk_f16_instances( + std::vector>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_wmma_c_shuffle_f16_generic_instances< + 2, + GNHWC, + GKYXC, + GNHWK, + ConvBwdWeightDefault>{}); + // 2. Filter1x1Stride1Pad0 + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_wmma_c_shuffle_f16_generic_instances< + 2, + GNHWC, + GKYXC, + GNHWK, + ConvBwdWeightFilter1x1Stride1Pad0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/wmma/gnhwc_gkyxc_gnhwk/device_grouped_conv2d_bwd_weight_wmma_gnhwc_gkyxc_gnhwk_f16_pad0_pipev1_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/wmma/gnhwc_gkyxc_gnhwk/device_grouped_conv2d_bwd_weight_wmma_gnhwc_gkyxc_gnhwk_f16_pad0_pipev1_instance.cpp new file mode 100644 index 0000000000..365feeafd4 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/wmma/gnhwc_gkyxc_gnhwk/device_grouped_conv2d_bwd_weight_wmma_gnhwc_gkyxc_gnhwk_f16_pad0_pipev1_instance.cpp @@ -0,0 +1,39 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_wmma_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Compilation parameters for in[g, n, hi, wi, c] * wei[g, k, y, x, c] = out[g, n, ho, wo, k] +void add_device_grouped_conv2d_bwd_weight_wmma_gnhwc_gkyxc_gnhwk_f16_pad0_pipev1_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_bwd_weight_v3_wmma_c_shuffle_f16_instances< + 2, + GNHWC, + GKYXC, + GNHWK, + ConvBwdWeightFilter1x1Stride1Pad0, + BlockGemmPipelineScheduler::Intrawave, + BlockGemmPipelineVersion::v1>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/wmma/ngchw_gkcyx_ngkhw/device_grouped_conv2d_bwd_weight_two_stage_wmma_ngchw_gkcyx_ngkhw_bf16_pipev1_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/wmma/ngchw_gkcyx_ngkhw/device_grouped_conv2d_bwd_weight_two_stage_wmma_ngchw_gkcyx_ngkhw_bf16_pipev1_instance.cpp new file mode 100644 index 0000000000..1edaa86214 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/wmma/ngchw_gkcyx_ngkhw/device_grouped_conv2d_bwd_weight_two_stage_wmma_ngchw_gkcyx_ngkhw_bf16_pipev1_instance.cpp @@ -0,0 +1,41 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_wmma_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_bwd_weight_two_stage_wmma_ngchw_gkcyx_ngkhw_bf16_pipev1_instances( + std::vector>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_two_stage_ngchw_wmma_c_shuffle_bf16_instances< + 2, + NGCHW, + GKCYX, + NGKHW, + ConvBwdWeightDefault, + BlockGemmPipelineScheduler::Intrawave, + BlockGemmPipelineVersion::v1>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/wmma/ngchw_gkcyx_ngkhw/device_grouped_conv2d_bwd_weight_two_stage_wmma_ngchw_gkcyx_ngkhw_bf16_pipev1_part2_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/wmma/ngchw_gkcyx_ngkhw/device_grouped_conv2d_bwd_weight_two_stage_wmma_ngchw_gkcyx_ngkhw_bf16_pipev1_part2_instance.cpp new file mode 100644 index 0000000000..677e03b196 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/wmma/ngchw_gkcyx_ngkhw/device_grouped_conv2d_bwd_weight_two_stage_wmma_ngchw_gkcyx_ngkhw_bf16_pipev1_part2_instance.cpp @@ -0,0 +1,41 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_wmma_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_bwd_weight_two_stage_wmma_ngchw_gkcyx_ngkhw_bf16_pipev1_part2_instances( + std::vector>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_two_stage_ngchw_wmma_c_shuffle_bf16_part2_instances< + 2, + NGCHW, + GKCYX, + NGKHW, + ConvBwdWeightDefault, + BlockGemmPipelineScheduler::Intrawave, + BlockGemmPipelineVersion::v1>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/wmma/ngchw_gkcyx_ngkhw/device_grouped_conv2d_bwd_weight_two_stage_wmma_ngchw_gkcyx_ngkhw_f16_pipev1_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/wmma/ngchw_gkcyx_ngkhw/device_grouped_conv2d_bwd_weight_two_stage_wmma_ngchw_gkcyx_ngkhw_f16_pipev1_instance.cpp new file mode 100644 index 0000000000..43039fe11c --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/wmma/ngchw_gkcyx_ngkhw/device_grouped_conv2d_bwd_weight_two_stage_wmma_ngchw_gkcyx_ngkhw_f16_pipev1_instance.cpp @@ -0,0 +1,41 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_wmma_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_bwd_weight_two_stage_wmma_ngchw_gkcyx_ngkhw_f16_pipev1_instances( + std::vector>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_two_stage_ngchw_wmma_c_shuffle_f16_instances< + 2, + NGCHW, + GKCYX, + NGKHW, + ConvBwdWeightDefault, + BlockGemmPipelineScheduler::Intrawave, + BlockGemmPipelineVersion::v1>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/wmma/ngchw_gkcyx_ngkhw/device_grouped_conv2d_bwd_weight_two_stage_wmma_ngchw_gkcyx_ngkhw_f16_pipev1_part2_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/wmma/ngchw_gkcyx_ngkhw/device_grouped_conv2d_bwd_weight_two_stage_wmma_ngchw_gkcyx_ngkhw_f16_pipev1_part2_instance.cpp new file mode 100644 index 0000000000..a1e8ab5a1e --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/wmma/ngchw_gkcyx_ngkhw/device_grouped_conv2d_bwd_weight_two_stage_wmma_ngchw_gkcyx_ngkhw_f16_pipev1_part2_instance.cpp @@ -0,0 +1,41 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_wmma_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_bwd_weight_two_stage_wmma_ngchw_gkcyx_ngkhw_f16_pipev1_part2_instances( + std::vector>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_two_stage_ngchw_wmma_c_shuffle_f16_part2_instances< + 2, + NGCHW, + GKCYX, + NGKHW, + ConvBwdWeightDefault, + BlockGemmPipelineScheduler::Intrawave, + BlockGemmPipelineVersion::v1>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/wmma/ngchw_gkcyx_ngkhw/device_grouped_conv2d_bwd_weight_wmma_ngchw_gkcyx_ngkhw_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/wmma/ngchw_gkcyx_ngkhw/device_grouped_conv2d_bwd_weight_wmma_ngchw_gkcyx_ngkhw_bf16_instance.cpp new file mode 100644 index 0000000000..67db6f5f1f --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/wmma/ngchw_gkcyx_ngkhw/device_grouped_conv2d_bwd_weight_wmma_ngchw_gkcyx_ngkhw_bf16_instance.cpp @@ -0,0 +1,49 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_wmma_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_bwd_weight_wmma_ngchw_gkcyx_ngkhw_bf16_instances( + std::vector>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_wmma_c_shuffle_bf16_instances<2, + NGCHW, + GKCYX, + NGKHW, + ConvBwdWeightDefault, + 1, + 1>{}); + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_wmma_c_shuffle_bf16_instances<2, + NGCHW, + GKCYX, + NGKHW, + ConvBwdWeightDefault, + 4, + 4>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/wmma/ngchw_gkcyx_ngkhw/device_grouped_conv2d_bwd_weight_wmma_ngchw_gkcyx_ngkhw_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/wmma/ngchw_gkcyx_ngkhw/device_grouped_conv2d_bwd_weight_wmma_ngchw_gkcyx_ngkhw_f16_instance.cpp new file mode 100644 index 0000000000..e4d306c41a --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/wmma/ngchw_gkcyx_ngkhw/device_grouped_conv2d_bwd_weight_wmma_ngchw_gkcyx_ngkhw_f16_instance.cpp @@ -0,0 +1,49 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_wmma_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_bwd_weight_wmma_ngchw_gkcyx_ngkhw_f16_instances( + std::vector>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_wmma_c_shuffle_f16_instances<2, + NGCHW, + GKCYX, + NGKHW, + ConvBwdWeightDefault, + 1, + 1>{}); + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_wmma_c_shuffle_f16_instances<2, + NGCHW, + GKCYX, + NGKHW, + ConvBwdWeightDefault, + 4, + 4>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/wmma/ngchw_gkyxc_ngkhw/device_grouped_conv2d_bwd_weight_two_stage_wmma_ngchw_gkyxc_ngkhw_bf16_pipev1_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/wmma/ngchw_gkyxc_ngkhw/device_grouped_conv2d_bwd_weight_two_stage_wmma_ngchw_gkyxc_ngkhw_bf16_pipev1_instance.cpp new file mode 100644 index 0000000000..e222e8c663 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/wmma/ngchw_gkyxc_ngkhw/device_grouped_conv2d_bwd_weight_two_stage_wmma_ngchw_gkyxc_ngkhw_bf16_pipev1_instance.cpp @@ -0,0 +1,41 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_wmma_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_bwd_weight_two_stage_wmma_ngchw_gkyxc_ngkhw_bf16_pipev1_instances( + std::vector>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_two_stage_ngchw_wmma_c_shuffle_bf16_generic_instances< + 2, + NGCHW, + GKYXC, + NGKHW, + ConvBwdWeightDefault, + BlockGemmPipelineScheduler::Intrawave, + BlockGemmPipelineVersion::v1>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/wmma/ngchw_gkyxc_ngkhw/device_grouped_conv2d_bwd_weight_two_stage_wmma_ngchw_gkyxc_ngkhw_f16_pipev1_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/wmma/ngchw_gkyxc_ngkhw/device_grouped_conv2d_bwd_weight_two_stage_wmma_ngchw_gkyxc_ngkhw_f16_pipev1_instance.cpp new file mode 100644 index 0000000000..85d395c7c2 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/wmma/ngchw_gkyxc_ngkhw/device_grouped_conv2d_bwd_weight_two_stage_wmma_ngchw_gkyxc_ngkhw_f16_pipev1_instance.cpp @@ -0,0 +1,41 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_wmma_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_bwd_weight_two_stage_wmma_ngchw_gkyxc_ngkhw_f16_pipev1_instances( + std::vector>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_two_stage_ngchw_wmma_c_shuffle_f16_generic_instances< + 2, + NGCHW, + GKYXC, + NGKHW, + ConvBwdWeightDefault, + BlockGemmPipelineScheduler::Intrawave, + BlockGemmPipelineVersion::v1>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/wmma/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_two_stage_wmma_nhwgc_gkyxc_nhwgk_bf16_pipev1_irregular_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/wmma/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_two_stage_wmma_nhwgc_gkyxc_nhwgk_bf16_pipev1_irregular_instance.cpp new file mode 100644 index 0000000000..ef8ba7109d --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/wmma/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_two_stage_wmma_nhwgc_gkyxc_nhwgk_bf16_pipev1_irregular_instance.cpp @@ -0,0 +1,41 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_wmma_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_bwd_weight_two_stage_wmma_nhwgc_gkyxc_nhwgk_bf16_pipev1_irregular_instances( + std::vector>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_two_stage_nhwgc_wmma_c_shuffle_bf16_irregular_instances< + 2, + NHWGC, + GKYXC, + NHWGK, + ConvBwdWeightDefault, + BlockGemmPipelineScheduler::Intrawave, + BlockGemmPipelineVersion::v1>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/wmma/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_two_stage_wmma_nhwgc_gkyxc_nhwgk_bf16_pipev1_part2_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/wmma/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_two_stage_wmma_nhwgc_gkyxc_nhwgk_bf16_pipev1_part2_instance.cpp new file mode 100644 index 0000000000..c6787c8200 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/wmma/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_two_stage_wmma_nhwgc_gkyxc_nhwgk_bf16_pipev1_part2_instance.cpp @@ -0,0 +1,41 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_wmma_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_bwd_weight_two_stage_wmma_nhwgc_gkyxc_nhwgk_bf16_pipev1_part2_instances( + std::vector>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_two_stage_nhwgc_wmma_c_shuffle_bf16_part2_instances< + 2, + NHWGC, + GKYXC, + NHWGK, + ConvBwdWeightDefault, + BlockGemmPipelineScheduler::Intrawave, + BlockGemmPipelineVersion::v1>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/wmma/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_two_stage_wmma_nhwgc_gkyxc_nhwgk_f16_pipev1_irregular_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/wmma/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_two_stage_wmma_nhwgc_gkyxc_nhwgk_f16_pipev1_irregular_instance.cpp new file mode 100644 index 0000000000..c71f60416e --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/wmma/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_two_stage_wmma_nhwgc_gkyxc_nhwgk_f16_pipev1_irregular_instance.cpp @@ -0,0 +1,41 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_wmma_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_bwd_weight_two_stage_wmma_nhwgc_gkyxc_nhwgk_f16_pipev1_irregular_instances( + std::vector>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_two_stage_nhwgc_wmma_c_shuffle_f16_irregular_instances< + 2, + NHWGC, + GKYXC, + NHWGK, + ConvBwdWeightDefault, + BlockGemmPipelineScheduler::Intrawave, + BlockGemmPipelineVersion::v1>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/wmma/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_two_stage_wmma_nhwgc_gkyxc_nhwgk_f16_pipev1_part2_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/wmma/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_two_stage_wmma_nhwgc_gkyxc_nhwgk_f16_pipev1_part2_instance.cpp new file mode 100644 index 0000000000..35e7dccce4 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/wmma/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_two_stage_wmma_nhwgc_gkyxc_nhwgk_f16_pipev1_part2_instance.cpp @@ -0,0 +1,41 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_wmma_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_bwd_weight_two_stage_wmma_nhwgc_gkyxc_nhwgk_f16_pipev1_part2_instances( + std::vector>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_two_stage_nhwgc_wmma_c_shuffle_f16_part2_instances< + 2, + NHWGC, + GKYXC, + NHWGK, + ConvBwdWeightDefault, + BlockGemmPipelineScheduler::Intrawave, + BlockGemmPipelineVersion::v1>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/wmma/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_bf16_default_pipev1_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/wmma/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_bf16_default_pipev1_instance.cpp new file mode 100644 index 0000000000..6768244bb3 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/wmma/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_bf16_default_pipev1_instance.cpp @@ -0,0 +1,39 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_wmma_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_bf16_default_pipev1_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_bwd_weight_v3_wmma_c_shuffle_bf16_instances< + 2, + NHWGC, + GKYXC, + NHWGK, + ConvBwdWeightDefault, + BlockGemmPipelineScheduler::Intrawave, + BlockGemmPipelineVersion::v1>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/wmma/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_bf16_f32_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/wmma/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_bf16_f32_bf16_instance.cpp new file mode 100644 index 0000000000..db422cd571 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/wmma/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_bf16_f32_bf16_instance.cpp @@ -0,0 +1,48 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_wmma_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_bf16_f32_bf16_instances( + std::vector>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_wmma_c_shuffle_bf16_f32_bf16_instances< + 2, + NHWGC, + GKYXC, + NHWGK, + ConvBwdWeightDefault>{}); + // 2. Filter1x1Stride1Pad0 + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_wmma_c_shuffle_bf16_f32_bf16_instances< + 2, + NHWGC, + GKYXC, + NHWGK, + ConvBwdWeightFilter1x1Stride1Pad0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/wmma/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/wmma/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_bf16_instance.cpp index adc9de3a3d..78b29101ed 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/wmma/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_bf16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/wmma/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_bf16_instance.cpp @@ -2,7 +2,7 @@ // SPDX-License-Identifier: MIT #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" -#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_wmma_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_wmma_instance.hpp" namespace ck { namespace tensor_operation { @@ -25,11 +25,19 @@ void add_device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_bf16_instances( // 1. Default add_device_operation_instances( instances, - device_grouped_conv_bwd_weight_v3_wmma_c_shuffle_bf16_instances<2, - NHWGC, - GKYXC, - NHWGK, - ConvBwdWeightDefault>{}); + device_grouped_conv_bwd_weight_wmma_c_shuffle_bf16_instances<2, + NHWGC, + GKYXC, + NHWGK, + ConvBwdWeightDefault>{}); + // 2. Filter1x1Stride1Pad0 + add_device_operation_instances(instances, + device_grouped_conv_bwd_weight_wmma_c_shuffle_bf16_instances< + 2, + NHWGC, + GKYXC, + NHWGK, + ConvBwdWeightFilter1x1Stride1Pad0>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/wmma/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_bf16_pad0_pipev1_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/wmma/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_bf16_pad0_pipev1_instance.cpp new file mode 100644 index 0000000000..85ce996a3b --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/wmma/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_bf16_pad0_pipev1_instance.cpp @@ -0,0 +1,39 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_wmma_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_bf16_pad0_pipev1_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_bwd_weight_v3_wmma_c_shuffle_bf16_instances< + 2, + NHWGC, + GKYXC, + NHWGK, + ConvBwdWeightFilter1x1Stride1Pad0, + BlockGemmPipelineScheduler::Intrawave, + BlockGemmPipelineVersion::v1>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/wmma/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_f16_default_pipev1_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/wmma/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_f16_default_pipev1_instance.cpp new file mode 100644 index 0000000000..e323cd3d72 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/wmma/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_f16_default_pipev1_instance.cpp @@ -0,0 +1,39 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_wmma_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_f16_default_pipev1_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_bwd_weight_v3_wmma_c_shuffle_f16_instances< + 2, + NHWGC, + GKYXC, + NHWGK, + ConvBwdWeightDefault, + BlockGemmPipelineScheduler::Intrawave, + BlockGemmPipelineVersion::v1>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/wmma/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/wmma/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_f16_instance.cpp index f304d1bba4..c4c40f49be 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/wmma/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_f16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/wmma/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_f16_instance.cpp @@ -2,7 +2,7 @@ // SPDX-License-Identifier: MIT #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" -#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_wmma_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_wmma_instance.hpp" namespace ck { namespace tensor_operation { @@ -25,11 +25,19 @@ void add_device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_f16_instances( // 1. Default add_device_operation_instances( instances, - device_grouped_conv_bwd_weight_v3_wmma_c_shuffle_f16_instances<2, - NHWGC, - GKYXC, - NHWGK, - ConvBwdWeightDefault>{}); + device_grouped_conv_bwd_weight_wmma_c_shuffle_f16_instances<2, + NHWGC, + GKYXC, + NHWGK, + ConvBwdWeightDefault>{}); + // 2. Filter1x1Stride1Pad0 + add_device_operation_instances(instances, + device_grouped_conv_bwd_weight_wmma_c_shuffle_f16_instances< + 2, + NHWGC, + GKYXC, + NHWGK, + ConvBwdWeightFilter1x1Stride1Pad0>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/wmma/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_f16_pad0_pipev1_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/wmma/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_f16_pad0_pipev1_instance.cpp new file mode 100644 index 0000000000..1fee43c50d --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/wmma/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_f16_pad0_pipev1_instance.cpp @@ -0,0 +1,39 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_wmma_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_f16_pad0_pipev1_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_bwd_weight_v3_wmma_c_shuffle_f16_instances< + 2, + NHWGC, + GKYXC, + NHWGK, + ConvBwdWeightFilter1x1Stride1Pad0, + BlockGemmPipelineScheduler::Intrawave, + BlockGemmPipelineVersion::v1>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/CMakeLists.txt index b246b87178..2f5b86c9ee 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/CMakeLists.txt @@ -69,10 +69,32 @@ if(DL_KERNELS) endif() list(APPEND GROUPED_CONV3D_BWD_WEIGHT + wmma/gndhwc_gkzyxc_gndhwk/device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_f16_instance.cpp + wmma/gndhwc_gkzyxc_gndhwk/device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_bf16_f32_bf16_instance.cpp + wmma/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp + wmma/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_f16_default_pipev1_instance.cpp + wmma/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_f16_pad0_pipev1_instance.cpp + wmma/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_instance.cpp wmma/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp + wmma/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_bf16_default_pipev1_instance.cpp + wmma/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_bf16_pad0_pipev1_instance.cpp wmma/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_two_stage_wmma_ndhwgc_gkzyxc_ndhwgk_f16_pipev1_instance.cpp + wmma/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_two_stage_wmma_ndhwgc_gkzyxc_ndhwgk_f16_pipev1_part2_instance.cpp + wmma/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_two_stage_wmma_ndhwgc_gkzyxc_ndhwgk_f16_pipev1_irregular_instance.cpp wmma/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_two_stage_wmma_ndhwgc_gkzyxc_ndhwgk_bf16_pipev1_instance.cpp + wmma/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_two_stage_wmma_ndhwgc_gkzyxc_ndhwgk_bf16_pipev1_part2_instance.cpp + wmma/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_two_stage_wmma_ndhwgc_gkzyxc_ndhwgk_bf16_pipev1_irregular_instance.cpp + + wmma/ngcdhw_gkczyx_ngkdhw/device_grouped_conv3d_bwd_weight_wmma_ngcdhw_gkczyx_ngkdhw_f16_instance.cpp + wmma/ngcdhw_gkczyx_ngkdhw/device_grouped_conv3d_bwd_weight_wmma_ngcdhw_gkczyx_ngkdhw_bf16_instance.cpp + wmma/ngcdhw_gkczyx_ngkdhw/device_grouped_conv3d_bwd_weight_two_stage_wmma_ngcdhw_gkczyx_ngkdhw_f16_pipev1_instance.cpp + wmma/ngcdhw_gkczyx_ngkdhw/device_grouped_conv3d_bwd_weight_two_stage_wmma_ngcdhw_gkczyx_ngkdhw_f16_pipev1_part2_instance.cpp + wmma/ngcdhw_gkczyx_ngkdhw/device_grouped_conv3d_bwd_weight_two_stage_wmma_ngcdhw_gkczyx_ngkdhw_bf16_pipev1_instance.cpp + wmma/ngcdhw_gkczyx_ngkdhw/device_grouped_conv3d_bwd_weight_two_stage_wmma_ngcdhw_gkczyx_ngkdhw_bf16_pipev1_part2_instance.cpp + + wmma/ngcdhw_gkzyxc_ngkdhw/device_grouped_conv3d_bwd_weight_two_stage_wmma_ngcdhw_gkzyxc_ngkdhw_f16_pipev1_instance.cpp + wmma/ngcdhw_gkzyxc_ngkdhw/device_grouped_conv3d_bwd_weight_two_stage_wmma_ngcdhw_gkzyxc_ngkdhw_bf16_pipev1_instance.cpp ) if((DTYPES MATCHES "fp8" AND DTYPES MATCHES "bf8" AND DTYPES MATCHES "fp16") OR NOT DEFINED DTYPES) diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/gndhwc_gkzyxc_gndhwk/device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_bf16_f32_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/gndhwc_gkzyxc_gndhwk/device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_bf16_f32_bf16_instance.cpp new file mode 100644 index 0000000000..b05ca0788f --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/gndhwc_gkzyxc_gndhwk/device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_bf16_f32_bf16_instance.cpp @@ -0,0 +1,48 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_wmma_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_bf16_f32_bf16_instances( + std::vector>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_wmma_c_shuffle_bf16_f32_bf16_generic_instances< + 3, + GNDHWC, + GKZYXC, + GNDHWK, + ConvBwdWeightDefault>{}); + + // 2. Filter1x1Stride1Pad0 + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_wmma_c_shuffle_bf16_f32_bf16_generic_instances< + 3, + GNDHWC, + GKZYXC, + GNDHWK, + ConvBwdWeightFilter1x1Stride1Pad0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/gndhwc_gkzyxc_gndhwk/device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/gndhwc_gkzyxc_gndhwk/device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_f16_instance.cpp new file mode 100644 index 0000000000..aa82de070f --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/gndhwc_gkzyxc_gndhwk/device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_f16_instance.cpp @@ -0,0 +1,47 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_wmma_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_f16_instances( + std::vector>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_wmma_c_shuffle_f16_generic_instances< + 3, + GNDHWC, + GKZYXC, + GNDHWK, + ConvBwdWeightDefault>{}); + // 2. Filter1x1Stride1Pad0 + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_wmma_c_shuffle_f16_generic_instances< + 3, + GNDHWC, + GKZYXC, + GNDHWK, + ConvBwdWeightFilter1x1Stride1Pad0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_two_stage_wmma_ndhwgc_gkzyxc_ndhwgk_bf16_pipev1_irregular_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_two_stage_wmma_ndhwgc_gkzyxc_ndhwgk_bf16_pipev1_irregular_instance.cpp new file mode 100644 index 0000000000..ed27879639 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_two_stage_wmma_ndhwgc_gkzyxc_ndhwgk_bf16_pipev1_irregular_instance.cpp @@ -0,0 +1,41 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_wmma_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv3d_bwd_weight_two_stage_wmma_ndhwgc_gkzyxc_ndhwgk_bf16_pipev1_irregular_instances( + std::vector>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_two_stage_nhwgc_wmma_c_shuffle_bf16_irregular_instances< + 3, + NDHWGC, + GKZYXC, + NDHWGK, + ConvBwdWeightDefault, + BlockGemmPipelineScheduler::Intrawave, + BlockGemmPipelineVersion::v1>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_two_stage_wmma_ndhwgc_gkzyxc_ndhwgk_bf16_pipev1_part2_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_two_stage_wmma_ndhwgc_gkzyxc_ndhwgk_bf16_pipev1_part2_instance.cpp new file mode 100644 index 0000000000..32e3988eb1 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_two_stage_wmma_ndhwgc_gkzyxc_ndhwgk_bf16_pipev1_part2_instance.cpp @@ -0,0 +1,41 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_wmma_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv3d_bwd_weight_two_stage_wmma_ndhwgc_gkzyxc_ndhwgk_bf16_pipev1_part2_instances( + std::vector>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_two_stage_nhwgc_wmma_c_shuffle_bf16_part2_instances< + 3, + NDHWGC, + GKZYXC, + NDHWGK, + ConvBwdWeightDefault, + BlockGemmPipelineScheduler::Intrawave, + BlockGemmPipelineVersion::v1>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_two_stage_wmma_ndhwgc_gkzyxc_ndhwgk_f16_pipev1_irregular_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_two_stage_wmma_ndhwgc_gkzyxc_ndhwgk_f16_pipev1_irregular_instance.cpp new file mode 100644 index 0000000000..5e044213ce --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_two_stage_wmma_ndhwgc_gkzyxc_ndhwgk_f16_pipev1_irregular_instance.cpp @@ -0,0 +1,41 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_wmma_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv3d_bwd_weight_two_stage_wmma_ndhwgc_gkzyxc_ndhwgk_f16_pipev1_irregular_instances( + std::vector>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_two_stage_nhwgc_wmma_c_shuffle_f16_irregular_instances< + 3, + NDHWGC, + GKZYXC, + NDHWGK, + ConvBwdWeightDefault, + BlockGemmPipelineScheduler::Intrawave, + BlockGemmPipelineVersion::v1>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_two_stage_wmma_ndhwgc_gkzyxc_ndhwgk_f16_pipev1_part2_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_two_stage_wmma_ndhwgc_gkzyxc_ndhwgk_f16_pipev1_part2_instance.cpp new file mode 100644 index 0000000000..d872fa1fc8 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_two_stage_wmma_ndhwgc_gkzyxc_ndhwgk_f16_pipev1_part2_instance.cpp @@ -0,0 +1,41 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_wmma_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv3d_bwd_weight_two_stage_wmma_ndhwgc_gkzyxc_ndhwgk_f16_pipev1_part2_instances( + std::vector>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_two_stage_nhwgc_wmma_c_shuffle_f16_part2_instances< + 3, + NDHWGC, + GKZYXC, + NDHWGK, + ConvBwdWeightDefault, + BlockGemmPipelineScheduler::Intrawave, + BlockGemmPipelineVersion::v1>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_bf16_default_pipev1_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_bf16_default_pipev1_instance.cpp new file mode 100644 index 0000000000..828936212c --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_bf16_default_pipev1_instance.cpp @@ -0,0 +1,39 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_wmma_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_bf16_default_pipev1_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_bwd_weight_v3_wmma_c_shuffle_bf16_instances< + 3, + NDHWGC, + GKZYXC, + NDHWGK, + ConvBwdWeightDefault, + BlockGemmPipelineScheduler::Intrawave, + BlockGemmPipelineVersion::v1>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_instance.cpp new file mode 100644 index 0000000000..98c0b71bef --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_instance.cpp @@ -0,0 +1,48 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_wmma_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_instances( + std::vector>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_wmma_c_shuffle_bf16_f32_bf16_instances< + 3, + NDHWGC, + GKZYXC, + NDHWGK, + ConvBwdWeightDefault>{}); + // 2. Filter1x1Stride1Pad0 + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_wmma_c_shuffle_bf16_f32_bf16_instances< + 3, + NDHWGC, + GKZYXC, + NDHWGK, + ConvBwdWeightFilter1x1Stride1Pad0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp index 728f514f9a..3b08d9f1ad 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp @@ -2,7 +2,7 @@ // SPDX-License-Identifier: MIT #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" -#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_wmma_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_wmma_instance.hpp" namespace ck { namespace tensor_operation { @@ -25,11 +25,19 @@ void add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_bf16_instanc // 1. Default add_device_operation_instances( instances, - device_grouped_conv_bwd_weight_v3_wmma_c_shuffle_bf16_instances<3, - NDHWGC, - GKZYXC, - NDHWGK, - ConvBwdWeightDefault>{}); + device_grouped_conv_bwd_weight_wmma_c_shuffle_bf16_instances<3, + NDHWGC, + GKZYXC, + NDHWGK, + ConvBwdWeightDefault>{}); + // 2. Filter1x1Stride1Pad0 + add_device_operation_instances(instances, + device_grouped_conv_bwd_weight_wmma_c_shuffle_bf16_instances< + 3, + NDHWGC, + GKZYXC, + NDHWGK, + ConvBwdWeightFilter1x1Stride1Pad0>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_bf16_pad0_pipev1_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_bf16_pad0_pipev1_instance.cpp new file mode 100644 index 0000000000..e5784c3885 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_bf16_pad0_pipev1_instance.cpp @@ -0,0 +1,39 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_wmma_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_bf16_pad0_pipev1_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_bwd_weight_v3_wmma_c_shuffle_bf16_instances< + 3, + NDHWGC, + GKZYXC, + NDHWGK, + ConvBwdWeightFilter1x1Stride1Pad0, + BlockGemmPipelineScheduler::Intrawave, + BlockGemmPipelineVersion::v1>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_f16_default_pipev1_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_f16_default_pipev1_instance.cpp new file mode 100644 index 0000000000..a9126540ff --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_f16_default_pipev1_instance.cpp @@ -0,0 +1,39 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_wmma_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_f16_default_pipev1_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_bwd_weight_v3_wmma_c_shuffle_f16_instances< + 3, + NDHWGC, + GKZYXC, + NDHWGK, + ConvBwdWeightDefault, + BlockGemmPipelineScheduler::Intrawave, + BlockGemmPipelineVersion::v1>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp index f929196ddb..5374dc1be9 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp @@ -2,7 +2,7 @@ // SPDX-License-Identifier: MIT #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" -#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_wmma_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_wmma_instance.hpp" namespace ck { namespace tensor_operation { @@ -25,11 +25,19 @@ void add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_f16_instance // 1. Default add_device_operation_instances( instances, - device_grouped_conv_bwd_weight_v3_wmma_c_shuffle_f16_instances<3, - NDHWGC, - GKZYXC, - NDHWGK, - ConvBwdWeightDefault>{}); + device_grouped_conv_bwd_weight_wmma_c_shuffle_f16_instances<3, + NDHWGC, + GKZYXC, + NDHWGK, + ConvBwdWeightDefault>{}); + // 2. Filter1x1Stride1Pad0 + add_device_operation_instances(instances, + device_grouped_conv_bwd_weight_wmma_c_shuffle_f16_instances< + 3, + NDHWGC, + GKZYXC, + NDHWGK, + ConvBwdWeightFilter1x1Stride1Pad0>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_f16_pad0_pipev1_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_f16_pad0_pipev1_instance.cpp new file mode 100644 index 0000000000..09f6e42afb --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_f16_pad0_pipev1_instance.cpp @@ -0,0 +1,39 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_wmma_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_f16_pad0_pipev1_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_bwd_weight_v3_wmma_c_shuffle_f16_instances< + 3, + NDHWGC, + GKZYXC, + NDHWGK, + ConvBwdWeightFilter1x1Stride1Pad0, + BlockGemmPipelineScheduler::Intrawave, + BlockGemmPipelineVersion::v1>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/ngcdhw_gkczyx_ngkdhw/device_grouped_conv3d_bwd_weight_two_stage_wmma_ngcdhw_gkczyx_ngkdhw_bf16_pipev1_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/ngcdhw_gkczyx_ngkdhw/device_grouped_conv3d_bwd_weight_two_stage_wmma_ngcdhw_gkczyx_ngkdhw_bf16_pipev1_instance.cpp new file mode 100644 index 0000000000..1ddeacbc79 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/ngcdhw_gkczyx_ngkdhw/device_grouped_conv3d_bwd_weight_two_stage_wmma_ngcdhw_gkczyx_ngkdhw_bf16_pipev1_instance.cpp @@ -0,0 +1,41 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_wmma_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv3d_bwd_weight_two_stage_wmma_ngcdhw_gkczyx_ngkdhw_bf16_pipev1_instances( + std::vector>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_two_stage_ngchw_wmma_c_shuffle_bf16_instances< + 3, + NGCDHW, + GKCZYX, + NGKDHW, + ConvBwdWeightDefault, + BlockGemmPipelineScheduler::Intrawave, + BlockGemmPipelineVersion::v1>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/ngcdhw_gkczyx_ngkdhw/device_grouped_conv3d_bwd_weight_two_stage_wmma_ngcdhw_gkczyx_ngkdhw_bf16_pipev1_part2_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/ngcdhw_gkczyx_ngkdhw/device_grouped_conv3d_bwd_weight_two_stage_wmma_ngcdhw_gkczyx_ngkdhw_bf16_pipev1_part2_instance.cpp new file mode 100644 index 0000000000..49a7ce52b8 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/ngcdhw_gkczyx_ngkdhw/device_grouped_conv3d_bwd_weight_two_stage_wmma_ngcdhw_gkczyx_ngkdhw_bf16_pipev1_part2_instance.cpp @@ -0,0 +1,41 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_wmma_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv3d_bwd_weight_two_stage_wmma_ngcdhw_gkczyx_ngkdhw_bf16_pipev1_part2_instances( + std::vector>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_two_stage_ngchw_wmma_c_shuffle_bf16_part2_instances< + 3, + NGCDHW, + GKCZYX, + NGKDHW, + ConvBwdWeightDefault, + BlockGemmPipelineScheduler::Intrawave, + BlockGemmPipelineVersion::v1>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/ngcdhw_gkczyx_ngkdhw/device_grouped_conv3d_bwd_weight_two_stage_wmma_ngcdhw_gkczyx_ngkdhw_f16_pipev1_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/ngcdhw_gkczyx_ngkdhw/device_grouped_conv3d_bwd_weight_two_stage_wmma_ngcdhw_gkczyx_ngkdhw_f16_pipev1_instance.cpp new file mode 100644 index 0000000000..eaffb670cc --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/ngcdhw_gkczyx_ngkdhw/device_grouped_conv3d_bwd_weight_two_stage_wmma_ngcdhw_gkczyx_ngkdhw_f16_pipev1_instance.cpp @@ -0,0 +1,41 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_wmma_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv3d_bwd_weight_two_stage_wmma_ngcdhw_gkczyx_ngkdhw_f16_pipev1_instances( + std::vector>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_two_stage_ngchw_wmma_c_shuffle_f16_instances< + 3, + NGCDHW, + GKCZYX, + NGKDHW, + ConvBwdWeightDefault, + BlockGemmPipelineScheduler::Intrawave, + BlockGemmPipelineVersion::v1>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/ngcdhw_gkczyx_ngkdhw/device_grouped_conv3d_bwd_weight_two_stage_wmma_ngcdhw_gkczyx_ngkdhw_f16_pipev1_part2_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/ngcdhw_gkczyx_ngkdhw/device_grouped_conv3d_bwd_weight_two_stage_wmma_ngcdhw_gkczyx_ngkdhw_f16_pipev1_part2_instance.cpp new file mode 100644 index 0000000000..cb446b5c3a --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/ngcdhw_gkczyx_ngkdhw/device_grouped_conv3d_bwd_weight_two_stage_wmma_ngcdhw_gkczyx_ngkdhw_f16_pipev1_part2_instance.cpp @@ -0,0 +1,41 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_wmma_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv3d_bwd_weight_two_stage_wmma_ngcdhw_gkczyx_ngkdhw_f16_pipev1_part2_instances( + std::vector>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_two_stage_ngchw_wmma_c_shuffle_f16_part2_instances< + 3, + NGCDHW, + GKCZYX, + NGKDHW, + ConvBwdWeightDefault, + BlockGemmPipelineScheduler::Intrawave, + BlockGemmPipelineVersion::v1>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/ngcdhw_gkczyx_ngkdhw/device_grouped_conv3d_bwd_weight_wmma_ngcdhw_gkczyx_ngkdhw_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/ngcdhw_gkczyx_ngkdhw/device_grouped_conv3d_bwd_weight_wmma_ngcdhw_gkczyx_ngkdhw_bf16_instance.cpp new file mode 100644 index 0000000000..4aff904534 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/ngcdhw_gkczyx_ngkdhw/device_grouped_conv3d_bwd_weight_wmma_ngcdhw_gkczyx_ngkdhw_bf16_instance.cpp @@ -0,0 +1,49 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_wmma_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv3d_bwd_weight_wmma_ngcdhw_gkczyx_ngkdhw_bf16_instances( + std::vector>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_wmma_c_shuffle_bf16_instances<3, + NGCDHW, + GKCZYX, + NGKDHW, + ConvBwdWeightDefault, + 1, + 1>{}); + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_wmma_c_shuffle_bf16_instances<3, + NGCDHW, + GKCZYX, + NGKDHW, + ConvBwdWeightDefault, + 4, + 4>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/ngcdhw_gkczyx_ngkdhw/device_grouped_conv3d_bwd_weight_wmma_ngcdhw_gkczyx_ngkdhw_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/ngcdhw_gkczyx_ngkdhw/device_grouped_conv3d_bwd_weight_wmma_ngcdhw_gkczyx_ngkdhw_f16_instance.cpp new file mode 100644 index 0000000000..2cfbeb7a96 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/ngcdhw_gkczyx_ngkdhw/device_grouped_conv3d_bwd_weight_wmma_ngcdhw_gkczyx_ngkdhw_f16_instance.cpp @@ -0,0 +1,49 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_wmma_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv3d_bwd_weight_wmma_ngcdhw_gkczyx_ngkdhw_f16_instances( + std::vector>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_wmma_c_shuffle_f16_instances<3, + NGCDHW, + GKCZYX, + NGKDHW, + ConvBwdWeightDefault, + 1, + 1>{}); + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_wmma_c_shuffle_f16_instances<3, + NGCDHW, + GKCZYX, + NGKDHW, + ConvBwdWeightDefault, + 4, + 4>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/ngcdhw_gkzyxc_ngkdhw/device_grouped_conv3d_bwd_weight_two_stage_wmma_ngcdhw_gkzyxc_ngkdhw_bf16_pipev1_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/ngcdhw_gkzyxc_ngkdhw/device_grouped_conv3d_bwd_weight_two_stage_wmma_ngcdhw_gkzyxc_ngkdhw_bf16_pipev1_instance.cpp new file mode 100644 index 0000000000..39d8ef3091 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/ngcdhw_gkzyxc_ngkdhw/device_grouped_conv3d_bwd_weight_two_stage_wmma_ngcdhw_gkzyxc_ngkdhw_bf16_pipev1_instance.cpp @@ -0,0 +1,41 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_wmma_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv3d_bwd_weight_two_stage_wmma_ngcdhw_gkzyxc_ngkdhw_bf16_pipev1_instances( + std::vector>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_two_stage_ngchw_wmma_c_shuffle_bf16_generic_instances< + 3, + NGCDHW, + GKZYXC, + NGKDHW, + ConvBwdWeightDefault, + BlockGemmPipelineScheduler::Intrawave, + BlockGemmPipelineVersion::v1>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/ngcdhw_gkzyxc_ngkdhw/device_grouped_conv3d_bwd_weight_two_stage_wmma_ngcdhw_gkzyxc_ngkdhw_f16_pipev1_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/ngcdhw_gkzyxc_ngkdhw/device_grouped_conv3d_bwd_weight_two_stage_wmma_ngcdhw_gkzyxc_ngkdhw_f16_pipev1_instance.cpp new file mode 100644 index 0000000000..76b6fc9b0d --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/wmma/ngcdhw_gkzyxc_ngkdhw/device_grouped_conv3d_bwd_weight_two_stage_wmma_ngcdhw_gkzyxc_ngkdhw_f16_pipev1_instance.cpp @@ -0,0 +1,41 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_wmma_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv3d_bwd_weight_two_stage_wmma_ngcdhw_gkzyxc_ngkdhw_f16_pipev1_instances( + std::vector>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_two_stage_ngchw_wmma_c_shuffle_f16_generic_instances< + 3, + NGCDHW, + GKZYXC, + NGKDHW, + ConvBwdWeightDefault, + BlockGemmPipelineScheduler::Intrawave, + BlockGemmPipelineVersion::v1>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/CMakeLists.txt index 08f95601f7..7ac458e93c 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/CMakeLists.txt @@ -28,10 +28,23 @@ set(GROUPED_CONVND_EXP_BWD_WEIGHT explicit_wmma/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_comp_default_instance.cpp explicit_wmma/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_comp_mnkpadding_instance.cpp + explicit_wmma/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_mem_v1_default_instance.cpp + explicit_wmma/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_odd_m_instance.cpp + explicit_wmma/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_odd_n_instance.cpp explicit_wmma/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_odd_mn_instance.cpp + explicit_wmma/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_mem_v1_mnkpadding_instance.cpp + explicit_wmma/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_mem_v2_default_instance.cpp + explicit_wmma/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_mem_v2_mnkpadding_instance.cpp explicit_wmma/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_comp_default_instance.cpp explicit_wmma/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_comp_mnkpadding_instance.cpp + explicit_wmma/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_mem_v1_default_instance.cpp + explicit_wmma/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_mem_v1_mnkpadding_instance.cpp + explicit_wmma/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_mem_v2_default_instance.cpp + explicit_wmma/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_mem_v2_mnkpadding_instance.cpp + explicit_wmma/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_odd_mn_instance.cpp + explicit_wmma/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_odd_n_instance.cpp + explicit_wmma/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_odd_m_instance.cpp ) add_instance_library(device_grouped_convnd_bwd_weight_instance ${GROUPED_CONVND_EXP_BWD_WEIGHT}) diff --git a/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_wmma/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_comp_default_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_wmma/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_comp_default_instance.cpp index 894063e081..4db9c93b32 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_wmma/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_comp_default_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_wmma/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_comp_default_instance.cpp @@ -32,7 +32,7 @@ void add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_comp_default_i PassThrough, PassThrough, PassThrough, - device_gemm_wmma_universal_km_kn_mn_GemmDefault_instances>(instances); + device_gemm_wmma_universal_km_kn_mn_comp_instances>(instances); } void add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_comp_default_instances( @@ -58,7 +58,7 @@ void add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_comp_default_i PassThrough, PassThrough, PassThrough, - device_gemm_wmma_universal_km_kn_mn_GemmDefault_instances>(instances); + device_gemm_wmma_universal_km_kn_mn_comp_instances>(instances); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_wmma/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_comp_mnkpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_wmma/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_comp_mnkpadding_instance.cpp index a3b16e4216..a4b726e6bf 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_wmma/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_comp_mnkpadding_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_wmma/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_comp_mnkpadding_instance.cpp @@ -32,7 +32,7 @@ void add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_comp_mnkpaddin PassThrough, PassThrough, PassThrough, - device_gemm_wmma_universal_km_kn_mn_GemmMNKPadding_bf16_instances>(instances); + device_gemm_wmma_universal_km_kn_mn_comp_instances>(instances); } void add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_comp_mnkpadding_instances( @@ -58,7 +58,7 @@ void add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_comp_mnkpaddin PassThrough, PassThrough, PassThrough, - device_gemm_wmma_universal_km_kn_mn_GemmMNKPadding_bf16_instances>(instances); + device_gemm_wmma_universal_km_kn_mn_comp_instances>(instances); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_wmma/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_mem_v1_default_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_wmma/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_mem_v1_default_instance.cpp new file mode 100644 index 0000000000..c98858ad1e --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_wmma/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_mem_v1_default_instance.cpp @@ -0,0 +1,67 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_exp_gemm_wmma_universal_km_kn_mn_instance.hpp" +#include "ck/host_utility/device_prop.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_mem_v1_default_instances( + std::vector>>& instances) +{ + add_explicit_gemm_device_operation_instances< + 2, + NHWGC, + GKYXC, + NHWGK, + BF16, + BF16, + BF16, + PassThrough, + PassThrough, + PassThrough, + device_gemm_wmma_universal_km_kn_mn_mem_instances>(instances); +} + +void add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_mem_v1_default_instances( + std::vector>>& instances) +{ + add_explicit_gemm_device_operation_instances< + 3, + NDHWGC, + GKZYXC, + NDHWGK, + BF16, + BF16, + BF16, + PassThrough, + PassThrough, + PassThrough, + device_gemm_wmma_universal_km_kn_mn_mem_instances>(instances); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_wmma/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_mem_v1_mnkpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_wmma/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_mem_v1_mnkpadding_instance.cpp new file mode 100644 index 0000000000..d03a824c92 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_wmma/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_mem_v1_mnkpadding_instance.cpp @@ -0,0 +1,69 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_exp_gemm_wmma_universal_km_kn_mn_instance.hpp" +#include "ck/host_utility/device_prop.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_mem_v1_mnkpadding_instances( + std::vector>>& instances) +{ + add_explicit_gemm_device_operation_instances< + 2, + NHWGC, + GKYXC, + NHWGK, + BF16, + BF16, + BF16, + PassThrough, + PassThrough, + PassThrough, + device_gemm_wmma_universal_km_kn_mn_mem_instances>( + instances); +} + +void add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_mem_v1_mnkpadding_instances( + std::vector>>& instances) +{ + add_explicit_gemm_device_operation_instances< + 3, + NDHWGC, + GKZYXC, + NDHWGK, + BF16, + BF16, + BF16, + PassThrough, + PassThrough, + PassThrough, + device_gemm_wmma_universal_km_kn_mn_mem_instances>( + instances); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_wmma/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_mem_v2_default_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_wmma/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_mem_v2_default_instance.cpp new file mode 100644 index 0000000000..f8121658a3 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_wmma/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_mem_v2_default_instance.cpp @@ -0,0 +1,67 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_exp_gemm_wmma_universal_km_kn_mn_instance.hpp" +#include "ck/host_utility/device_prop.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_mem_v2_default_instances( + std::vector>>& instances) +{ + add_explicit_gemm_device_operation_instances< + 2, + NHWGC, + GKYXC, + NHWGK, + BF16, + BF16, + BF16, + PassThrough, + PassThrough, + PassThrough, + device_gemm_wmma_universal_km_kn_mn_mem_instances>(instances); +} + +void add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_mem_v2_default_instances( + std::vector>>& instances) +{ + add_explicit_gemm_device_operation_instances< + 3, + NDHWGC, + GKZYXC, + NDHWGK, + BF16, + BF16, + BF16, + PassThrough, + PassThrough, + PassThrough, + device_gemm_wmma_universal_km_kn_mn_mem_instances>(instances); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_wmma/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_mem_v2_mnkpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_wmma/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_mem_v2_mnkpadding_instance.cpp new file mode 100644 index 0000000000..bad4c278a1 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_wmma/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_mem_v2_mnkpadding_instance.cpp @@ -0,0 +1,69 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_exp_gemm_wmma_universal_km_kn_mn_instance.hpp" +#include "ck/host_utility/device_prop.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_mem_v2_mnkpadding_instances( + std::vector>>& instances) +{ + add_explicit_gemm_device_operation_instances< + 2, + NHWGC, + GKYXC, + NHWGK, + BF16, + BF16, + BF16, + PassThrough, + PassThrough, + PassThrough, + device_gemm_wmma_universal_km_kn_mn_mem_instances>( + instances); +} + +void add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_mem_v2_mnkpadding_instances( + std::vector>>& instances) +{ + add_explicit_gemm_device_operation_instances< + 3, + NDHWGC, + GKZYXC, + NDHWGK, + BF16, + BF16, + BF16, + PassThrough, + PassThrough, + PassThrough, + device_gemm_wmma_universal_km_kn_mn_mem_instances>( + instances); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_wmma/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_odd_m_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_wmma/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_odd_m_instance.cpp new file mode 100644 index 0000000000..7bfaff0a51 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_wmma/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_odd_m_instance.cpp @@ -0,0 +1,71 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_exp_gemm_wmma_universal_km_kn_mn_instance.hpp" +#include "ck/host_utility/device_prop.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_odd_m_instances( + std::vector>>& instances) +{ + add_explicit_gemm_device_operation_instances< + 2, + NHWGC, + GKYXC, + NHWGK, + BF16, + BF16, + BF16, + PassThrough, + PassThrough, + PassThrough, + device_gemm_wmma_universal_km_kn_mn_irregular_odd_m_instances>(instances); +} + +void add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_odd_m_instances( + std::vector>>& instances) +{ + add_explicit_gemm_device_operation_instances< + 3, + NDHWGC, + GKZYXC, + NDHWGK, + BF16, + BF16, + BF16, + PassThrough, + PassThrough, + PassThrough, + device_gemm_wmma_universal_km_kn_mn_irregular_odd_m_instances>(instances); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_wmma/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_odd_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_wmma/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_odd_mn_instance.cpp index 967e2884f9..98eb2851c1 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_wmma/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_odd_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_wmma/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_odd_mn_instance.cpp @@ -32,7 +32,9 @@ void add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_odd_mn_instanc PassThrough, PassThrough, PassThrough, - device_gemm_wmma_universal_km_kn_mn_irregular_odd_mn_instances>(instances); + device_gemm_wmma_universal_km_kn_mn_irregular_odd_mn_instances>(instances); } void add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_odd_mn_instances( @@ -58,7 +60,9 @@ void add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_odd_mn_instanc PassThrough, PassThrough, PassThrough, - device_gemm_wmma_universal_km_kn_mn_irregular_odd_mn_instances>(instances); + device_gemm_wmma_universal_km_kn_mn_irregular_odd_mn_instances>(instances); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_wmma/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_odd_n_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_wmma/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_odd_n_instance.cpp new file mode 100644 index 0000000000..f34d511866 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_wmma/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_odd_n_instance.cpp @@ -0,0 +1,69 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_exp_gemm_wmma_universal_km_kn_mn_instance.hpp" +#include "ck/host_utility/device_prop.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_odd_n_instances( + std::vector>>& instances) +{ + add_explicit_gemm_device_operation_instances< + 2, + NHWGC, + GKYXC, + NHWGK, + BF16, + BF16, + BF16, + PassThrough, + PassThrough, + PassThrough, + device_gemm_wmma_universal_km_kn_mn_odd_n_instances>( + instances); +} + +void add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_odd_n_instances( + std::vector>>& instances) +{ + add_explicit_gemm_device_operation_instances< + 3, + NDHWGC, + GKZYXC, + NDHWGK, + BF16, + BF16, + BF16, + PassThrough, + PassThrough, + PassThrough, + device_gemm_wmma_universal_km_kn_mn_odd_n_instances>( + instances); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_wmma/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_comp_default_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_wmma/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_comp_default_instance.cpp index 38e98e719e..4bbccaed72 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_wmma/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_comp_default_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_wmma/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_comp_default_instance.cpp @@ -32,7 +32,7 @@ void add_device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_comp_default_inst PassThrough, PassThrough, PassThrough, - device_gemm_wmma_universal_km_kn_mn_GemmDefault_instances>(instances); + device_gemm_wmma_universal_km_kn_mn_comp_instances>(instances); } void add_device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_comp_default_instances( @@ -58,7 +58,7 @@ void add_device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_comp_default_inst PassThrough, PassThrough, PassThrough, - device_gemm_wmma_universal_km_kn_mn_GemmDefault_instances>(instances); + device_gemm_wmma_universal_km_kn_mn_comp_instances>(instances); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_wmma/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_comp_mnkpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_wmma/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_comp_mnkpadding_instance.cpp index b0a8998562..0f4d35f4c6 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_wmma/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_comp_mnkpadding_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_wmma/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_comp_mnkpadding_instance.cpp @@ -32,7 +32,7 @@ void add_device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_comp_mnkpadding_i PassThrough, PassThrough, PassThrough, - device_gemm_wmma_universal_km_kn_mn_GemmMNKPadding_f16_instances>(instances); + device_gemm_wmma_universal_km_kn_mn_comp_instances>(instances); } void add_device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_comp_mnkpadding_instances( @@ -58,7 +58,7 @@ void add_device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_comp_mnkpadding_i PassThrough, PassThrough, PassThrough, - device_gemm_wmma_universal_km_kn_mn_GemmMNKPadding_f16_instances>(instances); + device_gemm_wmma_universal_km_kn_mn_comp_instances>(instances); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_wmma/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_mem_v1_default_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_wmma/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_mem_v1_default_instance.cpp new file mode 100644 index 0000000000..0e819b87e6 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_wmma/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_mem_v1_default_instance.cpp @@ -0,0 +1,67 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_exp_gemm_wmma_universal_km_kn_mn_instance.hpp" +#include "ck/host_utility/device_prop.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_mem_v1_default_instances( + std::vector>>& instances) +{ + add_explicit_gemm_device_operation_instances< + 2, + NHWGC, + GKYXC, + NHWGK, + F16, + F16, + F16, + PassThrough, + PassThrough, + PassThrough, + device_gemm_wmma_universal_km_kn_mn_mem_instances>(instances); +} + +void add_device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_mem_v1_default_instances( + std::vector>>& instances) +{ + add_explicit_gemm_device_operation_instances< + 3, + NDHWGC, + GKZYXC, + NDHWGK, + F16, + F16, + F16, + PassThrough, + PassThrough, + PassThrough, + device_gemm_wmma_universal_km_kn_mn_mem_instances>(instances); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_wmma/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_mem_v1_mnkpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_wmma/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_mem_v1_mnkpadding_instance.cpp new file mode 100644 index 0000000000..9ec81e7d7b --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_wmma/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_mem_v1_mnkpadding_instance.cpp @@ -0,0 +1,69 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_exp_gemm_wmma_universal_km_kn_mn_instance.hpp" +#include "ck/host_utility/device_prop.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_mem_v1_mnkpadding_instances( + std::vector>>& instances) +{ + add_explicit_gemm_device_operation_instances< + 2, + NHWGC, + GKYXC, + NHWGK, + F16, + F16, + F16, + PassThrough, + PassThrough, + PassThrough, + device_gemm_wmma_universal_km_kn_mn_mem_instances>( + instances); +} + +void add_device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_mem_v1_mnkpadding_instances( + std::vector>>& instances) +{ + add_explicit_gemm_device_operation_instances< + 3, + NDHWGC, + GKZYXC, + NDHWGK, + F16, + F16, + F16, + PassThrough, + PassThrough, + PassThrough, + device_gemm_wmma_universal_km_kn_mn_mem_instances>( + instances); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_wmma/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_mem_v2_default_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_wmma/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_mem_v2_default_instance.cpp new file mode 100644 index 0000000000..c8a344d343 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_wmma/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_mem_v2_default_instance.cpp @@ -0,0 +1,67 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_exp_gemm_wmma_universal_km_kn_mn_instance.hpp" +#include "ck/host_utility/device_prop.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_mem_v2_default_instances( + std::vector>>& instances) +{ + add_explicit_gemm_device_operation_instances< + 2, + NHWGC, + GKYXC, + NHWGK, + F16, + F16, + F16, + PassThrough, + PassThrough, + PassThrough, + device_gemm_wmma_universal_km_kn_mn_mem_instances>(instances); +} + +void add_device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_mem_v2_default_instances( + std::vector>>& instances) +{ + add_explicit_gemm_device_operation_instances< + 3, + NDHWGC, + GKZYXC, + NDHWGK, + F16, + F16, + F16, + PassThrough, + PassThrough, + PassThrough, + device_gemm_wmma_universal_km_kn_mn_mem_instances>(instances); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_wmma/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_mem_v2_mnkpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_wmma/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_mem_v2_mnkpadding_instance.cpp new file mode 100644 index 0000000000..2f31e05e87 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_wmma/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_mem_v2_mnkpadding_instance.cpp @@ -0,0 +1,69 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_exp_gemm_wmma_universal_km_kn_mn_instance.hpp" +#include "ck/host_utility/device_prop.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_mem_v2_mnkpadding_instances( + std::vector>>& instances) +{ + add_explicit_gemm_device_operation_instances< + 2, + NHWGC, + GKYXC, + NHWGK, + F16, + F16, + F16, + PassThrough, + PassThrough, + PassThrough, + device_gemm_wmma_universal_km_kn_mn_mem_instances>( + instances); +} + +void add_device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_mem_v2_mnkpadding_instances( + std::vector>>& instances) +{ + add_explicit_gemm_device_operation_instances< + 3, + NDHWGC, + GKZYXC, + NDHWGK, + F16, + F16, + F16, + PassThrough, + PassThrough, + PassThrough, + device_gemm_wmma_universal_km_kn_mn_mem_instances>( + instances); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_wmma/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_odd_m_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_wmma/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_odd_m_instance.cpp new file mode 100644 index 0000000000..b042058853 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_wmma/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_odd_m_instance.cpp @@ -0,0 +1,71 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_exp_gemm_wmma_universal_km_kn_mn_instance.hpp" +#include "ck/host_utility/device_prop.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_odd_m_instances( + std::vector>>& instances) +{ + add_explicit_gemm_device_operation_instances< + 2, + NHWGC, + GKYXC, + NHWGK, + F16, + F16, + F16, + PassThrough, + PassThrough, + PassThrough, + device_gemm_wmma_universal_km_kn_mn_irregular_odd_m_instances>(instances); +} + +void add_device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_odd_m_instances( + std::vector>>& instances) +{ + add_explicit_gemm_device_operation_instances< + 3, + NDHWGC, + GKZYXC, + NDHWGK, + F16, + F16, + F16, + PassThrough, + PassThrough, + PassThrough, + device_gemm_wmma_universal_km_kn_mn_irregular_odd_m_instances>(instances); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_wmma/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_odd_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_wmma/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_odd_mn_instance.cpp index ace411ea68..54485d7298 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_wmma/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_odd_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_wmma/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_odd_mn_instance.cpp @@ -32,7 +32,9 @@ void add_device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_odd_mn_instances( PassThrough, PassThrough, PassThrough, - device_gemm_wmma_universal_km_kn_mn_irregular_odd_mn_instances>(instances); + device_gemm_wmma_universal_km_kn_mn_irregular_odd_mn_instances>(instances); } void add_device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_odd_mn_instances( @@ -58,7 +60,9 @@ void add_device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_odd_mn_instances( PassThrough, PassThrough, PassThrough, - device_gemm_wmma_universal_km_kn_mn_irregular_odd_mn_instances>(instances); + device_gemm_wmma_universal_km_kn_mn_irregular_odd_mn_instances>(instances); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_wmma/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_odd_n_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_wmma/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_odd_n_instance.cpp new file mode 100644 index 0000000000..b2f8d4c3e3 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_wmma/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_odd_n_instance.cpp @@ -0,0 +1,69 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_exp_gemm_wmma_universal_km_kn_mn_instance.hpp" +#include "ck/host_utility/device_prop.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_odd_n_instances( + std::vector>>& instances) +{ + add_explicit_gemm_device_operation_instances< + 2, + NHWGC, + GKYXC, + NHWGK, + F16, + F16, + F16, + PassThrough, + PassThrough, + PassThrough, + device_gemm_wmma_universal_km_kn_mn_odd_n_instances>( + instances); +} + +void add_device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_odd_n_instances( + std::vector>>& instances) +{ + add_explicit_gemm_device_operation_instances< + 3, + NDHWGC, + GKZYXC, + NDHWGK, + F16, + F16, + F16, + PassThrough, + PassThrough, + PassThrough, + device_gemm_wmma_universal_km_kn_mn_odd_n_instances>( + instances); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck