mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
Unify Convolution FWD XDL 1D/2D implementation. (#93)
* Convolution ND
* Code unification across dimensions for generating tensor descriptors.
* Example
* Instances
* Move convnd f32 instance file to comply with repo structure.
* Conv 1D tensor layouts.
* Formatting and use ReferenceConv
* Reference ConvFwd supporting 1D and 2D convolution.
* Debug printing TensorLayout name.
* Conv fwd 1D instance f32
* Refactor conv ND example.
Needed to support various conv dimensio.
Needed to support various conv dimensions
* Rename conv nd example director to prevent conflicts.
* Refactor some common utility to single file.
Plus some tests.
* Refactor GetHostTensorDescriptor + UT.
* Add 1D test case.
* Test reference convolution 1d/2d
* Remove some leftovers.
* Fix convolution example error for 1D
* Refactor test check errors utility function.
* Test Conv2D Fwd XDL
* More UT for 1D case.
* Parameterize input & weight initializers.
* Rename example to prevent conflicts.
* Split convnd instance into separate files for 1d/2d
* Address review comments.
* Fix data type for flops/gbytes calculations.
* Assign example number 11.
Co-authored-by: Adam Osewski <aosewski@amd.com>
Co-authored-by: Chao Liu <chao.liu2@amd.com>
[ROCm/composable_kernel commit: 756a761727]
This commit is contained in:
@@ -1,6 +1,8 @@
|
||||
#ifndef CK_ELEMENT_WISE_OPERATION_HPP
|
||||
#define CK_ELEMENT_WISE_OPERATION_HPP
|
||||
|
||||
#include "data_type.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace element_wise {
|
||||
|
||||
@@ -76,6 +76,11 @@ set(DEVICE_CONV2D_FWD_INSTANCE_SOURCE
|
||||
${PROJECT_SOURCE_DIR}/device_operation/src/device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instance.cpp;
|
||||
)
|
||||
|
||||
# device_conv1d_fwd_instance
|
||||
set(DEVICE_CONV1D_FWD_INSTANCE_SOURCE
|
||||
${PROJECT_SOURCE_DIR}/device_operation/src/device_conv1d_fwd_xdl_nwc_kxc_nwk_f32_instance.cpp;
|
||||
)
|
||||
|
||||
# device_conv2d_fwd_bias_relu_instance
|
||||
set(DEVICE_CONV2D_FWD_BIAS_RELU_INSTANCE_SOURCE
|
||||
${PROJECT_SOURCE_DIR}/device_operation/src/device_conv2d_fwd_xdl_c_shuffle_bias_relu_nhwc_kyxc_nhwk_f16_instance.cpp;
|
||||
@@ -96,16 +101,18 @@ add_library(device_gemm_bias_2d_instance SHARED ${DEVICE_GEMM_BIAS_2D_INSTANCE_S
|
||||
add_library(device_gemm_bias_relu_instance SHARED ${DEVICE_GEMM_BIAS_RELU_INSTANCE_SOURCE})
|
||||
add_library(device_gemm_bias_relu_add_instance SHARED ${DEVICE_GEMM_BIAS_RELU_ADD_INSTANCE_SOURCE})
|
||||
add_library(device_batched_gemm_instance SHARED ${DEVICE_BATCHED_GEMM_INSTANCE_SOURCE})
|
||||
add_library(device_conv2d_fwd_instance SHARED ${DEVICE_CONV2D_FWD_INSTANCE_SOURCE})
|
||||
add_library(device_conv2d_fwd_bias_relu_instance SHARED ${DEVICE_CONV2D_FWD_BIAS_RELU_INSTANCE_SOURCE})
|
||||
add_library(device_conv2d_fwd_bias_relu_add_instance SHARED ${DEVICE_CONV2D_FWD_BIAS_RELU_ADD_INSTANCE_SOURCE})
|
||||
add_library(device_conv2d_fwd_bias_relu_atomic_add_instance SHARED ${DEVICE_CONV2D_FWD_BIAS_RELU_ATOMIC_ADD_INSTANCE_SOURCE})
|
||||
add_library(device_conv1d_fwd_instance SHARED ${DEVICE_CONV1D_FWD_INSTANCE_SOURCE})
|
||||
add_library(device_conv2d_fwd_instance SHARED ${DEVICE_CONV2D_FWD_INSTANCE_SOURCE})
|
||||
add_library(device_conv2d_fwd_bias_relu_instance SHARED ${DEVICE_CONV2D_FWD_BIAS_RELU_INSTANCE_SOURCE})
|
||||
add_library(device_conv2d_fwd_bias_relu_add_instance SHARED ${DEVICE_CONV2D_FWD_BIAS_RELU_ADD_INSTANCE_SOURCE})
|
||||
add_library(device_conv2d_fwd_bias_relu_atomic_add_instance SHARED ${DEVICE_CONV2D_FWD_BIAS_RELU_ATOMIC_ADD_INSTANCE_SOURCE})
|
||||
|
||||
target_include_directories(device_gemm_instance SYSTEM PUBLIC $<BUILD_INTERFACE:${HALF_INCLUDE_DIR}>)
|
||||
target_include_directories(device_gemm_bias_2d_instance SYSTEM PUBLIC $<BUILD_INTERFACE:${HALF_INCLUDE_DIR}>)
|
||||
target_include_directories(device_gemm_bias_relu_instance SYSTEM PUBLIC $<BUILD_INTERFACE:${HALF_INCLUDE_DIR}>)
|
||||
target_include_directories(device_gemm_bias_relu_add_instance SYSTEM PUBLIC $<BUILD_INTERFACE:${HALF_INCLUDE_DIR}>)
|
||||
target_include_directories(device_batched_gemm_instance SYSTEM PUBLIC $<BUILD_INTERFACE:${HALF_INCLUDE_DIR}>)
|
||||
target_include_directories(device_conv1d_fwd_instance SYSTEM PUBLIC $<BUILD_INTERFACE:${HALF_INCLUDE_DIR}>)
|
||||
target_include_directories(device_conv2d_fwd_instance SYSTEM PUBLIC $<BUILD_INTERFACE:${HALF_INCLUDE_DIR}>)
|
||||
target_include_directories(device_conv2d_fwd_bias_relu_instance SYSTEM PUBLIC $<BUILD_INTERFACE:${HALF_INCLUDE_DIR}>)
|
||||
target_include_directories(device_conv2d_fwd_bias_relu_add_instance SYSTEM PUBLIC $<BUILD_INTERFACE:${HALF_INCLUDE_DIR}>)
|
||||
@@ -116,6 +123,7 @@ target_compile_features(device_gemm_bias_2d_instance PUBLIC)
|
||||
target_compile_features(device_gemm_bias_relu_instance PUBLIC)
|
||||
target_compile_features(device_gemm_bias_relu_add_instance PUBLIC)
|
||||
target_compile_features(device_batched_gemm_instance PUBLIC)
|
||||
target_compile_features(device_conv1d_fwd_instance PUBLIC)
|
||||
target_compile_features(device_conv2d_fwd_instance PUBLIC)
|
||||
target_compile_features(device_conv2d_fwd_bias_relu_instance PUBLIC)
|
||||
target_compile_features(device_conv2d_fwd_bias_relu_add_instance PUBLIC)
|
||||
@@ -126,6 +134,7 @@ set_target_properties(device_gemm_bias_2d_instance PROPERTIES POSITION_INDEPENDE
|
||||
set_target_properties(device_gemm_bias_relu_instance PROPERTIES POSITION_INDEPENDENT_CODE ON)
|
||||
set_target_properties(device_gemm_bias_relu_add_instance PROPERTIES POSITION_INDEPENDENT_CODE ON)
|
||||
set_target_properties(device_batched_gemm_instance PROPERTIES POSITION_INDEPENDENT_CODE ON)
|
||||
set_target_properties(device_conv1d_fwd_instance PROPERTIES POSITION_INDEPENDENT_CODE ON)
|
||||
set_target_properties(device_conv2d_fwd_instance PROPERTIES POSITION_INDEPENDENT_CODE ON)
|
||||
set_target_properties(device_conv2d_fwd_bias_relu_instance PROPERTIES POSITION_INDEPENDENT_CODE ON)
|
||||
set_target_properties(device_conv2d_fwd_bias_relu_add_instance PROPERTIES POSITION_INDEPENDENT_CODE ON)
|
||||
@@ -136,7 +145,8 @@ install(TARGETS device_gemm_bias_2d_instance LIBRARY DESTINATION lib)
|
||||
install(TARGETS device_gemm_bias_relu_instance LIBRARY DESTINATION lib)
|
||||
install(TARGETS device_gemm_bias_relu_add_instance LIBRARY DESTINATION lib)
|
||||
install(TARGETS device_batched_gemm_instance LIBRARY DESTINATION lib)
|
||||
install(TARGETS device_conv2d_fwd_instance LIBRARY DESTINATION lib)
|
||||
install(TARGETS device_conv2d_fwd_bias_relu_instance LIBRARY DESTINATION lib)
|
||||
install(TARGETS device_conv2d_fwd_bias_relu_add_instance LIBRARY DESTINATION lib)
|
||||
install(TARGETS device_conv2d_fwd_bias_relu_atomic_add_instance LIBRARY DESTINATION lib)
|
||||
install(TARGETS device_conv1d_fwd_instance LIBRARY DESTINATION lib)
|
||||
install(TARGETS device_conv2d_fwd_instance LIBRARY DESTINATION lib)
|
||||
install(TARGETS device_conv2d_fwd_bias_relu_instance LIBRARY DESTINATION lib)
|
||||
install(TARGETS device_conv2d_fwd_bias_relu_add_instance LIBRARY DESTINATION lib)
|
||||
install(TARGETS device_conv2d_fwd_bias_relu_atomic_add_instance LIBRARY DESTINATION lib)
|
||||
|
||||
198
device_operation/include/conv_utils.hpp
Normal file
198
device_operation/include/conv_utils.hpp
Normal file
@@ -0,0 +1,198 @@
|
||||
#ifndef CONV_UTILS_HPP
|
||||
#define CONV_UTILS_HPP
|
||||
|
||||
#include <cstdlib>
|
||||
#include <functional>
|
||||
#include <iterator>
|
||||
#include <numeric>
|
||||
#include <sstream>
|
||||
#include <type_traits>
|
||||
#include <vector>
|
||||
|
||||
#include "config.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "tensor_layout.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace conv_util {
|
||||
|
||||
/**
|
||||
* @brief Calculate number of FLOPs for Convolution
|
||||
*
|
||||
* @param[in] N Batch size.
|
||||
* @param[in] C Number of input channels.
|
||||
* @param[in] K Number of output channels.
|
||||
* @param[in] filter_spatial_lengths Filter spatial dimensions lengths.
|
||||
* @param[in] output_spatial_lengths Convolution output spatial dimensions
|
||||
* lengths.
|
||||
*
|
||||
* @return The number of flops.
|
||||
*/
|
||||
std::size_t GetFlops(ck::index_t N,
|
||||
ck::index_t C,
|
||||
ck::index_t K,
|
||||
const std::vector<ck::index_t>& filter_spatial_lengths,
|
||||
const std::vector<ck::index_t>& output_spatial_lengths)
|
||||
{
|
||||
// 2 * N * K * <output spatial lengths product> * C * <filter spatial lengths product>
|
||||
return static_cast<std::size_t>(2) * N * K *
|
||||
std::accumulate(std::begin(output_spatial_lengths),
|
||||
std::end(output_spatial_lengths),
|
||||
static_cast<std::size_t>(1),
|
||||
std::multiplies<std::size_t>()) *
|
||||
C *
|
||||
std::accumulate(std::begin(filter_spatial_lengths),
|
||||
std::end(filter_spatial_lengths),
|
||||
static_cast<std::size_t>(1),
|
||||
std::multiplies<std::size_t>());
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Calculate number of bytes read/write by convolution algorithm.
|
||||
*
|
||||
* @param[in] N Batch size.
|
||||
* @param[in] C Number of input channels.
|
||||
* @param[in] K Number of output channels.
|
||||
* @param[in] input_spatial_lengths Input spatial dimensions lengths.
|
||||
* @param[in] filter_spatial_lengths Filter spatial dimensions lengths.
|
||||
* @param[in] output_spatial_lengths Output spatial dimensions lengths
|
||||
*
|
||||
* @tparam InDataType Input tensor data type.
|
||||
* @tparam WeiDataType Weights tensor data type.
|
||||
* @tparam OutDataType Output tensor data type.
|
||||
*
|
||||
* @return The number of used bytes.
|
||||
*/
|
||||
template <typename InDataType = float,
|
||||
typename WeiDataType = InDataType,
|
||||
typename OutDataType = InDataType>
|
||||
std::size_t GetBtype(ck::index_t N,
|
||||
ck::index_t C,
|
||||
ck::index_t K,
|
||||
const std::vector<ck::index_t>& input_spatial_lengths,
|
||||
const std::vector<ck::index_t>& filter_spatial_lengths,
|
||||
const std::vector<ck::index_t>& output_spatial_lengths)
|
||||
{
|
||||
// sizeof(InDataType) * (N * C * <input spatial lengths product>) +
|
||||
// sizeof(WeiDataType) * (K * C * <filter spatial lengths product>) +
|
||||
// sizeof(OutDataType) * (N * K * <output spatial lengths product>);
|
||||
return sizeof(InDataType) * (N * C *
|
||||
std::accumulate(std::begin(input_spatial_lengths),
|
||||
std::end(input_spatial_lengths),
|
||||
static_cast<std::size_t>(1),
|
||||
std::multiplies<std::size_t>())) +
|
||||
sizeof(WeiDataType) * (K * C *
|
||||
std::accumulate(std::begin(filter_spatial_lengths),
|
||||
std::end(filter_spatial_lengths),
|
||||
static_cast<std::size_t>(1),
|
||||
std::multiplies<std::size_t>())) +
|
||||
sizeof(OutDataType) * (N * K *
|
||||
std::accumulate(std::begin(output_spatial_lengths),
|
||||
std::end(output_spatial_lengths),
|
||||
static_cast<std::size_t>(1),
|
||||
std::multiplies<std::size_t>()));
|
||||
}
|
||||
|
||||
struct ConvParams
|
||||
{
|
||||
ConvParams()
|
||||
: num_dim_spatial(2),
|
||||
N(128),
|
||||
K(256),
|
||||
C(192),
|
||||
filter_spatial_lengths(2, 3),
|
||||
input_spatial_lengths(2, 71),
|
||||
conv_filter_strides(2, 2),
|
||||
conv_filter_dilations(2, 1),
|
||||
input_left_pads(2, 1),
|
||||
input_right_pads(2, 1)
|
||||
{
|
||||
}
|
||||
|
||||
ck::index_t num_dim_spatial;
|
||||
ck::index_t N;
|
||||
ck::index_t K;
|
||||
ck::index_t C;
|
||||
|
||||
std::vector<ck::index_t> filter_spatial_lengths;
|
||||
std::vector<ck::index_t> input_spatial_lengths;
|
||||
|
||||
std::vector<ck::index_t> conv_filter_strides;
|
||||
std::vector<ck::index_t> conv_filter_dilations;
|
||||
|
||||
std::vector<ck::index_t> input_left_pads;
|
||||
std::vector<ck::index_t> input_right_pads;
|
||||
|
||||
std::vector<ck::index_t> GetOutputSpatialLengths() const
|
||||
{
|
||||
std::vector<ck::index_t> out_spatial_len(num_dim_spatial, 0);
|
||||
for(ck::index_t i = 0; i < num_dim_spatial; ++i)
|
||||
{
|
||||
// XEff = (X - 1) * conv_dilation_w + 1;
|
||||
// Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1;
|
||||
const ck::index_t idx_eff =
|
||||
(filter_spatial_lengths[i] - 1) * conv_filter_dilations[i] + 1;
|
||||
out_spatial_len[i] =
|
||||
(input_spatial_lengths[i] + input_left_pads[i] + input_right_pads[i] - idx_eff) /
|
||||
conv_filter_strides[i] +
|
||||
1;
|
||||
}
|
||||
return out_spatial_len;
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Gets the host tensor descriptor.
|
||||
*
|
||||
* @param[in] dims The tensor dimensions lengths. Always in NCHW format.
|
||||
* @param[in] layout The tensor data layout.
|
||||
*
|
||||
* @tparam TensorLayout Layout type.
|
||||
*
|
||||
* @return The host tensor descriptor object.
|
||||
*/
|
||||
template <typename TensorLayout>
|
||||
HostTensorDescriptor GetHostTensorDescriptor(const std::vector<std::size_t>& dims,
|
||||
const TensorLayout& layout)
|
||||
{
|
||||
std::size_t C = dims[1];
|
||||
// 1D
|
||||
if constexpr(std::is_same<TensorLayout, ck::tensor_layout::convolution::NCW>::value ||
|
||||
std::is_same<TensorLayout, ck::tensor_layout::convolution::KCX>::value ||
|
||||
std::is_same<TensorLayout, ck::tensor_layout::convolution::NKW>::value)
|
||||
{
|
||||
|
||||
return HostTensorDescriptor(dims, std::vector<std::size_t>({C * dims[2], dims[2], 1}));
|
||||
}
|
||||
else if constexpr(std::is_same<TensorLayout, ck::tensor_layout::convolution::NWC>::value ||
|
||||
std::is_same<TensorLayout, ck::tensor_layout::convolution::KXC>::value ||
|
||||
std::is_same<TensorLayout, ck::tensor_layout::convolution::NWK>::value)
|
||||
{
|
||||
return HostTensorDescriptor(dims, std::vector<std::size_t>({C * dims[2], 1, C}));
|
||||
}
|
||||
// 2D
|
||||
else if constexpr(std::is_same<TensorLayout, ck::tensor_layout::convolution::NCHW>::value ||
|
||||
std::is_same<TensorLayout, ck::tensor_layout::convolution::KCYX>::value ||
|
||||
std::is_same<TensorLayout, ck::tensor_layout::convolution::NKHW>::value)
|
||||
{
|
||||
|
||||
return HostTensorDescriptor(
|
||||
dims, std::vector<std::size_t>{C * dims[2] * dims[3], dims[2] * dims[3], dims[3], 1});
|
||||
}
|
||||
else if constexpr(std::is_same<TensorLayout, ck::tensor_layout::convolution::NHWC>::value ||
|
||||
std::is_same<TensorLayout, ck::tensor_layout::convolution::KYXC>::value ||
|
||||
std::is_same<TensorLayout, ck::tensor_layout::convolution::NHWK>::value)
|
||||
{
|
||||
return HostTensorDescriptor(
|
||||
dims, std::vector<std::size_t>{C * dims[2] * dims[3], 1, dims[3] * C, C});
|
||||
}
|
||||
|
||||
std::stringstream err_msg;
|
||||
err_msg << "Unsupported data layout provided: " << layout << "!";
|
||||
throw std::runtime_error(err_msg.str());
|
||||
}
|
||||
|
||||
} // namespace conv_util
|
||||
} // namespace ck
|
||||
|
||||
#endif
|
||||
@@ -215,7 +215,6 @@ struct DeviceConv3dFwdXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_
|
||||
else if constexpr(ConvForwardSpecialization ==
|
||||
ConvolutionForwardSpecialization_t::Filter1x1Pad0)
|
||||
{
|
||||
|
||||
static_assert(ConvForwardSpecialization == -1, "Not implemented!");
|
||||
}
|
||||
else
|
||||
|
||||
@@ -0,0 +1,865 @@
|
||||
#ifndef DEVICE_CONVND_FWD_XDL_NHWC_KYXC_NHWK_HPP
|
||||
#define DEVICE_CONVND_FWD_XDL_NHWC_KYXC_NHWK_HPP
|
||||
|
||||
#include <functional>
|
||||
#include <iostream>
|
||||
#include <iterator>
|
||||
#include <numeric>
|
||||
#include <sstream>
|
||||
|
||||
#include "device.hpp"
|
||||
#include "device_base.hpp"
|
||||
#include "device_conv_fwd.hpp"
|
||||
#include "convolution_forward_specialization.hpp"
|
||||
#include "common_header.hpp"
|
||||
#include "tensor_layout.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
#include "gridwise_gemm_xdlops_v2r3.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
//
|
||||
// @brief Device Convolution operation.
|
||||
//
|
||||
// Supports:
|
||||
// @li Inputs with up to 3 spatial dimentions
|
||||
// @li Input tensor in NHWC data format
|
||||
// @li Weight tensor in KYXC data format
|
||||
// @li Output tensor in NHWK data format
|
||||
//
|
||||
// 1D:
|
||||
// out[N, Wo, K] = in[N, Wi, C] * wei[K, X, C]
|
||||
// 2D:
|
||||
// out[N, Ho, Wo, K] = in[N, Hi, Wi, C] * wei[K, Y, X, C]
|
||||
// 3D:
|
||||
// out[N, Do, Ho, Wo, K] = in[N, Di, Hi, Wi, C] * wei[K, Z, Y, X, C]
|
||||
//
|
||||
template <typename InDataType,
|
||||
typename WeiDataType,
|
||||
typename OutDataType,
|
||||
typename AccDataType,
|
||||
typename InElementwiseOperation,
|
||||
typename WeiElementwiseOperation,
|
||||
typename OutElementwiseOperation,
|
||||
ConvolutionForwardSpecialization_t ConvForwardSpecialization,
|
||||
ck::index_t NumDimSpatial,
|
||||
ck::index_t BlockSize,
|
||||
ck::index_t MPerBlock,
|
||||
ck::index_t NPerBlock,
|
||||
ck::index_t K0PerBlock,
|
||||
ck::index_t K1,
|
||||
ck::index_t MPerXDL,
|
||||
ck::index_t NPerXDL,
|
||||
ck::index_t MXdlPerWave,
|
||||
ck::index_t NXdlPerWave,
|
||||
typename ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
typename ABlockTransferThreadClusterArrangeOrder,
|
||||
typename ABlockTransferSrcAccessOrder,
|
||||
ck::index_t ABlockTransferSrcVectorDim,
|
||||
ck::index_t ABlockTransferSrcScalarPerVector,
|
||||
ck::index_t ABlockTransferDstScalarPerVector_K1,
|
||||
bool ABlockLdsAddExtraM,
|
||||
typename BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
typename BBlockTransferThreadClusterArrangeOrder,
|
||||
typename BBlockTransferSrcAccessOrder,
|
||||
ck::index_t BBlockTransferSrcVectorDim,
|
||||
ck::index_t BBlockTransferSrcScalarPerVector,
|
||||
ck::index_t BBlockTransferDstScalarPerVector_K1,
|
||||
bool BBlockLdsAddExtraN,
|
||||
ck::index_t CThreadTransferSrcDstVectorDim,
|
||||
ck::index_t CThreadTransferDstScalarPerVector>
|
||||
struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
|
||||
: public DeviceConvFwd<InElementwiseOperation, WeiElementwiseOperation, OutElementwiseOperation>
|
||||
{
|
||||
using DeviceOp = DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K;
|
||||
|
||||
using ADataType = InDataType;
|
||||
using BDataType = WeiDataType;
|
||||
using CDataType = OutDataType;
|
||||
|
||||
// TODO make A/B datatype different
|
||||
using ABDataType = InDataType;
|
||||
|
||||
static constexpr index_t NDimSpatial = NumDimSpatial;
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
static constexpr auto I3 = Number<3>{};
|
||||
|
||||
static constexpr auto K1Number = Number<K1>{};
|
||||
static constexpr auto GemmK1Number = K1Number;
|
||||
|
||||
static auto GetWeightTensorDescriptor(ck::index_t gemm_n, ck::index_t gemm_k)
|
||||
{
|
||||
const ck::index_t gemm_k0 = gemm_k / GemmK1Number;
|
||||
const auto wei_k_yxc_grid_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(gemm_n, gemm_k));
|
||||
|
||||
// wei_gemmk0_gemmn_gemmk1_grid_desc
|
||||
return transform_tensor_descriptor(
|
||||
wei_k_yxc_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(gemm_k0, GemmK1Number)),
|
||||
make_pass_through_transform(gemm_n)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
}
|
||||
|
||||
static auto
|
||||
GetOutputTensorDescriptor(ck::index_t gemm_m, ck::index_t gemm_n, ck::index_t gemm_m_pad)
|
||||
{
|
||||
const auto out_gemmmraw_gemmn_grid_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(gemm_m, gemm_n));
|
||||
|
||||
// out_gemmm_gemmn_grid_desc
|
||||
return transform_tensor_descriptor(out_gemmmraw_gemmn_grid_desc,
|
||||
make_tuple(make_right_pad_transform(gemm_m, gemm_m_pad),
|
||||
make_pass_through_transform(gemm_n)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
}
|
||||
|
||||
template <ck::index_t NDim, typename std::enable_if<NDim == 1, bool>::type = false>
|
||||
static auto GetInputTensorDescriptor(ck::index_t N,
|
||||
ck::index_t C,
|
||||
ck::index_t gemm_m,
|
||||
ck::index_t gemm_k,
|
||||
ck::index_t gemm_m_pad,
|
||||
const std::vector<ck::index_t>& input_spatial_lengths,
|
||||
const std::vector<ck::index_t>& filter_spatial_lengths,
|
||||
const std::vector<ck::index_t>& output_spatial_lengths,
|
||||
const std::vector<ck::index_t>& conv_filter_strides,
|
||||
const std::vector<ck::index_t>& conv_filter_dilations,
|
||||
const std::vector<ck::index_t>& input_left_pads,
|
||||
const std::vector<ck::index_t>& input_right_pads)
|
||||
{
|
||||
const ck::index_t gemm_k0 = gemm_k / GemmK1Number;
|
||||
const index_t Wi = input_spatial_lengths[0];
|
||||
const index_t Wo = output_spatial_lengths[0];
|
||||
const index_t ConvStrideW = conv_filter_strides[0];
|
||||
|
||||
if constexpr(ConvForwardSpecialization ==
|
||||
ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0)
|
||||
{
|
||||
const auto in_gemmmraw_gemmk_grid_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(gemm_m, gemm_k));
|
||||
|
||||
// in_gemmk0_gemmm_gemmk1_grid_desc
|
||||
return transform_tensor_descriptor(
|
||||
in_gemmmraw_gemmk_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(gemm_k0, GemmK1Number)),
|
||||
make_right_pad_transform(gemm_m, gemm_m_pad)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
}
|
||||
else if constexpr(ConvForwardSpecialization ==
|
||||
ConvolutionForwardSpecialization_t::Filter1x1Pad0)
|
||||
{
|
||||
const auto in_n_wi_c_grid_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(N, Wi, C));
|
||||
|
||||
const auto in_n_wo_c_grid_desc = transform_tensor_descriptor(
|
||||
in_n_wi_c_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
|
||||
|
||||
const auto in_gemmk0_gemmmraw_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
in_n_wo_c_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(gemm_k0, GemmK1Number)),
|
||||
make_merge_transform(make_tuple(N, Wo))),
|
||||
make_tuple(Sequence<2>{}, Sequence<0, 1>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
// in_gemmk0_gemmm_gemmk1_grid_desc
|
||||
return transform_tensor_descriptor(
|
||||
in_gemmk0_gemmmraw_gemmk1_grid_desc,
|
||||
make_tuple(make_pass_through_transform(gemm_k0),
|
||||
make_right_pad_transform(gemm_m, gemm_m_pad),
|
||||
make_pass_through_transform(GemmK1Number)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
const index_t X = filter_spatial_lengths[0];
|
||||
const index_t ConvDilationW = conv_filter_dilations[0];
|
||||
const index_t InLeftPadW = input_left_pads[0];
|
||||
const index_t InRightPadW = input_right_pads[0];
|
||||
|
||||
const auto in_n_wi_c_grid_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(N, Wi, C));
|
||||
|
||||
const auto in_n_wip_c_grid_desc = transform_tensor_descriptor(
|
||||
in_n_wi_c_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_pad_transform(Wi, InLeftPadW, InRightPadW),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
|
||||
|
||||
const auto in_n_x_wo_c_grid_desc = transform_tensor_descriptor(
|
||||
in_n_wip_c_grid_desc,
|
||||
make_tuple(
|
||||
make_pass_through_transform(N),
|
||||
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
|
||||
|
||||
const auto in_gemmk_gemmmraw_grid_desc =
|
||||
transform_tensor_descriptor(in_n_x_wo_c_grid_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(X, C)),
|
||||
make_merge_transform(make_tuple(N, Wo))),
|
||||
make_tuple(Sequence<1, 3>{}, Sequence<0, 2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
const auto in_gemmk0_gemmmraw_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
in_gemmk_gemmmraw_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(gemm_k0, GemmK1Number)),
|
||||
make_pass_through_transform(gemm_m)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
// in_gemmk0_gemmm_gemmk1_grid_desc
|
||||
return transform_tensor_descriptor(
|
||||
in_gemmk0_gemmmraw_gemmk1_grid_desc,
|
||||
make_tuple(make_pass_through_transform(gemm_k0),
|
||||
make_right_pad_transform(gemm_m, gemm_m_pad),
|
||||
make_pass_through_transform(GemmK1Number)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
|
||||
}
|
||||
}
|
||||
|
||||
template <ck::index_t NDim, typename std::enable_if<NDim == 2, bool>::type = false>
|
||||
static auto GetInputTensorDescriptor(ck::index_t N,
|
||||
ck::index_t C,
|
||||
ck::index_t gemm_m,
|
||||
ck::index_t gemm_k,
|
||||
ck::index_t gemm_m_pad,
|
||||
const std::vector<ck::index_t>& input_spatial_lengths,
|
||||
const std::vector<ck::index_t>& filter_spatial_lengths,
|
||||
const std::vector<ck::index_t>& output_spatial_lengths,
|
||||
const std::vector<ck::index_t>& conv_filter_strides,
|
||||
const std::vector<ck::index_t>& conv_filter_dilations,
|
||||
const std::vector<ck::index_t>& input_left_pads,
|
||||
const std::vector<ck::index_t>& input_right_pads)
|
||||
{
|
||||
const ck::index_t gemm_k0 = gemm_k / GemmK1Number;
|
||||
const index_t Hi = input_spatial_lengths[0];
|
||||
const index_t Wi = input_spatial_lengths[1];
|
||||
|
||||
const index_t Ho = output_spatial_lengths[0];
|
||||
const index_t Wo = output_spatial_lengths[1];
|
||||
|
||||
const index_t ConvStrideH = conv_filter_strides[0];
|
||||
const index_t ConvStrideW = conv_filter_strides[1];
|
||||
|
||||
if constexpr(ConvForwardSpecialization ==
|
||||
ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0)
|
||||
{
|
||||
const auto in_gemmmraw_gemmk_grid_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(gemm_m, gemm_k));
|
||||
|
||||
// in_gemmk0_gemmm_gemmk1_grid_desc
|
||||
return transform_tensor_descriptor(
|
||||
in_gemmmraw_gemmk_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(gemm_k0, GemmK1Number)),
|
||||
make_right_pad_transform(gemm_m, gemm_m_pad)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
}
|
||||
else if constexpr(ConvForwardSpecialization ==
|
||||
ConvolutionForwardSpecialization_t::Filter1x1Pad0)
|
||||
{
|
||||
const auto in_n_hi_wi_c_grid_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C));
|
||||
|
||||
const auto in_n_ho_wo_c_grid_desc = transform_tensor_descriptor(
|
||||
in_n_hi_wi_c_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_embed_transform(make_tuple(Ho), make_tuple(ConvStrideH)),
|
||||
make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
const auto in_gemmk0_gemmmraw_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
in_n_ho_wo_c_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(gemm_k0, GemmK1Number)),
|
||||
make_merge_transform(make_tuple(N, Ho, Wo))),
|
||||
make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
// in_gemmk0_gemmm_gemmk1_grid_desc
|
||||
return transform_tensor_descriptor(
|
||||
in_gemmk0_gemmmraw_gemmk1_grid_desc,
|
||||
make_tuple(make_pass_through_transform(gemm_k0),
|
||||
make_right_pad_transform(gemm_m, gemm_m_pad),
|
||||
make_pass_through_transform(GemmK1Number)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
const index_t Y = filter_spatial_lengths[0];
|
||||
const index_t X = filter_spatial_lengths[1];
|
||||
|
||||
const index_t ConvDilationH = conv_filter_dilations[0];
|
||||
const index_t ConvDilationW = conv_filter_dilations[1];
|
||||
|
||||
const index_t InLeftPadH = input_left_pads[0];
|
||||
const index_t InLeftPadW = input_left_pads[1];
|
||||
|
||||
const index_t InRightPadH = input_right_pads[0];
|
||||
const index_t InRightPadW = input_right_pads[1];
|
||||
|
||||
const auto in_n_hi_wi_c_grid_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C));
|
||||
|
||||
const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
|
||||
in_n_hi_wi_c_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_pad_transform(Hi, InLeftPadH, InRightPadH),
|
||||
make_pad_transform(Wi, InLeftPadW, InRightPadW),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
|
||||
in_n_hip_wip_c_grid_desc,
|
||||
make_tuple(
|
||||
make_pass_through_transform(N),
|
||||
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
|
||||
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
|
||||
|
||||
const auto in_gemmk_gemmmraw_grid_desc =
|
||||
transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(Y, X, C)),
|
||||
make_merge_transform(make_tuple(N, Ho, Wo))),
|
||||
make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
const auto in_gemmk0_gemmmraw_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
in_gemmk_gemmmraw_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(gemm_k0, GemmK1Number)),
|
||||
make_pass_through_transform(gemm_m)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
// in_gemmk0_gemmm_gemmk1_grid_desc
|
||||
return transform_tensor_descriptor(
|
||||
in_gemmk0_gemmmraw_gemmk1_grid_desc,
|
||||
make_tuple(make_pass_through_transform(gemm_k0),
|
||||
make_right_pad_transform(gemm_m, gemm_m_pad),
|
||||
make_pass_through_transform(GemmK1Number)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
|
||||
}
|
||||
}
|
||||
|
||||
static index_t GetGemmMRaw(ck::index_t N,
|
||||
const std::vector<ck::index_t>& output_spatial_lengths)
|
||||
{
|
||||
return N * std::accumulate(std::begin(output_spatial_lengths),
|
||||
std::end(output_spatial_lengths),
|
||||
1,
|
||||
std::multiplies<ck::index_t>());
|
||||
}
|
||||
|
||||
static index_t GetGemmK(ck::index_t C, const std::vector<ck::index_t>& filter_spatial_lengths)
|
||||
{
|
||||
return C * std::accumulate(std::begin(filter_spatial_lengths),
|
||||
std::end(filter_spatial_lengths),
|
||||
1,
|
||||
std::multiplies<ck::index_t>());
|
||||
}
|
||||
|
||||
static auto
|
||||
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(ck::index_t N,
|
||||
ck::index_t K,
|
||||
ck::index_t C,
|
||||
std::vector<ck::index_t> input_spatial_lengths,
|
||||
std::vector<ck::index_t> filter_spatial_lengths,
|
||||
std::vector<ck::index_t> output_spatial_lengths,
|
||||
std::vector<ck::index_t> conv_filter_strides,
|
||||
std::vector<ck::index_t> conv_filter_dilations,
|
||||
std::vector<ck::index_t> input_left_pads,
|
||||
std::vector<ck::index_t> input_right_pads)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
const index_t GemmMRaw = GetGemmMRaw(N, output_spatial_lengths);
|
||||
const index_t GemmN = K;
|
||||
const index_t GemmK = GetGemmK(C, filter_spatial_lengths);
|
||||
|
||||
const auto GemmMPad = math::integer_least_multiple(GemmMRaw, MPerBlock) - GemmMRaw;
|
||||
|
||||
assert(GemmK % GemmK1Number == 0);
|
||||
|
||||
// C = A^T*B
|
||||
// A:
|
||||
const auto in_gemmk0_gemmm_gemmk1_grid_desc =
|
||||
GetInputTensorDescriptor<NumDimSpatial>(N,
|
||||
C,
|
||||
GemmMRaw,
|
||||
GemmK,
|
||||
GemmMPad,
|
||||
input_spatial_lengths,
|
||||
filter_spatial_lengths,
|
||||
output_spatial_lengths,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads);
|
||||
// B:
|
||||
const auto wei_gemmk0_gemmn_gemmk1_grid_desc = GetWeightTensorDescriptor(GemmN, GemmK);
|
||||
// C:
|
||||
const auto out_gemmm_gemmn_grid_desc = GetOutputTensorDescriptor(GemmMRaw, GemmN, GemmMPad);
|
||||
|
||||
return make_tuple(in_gemmk0_gemmm_gemmk1_grid_desc,
|
||||
wei_gemmk0_gemmn_gemmk1_grid_desc,
|
||||
out_gemmm_gemmn_grid_desc);
|
||||
}
|
||||
|
||||
template <ck::index_t NDim, typename std::enable_if<NDim == 1, bool>::type = false>
|
||||
static auto GetABCGridDesc()
|
||||
{
|
||||
return MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(
|
||||
1, 1, 1, {1}, {1}, {1}, {1}, {1}, {1}, {1});
|
||||
}
|
||||
|
||||
template <ck::index_t NDim, typename std::enable_if<NDim == 2, bool>::type = false>
|
||||
static auto GetABCGridDesc()
|
||||
{
|
||||
return MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(
|
||||
1, 1, 1, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1});
|
||||
}
|
||||
|
||||
using ABCGridDescs = decltype(GetABCGridDesc<NumDimSpatial>());
|
||||
|
||||
using AGridDesc_K0_M_K1 = remove_cvref_t<decltype(ABCGridDescs{}[I0])>;
|
||||
using BGridDesc_K0_N_K1 = remove_cvref_t<decltype(ABCGridDescs{}[I1])>;
|
||||
using CGridDesc_M_N = remove_cvref_t<decltype(ABCGridDescs{}[I2])>;
|
||||
|
||||
// GridwiseGemm
|
||||
using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3<
|
||||
BlockSize,
|
||||
ABDataType, // TODO: distinguish A/B datatype
|
||||
AccDataType,
|
||||
CDataType,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
AGridDesc_K0_M_K1,
|
||||
BGridDesc_K0_N_K1,
|
||||
CGridDesc_M_N,
|
||||
InElementwiseOperation,
|
||||
WeiElementwiseOperation,
|
||||
OutElementwiseOperation,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
K0PerBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
K1,
|
||||
MXdlPerWave,
|
||||
NXdlPerWave,
|
||||
ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder,
|
||||
Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder,
|
||||
2, // ABlockTransferSrcVectorDim,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_K1,
|
||||
false, // AThreadTransferSrcResetCoordinateAfterRun,
|
||||
ABlockLdsAddExtraM,
|
||||
BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder,
|
||||
Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder,
|
||||
2, // BBlockTransferSrcVectorDim,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_K1,
|
||||
false, // BThreadTransferSrcResetCoordinateAfterRun,
|
||||
BBlockLdsAddExtraN,
|
||||
Sequence<2, 3, 0, 1, 7, 5, 4, 6>, // CThreadTransferSrcDstAccessOrder,
|
||||
7, // CThreadTransferSrcDstVectorDim,
|
||||
CThreadTransferDstScalarPerVector>;
|
||||
|
||||
// Argument
|
||||
struct Argument : public BaseArgument
|
||||
{
|
||||
Argument(const InDataType* p_in_grid,
|
||||
const WeiDataType* p_wei_grid,
|
||||
OutDataType* p_out_grid,
|
||||
ck::index_t N,
|
||||
ck::index_t K,
|
||||
ck::index_t C,
|
||||
std::vector<ck::index_t> input_spatial_lengths,
|
||||
std::vector<ck::index_t> filter_spatial_lengths,
|
||||
std::vector<ck::index_t> output_spatial_lengths,
|
||||
std::vector<ck::index_t> conv_filter_strides,
|
||||
std::vector<ck::index_t> conv_filter_dilations,
|
||||
std::vector<ck::index_t> input_left_pads,
|
||||
std::vector<ck::index_t> input_right_pads,
|
||||
ck::index_t M01,
|
||||
ck::index_t N01,
|
||||
InElementwiseOperation in_element_op,
|
||||
WeiElementwiseOperation wei_element_op,
|
||||
OutElementwiseOperation out_element_op)
|
||||
: p_a_grid_{p_in_grid},
|
||||
p_b_grid_{p_wei_grid},
|
||||
p_c_grid_{p_out_grid},
|
||||
a_grid_desc_k0_m_k1_{},
|
||||
b_grid_desc_k0_n_k1_{},
|
||||
c_grid_desc_m_n_{},
|
||||
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_{},
|
||||
block_2_ctile_map_{},
|
||||
M01_{M01},
|
||||
N01_{N01},
|
||||
in_element_op_{in_element_op},
|
||||
wei_element_op_{wei_element_op},
|
||||
out_element_op_{out_element_op},
|
||||
Conv_N_{N},
|
||||
Conv_K_{K},
|
||||
Conv_C_{C},
|
||||
filter_spatial_lengths_{filter_spatial_lengths},
|
||||
conv_filter_strides_{conv_filter_strides},
|
||||
input_left_pads_{input_left_pads},
|
||||
input_right_pads_{input_right_pads}
|
||||
{
|
||||
const auto descs =
|
||||
DeviceOp::MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(N,
|
||||
K,
|
||||
C,
|
||||
input_spatial_lengths,
|
||||
filter_spatial_lengths,
|
||||
output_spatial_lengths,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads);
|
||||
|
||||
a_grid_desc_k0_m_k1_ = descs[I0];
|
||||
b_grid_desc_k0_n_k1_ = descs[I1];
|
||||
c_grid_desc_m_n_ = descs[I2];
|
||||
|
||||
if(GridwiseGemm::CheckValidity(
|
||||
a_grid_desc_k0_m_k1_, b_grid_desc_k0_n_k1_, c_grid_desc_m_n_, M01_, N01_))
|
||||
{
|
||||
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_ =
|
||||
GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m_n_);
|
||||
|
||||
block_2_ctile_map_ =
|
||||
GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01);
|
||||
}
|
||||
}
|
||||
|
||||
// private:
|
||||
const ADataType* p_a_grid_;
|
||||
const BDataType* p_b_grid_;
|
||||
CDataType* p_c_grid_;
|
||||
AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_;
|
||||
BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_;
|
||||
CGridDesc_M_N c_grid_desc_m_n_;
|
||||
typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
|
||||
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_;
|
||||
typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_;
|
||||
index_t M01_;
|
||||
index_t N01_;
|
||||
InElementwiseOperation in_element_op_;
|
||||
WeiElementwiseOperation wei_element_op_;
|
||||
OutElementwiseOperation out_element_op_;
|
||||
// for checking IsSupportedArgument()
|
||||
index_t Conv_N_;
|
||||
index_t Conv_K_;
|
||||
index_t Conv_C_;
|
||||
std::vector<index_t> filter_spatial_lengths_;
|
||||
std::vector<index_t> conv_filter_strides_;
|
||||
std::vector<index_t> input_left_pads_;
|
||||
std::vector<index_t> input_right_pads_;
|
||||
};
|
||||
|
||||
// Invoker
|
||||
struct Invoker : public BaseInvoker
|
||||
{
|
||||
using Argument = DeviceOp::Argument;
|
||||
|
||||
float Run(const Argument& arg, int nrepeat = 1)
|
||||
{
|
||||
{
|
||||
std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0)
|
||||
<< ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", "
|
||||
<< arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl;
|
||||
|
||||
std::cout << "arg.b_grid_desc_k0_n_k1_{" << arg.b_grid_desc_k0_n_k1_.GetLength(I0)
|
||||
<< ", " << arg.b_grid_desc_k0_n_k1_.GetLength(I1) << ", "
|
||||
<< arg.b_grid_desc_k0_n_k1_.GetLength(I2) << "}" << std::endl;
|
||||
|
||||
std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", "
|
||||
<< arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
|
||||
}
|
||||
|
||||
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_,
|
||||
arg.b_grid_desc_k0_n_k1_,
|
||||
arg.c_grid_desc_m_n_,
|
||||
arg.M01_,
|
||||
arg.N01_))
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v2r3 has invalid setting");
|
||||
}
|
||||
|
||||
const index_t grid_size = GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_);
|
||||
|
||||
const auto K0 = arg.a_grid_desc_k0_m_k1_.GetLength(I0);
|
||||
|
||||
const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0);
|
||||
|
||||
float ave_time = 0;
|
||||
|
||||
if(has_main_k0_block_loop)
|
||||
{
|
||||
const auto kernel = kernel_gemm_xdlops_v2r3<
|
||||
GridwiseGemm,
|
||||
ADataType, // TODO: distiguish A/B datatype
|
||||
CDataType,
|
||||
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
|
||||
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
|
||||
remove_reference_t<typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
|
||||
InElementwiseOperation,
|
||||
WeiElementwiseOperation,
|
||||
OutElementwiseOperation,
|
||||
remove_reference_t<typename GridwiseGemm::DefaultBlock2CTileMap>,
|
||||
true>;
|
||||
|
||||
ave_time = launch_and_time_kernel(kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_a_grid_,
|
||||
arg.p_b_grid_,
|
||||
arg.p_c_grid_,
|
||||
arg.a_grid_desc_k0_m_k1_,
|
||||
arg.b_grid_desc_k0_n_k1_,
|
||||
arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_,
|
||||
arg.in_element_op_,
|
||||
arg.wei_element_op_,
|
||||
arg.out_element_op_,
|
||||
arg.block_2_ctile_map_);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel = kernel_gemm_xdlops_v2r3<
|
||||
GridwiseGemm,
|
||||
ADataType, // TODO: distiguish A/B datatype
|
||||
CDataType,
|
||||
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
|
||||
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
|
||||
remove_reference_t<typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
|
||||
InElementwiseOperation,
|
||||
WeiElementwiseOperation,
|
||||
OutElementwiseOperation,
|
||||
remove_reference_t<typename GridwiseGemm::DefaultBlock2CTileMap>,
|
||||
false>;
|
||||
|
||||
ave_time = launch_and_time_kernel(kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_a_grid_,
|
||||
arg.p_b_grid_,
|
||||
arg.p_c_grid_,
|
||||
arg.a_grid_desc_k0_m_k1_,
|
||||
arg.b_grid_desc_k0_n_k1_,
|
||||
arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_,
|
||||
arg.in_element_op_,
|
||||
arg.wei_element_op_,
|
||||
arg.out_element_op_,
|
||||
arg.block_2_ctile_map_);
|
||||
}
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
float Run(const BaseArgument* p_arg, int nrepeat = 1) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat);
|
||||
}
|
||||
};
|
||||
|
||||
static constexpr bool IsValidCompilationParameter()
|
||||
{
|
||||
// TODO: properly implement this check
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
if constexpr(ConvForwardSpecialization ==
|
||||
ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0)
|
||||
{
|
||||
// check if it's 1x1, stride=1 conv
|
||||
for(ck::index_t i = 0; i < NumDimSpatial; ++i)
|
||||
{
|
||||
if(!(arg.filter_spatial_lengths_[i] == 1 && arg.conv_filter_strides_[i] == 1 &&
|
||||
arg.input_left_pads_[i] == 0 && arg.input_right_pads_[i] == 0))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
else if constexpr(ConvForwardSpecialization ==
|
||||
ConvolutionForwardSpecialization_t::Filter1x1Pad0)
|
||||
{
|
||||
// check if it's 1x1 conv
|
||||
for(ck::index_t i = 0; i < NumDimSpatial; ++i)
|
||||
{
|
||||
if(!(arg.filter_spatial_lengths_[i] == 1 && arg.input_left_pads_[i] == 0 &&
|
||||
arg.input_right_pads_[i] == 0))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// vector load A/B matrix from global memory
|
||||
if(!(ABlockTransferSrcVectorDim == 2 && BBlockTransferSrcVectorDim == 2 &&
|
||||
arg.Conv_C_ % ABlockTransferSrcScalarPerVector == 0 &&
|
||||
arg.Conv_C_ % BBlockTransferSrcScalarPerVector == 0))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// vector store C matrix into global memory
|
||||
if(!(arg.Conv_K_ % CThreadTransferDstScalarPerVector == 0))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// Gridwise GEMM size
|
||||
return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_,
|
||||
arg.b_grid_desc_k0_n_k1_,
|
||||
arg.c_grid_desc_m_n_,
|
||||
arg.M01_,
|
||||
arg.N01_);
|
||||
}
|
||||
|
||||
bool IsSupportedArgument(const BaseArgument* p_arg) override
|
||||
{
|
||||
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
|
||||
}
|
||||
|
||||
static auto MakeArgument(const InDataType* p_in_grid,
|
||||
const WeiDataType* p_wei_grid,
|
||||
OutDataType* p_out_grid,
|
||||
ck::index_t N,
|
||||
ck::index_t K,
|
||||
ck::index_t C,
|
||||
std::vector<ck::index_t> input_spatial_lengths,
|
||||
std::vector<ck::index_t> filter_spatial_lengths,
|
||||
std::vector<ck::index_t> output_spatial_lengths,
|
||||
std::vector<ck::index_t> conv_filter_strides,
|
||||
std::vector<ck::index_t> conv_filter_dilations,
|
||||
std::vector<ck::index_t> input_left_pads,
|
||||
std::vector<ck::index_t> input_right_pads,
|
||||
InElementwiseOperation in_element_op,
|
||||
WeiElementwiseOperation wei_element_op,
|
||||
OutElementwiseOperation out_element_op)
|
||||
{
|
||||
return Argument{p_in_grid,
|
||||
p_wei_grid,
|
||||
p_out_grid,
|
||||
N,
|
||||
K,
|
||||
C,
|
||||
input_spatial_lengths,
|
||||
filter_spatial_lengths,
|
||||
output_spatial_lengths,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads,
|
||||
1,
|
||||
1,
|
||||
in_element_op,
|
||||
wei_element_op,
|
||||
out_element_op};
|
||||
}
|
||||
|
||||
static auto MakeInvoker() { return Invoker{}; }
|
||||
|
||||
std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const void* p_in_grid,
|
||||
const void* p_wei_grid,
|
||||
void* p_out_grid,
|
||||
ck::index_t N,
|
||||
ck::index_t K,
|
||||
ck::index_t C,
|
||||
std::vector<ck::index_t> input_spatial_lengths,
|
||||
std::vector<ck::index_t> filter_spatial_lengths,
|
||||
std::vector<ck::index_t> output_spatial_lengths,
|
||||
std::vector<ck::index_t> conv_filter_strides,
|
||||
std::vector<ck::index_t> conv_filter_dilations,
|
||||
std::vector<ck::index_t> input_left_pads,
|
||||
std::vector<ck::index_t> input_right_pads,
|
||||
InElementwiseOperation in_element_op,
|
||||
WeiElementwiseOperation wei_element_op,
|
||||
OutElementwiseOperation out_element_op) override
|
||||
{
|
||||
return std::make_unique<Argument>(static_cast<const InDataType*>(p_in_grid),
|
||||
static_cast<const WeiDataType*>(p_wei_grid),
|
||||
static_cast<OutDataType*>(p_out_grid),
|
||||
N,
|
||||
K,
|
||||
C,
|
||||
input_spatial_lengths,
|
||||
filter_spatial_lengths,
|
||||
output_spatial_lengths,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads,
|
||||
1,
|
||||
1,
|
||||
in_element_op,
|
||||
wei_element_op,
|
||||
out_element_op);
|
||||
}
|
||||
|
||||
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
|
||||
{
|
||||
return std::make_unique<Invoker>(Invoker{});
|
||||
}
|
||||
|
||||
std::string GetTypeString() const override
|
||||
{
|
||||
auto str = std::stringstream();
|
||||
|
||||
// clang-format off
|
||||
str << "DeviceConv" << std::to_string(NumDimSpatial)
|
||||
<< "DFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K"
|
||||
<< "<"
|
||||
<< BlockSize << ", "
|
||||
<< MPerBlock << ", "
|
||||
<< NPerBlock << ", "
|
||||
<< K0PerBlock
|
||||
<< ">";
|
||||
// clang-format on
|
||||
|
||||
return str.str();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -12,37 +12,77 @@ namespace gemm {
|
||||
|
||||
struct RowMajor : public BaseTensorLayout
|
||||
{
|
||||
static constexpr const char* name = "RowMajor";
|
||||
};
|
||||
|
||||
struct ColumnMajor : public BaseTensorLayout
|
||||
{
|
||||
static constexpr const char* name = "ColumnMajor";
|
||||
};
|
||||
} // namespace gemm
|
||||
|
||||
namespace convolution {
|
||||
|
||||
// 1D Conv
|
||||
struct NWC : public BaseTensorLayout
|
||||
{
|
||||
static constexpr const char* name = "NWC";
|
||||
};
|
||||
|
||||
struct KXC : public BaseTensorLayout
|
||||
{
|
||||
static constexpr const char* name = "KXC";
|
||||
};
|
||||
|
||||
struct NWK : public BaseTensorLayout
|
||||
{
|
||||
static constexpr const char* name = "NWK";
|
||||
};
|
||||
|
||||
struct NCW : public BaseTensorLayout
|
||||
{
|
||||
static constexpr const char* name = "NCW";
|
||||
};
|
||||
|
||||
struct KCX : public BaseTensorLayout
|
||||
{
|
||||
static constexpr const char* name = "KCX";
|
||||
};
|
||||
|
||||
struct NKW : public BaseTensorLayout
|
||||
{
|
||||
static constexpr const char* name = "NKW";
|
||||
};
|
||||
|
||||
// 2D Conv
|
||||
struct NHWC : public BaseTensorLayout
|
||||
{
|
||||
static constexpr const char* name = "NHWC";
|
||||
};
|
||||
|
||||
struct KYXC : public BaseTensorLayout
|
||||
{
|
||||
static constexpr const char* name = "KYXC";
|
||||
};
|
||||
|
||||
struct NHWK : public BaseTensorLayout
|
||||
{
|
||||
static constexpr const char* name = "NHWK";
|
||||
};
|
||||
|
||||
struct NCHW : public BaseTensorLayout
|
||||
{
|
||||
static constexpr const char* name = "NCHW";
|
||||
};
|
||||
|
||||
struct KCYX : public BaseTensorLayout
|
||||
{
|
||||
static constexpr const char* name = "KCYX";
|
||||
};
|
||||
|
||||
struct NKHW : public BaseTensorLayout
|
||||
{
|
||||
static constexpr const char* name = "NKHW";
|
||||
};
|
||||
|
||||
struct NDHWC : public BaseTensorLayout
|
||||
@@ -59,6 +99,15 @@ struct NDHWK : public BaseTensorLayout
|
||||
|
||||
} // namespace convolution
|
||||
|
||||
template <
|
||||
typename Layout,
|
||||
typename std::enable_if<std::is_base_of<BaseTensorLayout, Layout>::value, bool>::type = false>
|
||||
std::ostream& operator<<(std::ostream& os, const Layout&)
|
||||
{
|
||||
os << Layout::name;
|
||||
return os;
|
||||
}
|
||||
|
||||
} // namespace tensor_layout
|
||||
} // namespace ck
|
||||
#endif
|
||||
|
||||
@@ -0,0 +1,112 @@
|
||||
#include <stdlib.h>
|
||||
#include "config.hpp"
|
||||
#include "device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp"
|
||||
#include "element_wise_operation.hpp"
|
||||
#include "device_operation_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace device_conv1d_fwd_instance {
|
||||
|
||||
using F32 = float;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
static constexpr auto ConvFwdDefault =
|
||||
ck::tensor_operation::device::ConvolutionForwardSpecialization_t::Default;
|
||||
|
||||
static constexpr auto ConvFwd1x1P0 =
|
||||
ck::tensor_operation::device::ConvolutionForwardSpecialization_t::Filter1x1Pad0;
|
||||
|
||||
static constexpr auto ConvFwd1x1S1P0 =
|
||||
ck::tensor_operation::device::ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0;
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
// Conv1D
|
||||
//------------------------------------------------------------------------------
|
||||
|
||||
// Compilation parameters for in[n, wi, c] * wei[k, x, c] = out[n, wo, k]
|
||||
using device_conv1d_fwd_xdl_nwc_kxc_nwk_f32_instances = std::tuple<
|
||||
// clang-format off
|
||||
//################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| NumDim| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer|
|
||||
//################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization|Spatial| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
|
||||
//################################################################| | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
|
||||
//################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
using device_conv1d_fwd_xdl_nwc_kxc_nwk_1x1_p0_f32_instances = std::tuple<
|
||||
// clang-format off
|
||||
//################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| NumDim| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer|
|
||||
//################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization|Spatial| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
|
||||
//################################################################| | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
|
||||
//################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 1, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 1, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 1, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 1, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 1, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 1, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 1, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 1, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 1, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 1, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 1, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 1, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 1, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
using device_conv1d_fwd_xdl_nwc_kxc_nwk_1x1_s1_p0_f32_instances = std::tuple<
|
||||
// clang-format off
|
||||
//################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| NumDim| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer|
|
||||
//################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization|Spatial| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
|
||||
//################################################################| | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
|
||||
//################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 1, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 1, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 1, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 1, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 1, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 1, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 1, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 1, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 1, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 1, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 1, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 1, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 1, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
void add_device_conv1d_fwd_xdl_nwc_kxc_nwk_f32_instances(
|
||||
std::vector<DeviceConvFwdPtr<PassThrough, PassThrough, PassThrough>>& instances)
|
||||
{
|
||||
add_device_operation_instances(instances, device_conv1d_fwd_xdl_nwc_kxc_nwk_f32_instances{});
|
||||
add_device_operation_instances(instances,
|
||||
device_conv1d_fwd_xdl_nwc_kxc_nwk_1x1_p0_f32_instances{});
|
||||
add_device_operation_instances(instances,
|
||||
device_conv1d_fwd_xdl_nwc_kxc_nwk_1x1_s1_p0_f32_instances{});
|
||||
}
|
||||
|
||||
} // namespace device_conv1d_fwd_instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -1,6 +1,6 @@
|
||||
#include <stdlib.h>
|
||||
#include "config.hpp"
|
||||
#include "device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp"
|
||||
#include "device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp"
|
||||
#include "element_wise_operation.hpp"
|
||||
#include "device_operation_instance.hpp"
|
||||
|
||||
@@ -28,67 +28,67 @@ static constexpr auto ConvFwd1x1S1P0 =
|
||||
// Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k]
|
||||
using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances = std::tuple<
|
||||
// clang-format off
|
||||
//################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer|
|
||||
//################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
|
||||
//################################################################| | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
|
||||
//################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
|
||||
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
|
||||
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
|
||||
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
|
||||
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
|
||||
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
|
||||
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
|
||||
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
|
||||
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
|
||||
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
|
||||
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
|
||||
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
|
||||
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>
|
||||
//################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| NumDim| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer|
|
||||
//################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization|Spatial| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
|
||||
//################################################################| | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
|
||||
//################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_p0_f32_instances = std::tuple<
|
||||
// clang-format off
|
||||
//################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer|
|
||||
//################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
|
||||
//################################################################| | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
|
||||
//################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
|
||||
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
|
||||
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
|
||||
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
|
||||
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
|
||||
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
|
||||
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
|
||||
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
|
||||
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
|
||||
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
|
||||
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
|
||||
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
|
||||
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>
|
||||
//################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| NumDim| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer|
|
||||
//################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization|Spatial| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
|
||||
//################################################################| | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
|
||||
//################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_f32_instances = std::tuple<
|
||||
// clang-format off
|
||||
//################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer|
|
||||
//################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
|
||||
//################################################################| | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
|
||||
//################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
|
||||
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
|
||||
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
|
||||
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
|
||||
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
|
||||
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
|
||||
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
|
||||
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
|
||||
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
|
||||
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
|
||||
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
|
||||
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
|
||||
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>
|
||||
//################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| NumDim| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer|
|
||||
//################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization|Spatial| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
|
||||
//################################################################| | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
|
||||
//################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
|
||||
65
example/11_convnd_fwd_xdl/README.md
Normal file
65
example/11_convnd_fwd_xdl/README.md
Normal file
@@ -0,0 +1,65 @@
|
||||
# Instructions for ```convnd_fwd_xdl``` Example
|
||||
|
||||
## Docker script
|
||||
```bash
|
||||
docker run \
|
||||
-it \
|
||||
--rm \
|
||||
--privileged \
|
||||
--group-add sudo \
|
||||
-w /root/workspace \
|
||||
-v ${PATH_TO_LOCAL_WORKSPACE}:/root/workspace \
|
||||
rocm/tensorflow:rocm4.3.1-tf2.6-dev \
|
||||
/bin/bash
|
||||
```
|
||||
|
||||
## Build ```convnd_fwd_xdl```
|
||||
```bash
|
||||
mkdir build && cd build
|
||||
```
|
||||
|
||||
```bash
|
||||
# Need to specify target ID, example below is gfx908
|
||||
cmake \
|
||||
-D BUILD_DEV=OFF \
|
||||
-D CMAKE_BUILD_TYPE=Release \
|
||||
-D CMAKE_CXX_FLAGS="-DCK_AMD_GPU_GFX908 --amdgpu-target=gfx908 -O3 " \
|
||||
-D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \
|
||||
-D CMAKE_PREFIX_PATH=/opt/rocm \
|
||||
..
|
||||
```
|
||||
|
||||
```bash
|
||||
make -j convnd_fwd_xdl
|
||||
```
|
||||
|
||||
## Run ```convnd_fwd_xdl```
|
||||
```bash
|
||||
#arg1: verification (0=no, 1=yes)
|
||||
#arg2: initialization (0=no init, 1=integer value, 2=decimal value)
|
||||
#arg3: run kernel # of times (>1)
|
||||
#arg4: N spatial dimensions (default 2)
|
||||
#Following arguments (depending on number of spatial dims):
|
||||
# N, K, C,
|
||||
# <filter spatial dimensions>, (ie Y, X for 2D)
|
||||
# <input image spatial dimensions>, (ie Hi, Wi for 2D)
|
||||
# <strides>, (ie Sy, Sx for 2D)
|
||||
# <dilations>, (ie Dy, Dx for 2D)
|
||||
# <left padding>, (ie LeftPy, LeftPx for 2D)
|
||||
# <right padding>, (ie RightPy, RightPx for 2D)
|
||||
./example/convnd_fwd_xdl 0 1 100
|
||||
```
|
||||
|
||||
Result (MI100 @ 1087Mhz, 33.4TFlops peak FP32)
|
||||
```
|
||||
input: dim 4, lengths {128, 192, 71, 71}, strides {967872, 1, 13632, 192}
|
||||
weights: dim 4, lengths {256, 192, 3, 3}, strides {1728, 1, 576, 192}
|
||||
output: dim 4, lengths {128, 256, 36, 36}, strides {331776, 1, 9216, 256}
|
||||
arg.a_grid_desc_k0_m_k1_{432, 165888, 4}
|
||||
arg.b_grid_desc_k0_n_k1_{432, 256, 4}
|
||||
arg.c_grid_desc_m_n_{ 165888, 256}
|
||||
launch_and_time_kernel: grid_dim {1296, 1, 1}, block_dim {256, 1, 1}
|
||||
Warm up
|
||||
Start running 100 times...
|
||||
Perf: 4.43736 ms, 33.0753 TFlops, 150.357 GB/s
|
||||
```
|
||||
379
example/11_convnd_fwd_xdl/convnd_fwd_xdl.cpp
Normal file
379
example/11_convnd_fwd_xdl/convnd_fwd_xdl.cpp
Normal file
@@ -0,0 +1,379 @@
|
||||
#include <cstdlib>
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <type_traits>
|
||||
|
||||
#include "config.hpp"
|
||||
#include "conv_utils.hpp"
|
||||
#include "device.hpp"
|
||||
#include "device_tensor.hpp"
|
||||
#include "device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp"
|
||||
#include "element_wise_operation.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "host_tensor_generator.hpp"
|
||||
#include "reference_conv_fwd.hpp"
|
||||
#include "tensor_layout.hpp"
|
||||
|
||||
using InDataType = float;
|
||||
using WeiDataType = float;
|
||||
using OutDataType = float;
|
||||
using AccDataType = float;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using InElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using WeiElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using OutElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
static constexpr auto ConvFwdDefault =
|
||||
ck::tensor_operation::device::ConvolutionForwardSpecialization_t::Default;
|
||||
|
||||
using DeviceConvFwdBasePtr =
|
||||
ck::tensor_operation::device::DeviceConvFwdPtr<InElementOp, WeiElementOp, OutElementOp>;
|
||||
|
||||
template <ck::index_t NumDimSpatial>
|
||||
using DeviceConvNDFwdInstance = ck::tensor_operation::device::
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<
|
||||
// clang-format off
|
||||
InDataType, //
|
||||
WeiDataType, //
|
||||
OutDataType, //
|
||||
AccDataType, //
|
||||
InElementOp, // Input Elementwise Operation
|
||||
WeiElementOp, // Weights Elementwise Operation
|
||||
OutElementOp, // Output Elementwise Operation
|
||||
ConvFwdDefault, // ConvForwardSpecialization
|
||||
NumDimSpatial, // NumDimSpatial
|
||||
256, // BlockSize
|
||||
256, // MPerBlock
|
||||
128, // NPerBlock
|
||||
4, // K0PerBlock
|
||||
4, // K1
|
||||
32, // MPerXDL
|
||||
32, // NPerXDL
|
||||
4, // MXdlPerWave
|
||||
2, // NXdlPerWave
|
||||
S<4, 64, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1
|
||||
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
|
||||
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
|
||||
2, // ABlockTransferSrcVectorDim
|
||||
4, // ABlockTransferSrcScalarPerVector
|
||||
4, // ABlockTransferDstScalarPerVector_K1
|
||||
true, // ABlockLdsAddExtraM
|
||||
S<4, 64, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1
|
||||
S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
|
||||
S<1, 0, 2>, // BBlockTransferSrcAccessOrder
|
||||
2, // BBlockTransferSrcVectorDim
|
||||
4, // BBlockTransferSrcScalarPerVector
|
||||
4, // BBlockTransferDstScalarPerVector_K1
|
||||
true, // BBlockTransferAddExtraN
|
||||
7, // CThreadTransferSrcDstVectorDim
|
||||
1>; // CThreadTransferDstScalarPerVector
|
||||
// clang-format on
|
||||
|
||||
template <ck::index_t NumDimSpatial>
|
||||
using ReferenceConvNDFwdInstance = ck::tensor_operation::host::ReferenceConvFwd<InDataType,
|
||||
WeiDataType,
|
||||
OutDataType,
|
||||
InElementOp,
|
||||
WeiElementOp,
|
||||
OutElementOp,
|
||||
NumDimSpatial>;
|
||||
|
||||
DeviceConvFwdBasePtr GetConvInstance(int num_dim_spatial)
|
||||
{
|
||||
switch(num_dim_spatial)
|
||||
{
|
||||
case 2: {
|
||||
return std::make_unique<DeviceConvNDFwdInstance<2>>();
|
||||
}
|
||||
case 1: {
|
||||
return std::make_unique<DeviceConvNDFwdInstance<1>>();
|
||||
}
|
||||
default: {
|
||||
throw std::runtime_error("Unsupported number of spatial dimensions provided!");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void PrintUseMsg()
|
||||
{
|
||||
std::cout << "arg1: verification (0=no, 1=yes)\n"
|
||||
<< "arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"
|
||||
<< "arg3: run kernel # of times (>1)\n"
|
||||
<< "arg4: N spatial dimensions (default 2)\n"
|
||||
<< "Following arguments (depending on number of spatial dims):\n"
|
||||
<< " N, K, C, \n"
|
||||
<< " <filter spatial dimensions>, (ie Y, X for 2D)\n"
|
||||
<< " <input image spatial dimensions>, (ie Hi, Wi for 2D)\n"
|
||||
<< " <strides>, (ie Sy, Sx for 2D)\n"
|
||||
<< " <dilations>, (ie Dy, Dx for 2D)\n"
|
||||
<< " <left padding>, (ie LeftPy, LeftPx for 2D)\n"
|
||||
<< " <right padding>, (ie RightPy, RightPx for 2D)\n"
|
||||
<< std::endl;
|
||||
}
|
||||
|
||||
ck::conv_util::ConvParams ParseConvParams(int num_dim_spatial, int argc, char* argv[])
|
||||
{
|
||||
// (N, K, C) + num_dim_spatial * 6 (filter, input, strides, dilations, pad left, pad right)
|
||||
int conv_args = 3 + num_dim_spatial * 6;
|
||||
int cmdline_nargs = conv_args + 5;
|
||||
if(cmdline_nargs != argc)
|
||||
{
|
||||
PrintUseMsg();
|
||||
exit(0);
|
||||
}
|
||||
|
||||
ck::conv_util::ConvParams params;
|
||||
int arg_idx = 5;
|
||||
|
||||
params.num_dim_spatial = num_dim_spatial;
|
||||
params.N = std::stoi(argv[arg_idx++]);
|
||||
params.K = std::stoi(argv[arg_idx++]);
|
||||
params.C = std::stoi(argv[arg_idx++]);
|
||||
|
||||
params.filter_spatial_lengths.resize(num_dim_spatial);
|
||||
for(int i = 0; i < num_dim_spatial; ++i)
|
||||
{
|
||||
params.filter_spatial_lengths[i] = std::stoi(argv[arg_idx++]);
|
||||
}
|
||||
params.input_spatial_lengths.resize(num_dim_spatial);
|
||||
for(int i = 0; i < num_dim_spatial; ++i)
|
||||
{
|
||||
params.input_spatial_lengths[i] = std::stoi(argv[arg_idx++]);
|
||||
}
|
||||
params.conv_filter_strides.resize(num_dim_spatial);
|
||||
for(int i = 0; i < num_dim_spatial; ++i)
|
||||
{
|
||||
params.conv_filter_strides[i] = std::stoi(argv[arg_idx++]);
|
||||
}
|
||||
params.conv_filter_dilations.resize(num_dim_spatial);
|
||||
for(int i = 0; i < num_dim_spatial; ++i)
|
||||
{
|
||||
params.conv_filter_dilations[i] = std::stoi(argv[arg_idx++]);
|
||||
}
|
||||
params.input_left_pads.resize(num_dim_spatial);
|
||||
for(int i = 0; i < num_dim_spatial; ++i)
|
||||
{
|
||||
params.input_left_pads[i] = std::stoi(argv[arg_idx++]);
|
||||
}
|
||||
params.input_right_pads.resize(num_dim_spatial);
|
||||
for(int i = 0; i < num_dim_spatial; ++i)
|
||||
{
|
||||
params.input_right_pads[i] = std::stoi(argv[arg_idx++]);
|
||||
}
|
||||
|
||||
return params;
|
||||
}
|
||||
|
||||
HostTensorDescriptor GetOutputHostTensorDescriptor(const std::vector<std::size_t>& dims,
|
||||
int num_dim_spatial = 2)
|
||||
{
|
||||
namespace tl = ck::tensor_layout::convolution;
|
||||
|
||||
switch(num_dim_spatial)
|
||||
{
|
||||
case 2: {
|
||||
return ck::conv_util::GetHostTensorDescriptor(dims, tl::NHWK{});
|
||||
}
|
||||
case 1: {
|
||||
return ck::conv_util::GetHostTensorDescriptor(dims, tl::NWK{});
|
||||
}
|
||||
default: {
|
||||
throw std::runtime_error("Unsupported number of spatial dimensions provided!");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
HostTensorDescriptor GetFiltersHostTensorDescriptor(const std::vector<std::size_t>& dims,
|
||||
int num_dim_spatial = 2)
|
||||
{
|
||||
namespace tl = ck::tensor_layout::convolution;
|
||||
|
||||
switch(num_dim_spatial)
|
||||
{
|
||||
case 2: {
|
||||
return ck::conv_util::GetHostTensorDescriptor(dims, tl::KYXC{});
|
||||
}
|
||||
case 1: {
|
||||
return ck::conv_util::GetHostTensorDescriptor(dims, tl::KXC{});
|
||||
}
|
||||
default: {
|
||||
throw std::runtime_error("Unsupported number of spatial dimensions provided!");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
HostTensorDescriptor GetInputHostTensorDescriptor(const std::vector<std::size_t>& dims,
|
||||
int num_dim_spatial = 2)
|
||||
{
|
||||
namespace tl = ck::tensor_layout::convolution;
|
||||
|
||||
switch(num_dim_spatial)
|
||||
{
|
||||
case 2: {
|
||||
return ck::conv_util::GetHostTensorDescriptor(dims, tl::NHWC{});
|
||||
}
|
||||
case 1: {
|
||||
return ck::conv_util::GetHostTensorDescriptor(dims, tl::NWC{});
|
||||
}
|
||||
default: {
|
||||
throw std::runtime_error("Unsupported number of spatial dimensions provided!");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
bool do_verification = 0;
|
||||
int init_method = 0;
|
||||
int nrepeat = 5;
|
||||
int num_dim_spatial = 2;
|
||||
|
||||
ck::conv_util::ConvParams params;
|
||||
|
||||
if(argc >= 5)
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
nrepeat = std::stoi(argv[3]);
|
||||
num_dim_spatial = std::stoi(argv[4]);
|
||||
}
|
||||
|
||||
if(argc >= 6)
|
||||
{
|
||||
params = ParseConvParams(num_dim_spatial, argc, argv);
|
||||
}
|
||||
|
||||
std::vector<std::size_t> input_dims{static_cast<std::size_t>(params.N),
|
||||
static_cast<std::size_t>(params.C)};
|
||||
input_dims.insert(std::end(input_dims),
|
||||
std::begin(params.input_spatial_lengths),
|
||||
std::end(params.input_spatial_lengths));
|
||||
|
||||
std::vector<std::size_t> filter_dims{static_cast<std::size_t>(params.K),
|
||||
static_cast<std::size_t>(params.C)};
|
||||
filter_dims.insert(std::end(filter_dims),
|
||||
std::begin(params.filter_spatial_lengths),
|
||||
std::end(params.filter_spatial_lengths));
|
||||
|
||||
const std::vector<ck::index_t>& output_spatial_lengths = params.GetOutputSpatialLengths();
|
||||
std::vector<std::size_t> output_dims{static_cast<std::size_t>(params.N),
|
||||
static_cast<std::size_t>(params.K)};
|
||||
output_dims.insert(std::end(output_dims),
|
||||
std::begin(output_spatial_lengths),
|
||||
std::end(output_spatial_lengths));
|
||||
|
||||
Tensor<InDataType> input(GetInputHostTensorDescriptor(input_dims, num_dim_spatial));
|
||||
Tensor<WeiDataType> weights(GetFiltersHostTensorDescriptor(filter_dims, num_dim_spatial));
|
||||
Tensor<OutDataType> host_output(GetOutputHostTensorDescriptor(output_dims, num_dim_spatial));
|
||||
Tensor<OutDataType> device_output(GetOutputHostTensorDescriptor(output_dims, num_dim_spatial));
|
||||
|
||||
std::cout << "input: " << input.mDesc << std::endl;
|
||||
std::cout << "weights: " << weights.mDesc << std::endl;
|
||||
std::cout << "output: " << host_output.mDesc << std::endl;
|
||||
|
||||
switch(init_method)
|
||||
{
|
||||
case 0: break;
|
||||
case 1:
|
||||
input.GenerateTensorValue(GeneratorTensor_2<InDataType>{-5, 5});
|
||||
weights.GenerateTensorValue(GeneratorTensor_2<WeiDataType>{-5, 5});
|
||||
break;
|
||||
default:
|
||||
input.GenerateTensorValue(GeneratorTensor_3<InDataType>{0.0, 1.0});
|
||||
weights.GenerateTensorValue(GeneratorTensor_3<WeiDataType>{-0.5, 0.5});
|
||||
}
|
||||
|
||||
DeviceMem in_device_buf(sizeof(InDataType) * input.mDesc.GetElementSpace());
|
||||
DeviceMem wei_device_buf(sizeof(WeiDataType) * weights.mDesc.GetElementSpace());
|
||||
DeviceMem out_device_buf(sizeof(OutDataType) * device_output.mDesc.GetElementSpace());
|
||||
|
||||
in_device_buf.ToDevice(input.mData.data());
|
||||
wei_device_buf.ToDevice(weights.mData.data());
|
||||
|
||||
// do GEMM
|
||||
auto conv = GetConvInstance(num_dim_spatial);
|
||||
auto invoker = conv->MakeInvokerPointer();
|
||||
auto argument =
|
||||
conv->MakeArgumentPointer(static_cast<InDataType*>(in_device_buf.GetDeviceBuffer()),
|
||||
static_cast<WeiDataType*>(wei_device_buf.GetDeviceBuffer()),
|
||||
static_cast<OutDataType*>(out_device_buf.GetDeviceBuffer()),
|
||||
params.N,
|
||||
params.K,
|
||||
params.C,
|
||||
params.input_spatial_lengths,
|
||||
params.filter_spatial_lengths,
|
||||
output_spatial_lengths,
|
||||
params.conv_filter_strides,
|
||||
params.conv_filter_dilations,
|
||||
params.input_left_pads,
|
||||
params.input_right_pads,
|
||||
InElementOp{},
|
||||
WeiElementOp{},
|
||||
OutElementOp{});
|
||||
|
||||
if(!conv->IsSupportedArgument(argument.get()))
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"wrong! device_conv with the specified compilation parameters does "
|
||||
"not support this Conv problem");
|
||||
}
|
||||
|
||||
float ave_time = invoker->Run(argument.get(), nrepeat);
|
||||
|
||||
std::size_t flop = ck::conv_util::GetFlops(
|
||||
params.N, params.C, params.K, params.filter_spatial_lengths, output_spatial_lengths);
|
||||
std::size_t num_btype =
|
||||
ck::conv_util::GetBtype<InDataType, WeiDataType, OutDataType>(params.N,
|
||||
params.C,
|
||||
params.K,
|
||||
params.input_spatial_lengths,
|
||||
params.filter_spatial_lengths,
|
||||
output_spatial_lengths);
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s"
|
||||
<< std::endl;
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
auto verify_f = [&input, &weights, &host_output, ¶ms, &out_device_buf, &device_output](
|
||||
const auto& ref_conv) {
|
||||
auto ref_invoker = ref_conv.MakeInvoker();
|
||||
auto ref_argument = ref_conv.MakeArgument(input,
|
||||
weights,
|
||||
host_output,
|
||||
params.conv_filter_strides,
|
||||
params.conv_filter_dilations,
|
||||
params.input_left_pads,
|
||||
params.input_right_pads,
|
||||
InElementOp{},
|
||||
WeiElementOp{},
|
||||
OutElementOp{});
|
||||
|
||||
ref_invoker.Run(ref_argument);
|
||||
out_device_buf.FromDevice(device_output.mData.data());
|
||||
check_error(host_output, device_output);
|
||||
};
|
||||
|
||||
switch(num_dim_spatial)
|
||||
{
|
||||
case 2: {
|
||||
auto ref_conv = ReferenceConvNDFwdInstance<2>();
|
||||
verify_f(ref_conv);
|
||||
break;
|
||||
}
|
||||
case 1: {
|
||||
auto ref_conv = ReferenceConvNDFwdInstance<1>();
|
||||
verify_f(ref_conv);
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
throw std::runtime_error("Unsupported number of spatial dimensions provided!");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -23,6 +23,7 @@ set(CONV2D_FWD_XDL_BIAS_RELU_ATOMIC_ADD_SOURCE 7_conv2d_fwd_xdl_bias_relu_atomic
|
||||
set(GEMM_XDL_ALPHA_BETA_SOURCE 8_gemm_xdl_alpha_beta/gemm_xdl_alpha_beta.cpp)
|
||||
set(CONV2D_FWD_XDL_INT8_SOURCE 9_conv2d_fwd_xdl_int8/conv2d_fwd_xdl_int8.cpp)
|
||||
set(CONV3D_FWD_XDL_SOURCE 10_conv3d_fwd_xdl/conv3d_fwd_xdl.cpp)
|
||||
set(CONVND_FWD_XDL_SOURCE 11_convnd_fwd_xdl/convnd_fwd_xdl.cpp)
|
||||
|
||||
add_executable(gemm_xdl ${GEMM_XDL_SOURCE})
|
||||
add_executable(gemm_xdl_bias_relu ${GEMM_XDL_BIAS_RELU_SOURCE})
|
||||
@@ -34,6 +35,7 @@ add_executable(conv2d_fwd_xdl_bias_relu_atomic_add ${CONV2D_FWD_XDL_BIAS_RELU_AT
|
||||
add_executable(gemm_xdl_alpha_beta ${GEMM_XDL_ALPHA_BETA_SOURCE})
|
||||
add_executable(conv2d_fwd_xdl_int8 ${CONV2D_FWD_XDL_INT8_SOURCE})
|
||||
add_executable(conv3d_fwd_xdl ${CONV3D_FWD_XDL_SOURCE})
|
||||
add_executable(convnd_fwd_xdl ${CONVND_FWD_XDL_SOURCE})
|
||||
|
||||
target_link_libraries(gemm_xdl PRIVATE host_tensor)
|
||||
target_link_libraries(gemm_xdl_bias_relu PRIVATE host_tensor)
|
||||
@@ -45,4 +47,4 @@ target_link_libraries(conv2d_fwd_xdl_bias_relu_atomic_add PRIVATE host_tensor)
|
||||
target_link_libraries(gemm_xdl_alpha_beta PRIVATE host_tensor)
|
||||
target_link_libraries(conv2d_fwd_xdl_int8 PRIVATE host_tensor)
|
||||
target_link_libraries(conv3d_fwd_xdl PRIVATE host_tensor)
|
||||
|
||||
target_link_libraries(convnd_fwd_xdl PRIVATE host_tensor)
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
#define REFERENCE_CONV_FWD_HPP
|
||||
|
||||
#include <iostream>
|
||||
#include <type_traits>
|
||||
#include <sstream>
|
||||
#include "device_base.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
@@ -10,21 +11,38 @@ namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace host {
|
||||
|
||||
// out[N, K, Ho, Wo] = in[N, C, Hi, Wi] * wei[K, C, Y, X]
|
||||
//
|
||||
// @brief Reference implementation for forward convolution.
|
||||
//
|
||||
// @paragraph Supported tensor layouts. Input tensor supports NCHiWi data layout.
|
||||
// Weights tensor supports KCYX data layout. Output tensor supports
|
||||
// NKHoWo data layout.
|
||||
//
|
||||
// @tparam InDataType Input tensor data type.
|
||||
// @tparam WeiDataType Weights tensor data type.
|
||||
// @tparam OutDataType Output tensor data type.
|
||||
// @tparam InElementwiseOperation Functor for input tensor elementwise
|
||||
// operation.
|
||||
// @tparam WeiElementwiseOperation Functor for weights tensor elementwise
|
||||
// operation.
|
||||
// @tparam NumDimSpatial Number of spatial dimensions.
|
||||
//
|
||||
template <typename InDataType,
|
||||
typename WeiDataType,
|
||||
typename OutDataType,
|
||||
typename InElementwiseOperation,
|
||||
typename WeiElementwiseOperation,
|
||||
typename OutElementwiseOperation>
|
||||
typename OutElementwiseOperation,
|
||||
ck::index_t NumDimSpatial = 2,
|
||||
typename std::enable_if<NumDimSpatial >= 1 && NumDimSpatial <= 3, bool>::type = false>
|
||||
struct ReferenceConvFwd : public device::BaseOperator
|
||||
{
|
||||
// Argument
|
||||
struct Argument : public device::BaseArgument
|
||||
{
|
||||
Argument(const Tensor<InDataType>& in_n_c_hi_wi,
|
||||
const Tensor<WeiDataType>& wei_k_c_y_x,
|
||||
Tensor<OutDataType>& out_n_k_ho_wo,
|
||||
Argument(const Tensor<InDataType>& input,
|
||||
const Tensor<WeiDataType>& weight,
|
||||
Tensor<OutDataType>& output,
|
||||
std::vector<ck::index_t> conv_filter_strides,
|
||||
std::vector<ck::index_t> conv_filter_dilations,
|
||||
std::vector<ck::index_t> input_left_pads,
|
||||
@@ -32,9 +50,9 @@ struct ReferenceConvFwd : public device::BaseOperator
|
||||
InElementwiseOperation in_element_op,
|
||||
WeiElementwiseOperation wei_element_op,
|
||||
OutElementwiseOperation out_element_op)
|
||||
: in_n_c_hi_wi_{in_n_c_hi_wi},
|
||||
wei_k_c_y_x_{wei_k_c_y_x},
|
||||
out_n_k_ho_wo_{out_n_k_ho_wo},
|
||||
: input_{input},
|
||||
weight_{weight},
|
||||
output_{output},
|
||||
conv_strides_{conv_filter_strides},
|
||||
conv_dilations_{conv_filter_dilations},
|
||||
in_left_pads_{input_left_pads},
|
||||
@@ -45,9 +63,9 @@ struct ReferenceConvFwd : public device::BaseOperator
|
||||
{
|
||||
}
|
||||
|
||||
const Tensor<InDataType>& in_n_c_hi_wi_;
|
||||
const Tensor<WeiDataType>& wei_k_c_y_x_;
|
||||
Tensor<OutDataType>& out_n_k_ho_wo_;
|
||||
const Tensor<InDataType>& input_;
|
||||
const Tensor<WeiDataType>& weight_;
|
||||
Tensor<OutDataType>& output_;
|
||||
|
||||
std::vector<index_t> conv_strides_;
|
||||
std::vector<index_t> conv_dilations_;
|
||||
@@ -59,58 +77,98 @@ struct ReferenceConvFwd : public device::BaseOperator
|
||||
OutElementwiseOperation out_element_op_;
|
||||
};
|
||||
|
||||
// Invoker
|
||||
struct Invoker : public device::BaseInvoker
|
||||
{
|
||||
using Argument = ReferenceConvFwd::Argument;
|
||||
|
||||
float Run(const Argument& arg)
|
||||
{
|
||||
auto f_nchw = [&](auto n, auto k, auto ho, auto wo) {
|
||||
float v_acc = 0;
|
||||
if constexpr(NumDimSpatial == 1)
|
||||
{
|
||||
auto f_ncw = [&](auto n, auto k, auto wo) {
|
||||
float v_acc = 0;
|
||||
|
||||
for(int c = 0; c < arg.wei_k_c_y_x_.mDesc.GetLengths()[1]; ++c)
|
||||
{
|
||||
for(int y = 0; y < arg.wei_k_c_y_x_.mDesc.GetLengths()[2]; ++y)
|
||||
for(int c = 0; c < arg.weight_.mDesc.GetLengths()[1]; ++c)
|
||||
{
|
||||
int hi = ho * arg.conv_strides_[0] + y * arg.conv_dilations_[0] -
|
||||
arg.in_left_pads_[0];
|
||||
for(int x = 0; x < arg.wei_k_c_y_x_.mDesc.GetLengths()[3]; ++x)
|
||||
for(int x = 0; x < arg.weight_.mDesc.GetLengths()[2]; ++x)
|
||||
{
|
||||
int wi = wo * arg.conv_strides_[1] + x * arg.conv_dilations_[1] -
|
||||
arg.in_left_pads_[1];
|
||||
if(hi >= 0 && hi < arg.in_n_c_hi_wi_.mDesc.GetLengths()[2] && wi >= 0 &&
|
||||
wi < arg.in_n_c_hi_wi_.mDesc.GetLengths()[3])
|
||||
int wi = wo * arg.conv_strides_[0] + x * arg.conv_dilations_[0] -
|
||||
arg.in_left_pads_[0];
|
||||
if(wi >= 0 && wi < arg.input_.mDesc.GetLengths()[2])
|
||||
{
|
||||
float v_in;
|
||||
float v_wei;
|
||||
|
||||
arg.in_element_op_(
|
||||
v_in, ck::type_convert<float>(arg.in_n_c_hi_wi_(n, c, hi, wi)));
|
||||
arg.wei_element_op_(
|
||||
v_wei, ck::type_convert<float>(arg.wei_k_c_y_x_(k, c, y, x)));
|
||||
arg.in_element_op_(v_in,
|
||||
static_cast<const float>(arg.input_(n, c, wi)));
|
||||
arg.wei_element_op_(v_wei,
|
||||
static_cast<const float>(arg.weight_(k, c, x)));
|
||||
|
||||
v_acc += v_in * v_wei;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
float v_out;
|
||||
float v_out;
|
||||
|
||||
arg.out_element_op_(v_out, v_acc);
|
||||
arg.out_element_op_(v_out, v_acc);
|
||||
arg.output_(n, k, wo) = v_out;
|
||||
};
|
||||
|
||||
arg.out_n_k_ho_wo_(n, k, ho, wo) = ck::type_convert<OutDataType>(v_out);
|
||||
};
|
||||
make_ParallelTensorFunctor(f_ncw,
|
||||
arg.output_.mDesc.GetLengths()[0],
|
||||
arg.output_.mDesc.GetLengths()[1],
|
||||
arg.output_.mDesc.GetLengths()[2])(
|
||||
std::thread::hardware_concurrency());
|
||||
|
||||
make_ParallelTensorFunctor(f_nchw,
|
||||
arg.out_n_k_ho_wo_.mDesc.GetLengths()[0],
|
||||
arg.out_n_k_ho_wo_.mDesc.GetLengths()[1],
|
||||
arg.out_n_k_ho_wo_.mDesc.GetLengths()[2],
|
||||
arg.out_n_k_ho_wo_.mDesc.GetLengths()[3])(
|
||||
std::thread::hardware_concurrency());
|
||||
return 0;
|
||||
}
|
||||
else if constexpr(NumDimSpatial == 2)
|
||||
{
|
||||
auto f_nchw = [&](auto n, auto k, auto ho, auto wo) {
|
||||
float v_acc = 0;
|
||||
|
||||
return 0;
|
||||
for(int c = 0; c < arg.weight_.mDesc.GetLengths()[1]; ++c)
|
||||
{
|
||||
for(int y = 0; y < arg.weight_.mDesc.GetLengths()[2]; ++y)
|
||||
{
|
||||
int hi = ho * arg.conv_strides_[0] + y * arg.conv_dilations_[0] -
|
||||
arg.in_left_pads_[0];
|
||||
for(int x = 0; x < arg.weight_.mDesc.GetLengths()[3]; ++x)
|
||||
{
|
||||
int wi = wo * arg.conv_strides_[1] + x * arg.conv_dilations_[1] -
|
||||
arg.in_left_pads_[1];
|
||||
if(hi >= 0 && hi < arg.input_.mDesc.GetLengths()[2] && wi >= 0 &&
|
||||
wi < arg.input_.mDesc.GetLengths()[3])
|
||||
{
|
||||
float v_in;
|
||||
float v_wei;
|
||||
|
||||
arg.in_element_op_(
|
||||
v_in, ck::type_convert<float>(arg.input_(n, c, hi, wi)));
|
||||
arg.wei_element_op_(
|
||||
v_wei, ck::type_convert<float>(arg.weight_(k, c, y, x)));
|
||||
v_acc += v_in * v_wei;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
float v_out;
|
||||
|
||||
arg.out_element_op_(v_out, v_acc);
|
||||
arg.output_(n, k, ho, wo) = ck::type_convert<OutDataType>(v_out);
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(f_nchw,
|
||||
arg.output_.mDesc.GetLengths()[0],
|
||||
arg.output_.mDesc.GetLengths()[1],
|
||||
arg.output_.mDesc.GetLengths()[2],
|
||||
arg.output_.mDesc.GetLengths()[3])(
|
||||
std::thread::hardware_concurrency());
|
||||
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
float Run(const device::BaseArgument* p_arg, int) override
|
||||
@@ -127,9 +185,9 @@ struct ReferenceConvFwd : public device::BaseOperator
|
||||
|
||||
bool IsSupportedArgument(const device::BaseArgument*) override { return true; }
|
||||
|
||||
static auto MakeArgument(const Tensor<InDataType>& in_n_c_hi_wi,
|
||||
const Tensor<WeiDataType>& wei_k_c_y_x,
|
||||
Tensor<OutDataType>& out_n_k_ho_wo,
|
||||
static auto MakeArgument(const Tensor<InDataType>& input,
|
||||
const Tensor<WeiDataType>& weight,
|
||||
Tensor<OutDataType>& output,
|
||||
std::vector<ck::index_t> conv_filter_strides,
|
||||
std::vector<ck::index_t> conv_filter_dilations,
|
||||
std::vector<ck::index_t> input_left_pads,
|
||||
@@ -138,9 +196,9 @@ struct ReferenceConvFwd : public device::BaseOperator
|
||||
WeiElementwiseOperation wei_element_op,
|
||||
OutElementwiseOperation out_element_op)
|
||||
{
|
||||
return Argument{in_n_c_hi_wi,
|
||||
wei_k_c_y_x,
|
||||
out_n_k_ho_wo,
|
||||
return Argument{input,
|
||||
weight,
|
||||
output,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
|
||||
@@ -10,6 +10,7 @@ include_directories(BEFORE
|
||||
${PROJECT_SOURCE_DIR}/composable_kernel/include/problem_transform
|
||||
${PROJECT_SOURCE_DIR}/external/rocm/include
|
||||
${PROJECT_SOURCE_DIR}/reference_operation/include
|
||||
${PROJECT_SOURCE_DIR}/test/include
|
||||
)
|
||||
|
||||
# test_magic_number_division
|
||||
@@ -30,3 +31,17 @@ add_executable(test_split_k ${SPLIT_K_SOURCE})
|
||||
target_link_libraries(test_split_k PRIVATE host_tensor)
|
||||
target_link_libraries(test_split_k PRIVATE device_gemm_instance)
|
||||
|
||||
# test_conv_util
|
||||
set(CONV_UTIL_SOURCE conv_util/main.cpp)
|
||||
add_executable(test_conv_util ${CONV_UTIL_SOURCE})
|
||||
target_link_libraries(test_conv_util PRIVATE host_tensor)
|
||||
|
||||
# test_reference_conv_fwd
|
||||
set(REFERENCE_CONV_FWD_SOURCE reference_conv_fwd/main.cpp)
|
||||
add_executable(test_reference_conv_fwd ${REFERENCE_CONV_FWD_SOURCE})
|
||||
target_link_libraries(test_reference_conv_fwd PRIVATE host_tensor)
|
||||
|
||||
# test_convnd_fwd_xdl
|
||||
set(CONVND_FWD_XDL_SOURCE convnd_fwd_xdl/main.cpp)
|
||||
add_executable(test_convnd_fwd_xdl ${CONVND_FWD_XDL_SOURCE})
|
||||
target_link_libraries(test_convnd_fwd_xdl PRIVATE host_tensor)
|
||||
|
||||
157
test/conv_util/main.cpp
Normal file
157
test/conv_util/main.cpp
Normal file
@@ -0,0 +1,157 @@
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "config.hpp"
|
||||
#include "conv_utils.hpp"
|
||||
#include "tensor_layout.hpp"
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename T>
|
||||
bool cmp_vec(const std::vector<T>& out, const std::vector<T>& ref, const std::string& msg)
|
||||
{
|
||||
if(out.size() != ref.size())
|
||||
{
|
||||
std::cout << "out.size() != ref.size(), :" << out.size() << " != " << ref.size()
|
||||
<< std::endl
|
||||
<< msg << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
for(std::size_t i = 0; i < ref.size(); ++i)
|
||||
{
|
||||
if(out[i] != ref[i])
|
||||
{
|
||||
std::cout << "out[" << i << "] != ref[" << i << "]: " << out[i] << "!=" << ref[i]
|
||||
<< std::endl
|
||||
<< msg << std::endl;
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool TestConvParams_GetOutputSpatialLengths()
|
||||
{
|
||||
bool res{true};
|
||||
// -------------------------- default 2D ------------------------------------
|
||||
// input NCHW {128,192,71,71},
|
||||
// weights KCYX {256,192,3,3},
|
||||
// stride {2,2},
|
||||
// dilations {1,1},
|
||||
// padding {{1,1}, {1,1}}
|
||||
ck::conv_util::ConvParams conv_params;
|
||||
std::vector<ck::index_t> out_spatial_len = conv_params.GetOutputSpatialLengths();
|
||||
res = cmp_vec(out_spatial_len,
|
||||
std::vector<ck::index_t>{36, 36},
|
||||
"Error: ConvParams 2D default constructor.");
|
||||
|
||||
conv_params.conv_filter_strides = std::vector<ck::index_t>{1, 1};
|
||||
out_spatial_len = conv_params.GetOutputSpatialLengths();
|
||||
res = cmp_vec(
|
||||
out_spatial_len, std::vector<ck::index_t>{71, 71}, "Error: ConvParams 2D stride {1,1}.");
|
||||
|
||||
conv_params.conv_filter_strides = std::vector<ck::index_t>{2, 2};
|
||||
conv_params.input_left_pads = std::vector<ck::index_t>{2, 2};
|
||||
conv_params.input_right_pads = std::vector<ck::index_t>{2, 2};
|
||||
out_spatial_len = conv_params.GetOutputSpatialLengths();
|
||||
res = cmp_vec(out_spatial_len,
|
||||
std::vector<ck::index_t>{37, 37},
|
||||
"Error: ConvParams 2D padding left/right {2,2}.");
|
||||
|
||||
conv_params.conv_filter_dilations = std::vector<ck::index_t>{2, 2};
|
||||
out_spatial_len = conv_params.GetOutputSpatialLengths();
|
||||
res = cmp_vec(
|
||||
out_spatial_len, std::vector<ck::index_t>{36, 36}, "Error: ConvParams 2D dilation {2,2}.");
|
||||
|
||||
conv_params.conv_filter_strides = std::vector<ck::index_t>{3, 3};
|
||||
conv_params.input_left_pads = std::vector<ck::index_t>{1, 1};
|
||||
conv_params.input_right_pads = std::vector<ck::index_t>{1, 1};
|
||||
conv_params.conv_filter_dilations = std::vector<ck::index_t>{2, 2};
|
||||
out_spatial_len = conv_params.GetOutputSpatialLengths();
|
||||
res = cmp_vec(out_spatial_len,
|
||||
std::vector<ck::index_t>{23, 23},
|
||||
"Error: ConvParams 2D strides{3,3}, padding {1,1}, dilations {2,2}.");
|
||||
|
||||
// -------------------------- 1D ------------------------------------
|
||||
conv_params.num_dim_spatial = 1;
|
||||
conv_params.filter_spatial_lengths = std::vector<ck::index_t>{3};
|
||||
conv_params.input_spatial_lengths = std::vector<ck::index_t>{71};
|
||||
conv_params.conv_filter_strides = std::vector<ck::index_t>{2};
|
||||
conv_params.conv_filter_dilations = std::vector<ck::index_t>{1};
|
||||
conv_params.input_left_pads = std::vector<ck::index_t>{1};
|
||||
conv_params.input_right_pads = std::vector<ck::index_t>{1};
|
||||
|
||||
out_spatial_len = conv_params.GetOutputSpatialLengths();
|
||||
res = cmp_vec(
|
||||
out_spatial_len, std::vector<ck::index_t>{36}, "Error: ConvParams 1D default constructor.");
|
||||
|
||||
conv_params.conv_filter_strides = std::vector<ck::index_t>{1, 1};
|
||||
out_spatial_len = conv_params.GetOutputSpatialLengths();
|
||||
res =
|
||||
cmp_vec(out_spatial_len, std::vector<ck::index_t>{71}, "Error: ConvParams 1D stride {1}.");
|
||||
|
||||
conv_params.conv_filter_strides = std::vector<ck::index_t>{2};
|
||||
conv_params.input_left_pads = std::vector<ck::index_t>{2};
|
||||
conv_params.input_right_pads = std::vector<ck::index_t>{2};
|
||||
out_spatial_len = conv_params.GetOutputSpatialLengths();
|
||||
res = cmp_vec(out_spatial_len,
|
||||
std::vector<ck::index_t>{37},
|
||||
"Error: ConvParams 1D padding left/right {2}.");
|
||||
|
||||
conv_params.conv_filter_dilations = std::vector<ck::index_t>{2};
|
||||
out_spatial_len = conv_params.GetOutputSpatialLengths();
|
||||
res = cmp_vec(
|
||||
out_spatial_len, std::vector<ck::index_t>{36}, "Error: ConvParams 1D dilation {2}.");
|
||||
|
||||
conv_params.conv_filter_strides = std::vector<ck::index_t>{3};
|
||||
conv_params.input_left_pads = std::vector<ck::index_t>{1};
|
||||
conv_params.input_right_pads = std::vector<ck::index_t>{1};
|
||||
conv_params.conv_filter_dilations = std::vector<ck::index_t>{2};
|
||||
out_spatial_len = conv_params.GetOutputSpatialLengths();
|
||||
res = cmp_vec(out_spatial_len,
|
||||
std::vector<ck::index_t>{23},
|
||||
"Error: ConvParams 1D strides{3}, padding {1}, dilations {2}.");
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
bool TestGetHostTensorDescriptor()
|
||||
{
|
||||
bool res{true};
|
||||
namespace tl = ck::tensor_layout::convolution;
|
||||
std::vector<std::size_t> dims{2, 3, 4, 5};
|
||||
HostTensorDescriptor h = ck::conv_util::GetHostTensorDescriptor(dims, tl::NHWC{});
|
||||
res = cmp_vec(h.GetLengths(), {2, 3, 4, 5}, "Error: wrong NHWC dimensions lengths!");
|
||||
res =
|
||||
cmp_vec(h.GetStrides(), {3 * 4 * 5, 1, 3 * 5, 3}, "Error: wrong NHWC dimensions strides!");
|
||||
|
||||
h = ck::conv_util::GetHostTensorDescriptor(dims, tl::NCHW{});
|
||||
res = cmp_vec(h.GetLengths(), {2, 3, 4, 5}, "Error: wrong NCHW dimensions lengths!");
|
||||
res =
|
||||
cmp_vec(h.GetStrides(), {3 * 4 * 5, 4 * 5, 5, 1}, "Error: wrong NCHW dimensions strides!");
|
||||
|
||||
dims = std::vector<std::size_t>{2, 3, 4};
|
||||
h = ck::conv_util::GetHostTensorDescriptor(dims, tl::NWC{});
|
||||
res = cmp_vec(h.GetLengths(), {2, 3, 4}, "Error: wrong NWC dimensions lengths!");
|
||||
res = cmp_vec(h.GetStrides(), {3 * 4, 1, 3}, "Error: wrong NWC dimensions strides!");
|
||||
|
||||
h = ck::conv_util::GetHostTensorDescriptor(dims, tl::NCW{});
|
||||
res = cmp_vec(h.GetLengths(), {2, 3, 4}, "Error: wrong NCW dimensions lengths!");
|
||||
res = cmp_vec(h.GetStrides(), {3 * 4, 4, 1}, "Error: wrong NCW dimensions strides!");
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
int main(void)
|
||||
{
|
||||
bool res = TestConvParams_GetOutputSpatialLengths();
|
||||
std::cout << "TestConvParams_GetOutputSpatialLengths ..... " << (res ? "SUCCESS" : "FAILURE")
|
||||
<< std::endl;
|
||||
res = TestGetHostTensorDescriptor();
|
||||
std::cout << "TestGetHostTensorDescriptor ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl;
|
||||
return 0;
|
||||
}
|
||||
262
test/convnd_fwd_xdl/main.cpp
Normal file
262
test/convnd_fwd_xdl/main.cpp
Normal file
@@ -0,0 +1,262 @@
|
||||
#include <algorithm>
|
||||
#include <cstdlib>
|
||||
#include <half.hpp>
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <tuple>
|
||||
#include <vector>
|
||||
|
||||
#include "config.hpp"
|
||||
#include "conv_utils.hpp"
|
||||
#include "device.hpp"
|
||||
#include "device_tensor.hpp"
|
||||
#include "device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp"
|
||||
#include "element_wise_operation.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "reference_conv_fwd.hpp"
|
||||
#include "tensor_layout.hpp"
|
||||
#include "test_util.hpp"
|
||||
|
||||
namespace {
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using InElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using WeiElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using OutElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
static constexpr auto ConvFwdDefault =
|
||||
ck::tensor_operation::device::ConvolutionForwardSpecialization_t::Default;
|
||||
|
||||
template <ck::index_t SpatialDims, typename InDataType, typename WeiDataType, typename OutDataType>
|
||||
using DeviceConvNDFwdInstance = ck::tensor_operation::device::
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<
|
||||
// clang-format off
|
||||
InDataType, //
|
||||
WeiDataType, //
|
||||
OutDataType, //
|
||||
InDataType, //
|
||||
InElementOp, // Input Elementwise Operation
|
||||
WeiElementOp, // Weights Elementwise Operation
|
||||
OutElementOp, // Output Elementwise Operation
|
||||
ConvFwdDefault, // ConvForwardSpecialization
|
||||
SpatialDims, // SptialDims
|
||||
64, // BlockSize
|
||||
16, // MPerBlock
|
||||
16, // NPerBlock
|
||||
4, // K0PerBlock
|
||||
1, // K1
|
||||
16, // MPerXDL
|
||||
16, // NPerXDL
|
||||
1, // MXdlPerWave
|
||||
1, // NXdlPerWave
|
||||
S<1, 16, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1
|
||||
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
|
||||
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
|
||||
2, // ABlockTransferSrcVectorDim
|
||||
1, // ABlockTransferSrcScalarPerVector
|
||||
1, // ABlockTransferDstScalarPerVector_K1
|
||||
true, // ABlockLdsAddExtraM
|
||||
S<1, 16, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1
|
||||
S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
|
||||
S<1, 0, 2>, // BBlockTransferSrcAccessOrder
|
||||
2, // BBlockTransferSrcVectorDim
|
||||
1, // BBlockTransferSrcScalarPerVector
|
||||
1, // BBlockTransferDstScalarPerVector_K1
|
||||
true, // BBlockTransferAddExtraN
|
||||
7, // CThreadTransferSrcDstVectorDim
|
||||
1>; // CThreadTransferDstScalarPerVector
|
||||
// clang-format on
|
||||
|
||||
template <typename InDataType = float,
|
||||
typename WeiDataType = float,
|
||||
typename OutDataType = float,
|
||||
typename InLayout = ck::tensor_layout::convolution::NHWC,
|
||||
typename WeiLayout = ck::tensor_layout::convolution::KYXC,
|
||||
typename OutLayout = ck::tensor_layout::convolution::NHWK>
|
||||
auto GetHostTensors(const ck::conv_util::ConvParams& params)
|
||||
{
|
||||
std::vector<std::size_t> input_dims{static_cast<std::size_t>(params.N),
|
||||
static_cast<std::size_t>(params.C)};
|
||||
input_dims.insert(std::end(input_dims),
|
||||
std::begin(params.input_spatial_lengths),
|
||||
std::end(params.input_spatial_lengths));
|
||||
|
||||
std::vector<std::size_t> filter_dims{static_cast<std::size_t>(params.K),
|
||||
static_cast<std::size_t>(params.C)};
|
||||
filter_dims.insert(std::end(filter_dims),
|
||||
std::begin(params.filter_spatial_lengths),
|
||||
std::end(params.filter_spatial_lengths));
|
||||
|
||||
const std::vector<ck::index_t>& output_spatial_lengths = params.GetOutputSpatialLengths();
|
||||
std::vector<std::size_t> output_dims{static_cast<std::size_t>(params.N),
|
||||
static_cast<std::size_t>(params.K)};
|
||||
output_dims.insert(std::end(output_dims),
|
||||
std::begin(output_spatial_lengths),
|
||||
std::end(output_spatial_lengths));
|
||||
|
||||
Tensor<InDataType> input(ck::conv_util::GetHostTensorDescriptor(input_dims, InLayout{}));
|
||||
Tensor<WeiDataType> weights(ck::conv_util::GetHostTensorDescriptor(filter_dims, WeiLayout{}));
|
||||
Tensor<OutDataType> host_output(
|
||||
ck::conv_util::GetHostTensorDescriptor(output_dims, OutLayout{}));
|
||||
Tensor<OutDataType> device_output(
|
||||
ck::conv_util::GetHostTensorDescriptor(output_dims, OutLayout{}));
|
||||
|
||||
std::generate(input.begin(), input.end(), [n = 0]() mutable {
|
||||
return InDataType(n++) * InDataType(0.1f);
|
||||
});
|
||||
std::fill(weights.begin(), weights.end(), WeiDataType(0.5f));
|
||||
std::fill(host_output.begin(), host_output.end(), OutDataType(0.f));
|
||||
std::fill(device_output.begin(), device_output.end(), OutDataType(0.f));
|
||||
|
||||
return std::make_tuple(input, weights, host_output, device_output);
|
||||
}
|
||||
|
||||
template <ck::index_t NDim,
|
||||
typename InDataType = float,
|
||||
typename WeiDataType = float,
|
||||
typename OutDataType = float>
|
||||
void RunReferenceConv(const ck::conv_util::ConvParams& params,
|
||||
const Tensor<InDataType>& input,
|
||||
const Tensor<WeiDataType>& weights,
|
||||
Tensor<OutDataType>& output)
|
||||
{
|
||||
auto ref_conv = ck::tensor_operation::host::ReferenceConvFwd<InDataType,
|
||||
WeiDataType,
|
||||
OutDataType,
|
||||
InElementOp,
|
||||
WeiElementOp,
|
||||
OutElementOp,
|
||||
NDim>();
|
||||
auto ref_invoker = ref_conv.MakeInvoker();
|
||||
auto ref_argument = ref_conv.MakeArgument(input,
|
||||
weights,
|
||||
output,
|
||||
params.conv_filter_strides,
|
||||
params.conv_filter_dilations,
|
||||
params.input_left_pads,
|
||||
params.input_right_pads,
|
||||
InElementOp{},
|
||||
WeiElementOp{},
|
||||
OutElementOp{});
|
||||
|
||||
ref_invoker.Run(ref_argument);
|
||||
}
|
||||
|
||||
template <ck::index_t NDim,
|
||||
typename InDataType = float,
|
||||
typename WeiDataType = float,
|
||||
typename OutDataType = float>
|
||||
void RunConv(const ck::conv_util::ConvParams& params,
|
||||
const Tensor<InDataType>& input,
|
||||
const Tensor<WeiDataType>& weights,
|
||||
Tensor<OutDataType>& output)
|
||||
{
|
||||
DeviceMem in_device_buf(sizeof(InDataType) * input.mDesc.GetElementSpace());
|
||||
DeviceMem wei_device_buf(sizeof(WeiDataType) * weights.mDesc.GetElementSpace());
|
||||
DeviceMem out_device_buf(sizeof(OutDataType) * output.mDesc.GetElementSpace());
|
||||
|
||||
in_device_buf.ToDevice(input.mData.data());
|
||||
wei_device_buf.ToDevice(weights.mData.data());
|
||||
const std::vector<ck::index_t>& output_spatial_lengths = params.GetOutputSpatialLengths();
|
||||
|
||||
auto conv = DeviceConvNDFwdInstance<NDim, InDataType, WeiDataType, OutDataType>();
|
||||
auto invoker = conv.MakeInvoker();
|
||||
auto argument = conv.MakeArgument(static_cast<InDataType*>(in_device_buf.GetDeviceBuffer()),
|
||||
static_cast<WeiDataType*>(wei_device_buf.GetDeviceBuffer()),
|
||||
static_cast<OutDataType*>(out_device_buf.GetDeviceBuffer()),
|
||||
params.N,
|
||||
params.K,
|
||||
params.C,
|
||||
params.input_spatial_lengths,
|
||||
params.filter_spatial_lengths,
|
||||
output_spatial_lengths,
|
||||
params.conv_filter_strides,
|
||||
params.conv_filter_dilations,
|
||||
params.input_left_pads,
|
||||
params.input_right_pads,
|
||||
InElementOp{},
|
||||
WeiElementOp{},
|
||||
OutElementOp{});
|
||||
|
||||
if(!conv.IsSupportedArgument(argument))
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"Error! device_conv with the specified compilation parameters does "
|
||||
"not support this Conv problem");
|
||||
}
|
||||
|
||||
invoker.Run(argument);
|
||||
out_device_buf.FromDevice(output.mData.data());
|
||||
}
|
||||
|
||||
bool TestConv2DNHWC()
|
||||
{
|
||||
bool res{true};
|
||||
ck::conv_util::ConvParams params;
|
||||
params.N = 2;
|
||||
params.K = 16;
|
||||
params.C = 4;
|
||||
params.input_spatial_lengths = std::vector<ck::index_t>{16, 16};
|
||||
params.conv_filter_strides = std::vector<ck::index_t>{1, 1};
|
||||
|
||||
auto host_tensors = GetHostTensors(params);
|
||||
const Tensor<float>& input = std::get<0>(host_tensors);
|
||||
const Tensor<float>& weights = std::get<1>(host_tensors);
|
||||
Tensor<float>& host_output = std::get<2>(host_tensors);
|
||||
Tensor<float>& device_output = std::get<3>(host_tensors);
|
||||
|
||||
RunReferenceConv<2>(params, input, weights, host_output);
|
||||
RunConv<2>(params, input, weights, device_output);
|
||||
res = res &&
|
||||
test_util::check_err(
|
||||
device_output.mData, host_output.mData, "Error: incorrect results!", 1e-5f, 1e-4f);
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
bool TestConv1DNWC()
|
||||
{
|
||||
bool res{true};
|
||||
ck::conv_util::ConvParams params;
|
||||
params.num_dim_spatial = 1;
|
||||
params.N = 2;
|
||||
params.K = 16;
|
||||
params.C = 4;
|
||||
params.filter_spatial_lengths = std::vector<ck::index_t>{3};
|
||||
params.input_spatial_lengths = std::vector<ck::index_t>{16};
|
||||
params.conv_filter_strides = std::vector<ck::index_t>{1};
|
||||
params.conv_filter_dilations = std::vector<ck::index_t>{1};
|
||||
params.input_left_pads = std::vector<ck::index_t>{1};
|
||||
params.input_right_pads = std::vector<ck::index_t>{1};
|
||||
|
||||
auto host_tensors = GetHostTensors<float,
|
||||
float,
|
||||
float,
|
||||
ck::tensor_layout::convolution::NWC,
|
||||
ck::tensor_layout::convolution::KXC,
|
||||
ck::tensor_layout::convolution::NWK>(params);
|
||||
const Tensor<float>& input = std::get<0>(host_tensors);
|
||||
const Tensor<float>& weights = std::get<1>(host_tensors);
|
||||
Tensor<float>& host_output = std::get<2>(host_tensors);
|
||||
Tensor<float>& device_output = std::get<3>(host_tensors);
|
||||
|
||||
RunReferenceConv<1>(params, input, weights, host_output);
|
||||
RunConv<1>(params, input, weights, device_output);
|
||||
res = res &&
|
||||
test_util::check_err(
|
||||
device_output.mData, host_output.mData, "Error: incorrect results!", 1e-5f, 1e-4f);
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
int main()
|
||||
{
|
||||
bool res{true};
|
||||
res = TestConv1DNWC();
|
||||
std::cout << "TestConv1DNWC ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl;
|
||||
res = TestConv2DNHWC();
|
||||
std::cout << "TestConv2DNHWC ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl;
|
||||
}
|
||||
84
test/include/test_util.hpp
Normal file
84
test/include/test_util.hpp
Normal file
@@ -0,0 +1,84 @@
|
||||
#ifndef TEST_UTIL_HPP
|
||||
#define TEST_UTIL_HPP
|
||||
|
||||
#include <cmath>
|
||||
#include <cstdlib>
|
||||
#include <iostream>
|
||||
#include <iomanip>
|
||||
#include <limits>
|
||||
#include <type_traits>
|
||||
#include <vector>
|
||||
|
||||
namespace test_util {
|
||||
|
||||
template <typename T>
|
||||
typename std::enable_if<std::is_floating_point<T>::value, bool>::type
|
||||
check_err(const std::vector<T>& out,
|
||||
const std::vector<T>& ref,
|
||||
const std::string& msg,
|
||||
T rtol = static_cast<T>(1e-5),
|
||||
T atol = static_cast<T>(1e-8))
|
||||
{
|
||||
if(out.size() != ref.size())
|
||||
{
|
||||
std::cout << "out.size() != ref.size(), :" << out.size() << " != " << ref.size()
|
||||
<< std::endl
|
||||
<< msg << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
bool res{true};
|
||||
int err_count = 0;
|
||||
T err = 0;
|
||||
T max_err = std::numeric_limits<T>::min();
|
||||
for(std::size_t i = 0; i < ref.size(); ++i)
|
||||
{
|
||||
err = std::abs(out[i] - ref[i]);
|
||||
if(err > atol + rtol * std::abs(ref[i]) || !std::isfinite(out[i]) || !std::isfinite(ref[i]))
|
||||
{
|
||||
max_err = err > max_err ? err : max_err;
|
||||
err_count++;
|
||||
if(err_count < 5)
|
||||
{
|
||||
std::cout << std::setw(12) << std::setprecision(7) << "out[" << i << "] != ref["
|
||||
<< i << "]: " << out[i] << "!=" << ref[i] << std::endl
|
||||
<< msg << std::endl;
|
||||
}
|
||||
res = false;
|
||||
}
|
||||
}
|
||||
if(!res)
|
||||
{
|
||||
std::cout << std::setw(12) << std::setprecision(7) << "max err: " << max_err << std::endl;
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
typename std::enable_if<std::is_integral<T>::value, bool>::type check_err(
|
||||
const std::vector<T>& out, const std::vector<T>& ref, const std::string& msg, T = 0, T = 0)
|
||||
{
|
||||
if(out.size() != ref.size())
|
||||
{
|
||||
std::cout << "out.size() != ref.size(), :" << out.size() << " != " << ref.size()
|
||||
<< std::endl
|
||||
<< msg << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
for(std::size_t i = 0; i < ref.size(); ++i)
|
||||
{
|
||||
if(out[i] != ref[i])
|
||||
{
|
||||
std::cout << "out[" << i << "] != ref[" << i << "]: " << out[i] << "!=" << ref[i]
|
||||
<< std::endl
|
||||
<< msg << std::endl;
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace test_util
|
||||
|
||||
#endif
|
||||
333
test/reference_conv_fwd/main.cpp
Normal file
333
test/reference_conv_fwd/main.cpp
Normal file
@@ -0,0 +1,333 @@
|
||||
#include <algorithm>
|
||||
#include <cmath>
|
||||
#include <cstdlib>
|
||||
#include <half.hpp>
|
||||
#include <numeric>
|
||||
#include <type_traits>
|
||||
#include <vector>
|
||||
|
||||
#include "config.hpp"
|
||||
#include "conv_utils.hpp"
|
||||
#include "element_wise_operation.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "reference_conv_fwd.hpp"
|
||||
#include "tensor_layout.hpp"
|
||||
#include "test_util.hpp"
|
||||
|
||||
namespace {
|
||||
using InElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using WeiElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using OutElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
template <typename T>
|
||||
struct FillMonotonicSeq
|
||||
{
|
||||
T m_init_value{0};
|
||||
|
||||
template <typename ForwardIter>
|
||||
void operator()(ForwardIter first, ForwardIter last) const
|
||||
{
|
||||
std::iota(first, last, m_init_value);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct FillConstant
|
||||
{
|
||||
T m_value{0};
|
||||
|
||||
template <typename ForwardIter>
|
||||
void operator()(ForwardIter first, ForwardIter last) const
|
||||
{
|
||||
std::fill(first, last, m_value);
|
||||
}
|
||||
};
|
||||
|
||||
template <ck::index_t NDim,
|
||||
typename InDataType = float,
|
||||
typename WeiDataType = float,
|
||||
typename OutDataType = float,
|
||||
typename InLayout = ck::tensor_layout::convolution::NHWC,
|
||||
typename WeiLayout = ck::tensor_layout::convolution::KYXC,
|
||||
typename OutLayout = ck::tensor_layout::convolution::NHWK,
|
||||
typename FillInputOp = FillMonotonicSeq<InDataType>,
|
||||
typename FillWeightsOp = FillConstant<WeiDataType>>
|
||||
Tensor<OutDataType> RunReferenceConv(const ck::conv_util::ConvParams& params,
|
||||
const FillInputOp& fill_input_op = FillInputOp{0},
|
||||
const FillWeightsOp& fill_weights_op = FillWeightsOp{0.5f})
|
||||
{
|
||||
std::vector<std::size_t> input_dims{static_cast<std::size_t>(params.N),
|
||||
static_cast<std::size_t>(params.C)};
|
||||
input_dims.insert(std::end(input_dims),
|
||||
std::begin(params.input_spatial_lengths),
|
||||
std::end(params.input_spatial_lengths));
|
||||
|
||||
std::vector<std::size_t> filter_dims{static_cast<std::size_t>(params.K),
|
||||
static_cast<std::size_t>(params.C)};
|
||||
filter_dims.insert(std::end(filter_dims),
|
||||
std::begin(params.filter_spatial_lengths),
|
||||
std::end(params.filter_spatial_lengths));
|
||||
|
||||
const std::vector<ck::index_t>& output_spatial_lengths = params.GetOutputSpatialLengths();
|
||||
std::vector<std::size_t> output_dims{static_cast<std::size_t>(params.N),
|
||||
static_cast<std::size_t>(params.K)};
|
||||
output_dims.insert(std::end(output_dims),
|
||||
std::begin(output_spatial_lengths),
|
||||
std::end(output_spatial_lengths));
|
||||
|
||||
Tensor<InDataType> input(ck::conv_util::GetHostTensorDescriptor(input_dims, InLayout{}));
|
||||
Tensor<WeiDataType> weights(ck::conv_util::GetHostTensorDescriptor(filter_dims, WeiLayout{}));
|
||||
Tensor<OutDataType> host_output(
|
||||
ck::conv_util::GetHostTensorDescriptor(output_dims, OutLayout{}));
|
||||
|
||||
fill_input_op(input.begin(), input.end());
|
||||
fill_weights_op(weights.begin(), weights.end());
|
||||
std::fill(host_output.begin(), host_output.end(), OutDataType(0.f));
|
||||
|
||||
auto ref_conv = ck::tensor_operation::host::ReferenceConvFwd<InDataType,
|
||||
WeiDataType,
|
||||
OutDataType,
|
||||
InElementOp,
|
||||
WeiElementOp,
|
||||
OutElementOp,
|
||||
NDim>();
|
||||
auto ref_invoker = ref_conv.MakeInvoker();
|
||||
auto ref_argument = ref_conv.MakeArgument(input,
|
||||
weights,
|
||||
host_output,
|
||||
params.conv_filter_strides,
|
||||
params.conv_filter_dilations,
|
||||
params.input_left_pads,
|
||||
params.input_right_pads,
|
||||
InElementOp{},
|
||||
WeiElementOp{},
|
||||
OutElementOp{});
|
||||
|
||||
ref_invoker.Run(ref_argument);
|
||||
return host_output;
|
||||
}
|
||||
|
||||
bool TestConv2DNHWC()
|
||||
{
|
||||
bool res{true};
|
||||
ck::conv_util::ConvParams params;
|
||||
params.N = 1;
|
||||
params.K = 1;
|
||||
params.C = 2;
|
||||
params.filter_spatial_lengths = std::vector<ck::index_t>{3, 3};
|
||||
params.input_spatial_lengths = std::vector<ck::index_t>{6, 6};
|
||||
params.conv_filter_strides = std::vector<ck::index_t>{1, 1};
|
||||
params.conv_filter_dilations = std::vector<ck::index_t>{1, 1};
|
||||
params.input_left_pads = std::vector<ck::index_t>{0, 0};
|
||||
params.input_right_pads = std::vector<ck::index_t>{0, 0};
|
||||
|
||||
auto out_tensor = RunReferenceConv<2>(params);
|
||||
std::vector<std::size_t> ref_dims{1, 1, 4, 4};
|
||||
std::vector<float> ref_data{130.5,
|
||||
148.5,
|
||||
166.5,
|
||||
184.5,
|
||||
238.5,
|
||||
256.5,
|
||||
274.5,
|
||||
292.5,
|
||||
346.5,
|
||||
364.5,
|
||||
382.5,
|
||||
400.5,
|
||||
454.5,
|
||||
472.5,
|
||||
490.5,
|
||||
508.5};
|
||||
res = res && test_util::check_err(out_tensor.mDesc.GetLengths(),
|
||||
ref_dims,
|
||||
"Error: wrong output tensor dimensions!");
|
||||
res = res && test_util::check_err(out_tensor.mData, ref_data, "Error: incorrect results!");
|
||||
|
||||
params.N = 1;
|
||||
params.K = 2;
|
||||
params.C = 2;
|
||||
params.filter_spatial_lengths = std::vector<ck::index_t>{3, 3};
|
||||
params.input_spatial_lengths = std::vector<ck::index_t>{12, 12};
|
||||
params.conv_filter_strides = std::vector<ck::index_t>{2, 2};
|
||||
params.conv_filter_dilations = std::vector<ck::index_t>{2, 2};
|
||||
params.input_left_pads = std::vector<ck::index_t>{1, 1};
|
||||
params.input_right_pads = std::vector<ck::index_t>{1, 1};
|
||||
|
||||
out_tensor = RunReferenceConv<2>(params);
|
||||
ref_dims = std::vector<std::size_t>{1, 2, 5, 5};
|
||||
ref_data = std::vector<float>{
|
||||
210., 210., 327., 327., 351., 351., 375., 375., 399., 399.,
|
||||
459., 459., 706.5, 706.5, 742.5, 742.5, 778.5, 778.5, 814.5, 814.5,
|
||||
747., 747., 1138.5, 1138.5, 1174.5, 1174.5, 1210.5, 1210.5, 1246.5, 1246.5,
|
||||
1035., 1035., 1570.5, 1570.5, 1606.5, 1606.5, 1642.5, 1642.5, 1678.5, 1678.5,
|
||||
1323., 1323., 2002.5, 2002.5, 2038.5, 2038.5, 2074.5, 2074.5, 2110.5, 2110.5};
|
||||
res = res && test_util::check_err(out_tensor.mDesc.GetLengths(),
|
||||
ref_dims,
|
||||
"Error: wrong output tensor dimensions!");
|
||||
res = res && test_util::check_err(out_tensor.mData, ref_data, "Error: incorrect results!");
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
bool TestConv1DNWC()
|
||||
{
|
||||
bool res{true};
|
||||
ck::conv_util::ConvParams params;
|
||||
params.num_dim_spatial = 1;
|
||||
params.N = 1;
|
||||
params.K = 1;
|
||||
params.C = 2;
|
||||
params.filter_spatial_lengths = std::vector<ck::index_t>{3};
|
||||
params.input_spatial_lengths = std::vector<ck::index_t>{6};
|
||||
params.conv_filter_strides = std::vector<ck::index_t>{1};
|
||||
params.conv_filter_dilations = std::vector<ck::index_t>{1};
|
||||
params.input_left_pads = std::vector<ck::index_t>{0};
|
||||
params.input_right_pads = std::vector<ck::index_t>{0};
|
||||
|
||||
auto out_tensor = RunReferenceConv<1,
|
||||
float,
|
||||
float,
|
||||
float,
|
||||
ck::tensor_layout::convolution::NWC,
|
||||
ck::tensor_layout::convolution::KXC,
|
||||
ck::tensor_layout::convolution::NWK>(params);
|
||||
std::vector<std::size_t> ref_dims{1, 1, 4};
|
||||
std::vector<float> ref_data{7.5, 13.5, 19.5, 25.5};
|
||||
res = res && test_util::check_err(out_tensor.mDesc.GetLengths(),
|
||||
ref_dims,
|
||||
"Error: wrong output tensor dimensions!");
|
||||
res = res && test_util::check_err(out_tensor.mData, ref_data, "Error: incorrect results!");
|
||||
|
||||
params.num_dim_spatial = 1;
|
||||
params.N = 1;
|
||||
params.K = 2;
|
||||
params.C = 2;
|
||||
params.filter_spatial_lengths = std::vector<ck::index_t>{3};
|
||||
params.input_spatial_lengths = std::vector<ck::index_t>{12};
|
||||
params.conv_filter_strides = std::vector<ck::index_t>{2};
|
||||
params.conv_filter_dilations = std::vector<ck::index_t>{2};
|
||||
params.input_left_pads = std::vector<ck::index_t>{1};
|
||||
params.input_right_pads = std::vector<ck::index_t>{1};
|
||||
|
||||
out_tensor = RunReferenceConv<1,
|
||||
float,
|
||||
float,
|
||||
float,
|
||||
ck::tensor_layout::convolution::NWC,
|
||||
ck::tensor_layout::convolution::KXC,
|
||||
ck::tensor_layout::convolution::NWK>(params);
|
||||
ref_dims = std::vector<std::size_t>{1, 2, 5};
|
||||
ref_data = std::vector<float>{9., 9., 19.5, 19.5, 31.5, 31.5, 43.5, 43.5, 55.5, 55.5};
|
||||
res = res && test_util::check_err(out_tensor.mDesc.GetLengths(),
|
||||
ref_dims,
|
||||
"Error: wrong output tensor dimensions!");
|
||||
res = res && test_util::check_err(out_tensor.mData, ref_data, "Error: incorrect results!");
|
||||
|
||||
params.num_dim_spatial = 1;
|
||||
params.N = 2;
|
||||
params.K = 16;
|
||||
params.C = 4;
|
||||
params.filter_spatial_lengths = std::vector<ck::index_t>{3};
|
||||
params.input_spatial_lengths = std::vector<ck::index_t>{16};
|
||||
params.conv_filter_strides = std::vector<ck::index_t>{1};
|
||||
params.conv_filter_dilations = std::vector<ck::index_t>{1};
|
||||
params.input_left_pads = std::vector<ck::index_t>{1};
|
||||
params.input_right_pads = std::vector<ck::index_t>{1};
|
||||
|
||||
auto out_tensor2 =
|
||||
RunReferenceConv<1,
|
||||
float,
|
||||
float,
|
||||
float,
|
||||
ck::tensor_layout::convolution::NWC,
|
||||
ck::tensor_layout::convolution::KXC,
|
||||
ck::tensor_layout::convolution::NWK>(params, [](auto first, auto last) {
|
||||
std::generate(first, last, [n = 0]() mutable { return float(n++) * float(0.1f); });
|
||||
});
|
||||
|
||||
ref_dims = std::vector<std::size_t>{2, 16, 16};
|
||||
ref_data = std::vector<float>{
|
||||
1.4, 1.4, 1.4, 1.4, 1.4, 1.4, 1.4, 1.4,
|
||||
1.4, 1.4, 1.4, 1.4, 1.4, 1.4, 1.4, 1.4,
|
||||
3.3, 3.3, 3.3, 3.3, 3.3, 3.3, 3.3, 3.3,
|
||||
3.3, 3.3, 3.3, 3.3, 3.3, 3.3, 3.3, 3.3,
|
||||
5.7, 5.7, 5.7, 5.7, 5.7, 5.7, 5.7, 5.7,
|
||||
5.7, 5.7, 5.7, 5.7, 5.7, 5.7, 5.7, 5.7,
|
||||
8.1, 8.1, 8.1, 8.1, 8.1, 8.1, 8.1, 8.1,
|
||||
8.1, 8.1, 8.1, 8.1, 8.1, 8.1, 8.1, 8.1,
|
||||
10.5, 10.5, 10.5, 10.5, 10.5, 10.5, 10.5, 10.5,
|
||||
10.5, 10.5, 10.5, 10.5, 10.5, 10.5, 10.5, 10.5,
|
||||
12.900001, 12.900001, 12.900001, 12.900001, 12.900001, 12.900001, 12.900001, 12.900001,
|
||||
12.900001, 12.900001, 12.900001, 12.900001, 12.900001, 12.900001, 12.900001, 12.900001,
|
||||
15.3, 15.3, 15.3, 15.3, 15.3, 15.3, 15.3, 15.3,
|
||||
15.3, 15.3, 15.3, 15.3, 15.3, 15.3, 15.3, 15.3,
|
||||
17.7, 17.7, 17.7, 17.7, 17.7, 17.7, 17.7, 17.7,
|
||||
17.7, 17.7, 17.7, 17.7, 17.7, 17.7, 17.7, 17.7,
|
||||
20.1, 20.1, 20.1, 20.1, 20.1, 20.1, 20.1, 20.1,
|
||||
20.1, 20.1, 20.1, 20.1, 20.1, 20.1, 20.1, 20.1,
|
||||
22.5, 22.5, 22.5, 22.5, 22.5, 22.5, 22.5, 22.5,
|
||||
22.5, 22.5, 22.5, 22.5, 22.5, 22.5, 22.5, 22.5,
|
||||
24.900002, 24.900002, 24.900002, 24.900002, 24.900002, 24.900002, 24.900002, 24.900002,
|
||||
24.900002, 24.900002, 24.900002, 24.900002, 24.900002, 24.900002, 24.900002, 24.900002,
|
||||
27.300001, 27.300001, 27.300001, 27.300001, 27.300001, 27.300001, 27.300001, 27.300001,
|
||||
27.300001, 27.300001, 27.300001, 27.300001, 27.300001, 27.300001, 27.300001, 27.300001,
|
||||
29.7, 29.7, 29.7, 29.7, 29.7, 29.7, 29.7, 29.7,
|
||||
29.7, 29.7, 29.7, 29.7, 29.7, 29.7, 29.7, 29.7,
|
||||
32.100002, 32.100002, 32.100002, 32.100002, 32.100002, 32.100002, 32.100002, 32.100002,
|
||||
32.100002, 32.100002, 32.100002, 32.100002, 32.100002, 32.100002, 32.100002, 32.100002,
|
||||
34.5, 34.5, 34.5, 34.5, 34.5, 34.5, 34.5, 34.5,
|
||||
34.5, 34.5, 34.5, 34.5, 34.5, 34.5, 34.5, 34.5,
|
||||
23.8, 23.8, 23.8, 23.8, 23.8, 23.8, 23.8, 23.8,
|
||||
23.8, 23.8, 23.8, 23.8, 23.8, 23.8, 23.8, 23.8,
|
||||
27., 27., 27., 27., 27., 27., 27., 27.,
|
||||
27., 27., 27., 27., 27., 27., 27., 27.,
|
||||
41.7, 41.7, 41.7, 41.7, 41.7, 41.7, 41.7, 41.7,
|
||||
41.7, 41.7, 41.7, 41.7, 41.7, 41.7, 41.7, 41.7,
|
||||
44.100002, 44.100002, 44.100002, 44.100002, 44.100002, 44.100002, 44.100002, 44.100002,
|
||||
44.100002, 44.100002, 44.100002, 44.100002, 44.100002, 44.100002, 44.100002, 44.100002,
|
||||
46.5, 46.5, 46.5, 46.5, 46.5, 46.5, 46.5, 46.5,
|
||||
46.5, 46.5, 46.5, 46.5, 46.5, 46.5, 46.5, 46.5,
|
||||
48.899998, 48.899998, 48.899998, 48.899998, 48.899998, 48.899998, 48.899998, 48.899998,
|
||||
48.899998, 48.899998, 48.899998, 48.899998, 48.899998, 48.899998, 48.899998, 48.899998,
|
||||
51.3, 51.3, 51.3, 51.3, 51.3, 51.3, 51.3, 51.3,
|
||||
51.3, 51.3, 51.3, 51.3, 51.3, 51.3, 51.3, 51.3,
|
||||
53.7, 53.7, 53.7, 53.7, 53.7, 53.7, 53.7, 53.7,
|
||||
53.7, 53.7, 53.7, 53.7, 53.7, 53.7, 53.7, 53.7,
|
||||
56.100002, 56.100002, 56.100002, 56.100002, 56.100002, 56.100002, 56.100002, 56.100002,
|
||||
56.100002, 56.100002, 56.100002, 56.100002, 56.100002, 56.100002, 56.100002, 56.100002,
|
||||
58.5, 58.5, 58.5, 58.5, 58.5, 58.5, 58.5, 58.5,
|
||||
58.5, 58.5, 58.5, 58.5, 58.5, 58.5, 58.5, 58.5,
|
||||
60.899998, 60.899998, 60.899998, 60.899998, 60.899998, 60.899998, 60.899998, 60.899998,
|
||||
60.899998, 60.899998, 60.899998, 60.899998, 60.899998, 60.899998, 60.899998, 60.899998,
|
||||
63.3, 63.3, 63.3, 63.3, 63.3, 63.3, 63.3, 63.3,
|
||||
63.3, 63.3, 63.3, 63.3, 63.3, 63.3, 63.3, 63.3,
|
||||
65.7, 65.7, 65.7, 65.7, 65.7, 65.7, 65.7, 65.7,
|
||||
65.7, 65.7, 65.7, 65.7, 65.7, 65.7, 65.7, 65.7,
|
||||
68.1, 68.1, 68.1, 68.1, 68.1, 68.1, 68.1, 68.1,
|
||||
68.1, 68.1, 68.1, 68.1, 68.1, 68.1, 68.1, 68.1,
|
||||
70.5, 70.5, 70.5, 70.5, 70.5, 70.5, 70.5, 70.5,
|
||||
70.5, 70.5, 70.5, 70.5, 70.5, 70.5, 70.5, 70.5,
|
||||
72.9, 72.9, 72.9, 72.9, 72.9, 72.9, 72.9, 72.9,
|
||||
72.9, 72.9, 72.9, 72.9, 72.9, 72.9, 72.9, 72.9,
|
||||
49.4, 49.4, 49.4, 49.4, 49.4, 49.4, 49.4, 49.4,
|
||||
49.4, 49.4, 49.4, 49.4, 49.4, 49.4, 49.4, 49.4};
|
||||
res = res && test_util::check_err(out_tensor2.mDesc.GetLengths(),
|
||||
ref_dims,
|
||||
"Error: wrong output tensor dimensions!");
|
||||
res = res && test_util::check_err(out_tensor2.mData, ref_data, "Error: incorrect results!");
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
int main(void)
|
||||
{
|
||||
bool res{true};
|
||||
res = TestConv2DNHWC();
|
||||
std::cout << "TestConv2DNHWC ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl;
|
||||
res = TestConv1DNWC();
|
||||
std::cout << "TestConv1DNHWC ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl;
|
||||
return 0;
|
||||
}
|
||||
Reference in New Issue
Block a user