mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-18 12:00:07 +00:00
Grouped Conv Bwd Data out index calculation optimizations (#2917)
* Grouped Conv Bwd Data index calculation optimizations
* fixes
* refactor instances
* gfx12 fixes
* temporary disable splitK for gfx12
[ROCm/composable_kernel commit: 5477811670]
This commit is contained in:
@@ -76,6 +76,47 @@ using device_grouped_conv_bwd_data_xdl_f16_16_16_instances =
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
template <index_t NDimSpatial,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
ConvolutionBackwardDataSpecialization ConvSpec>
|
||||
using device_grouped_conv_bwd_data_xdl_f16_optimized_loads_instances =
|
||||
std::tuple<
|
||||
// clang-format off
|
||||
// ##############################################| NDim| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| AElementwise| BElementwise| CDEElementwise| ConvolutionBackward| DoPad| DoPad| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffleMXdl| CShuffleNXdl| CDEBlockTransfer| CDEBlockTransfer|
|
||||
// ##############################################| Spatial| | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| PrefetchStage| Size| Block| Block| Block| | | XDL| XDL| PerWave| PerWave| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| PerWave| PerWave| _MBlock_MPerBlock| ScalarPerVector|
|
||||
// ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock|
|
||||
// ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
// A K1 one access for each thread per load
|
||||
// 32x32
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 128, 32, 64, 16, 16, 32, 32, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 4, 1, S<4, 4, 16>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, 1, 1, 1, S<1, 64, 1, 4>, 8>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 128, 32, 32, 8, 8, 32, 32, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 4, 1, S<4, 4, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, 1, 1, 1, S<1, 64, 1, 4>, 8>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 128, 32, 16, 4, 4, 32, 32, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 4, 4>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, 1, 1, 1, S<1, 64, 1, 4>, 8>,
|
||||
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 128, 32, 64, 16, 16, 32, 32, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 4, 1, S<4, 4, 16>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, 1, 1, 1, S<1, 16, 1, 16>, 2>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 128, 32, 32, 8, 8, 32, 32, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 4, 1, S<4, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, 1, 1, 1, S<1, 16, 1, 16>, 2>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 128, 32, 16, 4, 4, 32, 32, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 8, 4>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, 1, 1, 1, S<1, 16, 1, 16>, 2>,
|
||||
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 128, 32, 64, 16, 16, 32, 32, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 4, 1, S<4, 4, 16>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 8, 1, 32>, 1>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 128, 32, 32, 8, 8, 32, 32, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 4, 1, S<4, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 8, 1, 32>, 1>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 128, 32, 16, 4, 4, 32, 32, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 4>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 8, 1, 32>, 1>,
|
||||
// 16x16
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 64, 16, 64, 16, 16, 16, 16, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 4, 1, S<4, 2, 16>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, 1, 1, 1, S<1, 64, 1, 4>, 4>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 64, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 4, 1, S<4, 2, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, 1, 1, 1, S<1, 64, 1, 4>, 4>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 64, 16, 16, 4, 4, 16, 16, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 2, 4>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, 1, 1, 1, S<1, 64, 1, 4>, 4>,
|
||||
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 64, 16, 64, 16, 16, 16, 16, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 4, 1, S<4, 4, 16>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, 1, 1, 1, S<1, 32, 1, 8>, 2>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 64, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 4, 1, S<4, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, 1, 1, 1, S<1, 32, 1, 8>, 2>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 64, 16, 16, 4, 4, 16, 16, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 8, 4>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, 1, 1, 1, S<1, 32, 1, 8>, 2>,
|
||||
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 64, 16, 64, 16, 16, 16, 16, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 4, 1, S<4, 4, 16>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 64, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 4, 1, S<4, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 64, 16, 16, 4, 4, 16, 16, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 4>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
template <index_t NDimSpatial,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
@@ -178,6 +219,48 @@ using device_grouped_conv_bwd_data_xdl_bf16_16_16_instances = std::tuple<
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 64, 16, 32, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 2, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 4>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 64, 16, 32, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 64, 16, 32, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 2, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 4>
|
||||
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
template <index_t NDimSpatial,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
ConvolutionBackwardDataSpecialization ConvSpec>
|
||||
using device_grouped_conv_bwd_data_xdl_bf16_optimized_loads_instances = std::tuple<
|
||||
// clang-format off
|
||||
// ##############################################| NDim| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| AElementwise| BElementwise| CDEElementwise| ConvolutionBackward| DoPad| DoPad| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffleMXdl| CShuffleNXdl| CDEBlockTransfer| CDEBlockTransfer|
|
||||
// ##############################################| Spatial| | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| PrefetchStage| Size| Block| Block| Block| | | XDL| XDL| PerWave| PerWave| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| PerWave| PerWave| _MBlock_MPerBlock| ScalarPerVector|
|
||||
// ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock|
|
||||
// ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
// A K1 one access for each thread per load
|
||||
// 32x32
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 128, 32, 64, 16, 16, 32, 32, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 4, 1, S<4, 4, 16>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, 1, 1, 1, S<1, 64, 1, 4>, 8>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 128, 32, 32, 8, 8, 32, 32, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 4, 1, S<4, 4, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, 1, 1, 1, S<1, 64, 1, 4>, 8>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 128, 32, 16, 4, 4, 32, 32, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 4, 4>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, 1, 1, 1, S<1, 64, 1, 4>, 8>,
|
||||
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 128, 32, 64, 16, 16, 32, 32, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 4, 1, S<4, 4, 16>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, 1, 1, 1, S<1, 16, 1, 16>, 2>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 128, 32, 32, 8, 8, 32, 32, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 4, 1, S<4, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, 1, 1, 1, S<1, 16, 1, 16>, 2>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 128, 32, 16, 4, 4, 32, 32, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 8, 4>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, 1, 1, 1, S<1, 16, 1, 16>, 2>,
|
||||
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 128, 32, 64, 16, 16, 32, 32, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 4, 1, S<4, 4, 16>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 8, 1, 32>, 1>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 128, 32, 32, 8, 8, 32, 32, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 4, 1, S<4, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 8, 1, 32>, 1>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 128, 32, 16, 4, 4, 32, 32, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 4>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 8, 1, 32>, 1>,
|
||||
// 16x16
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 64, 16, 64, 16, 16, 16, 16, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 4, 1, S<4, 2, 16>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, 1, 1, 1, S<1, 64, 1, 4>, 4>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 64, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 4, 1, S<4, 2, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, 1, 1, 1, S<1, 64, 1, 4>, 4>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 64, 16, 16, 4, 4, 16, 16, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 2, 4>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, 1, 1, 1, S<1, 64, 1, 4>, 4>,
|
||||
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 64, 16, 64, 16, 16, 16, 16, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 4, 1, S<4, 4, 16>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, 1, 1, 1, S<1, 32, 1, 8>, 2>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 64, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 4, 1, S<4, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, 1, 1, 1, S<1, 32, 1, 8>, 2>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 64, 16, 16, 4, 4, 16, 16, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 8, 4>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, 1, 1, 1, S<1, 32, 1, 8>, 2>,
|
||||
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 64, 16, 64, 16, 16, 16, 16, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 4, 1, S<4, 4, 16>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 64, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 4, 1, S<4, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 64, 16, 16, 4, 4, 16, 16, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 4>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1>
|
||||
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
@@ -257,6 +340,41 @@ using device_grouped_conv_bwd_data_xdl_f32_16_16_instances =
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
template <index_t NDimSpatial,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
ConvolutionBackwardDataSpecialization ConvSpec>
|
||||
using device_grouped_conv_bwd_data_xdl_f32_optimized_loads_instances =
|
||||
std::tuple<
|
||||
// clang-format off
|
||||
// ##############################################| NDim| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| AElementwise| BElementwise| CDEElementwise| ConvolutionBackward| DoPad| DoPad| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffleMXdl| CShuffleNXdl| CDEBlockTransfer| CDEBlockTransfer|
|
||||
// ##############################################| Spatial| | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| PrefetchStage| Size| Block| Block| Block| | | XDL| XDL| PerWave| PerWave| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| PerWave| PerWave| _MBlock_MPerBlock| ScalarPerVector|
|
||||
// ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock|
|
||||
// ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
// A K1 one access for each thread per load
|
||||
// 32x32
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 128, 32, 32, 8, 8, 32, 32, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 4, 1, S<4, 4, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, 1, 1, 1, S<1, 32, 1, 8>, 4>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 128, 32, 16, 4, 4, 32, 32, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 4, 4>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, 1, 1, 1, S<1, 32, 1, 8>, 4>,
|
||||
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 128, 32, 32, 8, 8, 32, 32, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 4, 1, S<4, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, 1, 1, 1, S<1, 16, 1, 16>, 2>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 128, 32, 16, 4, 4, 32, 32, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 8, 4>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, 1, 1, 1, S<1, 16, 1, 16>, 2>,
|
||||
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 128, 32, 32, 8, 8, 32, 32, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 4, 1, S<4, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 8, 1, 32>, 1>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 128, 32, 16, 4, 4, 32, 32, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 4>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 8, 1, 32>, 1>,
|
||||
// 16x16
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 64, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 4, 1, S<4, 2, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, 1, 1, 1, S<1, 64, 1, 4>, 4>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 64, 16, 16, 4, 4, 16, 16, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 2, 4>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, 1, 1, 1, S<1, 64, 1, 4>, 4>,
|
||||
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 64, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 4, 1, S<4, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, 1, 1, 1, S<1, 32, 1, 8>, 2>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 64, 16, 16, 4, 4, 16, 16, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 8, 4>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, 1, 1, 1, S<1, 32, 1, 8>, 2>,
|
||||
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 64, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 4, 1, S<4, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 64, 16, 16, 4, 4, 16, 16, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 4>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
template <index_t NDimSpatial,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
|
||||
@@ -111,6 +111,8 @@ struct DeviceOperationInstanceFactory<
|
||||
add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f16_instances(op_ptrs);
|
||||
add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f16_16_16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f16_optimized_loads_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP32
|
||||
@@ -121,6 +123,8 @@ struct DeviceOperationInstanceFactory<
|
||||
add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_instances(op_ptrs);
|
||||
add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_16_16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_optimized_loads_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_BF16
|
||||
@@ -132,6 +136,8 @@ struct DeviceOperationInstanceFactory<
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_bf16_16_16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_bf16_optimized_loads_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
@@ -251,6 +257,8 @@ struct DeviceOperationInstanceFactory<
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f16_16_16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f16_optimized_loads_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#if defined CK_ENABLE_FP16 && defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
|
||||
@@ -271,6 +279,8 @@ struct DeviceOperationInstanceFactory<
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_16_16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_optimized_loads_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_BF16
|
||||
@@ -282,6 +292,8 @@ struct DeviceOperationInstanceFactory<
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_bf16_16_16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_bf16_optimized_loads_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
@@ -83,6 +83,20 @@ void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f16_16_16_instance
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f16_optimized_loads_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
|
||||
NHWGK,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGC,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP32
|
||||
void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_instances(
|
||||
@@ -112,6 +126,20 @@ void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_16_16_instance
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_optimized_loads_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
|
||||
NHWGK,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGC,
|
||||
F32,
|
||||
F32,
|
||||
Empty_Tuple,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
#ifdef CK_ENABLE_BF16
|
||||
void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_bf16_instances(
|
||||
@@ -141,6 +169,20 @@ void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_bf16_16_16_instanc
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_bf16_optimized_loads_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
|
||||
NHWGK,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGC,
|
||||
BF16,
|
||||
BF16,
|
||||
Empty_Tuple,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
|
||||
#ifdef CK_ENABLE_FP16
|
||||
@@ -393,6 +435,20 @@ void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f16_16_16_insta
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f16_optimized_loads_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
|
||||
NDHWGK,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NDHWGC,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP32
|
||||
void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_instances(
|
||||
@@ -422,6 +478,20 @@ void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_16_16_insta
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_optimized_loads_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
|
||||
NDHWGK,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NDHWGC,
|
||||
F32,
|
||||
F32,
|
||||
Empty_Tuple,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
#ifdef CK_ENABLE_BF16
|
||||
void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_bf16_instances(
|
||||
@@ -451,6 +521,20 @@ void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_bf16_16_16_inst
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_bf16_optimized_loads_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
|
||||
NDHWGK,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NDHWGC,
|
||||
BF16,
|
||||
BF16,
|
||||
Empty_Tuple,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
#if defined CK_ENABLE_FP16 && defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
|
||||
void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_input_f16_comp_bf8f8_instances(
|
||||
|
||||
Reference in New Issue
Block a user