diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_bilinear_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_bilinear_instance.hpp new file mode 100644 index 0000000000..2bfe99b68a --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_bilinear_instance.hpp @@ -0,0 +1,247 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_wmma_cshuffle_v3.hpp" +#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using BF16 = ck::bhalf_t; +using F16 = ck::half_t; +using F32 = float; + +template +using S = ck::Sequence; + +using namespace ck::tensor_layout::convolution; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Bilinear = ck::tensor_operation::element_wise::Bilinear; + +static constexpr auto ConvFwdDefault = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto ConvFwd1x1P0 = ConvolutionForwardSpecialization::Filter1x1Pad0; + +static constexpr auto ConvFwd1x1S1P0 = ConvolutionForwardSpecialization::Filter1x1Stride1Pad0; + +static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding; + +// BF16 instances +template +using device_grouped_conv_fwd_wmma_cshufflev3_bilinear_bf16_instances_part1 = std::tuple< + // clang-format off + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MWmma| NWmma| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Size| Block| Block| Block| | | WMMA| WMMA| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MWmmaPerWave| NWmmaPerWave| _MBlock_MWaveMPerWmma| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerWmma| _NWaveNPerWmma| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 64, 64, 64, 32, 8, 8, 16, 16, 4, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1> +#ifndef ONE_INSTANCE_PER_LIST + , + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 128, 128, 48, 64, 8, 8, 16, 16, 2, 3, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 2>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 128, 128, 48, 64, 8, 8, 16, 16, 2, 3, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 64, 1, 2>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 256, 64, 64, 64, 8, 8, 16, 16, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 256, 64, 64, 64, 8, 8, 16, 16, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 128, 64, 32, 64, 8, 8, 16, 16, 2, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 128, 64, 64, 64, 8, 8, 16, 16, 2, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1> +#endif + // clang-format on + >; + +template +using device_grouped_conv_fwd_wmma_cshufflev3_bilinear_bf16_instances_part2 = std::tuple< + // clang-format off + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MWmma| NWmma| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Size| Block| Block| Block| | | WMMA| WMMA| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MWmmaPerWave| NWmmaPerWave| _MBlock_MWaveMPerWmma| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerWmma| _NWaveNPerWmma| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 128, 64, 64, 64, 8, 8, 16, 16, 2, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1> +#ifndef ONE_INSTANCE_PER_LIST + , + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 128, 64, 64, 64, 8, 8, 16, 16, 2, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 128, 64, 96, 64, 8, 8, 16, 16, 2, 3, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 128, 64, 96, 64, 8, 8, 16, 16, 2, 3, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 128, 128, 96, 32, 8, 8, 16, 16, 4, 3, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 128, 128, 96, 32, 8, 8, 16, 16, 4, 3, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 128, 128, 96, 64, 8, 8, 16, 16, 4, 3, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 256, 128, 96, 64, 8, 8, 16, 16, 2, 3, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 64, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1> + +#endif + // clang-format on + >; + +template +using device_grouped_conv_fwd_wmma_cshufflev3_bilinear_bf16_instances_part3 = std::tuple< + // clang-format off + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MWmma| NWmma| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Size| Block| Block| Block| | | WMMA| WMMA| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MWmmaPerWave| NWmmaPerWave| _MBlock_MWaveMPerWmma| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerWmma| _NWaveNPerWmma| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 256, 128, 96, 64, 8, 8, 16, 16, 2, 3, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1> +#ifndef ONE_INSTANCE_PER_LIST + , + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 256, 128, 96, 64, 8, 8, 16, 16, 2, 3, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 64, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 128, 128, 64, 32, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 128, 128, 64, 32, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 128, 128, 64, 32, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 128, 128, 32, 32, 8, 8, 16, 16, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 128, 128, 32, 32, 8, 8, 16, 16, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 128, 128, 32, 32, 8, 8, 16, 16, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 256, 256, 64, 64, 8, 8, 16, 16, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1> + +#endif + // clang-format on + >; + +template +using device_grouped_conv_fwd_wmma_cshufflev3_bilinear_bf16_instances_part4 = std::tuple< + // clang-format off + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MWmma| NWmma| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Size| Block| Block| Block| | | WMMA| WMMA| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MWmmaPerWave| NWmmaPerWave| _MBlock_MWaveMPerWmma| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerWmma| _NWaveNPerWmma| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 256, 128, 256, 32, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1> +#ifndef ONE_INSTANCE_PER_LIST + , + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 128, 128, 128, 32, 8, 8, 16, 16, 8, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 128, 64, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 256, 64, 64, 32, 8, 8, 16, 16, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 256, 64, 64, 32, 8, 8, 16, 16, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1> +#endif + // clang-format on + >; + +// F16 instances +template +using device_grouped_conv_fwd_wmma_cshufflev3_bilinear_f16_instances_part1 = std::tuple< + // clang-format off + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MWmma| NWmma| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Size| Block| Block| Block| | | WMMA| WMMA| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MWmmaPerWave| NWmmaPerWave| _MBlock_MWaveMPerWmma| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerWmma| _NWaveNPerWmma| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 64, 64, 64, 32, 8, 8, 16, 16, 4, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1> +#ifndef ONE_INSTANCE_PER_LIST + , + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 64, 64, 32, 32, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 64, 64, 64, 32, 8, 8, 16, 16, 4, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 64, 64, 32, 32, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 64, 32, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1> +#endif + // clang-format on + >; + +template +using device_grouped_conv_fwd_wmma_cshufflev3_bilinear_f16_instances_part2 = std::tuple< + // clang-format off + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MWmma| NWmma| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Size| Block| Block| Block| | | WMMA| WMMA| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MWmmaPerWave| NWmmaPerWave| _MBlock_MWaveMPerWmma| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerWmma| _NWaveNPerWmma| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1> +#ifndef ONE_INSTANCE_PER_LIST + , + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 128, 128, 48, 64, 8, 8, 16, 16, 2, 3, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 2>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 128, 128, 48, 64, 8, 8, 16, 16, 2, 3, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 2>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 128, 64, 32, 64, 8, 8, 16, 16, 2, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 128, 64, 96, 64, 8, 8, 16, 16, 2, 3, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 128, 128, 96, 32, 8, 8, 16, 16, 4, 3, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1> +#endif + // clang-format on + >; + +template +using device_grouped_conv_fwd_wmma_cshufflev3_bilinear_f16_instances_part3 = std::tuple< + // clang-format off + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MWmma| NWmma| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Size| Block| Block| Block| | | WMMA| WMMA| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MWmmaPerWave| NWmmaPerWave| _MBlock_MWaveMPerWmma| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerWmma| _NWaveNPerWmma| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 128, 128, 96, 64, 8, 8, 16, 16, 4, 3, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1> +#ifndef ONE_INSTANCE_PER_LIST + , + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 256, 128, 96, 64, 8, 8, 16, 16, 2, 3, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 64, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 256, 128, 96, 64, 8, 8, 16, 16, 2, 3, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 128, 128, 64, 32, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 128, 128, 32, 32, 8, 8, 16, 16, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1> +#endif + // clang-format on + >; + +template +using device_grouped_conv_fwd_wmma_cshufflev3_bilinear_f16_instances_part4 = std::tuple< + // clang-format off + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MWmma| NWmma| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Size| Block| Block| Block| | | WMMA| WMMA| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MWmmaPerWave| NWmmaPerWave| _MBlock_MWaveMPerWmma| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerWmma| _NWaveNPerWmma| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 256, 128, 256, 32, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1> +#ifndef ONE_INSTANCE_PER_LIST + , + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 128, 128, 128, 32, 8, 8, 16, 16, 8, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 128, 64, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 256, 64, 64, 32, 8, 8, 16, 16, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 2, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 256, 64, 64, 32, 8, 8, 16, 16, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1> +#endif + // clang-format on + >; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_scale_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_scale_instance.hpp new file mode 100644 index 0000000000..92b8c7e509 --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_scale_instance.hpp @@ -0,0 +1,247 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_wmma_cshuffle_v3.hpp" +#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using BF16 = ck::bhalf_t; +using F16 = ck::half_t; +using F32 = float; +using I8 = int8_t; + +template +using S = ck::Sequence; + +using namespace ck::tensor_layout::convolution; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Scale = ck::tensor_operation::element_wise::Scale; + +static constexpr auto ConvFwdDefault = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto ConvFwd1x1P0 = ConvolutionForwardSpecialization::Filter1x1Pad0; + +static constexpr auto ConvFwd1x1S1P0 = ConvolutionForwardSpecialization::Filter1x1Stride1Pad0; + +static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding; + +// BF16 instances +template +using device_grouped_conv_fwd_wmma_cshufflev3_scale_bf16_instances_part1 = std::tuple< + // clang-format off + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MWmma| NWmma| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Size| Block| Block| Block| | | WMMA| WMMA| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MWmmaPerWave| NWmmaPerWave| _MBlock_MWaveMPerWmma| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerWmma| _NWaveNPerWmma| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 64, 64, 64, 32, 8, 8, 16, 16, 4, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1> +#ifndef ONE_INSTANCE_PER_LIST + , + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 128, 128, 48, 64, 8, 8, 16, 16, 2, 3, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 2>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 128, 128, 48, 64, 8, 8, 16, 16, 2, 3, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 64, 1, 2>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 256, 64, 64, 64, 8, 8, 16, 16, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 256, 64, 64, 64, 8, 8, 16, 16, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 128, 64, 32, 64, 8, 8, 16, 16, 2, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 128, 64, 64, 64, 8, 8, 16, 16, 2, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1> +#endif + // clang-format on + >; + +template +using device_grouped_conv_fwd_wmma_cshufflev3_scale_bf16_instances_part2 = std::tuple< + // clang-format off + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MWmma| NWmma| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Size| Block| Block| Block| | | WMMA| WMMA| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MWmmaPerWave| NWmmaPerWave| _MBlock_MWaveMPerWmma| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerWmma| _NWaveNPerWmma| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 128, 64, 64, 64, 8, 8, 16, 16, 2, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1> +#ifndef ONE_INSTANCE_PER_LIST + , + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 128, 64, 64, 64, 8, 8, 16, 16, 2, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 128, 64, 96, 64, 8, 8, 16, 16, 2, 3, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 128, 64, 96, 64, 8, 8, 16, 16, 2, 3, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 128, 128, 96, 32, 8, 8, 16, 16, 4, 3, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 128, 128, 96, 32, 8, 8, 16, 16, 4, 3, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 128, 128, 96, 64, 8, 8, 16, 16, 4, 3, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 256, 128, 96, 64, 8, 8, 16, 16, 2, 3, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 64, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1> +#endif + // clang-format on + >; + +template +using device_grouped_conv_fwd_wmma_cshufflev3_scale_bf16_instances_part3 = std::tuple< + // clang-format off + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MWmma| NWmma| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Size| Block| Block| Block| | | WMMA| WMMA| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MWmmaPerWave| NWmmaPerWave| _MBlock_MWaveMPerWmma| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerWmma| _NWaveNPerWmma| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 256, 128, 96, 64, 8, 8, 16, 16, 2, 3, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1> +#ifndef ONE_INSTANCE_PER_LIST + , + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 256, 128, 96, 64, 8, 8, 16, 16, 2, 3, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 64, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 128, 128, 64, 32, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 128, 128, 64, 32, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 128, 128, 64, 32, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 128, 128, 32, 32, 8, 8, 16, 16, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 128, 128, 32, 32, 8, 8, 16, 16, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 128, 128, 32, 32, 8, 8, 16, 16, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1> + +#endif + // clang-format on + >; + +template +using device_grouped_conv_fwd_wmma_cshufflev3_scale_bf16_instances_part4 = std::tuple< + // clang-format off + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MWmma| NWmma| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Size| Block| Block| Block| | | WMMA| WMMA| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MWmmaPerWave| NWmmaPerWave| _MBlock_MWaveMPerWmma| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerWmma| _NWaveNPerWmma| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 256, 256, 64, 64, 8, 8, 16, 16, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1> +#ifndef ONE_INSTANCE_PER_LIST + , + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 256, 128, 256, 32, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 128, 128, 128, 32, 8, 8, 16, 16, 8, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 128, 64, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 256, 64, 64, 32, 8, 8, 16, 16, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, BF16, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 256, 64, 64, 32, 8, 8, 16, 16, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1> +#endif + // clang-format on + >; + +// F16 instances +template +using device_grouped_conv_fwd_wmma_cshufflev3_scale_f16_instances_part1 = std::tuple< + // clang-format off + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MWmma| NWmma| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Size| Block| Block| Block| | | WMMA| WMMA| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MWmmaPerWave| NWmmaPerWave| _MBlock_MWaveMPerWmma| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerWmma| _NWaveNPerWmma| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F16, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 64, 64, 64, 32, 8, 8, 16, 16, 4, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1> +#ifndef ONE_INSTANCE_PER_LIST + , + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F16, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 64, 64, 32, 32, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F16, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F16, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 64, 64, 64, 32, 8, 8, 16, 16, 4, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F16, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 64, 64, 32, 32, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F16, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 64, 32, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1> +#endif + // clang-format on + >; + +template +using device_grouped_conv_fwd_wmma_cshufflev3_scale_f16_instances_part2 = std::tuple< + // clang-format off + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MWmma| NWmma| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Size| Block| Block| Block| | | WMMA| WMMA| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MWmmaPerWave| NWmmaPerWave| _MBlock_MWaveMPerWmma| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerWmma| _NWaveNPerWmma| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F16, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1> +#ifndef ONE_INSTANCE_PER_LIST + , + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F16, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 128, 128, 48, 64, 8, 8, 16, 16, 2, 3, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 2>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F16, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 128, 128, 48, 64, 8, 8, 16, 16, 2, 3, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 2>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F16, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 128, 64, 32, 64, 8, 8, 16, 16, 2, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F16, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 128, 64, 96, 64, 8, 8, 16, 16, 2, 3, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F16, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 128, 128, 96, 32, 8, 8, 16, 16, 4, 3, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1> +#endif + // clang-format on + >; + +template +using device_grouped_conv_fwd_wmma_cshufflev3_scale_f16_instances_part3 = std::tuple< + // clang-format off + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MWmma| NWmma| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Size| Block| Block| Block| | | WMMA| WMMA| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MWmmaPerWave| NWmmaPerWave| _MBlock_MWaveMPerWmma| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerWmma| _NWaveNPerWmma| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F16, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 128, 128, 96, 64, 8, 8, 16, 16, 4, 3, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1> +#ifndef ONE_INSTANCE_PER_LIST + , + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F16, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F16, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 256, 128, 96, 64, 8, 8, 16, 16, 2, 3, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 64, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F16, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 256, 128, 96, 64, 8, 8, 16, 16, 2, 3, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F16, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 128, 128, 64, 32, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F16, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 128, 128, 32, 32, 8, 8, 16, 16, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1> +#endif + // clang-format on + >; + +template +using device_grouped_conv_fwd_wmma_cshufflev3_scale_f16_instances_part4 = std::tuple< + // clang-format off + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MWmma| NWmma| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Size| Block| Block| Block| | | WMMA| WMMA| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MWmmaPerWave| NWmmaPerWave| _MBlock_MWaveMPerWmma| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerWmma| _NWaveNPerWmma| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F16, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 256, 128, 256, 32, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1> +#ifndef ONE_INSTANCE_PER_LIST + , + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F16, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 128, 128, 128, 32, 8, 8, 16, 16, 8, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F16, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 128, 64, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F16, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 256, 64, 64, 32, 8, 8, 16, 16, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 2, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F16, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 256, 64, 64, 32, 8, 8, 16, 16, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3, F16, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1> +#endif + // clang-format on + >; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_bilinear.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_bilinear.hpp index e15a32a9d7..e3c5bb3722 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_bilinear.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_bilinear.hpp @@ -21,6 +21,7 @@ namespace instance { using PassThrough = ck::tensor_operation::element_wise::PassThrough; using Bilinear = ck::tensor_operation::element_wise::Bilinear; +#ifdef CK_USE_XDL #ifdef CK_ENABLE_BF16 // grouped conv3d forward, NDHWGC/GKZYXC/NDHWGK void add_device_grouped_conv3d_fwd_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_bf16_instances( @@ -103,6 +104,120 @@ void add_device_grouped_conv3d_fwd_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_int8_instan PassThrough, Bilinear>>>& instances); #endif +#endif // CK_USE_XDL + +#ifdef CK_USE_WMMA +#ifdef CK_ENABLE_BF16 +// grouped conv3d forward, NDHWGC/GKZYXC/NDHWGK +void add_device_grouped_conv3d_fwd_wmma_cshufflev3_bilinear_ndhwgc_gkzyxc_ndhwgk_bf16_instances_part1( + std::vector, + NDHWGK, + BF16, + BF16, + ck::Tuple, + BF16, + PassThrough, + PassThrough, + Bilinear>>>& instances); +void add_device_grouped_conv3d_fwd_wmma_cshufflev3_bilinear_ndhwgc_gkzyxc_ndhwgk_bf16_instances_part2( + std::vector, + NDHWGK, + BF16, + BF16, + ck::Tuple, + BF16, + PassThrough, + PassThrough, + Bilinear>>>& instances); +void add_device_grouped_conv3d_fwd_wmma_cshufflev3_bilinear_ndhwgc_gkzyxc_ndhwgk_bf16_instances_part3( + std::vector, + NDHWGK, + BF16, + BF16, + ck::Tuple, + BF16, + PassThrough, + PassThrough, + Bilinear>>>& instances); +void add_device_grouped_conv3d_fwd_wmma_cshufflev3_bilinear_ndhwgc_gkzyxc_ndhwgk_bf16_instances_part4( + std::vector, + NDHWGK, + BF16, + BF16, + ck::Tuple, + BF16, + PassThrough, + PassThrough, + Bilinear>>>& instances); +#endif + +#ifdef CK_ENABLE_FP16 +void add_device_grouped_conv3d_fwd_wmma_cshufflev3_bilinear_ndhwgc_gkzyxc_ndhwgk_f16_instances_part1( + std::vector, + NDHWGK, + F16, + F16, + ck::Tuple, + F16, + PassThrough, + PassThrough, + Bilinear>>>& instances); +void add_device_grouped_conv3d_fwd_wmma_cshufflev3_bilinear_ndhwgc_gkzyxc_ndhwgk_f16_instances_part2( + std::vector, + NDHWGK, + F16, + F16, + ck::Tuple, + F16, + PassThrough, + PassThrough, + Bilinear>>>& instances); +void add_device_grouped_conv3d_fwd_wmma_cshufflev3_bilinear_ndhwgc_gkzyxc_ndhwgk_f16_instances_part3( + std::vector, + NDHWGK, + F16, + F16, + ck::Tuple, + F16, + PassThrough, + PassThrough, + Bilinear>>>& instances); +void add_device_grouped_conv3d_fwd_wmma_cshufflev3_bilinear_ndhwgc_gkzyxc_ndhwgk_f16_instances_part4( + std::vector, + NDHWGK, + F16, + F16, + ck::Tuple, + F16, + PassThrough, + PassThrough, + Bilinear>>>& instances); +#endif +#endif // CK_USE_WMMA template > op_ptrs; + +#ifdef CK_USE_XDL if constexpr(NumDimSpatial == 3 && is_same_v && is_same_v && is_same_v && DLayouts::Size() == 1 && is_same_v, NDHWGK>) @@ -194,6 +311,43 @@ struct DeviceOperationInstanceFactory && + is_same_v && is_same_v && + DLayouts::Size() == 1 && is_same_v, NDHWGK>) + { +#ifdef CK_ENABLE_BF16 + if constexpr(is_same_v && + is_same_v && is_same_v) + { + add_device_grouped_conv3d_fwd_wmma_cshufflev3_bilinear_ndhwgc_gkzyxc_ndhwgk_bf16_instances_part1( + op_ptrs); + add_device_grouped_conv3d_fwd_wmma_cshufflev3_bilinear_ndhwgc_gkzyxc_ndhwgk_bf16_instances_part2( + op_ptrs); + add_device_grouped_conv3d_fwd_wmma_cshufflev3_bilinear_ndhwgc_gkzyxc_ndhwgk_bf16_instances_part3( + op_ptrs); + add_device_grouped_conv3d_fwd_wmma_cshufflev3_bilinear_ndhwgc_gkzyxc_ndhwgk_bf16_instances_part4( + op_ptrs); + } +#endif +#ifdef CK_ENABLE_FP16 + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv3d_fwd_wmma_cshufflev3_bilinear_ndhwgc_gkzyxc_ndhwgk_f16_instances_part1( + op_ptrs); + add_device_grouped_conv3d_fwd_wmma_cshufflev3_bilinear_ndhwgc_gkzyxc_ndhwgk_f16_instances_part2( + op_ptrs); + add_device_grouped_conv3d_fwd_wmma_cshufflev3_bilinear_ndhwgc_gkzyxc_ndhwgk_f16_instances_part3( + op_ptrs); + add_device_grouped_conv3d_fwd_wmma_cshufflev3_bilinear_ndhwgc_gkzyxc_ndhwgk_f16_instances_part4( + op_ptrs); + } +#endif + } +#endif return op_ptrs; } diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_scale.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_scale.hpp index 6c5cedc1a6..317a7a62ff 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_scale.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_scale.hpp @@ -21,6 +21,7 @@ namespace instance { using PassThrough = ck::tensor_operation::element_wise::PassThrough; using Scale = ck::tensor_operation::element_wise::Scale; +#ifdef CK_USE_XDL #ifdef CK_ENABLE_BF16 // grouped conv3d forward, NDHWGC/GKZYXC/NDHWGK void add_device_grouped_conv3d_fwd_xdl_scale_ndhwgc_gkzyxc_ndhwgk_bf16_instances( @@ -103,6 +104,120 @@ void add_device_grouped_conv3d_fwd_xdl_scale_ndhwgc_gkzyxc_ndhwgk_int8_instances PassThrough, Scale>>>& instances); #endif +#endif + +#ifdef CK_USE_WMMA +#ifdef CK_ENABLE_BF16 +// grouped conv3d forward, NDHWGC/GKZYXC/NDHWGK +void add_device_grouped_conv3d_fwd_wmma_cshufflev3_scale_ndhwgc_gkzyxc_ndhwgk_bf16_instances_part1( + std::vector, + NDHWGK, + BF16, + BF16, + ck::Tuple<>, + BF16, + PassThrough, + PassThrough, + Scale>>>& instances); +void add_device_grouped_conv3d_fwd_wmma_cshufflev3_scale_ndhwgc_gkzyxc_ndhwgk_bf16_instances_part2( + std::vector, + NDHWGK, + BF16, + BF16, + ck::Tuple<>, + BF16, + PassThrough, + PassThrough, + Scale>>>& instances); +void add_device_grouped_conv3d_fwd_wmma_cshufflev3_scale_ndhwgc_gkzyxc_ndhwgk_bf16_instances_part3( + std::vector, + NDHWGK, + BF16, + BF16, + ck::Tuple<>, + BF16, + PassThrough, + PassThrough, + Scale>>>& instances); +void add_device_grouped_conv3d_fwd_wmma_cshufflev3_scale_ndhwgc_gkzyxc_ndhwgk_bf16_instances_part4( + std::vector, + NDHWGK, + BF16, + BF16, + ck::Tuple<>, + BF16, + PassThrough, + PassThrough, + Scale>>>& instances); +#endif + +#ifdef CK_ENABLE_FP16 +void add_device_grouped_conv3d_fwd_wmma_cshufflev3_scale_ndhwgc_gkzyxc_ndhwgk_f16_instances_part1( + std::vector, + NDHWGK, + F16, + F16, + ck::Tuple<>, + F16, + PassThrough, + PassThrough, + Scale>>>& instances); +void add_device_grouped_conv3d_fwd_wmma_cshufflev3_scale_ndhwgc_gkzyxc_ndhwgk_f16_instances_part2( + std::vector, + NDHWGK, + F16, + F16, + ck::Tuple<>, + F16, + PassThrough, + PassThrough, + Scale>>>& instances); +void add_device_grouped_conv3d_fwd_wmma_cshufflev3_scale_ndhwgc_gkzyxc_ndhwgk_f16_instances_part3( + std::vector, + NDHWGK, + F16, + F16, + ck::Tuple<>, + F16, + PassThrough, + PassThrough, + Scale>>>& instances); +void add_device_grouped_conv3d_fwd_wmma_cshufflev3_scale_ndhwgc_gkzyxc_ndhwgk_f16_instances_part4( + std::vector, + NDHWGK, + F16, + F16, + ck::Tuple<>, + F16, + PassThrough, + PassThrough, + Scale>>>& instances); +#endif +#endif template && is_same_v && is_same_v && is_same_v) { +#ifdef CK_USE_XDL add_device_grouped_conv3d_fwd_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f16_instances(op_ptrs); +#endif +#ifdef CK_USE_WMMA + add_device_grouped_conv3d_fwd_wmma_cshufflev3_scale_ndhwgc_gkzyxc_ndhwgk_f16_instances_part1( + op_ptrs); + add_device_grouped_conv3d_fwd_wmma_cshufflev3_scale_ndhwgc_gkzyxc_ndhwgk_f16_instances_part2( + op_ptrs); + add_device_grouped_conv3d_fwd_wmma_cshufflev3_scale_ndhwgc_gkzyxc_ndhwgk_f16_instances_part3( + op_ptrs); + add_device_grouped_conv3d_fwd_wmma_cshufflev3_scale_ndhwgc_gkzyxc_ndhwgk_f16_instances_part4( + op_ptrs); +#endif } #endif #ifdef CK_ENABLE_BF16 if constexpr(is_same_v && is_same_v && is_same_v) { +#ifdef CK_USE_XDL add_device_grouped_conv3d_fwd_xdl_scale_ndhwgc_gkzyxc_ndhwgk_bf16_instances( op_ptrs); +#endif +#ifdef CK_USE_WMMA + add_device_grouped_conv3d_fwd_wmma_cshufflev3_scale_ndhwgc_gkzyxc_ndhwgk_bf16_instances_part1( + op_ptrs); + add_device_grouped_conv3d_fwd_wmma_cshufflev3_scale_ndhwgc_gkzyxc_ndhwgk_bf16_instances_part2( + op_ptrs); + add_device_grouped_conv3d_fwd_wmma_cshufflev3_scale_ndhwgc_gkzyxc_ndhwgk_bf16_instances_part3( + op_ptrs); + add_device_grouped_conv3d_fwd_wmma_cshufflev3_scale_ndhwgc_gkzyxc_ndhwgk_bf16_instances_part4( + op_ptrs); +#endif } #endif #ifdef CK_ENABLE_INT8 if constexpr(is_same_v && is_same_v && is_same_v) { +#ifdef CK_USE_XDL add_device_grouped_conv3d_fwd_xdl_scale_ndhwgc_gkzyxc_ndhwgk_int8_instances( op_ptrs); +#endif } #endif } diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bilinear/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bilinear/CMakeLists.txt index bd143bc0b9..c196e1f7d8 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bilinear/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bilinear/CMakeLists.txt @@ -1,12 +1,20 @@ # Copyright (c) Advanced Micro Devices, Inc., or its affiliates. # SPDX-License-Identifier: MIT -# ONLY XDL_KERNELS +# ONLY XDL_AND_WMMA_KERNELS set(GROUPED_CONV3D_FWD_BILINEAR xdl/device_grouped_conv3d_fwd_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp xdl/device_grouped_conv3d_fwd_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp xdl/device_grouped_conv3d_fwd_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp xdl/device_grouped_conv3d_fwd_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.cpp - xdl/device_grouped_conv3d_fwd_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_int8_instance.cpp) + xdl/device_grouped_conv3d_fwd_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_int8_instance.cpp + wmma/device_grouped_conv3d_fwd_wmma_bilinear_ndhwgc_gkzyxc_ndhwgk_bf16_instance_part1.cpp + wmma/device_grouped_conv3d_fwd_wmma_bilinear_ndhwgc_gkzyxc_ndhwgk_bf16_instance_part2.cpp + wmma/device_grouped_conv3d_fwd_wmma_bilinear_ndhwgc_gkzyxc_ndhwgk_bf16_instance_part3.cpp + wmma/device_grouped_conv3d_fwd_wmma_bilinear_ndhwgc_gkzyxc_ndhwgk_bf16_instance_part4.cpp + wmma/device_grouped_conv3d_fwd_wmma_bilinear_ndhwgc_gkzyxc_ndhwgk_f16_instance_part1.cpp + wmma/device_grouped_conv3d_fwd_wmma_bilinear_ndhwgc_gkzyxc_ndhwgk_f16_instance_part2.cpp + wmma/device_grouped_conv3d_fwd_wmma_bilinear_ndhwgc_gkzyxc_ndhwgk_f16_instance_part3.cpp + wmma/device_grouped_conv3d_fwd_wmma_bilinear_ndhwgc_gkzyxc_ndhwgk_f16_instance_part4.cpp) add_instance_library(device_grouped_conv3d_fwd_bilinear_instance ${GROUPED_CONV3D_FWD_BILINEAR}) diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bilinear/wmma/device_grouped_conv3d_fwd_wmma_bilinear_ndhwgc_gkzyxc_ndhwgk_bf16_instance_part1.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bilinear/wmma/device_grouped_conv3d_fwd_wmma_bilinear_ndhwgc_gkzyxc_ndhwgk_bf16_instance_part1.cpp new file mode 100644 index 0000000000..98e4051eeb --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bilinear/wmma/device_grouped_conv3d_fwd_wmma_bilinear_ndhwgc_gkzyxc_ndhwgk_bf16_instance_part1.cpp @@ -0,0 +1,55 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_bilinear_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_fwd_wmma_cshufflev3_bilinear_ndhwgc_gkzyxc_ndhwgk_bf16_instances_part1( + std::vector, + NDHWGK, + BF16, + BF16, + ck::Tuple, + BF16, + PassThrough, + PassThrough, + Bilinear>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_cshufflev3_bilinear_bf16_instances_part1<3, + NDHWGC, + GKZYXC, + Tuple, + NDHWGK, + ConvFwdDefault>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_cshufflev3_bilinear_bf16_instances_part1<3, + NDHWGC, + GKZYXC, + Tuple, + NDHWGK, + ConvFwd1x1P0>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_cshufflev3_bilinear_bf16_instances_part1<3, + NDHWGC, + GKZYXC, + Tuple, + NDHWGK, + ConvFwd1x1S1P0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bilinear/wmma/device_grouped_conv3d_fwd_wmma_bilinear_ndhwgc_gkzyxc_ndhwgk_bf16_instance_part2.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bilinear/wmma/device_grouped_conv3d_fwd_wmma_bilinear_ndhwgc_gkzyxc_ndhwgk_bf16_instance_part2.cpp new file mode 100644 index 0000000000..020ea1dca8 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bilinear/wmma/device_grouped_conv3d_fwd_wmma_bilinear_ndhwgc_gkzyxc_ndhwgk_bf16_instance_part2.cpp @@ -0,0 +1,55 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_bilinear_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_fwd_wmma_cshufflev3_bilinear_ndhwgc_gkzyxc_ndhwgk_bf16_instances_part2( + std::vector, + NDHWGK, + BF16, + BF16, + ck::Tuple, + BF16, + PassThrough, + PassThrough, + Bilinear>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_cshufflev3_bilinear_bf16_instances_part2<3, + NDHWGC, + GKZYXC, + Tuple, + NDHWGK, + ConvFwdDefault>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_cshufflev3_bilinear_bf16_instances_part2<3, + NDHWGC, + GKZYXC, + Tuple, + NDHWGK, + ConvFwd1x1P0>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_cshufflev3_bilinear_bf16_instances_part2<3, + NDHWGC, + GKZYXC, + Tuple, + NDHWGK, + ConvFwd1x1S1P0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bilinear/wmma/device_grouped_conv3d_fwd_wmma_bilinear_ndhwgc_gkzyxc_ndhwgk_bf16_instance_part3.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bilinear/wmma/device_grouped_conv3d_fwd_wmma_bilinear_ndhwgc_gkzyxc_ndhwgk_bf16_instance_part3.cpp new file mode 100644 index 0000000000..80ff579c6e --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bilinear/wmma/device_grouped_conv3d_fwd_wmma_bilinear_ndhwgc_gkzyxc_ndhwgk_bf16_instance_part3.cpp @@ -0,0 +1,55 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_bilinear_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_fwd_wmma_cshufflev3_bilinear_ndhwgc_gkzyxc_ndhwgk_bf16_instances_part3( + std::vector, + NDHWGK, + BF16, + BF16, + ck::Tuple, + BF16, + PassThrough, + PassThrough, + Bilinear>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_cshufflev3_bilinear_bf16_instances_part3<3, + NDHWGC, + GKZYXC, + Tuple, + NDHWGK, + ConvFwdDefault>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_cshufflev3_bilinear_bf16_instances_part3<3, + NDHWGC, + GKZYXC, + Tuple, + NDHWGK, + ConvFwd1x1P0>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_cshufflev3_bilinear_bf16_instances_part3<3, + NDHWGC, + GKZYXC, + Tuple, + NDHWGK, + ConvFwd1x1S1P0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bilinear/wmma/device_grouped_conv3d_fwd_wmma_bilinear_ndhwgc_gkzyxc_ndhwgk_bf16_instance_part4.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bilinear/wmma/device_grouped_conv3d_fwd_wmma_bilinear_ndhwgc_gkzyxc_ndhwgk_bf16_instance_part4.cpp new file mode 100644 index 0000000000..061d1b98ca --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bilinear/wmma/device_grouped_conv3d_fwd_wmma_bilinear_ndhwgc_gkzyxc_ndhwgk_bf16_instance_part4.cpp @@ -0,0 +1,55 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_bilinear_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_fwd_wmma_cshufflev3_bilinear_ndhwgc_gkzyxc_ndhwgk_bf16_instances_part4( + std::vector, + NDHWGK, + BF16, + BF16, + ck::Tuple, + BF16, + PassThrough, + PassThrough, + Bilinear>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_cshufflev3_bilinear_bf16_instances_part4<3, + NDHWGC, + GKZYXC, + Tuple, + NDHWGK, + ConvFwdDefault>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_cshufflev3_bilinear_bf16_instances_part4<3, + NDHWGC, + GKZYXC, + Tuple, + NDHWGK, + ConvFwd1x1P0>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_cshufflev3_bilinear_bf16_instances_part4<3, + NDHWGC, + GKZYXC, + Tuple, + NDHWGK, + ConvFwd1x1S1P0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bilinear/wmma/device_grouped_conv3d_fwd_wmma_bilinear_ndhwgc_gkzyxc_ndhwgk_f16_instance_part1.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bilinear/wmma/device_grouped_conv3d_fwd_wmma_bilinear_ndhwgc_gkzyxc_ndhwgk_f16_instance_part1.cpp new file mode 100644 index 0000000000..022bc3c8d2 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bilinear/wmma/device_grouped_conv3d_fwd_wmma_bilinear_ndhwgc_gkzyxc_ndhwgk_f16_instance_part1.cpp @@ -0,0 +1,55 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_bilinear_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_fwd_wmma_cshufflev3_bilinear_ndhwgc_gkzyxc_ndhwgk_f16_instances_part1( + std::vector, + NDHWGK, + F16, + F16, + ck::Tuple, + F16, + PassThrough, + PassThrough, + Bilinear>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_cshufflev3_bilinear_f16_instances_part1<3, + NDHWGC, + GKZYXC, + Tuple, + NDHWGK, + ConvFwdDefault>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_cshufflev3_bilinear_f16_instances_part1<3, + NDHWGC, + GKZYXC, + Tuple, + NDHWGK, + ConvFwd1x1P0>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_cshufflev3_bilinear_f16_instances_part1<3, + NDHWGC, + GKZYXC, + Tuple, + NDHWGK, + ConvFwd1x1S1P0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bilinear/wmma/device_grouped_conv3d_fwd_wmma_bilinear_ndhwgc_gkzyxc_ndhwgk_f16_instance_part2.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bilinear/wmma/device_grouped_conv3d_fwd_wmma_bilinear_ndhwgc_gkzyxc_ndhwgk_f16_instance_part2.cpp new file mode 100644 index 0000000000..e2c8447c89 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bilinear/wmma/device_grouped_conv3d_fwd_wmma_bilinear_ndhwgc_gkzyxc_ndhwgk_f16_instance_part2.cpp @@ -0,0 +1,55 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_bilinear_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_fwd_wmma_cshufflev3_bilinear_ndhwgc_gkzyxc_ndhwgk_f16_instances_part2( + std::vector, + NDHWGK, + F16, + F16, + ck::Tuple, + F16, + PassThrough, + PassThrough, + Bilinear>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_cshufflev3_bilinear_f16_instances_part2<3, + NDHWGC, + GKZYXC, + Tuple, + NDHWGK, + ConvFwdDefault>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_cshufflev3_bilinear_f16_instances_part2<3, + NDHWGC, + GKZYXC, + Tuple, + NDHWGK, + ConvFwd1x1P0>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_cshufflev3_bilinear_f16_instances_part2<3, + NDHWGC, + GKZYXC, + Tuple, + NDHWGK, + ConvFwd1x1S1P0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bilinear/wmma/device_grouped_conv3d_fwd_wmma_bilinear_ndhwgc_gkzyxc_ndhwgk_f16_instance_part3.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bilinear/wmma/device_grouped_conv3d_fwd_wmma_bilinear_ndhwgc_gkzyxc_ndhwgk_f16_instance_part3.cpp new file mode 100644 index 0000000000..1d8464d122 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bilinear/wmma/device_grouped_conv3d_fwd_wmma_bilinear_ndhwgc_gkzyxc_ndhwgk_f16_instance_part3.cpp @@ -0,0 +1,55 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_bilinear_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_fwd_wmma_cshufflev3_bilinear_ndhwgc_gkzyxc_ndhwgk_f16_instances_part3( + std::vector, + NDHWGK, + F16, + F16, + ck::Tuple, + F16, + PassThrough, + PassThrough, + Bilinear>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_cshufflev3_bilinear_f16_instances_part3<3, + NDHWGC, + GKZYXC, + Tuple, + NDHWGK, + ConvFwdDefault>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_cshufflev3_bilinear_f16_instances_part3<3, + NDHWGC, + GKZYXC, + Tuple, + NDHWGK, + ConvFwd1x1P0>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_cshufflev3_bilinear_f16_instances_part3<3, + NDHWGC, + GKZYXC, + Tuple, + NDHWGK, + ConvFwd1x1S1P0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bilinear/wmma/device_grouped_conv3d_fwd_wmma_bilinear_ndhwgc_gkzyxc_ndhwgk_f16_instance_part4.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bilinear/wmma/device_grouped_conv3d_fwd_wmma_bilinear_ndhwgc_gkzyxc_ndhwgk_f16_instance_part4.cpp new file mode 100644 index 0000000000..ee9c227a25 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bilinear/wmma/device_grouped_conv3d_fwd_wmma_bilinear_ndhwgc_gkzyxc_ndhwgk_f16_instance_part4.cpp @@ -0,0 +1,55 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_bilinear_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_fwd_wmma_cshufflev3_bilinear_ndhwgc_gkzyxc_ndhwgk_f16_instances_part4( + std::vector, + NDHWGK, + F16, + F16, + ck::Tuple, + F16, + PassThrough, + PassThrough, + Bilinear>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_cshufflev3_bilinear_f16_instances_part4<3, + NDHWGC, + GKZYXC, + Tuple, + NDHWGK, + ConvFwdDefault>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_cshufflev3_bilinear_f16_instances_part4<3, + NDHWGC, + GKZYXC, + Tuple, + NDHWGK, + ConvFwd1x1P0>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_cshufflev3_bilinear_f16_instances_part4<3, + NDHWGC, + GKZYXC, + Tuple, + NDHWGK, + ConvFwd1x1S1P0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scale/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scale/CMakeLists.txt index 0622b121b5..416b9f6d69 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scale/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scale/CMakeLists.txt @@ -1,12 +1,20 @@ # Copyright (c) Advanced Micro Devices, Inc., or its affiliates. # SPDX-License-Identifier: MIT -# ONLY XDL_KERNELS -set(GROUPED_CONV3D_FWD_BILINEAR +# ONLY XDL_AND_WMMA_KERNELS +set(GROUPED_CONV3D_FWD_SCALE xdl/device_grouped_conv3d_fwd_xdl_scale_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp xdl/device_grouped_conv3d_fwd_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp xdl/device_grouped_conv3d_fwd_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp xdl/device_grouped_conv3d_fwd_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.cpp - xdl/device_grouped_conv3d_fwd_xdl_scale_ndhwgc_gkzyxc_ndhwgk_int8_instance.cpp) + xdl/device_grouped_conv3d_fwd_xdl_scale_ndhwgc_gkzyxc_ndhwgk_int8_instance.cpp + wmma/device_grouped_conv3d_fwd_wmma_scale_ndhwgc_gkzyxc_ndhwgk_bf16_instance_part1.cpp + wmma/device_grouped_conv3d_fwd_wmma_scale_ndhwgc_gkzyxc_ndhwgk_bf16_instance_part2.cpp + wmma/device_grouped_conv3d_fwd_wmma_scale_ndhwgc_gkzyxc_ndhwgk_bf16_instance_part3.cpp + wmma/device_grouped_conv3d_fwd_wmma_scale_ndhwgc_gkzyxc_ndhwgk_bf16_instance_part4.cpp + wmma/device_grouped_conv3d_fwd_wmma_scale_ndhwgc_gkzyxc_ndhwgk_f16_instance_part1.cpp + wmma/device_grouped_conv3d_fwd_wmma_scale_ndhwgc_gkzyxc_ndhwgk_f16_instance_part2.cpp + wmma/device_grouped_conv3d_fwd_wmma_scale_ndhwgc_gkzyxc_ndhwgk_f16_instance_part3.cpp + wmma/device_grouped_conv3d_fwd_wmma_scale_ndhwgc_gkzyxc_ndhwgk_f16_instance_part4.cpp) -add_instance_library(device_grouped_conv3d_fwd_scale_instance ${GROUPED_CONV3D_FWD_BILINEAR}) +add_instance_library(device_grouped_conv3d_fwd_scale_instance ${GROUPED_CONV3D_FWD_SCALE}) diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scale/wmma/device_grouped_conv3d_fwd_wmma_scale_ndhwgc_gkzyxc_ndhwgk_bf16_instance_part1.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scale/wmma/device_grouped_conv3d_fwd_wmma_scale_ndhwgc_gkzyxc_ndhwgk_bf16_instance_part1.cpp new file mode 100644 index 0000000000..38819a0354 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scale/wmma/device_grouped_conv3d_fwd_wmma_scale_ndhwgc_gkzyxc_ndhwgk_bf16_instance_part1.cpp @@ -0,0 +1,55 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_scale_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_fwd_wmma_cshufflev3_scale_ndhwgc_gkzyxc_ndhwgk_bf16_instances_part1( + std::vector, + NDHWGK, + BF16, + BF16, + ck::Tuple<>, + BF16, + PassThrough, + PassThrough, + Scale>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_cshufflev3_scale_bf16_instances_part1<3, + NDHWGC, + GKZYXC, + Tuple<>, + NDHWGK, + ConvFwdDefault>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_cshufflev3_scale_bf16_instances_part1<3, + NDHWGC, + GKZYXC, + Tuple<>, + NDHWGK, + ConvFwd1x1P0>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_cshufflev3_scale_bf16_instances_part1<3, + NDHWGC, + GKZYXC, + Tuple<>, + NDHWGK, + ConvFwd1x1S1P0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scale/wmma/device_grouped_conv3d_fwd_wmma_scale_ndhwgc_gkzyxc_ndhwgk_bf16_instance_part2.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scale/wmma/device_grouped_conv3d_fwd_wmma_scale_ndhwgc_gkzyxc_ndhwgk_bf16_instance_part2.cpp new file mode 100644 index 0000000000..9d554dfb24 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scale/wmma/device_grouped_conv3d_fwd_wmma_scale_ndhwgc_gkzyxc_ndhwgk_bf16_instance_part2.cpp @@ -0,0 +1,55 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_scale_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_fwd_wmma_cshufflev3_scale_ndhwgc_gkzyxc_ndhwgk_bf16_instances_part2( + std::vector, + NDHWGK, + BF16, + BF16, + ck::Tuple<>, + BF16, + PassThrough, + PassThrough, + Scale>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_cshufflev3_scale_bf16_instances_part2<3, + NDHWGC, + GKZYXC, + Tuple<>, + NDHWGK, + ConvFwdDefault>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_cshufflev3_scale_bf16_instances_part2<3, + NDHWGC, + GKZYXC, + Tuple<>, + NDHWGK, + ConvFwd1x1P0>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_cshufflev3_scale_bf16_instances_part2<3, + NDHWGC, + GKZYXC, + Tuple<>, + NDHWGK, + ConvFwd1x1S1P0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scale/wmma/device_grouped_conv3d_fwd_wmma_scale_ndhwgc_gkzyxc_ndhwgk_bf16_instance_part3.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scale/wmma/device_grouped_conv3d_fwd_wmma_scale_ndhwgc_gkzyxc_ndhwgk_bf16_instance_part3.cpp new file mode 100644 index 0000000000..6d310f8563 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scale/wmma/device_grouped_conv3d_fwd_wmma_scale_ndhwgc_gkzyxc_ndhwgk_bf16_instance_part3.cpp @@ -0,0 +1,55 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_scale_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_fwd_wmma_cshufflev3_scale_ndhwgc_gkzyxc_ndhwgk_bf16_instances_part3( + std::vector, + NDHWGK, + BF16, + BF16, + ck::Tuple<>, + BF16, + PassThrough, + PassThrough, + Scale>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_cshufflev3_scale_bf16_instances_part3<3, + NDHWGC, + GKZYXC, + Tuple<>, + NDHWGK, + ConvFwdDefault>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_cshufflev3_scale_bf16_instances_part3<3, + NDHWGC, + GKZYXC, + Tuple<>, + NDHWGK, + ConvFwd1x1P0>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_cshufflev3_scale_bf16_instances_part3<3, + NDHWGC, + GKZYXC, + Tuple<>, + NDHWGK, + ConvFwd1x1S1P0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scale/wmma/device_grouped_conv3d_fwd_wmma_scale_ndhwgc_gkzyxc_ndhwgk_bf16_instance_part4.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scale/wmma/device_grouped_conv3d_fwd_wmma_scale_ndhwgc_gkzyxc_ndhwgk_bf16_instance_part4.cpp new file mode 100644 index 0000000000..28512e9b81 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scale/wmma/device_grouped_conv3d_fwd_wmma_scale_ndhwgc_gkzyxc_ndhwgk_bf16_instance_part4.cpp @@ -0,0 +1,55 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_scale_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_fwd_wmma_cshufflev3_scale_ndhwgc_gkzyxc_ndhwgk_bf16_instances_part4( + std::vector, + NDHWGK, + BF16, + BF16, + ck::Tuple<>, + BF16, + PassThrough, + PassThrough, + Scale>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_cshufflev3_scale_bf16_instances_part4<3, + NDHWGC, + GKZYXC, + Tuple<>, + NDHWGK, + ConvFwdDefault>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_cshufflev3_scale_bf16_instances_part4<3, + NDHWGC, + GKZYXC, + Tuple<>, + NDHWGK, + ConvFwd1x1P0>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_cshufflev3_scale_bf16_instances_part4<3, + NDHWGC, + GKZYXC, + Tuple<>, + NDHWGK, + ConvFwd1x1S1P0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scale/wmma/device_grouped_conv3d_fwd_wmma_scale_ndhwgc_gkzyxc_ndhwgk_f16_instance_part1.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scale/wmma/device_grouped_conv3d_fwd_wmma_scale_ndhwgc_gkzyxc_ndhwgk_f16_instance_part1.cpp new file mode 100644 index 0000000000..16fedac2e9 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scale/wmma/device_grouped_conv3d_fwd_wmma_scale_ndhwgc_gkzyxc_ndhwgk_f16_instance_part1.cpp @@ -0,0 +1,55 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_scale_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_fwd_wmma_cshufflev3_scale_ndhwgc_gkzyxc_ndhwgk_f16_instances_part1( + std::vector, + NDHWGK, + F16, + F16, + ck::Tuple<>, + F16, + PassThrough, + PassThrough, + Scale>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_cshufflev3_scale_f16_instances_part1<3, + NDHWGC, + GKZYXC, + Tuple<>, + NDHWGK, + ConvFwdDefault>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_cshufflev3_scale_f16_instances_part1<3, + NDHWGC, + GKZYXC, + Tuple<>, + NDHWGK, + ConvFwd1x1P0>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_cshufflev3_scale_f16_instances_part1<3, + NDHWGC, + GKZYXC, + Tuple<>, + NDHWGK, + ConvFwd1x1S1P0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scale/wmma/device_grouped_conv3d_fwd_wmma_scale_ndhwgc_gkzyxc_ndhwgk_f16_instance_part2.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scale/wmma/device_grouped_conv3d_fwd_wmma_scale_ndhwgc_gkzyxc_ndhwgk_f16_instance_part2.cpp new file mode 100644 index 0000000000..47d354af37 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scale/wmma/device_grouped_conv3d_fwd_wmma_scale_ndhwgc_gkzyxc_ndhwgk_f16_instance_part2.cpp @@ -0,0 +1,55 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_scale_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_fwd_wmma_cshufflev3_scale_ndhwgc_gkzyxc_ndhwgk_f16_instances_part2( + std::vector, + NDHWGK, + F16, + F16, + ck::Tuple<>, + F16, + PassThrough, + PassThrough, + Scale>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_cshufflev3_scale_f16_instances_part2<3, + NDHWGC, + GKZYXC, + Tuple<>, + NDHWGK, + ConvFwdDefault>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_cshufflev3_scale_f16_instances_part2<3, + NDHWGC, + GKZYXC, + Tuple<>, + NDHWGK, + ConvFwd1x1P0>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_cshufflev3_scale_f16_instances_part2<3, + NDHWGC, + GKZYXC, + Tuple<>, + NDHWGK, + ConvFwd1x1S1P0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scale/wmma/device_grouped_conv3d_fwd_wmma_scale_ndhwgc_gkzyxc_ndhwgk_f16_instance_part3.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scale/wmma/device_grouped_conv3d_fwd_wmma_scale_ndhwgc_gkzyxc_ndhwgk_f16_instance_part3.cpp new file mode 100644 index 0000000000..25e823de91 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scale/wmma/device_grouped_conv3d_fwd_wmma_scale_ndhwgc_gkzyxc_ndhwgk_f16_instance_part3.cpp @@ -0,0 +1,55 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_scale_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_fwd_wmma_cshufflev3_scale_ndhwgc_gkzyxc_ndhwgk_f16_instances_part3( + std::vector, + NDHWGK, + F16, + F16, + ck::Tuple<>, + F16, + PassThrough, + PassThrough, + Scale>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_cshufflev3_scale_f16_instances_part3<3, + NDHWGC, + GKZYXC, + Tuple<>, + NDHWGK, + ConvFwdDefault>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_cshufflev3_scale_f16_instances_part3<3, + NDHWGC, + GKZYXC, + Tuple<>, + NDHWGK, + ConvFwd1x1P0>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_cshufflev3_scale_f16_instances_part3<3, + NDHWGC, + GKZYXC, + Tuple<>, + NDHWGK, + ConvFwd1x1S1P0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scale/wmma/device_grouped_conv3d_fwd_wmma_scale_ndhwgc_gkzyxc_ndhwgk_f16_instance_part4.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scale/wmma/device_grouped_conv3d_fwd_wmma_scale_ndhwgc_gkzyxc_ndhwgk_f16_instance_part4.cpp new file mode 100644 index 0000000000..50a9abd807 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scale/wmma/device_grouped_conv3d_fwd_wmma_scale_ndhwgc_gkzyxc_ndhwgk_f16_instance_part4.cpp @@ -0,0 +1,55 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_scale_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_fwd_wmma_cshufflev3_scale_ndhwgc_gkzyxc_ndhwgk_f16_instances_part4( + std::vector, + NDHWGK, + F16, + F16, + ck::Tuple<>, + F16, + PassThrough, + PassThrough, + Scale>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_cshufflev3_scale_f16_instances_part4<3, + NDHWGC, + GKZYXC, + Tuple<>, + NDHWGK, + ConvFwdDefault>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_cshufflev3_scale_f16_instances_part4<3, + NDHWGC, + GKZYXC, + Tuple<>, + NDHWGK, + ConvFwd1x1P0>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_wmma_cshufflev3_scale_f16_instances_part4<3, + NDHWGC, + GKZYXC, + Tuple<>, + NDHWGK, + ConvFwd1x1S1P0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/profiler/include/profiler/profile_grouped_conv_fwd_bilinear_impl.hpp b/profiler/include/profiler/profile_grouped_conv_fwd_bilinear_impl.hpp new file mode 100644 index 0000000000..3f4905c110 --- /dev/null +++ b/profiler/include/profiler/profile_grouped_conv_fwd_bilinear_impl.hpp @@ -0,0 +1,316 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "profiler/common.hpp" + +#include "ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_bilinear.hpp" + +#include "ck/library/utility/algorithm.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/convolution_parameter.hpp" +#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp" + +namespace ck { +namespace profiler { + +template +bool profile_grouped_conv_fwd_bilinear_impl( + int do_verification, + int init_method, + bool do_log, + bool time_kernel, + const ck::utils::conv::ConvParam& conv_param, + const ck::tensor_operation::element_wise::Bilinear& bilinear_op = + ck::tensor_operation::element_wise::Bilinear{}) +{ + using InElementOp = ck::tensor_operation::element_wise::PassThrough; + using WeiElementOp = ck::tensor_operation::element_wise::PassThrough; + using OutElementOp = ck::tensor_operation::element_wise::Bilinear; + using CShuffleDataType = float; + + bool pass = true; + + auto f_host_tensor_descriptor = + ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed(conv_param); + + auto f_host_tensor_descriptor_packed = + ck::utils::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed(conv_param); + + auto e_host_tensor_descriptor = + ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed(conv_param); + + auto d_host_tensor_descriptor = + ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed(conv_param); + + std::array a_g_n_c_wis_lengths{}; + std::array a_g_n_c_wis_strides{}; + std::array b_g_k_c_xs_lengths{}; + std::array b_g_k_c_xs_strides{}; + std::array d_g_n_k_wos_lengths{}; + std::array d_g_n_k_wos_strides{}; + std::array e_g_n_k_wos_lengths{}; + std::array e_g_n_k_wos_strides{}; + std::array conv_filter_strides{}; + std::array conv_filter_dilations{}; + std::array input_left_pads{}; + std::array input_right_pads{}; + + auto copy = [](const auto& x, auto& y) { ck::ranges::copy(x, y.begin()); }; + + copy(f_host_tensor_descriptor.GetLengths(), a_g_n_c_wis_lengths); + copy(f_host_tensor_descriptor.GetStrides(), a_g_n_c_wis_strides); + copy(f_host_tensor_descriptor_packed.GetLengths(), b_g_k_c_xs_lengths); + copy(f_host_tensor_descriptor_packed.GetStrides(), b_g_k_c_xs_strides); + copy(d_host_tensor_descriptor.GetLengths(), d_g_n_k_wos_lengths); + copy(d_host_tensor_descriptor.GetStrides(), d_g_n_k_wos_strides); + copy(e_host_tensor_descriptor.GetLengths(), e_g_n_k_wos_lengths); + copy(e_host_tensor_descriptor.GetStrides(), e_g_n_k_wos_strides); + copy(conv_param.conv_filter_strides_, conv_filter_strides); + copy(conv_param.conv_filter_dilations_, conv_filter_dilations); + copy(conv_param.input_left_pads_, input_left_pads); + copy(conv_param.input_right_pads_, input_right_pads); + + Tensor input(f_host_tensor_descriptor); + Tensor weight(f_host_tensor_descriptor_packed); + Tensor d_tensor(d_host_tensor_descriptor); + Tensor c(e_host_tensor_descriptor); + Tensor host_output(e_host_tensor_descriptor); + Tensor device_output(e_host_tensor_descriptor); + + std::cout << "input: " << input.mDesc << std::endl; + std::cout << "weight: " << weight.mDesc << std::endl; + std::cout << "d_tensor: " << d_tensor.mDesc << std::endl; + std::cout << "output: " << host_output.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + input.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + weight.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + d_tensor.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + input.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + weight.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + d_tensor.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + } + + DeviceMem in_device_buf(sizeof(InDataType) * input.mDesc.GetElementSpaceSize()); + DeviceMem wei_device_buf(sizeof(WeiDataType) * weight.mDesc.GetElementSpaceSize()); + DeviceMem d_device_buf(sizeof(DDataType) * d_tensor.mDesc.GetElementSpaceSize()); + DeviceMem out_device_buf(sizeof(OutDataType) * device_output.mDesc.GetElementSpaceSize()); + + in_device_buf.ToDevice(input.mData.data()); + wei_device_buf.ToDevice(weight.mData.data()); + d_device_buf.ToDevice(d_tensor.mData.data()); + + if(do_verification) + { + auto ref_conv = ck::tensor_operation::host::ReferenceConvFwd< + NDimSpatial, + InDataType, + WeiDataType, + CShuffleDataType, + InElementOp, + WeiElementOp, + ck::tensor_operation::element_wise::PassThrough>{}; + + auto ref_invoker = ref_conv.MakeInvoker(); + auto ref_argument = + ref_conv.MakeArgument(input, + weight, + c, + conv_param.conv_filter_strides_, + conv_param.conv_filter_dilations_, + conv_param.input_left_pads_, + conv_param.input_right_pads_, + InElementOp{}, + WeiElementOp{}, + ck::tensor_operation::element_wise::PassThrough{}); + + c.SetZero(); + ref_invoker.Run(ref_argument); + + host_output.ForEach([&](auto&, auto idx) { + const auto conv_shuffle = ck::type_convert(c(idx)); + + const auto conv_val = conv_shuffle; + const auto d_val = ck::type_convert(d_tensor(idx)); + + CShuffleDataType out_val{}; + bilinear_op(out_val, conv_val, d_val); + host_output(idx) = ck::type_convert(out_val); + }); + } + + std::string best_op_name; + float best_avg_time = 0; + float best_tflops = 0; + float best_gb_per_sec = 0; + int valids = 0; + + using DeviceOp = + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD, + OutLayout, + InDataType, + WeiDataType, + ck::Tuple, + OutDataType, + InElementOp, + WeiElementOp, + OutElementOp, + AComputeType, + BComputeType>; + + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + for(std::size_t i = 0; i < op_ptrs.size(); ++i) + { + auto& op_ptr = op_ptrs[i]; + + auto argument_ptr = op_ptr->MakeArgumentPointer( + static_cast(in_device_buf.GetDeviceBuffer()), + static_cast(wei_device_buf.GetDeviceBuffer()), + std::array{static_cast(d_device_buf.GetDeviceBuffer())}, + static_cast(out_device_buf.GetDeviceBuffer()), + a_g_n_c_wis_lengths, + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + std::array, 1>{d_g_n_k_wos_lengths}, + std::array, 1>{d_g_n_k_wos_strides}, + e_g_n_k_wos_lengths, + e_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + InElementOp{}, + WeiElementOp{}, + bilinear_op); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + ++valids; + + std::string op_name = op_ptr->GetTypeString(); + + if(do_log) + { + std::cout << "Evaluating [" << i << "] " << op_name << std::endl; + } + + out_device_buf.SetZero(); + auto ave_time = + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + + auto flop = conv_param.GetFlops(); + auto num_btype = conv_param.GetByte() + + sizeof(DDataType) * (conv_param.G_ * conv_param.N_ * conv_param.K_); + + for(std::size_t j = 0; j < conv_param.filter_spatial_lengths_.size(); ++j) + { + num_btype += sizeof(DDataType) * conv_param.output_spatial_lengths_[j]; + } + + float tflops = static_cast(flop) / 1.E9 / ave_time; + float gb_per_sec = num_btype / 1.E6 / ave_time; + + if(do_log) + { + std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops + << " TFlops, " << gb_per_sec << " GB/s, " << op_name << std::endl; + } + + if(tflops > best_tflops) + { + best_op_name = op_name; + best_tflops = tflops; + best_avg_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + + if(do_verification) + { + out_device_buf.FromDevice(device_output.mData.data()); + + bool is_valid = ck::utils::check_err(device_output, + host_output, + "Error: Device and Host results do not match!", + get_rtol(), + get_atol()); + + if(!is_valid) + { + pass = false; + } + + if(do_log) + { + LogRangeAsType(std::cout << "input : ", input.mData, ",") << std::endl; + LogRangeAsType(std::cout << "weight: ", weight.mData, ",") << std::endl; + LogRangeAsType(std::cout << "d_tensor: ", d_tensor.mData, ",") + << std::endl; + LogRangeAsType(std::cout << "host_output : ", host_output.mData, ",") + << std::endl; + LogRangeAsType(std::cout << "device_output: ", device_output.mData, ",") + << std::endl; + } + } + } + else + { + if(do_log) + { + std::cout << op_ptr->GetTypeString() << " does not support this problem" + << std::endl; + } + } + } + + printf("\033[36mvalids: %d\033[0m\n", valids); + + if(valids > 0) + { + std::cout << "Best Perf: " << std::setw(10) << best_avg_time << " ms, " << best_tflops + << " TFlops, " << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; + } + + return pass; +} + +} // namespace profiler +} // namespace ck diff --git a/profiler/include/profiler/profile_grouped_conv_fwd_outelementop_impl.hpp b/profiler/include/profiler/profile_grouped_conv_fwd_outelementop_impl.hpp index 50b97c3bae..acdc937a33 100644 --- a/profiler/include/profiler/profile_grouped_conv_fwd_outelementop_impl.hpp +++ b/profiler/include/profiler/profile_grouped_conv_fwd_outelementop_impl.hpp @@ -5,11 +5,24 @@ #include "ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_convscale.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_convinvscale.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_scale.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp" #include "ck/library/utility/device_memory.hpp" #include "ck/library/utility/host_tensor_generator.hpp" #include "profiler/common.hpp" +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/element/combined_element_wise_operation.hpp" + +#include "ck/library/utility/algorithm.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/convolution_parameter.hpp" +#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" + namespace ck { namespace profiler { @@ -110,8 +123,27 @@ bool profile_grouped_conv_fwd_outelementop_impl(int do_verification, auto scale_out = type_convert( type_convert(2.0f * float(RAND_MAX / 2 - std::rand()) / float(RAND_MAX))); - // initialize out_element_op for each iteration - const auto out_element_op = OutElementOp{scale_in, scale_wei, scale_out}; + OutElementOp out_element_op; + if constexpr(std::is_same_v) + { + using Scale = ck::tensor_operation::element_wise::Scale; + out_element_op = OutElementOp{Scale{scale_in}, Scale{scale_wei}, PassThrough{}}; + } + else if constexpr(std::is_same_v) + { + using Scale = ck::tensor_operation::element_wise::Scale; + using Relu = ck::tensor_operation::element_wise::Relu; + out_element_op = OutElementOp{Scale{scale_in}, Scale{scale_wei}, Relu{}}; + } + else if constexpr(std::is_same_v) + { + out_element_op = OutElementOp{scale_out}; + } + else + { + out_element_op = OutElementOp{scale_in, scale_wei, scale_out}; + } std::cout << "scale_in: " << scale_in << std::endl; std::cout << "scale_wei: " << scale_wei << std::endl; @@ -148,13 +180,32 @@ bool profile_grouped_conv_fwd_outelementop_impl(int do_verification, c.SetZero(); ref_invoker.Run(ref_argument); - host_output.ForEach([&](auto&, auto idx) { out_element_op(host_output(idx), c(idx)); }); + host_output.ForEach([&](auto&, auto idx) { + if constexpr(std::is_same_v) + { + const auto conv_shuffle = ck::type_convert(c(idx)); + if constexpr(std::is_same_v) + { + const auto conv_val = ck::type_convert(conv_shuffle); + out_element_op(host_output(idx), conv_val); + } + else + { + out_element_op(host_output(idx), conv_shuffle); + } + } + else + { + out_element_op(host_output(idx), c(idx)); + } + }); } std::string best_op_name; float best_avg_time = 0; float best_tflops = 0; float best_gb_per_sec = 0; + int valids = 0; auto run_impl = [&](auto& op_ptr, auto& argument_ptr) { if(op_ptr->IsSupportedArgument(argument_ptr.get())) @@ -163,6 +214,7 @@ bool profile_grouped_conv_fwd_outelementop_impl(int do_verification, out_device_buf.SetZero(); std::string op_name = op_ptr->GetTypeString(); + valids++; auto invoker_ptr = op_ptr->MakeInvokerPointer(); @@ -199,15 +251,11 @@ bool profile_grouped_conv_fwd_outelementop_impl(int do_verification, if(do_log) { - LogRangeAsType(std::cout << "input : ", input.mData, ",") + LogRangeAsType(std::cout << "input : ", input.mData, ",") << std::endl; + LogRangeAsType(std::cout << "weight: ", weight.mData, ",") << std::endl; + LogRangeAsType(std::cout << "host_output : ", host_output.mData, ",") << std::endl; - LogRangeAsType(std::cout << "weight: ", weight.mData, ",") - << std::endl; - LogRangeAsType( - std::cout << "host_output : ", host_output.mData, ",") - << std::endl; - LogRangeAsType( - std::cout << "device_output: ", device_output.mData, ",") + LogRangeAsType(std::cout << "device_output: ", device_output.mData, ",") << std::endl; } } @@ -264,9 +312,12 @@ bool profile_grouped_conv_fwd_outelementop_impl(int do_verification, run_impl(op_ptr, argument_ptr); } + printf("\033[36mvalids: %d\033[0m\n", valids); + std::cout << "Best configuration parameters:" << "\nname: " << best_op_name << "\navg_time: " << best_avg_time << "\ntflops: " << best_tflops << "\nGB/s: " << best_gb_per_sec << std::endl; + return pass; } diff --git a/profiler/src/CMakeLists.txt b/profiler/src/CMakeLists.txt index 7330fa0600..e484ff9ef7 100644 --- a/profiler/src/CMakeLists.txt +++ b/profiler/src/CMakeLists.txt @@ -79,7 +79,6 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx9|gfx1[12]") list(APPEND PROFILER_OPS profile_conv_fwd_bias_relu_add.cpp) list(APPEND PROFILER_OPS profile_conv_bwd_data.cpp) list(APPEND PROFILER_OPS profile_conv_fwd.cpp) - list(APPEND PROFILER_OPS profile_grouped_conv_fwd_outelementop.cpp) endif() if((SUPPORTED_GPU_TARGETS MATCHES "gfx9" AND (DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)) OR @@ -100,7 +99,9 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx9|gfx1[12]") list(APPEND PROFILER_OPS profile_grouped_conv_fwd_bias_clamp.cpp) list(APPEND PROFILER_OPS profile_grouped_conv_fwd_clamp.cpp) list(APPEND PROFILER_OPS profile_grouped_conv_bwd_data.cpp) + list(APPEND PROFILER_OPS profile_grouped_conv_fwd_bilinear.cpp) list(APPEND PROFILER_OPS profile_grouped_conv_bwd_weight.cpp) + list(APPEND PROFILER_OPS profile_grouped_conv_fwd_outelementop.cpp) list(APPEND PROFILER_OPS profile_gemm_multi_abd.cpp) if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) list(APPEND PROFILER_OPS profile_gemm_add_multiply.cpp) @@ -233,8 +234,10 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx9|gfx1[12]") list(APPEND DEVICE_INSTANCES device_grouped_conv3d_fwd_instance) list(APPEND DEVICE_INSTANCES device_grouped_conv2d_fwd_clamp_instance) list(APPEND DEVICE_INSTANCES device_grouped_conv3d_fwd_clamp_instance) + list(APPEND DEVICE_INSTANCES device_grouped_conv3d_fwd_scale_instance) list(APPEND DEVICE_INSTANCES device_grouped_conv2d_fwd_bias_clamp_instance) list(APPEND DEVICE_INSTANCES device_grouped_conv3d_fwd_bias_clamp_instance) + list(APPEND DEVICE_INSTANCES device_grouped_conv3d_fwd_bilinear_instance) list(APPEND DEVICE_INSTANCES device_gemm_add_relu_instance) list(APPEND DEVICE_INSTANCES device_gemm_multi_abd_instance) if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) diff --git a/profiler/src/profile_grouped_conv_fwd_bilinear.cpp b/profiler/src/profile_grouped_conv_fwd_bilinear.cpp new file mode 100644 index 0000000000..d4490abe7e --- /dev/null +++ b/profiler/src/profile_grouped_conv_fwd_bilinear.cpp @@ -0,0 +1,186 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "profiler/profile_grouped_conv_fwd_bilinear_impl.hpp" + +#include "ck/utility/data_type.hpp" +#include "ck/utility/ignore.hpp" +#include "profiler_operation_registry.hpp" + +#include + +enum struct ConvLayout +{ + GNHWC_GKYXC_GNHWK, // 0 + NHWGC_GKYXC_NHWGK, // 1 + NGCHW_GKYXC_NGKHW, // 2 + NGCHW_GKCYX_NGKHW, // 3 +}; + +enum struct ConvDataType +{ + F32_F32_F32, // 0 + F16_F16_F16, // 1 + BF16_BF16_BF16, // 2 + INT8_INT8_INT8, // 3 + F8_F8_F8, // 4 + BF8_BF8_F8, // 5 + F8_BF8_F8, // 6 + BF8_F8_F8, // 7 +}; + +enum struct IndexType +{ + INDEX_T, // 0 + LONG_INDEX_T, // 1 +}; + +#define OP_NAME "grouped_conv_fwd_bilinear" +#define OP_DESC "Grouped Convolution Forward+Bilinear" + +static void print_helper_msg() +{ + std::cout + // clang-format off + << "arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n" + << "arg2: data type (0: Input fp32, Weight fp32, Output fp32\n" + << " 1: Input fp16, Weight fp16, Output fp16\n" + << " 2: Input bf16, Weight bf16, Output bf16\n" + << " 3: Input int8, Weight int8, Output int8\n" + << " 4: Input fp8, Weight fp8, Output fp8\n" + << " 5: Input bf8, Weight bf8, Output fp8\n" + << " 6: Input fp8, Weight bf8, Output fp8\n" + << " 7: Input bf8, Weight fp8, Output fp8)\n" + << "arg3: tensor layout (0: Input[G, N, Hi, Wi, C], Weight[G, K, Y, X, C], Output[G, N, Ho, Wo, K]\n" + << " 1: Input[N, Hi, Wi, G, C], Weight[G, K, Y, X, C], Output[N, Ho, Wo, G, K]\n" + << " 2: Input[N, G, C, Hi, Wi], Weight[G, K, Y, X, C], Output[N, " + "G, K, Ho, Wo]\n" + << " 3: Input[N, G, C, Hi, Wi], Weight[G, K, C, Y, X], Output[N, " + "G, K, Ho, Wo])\n" + << "arg4: indexing data type (0: 32-bit, 1: 64-bit)\n" + << "arg5: verification (0: no, 1: yes)\n" + << "arg6: initialization (0: no init, 1: integer value, 2: decimal value)\n" + << "arg7: print tensor value (0: no; 1: yes)\n" + << "arg8: time kernel (0: no, 1: yes)\n" + << ck::utils::conv::get_conv_param_parser_helper_msg() << std::endl; + // clang-format on +} + +int grouped_conv_fwd_bilinear(int argc, char* argv[]) +{ + // 8 for control, 1 for num_dim_spatial + if(argc < 10) + { + print_helper_msg(); + return 1; + } + + const auto data_type = static_cast(std::stoi(argv[2])); + const auto layout = static_cast(std::stoi(argv[3])); + const auto index_type = static_cast(std::stoi(argv[4])); + const bool do_verification = std::stoi(argv[5]); + const int init_method = std::stoi(argv[6]); + const bool do_log = std::stoi(argv[7]); + const bool time_kernel = std::stoi(argv[8]); + const int num_dim_spatial = std::stoi(argv[9]); + + // 9 for control, 1 for num_dim_spatial, 4 for G/N/K/C, and 6 * num_dim_spatial + if(argc != 9 + 1 + 4 + 6 * num_dim_spatial) + { + print_helper_msg(); + return 1; + } + + const auto params = ck::utils::conv::parse_conv_param(num_dim_spatial, 10, argv); + + if(index_type != IndexType::INDEX_T) + { + std::cout << "this indexing data type is not implemented" << std::endl; + return 1; + } + + using F32 = float; + using BF16 = ck::bhalf_t; + using F16 = ck::half_t; + + using GKZYXC = ck::tensor_layout::convolution::GKZYXC; + using NDHWGC = ck::tensor_layout::convolution::NDHWGC; + using NDHWGK = ck::tensor_layout::convolution::NDHWGK; + + using GKZYXC = ck::tensor_layout::convolution::GKZYXC; + using NDHWGC = ck::tensor_layout::convolution::NDHWGC; + using NDHWGK = ck::tensor_layout::convolution::NDHWGK; + + constexpr auto I3 = ck::Number<3>{}; + + auto profile = [&](auto num_dim_spatial_tmp, + auto in_layout, + auto wei_layout, + auto out_layout, + auto in_type, + auto wei_type, + auto out_type, + auto a_compute_type, + auto b_compute_type) { + constexpr ck::index_t NDimSpatial = num_dim_spatial_tmp.value; + + using InLayout = decltype(in_layout); + using WeiLayout = decltype(wei_layout); + using OutLayout = decltype(out_layout); + + using InDataType = decltype(in_type); + using WeiDataType = decltype(wei_type); + using OutDataType = decltype(out_type); + + using AComputeType = decltype(a_compute_type); + using BComputeType = decltype(b_compute_type); + + // Create a Bilinear operation (requires two input tensors) + const auto bilinear_op = ck::tensor_operation::element_wise::Bilinear{}; + + bool pass = ck::profiler::profile_grouped_conv_fwd_bilinear_impl< + NDimSpatial, + InLayout, + WeiLayout, + OutLayout, // D layout same as output + OutLayout, + InDataType, + WeiDataType, + OutDataType, // D data type same as output + OutDataType, + AComputeType, + BComputeType, + ck::index_t>(do_verification, init_method, do_log, time_kernel, params, bilinear_op); + + return pass ? 0 : 1; + }; + + if(num_dim_spatial == 3 && layout == ConvLayout::NHWGC_GKYXC_NHWGK) + { + if(data_type == ConvDataType::F32_F32_F32) + { + return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F32{}, F32{}, F32{}, F32{}, F32{}); + } + else if(data_type == ConvDataType::F16_F16_F16) + { + return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F16{}, F16{}, F16{}, F16{}, F16{}); + } + else if(data_type == ConvDataType::BF16_BF16_BF16) + { + return profile( + I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, BF16{}, BF16{}, BF16{}, BF16{}, BF16{}); + } + else if(data_type == ConvDataType::INT8_INT8_INT8) + { + return profile( + I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, int8_t{}, int8_t{}, int8_t{}, int8_t{}, int8_t{}); + } + } + + std::cout << "this data_type & layout is not implemented" << std::endl; + + return 1; +} + +REGISTER_PROFILER_OPERATION(OP_NAME, OP_DESC, grouped_conv_fwd_bilinear); diff --git a/profiler/src/profile_grouped_conv_fwd_outelementop.cpp b/profiler/src/profile_grouped_conv_fwd_outelementop.cpp index 00b4bd8f13..134ccea45c 100644 --- a/profiler/src/profile_grouped_conv_fwd_outelementop.cpp +++ b/profiler/src/profile_grouped_conv_fwd_outelementop.cpp @@ -18,15 +18,19 @@ enum struct ConvLayout enum struct OutElementOp { ConvScale = 0, - ConvInvScale = 1 + ConvInvScale = 1, + Scale = 2 }; enum struct ConvDataType { - F8_F8_F8 = 0, - BF8_BF8_F8 = 1, - F8_BF8_F8 = 2, - BF8_F8_F8 = 3 + F8_F8_F8 = 0, + BF8_BF8_F8 = 1, + F8_BF8_F8 = 2, + BF8_F8_F8 = 3, + F16_F16_F16 = 4, + BF16_BF16_BF16 = 5, + I8_I8_I8 = 6 }; #define OP_NAME "grouped_conv_fwd_outelementop" @@ -41,8 +45,12 @@ static void print_helper_msg() << " 1: Input bf8, Weight bf8, Output fp8\n" << " 2: Input fp8, Weight bf8, Output fp8\n" << " 3: Input bf8, Weight fp8, Output fp8)\n" + << " 4: Input f16, Weight f16, Output f16)\n" + << " 5: Input bf16, Weight bf16, Output bf16)\n" + << " 6: Input i8, Weight i8, Output i8)\n" << "arg3: element-wise operation (0: ConvScale\n" - << " 1: ConvInvScale)\n" + << " 1: ConvInvScale\n" + << " 2: Scale\n" << "arg4: tensor layout (0: Input[G, N, Hi, Wi, C], Weight[G, K, Y, X, C], Output[G, N, Ho, Wo, K]\n" << " 1: Input[N, Hi, Wi, G, C], Weight[G, K, Y, X, C], Output[N, Ho, Wo, G, K])\n" << "arg5: verification (0: no, 1: yes)\n" @@ -81,8 +89,11 @@ int grouped_conv_fwd_outelementop(int argc, char* argv[]) const auto params = ck::utils::conv::parse_conv_param(num_dim_spatial, 10, argv); - using F8 = ck::f8_t; - using BF8 = ck::bf8_t; + using F8 = ck::f8_t; + using F16 = ck::half_t; + using BF8 = ck::bf8_t; + using BF16 = ck::bhalf_t; + using I8 = int8_t; using GKZYXC = ck::tensor_layout::convolution::GKZYXC; using NDHWGC = ck::tensor_layout::convolution::NDHWGC; @@ -90,6 +101,7 @@ int grouped_conv_fwd_outelementop(int argc, char* argv[]) using ConvScale = ck::tensor_operation::element_wise::ConvScale; using ConvInvScale = ck::tensor_operation::element_wise::ConvInvscale; + using Scale = ck::tensor_operation::element_wise::Scale; constexpr auto I3 = ck::Number<3>{}; @@ -213,6 +225,32 @@ int grouped_conv_fwd_outelementop(int argc, char* argv[]) F8{}); } } + else if(op == OutElementOp::Scale) + { + if(data_type == ConvDataType::F16_F16_F16) + { + return profile( + I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F16{}, F16{}, F16{}, Scale{}, F16{}, F16{}); + } + else if(data_type == ConvDataType::BF16_BF16_BF16) + { + return profile(I3, + NDHWGC{}, + GKZYXC{}, + NDHWGK{}, + BF16{}, + BF16{}, + BF16{}, + Scale{}, + BF16{}, + BF16{}); + } + else if(data_type == ConvDataType::I8_I8_I8) + { + return profile( + I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, I8{}, I8{}, I8{}, Scale{}, I8{}, I8{}); + } + } } std::cout << "this data_type & layout is not implemented" << std::endl; diff --git a/test/grouped_convnd_fwd/CMakeLists.txt b/test/grouped_convnd_fwd/CMakeLists.txt index c7f4f66f58..5e2db1184c 100644 --- a/test/grouped_convnd_fwd/CMakeLists.txt +++ b/test/grouped_convnd_fwd/CMakeLists.txt @@ -4,6 +4,8 @@ if(GPU_TARGETS MATCHES "gfx9|gfx11|gfx12") add_gtest_executable(test_grouped_convnd_fwd test_grouped_convnd_fwd.cpp) target_link_libraries(test_grouped_convnd_fwd PRIVATE utility device_grouped_conv1d_fwd_instance device_grouped_conv2d_fwd_instance device_grouped_conv3d_fwd_instance) + add_gtest_executable(test_grouped_convnd_fwd_bilinear test_grouped_convnd_fwd_bilinear.cpp) + target_link_libraries(test_grouped_convnd_fwd_bilinear PRIVATE utility device_grouped_conv3d_fwd_bilinear_instance) add_gtest_executable(test_grouped_convnd_fwd_scaleadd_ab test_grouped_convnd_fwd_scaleadd_ab.cpp) target_link_libraries(test_grouped_convnd_fwd_scaleadd_ab PRIVATE utility device_grouped_conv3d_fwd_scaleadd_ab_instance) diff --git a/test/grouped_convnd_fwd/test_grouped_convnd_fwd_bilinear.cpp b/test/grouped_convnd_fwd/test_grouped_convnd_fwd_bilinear.cpp new file mode 100644 index 0000000000..1b37f5eb4e --- /dev/null +++ b/test/grouped_convnd_fwd/test_grouped_convnd_fwd_bilinear.cpp @@ -0,0 +1,134 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include +#include +#include +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "profiler/profile_grouped_conv_fwd_bilinear_impl.hpp" +#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp" + +using I8 = int8_t; +using F8 = ck::f8_t; +using BF8 = ck::bf8_t; +using F16 = ck::half_t; +using BF16 = ck::bhalf_t; +using F32 = float; + +template +class TestGroupedConvndFwdBilinear : public ::testing::Test +{ + protected: + using InDataType = std::tuple_element_t<0, Tuple>; + using WeiDataType = std::tuple_element_t<1, Tuple>; + using OutDataType = std::tuple_element_t<2, Tuple>; + using AComputeType = std::tuple_element_t<3, Tuple>; + using BComputeType = std::tuple_element_t<4, Tuple>; + using InLayout = std::tuple_element_t<5, Tuple>; + using WeiLayout = std::tuple_element_t<6, Tuple>; + using OutLayout = std::tuple_element_t<7, Tuple>; + using IndexType = ck::index_t; + + std::vector conv_params; + + template + void Run() + { + EXPECT_FALSE(conv_params.empty()); + bool pass = true; + + // Create a Bilinear operation (binary element-wise operation) + const auto bilinear_op = ck::tensor_operation::element_wise::Bilinear{}; + + for(auto& param : conv_params) + { + if(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a") + { + if(std::is_same::value || std::is_same::value) + { + printf("Skipping FP8 / BF8 tests on CDNA1/2.\n"); + continue; + } + } + pass = pass && ck::profiler::profile_grouped_conv_fwd_bilinear_impl< + NDimSpatial, + InLayout, + WeiLayout, + OutLayout, // D layout same as output + OutLayout, + InDataType, + WeiDataType, + OutDataType, // D data type same as output + OutDataType, + AComputeType, + BComputeType, + IndexType>(true, // do_verification + 1, // init_method: integer value + false, // do_log + true, // time_kernel + param, + bilinear_op); + } + EXPECT_TRUE(pass); + } +}; + +using namespace ck::tensor_layout::convolution; + +using KernelTypes3d = + ::testing::Types, + std::tuple, + std::tuple, + std::tuple>; + +template +class TestGroupedConvndFwdBilinear3d : public TestGroupedConvndFwdBilinear +{ +}; + +TYPED_TEST_SUITE(TestGroupedConvndFwdBilinear3d, KernelTypes3d); + +TYPED_TEST(TestGroupedConvndFwdBilinear3d, Test3D) +{ + this->conv_params.clear(); + + this->conv_params.push_back( + {3, 3, 5, 96, 200, {1, 1, 1}, {37, 37, 16}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); + this->conv_params.push_back( + {3, 1, 1, 32, 32, {1, 1, 1}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); + this->conv_params.push_back( + {3, 1, 1, 32, 32, {2, 2, 2}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); + this->conv_params.push_back( + {3, 1, 1, 32, 32, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); + this->conv_params.push_back( + {3, 1, 1, 32, 32, {5, 5, 5}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); + this->conv_params.push_back( + {3, 1, 1, 32, 32, {9, 9, 9}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); + + this->conv_params.push_back( + {3, 2, 32, 128, 256, {1, 1, 1}, {7, 7, 7}, {2, 2, 2}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); + this->conv_params.push_back( + {3, 2, 32, 128, 256, {3, 3, 3}, {14, 14, 3}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); + + this->conv_params.push_back( + {3, 2, 32, 128, 256, {1, 1, 1}, {3, 3, 3}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); + this->conv_params.push_back( + {3, 1, 1, 32, 32, {1, 1, 1}, {16, 16, 16}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); + + this->conv_params.push_back( + {3, 1, 1, 1, 32, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); + this->conv_params.push_back( + {3, 1, 1, 64, 3, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); + this->conv_params.push_back( + {3, 1, 1, 1, 1, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); + + this->conv_params.push_back( + {3, 96, 1, 1, 1, {1, 1, 1}, {120, 40, 20}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); + this->conv_params.push_back( + {3, 96, 1, 1, 1, {3, 3, 3}, {120, 40, 20}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); + this->template Run<3>(); +} diff --git a/test/grouped_convnd_fwd_activation/CMakeLists.txt b/test/grouped_convnd_fwd_activation/CMakeLists.txt index de44195f99..e87ef77e6d 100644 --- a/test/grouped_convnd_fwd_activation/CMakeLists.txt +++ b/test/grouped_convnd_fwd_activation/CMakeLists.txt @@ -23,4 +23,7 @@ if(GPU_TARGETS MATCHES "gfx9|gfx11|gfx12") add_executable(test_grouped_convnd_fwd_bias_clamp_large_cases test_grouped_convnd_fwd_bias_clamp_large_cases.cpp) target_compile_options(test_grouped_convnd_fwd_bias_clamp_large_cases PRIVATE -Wno-global-constructors -Wno-undef) target_link_libraries(test_grouped_convnd_fwd_bias_clamp_large_cases PRIVATE gtest_main getopt::getopt utility device_grouped_conv2d_fwd_bias_clamp_instance device_grouped_conv3d_fwd_bias_clamp_instance) + + add_gtest_executable(test_grouped_convnd_fwd_scale test_grouped_convnd_fwd_scale.cpp) + target_link_libraries(test_grouped_convnd_fwd_scale PRIVATE utility device_grouped_conv3d_fwd_scale_instance) endif() diff --git a/test/grouped_convnd_fwd_activation/test_grouped_convnd_fwd_scale.cpp b/test/grouped_convnd_fwd_activation/test_grouped_convnd_fwd_scale.cpp new file mode 100644 index 0000000000..b4179cae62 --- /dev/null +++ b/test/grouped_convnd_fwd_activation/test_grouped_convnd_fwd_scale.cpp @@ -0,0 +1,124 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include +#include +#include +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "profiler/profile_grouped_conv_fwd_outelementop_impl.hpp" + +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/element/combined_element_wise_operation.hpp" + +using I8 = int8_t; +using F16 = ck::half_t; +using BF16 = ck::bhalf_t; + +template +class TestGroupedConvndFwdScale : public ::testing::Test +{ + protected: + using InDataType = std::tuple_element_t<0, Tuple>; + using WeiDataType = std::tuple_element_t<1, Tuple>; + using OutDataType = std::tuple_element_t<2, Tuple>; + using InLayout = std::tuple_element_t<3, Tuple>; + using WeiLayout = std::tuple_element_t<4, Tuple>; + using OutLayout = std::tuple_element_t<5, Tuple>; + using IndexType = ck::index_t; + + std::vector conv_params; + + template + void Run() + { + EXPECT_FALSE(conv_params.empty()); + bool pass = true; + for(auto& param : conv_params) + { + if(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a") + { + if(std::is_same::value || + std::is_same::value) + { + printf("Skipping FP8 / BF8 tests on CDNA1/2.\n"); + continue; + } + } + pass = pass && ck::profiler::profile_grouped_conv_fwd_outelementop_impl< + NDimSpatial, + InLayout, + WeiLayout, + OutLayout, + InDataType, + WeiDataType, + OutDataType, + ck::tensor_operation::element_wise::Scale, + InDataType, + InDataType>(true, // do_verification + 1, // init_method: integer value + false, // do_log + true, // time_kernel + param); + } + EXPECT_TRUE(pass); + } +}; + +using namespace ck::tensor_layout::convolution; +using CombConvScaleKernelTypes3d = + ::testing::Types, + std::tuple, + std::tuple>; + +template +class TestGroupedConvndFwdScale3d : public TestGroupedConvndFwdScale +{ +}; + +TYPED_TEST_SUITE(TestGroupedConvndFwdScale3d, CombConvScaleKernelTypes3d); + +TYPED_TEST(TestGroupedConvndFwdScale3d, Test3D) +{ + this->conv_params.clear(); + + this->conv_params.push_back( + {3, 3, 5, 96, 200, {1, 1, 1}, {37, 37, 16}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); + this->conv_params.push_back( + {3, 1, 1, 32, 32, {1, 1, 1}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); + this->conv_params.push_back( + {3, 1, 1, 32, 32, {2, 2, 2}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); + this->conv_params.push_back( + {3, 1, 1, 32, 32, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); + this->conv_params.push_back( + {3, 1, 1, 32, 32, {5, 5, 5}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); + this->conv_params.push_back( + {3, 1, 1, 32, 32, {9, 9, 9}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); + + this->conv_params.push_back( + {3, 2, 32, 128, 256, {1, 1, 1}, {7, 7, 7}, {2, 2, 2}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); + this->conv_params.push_back( + {3, 2, 32, 128, 256, {3, 3, 3}, {14, 14, 3}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); + + this->conv_params.push_back( + {3, 2, 32, 128, 256, {1, 1, 1}, {3, 3, 3}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); + this->conv_params.push_back( + {3, 1, 1, 32, 32, {1, 1, 1}, {16, 16, 16}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); + + this->conv_params.push_back( + {3, 1, 1, 1, 32, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); + this->conv_params.push_back( + {3, 1, 1, 64, 3, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); + this->conv_params.push_back( + {3, 1, 1, 1, 1, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); + + this->conv_params.push_back( + {3, 96, 1, 1, 1, {1, 1, 1}, {120, 40, 20}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); + this->conv_params.push_back( + {3, 96, 1, 1, 1, {3, 3, 3}, {120, 40, 20}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); + + this->template Run<3>(); +}