GEMM+Bias+ReLU+Add (#76)

* tweak conv for odd C

* update script

* clean up elementwise op

* fix build

* clean up

* added example for gemm+bias+relu+add

* added example for gemm+bias+relu

* add profiler for gemm_s_shuffle; re-org files

* add profiler

* fix build

* clean up

* clean up

* clean up

* fix build

[ROCm/composable_kernel commit: 823657ed12]
This commit is contained in:
Chao Liu
2022-02-06 22:32:47 -06:00
committed by GitHub
parent 8890cc207d
commit 8efcb80fa5
77 changed files with 3865 additions and 932 deletions

View File

@@ -0,0 +1,149 @@
#include <stdlib.h>
#include "config.hpp"
#include "device_conv2d_fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp"
#include "element_wise_operation.hpp"
#include "device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_conv2d_fwd_bias_activation_add_instance {
using F16 = ck::half_t;
using F32 = float;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using AddReluAdd = ck::tensor_operation::element_wise::AddReluAdd;
static constexpr auto ConvFwdDefault =
ck::tensor_operation::device::ConvolutionForwardSpecialization_t::Default;
static constexpr auto ConvFwd1x1P0 =
ck::tensor_operation::device::ConvolutionForwardSpecialization_t::Filter1x1Pad0;
static constexpr auto ConvFwd1x1S1P0 =
ck::tensor_operation::device::ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0;
static constexpr auto ConvFwdOddC =
ck::tensor_operation::device::ConvolutionForwardSpecialization_t::OddC;
// arbitrary conv
using device_conv2d_fwd_xdl_c_shuffle_bias_relu_add_nhwc_kyxc_nhwk_f16_instances = std::tuple<
// clang-format off
//##############################################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//##############################################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector|
//##############################################################################################| | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl|
//##############################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwdDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwdDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwdDefault, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwdDefault, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwdDefault, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwdDefault, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwdDefault, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwdDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwdDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwdDefault, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwdDefault, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwdDefault, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwdDefault, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>
// clang-format on
>;
// 1x1, pad 0
using device_conv2d_fwd_xdl_c_shuffle_bias_relu_add_nhwc_kyxc_nhwk_1x1_p0_f16_instances = std::tuple<
// clang-format off
//##############################################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//##############################################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector|
//##############################################################################################| | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl|
//##############################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwd1x1P0, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwd1x1P0, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwd1x1P0, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwd1x1P0, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwd1x1P0, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwd1x1P0, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwd1x1P0, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwd1x1P0, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwd1x1P0, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwd1x1P0, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwd1x1P0, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwd1x1P0, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwd1x1P0, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>
// clang-format on
>;
// 1x1, stride 1, pad 0
using device_conv2d_fwd_xdl_c_shuffle_bias_relu_add_nhwc_kyxc_nhwk_1x1_s1_p0_f16_instances = std::tuple<
// clang-format off
//##############################################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//##############################################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector|
//##############################################################################################| | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl|
//##############################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwd1x1S1P0, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwd1x1S1P0, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwd1x1S1P0, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwd1x1S1P0, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwd1x1S1P0, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwd1x1S1P0, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwd1x1S1P0, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwd1x1S1P0, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwd1x1S1P0, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwd1x1S1P0, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwd1x1S1P0, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwd1x1S1P0, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwd1x1S1P0, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>
// clang-format on
>;
// Odd C
using device_conv2d_fwd_xdl_c_shuffle_bias_relu_add_nhwc_kyxc_nhwk_odd_c_f16_instances = std::tuple<
// clang-format off
//##############################################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//##############################################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector|
//##############################################################################################| | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl|
//##############################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwdOddC, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwdOddC, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwdOddC, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwdOddC, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwdOddC, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwdOddC, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwdOddC, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 2, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 2, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwdOddC, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwdOddC, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwdOddC, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwdOddC, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwdOddC, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 2, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 2, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwdOddC, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 2, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 2, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwdOddC, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<2, 32, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<2, 32, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwdOddC, 256, 128, 64, 2, 4, 32, 32, 2, 1, S<2, 32, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<2, 32, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwdOddC, 256, 256, 64, 2, 4, 32, 32, 4, 1, S<2, 32, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<2, 32, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwdOddC, 128, 128, 64, 2, 4, 32, 32, 2, 2, S<2, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<2, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwdOddC, 128, 64, 64, 2, 4, 32, 32, 1, 2, S<2, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<2, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>
// clang-format on
>;
void add_device_conv2d_fwd_xdl_c_shuffle_bias_relu_add_nhwc_kyxc_nhwk_f16_instances(
std::vector<DeviceConvFwdBiasActivationAddPtr<PassThrough, PassThrough, AddReluAdd>>& instances)
{
add_device_operation_instances(
instances, device_conv2d_fwd_xdl_c_shuffle_bias_relu_add_nhwc_kyxc_nhwk_f16_instances{});
add_device_operation_instances(
instances,
device_conv2d_fwd_xdl_c_shuffle_bias_relu_add_nhwc_kyxc_nhwk_1x1_p0_f16_instances{});
add_device_operation_instances(
instances,
device_conv2d_fwd_xdl_c_shuffle_bias_relu_add_nhwc_kyxc_nhwk_1x1_s1_p0_f16_instances{});
add_device_operation_instances(
instances,
device_conv2d_fwd_xdl_c_shuffle_bias_relu_add_nhwc_kyxc_nhwk_odd_c_f16_instances{});
}
} // namespace device_conv2d_fwd_bias_activation_add_instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,69 @@
#include <stdlib.h>
#include "config.hpp"
#include "device_conv2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp"
#include "element_wise_operation.hpp"
#include "device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_conv2d_fwd_bias_activation_atomic_add_instance {
using F16 = ck::half_t;
using F32 = float;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using AddRelu = ck::tensor_operation::element_wise::AddRelu;
static constexpr auto InMemoryAtomicAdd = ck::InMemoryDataOperationEnum_t::AtomicAdd;
static constexpr auto ConvFwdDefault =
ck::tensor_operation::device::ConvolutionForwardSpecialization_t::Default;
using device_conv2d_fwd_xdl_c_shuffle_bias_relu_atomic_add_nhwc_kyxc_nhwk_f16_instances = std::tuple<
// clang-format off
//##########################################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//##########################################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| GlobalMemory| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector|
//##########################################################################################| | | | | Operation| Operation| Operation| DataOperation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl|
//##########################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, InMemoryAtomicAdd, ConvFwdDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 8, 1, 1, 32>, 2>,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, InMemoryAtomicAdd, ConvFwdDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 8, 1, 1, 32>, 2>,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, InMemoryAtomicAdd, ConvFwdDefault, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 4, 1, 1, 32>, 2>,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, InMemoryAtomicAdd, ConvFwdDefault, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 8, 1, 1, 32>, 2>,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, InMemoryAtomicAdd, ConvFwdDefault, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 8, 1, 1, 16>, 2>,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, InMemoryAtomicAdd, ConvFwdDefault, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 4, 1, 1, 32>, 2>,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, InMemoryAtomicAdd, ConvFwdDefault, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 4, 1, 1, 16>, 2>,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, InMemoryAtomicAdd, ConvFwdDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 8, 1, 1, 32>, 2>,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, InMemoryAtomicAdd, ConvFwdDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 8, 1, 1, 32>, 2>,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, InMemoryAtomicAdd, ConvFwdDefault, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 8, 1, 1, 16>, 2>,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, InMemoryAtomicAdd, ConvFwdDefault, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 4, 1, 1, 32>, 2>,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, InMemoryAtomicAdd, ConvFwdDefault, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 4, 1, 1, 16>, 2>,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, InMemoryAtomicAdd, ConvFwdDefault, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 4, 1, 1, 16>, 2>
// clang-format on
>;
void add_device_conv2d_fwd_xdl_c_shuffle_bias_relu_atomic_add_nhwc_kyxc_nhwk_f16_instances(
std::vector<DeviceConvFwdBiasActivationPtr<PassThrough, PassThrough, AddRelu>>&
instance_container)
{
using Instances =
device_conv2d_fwd_xdl_c_shuffle_bias_relu_atomic_add_nhwc_kyxc_nhwk_f16_instances;
const auto instances = Instances{};
ck::static_for<0, std::tuple_size_v<Instances>, 1>{}([&](auto i) {
using Instance = remove_cvref_t<decltype(std::get<i>(instances))>;
auto instance = Instance{};
instance_container.push_back(std::make_unique<Instance>(instance));
});
}
} // namespace device_conv2d_fwd_bias_activation_atomic_add_instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,149 @@
#include <stdlib.h>
#include "config.hpp"
#include "device_conv2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp"
#include "element_wise_operation.hpp"
#include "device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_conv2d_fwd_bias_activation_instance {
using F16 = ck::half_t;
using F32 = float;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using AddRelu = ck::tensor_operation::element_wise::AddRelu;
static constexpr auto MemorySet = ck::InMemoryDataOperationEnum_t::Set;
static constexpr auto ConvFwdDefault =
ck::tensor_operation::device::ConvolutionForwardSpecialization_t::Default;
static constexpr auto ConvFwd1x1P0 =
ck::tensor_operation::device::ConvolutionForwardSpecialization_t::Filter1x1Pad0;
static constexpr auto ConvFwd1x1S1P0 =
ck::tensor_operation::device::ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0;
static constexpr auto ConvFwdOddC =
ck::tensor_operation::device::ConvolutionForwardSpecialization_t::OddC;
// arbitrary conv
using device_conv2d_fwd_xdl_c_shuffle_bias_relu_nhwc_kyxc_nhwk_f16_instances = std::tuple<
// clang-format off
//##########################################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//##########################################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| GlobalMemory| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector|
//##########################################################################################| | | | | Operation| Operation| Operation| DataOperation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl|
//##########################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwdDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwdDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwdDefault, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwdDefault, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwdDefault, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwdDefault, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwdDefault, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwdDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwdDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwdDefault, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwdDefault, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwdDefault, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwdDefault, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>
// clang-format on
>;
// 1x1, pad 0
using device_conv2d_fwd_xdl_c_shuffle_bias_relu_nhwc_kyxc_nhwk_1x1_p0_f16_instances = std::tuple<
// clang-format off
//##########################################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//##########################################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| GlobalMemory| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector|
//##########################################################################################| | | | | Operation| Operation| Operation| DataOperation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl|
//##########################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwd1x1P0, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwd1x1P0, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwd1x1P0, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwd1x1P0, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwd1x1P0, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwd1x1P0, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwd1x1P0, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwd1x1P0, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwd1x1P0, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwd1x1P0, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwd1x1P0, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwd1x1P0, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwd1x1P0, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>
// clang-format on
>;
// 1x1, stride 1, pad 0
using device_conv2d_fwd_xdl_c_shuffle_bias_relu_nhwc_kyxc_nhwk_1x1_s1_p0_f16_instances = std::tuple<
// clang-format off
//##########################################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//##########################################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| GlobalMemory| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector|
//##########################################################################################| | | | | Operation| Operation| Operation| DataOperation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl|
//##########################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwd1x1S1P0, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwd1x1S1P0, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwd1x1S1P0, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwd1x1S1P0, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwd1x1S1P0, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwd1x1S1P0, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwd1x1S1P0, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwd1x1S1P0, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwd1x1S1P0, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwd1x1S1P0, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwd1x1S1P0, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwd1x1S1P0, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwd1x1S1P0, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>
// clang-format on
>;
// Odd C
using device_conv2d_fwd_xdl_c_shuffle_bias_relu_nhwc_kyxc_nhwk_odd_c_f16_instances = std::tuple<
// clang-format off
//##########################################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//##########################################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| GlobalMemory| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector|
//##########################################################################################| | | | | Operation| Operation| Operation| DataOperation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl|
//##########################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwdOddC, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwdOddC, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwdOddC, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwdOddC, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwdOddC, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwdOddC, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwdOddC, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 2, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 2, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwdOddC, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwdOddC, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwdOddC, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwdOddC, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwdOddC, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 2, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 2, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwdOddC, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 2, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 2, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwdOddC, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<2, 32, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<2, 32, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwdOddC, 256, 128, 64, 2, 4, 32, 32, 2, 1, S<2, 32, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<2, 32, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwdOddC, 256, 256, 64, 2, 4, 32, 32, 4, 1, S<2, 32, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<2, 32, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwdOddC, 128, 128, 64, 2, 4, 32, 32, 2, 2, S<2, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<2, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwdOddC, 128, 64, 64, 2, 4, 32, 32, 1, 2, S<2, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<2, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>
// clang-format on
>;
void add_device_conv2d_fwd_xdl_c_shuffle_bias_relu_nhwc_kyxc_nhwk_f16_instances(
std::vector<DeviceConvFwdBiasActivationPtr<PassThrough, PassThrough, AddRelu>>& instances)
{
add_device_operation_instances(
instances, device_conv2d_fwd_xdl_c_shuffle_bias_relu_nhwc_kyxc_nhwk_f16_instances{});
add_device_operation_instances(
instances, device_conv2d_fwd_xdl_c_shuffle_bias_relu_nhwc_kyxc_nhwk_1x1_p0_f16_instances{});
add_device_operation_instances(
instances,
device_conv2d_fwd_xdl_c_shuffle_bias_relu_nhwc_kyxc_nhwk_1x1_s1_p0_f16_instances{});
add_device_operation_instances(
instances, device_conv2d_fwd_xdl_c_shuffle_bias_relu_nhwc_kyxc_nhwk_odd_c_f16_instances{});
}
} // namespace device_conv2d_fwd_bias_activation_instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,144 @@
#include <stdlib.h>
#include "config.hpp"
#include "device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp"
#include "element_wise_operation.hpp"
#include "device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_conv2d_fwd_instance {
using F16 = ck::half_t;
using F32 = float;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
static constexpr auto ConvFwdDefault =
ck::tensor_operation::device::ConvolutionForwardSpecialization_t::Default;
static constexpr auto ConvFwd1x1P0 =
ck::tensor_operation::device::ConvolutionForwardSpecialization_t::Filter1x1Pad0;
static constexpr auto ConvFwd1x1S1P0 =
ck::tensor_operation::device::ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0;
static constexpr auto ConvFwdOddC =
ck::tensor_operation::device::ConvolutionForwardSpecialization_t::OddC;
// arbitrary conv
using device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instances = std::tuple<
// clang-format off
//##########################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//##########################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector|
//##########################################################################| | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl|
//##########################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>,
DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>,
DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>,
DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>,
DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>,
DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>,
DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>,
DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>,
DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>,
DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>,
DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>,
DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>,
DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>
// clang-format on
>;
// 1x1, pad 0
using device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_1x1_p0_f16_instances = std::tuple<
// clang-format off
//##########################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//##########################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector|
//##########################################################################| | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl|
//##########################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>,
DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>,
DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>,
DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>,
DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>,
DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>,
DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>,
DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>,
DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>,
DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>,
DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>,
DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>,
DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>
// clang-format on
>;
// 1x1, stride 1, pad 0
using device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_1x1_s1_p0_f16_instances = std::tuple<
// clang-format off
//##########################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//##########################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector|
//##########################################################################| | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl|
//##########################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>,
DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>,
DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>,
DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>,
DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>,
DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>,
DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>,
DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>,
DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>,
DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>,
DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>,
DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>,
DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>
// clang-format on
>;
using device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_odd_c_f16_instances = std::tuple<
// clang-format off
//##########################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//##########################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector|
//##########################################################################| | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl|
//##########################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdOddC, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>,
DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdOddC, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>,
DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdOddC, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>,
DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdOddC, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>,
DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdOddC, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>,
DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdOddC, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>,
DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdOddC, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 2, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 2, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>,
DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdOddC, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>,
DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdOddC, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>,
DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdOddC, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>,
DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdOddC, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>,
DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdOddC, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 2, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 2, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>,
DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdOddC, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 2, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 2, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>,
DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdOddC, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<2, 32, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<2, 32, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>,
DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdOddC, 256, 128, 64, 2, 4, 32, 32, 2, 1, S<2, 32, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<2, 32, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>,
DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdOddC, 256, 256, 64, 2, 4, 32, 32, 4, 1, S<2, 32, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<2, 32, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>,
DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdOddC, 128, 128, 64, 2, 4, 32, 32, 2, 2, S<2, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<2, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>,
DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdOddC, 128, 64, 64, 2, 4, 32, 32, 1, 2, S<2, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<2, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>
// clang-format on
>;
void add_device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instances(
std::vector<DeviceConvFwdPtr<PassThrough, PassThrough, PassThrough>>& instances)
{
add_device_operation_instances(instances,
device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instances{});
add_device_operation_instances(
instances, device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_1x1_p0_f16_instances{});
add_device_operation_instances(
instances, device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_1x1_s1_p0_f16_instances{});
add_device_operation_instances(
instances, device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_odd_c_f16_instances{});
}
} // namespace device_conv2d_fwd_instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,109 @@
#include <stdlib.h>
#include "config.hpp"
#include "device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp"
#include "element_wise_operation.hpp"
#include "device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_conv2d_fwd_instance {
using F16 = ck::half_t;
using F32 = float;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
static constexpr auto ConvFwdDefault =
ck::tensor_operation::device::ConvolutionForwardSpecialization_t::Default;
static constexpr auto ConvFwd1x1P0 =
ck::tensor_operation::device::ConvolutionForwardSpecialization_t::Filter1x1Pad0;
static constexpr auto ConvFwd1x1S1P0 =
ck::tensor_operation::device::ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0;
// Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k]
using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances = std::tuple<
// clang-format off
//################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer|
//################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
//################################################################| | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
//################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>
// clang-format on
>;
using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_p0_f16_instances = std::tuple<
// clang-format off
//################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer|
//################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
//################################################################| | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
//################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>
// clang-format on
>;
using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_f16_instances = std::tuple<
// clang-format off
//################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer|
//################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
//################################################################| | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
//################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>
// clang-format on
>;
void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances(
std::vector<DeviceConvFwdPtr<PassThrough, PassThrough, PassThrough>>& instances)
{
add_device_operation_instances(instances, device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances{});
add_device_operation_instances(instances,
device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_p0_f16_instances{});
add_device_operation_instances(instances,
device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_f16_instances{});
}
} // namespace device_conv2d_fwd_instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,108 @@
#include <stdlib.h>
#include "config.hpp"
#include "device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp"
#include "element_wise_operation.hpp"
#include "device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_conv2d_fwd_instance {
using F32 = float;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
static constexpr auto ConvFwdDefault =
ck::tensor_operation::device::ConvolutionForwardSpecialization_t::Default;
static constexpr auto ConvFwd1x1P0 =
ck::tensor_operation::device::ConvolutionForwardSpecialization_t::Filter1x1Pad0;
static constexpr auto ConvFwd1x1S1P0 =
ck::tensor_operation::device::ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0;
// Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k]
using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances = std::tuple<
// clang-format off
//################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer|
//################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
//################################################################| | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
//################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>
// clang-format on
>;
using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_p0_f32_instances = std::tuple<
// clang-format off
//################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer|
//################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
//################################################################| | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
//################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>
// clang-format on
>;
using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_f32_instances = std::tuple<
// clang-format off
//################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer|
//################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
//################################################################| | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
//################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>
// clang-format on
>;
void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances(
std::vector<DeviceConvFwdPtr<PassThrough, PassThrough, PassThrough>>& instances)
{
add_device_operation_instances(instances, device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances{});
add_device_operation_instances(instances,
device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_p0_f32_instances{});
add_device_operation_instances(instances,
device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_f32_instances{});
}
} // namespace device_conv2d_fwd_instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,52 @@
#include <stdlib.h>
#include "config.hpp"
#include "device_gemm_xdl_c_shuffle_bias_activation_add.hpp"
#include "element_wise_operation.hpp"
#include "device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_gemm_instance {
using F16 = ck::half_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using AddReluAdd = ck::tensor_operation::element_wise::AddReluAdd;
// c[m, n] = ReLU(a[k, m] * b[k, n] + c0[n]) + c1[m, n]
using device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_km_kn_mn_instances = std::tuple<
// clang-format off
//#########################################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//#########################################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector|
//#########################################| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl|
//#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmXdl_C_Shuffle_Bias_Activation_Add< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, AddReluAdd, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>,
DeviceGemmXdl_C_Shuffle_Bias_Activation_Add< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, AddReluAdd, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>,
DeviceGemmXdl_C_Shuffle_Bias_Activation_Add< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, AddReluAdd, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>,
DeviceGemmXdl_C_Shuffle_Bias_Activation_Add< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, AddReluAdd, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>,
DeviceGemmXdl_C_Shuffle_Bias_Activation_Add< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, AddReluAdd, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>,
DeviceGemmXdl_C_Shuffle_Bias_Activation_Add< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, AddReluAdd, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>,
DeviceGemmXdl_C_Shuffle_Bias_Activation_Add< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, AddReluAdd, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>,
DeviceGemmXdl_C_Shuffle_Bias_Activation_Add< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, AddReluAdd, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>
// clang-format on
>;
void add_device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_km_kn_mn_instances(
std::vector<DeviceGemmBiasActivationAddPtr<PassThrough, PassThrough, AddReluAdd>>& instances)
{
add_device_operation_instances(
instances, device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_km_kn_mn_instances{});
}
} // namespace device_gemm_instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,52 @@
#include <stdlib.h>
#include "config.hpp"
#include "device_gemm_xdl_c_shuffle_bias_activation_add.hpp"
#include "element_wise_operation.hpp"
#include "device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_gemm_instance {
using F16 = ck::half_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using AddReluAdd = ck::tensor_operation::element_wise::AddReluAdd;
// c[m, n] = ReLU(a[k, m] * b[n, k] + c0[n]) + c1[m, n]
using device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_km_nk_mn_instances = std::tuple<
// clang-format off
//#########################################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//#########################################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector|
//#########################################| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl|
//#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmXdl_C_Shuffle_Bias_Activation_Add< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, AddReluAdd, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>,
DeviceGemmXdl_C_Shuffle_Bias_Activation_Add< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, AddReluAdd, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>,
DeviceGemmXdl_C_Shuffle_Bias_Activation_Add< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, AddReluAdd, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>,
DeviceGemmXdl_C_Shuffle_Bias_Activation_Add< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, AddReluAdd, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>,
DeviceGemmXdl_C_Shuffle_Bias_Activation_Add< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, AddReluAdd, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>,
DeviceGemmXdl_C_Shuffle_Bias_Activation_Add< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, AddReluAdd, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>,
DeviceGemmXdl_C_Shuffle_Bias_Activation_Add< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, AddReluAdd, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>,
DeviceGemmXdl_C_Shuffle_Bias_Activation_Add< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, AddReluAdd, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>
// clang-format on
>;
void add_device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_km_nk_mn_instances(
std::vector<DeviceGemmBiasActivationAddPtr<PassThrough, PassThrough, AddReluAdd>>& instances)
{
add_device_operation_instances(
instances, device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_km_nk_mn_instances{});
}
} // namespace device_gemm_instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,52 @@
#include <stdlib.h>
#include "config.hpp"
#include "device_gemm_xdl_c_shuffle_bias_activation_add.hpp"
#include "element_wise_operation.hpp"
#include "device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_gemm_instance {
using F16 = ck::half_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using AddReluAdd = ck::tensor_operation::element_wise::AddReluAdd;
// c[m, n] = ReLU(a[m, k] * b[k, n] + c0[n]) + c1[m, n]
using device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_mk_kn_mn_instances = std::tuple<
// clang-format off
//#########################################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//#########################################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector|
//#########################################| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl|
//#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmXdl_C_Shuffle_Bias_Activation_Add< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, AddReluAdd, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>,
DeviceGemmXdl_C_Shuffle_Bias_Activation_Add< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, AddReluAdd, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>,
DeviceGemmXdl_C_Shuffle_Bias_Activation_Add< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, AddReluAdd, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>,
DeviceGemmXdl_C_Shuffle_Bias_Activation_Add< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, AddReluAdd, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>,
DeviceGemmXdl_C_Shuffle_Bias_Activation_Add< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, AddReluAdd, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>,
DeviceGemmXdl_C_Shuffle_Bias_Activation_Add< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, AddReluAdd, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>,
DeviceGemmXdl_C_Shuffle_Bias_Activation_Add< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, AddReluAdd, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>,
DeviceGemmXdl_C_Shuffle_Bias_Activation_Add< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, AddReluAdd, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>
// clang-format on
>;
void add_device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_mk_kn_mn_instances(
std::vector<DeviceGemmBiasActivationAddPtr<PassThrough, PassThrough, AddReluAdd>>& instances)
{
add_device_operation_instances(
instances, device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_mk_kn_mn_instances{});
}
} // namespace device_gemm_instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,57 @@
#include <stdlib.h>
#include "config.hpp"
#include "device_gemm_xdl_c_shuffle_bias_activation_add.hpp"
#include "element_wise_operation.hpp"
#include "device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_gemm_instance {
using F16 = ck::half_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using AddReluAdd = ck::tensor_operation::element_wise::AddReluAdd;
// c[m, n] = ReLU(a[m, k] * b[n, k] + c0[n]) + c1[m, n]
using device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_mk_nk_mn_instances = std::tuple<
// clang-format off
//#########################################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//#########################################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector|
//#########################################| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl|
//#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmXdl_C_Shuffle_Bias_Activation_Add< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, AddReluAdd, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>,
DeviceGemmXdl_C_Shuffle_Bias_Activation_Add< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, AddReluAdd, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>,
DeviceGemmXdl_C_Shuffle_Bias_Activation_Add< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, AddReluAdd, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>,
DeviceGemmXdl_C_Shuffle_Bias_Activation_Add< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, AddReluAdd, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>,
DeviceGemmXdl_C_Shuffle_Bias_Activation_Add< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, AddReluAdd, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>,
DeviceGemmXdl_C_Shuffle_Bias_Activation_Add< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, AddReluAdd, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>,
DeviceGemmXdl_C_Shuffle_Bias_Activation_Add< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, AddReluAdd, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>,
DeviceGemmXdl_C_Shuffle_Bias_Activation_Add< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, AddReluAdd, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>,
DeviceGemmXdl_C_Shuffle_Bias_Activation_Add< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, AddReluAdd, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>,
DeviceGemmXdl_C_Shuffle_Bias_Activation_Add< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, AddReluAdd, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>,
DeviceGemmXdl_C_Shuffle_Bias_Activation_Add< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, AddReluAdd, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>,
DeviceGemmXdl_C_Shuffle_Bias_Activation_Add< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, AddReluAdd, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>,
DeviceGemmXdl_C_Shuffle_Bias_Activation_Add< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, AddReluAdd, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>
// clang-format on
>;
void add_device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_mk_nk_mn_instances(
std::vector<DeviceGemmBiasActivationAddPtr<PassThrough, PassThrough, AddReluAdd>>& instances)
{
add_device_operation_instances(
instances, device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_mk_nk_mn_instances{});
}
} // namespace device_gemm_instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,52 @@
#include <stdlib.h>
#include "config.hpp"
#include "device_gemm_xdl_c_shuffle_bias_activation.hpp"
#include "element_wise_operation.hpp"
#include "device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_gemm_instance {
using F16 = ck::half_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using AddRelu = ck::tensor_operation::element_wise::AddRelu;
// c[m, n] = ReLU(a[k, m] * b[k, n] + c0[n])
using device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_km_kn_mn_instances = std::tuple<
// clang-format off
//#####################################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//#####################################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector|
//#####################################| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl|
//#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmXdl_C_Shuffle_Bias_Activation< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, AddRelu, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>,
DeviceGemmXdl_C_Shuffle_Bias_Activation< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, AddRelu, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>,
DeviceGemmXdl_C_Shuffle_Bias_Activation< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, AddRelu, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>,
DeviceGemmXdl_C_Shuffle_Bias_Activation< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, AddRelu, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>,
DeviceGemmXdl_C_Shuffle_Bias_Activation< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, AddRelu, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>,
DeviceGemmXdl_C_Shuffle_Bias_Activation< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, AddRelu, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>,
DeviceGemmXdl_C_Shuffle_Bias_Activation< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, AddRelu, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>,
DeviceGemmXdl_C_Shuffle_Bias_Activation< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, AddRelu, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>
// clang-format on
>;
void add_device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_km_kn_mn_instances(
std::vector<DeviceGemmBiasActivationPtr<PassThrough, PassThrough, AddRelu>>& instances)
{
add_device_operation_instances(
instances, device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_km_kn_mn_instances{});
}
} // namespace device_gemm_instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,52 @@
#include <stdlib.h>
#include "config.hpp"
#include "device_gemm_xdl_c_shuffle_bias_activation.hpp"
#include "element_wise_operation.hpp"
#include "device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_gemm_instance {
using F16 = ck::half_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using AddRelu = ck::tensor_operation::element_wise::AddRelu;
// c[m, n] = ReLU(a[k, m] * b[n, k] + c0[n])
using device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_km_nk_mn_instances = std::tuple<
// clang-format off
//#####################################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//#####################################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector|
//#####################################| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl|
//#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmXdl_C_Shuffle_Bias_Activation< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, AddRelu, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>,
DeviceGemmXdl_C_Shuffle_Bias_Activation< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, AddRelu, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>,
DeviceGemmXdl_C_Shuffle_Bias_Activation< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, AddRelu, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>,
DeviceGemmXdl_C_Shuffle_Bias_Activation< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, AddRelu, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>,
DeviceGemmXdl_C_Shuffle_Bias_Activation< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, AddRelu, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>,
DeviceGemmXdl_C_Shuffle_Bias_Activation< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, AddRelu, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>,
DeviceGemmXdl_C_Shuffle_Bias_Activation< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, AddRelu, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>,
DeviceGemmXdl_C_Shuffle_Bias_Activation< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, AddRelu, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>
// clang-format on
>;
void add_device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_km_nk_mn_instances(
std::vector<DeviceGemmBiasActivationPtr<PassThrough, PassThrough, AddRelu>>& instances)
{
add_device_operation_instances(
instances, device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_km_nk_mn_instances{});
}
} // namespace device_gemm_instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,52 @@
#include <stdlib.h>
#include "config.hpp"
#include "device_gemm_xdl_c_shuffle_bias_activation.hpp"
#include "element_wise_operation.hpp"
#include "device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_gemm_instance {
using F16 = ck::half_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using AddRelu = ck::tensor_operation::element_wise::AddRelu;
// c[m, n] = ReLU(a[m, k] * b[k, n] + c0[n])
using device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_mk_kn_mn_instances = std::tuple<
// clang-format off
//#####################################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//#####################################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector|
//#####################################| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl|
//#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmXdl_C_Shuffle_Bias_Activation< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, AddRelu, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>,
DeviceGemmXdl_C_Shuffle_Bias_Activation< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, AddRelu, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>,
DeviceGemmXdl_C_Shuffle_Bias_Activation< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, AddRelu, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>,
DeviceGemmXdl_C_Shuffle_Bias_Activation< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, AddRelu, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>,
DeviceGemmXdl_C_Shuffle_Bias_Activation< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, AddRelu, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>,
DeviceGemmXdl_C_Shuffle_Bias_Activation< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, AddRelu, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>,
DeviceGemmXdl_C_Shuffle_Bias_Activation< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, AddRelu, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>,
DeviceGemmXdl_C_Shuffle_Bias_Activation< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, AddRelu, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>
// clang-format on
>;
void add_device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_mk_kn_mn_instances(
std::vector<DeviceGemmBiasActivationPtr<PassThrough, PassThrough, AddRelu>>& instances)
{
add_device_operation_instances(
instances, device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_mk_kn_mn_instances{});
}
} // namespace device_gemm_instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,57 @@
#include <stdlib.h>
#include "config.hpp"
#include "device_gemm_xdl_c_shuffle_bias_activation.hpp"
#include "element_wise_operation.hpp"
#include "device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_gemm_instance {
using F16 = ck::half_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using AddRelu = ck::tensor_operation::element_wise::AddRelu;
// c[m, n] = ReLU(a[m, k] * b[n, k] + c0[n])
using device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_mk_nk_mn_instances = std::tuple<
// clang-format off
//#####################################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//#####################################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector|
//#####################################| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl|
//#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmXdl_C_Shuffle_Bias_Activation< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, AddRelu, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>,
DeviceGemmXdl_C_Shuffle_Bias_Activation< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, AddRelu, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>,
DeviceGemmXdl_C_Shuffle_Bias_Activation< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, AddRelu, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>,
DeviceGemmXdl_C_Shuffle_Bias_Activation< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, AddRelu, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>,
DeviceGemmXdl_C_Shuffle_Bias_Activation< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, AddRelu, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>,
DeviceGemmXdl_C_Shuffle_Bias_Activation< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, AddRelu, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>,
DeviceGemmXdl_C_Shuffle_Bias_Activation< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, AddRelu, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>,
DeviceGemmXdl_C_Shuffle_Bias_Activation< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, AddRelu, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>,
DeviceGemmXdl_C_Shuffle_Bias_Activation< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, AddRelu, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>,
DeviceGemmXdl_C_Shuffle_Bias_Activation< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, AddRelu, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>,
DeviceGemmXdl_C_Shuffle_Bias_Activation< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, AddRelu, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>,
DeviceGemmXdl_C_Shuffle_Bias_Activation< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, AddRelu, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>,
DeviceGemmXdl_C_Shuffle_Bias_Activation< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, AddRelu, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>
// clang-format on
>;
void add_device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_mk_nk_mn_instances(
std::vector<DeviceGemmBiasActivationPtr<PassThrough, PassThrough, AddRelu>>& instances)
{
add_device_operation_instances(
instances, device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_mk_nk_mn_instances{});
}
} // namespace device_gemm_instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,52 @@
#include <stdlib.h>
#include "config.hpp"
#include "device_gemm_xdl_c_shuffle.hpp"
#include "element_wise_operation.hpp"
#include "device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_gemm_instance {
using F16 = ck::half_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
// Compilation parameters for a[k, m] * b[k, n] = c[m, n]
using device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances =
std::tuple<
// clang-format off
//#####################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//#####################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector|
//#####################| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl|
//#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmXdl_C_Shuffle< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>,
DeviceGemmXdl_C_Shuffle< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>,
DeviceGemmXdl_C_Shuffle< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>,
DeviceGemmXdl_C_Shuffle< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>,
DeviceGemmXdl_C_Shuffle< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>,
DeviceGemmXdl_C_Shuffle< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>,
DeviceGemmXdl_C_Shuffle< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>,
DeviceGemmXdl_C_Shuffle< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>
// clang-format on
>;
void add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances(
std::vector<DeviceGemmPtr<PassThrough, PassThrough, PassThrough>>& instances)
{
add_device_operation_instances(instances,
device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances{});
}
} // namespace device_gemm_instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,52 @@
#include <stdlib.h>
#include "config.hpp"
#include "device_gemm_xdl_c_shuffle.hpp"
#include "element_wise_operation.hpp"
#include "device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_gemm_instance {
using F16 = ck::half_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
// Compilation parameters for a[k, m] * b[n, k] = c[m, n]
using device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances =
std::tuple<
// clang-format off
//#####################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//#####################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector|
//#####################| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl|
//#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmXdl_C_Shuffle< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>,
DeviceGemmXdl_C_Shuffle< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>,
DeviceGemmXdl_C_Shuffle< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>,
DeviceGemmXdl_C_Shuffle< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>,
DeviceGemmXdl_C_Shuffle< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>,
DeviceGemmXdl_C_Shuffle< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>,
DeviceGemmXdl_C_Shuffle< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>,
DeviceGemmXdl_C_Shuffle< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>
// clang-format on
>;
void add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances(
std::vector<DeviceGemmPtr<PassThrough, PassThrough, PassThrough>>& instances)
{
add_device_operation_instances(instances,
device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances{});
}
} // namespace device_gemm_instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,52 @@
#include <stdlib.h>
#include "config.hpp"
#include "device_gemm_xdl_c_shuffle.hpp"
#include "element_wise_operation.hpp"
#include "device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_gemm_instance {
using F16 = ck::half_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
// Compilation parameters for a[m, k] * b[k, n] = c[m, n]
using device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances =
std::tuple<
// clang-format off
//#####################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//#####################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector|
//#####################| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl|
//#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmXdl_C_Shuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>,
DeviceGemmXdl_C_Shuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>,
DeviceGemmXdl_C_Shuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>,
DeviceGemmXdl_C_Shuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>,
DeviceGemmXdl_C_Shuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>,
DeviceGemmXdl_C_Shuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>,
DeviceGemmXdl_C_Shuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>,
DeviceGemmXdl_C_Shuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>
// clang-format on
>;
void add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances(
std::vector<DeviceGemmPtr<PassThrough, PassThrough, PassThrough>>& instances)
{
add_device_operation_instances(instances,
device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances{});
}
} // namespace device_gemm_instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,57 @@
#include <stdlib.h>
#include "config.hpp"
#include "device_gemm_xdl_c_shuffle.hpp"
#include "element_wise_operation.hpp"
#include "device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_gemm_instance {
using F16 = ck::half_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
// Compilation parameters for a[m, k] * b[n, k] = c[m, n]
using device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances =
std::tuple<
// clang-format off
//#####################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//#####################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector|
//#####################| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl|
//#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmXdl_C_Shuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>,
DeviceGemmXdl_C_Shuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>,
DeviceGemmXdl_C_Shuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>,
DeviceGemmXdl_C_Shuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>,
DeviceGemmXdl_C_Shuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>,
DeviceGemmXdl_C_Shuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>,
DeviceGemmXdl_C_Shuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>,
DeviceGemmXdl_C_Shuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>,
DeviceGemmXdl_C_Shuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>,
DeviceGemmXdl_C_Shuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>,
DeviceGemmXdl_C_Shuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>,
DeviceGemmXdl_C_Shuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>,
DeviceGemmXdl_C_Shuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>
// clang-format on
>;
void add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances(
std::vector<DeviceGemmPtr<PassThrough, PassThrough, PassThrough>>& instances)
{
add_device_operation_instances(instances,
device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances{});
}
} // namespace device_gemm_instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,51 @@
#include <stdlib.h>
#include "config.hpp"
#include "device_gemm_xdl.hpp"
#include "element_wise_operation.hpp"
#include "device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_gemm_instance {
using F16 = ck::half_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
// Compilation parameters for a[k, m] * b[k, n] = c[m, n]
using device_gemm_xdl_f16_f16_f16_km_kn_mn_instances =
std::tuple<
// clang-format off
//##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer|
//##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
//##########| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
//##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>,
DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>,
DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>,
DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>,
DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>,
DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>,
DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1>,
DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>
// clang-format on
>;
void add_device_gemm_xdl_f16_f16_f16_km_kn_mn_instances(
std::vector<DeviceGemmPtr<PassThrough, PassThrough, PassThrough>>& instances)
{
add_device_operation_instances(instances, device_gemm_xdl_f16_f16_f16_km_kn_mn_instances{});
}
} // namespace device_gemm_instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,51 @@
#include <stdlib.h>
#include "config.hpp"
#include "device_gemm_xdl.hpp"
#include "element_wise_operation.hpp"
#include "device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_gemm_instance {
using F16 = ck::half_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
// Compilation parameters for a[k, m] * b[n, k] = c[m, n]
using device_gemm_xdl_f16_f16_f16_km_nk_mn_instances =
std::tuple<
// clang-format off
//##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer|
//##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
//##########| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
//##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>
// clang-format on
>;
void add_device_gemm_xdl_f16_f16_f16_km_nk_mn_instances(
std::vector<DeviceGemmPtr<PassThrough, PassThrough, PassThrough>>& instances)
{
add_device_operation_instances(instances, device_gemm_xdl_f16_f16_f16_km_nk_mn_instances{});
}
} // namespace device_gemm_instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,51 @@
#include <stdlib.h>
#include "config.hpp"
#include "device_gemm_xdl.hpp"
#include "element_wise_operation.hpp"
#include "device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_gemm_instance {
using F16 = ck::half_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
// Compilation parameters for a[m, k] * b[k, n] = c[m, n]
using device_gemm_xdl_f16_f16_f16_mk_kn_mn_instances =
std::tuple<
// clang-format off
//##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer|
//##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
//##########| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
//##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>,
DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>,
DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>,
DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>,
DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>,
DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>,
DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1>,
DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1>
// clang-format on
>;
void add_device_gemm_xdl_f16_f16_f16_mk_kn_mn_instances(
std::vector<DeviceGemmPtr<PassThrough, PassThrough, PassThrough>>& instances)
{
add_device_operation_instances(instances, device_gemm_xdl_f16_f16_f16_mk_kn_mn_instances{});
}
} // namespace device_gemm_instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,56 @@
#include <stdlib.h>
#include "config.hpp"
#include "device_gemm_xdl.hpp"
#include "element_wise_operation.hpp"
#include "device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_gemm_instance {
using F16 = ck::half_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
// Compilation parameters for a[m, k] * b[n, k] = c[m, n]
using device_gemm_xdl_f16_f16_f16_mk_nk_mn_instances =
std::tuple<
// clang-format off
//##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer|
//##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
//##########| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
//##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
DeviceGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>
// clang-format on
>;
void add_device_gemm_xdl_f16_f16_f16_mk_nk_mn_instances(
std::vector<DeviceGemmPtr<PassThrough, PassThrough, PassThrough>>& instances)
{
add_device_operation_instances(instances, device_gemm_xdl_f16_f16_f16_mk_nk_mn_instances{});
}
} // namespace device_gemm_instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,51 @@
#include <stdlib.h>
#include "config.hpp"
#include "device_gemm_xdl.hpp"
#include "element_wise_operation.hpp"
#include "device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_gemm_instance {
using F16 = ck::half_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
// Compilation parameters for a[k, m] * b[k, n] = c[m, n]
using device_gemm_xdl_f32_f32_f32_km_kn_mn_instances =
std::tuple<
// clang-format off
//##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer|
//##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
//##########| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
//##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmXdl< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>,
DeviceGemmXdl< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>,
DeviceGemmXdl< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>,
DeviceGemmXdl< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>,
DeviceGemmXdl< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>,
DeviceGemmXdl< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>,
DeviceGemmXdl< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, 7, 1>,
DeviceGemmXdl< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>
// clang-format on
>;
void add_device_gemm_xdl_f32_f32_f32_km_kn_mn_instances(
std::vector<DeviceGemmPtr<PassThrough, PassThrough, PassThrough>>& instances)
{
add_device_operation_instances(instances, device_gemm_xdl_f32_f32_f32_km_kn_mn_instances{});
}
} // namespace device_gemm_instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,51 @@
#include <stdlib.h>
#include "config.hpp"
#include "device_gemm_xdl.hpp"
#include "element_wise_operation.hpp"
#include "device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_gemm_instance {
using F16 = ck::half_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
// Compilation parameters for a[k, m] * b[n, k] = c[m, n]
using device_gemm_xdl_f32_f32_f32_km_nk_mn_instances =
std::tuple<
// clang-format off
//##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer|
//##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
//##########| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
//##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmXdl< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
DeviceGemmXdl< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
DeviceGemmXdl< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
DeviceGemmXdl< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
DeviceGemmXdl< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
DeviceGemmXdl< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
DeviceGemmXdl< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
DeviceGemmXdl< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>
// clang-format on
>;
void add_device_gemm_xdl_f32_f32_f32_km_nk_mn_instances(
std::vector<DeviceGemmPtr<PassThrough, PassThrough, PassThrough>>& instances)
{
add_device_operation_instances(instances, device_gemm_xdl_f32_f32_f32_km_nk_mn_instances{});
}
} // namespace device_gemm_instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,51 @@
#include <stdlib.h>
#include "config.hpp"
#include "device_gemm_xdl.hpp"
#include "element_wise_operation.hpp"
#include "device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_gemm_instance {
using F16 = ck::half_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
// Compilation parameters for a[m, k] * b[k, n] = c[m, n]
using device_gemm_xdl_f32_f32_f32_mk_kn_mn_instances =
std::tuple<
// clang-format off
//##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer|
//##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
//##########| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
//##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>,
DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>,
DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>,
DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>,
DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>,
DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>,
DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, 7, 1>,
DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1>
// clang-format on
>;
void add_device_gemm_xdl_f32_f32_f32_mk_kn_mn_instances(
std::vector<DeviceGemmPtr<PassThrough, PassThrough, PassThrough>>& instances)
{
add_device_operation_instances(instances, device_gemm_xdl_f32_f32_f32_mk_kn_mn_instances{});
}
} // namespace device_gemm_instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,56 @@
#include <stdlib.h>
#include "config.hpp"
#include "device_gemm_xdl.hpp"
#include "element_wise_operation.hpp"
#include "device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_gemm_instance {
using F16 = ck::half_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
// Compilation parameters for a[m, k] * b[n, k] = c[m, n]
using device_gemm_xdl_f32_f32_f32_mk_nk_mn_instances =
std::tuple<
// clang-format off
//##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer|
//##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
//##########| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
//##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmXdl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
DeviceGemmXdl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
DeviceGemmXdl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
DeviceGemmXdl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
DeviceGemmXdl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
DeviceGemmXdl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
DeviceGemmXdl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
DeviceGemmXdl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
DeviceGemmXdl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
DeviceGemmXdl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
DeviceGemmXdl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
DeviceGemmXdl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
DeviceGemmXdl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>
// clang-format on
>;
void add_device_gemm_xdl_f32_f32_f32_mk_nk_mn_instances(
std::vector<DeviceGemmPtr<PassThrough, PassThrough, PassThrough>>& instances)
{
add_device_operation_instances(instances, device_gemm_xdl_f32_f32_f32_mk_nk_mn_instances{});
}
} // namespace device_gemm_instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,51 @@
#include <stdlib.h>
#include "config.hpp"
#include "device_gemm_xdl_splitk.hpp"
#include "element_wise_operation.hpp"
#include "device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_gemm_instance {
using F16 = ck::half_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
// Compilation parameters for a[k, m] * b[k, n] = c[m, n]
using device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instances = std::tuple<
// clang-format off
//#################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer|
//#################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
//#################| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
//#################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmXdlSplitK< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, true, 7, 1>,
DeviceGemmXdlSplitK< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, true, 7, 1>,
DeviceGemmXdlSplitK< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, true, 7, 1>,
DeviceGemmXdlSplitK< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, true, 7, 1>,
DeviceGemmXdlSplitK< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, true, 7, 1>,
DeviceGemmXdlSplitK< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, true, 7, 1>,
DeviceGemmXdlSplitK< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 4, true, 7, 1>,
DeviceGemmXdlSplitK< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 4, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, true, 7, 1>
// clang-format on
>;
void add_device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instances(
std::vector<DeviceGemmPtr<PassThrough, PassThrough, PassThrough>>& instances)
{
add_device_operation_instances(instances,
device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instances{});
}
} // namespace device_gemm_instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,51 @@
#include <stdlib.h>
#include "config.hpp"
#include "device_gemm_xdl_splitk.hpp"
#include "element_wise_operation.hpp"
#include "device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_gemm_instance {
using F16 = ck::half_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
// Compilation parameters for a[k, m] * b[n, k] = c[m, n]
using device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instances = std::tuple<
// clang-format off
//#################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer|
//#################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
//#################| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
//#################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmXdlSplitK< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, true, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, 7, 1>,
DeviceGemmXdlSplitK< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, true, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, 7, 1>,
DeviceGemmXdlSplitK< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, true, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, 7, 1>,
DeviceGemmXdlSplitK< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, true, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, 7, 1>,
DeviceGemmXdlSplitK< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, true, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, 7, 1>,
DeviceGemmXdlSplitK< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, true, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, 7, 1>,
DeviceGemmXdlSplitK< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, true, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, 7, 1>,
DeviceGemmXdlSplitK< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 4, true, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, 7, 1>
// clang-format on
>;
void add_device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instances(
std::vector<DeviceGemmPtr<PassThrough, PassThrough, PassThrough>>& instances)
{
add_device_operation_instances(instances,
device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instances{});
}
} // namespace device_gemm_instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,51 @@
#include <stdlib.h>
#include "config.hpp"
#include "device_gemm_xdl_splitk.hpp"
#include "element_wise_operation.hpp"
#include "device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_gemm_instance {
using F16 = ck::half_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
// Compilation parameters for a[m, k] * b[k, n] = c[m, n]
using device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instances = std::tuple<
// clang-format off
//#################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer|
//#################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
//#################| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
//#################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmXdlSplitK< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 96, 128, 4, 8, 16, 16, 3, 4, S<1, 4, 32, 2>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, true, 7, 1>,
DeviceGemmXdlSplitK< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, true, 7, 1>,
DeviceGemmXdlSplitK< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, true, 7, 1>,
DeviceGemmXdlSplitK< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, true, 7, 1>,
DeviceGemmXdlSplitK< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, true, 7, 1>,
DeviceGemmXdlSplitK< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, true, 7, 1>,
DeviceGemmXdlSplitK< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, true, 7, 1>,
DeviceGemmXdlSplitK< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 4, true, 7, 1>,
DeviceGemmXdlSplitK< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, true, 7, 1>
>;
void add_device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instances(
std::vector<DeviceGemmPtr<PassThrough, PassThrough, PassThrough>>& instances)
{
add_device_operation_instances(instances,
device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instances{});
}
} // namespace device_gemm_instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,56 @@
#include <stdlib.h>
#include "config.hpp"
#include "device_gemm_xdl_splitk.hpp"
#include "element_wise_operation.hpp"
#include "device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_gemm_instance {
using F16 = ck::half_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
// Compilation parameters for a[m, k] * b[n, k] = c[m, n]
using device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instances = std::tuple<
// clang-format off
//#################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer|
//#################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
//#################| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
//#################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmXdlSplitK< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, 7, 1>,
DeviceGemmXdlSplitK< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, 7, 1>,
DeviceGemmXdlSplitK< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, 7, 1>,
DeviceGemmXdlSplitK< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, 7, 1>,
DeviceGemmXdlSplitK< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, 7, 1>,
DeviceGemmXdlSplitK< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, 7, 1>,
DeviceGemmXdlSplitK< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, 7, 1>,
DeviceGemmXdlSplitK< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, 7, 1>,
DeviceGemmXdlSplitK< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, 7, 1>,
DeviceGemmXdlSplitK< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, 7, 1>,
DeviceGemmXdlSplitK< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, 7, 1>,
DeviceGemmXdlSplitK< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, 7, 1>,
DeviceGemmXdlSplitK< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, 7, 1>
// clang-format on
>;
void add_device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instances(
std::vector<DeviceGemmPtr<PassThrough, PassThrough, PassThrough>>& instances)
{
add_device_operation_instances(instances,
device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instances{});
}
} // namespace device_gemm_instance
} // namespace device
} // namespace tensor_operation
} // namespace ck