Conv:TF32: add more instances - 1 (#2867)

* conv:tf32:add more instances
* add instances of device_grouped_conv_fwd_xdl_f32_comp_instances
* add instances of device_grouped_conv_fwd_xdl_f32_tf32_mem_instances
* add instances of device_grouped_conv_fwd_xdl_large_tensor_f32_tf32_instances
* remove gnhwc/ngchw/ngcdhw instances
This commit is contained in:
yinglu
2025-09-25 09:27:18 +08:00
committed by GitHub
parent f076f207ce
commit df97a286d5
92 changed files with 4273 additions and 443 deletions

View File

@@ -16,6 +16,7 @@ namespace instance {
using BF16 = ck::bhalf_t;
using F16 = ck::half_t;
using F32 = float;
using TF32 = ck::tf32_t;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
@@ -139,6 +140,40 @@ using device_grouped_conv_fwd_xdl_bilinear_f32_instances = std::tuple<
// clang-format on
>;
template <index_t NDimSpatial,
typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
ConvolutionForwardSpecialization ConvSpec>
using device_grouped_conv_fwd_xdl_bilinear_f32_tf32_instances = std::tuple<
// clang-format off
//########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// generic instance
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, Tuple<F32>, F32, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 64, 64, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 8, 1, 8>, 1, TF32, TF32>,
// instances for small conv.K and conv.C
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, Tuple<F32>, F32, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 64, 64, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 1, TF32, TF32>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, Tuple<F32>, F32, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 256, 128, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, TF32, TF32>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, Tuple<F32>, F32, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, TF32, TF32>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, Tuple<F32>, F32, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 256, 128, 256, 16, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, TF32, TF32>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, Tuple<F32>, F32, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 128, 128, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4, TF32, TF32>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, Tuple<F32>, F32, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 256, 128, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, TF32, TF32>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, Tuple<F32>, F32, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 128, 128, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4, TF32, TF32>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, Tuple<F32>, F32, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 128, 64, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4, TF32, TF32>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, Tuple<F32>, F32, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 64, 64, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4, TF32, TF32>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, Tuple<F32>, F32, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 256, 128, 64, 16, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, TF32, TF32>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, Tuple<F32>, F32, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 256, 64, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, TF32, TF32>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, Tuple<F32>, F32, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 128, 128, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4, TF32, TF32>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, Tuple<F32>, F32, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 128, 32, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4, TF32, TF32>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, Tuple<F32>, F32, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 64, 64, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4, TF32, TF32>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, Tuple<F32>, F32, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 64, 32, 64, 16, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4, TF32, TF32>
// clang-format on
>;
template <index_t NDimSpatial,
typename ALayout,
typename BLayout,

View File

@@ -24,6 +24,7 @@ using BF8 = ck::bf8_t;
using BF16 = ck::bhalf_t;
using F16 = ck::half_t;
using F32 = float;
using TF32 = ck::tf32_t;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
@@ -205,6 +206,27 @@ using device_grouped_conv_fwd_xdl_f32_comp_instances = std::tuple<
// clang-format on
>;
template <index_t NDimSpatial,
typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
ConvolutionForwardSpecialization ConvSpec,
typename DsDataTypes = Tuple<>,
typename OutElementOp = PassThrough>
using device_grouped_conv_fwd_xdl_f32_tf32_comp_instances = std::tuple<
// clang-format off
//########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, DsDataTypes, F32, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4, TF32, TF32>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, DsDataTypes, F32, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 256, 128, 128, 64, 8, 8, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, TF32, TF32>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, DsDataTypes, F32, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 256, 128, 128, 64, 8, 8, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5, TF32, TF32>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, DsDataTypes, F32, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 256, 128, 128, 64, 8, 8, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1, TF32, TF32>
// clang-format on
>;
// double rate mfma instances on gfx950
template <index_t NDimSpatial,
typename ALayout,

View File

@@ -253,6 +253,25 @@ using device_grouped_conv_fwd_xdl_f32_generic_instances = std::tuple<
// clang-format on
>;
template <index_t NDimSpatial,
typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
ConvolutionForwardSpecialization ConvSpec,
typename DsDataTypes = Tuple<>,
typename OutElementOp = PassThrough>
using device_grouped_conv_fwd_xdl_f32_tf32_generic_instances = std::tuple<
// clang-format off
//########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| AComputeType| BComputeType|
//########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| DATATYPE | DATATYPE |
//########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl |
//########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// generic instance
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, DsDataTypes, F32, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 64, 64, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 8, 1, 8>, 1, TF32, TF32>
// clang-format on
>;
template <index_t NDimSpatial,
typename ALayout,
typename BLayout,

View File

@@ -16,6 +16,7 @@ namespace instance {
using BF16 = ck::bhalf_t;
using F16 = ck::half_t;
using F32 = float;
using TF32 = ck::tf32_t;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
@@ -99,6 +100,27 @@ using device_grouped_conv_fwd_xdl_large_tensor_f32_instances = std::tuple<
// clang-format on
>;
template <index_t NDimSpatial,
typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
ConvolutionForwardSpecialization ConvSpec,
typename DsDataTypes = Tuple<>,
typename OutElementOp = PassThrough>
using device_grouped_conv_fwd_xdl_large_tensor_f32_tf32_instances = std::tuple<
// clang-format off
//########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// generic instance
DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, DsDataTypes, F32, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 64, 64, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 8, 1, 8>, 1, TF32, TF32>,
DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, DsDataTypes, F32, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, TF32, TF32>
// clang-format on
>;
template <index_t NDimSpatial,
typename ALayout,
typename BLayout,

View File

@@ -24,6 +24,7 @@ using BF8 = ck::bf8_t;
using BF16 = ck::bhalf_t;
using F16 = ck::half_t;
using F32 = float;
using TF32 = ck::tf32_t;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
@@ -64,7 +65,7 @@ using device_grouped_conv_fwd_xdl_bf16_mem_instances = std::tuple<
//########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// Latency friendly
// Latency friendly
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, BF16, BF16, F32, BF16, DsDataTypes, BF16, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 128, 32, 16, 64, 8, 8, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, BF16, BF16, F32, BF16, DsDataTypes, BF16, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 64, 16, 16, 128, 8, 8, 16, 16, 1, 1, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, BF16, BF16, F32, BF16, DsDataTypes, BF16, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 64, 16, 16, 64, 8, 8, 16, 16, 1, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
@@ -163,6 +164,41 @@ using device_grouped_conv_fwd_xdl_f32_mem_instances = std::tuple<
// clang-format on
>;
template <index_t NDimSpatial,
typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
ConvolutionForwardSpecialization ConvSpec,
BlockGemmPipelineScheduler BlkGemmPipeSched,
typename DsDataTypes = Tuple<>,
typename OutElementOp = PassThrough>
using device_grouped_conv_fwd_xdl_f32_tf32_mem_instances = std::tuple<
// clang-format off
//########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, DsDataTypes, F32, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 128, 32, 16, 64, 8, 8, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, TF32, TF32>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, DsDataTypes, F32, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 64, 16, 16, 128, 8, 8, 16, 16, 1, 1, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, TF32, TF32>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, DsDataTypes, F32, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 64, 16, 16, 64, 8, 8, 16, 16, 1, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, TF32, TF32>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, DsDataTypes, F32, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 128, 16, 32, 64, 8, 8, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, TF32, TF32>,
// Memory friendly
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, DsDataTypes, F32, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 128, 128, 32, 64, 8, 8, 32, 32, 2, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, TF32, TF32>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, DsDataTypes, F32, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 128, 128, 16, 64, 8, 8, 16, 16, 4, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, TF32, TF32>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, DsDataTypes, F32, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 128, 64, 32, 64, 8, 8, 32, 32, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, TF32, TF32>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, DsDataTypes, F32, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 128, 64, 16, 64, 8, 8, 16, 16, 2, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, TF32, TF32>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, DsDataTypes, F32, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 128, 32, 16, 64, 8, 8, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, TF32, TF32>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, DsDataTypes, F32, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 64, 16, 16, 128, 8, 8, 16, 16, 1, 1, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, TF32, TF32>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, DsDataTypes, F32, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 64, 16, 16, 64, 8, 8, 16, 16, 1, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, TF32, TF32>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, DsDataTypes, F32, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 128, 16, 32, 64, 8, 8, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, TF32, TF32>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, DsDataTypes, F32, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 128, 16, 64, 64, 8, 8, 16, 16, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, TF32, TF32>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, DsDataTypes, F32, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 128, 32, 64, 64, 8, 8, 32, 32, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, TF32, TF32>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, DsDataTypes, F32, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 128, 16, 128, 64, 8, 8, 16, 16, 1, 4, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, TF32, TF32>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, DsDataTypes, F32, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 128, 32, 128, 64, 8, 8, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, TF32, TF32>
// clang-format on
>;
template <index_t NDimSpatial,
typename ALayout,
typename BLayout,

View File

@@ -16,6 +16,7 @@ namespace instance {
using BF16 = ck::bhalf_t;
using F16 = ck::half_t;
using F32 = float;
using TF32 = ck::tf32_t;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
@@ -142,6 +143,27 @@ using device_grouped_conv_fwd_xdl_merged_groups_f32_instances = std::tuple<
// clang-format on
>;
template <index_t NDimSpatial,
typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
ConvolutionForwardSpecialization ConvSpec,
typename DsDataTypes = Tuple<>,
typename OutElementOp = PassThrough>
using device_grouped_conv_fwd_xdl_merged_groups_f32_tf32_instances = std::tuple<
// clang-format off
//########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// Instances with NumGroupsPerBatch > 1
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, DsDataTypes, F32, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 64, 64, 16, 16, 4, 4, 16, 16, 4, 1, S< 4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, TF32, TF32, LoopScheduler::Default, 8>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, DsDataTypes, F32, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 64, 64, 16, 16, 4, 4, 16, 16, 4, 1, S< 4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, TF32, TF32, LoopScheduler::Default, 16>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, DsDataTypes, F32, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 64, 64, 16, 16, 4, 4, 16, 16, 4, 1, S< 4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, TF32, TF32, LoopScheduler::Default, 32>
// clang-format on
>;
template <index_t NDimSpatial,
typename ALayout,
typename BLayout,

View File

@@ -16,6 +16,7 @@ namespace instance {
using BF16 = ck::bhalf_t;
using F16 = ck::half_t;
using F32 = float;
using TF32 = ck::tf32_t;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
@@ -139,6 +140,40 @@ using device_grouped_conv_fwd_xdl_scale_f32_instances = std::tuple<
// clang-format on
>;
template <index_t NDimSpatial,
typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
ConvolutionForwardSpecialization ConvSpec>
using device_grouped_conv_fwd_xdl_scale_f32_tf32_instances = std::tuple<
// clang-format off
//########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// generic instance
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, Tuple<>, F32, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 1, 64, 64, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 8, 1, 8>, 1, TF32, TF32>,
// instances for small conv.K and conv.C
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, Tuple<>, F32, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 1, 64, 64, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 1, TF32, TF32>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, Tuple<>, F32, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 1, 256, 128, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, TF32, TF32>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, Tuple<>, F32, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, TF32, TF32>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, Tuple<>, F32, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 1, 256, 128, 256, 16, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, TF32, TF32>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, Tuple<>, F32, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 1, 128, 128, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4, TF32, TF32>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, Tuple<>, F32, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 1, 256, 128, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, TF32, TF32>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, Tuple<>, F32, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 1, 128, 128, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4, TF32, TF32>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, Tuple<>, F32, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 1, 128, 64, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4, TF32, TF32>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, Tuple<>, F32, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 1, 64, 64, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4, TF32, TF32>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, Tuple<>, F32, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 1, 256, 128, 64, 16, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, TF32, TF32>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, Tuple<>, F32, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 1, 256, 64, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, TF32, TF32>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, Tuple<>, F32, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 1, 128, 128, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4, TF32, TF32>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, Tuple<>, F32, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 1, 128, 32, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4, TF32, TF32>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, Tuple<>, F32, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 1, 64, 64, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4, TF32, TF32>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, Tuple<>, F32, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 1, 64, 32, 64, 16, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4, TF32, TF32>
// clang-format on
>;
template <index_t NDimSpatial,
typename ALayout,
typename BLayout,

View File

@@ -16,6 +16,7 @@ namespace instance {
using BF16 = ck::bhalf_t;
using F16 = ck::half_t;
using F32 = float;
using TF32 = ck::tf32_t;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
@@ -89,7 +90,7 @@ using device_grouped_conv_fwd_xdl_scaleadd_ab_f32_instances = std::tuple<
//########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// generic instance
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, ck::Tuple<>,ELayout, ck::Tuple<F32, F32>, ck::Tuple<F32, F32>, F32, F32, ck::Tuple<>, F32, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 1, 64, 64, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 8, 1, 8>, 1>,
// instances for small conv.K and conv.C
// instances for small conv.K and conv.C
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, ck::Tuple<>,ELayout, ck::Tuple<F32, F32>, ck::Tuple<F32, F32>, F32, F32, ck::Tuple<>, F32, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 1, 64, 64, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 1>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, ck::Tuple<>,ELayout, ck::Tuple<F32, F32>, ck::Tuple<F32, F32>, F32, F32, ck::Tuple<>, F32, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 1, 256, 128, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 4>,
@@ -97,6 +98,27 @@ using device_grouped_conv_fwd_xdl_scaleadd_ab_f32_instances = std::tuple<
// clang-format on
>;
template <index_t NDimSpatial,
typename ALayout,
typename BLayout,
typename ELayout,
ConvolutionForwardSpecialization ConvSpec>
using device_grouped_conv_fwd_xdl_scaleadd_ab_f32_tf32_instances = std::tuple<
// clang-format off
//########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// generic instance
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, ck::Tuple<>,ELayout, ck::Tuple<F32, F32>, ck::Tuple<F32, F32>, F32, F32, ck::Tuple<>, F32, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 1, 64, 64, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 8, 1, 8>, 1, TF32, TF32>,
// instances for small conv.K and conv.C
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, ck::Tuple<>,ELayout, ck::Tuple<F32, F32>, ck::Tuple<F32, F32>, F32, F32, ck::Tuple<>, F32, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 1, 64, 64, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 1, TF32, TF32>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, ck::Tuple<>,ELayout, ck::Tuple<F32, F32>, ck::Tuple<F32, F32>, F32, F32, ck::Tuple<>, F32, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 1, 256, 128, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 4, TF32, TF32>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, ck::Tuple<>,ELayout, ck::Tuple<F32, F32>, ck::Tuple<F32, F32>, F32, F32, ck::Tuple<>, F32, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, TF32, TF32>
// clang-format on
>;
template <index_t NDimSpatial,
typename ALayout,
typename BLayout,

View File

@@ -428,26 +428,39 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
{
#ifdef CK_ENABLE_FP32
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, float> && is_same_v<AComputeType, float> &&
is_same_v<BComputeType, float>)
is_same_v<OutDataType, float>)
{
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances(op_ptrs);
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_16x16_instances(op_ptrs);
add_device_grouped_conv3d_fwd_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f32_instances(
op_ptrs);
add_device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_instances(
op_ptrs);
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_comp_instances(op_ptrs);
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_intra_instances(
op_ptrs);
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_inter_instances(
op_ptrs);
}
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, float> && is_same_v<AComputeType, TF32> &&
is_same_v<BComputeType, TF32>)
{
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances(op_ptrs);
if constexpr(is_same_v<AComputeType, BComputeType> && is_same_v<BComputeType, TF32>)
{
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances(
op_ptrs);
add_device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances(
op_ptrs);
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_comp_instances(
op_ptrs);
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_mem_inter_instances(
op_ptrs);
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_mem_intra_instances(
op_ptrs);
add_device_grouped_conv3d_fwd_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances(
op_ptrs);
}
else
{
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances(op_ptrs);
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_16x16_instances(
op_ptrs);
add_device_grouped_conv3d_fwd_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f32_instances(
op_ptrs);
add_device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_instances(
op_ptrs);
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_comp_instances(
op_ptrs);
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_intra_instances(
op_ptrs);
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_inter_instances(
op_ptrs);
}
}
#endif

View File

@@ -127,24 +127,44 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
#endif
#ifdef CK_ENABLE_FP32
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, float> && is_same_v<AComputeType, float> &&
is_same_v<BComputeType, float>)
is_same_v<OutDataType, float>)
{
add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_instances(
op_ptrs);
add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_16x16_instances(
op_ptrs);
add_device_grouped_conv2d_fwd_bias_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f32_instances(
op_ptrs);
add_device_grouped_conv2d_fwd_bias_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_instances(
op_ptrs);
add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_comp_instances(
op_ptrs);
add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_mem_intra_instances(
op_ptrs);
add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_mem_inter_instances(
op_ptrs);
static_assert(is_same_v<AComputeType, BComputeType>,
"Error: AComputeType and BComputeType should be the same");
if constexpr(is_same_v<AComputeType, TF32>)
{
add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_instances(
op_ptrs);
add_device_grouped_conv2d_fwd_bias_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_tf32_instances(
op_ptrs);
add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_comp_instances(
op_ptrs);
add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_mem_inter_instances(
op_ptrs);
add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_mem_intra_instances(
op_ptrs);
add_device_grouped_conv2d_fwd_bias_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f32_tf32_instances(
op_ptrs);
}
else
{
add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_instances(
op_ptrs);
add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_16x16_instances(
op_ptrs);
add_device_grouped_conv2d_fwd_bias_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f32_instances(
op_ptrs);
add_device_grouped_conv2d_fwd_bias_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_instances(
op_ptrs);
add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_comp_instances(
op_ptrs);
add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_mem_intra_instances(
op_ptrs);
add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_mem_inter_instances(
op_ptrs);
}
}
#endif
}
// layout NDHWGC/GKZYXC/NDHWGK
@@ -197,32 +217,44 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
#endif
#ifdef CK_ENABLE_FP32
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, float> && is_same_v<AComputeType, float> &&
is_same_v<BComputeType, float>)
is_same_v<OutDataType, float>)
{
add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances(
op_ptrs);
add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_16x16_instances(
op_ptrs);
add_device_grouped_conv3d_fwd_bias_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f32_instances(
op_ptrs);
add_device_grouped_conv3d_fwd_bias_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_instances(
op_ptrs);
add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_comp_instances(
op_ptrs);
add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_intra_instances(
op_ptrs);
add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_inter_instances(
op_ptrs);
static_assert(is_same_v<AComputeType, BComputeType>,
"Error: AComputeType and BComputeType should be the same");
if constexpr(is_same_v<AComputeType, TF32>)
{
add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances(
op_ptrs);
add_device_grouped_conv3d_fwd_bias_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances(
op_ptrs);
add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_comp_instances(
op_ptrs);
add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_mem_inter_instances(
op_ptrs);
add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_mem_intra_instances(
op_ptrs);
add_device_grouped_conv3d_fwd_bias_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances(
op_ptrs);
}
else
{
add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances(
op_ptrs);
add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_16x16_instances(
op_ptrs);
add_device_grouped_conv3d_fwd_bias_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f32_instances(
op_ptrs);
add_device_grouped_conv3d_fwd_bias_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_instances(
op_ptrs);
add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_comp_instances(
op_ptrs);
add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_intra_instances(
op_ptrs);
add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_inter_instances(
op_ptrs);
}
}
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, float> && is_same_v<AComputeType, TF32> &&
is_same_v<BComputeType, TF32>)
{
add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances(
op_ptrs);
}
#endif
}
#endif // CK_USE_XDL

View File

@@ -480,6 +480,22 @@ void add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_instance
PassThrough,
AddClamp>>>& instances);
void add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Tuple<NHWGK>,
NHWGK,
F32,
F32,
Tuple<F32>,
F32,
PassThrough,
PassThrough,
AddClamp,
TF32,
TF32>>>& instances);
void add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_16x16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
@@ -508,6 +524,22 @@ void add_device_grouped_conv2d_fwd_bias_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk
PassThrough,
AddClamp>>>& instances);
void add_device_grouped_conv2d_fwd_bias_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f32_tf32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Tuple<NHWGK>,
NHWGK,
F32,
F32,
Tuple<F32>,
F32,
PassThrough,
PassThrough,
AddClamp,
TF32,
TF32>>>& instances);
void add_device_grouped_conv2d_fwd_bias_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
@@ -522,6 +554,22 @@ void add_device_grouped_conv2d_fwd_bias_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwg
PassThrough,
AddClamp>>>& instances);
void add_device_grouped_conv2d_fwd_bias_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_tf32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Tuple<NHWGK>,
NHWGK,
F32,
F32,
Tuple<F32>,
F32,
PassThrough,
PassThrough,
AddClamp,
TF32,
TF32>>>& instances);
void add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_comp_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
@@ -536,6 +584,22 @@ void add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_comp_ins
PassThrough,
AddClamp>>>& instances);
void add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_comp_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Tuple<NHWGK>,
NHWGK,
F32,
F32,
Tuple<F32>,
F32,
PassThrough,
PassThrough,
AddClamp,
TF32,
TF32>>>& instances);
void add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_mem_intra_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
@@ -550,6 +614,22 @@ void add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_mem_intr
PassThrough,
AddClamp>>>& instances);
void add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_mem_intra_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Tuple<NHWGK>,
NHWGK,
F32,
F32,
Tuple<F32>,
F32,
PassThrough,
PassThrough,
AddClamp,
TF32,
TF32>>>& instances);
void add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_mem_inter_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
@@ -564,6 +644,22 @@ void add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_mem_inte
PassThrough,
AddClamp>>>& instances);
void add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_mem_inter_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Tuple<NHWGK>,
NHWGK,
F32,
F32,
Tuple<F32>,
F32,
PassThrough,
PassThrough,
AddClamp,
TF32,
TF32>>>& instances);
void add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
@@ -622,6 +718,22 @@ void add_device_grouped_conv3d_fwd_bias_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndh
PassThrough,
AddClamp>>>& instances);
void add_device_grouped_conv3d_fwd_bias_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
Tuple<NDHWGK>,
NDHWGK,
F32,
F32,
Tuple<F32>,
F32,
PassThrough,
PassThrough,
AddClamp,
TF32,
TF32>>>& instances);
void add_device_grouped_conv3d_fwd_bias_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
@@ -636,6 +748,22 @@ void add_device_grouped_conv3d_fwd_bias_clamp_xdl_merged_groups_ndhwgc_gkzyxc_nd
PassThrough,
AddClamp>>>& instances);
void add_device_grouped_conv3d_fwd_bias_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
Tuple<NDHWGK>,
NDHWGK,
F32,
F32,
Tuple<F32>,
F32,
PassThrough,
PassThrough,
AddClamp,
TF32,
TF32>>>& instances);
void add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_comp_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
@@ -650,6 +778,22 @@ void add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_comp_
PassThrough,
AddClamp>>>& instances);
void add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_comp_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
Tuple<NDHWGK>,
NDHWGK,
F32,
F32,
Tuple<F32>,
F32,
PassThrough,
PassThrough,
AddClamp,
TF32,
TF32>>>& instances);
void add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_intra_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
@@ -664,6 +808,22 @@ void add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_i
PassThrough,
AddClamp>>>& instances);
void add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_mem_intra_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
Tuple<NDHWGK>,
NDHWGK,
F32,
F32,
Tuple<F32>,
F32,
PassThrough,
PassThrough,
AddClamp,
TF32,
TF32>>>& instances);
void add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_inter_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
@@ -678,6 +838,22 @@ void add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_i
PassThrough,
AddClamp>>>& instances);
void add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_mem_inter_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
Tuple<NDHWGK>,
NDHWGK,
F32,
F32,
Tuple<F32>,
F32,
PassThrough,
PassThrough,
AddClamp,
TF32,
TF32>>>& instances);
#endif
} // namespace instance

View File

@@ -68,6 +68,22 @@ void add_device_grouped_conv3d_fwd_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f32_instanc
PassThrough,
PassThrough,
Bilinear>>>& instances);
void add_device_grouped_conv3d_fwd_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
ck::Tuple<NDHWGK>,
NDHWGK,
F32,
F32,
ck::Tuple<F32>,
F32,
PassThrough,
PassThrough,
Bilinear,
TF32,
TF32>>>& instances);
#endif
#ifdef CK_ENABLE_INT8
@@ -137,8 +153,16 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, float>)
{
add_device_grouped_conv3d_fwd_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f32_instances(
op_ptrs);
if constexpr(is_same_v<ComputeType, TF32>)
{
add_device_grouped_conv3d_fwd_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances(
op_ptrs);
}
else
{
add_device_grouped_conv3d_fwd_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f32_instances(
op_ptrs);
}
}
#endif
#ifdef CK_ENABLE_FP16

View File

@@ -125,23 +125,44 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
#endif
#ifdef CK_ENABLE_FP32
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, float> && is_same_v<AComputeType, float> &&
is_same_v<BComputeType, float>)
is_same_v<OutDataType, float>)
{
add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_instances(op_ptrs);
add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_16x16_instances(
op_ptrs);
add_device_grouped_conv2d_fwd_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f32_instances(
op_ptrs);
add_device_grouped_conv2d_fwd_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_instances(
op_ptrs);
add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_comp_instances(
op_ptrs);
add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_mem_intra_instances(
op_ptrs);
add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_mem_inter_instances(
op_ptrs);
static_assert(is_same_v<AComputeType, BComputeType>,
"Error: AComputeType and BComputeType should be the same");
if constexpr(is_same_v<AComputeType, TF32>)
{
add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_instances(
op_ptrs);
add_device_grouped_conv2d_fwd_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_tf32_instances(
op_ptrs);
add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_comp_instances(
op_ptrs);
add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_mem_inter_instances(
op_ptrs);
add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_mem_intra_instances(
op_ptrs);
add_device_grouped_conv2d_fwd_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f32_tf32_instances(
op_ptrs);
}
else
{
add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_instances(
op_ptrs);
add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_16x16_instances(
op_ptrs);
add_device_grouped_conv2d_fwd_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f32_instances(
op_ptrs);
add_device_grouped_conv2d_fwd_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_instances(
op_ptrs);
add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_comp_instances(
op_ptrs);
add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_mem_intra_instances(
op_ptrs);
add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_mem_inter_instances(
op_ptrs);
}
}
#endif
}
// layout NDHWGC/GKZYXC/NDHWGK
@@ -193,30 +214,42 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
#endif
#ifdef CK_ENABLE_FP32
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, float> && is_same_v<AComputeType, float> &&
is_same_v<BComputeType, float>)
is_same_v<OutDataType, float>)
{
add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances(op_ptrs);
add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_16x16_instances(
op_ptrs);
add_device_grouped_conv3d_fwd_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f32_instances(
op_ptrs);
add_device_grouped_conv3d_fwd_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_instances(
op_ptrs);
add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_comp_instances(
op_ptrs);
add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_intra_instances(
op_ptrs);
add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_inter_instances(
op_ptrs);
}
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, float> && is_same_v<AComputeType, TF32> &&
is_same_v<BComputeType, TF32>)
{
add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances(
op_ptrs);
static_assert(is_same_v<AComputeType, BComputeType>,
"Error: AComputeType and BComputeType should be the same");
if constexpr(is_same_v<AComputeType, TF32>)
{
add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances(
op_ptrs);
add_device_grouped_conv3d_fwd_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances(
op_ptrs);
add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_comp_instances(
op_ptrs);
add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_mem_inter_instances(
op_ptrs);
add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_mem_intra_instances(
op_ptrs);
add_device_grouped_conv3d_fwd_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances(
op_ptrs);
}
else
{
add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances(
op_ptrs);
add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_16x16_instances(
op_ptrs);
add_device_grouped_conv3d_fwd_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f32_instances(
op_ptrs);
add_device_grouped_conv3d_fwd_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_instances(
op_ptrs);
add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_comp_instances(
op_ptrs);
add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_intra_instances(
op_ptrs);
add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_inter_instances(
op_ptrs);
}
}
#endif
}

View File

@@ -480,6 +480,22 @@ void add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_instances(
PassThrough,
Clamp>>>& instances);
void add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Tuple<>,
NHWGK,
F32,
F32,
Tuple<>,
F32,
PassThrough,
PassThrough,
Clamp,
TF32,
TF32>>>& instances);
void add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_16x16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
@@ -508,6 +524,22 @@ void add_device_grouped_conv2d_fwd_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f32_
PassThrough,
Clamp>>>& instances);
void add_device_grouped_conv2d_fwd_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f32_tf32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Tuple<>,
NHWGK,
F32,
F32,
Tuple<>,
F32,
PassThrough,
PassThrough,
Clamp,
TF32,
TF32>>>& instances);
void add_device_grouped_conv2d_fwd_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
@@ -522,6 +554,22 @@ void add_device_grouped_conv2d_fwd_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32
PassThrough,
Clamp>>>& instances);
void add_device_grouped_conv2d_fwd_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_tf32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Tuple<>,
NHWGK,
F32,
F32,
Tuple<>,
F32,
PassThrough,
PassThrough,
Clamp,
TF32,
TF32>>>& instances);
void add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_comp_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
@@ -536,6 +584,22 @@ void add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_comp_instance
PassThrough,
Clamp>>>& instances);
void add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_comp_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Tuple<>,
NHWGK,
F32,
F32,
Tuple<>,
F32,
PassThrough,
PassThrough,
Clamp,
TF32,
TF32>>>& instances);
void add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_mem_intra_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
@@ -550,6 +614,22 @@ void add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_mem_intra_ins
PassThrough,
Clamp>>>& instances);
void add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_mem_intra_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Tuple<>,
NHWGK,
F32,
F32,
Tuple<>,
F32,
PassThrough,
PassThrough,
Clamp,
TF32,
TF32>>>& instances);
void add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_mem_inter_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
@@ -564,6 +644,22 @@ void add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_mem_inter_ins
PassThrough,
Clamp>>>& instances);
void add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_mem_inter_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Tuple<>,
NHWGK,
F32,
F32,
Tuple<>,
F32,
PassThrough,
PassThrough,
Clamp,
TF32,
TF32>>>& instances);
void add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
@@ -622,6 +718,22 @@ void add_device_grouped_conv3d_fwd_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f
PassThrough,
Clamp>>>& instances);
void add_device_grouped_conv3d_fwd_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
Tuple<>,
NDHWGK,
F32,
F32,
Tuple<>,
F32,
PassThrough,
PassThrough,
Clamp,
TF32,
TF32>>>& instances);
void add_device_grouped_conv3d_fwd_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
@@ -636,6 +748,22 @@ void add_device_grouped_conv3d_fwd_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_
PassThrough,
Clamp>>>& instances);
void add_device_grouped_conv3d_fwd_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
Tuple<>,
NDHWGK,
F32,
F32,
Tuple<>,
F32,
PassThrough,
PassThrough,
Clamp,
TF32,
TF32>>>& instances);
void add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_comp_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
@@ -650,6 +778,22 @@ void add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_comp_insta
PassThrough,
Clamp>>>& instances);
void add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_comp_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
Tuple<>,
NDHWGK,
F32,
F32,
Tuple<>,
F32,
PassThrough,
PassThrough,
Clamp,
TF32,
TF32>>>& instances);
void add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_intra_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
@@ -664,6 +808,22 @@ void add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_intra_
PassThrough,
Clamp>>>& instances);
void add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_mem_intra_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
Tuple<>,
NDHWGK,
F32,
F32,
Tuple<>,
F32,
PassThrough,
PassThrough,
Clamp,
TF32,
TF32>>>& instances);
void add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_inter_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
@@ -678,6 +838,22 @@ void add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_inter_
PassThrough,
Clamp>>>& instances);
void add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_mem_inter_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
Tuple<>,
NDHWGK,
F32,
F32,
Tuple<>,
F32,
PassThrough,
PassThrough,
Clamp,
TF32,
TF32>>>& instances);
#endif
} // namespace instance

View File

@@ -111,6 +111,21 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_comp_instances(
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_comp_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Empty_Tuple,
NHWGK,
F32,
F32,
Empty_Tuple,
F32,
PassThrough,
PassThrough,
PassThrough,
TF32,
TF32>>>& instances);
#endif
#ifdef CK_ENABLE_INT8
@@ -281,6 +296,22 @@ void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_comp_instances(
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_comp_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
Empty_Tuple,
NDHWGK,
F32,
F32,
Empty_Tuple,
F32,
PassThrough,
PassThrough,
PassThrough,
TF32,
TF32>>>& instances);
#endif
// grouped conv3d forward, NGCDHW/GKCZYX/NGKDHW

View File

@@ -55,6 +55,21 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_mem_inter_instances
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_mem_inter_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Empty_Tuple,
NHWGK,
F32,
F32,
Empty_Tuple,
F32,
PassThrough,
PassThrough,
PassThrough,
TF32,
TF32>>>& instances);
#endif
#ifdef CK_ENABLE_INT8
@@ -169,6 +184,22 @@ void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_inter_instan
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_mem_inter_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
Empty_Tuple,
NDHWGK,
F32,
F32,
Empty_Tuple,
F32,
PassThrough,
PassThrough,
PassThrough,
TF32,
TF32>>>& instances);
#endif
// grouped conv3d forward, NGCDHW/GKCZYX/NGKDHW

View File

@@ -55,6 +55,21 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_mem_intra_instances
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_mem_intra_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Empty_Tuple,
NHWGK,
F32,
F32,
Empty_Tuple,
F32,
PassThrough,
PassThrough,
PassThrough,
TF32,
TF32>>>& instances);
#endif
#ifdef CK_ENABLE_INT8
@@ -169,6 +184,21 @@ void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_intra_instan
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_mem_intra_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
Empty_Tuple,
NDHWGK,
F32,
F32,
Empty_Tuple,
F32,
PassThrough,
PassThrough,
PassThrough,
TF32,
TF32>>>& instances);
#endif
// grouped conv3d forward, NGCDHW/GKCZYX/NGKDHW

View File

@@ -68,6 +68,22 @@ void add_device_grouped_conv3d_fwd_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f32_instances(
PassThrough,
PassThrough,
Scale>>>& instances);
void add_device_grouped_conv3d_fwd_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
ck::Tuple<>,
NDHWGK,
F32,
F32,
ck::Tuple<>,
F32,
PassThrough,
PassThrough,
Scale,
TF32,
TF32>>>& instances);
#endif
#ifdef CK_ENABLE_INT8
@@ -137,7 +153,16 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, float>)
{
add_device_grouped_conv3d_fwd_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f32_instances(op_ptrs);
if constexpr(is_same_v<ComputeType, TF32>)
{
add_device_grouped_conv3d_fwd_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances(
op_ptrs);
}
else
{
add_device_grouped_conv3d_fwd_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f32_instances(
op_ptrs);
}
}
#endif
#ifdef CK_ENABLE_FP16

View File

@@ -211,6 +211,22 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_16x16_instances(
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Empty_Tuple,
NHWGK,
F32,
F32,
Empty_Tuple,
F32,
PassThrough,
PassThrough,
PassThrough,
TF32,
TF32>>>& instances);
#endif
#ifdef CK_ENABLE_INT8

View File

@@ -55,6 +55,22 @@ void add_device_grouped_conv2d_fwd_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f32_instan
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv2d_fwd_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f32_tf32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Empty_Tuple,
NHWGK,
F32,
F32,
Empty_Tuple,
F32,
PassThrough,
PassThrough,
PassThrough,
TF32,
TF32>>>& instances);
#endif
#ifdef CK_ENABLE_INT8
@@ -120,6 +136,22 @@ void add_device_grouped_conv3d_fwd_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f32_ins
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv3d_fwd_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
Empty_Tuple,
NDHWGK,
F32,
F32,
Empty_Tuple,
F32,
PassThrough,
PassThrough,
PassThrough,
TF32,
TF32>>>& instances);
#endif
} // namespace instance

View File

@@ -84,6 +84,22 @@ void add_device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_insta
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_tf32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Empty_Tuple,
NHWGK,
F32,
F32,
Empty_Tuple,
F32,
PassThrough,
PassThrough,
PassThrough,
TF32,
TF32>>>& instances);
void add_device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkcyx_ngkhw_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NGCHW,
@@ -176,6 +192,22 @@ void add_device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_in
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
Empty_Tuple,
NDHWGK,
F32,
F32,
Empty_Tuple,
F32,
PassThrough,
PassThrough,
PassThrough,
TF32,
TF32>>>& instances);
#endif
// grouped conv3d forward, NGCDHW/GKCZYX/NGKDHW