mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-19 20:40:07 +00:00
Restructure gridwise and blockwise GEMM, add tensor contraction and FWD-v4r5 (#36)
* experimenting magic number division
* overhauling fwd-v4r4 to clearly reflect transformation graph
* added fwd-v4r5
* bug fix for make_dynamic_naive_tensor_descriptor_aligned_v2
* bug fix and added sanity-check in transform_dynamic_tensor_descriptor
* added conv_driver_v2
[ROCm/composable_kernel commit: 30072aec37]
This commit is contained in:
@@ -16,6 +16,7 @@ install(TARGETS host LIBRARY DESTINATION lib)
|
||||
|
||||
if(DEVICE_BACKEND STREQUAL "AMD")
|
||||
set(CONV_SOURCE src/conv_driver.cpp)
|
||||
set(CONV_V2_SOURCE src/conv_driver_v2.cpp)
|
||||
set(CONV_BWD_DATA_SOURCE src/conv_bwd_data_driver.cpp)
|
||||
elseif(DEVICE_BACKEND STREQUAL "NVIDIA")
|
||||
set(CONV_SOURCE src/conv_driver.cu)
|
||||
@@ -23,7 +24,9 @@ elseif(DEVICE_BACKEND STREQUAL "NVIDIA")
|
||||
endif()
|
||||
|
||||
add_executable(conv_driver ${CONV_SOURCE})
|
||||
add_executable(conv_driver_v2 ${CONV_V2_SOURCE})
|
||||
add_executable(conv_bwd_data_driver ${CONV_BWD_DATA_SOURCE})
|
||||
|
||||
target_link_libraries(conv_driver PRIVATE host)
|
||||
target_link_libraries(conv_driver_v2 PRIVATE host)
|
||||
target_link_libraries(conv_bwd_data_driver PRIVATE host)
|
||||
|
||||
@@ -2,15 +2,25 @@
|
||||
#define CONV_COMMON_HPP
|
||||
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "dynamic_tensor_descriptor.hpp"
|
||||
|
||||
enum ConvTensorLayout
|
||||
{
|
||||
NCHW,
|
||||
NHWC,
|
||||
CHWN,
|
||||
NCHWc,
|
||||
NHWCc
|
||||
};
|
||||
|
||||
template <class InDesc,
|
||||
class WeiDesc,
|
||||
class ConvStrides,
|
||||
class ConvDilations,
|
||||
class LowerPads,
|
||||
class UpperPads>
|
||||
class LeftPads,
|
||||
class RightPads>
|
||||
constexpr auto get_convolution_output_default_4d_tensor_descriptor(
|
||||
InDesc, WeiDesc, ConvStrides, ConvDilations, LowerPads, UpperPads)
|
||||
InDesc, WeiDesc, ConvStrides, ConvDilations, LeftPads, RightPads)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
@@ -35,21 +45,69 @@ constexpr auto get_convolution_output_default_4d_tensor_descriptor(
|
||||
constexpr index_t Y = wei_desc.GetLength(I2);
|
||||
constexpr index_t X = wei_desc.GetLength(I3);
|
||||
|
||||
constexpr index_t HPadLow = LowerPads{}.Get(I0);
|
||||
constexpr index_t WPadLow = LowerPads{}.Get(I1);
|
||||
constexpr index_t LeftPadH = LeftPads{}.Get(I0);
|
||||
constexpr index_t LeftPadW = LeftPads{}.Get(I1);
|
||||
|
||||
constexpr index_t HPadUp = UpperPads{}.Get(I0);
|
||||
constexpr index_t WPadUp = UpperPads{}.Get(I1);
|
||||
constexpr index_t RightPadH = RightPads{}.Get(I0);
|
||||
constexpr index_t RightPadW = RightPads{}.Get(I1);
|
||||
|
||||
constexpr index_t YEff = (Y - 1) * ConvDilations{}[0] + 1;
|
||||
constexpr index_t XEff = (X - 1) * ConvDilations{}[1] + 1;
|
||||
|
||||
constexpr index_t Ho = (Hi + HPadLow + HPadUp - YEff) / ConvStrides{}[0] + 1;
|
||||
constexpr index_t Wo = (Wi + WPadLow + WPadUp - XEff) / ConvStrides{}[1] + 1;
|
||||
constexpr index_t Ho = (Hi + LeftPadH + RightPadH - YEff) / ConvStrides{}[0] + 1;
|
||||
constexpr index_t Wo = (Wi + LeftPadW + RightPadW - XEff) / ConvStrides{}[1] + 1;
|
||||
|
||||
return make_native_tensor_descriptor_packed(Sequence<N, K, Ho, Wo>{});
|
||||
}
|
||||
|
||||
template <typename... InDesc,
|
||||
typename... WeiDesc,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename LeftPads,
|
||||
typename RightPads>
|
||||
constexpr auto get_convolution_output_default_4d_tensor_descriptor(
|
||||
const ck::DynamicTensorDescriptor<InDesc...>& in_desc,
|
||||
const ck::DynamicTensorDescriptor<WeiDesc...>& wei_desc,
|
||||
const ConvStrides& conv_strides,
|
||||
const ConvDilations conv_dilations,
|
||||
const LeftPads& left_pads,
|
||||
const RightPads& right_pads)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
assert(in_desc.GetNumOfDimension() == 4);
|
||||
assert(wei_desc.GetNumOfDimension() == 4);
|
||||
assert(in_desc.GetLength(I1) == wei_desc.GetLength(I1));
|
||||
|
||||
const auto N = in_desc.GetLength(I0);
|
||||
const auto Hi = in_desc.GetLength(I2);
|
||||
const auto Wi = in_desc.GetLength(I3);
|
||||
|
||||
const auto K = wei_desc.GetLength(I0);
|
||||
const auto Y = wei_desc.GetLength(I2);
|
||||
const auto X = wei_desc.GetLength(I3);
|
||||
|
||||
const auto LeftPadH = left_pads[I0];
|
||||
const auto LeftPadW = left_pads[I1];
|
||||
|
||||
const auto RightPadH = right_pads[I0];
|
||||
const auto RightPadW = right_pads[I1];
|
||||
|
||||
const auto YEff = (Y - I1) * conv_dilations[I0] + I1;
|
||||
const auto XEff = (X - I1) * conv_dilations[I1] + I1;
|
||||
|
||||
const auto Ho = (Hi + LeftPadH + RightPadH - YEff) / conv_strides[I0] + I1;
|
||||
const auto Wo = (Wi + LeftPadW + RightPadW - XEff) / conv_strides[I1] + I1;
|
||||
|
||||
return make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, K, Ho, Wo));
|
||||
}
|
||||
|
||||
template <class InDesc, class WeiDesc, class OutDesc>
|
||||
constexpr std::size_t
|
||||
calculate_convolution_flops(const InDesc& in_desc, const WeiDesc& wei_desc, const OutDesc& out_desc)
|
||||
|
||||
@@ -2,30 +2,29 @@
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw.hpp"
|
||||
#include "driver_dynamic_gemm_v1.hpp"
|
||||
#include "driver_dynamic_gemm_v1r2.hpp"
|
||||
|
||||
template <class TInWei,
|
||||
ck::index_t InWeiVectorSize,
|
||||
class TAcc,
|
||||
class TOut,
|
||||
class InDesc,
|
||||
class WeiDesc,
|
||||
class OutDesc,
|
||||
class ConvStrides,
|
||||
class ConvDilations,
|
||||
class InLeftPads,
|
||||
class InRightPads>
|
||||
template <typename TInWei,
|
||||
typename TAcc,
|
||||
typename TOut,
|
||||
typename InLengths,
|
||||
typename WeiLengths,
|
||||
typename OutLengths,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads>
|
||||
void device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(
|
||||
InDesc,
|
||||
const InLengths& in_n_c_hi_wi_lengths,
|
||||
const WeiLengths& wei_k_c_y_x_lengths,
|
||||
const OutLengths& out_n_k_ho_wo_lengths,
|
||||
const ConvStrides& conv_strides,
|
||||
const ConvDilations& conv_dilations,
|
||||
const InLeftPads& in_left_pads,
|
||||
const InRightPads& in_right_pads,
|
||||
const Tensor<TInWei>& in_n_c_hi_wi,
|
||||
WeiDesc,
|
||||
const Tensor<TInWei>& wei_k_c_y_x,
|
||||
OutDesc,
|
||||
Tensor<TOut>& out_n_k_ho_wo,
|
||||
ConvStrides,
|
||||
ConvDilations,
|
||||
InLeftPads,
|
||||
InRightPads,
|
||||
ck::index_t nrepeat)
|
||||
{
|
||||
using namespace ck;
|
||||
@@ -50,505 +49,155 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(
|
||||
wei_k_c_y_x_device_buf.ToDevice(wei_k_c_y_x.mData.data());
|
||||
out_n_k_ho_wo_device_buf.ToDevice(out_n_k_ho_wo.mData.data());
|
||||
|
||||
#if 1
|
||||
// run-time variables
|
||||
const auto in_n_c_hi_wi_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(to_multi_index(InDesc::GetLengths()));
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(in_n_c_hi_wi_lengths);
|
||||
const auto wei_k_c_y_x_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(to_multi_index(WeiDesc::GetLengths()));
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(wei_k_c_y_x_lengths);
|
||||
const auto out_n_k_ho_wo_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(to_multi_index(OutDesc::GetLengths()));
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(out_n_k_ho_wo_lengths);
|
||||
|
||||
const auto conv_strides = to_multi_index(ConvStrides{});
|
||||
const auto conv_dilations = to_multi_index(ConvDilations{});
|
||||
const auto in_left_pads = to_multi_index(InLeftPads{});
|
||||
const auto in_right_pads = to_multi_index(InRightPads{});
|
||||
#else
|
||||
// compile-time variables
|
||||
const auto in_n_c_hi_wi_desc = make_dynamic_naive_tensor_descriptor_packed_v2(
|
||||
sequence_to_tuple_of_number(InDesc::GetLengths()));
|
||||
const auto wei_k_c_y_x_desc = make_dynamic_naive_tensor_descriptor_packed_v2(
|
||||
sequence_to_tuple_of_number(WeiDesc::GetLengths()));
|
||||
const auto out_n_k_ho_wo_desc = make_dynamic_naive_tensor_descriptor_packed_v2(
|
||||
sequence_to_tuple_of_number(OutDesc::GetLengths()));
|
||||
|
||||
const auto conv_strides = sequence_to_tuple_of_number(ConvStrides{});
|
||||
const auto conv_dilations = sequence_to_tuple_of_number(ConvDilations{});
|
||||
const auto in_left_pads = sequence_to_tuple_of_number(InLeftPads{});
|
||||
const auto in_right_pads = sequence_to_tuple_of_number(InRightPads{});
|
||||
#endif
|
||||
|
||||
#if 0
|
||||
// cdata = 16, BlockSize = 64, 16x64x4
|
||||
constexpr index_t BlockSize = 64;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 16;
|
||||
constexpr index_t GemmNPerBlock = 64;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerThread = 2;
|
||||
constexpr index_t GemmNPerThread = 2;
|
||||
constexpr index_t GemmKPerThread = 1;
|
||||
|
||||
constexpr index_t GemmMLevel0Cluster = 2;
|
||||
constexpr index_t GemmNLevel0Cluster = 2;
|
||||
constexpr index_t GemmMLevel1Cluster = 2;
|
||||
constexpr index_t GemmNLevel1Cluster = 8;
|
||||
|
||||
constexpr index_t ThreadGemmDataPerReadM = 2;
|
||||
constexpr index_t ThreadGemmDataPerReadN = 2;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<1, 1>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<4, 16>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 1;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 1;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<4, 1>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<1, 64>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 2;
|
||||
#elif 0
|
||||
// cdata = 32, BlockSize 64, 16x128x4
|
||||
constexpr index_t BlockSize = 64;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 16;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerThread = 2;
|
||||
constexpr index_t GemmNPerThread = 4;
|
||||
constexpr index_t GemmKPerThread = 1;
|
||||
|
||||
constexpr index_t GemmMLevel0Cluster = 2;
|
||||
constexpr index_t GemmNLevel0Cluster = 2;
|
||||
constexpr index_t GemmMLevel1Cluster = 2;
|
||||
constexpr index_t GemmNLevel1Cluster = 8;
|
||||
|
||||
constexpr index_t ThreadGemmDataPerReadM = 2;
|
||||
constexpr index_t ThreadGemmDataPerReadN = 4;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<1, 1>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<4, 16>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 1;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 1;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<4, 2>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<1, 64>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 4;
|
||||
#elif 0
|
||||
// cdata = 64, BlockSize 64, 16x256x2
|
||||
constexpr index_t BlockSize = 64;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 16;
|
||||
constexpr index_t GemmNPerBlock = 256;
|
||||
constexpr index_t GemmKPerBlock = 2;
|
||||
|
||||
constexpr index_t GemmMPerThread = 4;
|
||||
constexpr index_t GemmNPerThread = 4;
|
||||
constexpr index_t GemmKPerThread = 1;
|
||||
|
||||
constexpr index_t GemmMLevel0Cluster = 2;
|
||||
constexpr index_t GemmNLevel0Cluster = 2;
|
||||
constexpr index_t GemmMLevel1Cluster = 1;
|
||||
constexpr index_t GemmNLevel1Cluster = 16;
|
||||
|
||||
constexpr index_t ThreadGemmDataPerReadM = 4;
|
||||
constexpr index_t ThreadGemmDataPerReadN = 4;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<1, 1>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<2, 16>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 1;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 1;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<2, 4>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<1, 64>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 4;
|
||||
#elif 0
|
||||
// cdata = 64, BlockSize 64, 16x256x4
|
||||
constexpr index_t BlockSize = 64;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 16;
|
||||
constexpr index_t GemmNPerBlock = 256;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerThread = 4;
|
||||
constexpr index_t GemmNPerThread = 4;
|
||||
constexpr index_t GemmKPerThread = 1;
|
||||
|
||||
constexpr index_t GemmMLevel0Cluster = 2;
|
||||
constexpr index_t GemmNLevel0Cluster = 2;
|
||||
constexpr index_t GemmMLevel1Cluster = 1;
|
||||
constexpr index_t GemmNLevel1Cluster = 16;
|
||||
|
||||
constexpr index_t ThreadGemmDataPerReadM = 4;
|
||||
constexpr index_t ThreadGemmDataPerReadN = 4;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<1, 1>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<4, 16>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 1;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 1;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<4, 4>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<1, 64>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 4;
|
||||
#elif 0
|
||||
// cdata = 16, BlockSize = 64, 16x64x4
|
||||
// GemmBBlockCopySrcDataPerRead_GemmN = 4
|
||||
// GemmCThreadCopyDstDataPerWrite_GemmN1 = 2
|
||||
constexpr index_t BlockSize = 64;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 16;
|
||||
constexpr index_t GemmNPerBlock = 64;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerThread = 2;
|
||||
constexpr index_t GemmNPerThread = 2;
|
||||
constexpr index_t GemmKPerThread = 1;
|
||||
|
||||
constexpr index_t GemmMLevel0Cluster = 2;
|
||||
constexpr index_t GemmNLevel0Cluster = 2;
|
||||
constexpr index_t GemmMLevel1Cluster = 2;
|
||||
constexpr index_t GemmNLevel1Cluster = 8;
|
||||
|
||||
constexpr index_t ThreadGemmDataPerReadM = 2;
|
||||
constexpr index_t ThreadGemmDataPerReadN = 2;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<1, 1>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<4, 16>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 1;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 1;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<1, 4>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<4, 16>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 4;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 4;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 2;
|
||||
#elif 0
|
||||
// cdata = 32, BlockSize = 64, 16x128x4
|
||||
// GemmBBlockCopySrcDataPerRead_GemmN = 4
|
||||
// GemmCThreadCopyDstDataPerWrite_GemmN1 = 4
|
||||
constexpr index_t BlockSize = 64;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 16;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerThread = 2;
|
||||
constexpr index_t GemmNPerThread = 4;
|
||||
constexpr index_t GemmKPerThread = 1;
|
||||
|
||||
constexpr index_t GemmMLevel0Cluster = 2;
|
||||
constexpr index_t GemmNLevel0Cluster = 2;
|
||||
constexpr index_t GemmMLevel1Cluster = 2;
|
||||
constexpr index_t GemmNLevel1Cluster = 8;
|
||||
|
||||
constexpr index_t ThreadGemmDataPerReadM = 2;
|
||||
constexpr index_t ThreadGemmDataPerReadN = 4;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<1, 1>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<4, 16>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 1;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 1;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<2, 4>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<2, 32>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 4;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 4;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 4;
|
||||
#elif 0
|
||||
// cdata = 64, BlockSize = 128, 32x256x8
|
||||
constexpr index_t BlockSize = 128;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 32;
|
||||
constexpr index_t GemmNPerBlock = 256;
|
||||
constexpr index_t GemmKPerBlock = 8;
|
||||
|
||||
constexpr index_t GemmMPerThread = 4;
|
||||
constexpr index_t GemmNPerThread = 4;
|
||||
constexpr index_t GemmKPerThread = 1;
|
||||
|
||||
constexpr index_t GemmMLevel0Cluster = 2;
|
||||
constexpr index_t GemmNLevel0Cluster = 2;
|
||||
constexpr index_t GemmMLevel1Cluster = 2;
|
||||
constexpr index_t GemmNLevel1Cluster = 16;
|
||||
|
||||
constexpr index_t ThreadGemmDataPerReadM = 4;
|
||||
constexpr index_t ThreadGemmDataPerReadN = 4;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<2, 1>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<4, 32>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 1;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 1;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<8, 2>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<1, 128>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 1;
|
||||
#elif 0
|
||||
// cdata = 64, BlockSize = 256, 128x128x2
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 2;
|
||||
|
||||
constexpr index_t GemmMPerThread = 4;
|
||||
constexpr index_t GemmNPerThread = 4;
|
||||
constexpr index_t GemmKPerThread = 1;
|
||||
|
||||
constexpr index_t GemmMLevel0Cluster = 2;
|
||||
constexpr index_t GemmNLevel0Cluster = 2;
|
||||
constexpr index_t GemmMLevel1Cluster = 8;
|
||||
constexpr index_t GemmNLevel1Cluster = 8;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<1, 1>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<2, 128>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 1;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 1;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<1, 1>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<2, 128>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 1;
|
||||
#elif 0
|
||||
// cdata = 64, BlockSize = 256, 128x128x4
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerThread = 4;
|
||||
constexpr index_t GemmNPerThread = 4;
|
||||
constexpr index_t GemmKPerThread = 1;
|
||||
|
||||
constexpr index_t GemmMLevel0Cluster = 2;
|
||||
constexpr index_t GemmNLevel0Cluster = 2;
|
||||
constexpr index_t GemmMLevel1Cluster = 8;
|
||||
constexpr index_t GemmNLevel1Cluster = 8;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<2, 1>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<2, 128>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 2;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 1;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<2, 1>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<2, 128>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 1;
|
||||
#elif 1
|
||||
#if 1
|
||||
// cdata = 64, BlockSize = 256, 128x128x8
|
||||
// b thread copy 4x1
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 8;
|
||||
constexpr index_t GemmMPerBlockM1 = 128;
|
||||
constexpr index_t GemmNPerBlockN1 = 128;
|
||||
constexpr index_t GemmKPerBlock = 8;
|
||||
|
||||
constexpr index_t GemmMPerThread = 4;
|
||||
constexpr index_t GemmNPerThread = 4;
|
||||
constexpr index_t GemmKPerThread = 1;
|
||||
constexpr index_t GemmM1PerThreadM111 = 4;
|
||||
constexpr index_t GemmN1PerThreadN111 = 4;
|
||||
constexpr index_t GemmKPerThread = 1;
|
||||
|
||||
constexpr index_t GemmMLevel0Cluster = 2;
|
||||
constexpr index_t GemmNLevel0Cluster = 2;
|
||||
constexpr index_t GemmMLevel1Cluster = 8;
|
||||
constexpr index_t GemmNLevel1Cluster = 8;
|
||||
constexpr index_t GemmM11N11ThreadClusterM1100 = 8;
|
||||
constexpr index_t GemmM11N11ThreadClusterN1100 = 8;
|
||||
constexpr index_t GemmM11N11ThreadClusterM1101 = 2;
|
||||
constexpr index_t GemmM11N11ThreadClusterN1101 = 2;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<4, 1>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<2, 128>;
|
||||
using GemmABlockTransferThreadSliceLengths_K_M0_M1 = Sequence<4, 1, 1>;
|
||||
using GemmABlockTransferThreadClusterLengths_K_M0_M1 = Sequence<2, 1, 128>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 4;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 1;
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_K = 4;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_M1 = 1;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<4, 1>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<2, 128>;
|
||||
using GemmBBlockTransferThreadSliceLengths_K_N0_N1 = Sequence<4, 1, 1>;
|
||||
using GemmBBlockTransferThreadClusterLengths_K_N0_N1 = Sequence<2, 1, 128>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1;
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_N1 = 1;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_N1 = 1;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 1;
|
||||
#elif 1
|
||||
// cdata = 64, BlockSize = 256, 128x128x8
|
||||
// b thread copy 2x2
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 8;
|
||||
|
||||
constexpr index_t GemmMPerThread = 4;
|
||||
constexpr index_t GemmNPerThread = 4;
|
||||
constexpr index_t GemmKPerThread = 1;
|
||||
|
||||
constexpr index_t GemmMLevel0Cluster = 2;
|
||||
constexpr index_t GemmNLevel0Cluster = 2;
|
||||
constexpr index_t GemmMLevel1Cluster = 8;
|
||||
constexpr index_t GemmNLevel1Cluster = 8;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<4, 1>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<2, 128>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 2;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 1;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<2, 2>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<4, 64>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 1;
|
||||
#elif 1
|
||||
// cdata = 64, BlockSize = 256, 128x128x16
|
||||
// GemmBBlockCopySrcDataPerRead_GemmN = 4
|
||||
// GemmCThreadCopyDstDataPerWrite_GemmN1 = 4
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 16;
|
||||
|
||||
constexpr index_t GemmMPerThread = 4;
|
||||
constexpr index_t GemmNPerThread = 4;
|
||||
constexpr index_t GemmKPerThread = 1;
|
||||
|
||||
constexpr index_t GemmMLevel0Cluster = 4;
|
||||
constexpr index_t GemmNLevel0Cluster = 4;
|
||||
constexpr index_t GemmMLevel1Cluster = 4;
|
||||
constexpr index_t GemmNLevel1Cluster = 4;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<4, 2>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<4, 64>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 4;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 1;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<2, 4>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<8, 32>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 4;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 4;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 4;
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector_N11 = 1;
|
||||
#endif
|
||||
|
||||
constexpr index_t GemmM1 = GemmMPerThread * GemmMLevel0Cluster * GemmMLevel1Cluster;
|
||||
constexpr index_t GemmN1 = GemmNPerThread * GemmNLevel0Cluster * GemmNLevel1Cluster;
|
||||
|
||||
const auto descs =
|
||||
#if 1
|
||||
transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw_pad
|
||||
#elif 0
|
||||
transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw_no_pad
|
||||
#else
|
||||
transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw_1x1
|
||||
#endif
|
||||
<GemmMPerBlock, GemmNPerBlock, GemmM1, GemmN1>(wei_k_c_y_x_desc,
|
||||
in_n_c_hi_wi_desc,
|
||||
out_n_k_ho_wo_desc,
|
||||
conv_strides,
|
||||
conv_dilations,
|
||||
in_left_pads,
|
||||
in_right_pads);
|
||||
transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw_pad(wei_k_c_y_x_desc,
|
||||
in_n_c_hi_wi_desc,
|
||||
out_n_k_ho_wo_desc,
|
||||
conv_strides,
|
||||
conv_dilations,
|
||||
in_left_pads,
|
||||
in_right_pads);
|
||||
|
||||
// HACK: hacks that control index calculation when iterating over A, B, C matrix
|
||||
constexpr auto wei_gemmk_gemmm0_gemmn1_grid_iterator_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{}));
|
||||
|
||||
constexpr auto in_gemmk_gemmn0_gemmn1_grid_iterator_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}));
|
||||
|
||||
constexpr auto out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_iterator_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 1, 0, 0>{},
|
||||
Sequence<0, 0, 1, 0, 0>{},
|
||||
Sequence<0, 0, 1, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 2, 0, 0>{},
|
||||
Sequence<0, 0, 2, 0, 0>{},
|
||||
Sequence<0, 0, 2, 0, 0>{}));
|
||||
|
||||
constexpr auto wei_gemmk_gemmm0_gemmm1_grid_move_slice_window_iterator_hacks =
|
||||
Sequence<0, 0, 0, 0, 0>{};
|
||||
|
||||
constexpr auto in_gemmk_gemmn0_gemmn1_grid_move_slice_window_iterator_hacks =
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>{};
|
||||
|
||||
const auto wei_gemmk_gemmm_grid_desc = descs[I0];
|
||||
const auto in_gemmk_gemmn_grid_desc = descs[I1];
|
||||
const auto out_gemmm_gemmn_grid_desc = descs[I2];
|
||||
|
||||
for(index_t i = 0; i < 5; ++i)
|
||||
{
|
||||
float ave_time = launch_kernel_dynamic_gemm_v1<
|
||||
float ave_time = driver_dynamic_gemm_v1r2<
|
||||
BlockSize,
|
||||
typename vector_type<TInWei, InWeiVectorSize>::type,
|
||||
TInWei,
|
||||
TAcc,
|
||||
TOut,
|
||||
InMemoryDataOperation::Set,
|
||||
decltype(descs[I0]),
|
||||
decltype(descs[I1]),
|
||||
decltype(descs[I2]),
|
||||
decltype(descs[I3]),
|
||||
GemmMPerBlock,
|
||||
GemmNPerBlock,
|
||||
decltype(wei_gemmk_gemmm_grid_desc),
|
||||
decltype(in_gemmk_gemmn_grid_desc),
|
||||
decltype(out_gemmm_gemmn_grid_desc),
|
||||
GemmMPerBlockM1,
|
||||
GemmNPerBlockN1,
|
||||
GemmKPerBlock,
|
||||
GemmMPerThread,
|
||||
GemmNPerThread,
|
||||
GemmM1PerThreadM111,
|
||||
GemmN1PerThreadN111,
|
||||
GemmKPerThread,
|
||||
GemmMLevel0Cluster,
|
||||
GemmNLevel0Cluster,
|
||||
GemmMLevel1Cluster,
|
||||
GemmNLevel1Cluster,
|
||||
GemmABlockTransferThreadSliceLengths_GemmK_GemmM,
|
||||
GemmABlockTransferThreadClusterLengths_GemmK_GemmM,
|
||||
Sequence<1, 0>,
|
||||
Sequence<1, 0>,
|
||||
0,
|
||||
GemmABlockTransferSrcScalarPerVector_GemmK,
|
||||
GemmABlockTransferDstScalarPerVector_GemmM,
|
||||
GemmM11N11ThreadClusterM1100,
|
||||
GemmM11N11ThreadClusterN1100,
|
||||
GemmM11N11ThreadClusterM1101,
|
||||
GemmM11N11ThreadClusterN1101,
|
||||
GemmABlockTransferThreadSliceLengths_K_M0_M1,
|
||||
GemmABlockTransferThreadClusterLengths_K_M0_M1,
|
||||
Sequence<2, 1, 0>, // ABlockTransferThreadClusterArrangeOrder
|
||||
Sequence<2, 1, 0>, // ABlockTransferSrcAccessOrder
|
||||
0, // ABlockTransferSrcVectorDim
|
||||
GemmABlockTransferSrcScalarPerVector_K,
|
||||
GemmABlockTransferDstScalarPerVector_M1,
|
||||
false, // don't move back src coordinate after threadwise copy
|
||||
GemmBBlockTransferThreadSliceLengths_GemmK_GemmN,
|
||||
GemmBBlockTransferThreadClusterLengths_GemmK_GemmN,
|
||||
Sequence<0, 1>,
|
||||
Sequence<0, 1>,
|
||||
1,
|
||||
GemmBBlockTransferSrcScalarPerVector_GemmN,
|
||||
GemmBBlockTransferDstScalarPerVector_GemmN,
|
||||
false, // don't move back src coordinate after threadwise copy, which will be fused with
|
||||
// MoveSrcSliceWindow() to save addr computation
|
||||
Sequence<2, 3, 0, 1>,
|
||||
3,
|
||||
GemmCThreadTransferDstScalarPerVector_GemmN1,
|
||||
decltype(descs[I4]),
|
||||
decltype(descs[I5]),
|
||||
decltype(descs[I6]),
|
||||
decltype(descs[I7]),
|
||||
decltype(descs[I8])>(static_cast<typename vector_type<TInWei, InWeiVectorSize>::type*>(
|
||||
wei_k_c_y_x_device_buf.GetDeviceBuffer()),
|
||||
static_cast<typename vector_type<TInWei, InWeiVectorSize>::type*>(
|
||||
in_n_c_hi_wi_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TOut*>(out_n_k_ho_wo_device_buf.GetDeviceBuffer()),
|
||||
descs[I0],
|
||||
descs[I1],
|
||||
descs[I2],
|
||||
descs[I3],
|
||||
descs[I4],
|
||||
descs[I5],
|
||||
descs[I6],
|
||||
descs[I7],
|
||||
descs[I8],
|
||||
nrepeat);
|
||||
GemmBBlockTransferThreadSliceLengths_K_N0_N1,
|
||||
GemmBBlockTransferThreadClusterLengths_K_N0_N1,
|
||||
Sequence<0, 1, 2>, // BBlockTransferThreadClusterArrangeOrder
|
||||
Sequence<0, 1, 2>, // BBlockTransferSrcAccessOrder
|
||||
2, // BBlockTransferSrcVectorDim
|
||||
GemmBBlockTransferSrcScalarPerVector_N1,
|
||||
GemmBBlockTransferDstScalarPerVector_N1,
|
||||
false, // don't move back src coordinate after threadwise copy
|
||||
Sequence<3, 4, 5, 0, 1, 2>, // CThreadTransferSrcDstAccessOrder
|
||||
5, // CThreadTransferSrcDstVectorDim
|
||||
GemmCThreadTransferDstScalarPerVector_N11,
|
||||
decltype(wei_gemmk_gemmm0_gemmn1_grid_iterator_hacks),
|
||||
decltype(in_gemmk_gemmn0_gemmn1_grid_iterator_hacks),
|
||||
decltype(out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_iterator_hacks),
|
||||
decltype(wei_gemmk_gemmm0_gemmm1_grid_move_slice_window_iterator_hacks),
|
||||
decltype(in_gemmk_gemmn0_gemmn1_grid_move_slice_window_iterator_hacks)>(
|
||||
static_cast<TInWei*>(wei_k_c_y_x_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TInWei*>(in_n_c_hi_wi_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TOut*>(out_n_k_ho_wo_device_buf.GetDeviceBuffer()),
|
||||
wei_gemmk_gemmm_grid_desc,
|
||||
in_gemmk_gemmn_grid_desc,
|
||||
out_gemmm_gemmn_grid_desc,
|
||||
wei_gemmk_gemmm0_gemmn1_grid_iterator_hacks,
|
||||
in_gemmk_gemmn0_gemmn1_grid_iterator_hacks,
|
||||
out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_iterator_hacks,
|
||||
wei_gemmk_gemmm0_gemmm1_grid_move_slice_window_iterator_hacks,
|
||||
in_gemmk_gemmn0_gemmn1_grid_move_slice_window_iterator_hacks,
|
||||
nrepeat);
|
||||
|
||||
float perf = (float)calculate_convolution_flops(
|
||||
in_n_c_hi_wi_desc, wei_k_c_y_x_desc, out_n_k_ho_wo_desc) /
|
||||
|
||||
@@ -2,30 +2,29 @@
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "transform_forward_convolution_into_gemm_v4r4_nhwc_kyxc_nhwk.hpp"
|
||||
#include "driver_dynamic_gemm_v1.hpp"
|
||||
#include "driver_dynamic_gemm_v1r2.hpp"
|
||||
|
||||
template <class TInWei,
|
||||
ck::index_t InWeiVectorSize,
|
||||
class TAcc,
|
||||
class TOut,
|
||||
class InDesc,
|
||||
class WeiDesc,
|
||||
class OutDesc,
|
||||
class ConvStrides,
|
||||
class ConvDilations,
|
||||
class InLeftPads,
|
||||
class InRightPads>
|
||||
template <typename TInWei,
|
||||
typename TAcc,
|
||||
typename TOut,
|
||||
typename InLengths,
|
||||
typename WeiLengths,
|
||||
typename OutLengths,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads>
|
||||
void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(
|
||||
InDesc,
|
||||
const Tensor<TInWei>& in_n_c_hi_wi,
|
||||
WeiDesc,
|
||||
const Tensor<TInWei>& wei_k_c_y_x,
|
||||
OutDesc,
|
||||
Tensor<TOut>& out_n_k_ho_wo,
|
||||
ConvStrides,
|
||||
ConvDilations,
|
||||
InLeftPads,
|
||||
InRightPads,
|
||||
const InLengths& in_n_hi_wi_c_lengths,
|
||||
const WeiLengths& wei_k_y_x_c_lengths,
|
||||
const OutLengths& out_n_ho_wo_k_lengths,
|
||||
const ConvStrides& conv_strides,
|
||||
const ConvDilations& conv_dilations,
|
||||
const InLeftPads& in_left_pads,
|
||||
const InRightPads& in_right_pads,
|
||||
const Tensor<TInWei>& in_n_hi_wi_c,
|
||||
const Tensor<TInWei>& wei_k_y_x_c,
|
||||
Tensor<TOut>& out_n_ho_wo_k,
|
||||
ck::index_t nrepeat)
|
||||
{
|
||||
using namespace ck;
|
||||
@@ -42,73 +41,6 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(
|
||||
constexpr auto I7 = Number<7>{};
|
||||
constexpr auto I8 = Number<8>{};
|
||||
|
||||
constexpr auto N = OutDesc::GetLengths()[I0];
|
||||
constexpr auto K = OutDesc::GetLengths()[I1];
|
||||
constexpr auto C = WeiDesc::GetLengths()[I1];
|
||||
|
||||
constexpr auto Hi = InDesc::GetLengths()[I2];
|
||||
constexpr auto Wi = InDesc::GetLengths()[I3];
|
||||
|
||||
constexpr auto Ho = OutDesc::GetLengths()[I2];
|
||||
constexpr auto Wo = OutDesc::GetLengths()[I3];
|
||||
|
||||
constexpr auto Y = WeiDesc::GetLengths()[I2];
|
||||
constexpr auto X = WeiDesc::GetLengths()[I3];
|
||||
|
||||
constexpr auto C0 = C / Number<InWeiVectorSize>{};
|
||||
constexpr auto C1 = Number<InWeiVectorSize>{};
|
||||
|
||||
#if 1
|
||||
// run-time variables
|
||||
constexpr auto in_n_hi_wi_c0_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_multi_index(N, Hi, Wi, C0));
|
||||
constexpr auto wei_k_y_x_c0_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_multi_index(K, Y, X, C0));
|
||||
constexpr auto out_n_ho_wo_k_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_multi_index(N, Ho, Wo, K));
|
||||
|
||||
const auto conv_strides = to_multi_index(ConvStrides{});
|
||||
const auto conv_dilations = to_multi_index(ConvDilations{});
|
||||
const auto in_left_pads = to_multi_index(InLeftPads{});
|
||||
const auto in_right_pads = to_multi_index(InRightPads{});
|
||||
#else
|
||||
// compile-time variables
|
||||
constexpr auto in_n_hi_wi_c0_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, Hi, Wi, C0));
|
||||
constexpr auto wei_k_y_x_c0_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, Y, X, C0));
|
||||
constexpr auto out_n_ho_wo_k_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, Ho, Wo, K));
|
||||
|
||||
const auto conv_strides = sequence_to_tuple_of_number(ConvStrides{});
|
||||
const auto conv_dilations = sequence_to_tuple_of_number(ConvDilations{});
|
||||
const auto in_left_pads = sequence_to_tuple_of_number(InLeftPads{});
|
||||
const auto in_right_pads = sequence_to_tuple_of_number(InRightPads{});
|
||||
#endif
|
||||
|
||||
Tensor<TInWei> in_n_hi_wi_c(
|
||||
make_HostTensorDescriptor(make_native_tensor_descriptor_packed(Sequence<N, Hi, Wi, C>{})));
|
||||
Tensor<TInWei> wei_k_y_x_c(
|
||||
make_HostTensorDescriptor(make_native_tensor_descriptor_packed(Sequence<K, Y, X, C>{})));
|
||||
Tensor<TOut> out_n_ho_wo_k(
|
||||
make_HostTensorDescriptor(make_native_tensor_descriptor_packed(Sequence<N, Ho, Wo, K>{})));
|
||||
|
||||
auto f_nchw2nhwc = [&](auto n, auto hi, auto wi, auto c) {
|
||||
in_n_hi_wi_c(n, hi, wi, c) = in_n_c_hi_wi(n, c, hi, wi);
|
||||
};
|
||||
|
||||
auto f_kcyx2kyxc = [&](auto k, auto y, auto x, auto c) {
|
||||
wei_k_y_x_c(k, y, x, c) = wei_k_c_y_x(k, c, y, x);
|
||||
};
|
||||
|
||||
auto f_nkhw2nhwk = [&](auto n, auto ho, auto wo, auto k) {
|
||||
out_n_ho_wo_k(n, ho, wo, k) = out_n_k_ho_wo(n, k, ho, wo);
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(f_nchw2nhwc, N, Hi, Wi, C)();
|
||||
make_ParallelTensorFunctor(f_kcyx2kyxc, K, Y, X, C)();
|
||||
make_ParallelTensorFunctor(f_nkhw2nhwk, N, Ho, Wo, K)();
|
||||
|
||||
DeviceMem in_n_hi_wi_c_device_buf(sizeof(TInWei) * in_n_hi_wi_c.mDesc.GetElementSpace());
|
||||
DeviceMem wei_k_y_x_c_device_buf(sizeof(TInWei) * wei_k_y_x_c.mDesc.GetElementSpace());
|
||||
DeviceMem out_n_ho_wo_k_device_buf(sizeof(TOut) * out_n_ho_wo_k.mDesc.GetElementSpace());
|
||||
@@ -117,357 +49,472 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(
|
||||
wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data());
|
||||
out_n_ho_wo_k_device_buf.ToDevice(out_n_ho_wo_k.mData.data());
|
||||
|
||||
const auto in_n_hi_wi_c_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(in_n_hi_wi_c_lengths);
|
||||
const auto wei_k_y_x_c_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(wei_k_y_x_c_lengths);
|
||||
const auto out_n_ho_wo_k_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(out_n_ho_wo_k_lengths);
|
||||
|
||||
#if 0
|
||||
// cdata = 16, BlockSize = 64, 16x64x4
|
||||
constexpr index_t BlockSize = 64;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 16;
|
||||
constexpr index_t GemmNPerBlock = 64;
|
||||
constexpr index_t GemmMPerBlockM1 = 16;
|
||||
constexpr index_t GemmNPerBlockN1 = 64;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerThread = 2;
|
||||
constexpr index_t GemmNPerThread = 2;
|
||||
constexpr index_t GemmKPerThread = 1;
|
||||
constexpr index_t GemmM1PerThreadM111 = 2;
|
||||
constexpr index_t GemmN1PerThreadN111 = 2;
|
||||
constexpr index_t GemmKPerThread = 1;
|
||||
|
||||
constexpr index_t GemmMLevel0Cluster = 2;
|
||||
constexpr index_t GemmNLevel0Cluster = 2;
|
||||
constexpr index_t GemmMLevel1Cluster = 2;
|
||||
constexpr index_t GemmNLevel1Cluster = 8;
|
||||
constexpr index_t GemmM11N11ThreadClusterM1101 = 2;
|
||||
constexpr index_t GemmM11N11ThreadClusterN1101 = 2;
|
||||
constexpr index_t GemmM11N11ThreadClusterM1100 = 2;
|
||||
constexpr index_t GemmM11N11ThreadClusterN1100 = 8;
|
||||
|
||||
constexpr index_t ThreadGemmDataPerReadM = 2;
|
||||
constexpr index_t ThreadGemmDataPerReadN = 2;
|
||||
using GemmABlockTransferThreadSliceLengths_K_M0_M1 = Sequence<1, 1, 1>;
|
||||
using GemmABlockTransferThreadClusterLengths_K_M0_M1 = Sequence<4, 1, 16>;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<1, 1>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<4, 16>;
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_K = 1;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_M1 = 1;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 1;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 1;
|
||||
using GemmBBlockTransferThreadSliceLengths_K_N0_N1 = Sequence<4, 1, 1>;
|
||||
using GemmBBlockTransferThreadClusterLengths_K_N0_N1 = Sequence<1, 1, 64>;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<4, 1>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<1, 64>;
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_K = 4;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_N1 = 1;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK = 4;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmM1 = 2;
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector_M11 = 2;
|
||||
#elif 0
|
||||
// cdata = 32, BlockSize = 64, 16x128x4
|
||||
constexpr index_t BlockSize = 64;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 16;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
constexpr index_t GemmMPerBlockM1 = 16;
|
||||
constexpr index_t GemmNPerBlockN1 = 128;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerThread = 2;
|
||||
constexpr index_t GemmNPerThread = 4;
|
||||
constexpr index_t GemmKPerThread = 1;
|
||||
constexpr index_t GemmM1PerThreadM111 = 2;
|
||||
constexpr index_t GemmN1PerThreadN111 = 4;
|
||||
constexpr index_t GemmKPerThread = 1;
|
||||
|
||||
constexpr index_t GemmMLevel0Cluster = 2;
|
||||
constexpr index_t GemmNLevel0Cluster = 2;
|
||||
constexpr index_t GemmMLevel1Cluster = 2;
|
||||
constexpr index_t GemmNLevel1Cluster = 8;
|
||||
constexpr index_t GemmM11N11ThreadClusterM1101 = 2;
|
||||
constexpr index_t GemmM11N11ThreadClusterN1101 = 2;
|
||||
constexpr index_t GemmM11N11ThreadClusterM1100 = 2;
|
||||
constexpr index_t GemmM11N11ThreadClusterN1100 = 8;
|
||||
|
||||
constexpr index_t ThreadGemmDataPerReadM = 2;
|
||||
constexpr index_t ThreadGemmDataPerReadN = 4;
|
||||
using GemmABlockTransferThreadSliceLengths_K_M0_M1 = Sequence<1, 1, 1>;
|
||||
using GemmABlockTransferThreadClusterLengths_K_M0_M1 = Sequence<4, 1, 16>;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<1, 1>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<4, 16>;
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_K = 1;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_M1 = 1;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 1;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 1;
|
||||
using GemmBBlockTransferThreadSliceLengths_K_N0_N1 = Sequence<4, 1, 2>;
|
||||
using GemmBBlockTransferThreadClusterLengths_K_N0_N1 = Sequence<1, 1, 64>;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<4, 2>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<1, 64>;
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_K = 4;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_N1 = 1;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK = 4;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmM1 = 2;
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector_M11 = 2;
|
||||
#elif 0
|
||||
// cdata = 64, BlockSize = 64, 16x256x2
|
||||
constexpr index_t BlockSize = 64;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 16;
|
||||
constexpr index_t GemmNPerBlock = 256;
|
||||
constexpr index_t GemmKPerBlock = 2;
|
||||
constexpr index_t GemmMPerBlockM1 = 16;
|
||||
constexpr index_t GemmNPerBlockN1 = 256;
|
||||
constexpr index_t GemmKPerBlock = 2;
|
||||
|
||||
constexpr index_t GemmMPerThread = 4;
|
||||
constexpr index_t GemmNPerThread = 4;
|
||||
constexpr index_t GemmKPerThread = 1;
|
||||
constexpr index_t GemmM1PerThreadM111 = 4;
|
||||
constexpr index_t GemmN1PerThreadN111 = 4;
|
||||
constexpr index_t GemmKPerThread = 1;
|
||||
|
||||
constexpr index_t GemmMLevel0Cluster = 1;
|
||||
constexpr index_t GemmNLevel0Cluster = 2;
|
||||
constexpr index_t GemmMLevel1Cluster = 2;
|
||||
constexpr index_t GemmNLevel1Cluster = 16;
|
||||
constexpr index_t GemmM11N11ThreadClusterM1101 = 1;
|
||||
constexpr index_t GemmM11N11ThreadClusterN1101 = 2;
|
||||
constexpr index_t GemmM11N11ThreadClusterM1100 = 2;
|
||||
constexpr index_t GemmM11N11ThreadClusterN1100 = 16;
|
||||
|
||||
constexpr index_t ThreadGemmDataPerReadM = 4;
|
||||
constexpr index_t ThreadGemmDataPerReadN = 4;
|
||||
using GemmABlockTransferThreadSliceLengths_K_M0_M1 = Sequence<1, 1, 1>;
|
||||
using GemmABlockTransferThreadClusterLengths_K_M0_M1 = Sequence<2, 1, 16>;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<1, 1>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<2, 16>;
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_K = 1;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_M1 = 1;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 1;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 1;
|
||||
using GemmBBlockTransferThreadSliceLengths_K_N0_N1 = Sequence<2, 1, 4>;
|
||||
using GemmBBlockTransferThreadClusterLengths_K_N0_N1 = Sequence<1, 1, 64>;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<2, 4>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<1, 64>;
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_K = 2;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_N1 = 1;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK = 2;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmM1 = 4;
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector_M11 = 4;
|
||||
#elif 0
|
||||
// cdata = 64, BlockSize = 64, 16x256x4
|
||||
constexpr index_t BlockSize = 64;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 16;
|
||||
constexpr index_t GemmNPerBlock = 256;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
constexpr index_t GemmMPerBlockM1 = 16;
|
||||
constexpr index_t GemmNPerBlockN1 = 256;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerThread = 4;
|
||||
constexpr index_t GemmNPerThread = 4;
|
||||
constexpr index_t GemmKPerThread = 1;
|
||||
constexpr index_t GemmM1PerThreadM111 = 4;
|
||||
constexpr index_t GemmN1PerThreadN111 = 4;
|
||||
constexpr index_t GemmKPerThread = 1;
|
||||
|
||||
constexpr index_t GemmMLevel0Cluster = 2;
|
||||
constexpr index_t GemmNLevel0Cluster = 2;
|
||||
constexpr index_t GemmMLevel1Cluster = 1;
|
||||
constexpr index_t GemmNLevel1Cluster = 16;
|
||||
constexpr index_t GemmM11N11ThreadClusterM1101 = 2;
|
||||
constexpr index_t GemmM11N11ThreadClusterN1101 = 2;
|
||||
constexpr index_t GemmM11N11ThreadClusterM1100 = 1;
|
||||
constexpr index_t GemmM11N11ThreadClusterN1100 = 16;
|
||||
|
||||
constexpr index_t ThreadGemmDataPerReadM = 4;
|
||||
constexpr index_t ThreadGemmDataPerReadN = 4;
|
||||
using GemmABlockTransferThreadSliceLengths_K_M0_M1 = Sequence<1, 1, 1>;
|
||||
using GemmABlockTransferThreadClusterLengths_K_M0_M1 = Sequence<4, 1, 16>;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<1, 1>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<4, 16>;
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_K = 1;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_M1 = 1;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 1;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 1;
|
||||
using GemmBBlockTransferThreadSliceLengths_K_N0_N1 = Sequence<4, 1, 4>;
|
||||
using GemmBBlockTransferThreadClusterLengths_K_N0_N1 = Sequence<1, 1, 64>;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<4, 4>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<1, 64>;
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_K = 4;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_N1 = 1;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK = 4;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmM1 = 4;
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector_M11 = 4;
|
||||
#elif 0
|
||||
// cdata = 64, BlockSize = 128, 32x256x4
|
||||
constexpr index_t BlockSize = 128;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 32;
|
||||
constexpr index_t GemmNPerBlock = 256;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
constexpr index_t GemmMPerBlockM1 = 32;
|
||||
constexpr index_t GemmNPerBlockN1 = 256;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerThread = 4;
|
||||
constexpr index_t GemmNPerThread = 4;
|
||||
constexpr index_t GemmKPerThread = 1;
|
||||
constexpr index_t GemmM1PerThreadM111 = 4;
|
||||
constexpr index_t GemmN1PerThreadN111 = 4;
|
||||
constexpr index_t GemmKPerThread = 1;
|
||||
|
||||
constexpr index_t GemmMLevel0Cluster = 2;
|
||||
constexpr index_t GemmNLevel0Cluster = 2;
|
||||
constexpr index_t GemmMLevel1Cluster = 2;
|
||||
constexpr index_t GemmNLevel1Cluster = 16;
|
||||
constexpr index_t GemmM11N11ThreadClusterM1101 = 2;
|
||||
constexpr index_t GemmM11N11ThreadClusterN1101 = 2;
|
||||
constexpr index_t GemmM11N11ThreadClusterM1100 = 2;
|
||||
constexpr index_t GemmM11N11ThreadClusterN1100 = 16;
|
||||
|
||||
constexpr index_t ThreadGemmDataPerReadM = 4;
|
||||
constexpr index_t ThreadGemmDataPerReadN = 4;
|
||||
using GemmABlockTransferThreadSliceLengths_K_M0_M1 = Sequence<1, 1, 1>;
|
||||
using GemmABlockTransferThreadClusterLengths_K_M0_M1 = Sequence<4, 1, 32>;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<1, 1>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<4, 32>;
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_K = 1;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_M1 = 1;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 1;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 1;
|
||||
using GemmBBlockTransferThreadSliceLengths_K_N0_N1 = Sequence<4, 1, 2>;
|
||||
using GemmBBlockTransferThreadClusterLengths_K_N0_N1 = Sequence<1, 1, 128>;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<4, 2>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<1, 128>;
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_K = 4;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_N1 = 1;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK = 4;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmM1 = 4;
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector_M11 = 4;
|
||||
#elif 0
|
||||
// cdata = 64, BlockSize = 128, 32x256x8
|
||||
constexpr index_t BlockSize = 128;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 32;
|
||||
constexpr index_t GemmNPerBlock = 256;
|
||||
constexpr index_t GemmKPerBlock = 8;
|
||||
constexpr index_t GemmMPerBlockM1 = 32;
|
||||
constexpr index_t GemmNPerBlockN1 = 256;
|
||||
constexpr index_t GemmKPerBlock = 8;
|
||||
|
||||
constexpr index_t GemmMPerThread = 4;
|
||||
constexpr index_t GemmNPerThread = 4;
|
||||
constexpr index_t GemmKPerThread = 1;
|
||||
constexpr index_t GemmM1PerThreadM111 = 4;
|
||||
constexpr index_t GemmN1PerThreadN111 = 4;
|
||||
constexpr index_t GemmKPerThread = 1;
|
||||
|
||||
constexpr index_t GemmMLevel0Cluster = 2;
|
||||
constexpr index_t GemmNLevel0Cluster = 2;
|
||||
constexpr index_t GemmMLevel1Cluster = 2;
|
||||
constexpr index_t GemmNLevel1Cluster = 16;
|
||||
constexpr index_t GemmM11N11ThreadClusterM1101 = 2;
|
||||
constexpr index_t GemmM11N11ThreadClusterN1101 = 2;
|
||||
constexpr index_t GemmM11N11ThreadClusterM1100 = 2;
|
||||
constexpr index_t GemmM11N11ThreadClusterN1100 = 16;
|
||||
|
||||
constexpr index_t ThreadGemmDataPerReadM = 4;
|
||||
constexpr index_t ThreadGemmDataPerReadN = 4;
|
||||
using GemmABlockTransferThreadSliceLengths_K_M0_M1 = Sequence<2, 1, 1>;
|
||||
using GemmABlockTransferThreadClusterLengths_K_M0_M1 = Sequence<4, 1, 32>;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<2, 1>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<4, 32>;
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_K = 2;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_M1 = 1;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 2;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 1;
|
||||
using GemmBBlockTransferThreadSliceLengths_K_N0_N1 = Sequence<8, 1, 2>;
|
||||
using GemmBBlockTransferThreadClusterLengths_K_N0_N1 = Sequence<1, 1, 128>;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<8, 2>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<1, 128>;
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_K = 8;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_N1 = 1;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK = 8;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmM1 = 4;
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector_M11 = 4;
|
||||
#elif 1
|
||||
// cdata = 64, BlockSize = 256, 128x128x8
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 8;
|
||||
constexpr index_t GemmMPerBlockM1 = 128;
|
||||
constexpr index_t GemmNPerBlockN1 = 128;
|
||||
constexpr index_t GemmKPerBlock = 8;
|
||||
|
||||
constexpr index_t GemmMPerThread = 4;
|
||||
constexpr index_t GemmNPerThread = 4;
|
||||
constexpr index_t GemmKPerThread = 1;
|
||||
constexpr index_t GemmM1PerThreadM111 = 4;
|
||||
constexpr index_t GemmN1PerThreadN111 = 4;
|
||||
constexpr index_t GemmKPerThread = 1;
|
||||
|
||||
constexpr index_t GemmMLevel0Cluster = 2;
|
||||
constexpr index_t GemmNLevel0Cluster = 2;
|
||||
constexpr index_t GemmMLevel1Cluster = 8;
|
||||
constexpr index_t GemmNLevel1Cluster = 8;
|
||||
constexpr index_t GemmM11N11ThreadClusterM1100 = 8;
|
||||
constexpr index_t GemmM11N11ThreadClusterN1100 = 8;
|
||||
constexpr index_t GemmM11N11ThreadClusterM1101 = 2;
|
||||
constexpr index_t GemmM11N11ThreadClusterN1101 = 2;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<4, 1>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<2, 128>;
|
||||
using GemmABlockTransferThreadSliceLengths_K_M0_M1 = Sequence<4, 1, 1>;
|
||||
using GemmABlockTransferThreadClusterLengths_K_M0_M1 = Sequence<2, 1, 128>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 4;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 1;
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_K = 4;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_M1 = 1;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<4, 1>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<2, 128>;
|
||||
using GemmBBlockTransferThreadSliceLengths_K_N0_N1 = Sequence<4, 1, 1>;
|
||||
using GemmBBlockTransferThreadClusterLengths_K_N0_N1 = Sequence<2, 1, 128>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK = 4;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1;
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_K = 4;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_N1 = 1;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmM1 = 4;
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector_M11 = 4;
|
||||
#elif 1
|
||||
// cdata = 64, BlockSize = 256, 128x128x16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 16;
|
||||
constexpr index_t GemmMPerBlockM1 = 128;
|
||||
constexpr index_t GemmNPerBlockN1 = 128;
|
||||
constexpr index_t GemmKPerBlock = 16;
|
||||
|
||||
constexpr index_t GemmMPerThread = 4;
|
||||
constexpr index_t GemmNPerThread = 4;
|
||||
constexpr index_t GemmKPerThread = 1;
|
||||
constexpr index_t GemmM1PerThreadM111 = 4;
|
||||
constexpr index_t GemmN1PerThreadN111 = 4;
|
||||
constexpr index_t GemmKPerThread = 1;
|
||||
|
||||
constexpr index_t GemmMLevel0Cluster = 2;
|
||||
constexpr index_t GemmNLevel0Cluster = 2;
|
||||
constexpr index_t GemmMLevel1Cluster = 8;
|
||||
constexpr index_t GemmNLevel1Cluster = 8;
|
||||
constexpr index_t GemmM11N11ThreadClusterM1100 = 8;
|
||||
constexpr index_t GemmM11N11ThreadClusterN1100 = 8;
|
||||
constexpr index_t GemmM11N11ThreadClusterM1101 = 2;
|
||||
constexpr index_t GemmM11N11ThreadClusterN1101 = 2;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<4, 2>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<4, 64>;
|
||||
using GemmABlockTransferThreadSliceLengths_K_M0_M1 = Sequence<4, 1, 2>;
|
||||
using GemmABlockTransferThreadClusterLengths_K_M0_M1 = Sequence<4, 1, 64>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 4;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 2;
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_K = 4;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_M1 = 2;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<8, 1>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<2, 128>;
|
||||
using GemmBBlockTransferThreadSliceLengths_K_N0_N1 = Sequence<8, 1, 1>;
|
||||
using GemmBBlockTransferThreadClusterLengths_K_N0_N1 = Sequence<2, 1, 128>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK = 8;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1;
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_K = 8;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_N1 = 1;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmM1 = 4;
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector_M11 = 4;
|
||||
#endif
|
||||
|
||||
constexpr index_t GemmM1 = GemmMPerThread * GemmMLevel0Cluster * GemmMLevel1Cluster;
|
||||
constexpr index_t GemmN1 = GemmNPerThread * GemmNLevel0Cluster * GemmNLevel1Cluster;
|
||||
|
||||
const auto descs =
|
||||
#if 1
|
||||
transform_forward_convolution_into_gemm_v4r4_nhwc_kyxc_nhwk_pad
|
||||
const auto descs =
|
||||
transform_forward_convolution_into_gemm_v4r4_nhwc_kyxc_nhwk_pad(wei_k_y_x_c_desc,
|
||||
in_n_hi_wi_c_desc,
|
||||
out_n_ho_wo_k_desc,
|
||||
conv_strides,
|
||||
conv_dilations,
|
||||
in_left_pads,
|
||||
in_right_pads);
|
||||
|
||||
#if 0
|
||||
// HACK: hacks that control index calculation when iterating over A, B, C matrix
|
||||
constexpr auto wei_gemmk_gemmm0_gemmn1_grid_iterator_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{}));
|
||||
|
||||
constexpr auto in_gemmk_gemmn0_gemmn1_grid_iterator_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}));
|
||||
|
||||
constexpr auto out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_iterator_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{}));
|
||||
|
||||
constexpr auto wei_gemmk_gemmm0_gemmm1_grid_move_slice_window_iterator_hacks =
|
||||
Sequence<0, 0, 0, 0, 0>{};
|
||||
|
||||
constexpr auto in_gemmk_gemmn0_gemmn1_grid_move_slice_window_iterator_hacks =
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>{};
|
||||
#else
|
||||
transform_forward_convolution_into_gemm_v4r4_nhwc_kyxc_nhwk_1x1
|
||||
// HACK: hacks that control index calculation when iterating over A, B, C matrix
|
||||
constexpr auto wei_gemmk_gemmm0_gemmn1_grid_iterator_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{}));
|
||||
|
||||
constexpr auto in_gemmk_gemmn0_gemmn1_grid_iterator_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}));
|
||||
|
||||
constexpr auto out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_iterator_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{}));
|
||||
|
||||
constexpr auto wei_gemmk_gemmm0_gemmm1_grid_move_slice_window_iterator_hacks =
|
||||
Sequence<0, 0, 0, 0, 0>{};
|
||||
|
||||
constexpr auto in_gemmk_gemmn0_gemmn1_grid_move_slice_window_iterator_hacks =
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{};
|
||||
#endif
|
||||
<GemmMPerBlock, GemmNPerBlock, GemmM1, GemmN1>(wei_k_y_x_c0_desc,
|
||||
in_n_hi_wi_c0_desc,
|
||||
out_n_ho_wo_k_desc,
|
||||
conv_strides,
|
||||
conv_dilations,
|
||||
in_left_pads,
|
||||
in_right_pads);
|
||||
|
||||
#else
|
||||
const auto descs =
|
||||
transform_forward_convolution_into_gemm_v4r4_nhwc_kyxc_nhwk_1x1(wei_k_y_x_c_desc,
|
||||
in_n_hi_wi_c_desc,
|
||||
out_n_ho_wo_k_desc,
|
||||
conv_strides,
|
||||
conv_dilations,
|
||||
in_left_pads,
|
||||
in_right_pads);
|
||||
|
||||
// HACK: hacks that control index calculation when iterating over A, B, C matrix
|
||||
constexpr auto wei_gemmk_gemmm0_gemmn1_grid_iterator_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{}));
|
||||
|
||||
constexpr auto in_gemmk_gemmn0_gemmn1_grid_iterator_hacks = make_tuple(
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}),
|
||||
make_tuple(
|
||||
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}));
|
||||
|
||||
constexpr auto out_gemmm0_gemmm1_gemmn0_gemmn1_global_iterator_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{}));
|
||||
|
||||
constexpr auto wei_gemmk_gemmm_global_move_slice_window_iterator_hacks =
|
||||
Sequence<0, 0, 0, 0, 0>{};
|
||||
|
||||
constexpr auto in_gemmk_gemmn_global_move_slice_window_iterator_hacks =
|
||||
Sequence<0, 0, 0, 0, 0>{};
|
||||
#endif
|
||||
|
||||
const auto wei_gemmk_gemmm_grid_desc = descs[I0];
|
||||
const auto in_gemmk_gemmn_grid_desc = descs[I1];
|
||||
const auto out_gemmm_gemmn_grid_desc = descs[I2];
|
||||
|
||||
for(index_t i = 0; i < 5; ++i)
|
||||
{
|
||||
float ave_time = launch_kernel_dynamic_gemm_v1<
|
||||
float ave_time = driver_dynamic_gemm_v1r2<
|
||||
BlockSize,
|
||||
typename vector_type<TInWei, InWeiVectorSize>::type,
|
||||
TInWei,
|
||||
TAcc,
|
||||
TOut,
|
||||
InMemoryDataOperation::Set,
|
||||
decltype(descs[I0]),
|
||||
decltype(descs[I1]),
|
||||
decltype(descs[I2]),
|
||||
decltype(descs[I3]),
|
||||
GemmMPerBlock,
|
||||
GemmNPerBlock,
|
||||
decltype(wei_gemmk_gemmm_grid_desc),
|
||||
decltype(in_gemmk_gemmn_grid_desc),
|
||||
decltype(out_gemmm_gemmn_grid_desc),
|
||||
GemmMPerBlockM1,
|
||||
GemmNPerBlockN1,
|
||||
GemmKPerBlock,
|
||||
GemmMPerThread,
|
||||
GemmNPerThread,
|
||||
GemmM1PerThreadM111,
|
||||
GemmN1PerThreadN111,
|
||||
GemmKPerThread,
|
||||
GemmMLevel0Cluster,
|
||||
GemmNLevel0Cluster,
|
||||
GemmMLevel1Cluster,
|
||||
GemmNLevel1Cluster,
|
||||
GemmABlockTransferThreadSliceLengths_GemmK_GemmM,
|
||||
GemmABlockTransferThreadClusterLengths_GemmK_GemmM,
|
||||
Sequence<1, 0>,
|
||||
Sequence<1, 0>,
|
||||
0,
|
||||
GemmABlockTransferSrcScalarPerVector_GemmK,
|
||||
GemmABlockTransferDstScalarPerVector_GemmM,
|
||||
GemmM11N11ThreadClusterM1100,
|
||||
GemmM11N11ThreadClusterN1100,
|
||||
GemmM11N11ThreadClusterM1101,
|
||||
GemmM11N11ThreadClusterN1101,
|
||||
GemmABlockTransferThreadSliceLengths_K_M0_M1,
|
||||
GemmABlockTransferThreadClusterLengths_K_M0_M1,
|
||||
Sequence<1, 2, 0>, // ABlockTransferThreadClusterArrangeOrder
|
||||
Sequence<1, 2, 0>, // ABlockTransferSrcAccessOrder
|
||||
0, // ABlockTransferSrcVectorDim
|
||||
GemmABlockTransferSrcScalarPerVector_K,
|
||||
GemmABlockTransferDstScalarPerVector_M1,
|
||||
false, // don't move back src coordinate after threadwise copy
|
||||
GemmBBlockTransferThreadSliceLengths_GemmK_GemmN,
|
||||
GemmBBlockTransferThreadClusterLengths_GemmK_GemmN,
|
||||
Sequence<1, 0>,
|
||||
Sequence<1, 0>,
|
||||
0,
|
||||
GemmBBlockTransferSrcScalarPerVector_GemmK,
|
||||
GemmBBlockTransferDstScalarPerVector_GemmN,
|
||||
false, // don't move back src coordinate after threadwise copy, which will be fused with
|
||||
// MoveSrcSliceWindow() to save addr computation
|
||||
Sequence<2, 3, 0, 1>,
|
||||
1,
|
||||
GemmCThreadTransferDstScalarPerVector_GemmM1,
|
||||
decltype(descs[I4]),
|
||||
decltype(descs[I5]),
|
||||
decltype(descs[I6]),
|
||||
decltype(descs[I7]),
|
||||
decltype(descs[I8])>(static_cast<typename vector_type<TInWei, InWeiVectorSize>::type*>(
|
||||
wei_k_y_x_c_device_buf.GetDeviceBuffer()),
|
||||
static_cast<typename vector_type<TInWei, InWeiVectorSize>::type*>(
|
||||
in_n_hi_wi_c_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TOut*>(out_n_ho_wo_k_device_buf.GetDeviceBuffer()),
|
||||
descs[I0],
|
||||
descs[I1],
|
||||
descs[I2],
|
||||
descs[I3],
|
||||
descs[I4],
|
||||
descs[I5],
|
||||
descs[I6],
|
||||
descs[I7],
|
||||
descs[I8],
|
||||
nrepeat);
|
||||
GemmBBlockTransferThreadSliceLengths_K_N0_N1,
|
||||
GemmBBlockTransferThreadClusterLengths_K_N0_N1,
|
||||
Sequence<1, 2, 0>, // BBlockTransferThreadClusterArrangeOrder
|
||||
Sequence<1, 2, 0>, // BBlockTransferSrcAccessOrder
|
||||
0, // BBlockTransferSrcVectorDim
|
||||
GemmBBlockTransferSrcScalarPerVector_K,
|
||||
GemmBBlockTransferDstScalarPerVector_N1,
|
||||
false, // don't move back src coordinate after threadwise copy
|
||||
Sequence<3, 4, 5, 0, 1, 2>, // CThreadTransferSrcDstAccessOrder
|
||||
2, // CThreadTransferSrcDstVectorDim
|
||||
GemmCThreadTransferDstScalarPerVector_M11,
|
||||
decltype(wei_gemmk_gemmm0_gemmn1_grid_iterator_hacks),
|
||||
decltype(in_gemmk_gemmn0_gemmn1_grid_iterator_hacks),
|
||||
decltype(out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_iterator_hacks),
|
||||
decltype(wei_gemmk_gemmm0_gemmm1_grid_move_slice_window_iterator_hacks),
|
||||
decltype(in_gemmk_gemmn0_gemmn1_grid_move_slice_window_iterator_hacks)>(
|
||||
static_cast<TInWei*>(wei_k_y_x_c_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TInWei*>(in_n_hi_wi_c_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TOut*>(out_n_ho_wo_k_device_buf.GetDeviceBuffer()),
|
||||
wei_gemmk_gemmm_grid_desc,
|
||||
in_gemmk_gemmn_grid_desc,
|
||||
out_gemmm_gemmn_grid_desc,
|
||||
wei_gemmk_gemmm0_gemmn1_grid_iterator_hacks,
|
||||
in_gemmk_gemmn0_gemmn1_grid_iterator_hacks,
|
||||
out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_iterator_hacks,
|
||||
wei_gemmk_gemmm0_gemmm1_grid_move_slice_window_iterator_hacks,
|
||||
in_gemmk_gemmn0_gemmn1_grid_move_slice_window_iterator_hacks,
|
||||
nrepeat);
|
||||
|
||||
float perf = (float)(std::size_t(2) * N * K * Ho * Wo * C * Y * X) /
|
||||
(std::size_t(1000) * 1000 * 1000) / ave_time;
|
||||
{
|
||||
const auto N = out_n_ho_wo_k_lengths[I0];
|
||||
const auto K = out_n_ho_wo_k_lengths[I3];
|
||||
const auto C = wei_k_y_x_c_lengths[I3];
|
||||
|
||||
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl;
|
||||
const auto Hi = in_n_hi_wi_c_lengths[I1];
|
||||
const auto Wi = in_n_hi_wi_c_lengths[I2];
|
||||
|
||||
const auto Ho = out_n_ho_wo_k_lengths[I1];
|
||||
const auto Wo = out_n_ho_wo_k_lengths[I2];
|
||||
|
||||
const auto Y = wei_k_y_x_c_lengths[I1];
|
||||
const auto X = wei_k_y_x_c_lengths[I2];
|
||||
|
||||
float perf = (float)(std::size_t(2) * N * K * Ho * Wo * C * Y * X) /
|
||||
(std::size_t(1000) * 1000 * 1000) / ave_time;
|
||||
|
||||
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s"
|
||||
<< std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
// copy result back to host
|
||||
out_n_ho_wo_k_device_buf.FromDevice(out_n_ho_wo_k.mData.data());
|
||||
|
||||
auto f_nhwk2nkhw = [&](auto n, auto k, auto ho, auto wo) {
|
||||
out_n_k_ho_wo(n, k, ho, wo) = out_n_ho_wo_k(n, ho, wo, k);
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(f_nhwk2nkhw, N, K, Ho, Wo)();
|
||||
}
|
||||
|
||||
@@ -0,0 +1,240 @@
|
||||
#include <unistd.h>
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "transform_forward_convolution_into_gemm_v4r5_nchw_kcyx_nkhw.hpp"
|
||||
#include "driver_dynamic_contraction_v1r1.hpp"
|
||||
|
||||
template <typename TInWei,
|
||||
typename TAcc,
|
||||
typename TOut,
|
||||
typename InLengths,
|
||||
typename WeiLengths,
|
||||
typename OutLengths,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads>
|
||||
void device_dynamic_convolution_forward_implicit_gemm_v4r5_nchw_kcyx_nkhw(
|
||||
const InLengths& in_n_c_hi_wi_lengths,
|
||||
const WeiLengths& wei_k_c_y_x_lengths,
|
||||
const OutLengths& out_n_k_ho_wo_lengths,
|
||||
const ConvStrides& conv_strides,
|
||||
const ConvDilations& conv_dilations,
|
||||
const InLeftPads& in_left_pads,
|
||||
const InRightPads& in_right_pads,
|
||||
const Tensor<TInWei>& in_n_c_hi_wi,
|
||||
const Tensor<TInWei>& wei_k_c_y_x,
|
||||
Tensor<TOut>& out_n_k_ho_wo,
|
||||
ck::index_t nrepeat)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
std::cout << __func__ << std::endl;
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
|
||||
DeviceMem in_n_c_hi_wi_device_buf(sizeof(TInWei) * in_n_c_hi_wi.mDesc.GetElementSpace());
|
||||
DeviceMem wei_k_c_y_x_device_buf(sizeof(TInWei) * wei_k_c_y_x.mDesc.GetElementSpace());
|
||||
DeviceMem out_n_k_ho_wo_device_buf(sizeof(TOut) * out_n_k_ho_wo.mDesc.GetElementSpace());
|
||||
|
||||
in_n_c_hi_wi_device_buf.ToDevice(in_n_c_hi_wi.mData.data());
|
||||
wei_k_c_y_x_device_buf.ToDevice(wei_k_c_y_x.mData.data());
|
||||
out_n_k_ho_wo_device_buf.ToDevice(out_n_k_ho_wo.mData.data());
|
||||
|
||||
const auto in_n_c_hi_wi_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(in_n_c_hi_wi_lengths);
|
||||
const auto wei_k_c_y_x_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(wei_k_c_y_x_lengths);
|
||||
const auto out_n_k_ho_wo_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(out_n_k_ho_wo_lengths);
|
||||
|
||||
#if 1
|
||||
// cdata = 64, BlockSize = 256, [8, 1, 128] * [8, 4, 32] = [1, 128, 4, 32]
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t N0 = 4;
|
||||
|
||||
constexpr index_t GemmGM1PerBlockGM11 = 128;
|
||||
constexpr index_t GemmGN1PerBlockGN11 = 32;
|
||||
constexpr index_t GemmKPerBlock = 8;
|
||||
|
||||
constexpr index_t GemmM1PerThreadM111 = 4;
|
||||
constexpr index_t GemmN1PerThreadN111 = 4;
|
||||
constexpr index_t GemmKPerThread = 1;
|
||||
|
||||
constexpr index_t GemmM11N11ThreadClusterM1101 = 2;
|
||||
constexpr index_t GemmM11N11ThreadClusterN1101 = 2;
|
||||
constexpr index_t GemmM11N11ThreadClusterM1100 = 8;
|
||||
constexpr index_t GemmM11N11ThreadClusterN1100 = 8;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GK_GM0_GM10_GM11 = Sequence<4, 1, 1, 1>;
|
||||
using GemmABlockTransferThreadClusterLengths_GK_GM0_GM10_GM11 = Sequence<2, 1, 1, 128>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GK = 4;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GM11 = 1;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GK_GN0_GN10_GN11 = Sequence<1, 4, 1, 1>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GK_GN0_GN10_GN11 = Sequence<8, 1, 1, 32>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GN11 = 1;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GN11 = 1;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector_BN1 = 1;
|
||||
#elif 1
|
||||
// cdata = 64, BlockSize = 256, [8, 1, 128] * [8, 8, 16] = [1, 128, 8, 16]
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t N0 = 8;
|
||||
|
||||
constexpr index_t GemmGM1PerBlockGM11 = 128;
|
||||
constexpr index_t GemmGN1PerBlockGN11 = 16;
|
||||
constexpr index_t GemmKPerBlock = 8;
|
||||
|
||||
constexpr index_t GemmM1PerThreadM111 = 4;
|
||||
constexpr index_t GemmN1PerThreadN111 = 4;
|
||||
constexpr index_t GemmKPerThread = 1;
|
||||
|
||||
constexpr index_t GemmM11N11ThreadClusterM1101 = 2;
|
||||
constexpr index_t GemmM11N11ThreadClusterN1101 = 2;
|
||||
constexpr index_t GemmM11N11ThreadClusterM1100 = 8;
|
||||
constexpr index_t GemmM11N11ThreadClusterN1100 = 8;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GK_GM0_GM10_GM11 = Sequence<4, 1, 1, 1>;
|
||||
using GemmABlockTransferThreadClusterLengths_GK_GM0_GM10_GM11 = Sequence<2, 1, 1, 128>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GK = 4;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GM11 = 1;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GK_GN0_GN10_GN11 = Sequence<1, 4, 1, 1>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GK_GN0_GN10_GN11 = Sequence<8, 2, 1, 16>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GN11 = 1;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GN11 = 1;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector_BN1 = 1;
|
||||
#endif
|
||||
|
||||
const auto descs = transform_forward_convolution_into_contraction_v4r5_nchw_kcyx_nkhw_pad<N0>(
|
||||
wei_k_c_y_x_desc,
|
||||
in_n_c_hi_wi_desc,
|
||||
out_n_k_ho_wo_desc,
|
||||
conv_strides,
|
||||
conv_dilations,
|
||||
in_left_pads,
|
||||
in_right_pads);
|
||||
|
||||
const auto wei_gk_gm0_gm1_grid_desc = descs[I0];
|
||||
const auto in_gk_gn0_gn1_grid_desc = descs[I1];
|
||||
const auto out_gm0_gm1_gn0_gn1_grid_desc = descs[I2];
|
||||
|
||||
// HACK: hacks that control index calculation when iterating over A, B, C matrix
|
||||
constexpr auto wei_gk_gm0_gm10_gm11_grid_iterator_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0>{}));
|
||||
|
||||
constexpr auto in_gk_gn0_gn10_gn11_grid_iterator_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}));
|
||||
|
||||
constexpr auto out_gm10_bm0_bm1_gn10_bn0_bn1_grid_iterator_hacks = make_tuple(
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{}));
|
||||
|
||||
constexpr auto wei_gk_gm0_gm10_gm11_grid_move_slice_window_iterator_hacks =
|
||||
Sequence<0, 0, 0, 0, 0, 0>{};
|
||||
|
||||
constexpr auto in_gk_gn0_gn10_gn11_grid_move_slice_window_iterator_hacks =
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 2, 0, 0, 0>{};
|
||||
|
||||
for(index_t i = 0; i < 5; ++i)
|
||||
{
|
||||
float ave_time = driver_dynamic_contraction_v1r1<
|
||||
BlockSize,
|
||||
TInWei,
|
||||
TAcc,
|
||||
TOut,
|
||||
InMemoryDataOperation::Set,
|
||||
decltype(wei_gk_gm0_gm1_grid_desc),
|
||||
decltype(in_gk_gn0_gn1_grid_desc),
|
||||
decltype(out_gm0_gm1_gn0_gn1_grid_desc),
|
||||
GemmGM1PerBlockGM11,
|
||||
GemmGN1PerBlockGN11,
|
||||
GemmKPerBlock,
|
||||
GemmM1PerThreadM111,
|
||||
GemmN1PerThreadN111,
|
||||
GemmKPerThread,
|
||||
GemmM11N11ThreadClusterM1100,
|
||||
GemmM11N11ThreadClusterN1100,
|
||||
GemmM11N11ThreadClusterM1101,
|
||||
GemmM11N11ThreadClusterN1101,
|
||||
GemmABlockTransferThreadSliceLengths_GK_GM0_GM10_GM11,
|
||||
GemmABlockTransferThreadClusterLengths_GK_GM0_GM10_GM11,
|
||||
Sequence<3, 2, 1, 0>, // ABlockTransferThreadClusterArrangeOrder
|
||||
Sequence<3, 2, 1, 0>, // ABlockTransferSrcAccessOrder
|
||||
0, // ABlockTransferSrcVectorDim
|
||||
GemmABlockTransferSrcScalarPerVector_GK,
|
||||
GemmABlockTransferDstScalarPerVector_GM11,
|
||||
false, // don't move back src coordinate after threadwise copy
|
||||
GemmBBlockTransferThreadSliceLengths_GK_GN0_GN10_GN11,
|
||||
GemmBBlockTransferThreadClusterLengths_GK_GN0_GN10_GN11,
|
||||
Sequence<0, 3, 2, 1>, // BBlockTransferThreadClusterArrangeOrder
|
||||
Sequence<0, 3, 2, 1>, // BBlockTransferSrcAccessOrder
|
||||
3, // BBlockTransferSrcVectorDim
|
||||
GemmBBlockTransferSrcScalarPerVector_GN11,
|
||||
GemmBBlockTransferDstScalarPerVector_GN11,
|
||||
false, // don't move back src coordinate after threadwise copy
|
||||
Sequence<3, 4, 5, 0, 1, 2>, // CThreadTransferSrcDstAccessOrder
|
||||
5, // CThreadTransferSrcDstVectorDim
|
||||
GemmCThreadTransferDstScalarPerVector_BN1,
|
||||
decltype(wei_gk_gm0_gm10_gm11_grid_iterator_hacks),
|
||||
decltype(in_gk_gn0_gn10_gn11_grid_iterator_hacks),
|
||||
decltype(out_gm10_bm0_bm1_gn10_bn0_bn1_grid_iterator_hacks),
|
||||
decltype(wei_gk_gm0_gm10_gm11_grid_move_slice_window_iterator_hacks),
|
||||
decltype(in_gk_gn0_gn10_gn11_grid_move_slice_window_iterator_hacks)>(
|
||||
static_cast<TInWei*>(wei_k_c_y_x_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TInWei*>(in_n_c_hi_wi_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TOut*>(out_n_k_ho_wo_device_buf.GetDeviceBuffer()),
|
||||
wei_gk_gm0_gm1_grid_desc,
|
||||
in_gk_gn0_gn1_grid_desc,
|
||||
out_gm0_gm1_gn0_gn1_grid_desc,
|
||||
wei_gk_gm0_gm10_gm11_grid_iterator_hacks,
|
||||
in_gk_gn0_gn10_gn11_grid_iterator_hacks,
|
||||
out_gm10_bm0_bm1_gn10_bn0_bn1_grid_iterator_hacks,
|
||||
wei_gk_gm0_gm10_gm11_grid_move_slice_window_iterator_hacks,
|
||||
in_gk_gn0_gn10_gn11_grid_move_slice_window_iterator_hacks,
|
||||
nrepeat);
|
||||
|
||||
float perf = (float)calculate_convolution_flops(
|
||||
in_n_c_hi_wi_desc, wei_k_c_y_x_desc, out_n_k_ho_wo_desc) /
|
||||
(std::size_t(1000) * 1000 * 1000) / ave_time;
|
||||
|
||||
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl;
|
||||
}
|
||||
|
||||
// copy result back to host
|
||||
out_n_k_ho_wo_device_buf.FromDevice(out_n_k_ho_wo.mData.data());
|
||||
}
|
||||
@@ -4,97 +4,64 @@
|
||||
#include "driver_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw.hpp"
|
||||
#include "driver_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw_outpad.hpp"
|
||||
|
||||
template <class TInWei,
|
||||
template <typename TInWei,
|
||||
ck::index_t InWeiVectorSize,
|
||||
class TAcc,
|
||||
class TOut,
|
||||
class InDesc,
|
||||
class WeiDesc,
|
||||
class OutDesc,
|
||||
class ConvStrides,
|
||||
class ConvDilations,
|
||||
class InLeftPads,
|
||||
class InRightPads>
|
||||
typename TAcc,
|
||||
typename TOut,
|
||||
typename InLengths,
|
||||
typename WeiLengths,
|
||||
typename OutLengths,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads>
|
||||
void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(
|
||||
InDesc,
|
||||
const InLengths& in_n_c_hi_wi_lengths,
|
||||
const WeiLengths& wei_k_c_y_x_lengths,
|
||||
const OutLengths& out_n_k_ho_wo_lengths,
|
||||
const ConvStrides& conv_strides,
|
||||
const ConvDilations& conv_dilations,
|
||||
const InLeftPads& in_left_pads,
|
||||
const InRightPads& in_right_pads,
|
||||
const Tensor<TInWei>& in_n_c_hi_wi,
|
||||
WeiDesc,
|
||||
const Tensor<TInWei>& wei_k_c_y_x,
|
||||
OutDesc,
|
||||
Tensor<TOut>& out_n_k_ho_wo,
|
||||
ConvStrides,
|
||||
ConvDilations,
|
||||
InLeftPads,
|
||||
InRightPads,
|
||||
ck::index_t nrepeat)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
std::cout << "device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw"
|
||||
<< std::endl;
|
||||
|
||||
DeviceMem in_n_c_hi_wi_device_buf(sizeof(TInWei) * in_n_c_hi_wi.mDesc.GetElementSpace());
|
||||
DeviceMem wei_k_c_y_x_device_buf(sizeof(TInWei) * wei_k_c_y_x.mDesc.GetElementSpace());
|
||||
DeviceMem out_n_k_ho_wo_device_buf(sizeof(TOut) * out_n_k_ho_wo.mDesc.GetElementSpace());
|
||||
std::cout << __func__ << std::endl;
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
constexpr auto N = OutDesc::GetLengths()[I0];
|
||||
constexpr auto K = OutDesc::GetLengths()[I1];
|
||||
constexpr auto C = WeiDesc::GetLengths()[I1];
|
||||
const auto N = out_n_k_ho_wo_lengths[I0];
|
||||
const auto K = out_n_k_ho_wo_lengths[I1];
|
||||
const auto C = wei_k_c_y_x_lengths[I1];
|
||||
|
||||
constexpr auto Hi = InDesc::GetLengths()[I2];
|
||||
constexpr auto Wi = InDesc::GetLengths()[I3];
|
||||
const auto Hi = in_n_c_hi_wi_lengths[I2];
|
||||
const auto Wi = in_n_c_hi_wi_lengths[I3];
|
||||
|
||||
constexpr auto Ho = OutDesc::GetLengths()[I2];
|
||||
constexpr auto Wo = OutDesc::GetLengths()[I3];
|
||||
const auto Ho = out_n_k_ho_wo_lengths[I2];
|
||||
const auto Wo = out_n_k_ho_wo_lengths[I3];
|
||||
|
||||
constexpr auto Y = WeiDesc::GetLengths()[I2];
|
||||
constexpr auto X = WeiDesc::GetLengths()[I3];
|
||||
const auto Y = wei_k_c_y_x_lengths[I2];
|
||||
const auto X = wei_k_c_y_x_lengths[I3];
|
||||
|
||||
constexpr auto C0 = C / Number<InWeiVectorSize>{};
|
||||
constexpr auto C1 = Number<InWeiVectorSize>{};
|
||||
const auto C0 = C / Number<InWeiVectorSize>{};
|
||||
const auto C1 = Number<InWeiVectorSize>{};
|
||||
|
||||
constexpr auto K0 = K / Number<InWeiVectorSize>{};
|
||||
constexpr auto K1 = Number<InWeiVectorSize>{};
|
||||
const auto K0 = K / Number<InWeiVectorSize>{};
|
||||
const auto K1 = Number<InWeiVectorSize>{};
|
||||
|
||||
#if 0
|
||||
// run-time variables
|
||||
const auto in_n_c0_hi_wi_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_multi_index(N, C0, Hi, Wi));
|
||||
const auto wei_k_c0_y_x_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_multi_index(K, C0, Y, X));
|
||||
const auto out_n_k0_ho_wo_k1_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_multi_index(N, K0, Ho, Wo, K1));
|
||||
|
||||
const auto conv_strides = to_multi_index(ConvStrides{});
|
||||
const auto conv_dilations = to_multi_index(ConvDilations{});
|
||||
const auto in_left_pads = to_multi_index(InLeftPads{});
|
||||
const auto in_right_pads = to_multi_index(InRightPads{});
|
||||
#else
|
||||
// compile-time variables
|
||||
const auto in_n_c0_hi_wi_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, C0, Hi, Wi));
|
||||
const auto wei_k_c0_y_x_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, C0, Y, X));
|
||||
const auto out_n_k0_ho_wo_k1_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, K0, Ho, Wo, K1));
|
||||
|
||||
const auto conv_strides = sequence_to_tuple_of_number(ConvStrides{});
|
||||
const auto conv_dilations = sequence_to_tuple_of_number(ConvDilations{});
|
||||
const auto in_left_pads = sequence_to_tuple_of_number(InLeftPads{});
|
||||
const auto in_right_pads = sequence_to_tuple_of_number(InRightPads{});
|
||||
#endif
|
||||
|
||||
Tensor<TInWei> in_n_c0_hi_wi_c1(make_HostTensorDescriptor(
|
||||
make_native_tensor_descriptor_packed(Sequence<N, C0, Hi, Wi, C1>{})));
|
||||
Tensor<TInWei> wei_k_c0_y_x_c1(make_HostTensorDescriptor(
|
||||
make_native_tensor_descriptor_packed(Sequence<K, C0, Y, X, C1>{})));
|
||||
Tensor<TOut> out_n_k0_ho_wo_k1(make_HostTensorDescriptor(
|
||||
make_native_tensor_descriptor_packed(Sequence<N, K0, Ho, Wo, K1>{})));
|
||||
Tensor<TInWei> in_n_c0_hi_wi_c1(
|
||||
HostTensorDescriptor(std::initializer_list<index_t>{N, C0, Hi, Wi, C1}));
|
||||
Tensor<TInWei> wei_k_c0_y_x_c1(
|
||||
HostTensorDescriptor(std::initializer_list<index_t>{K, C0, Y, X, C1}));
|
||||
Tensor<TOut> out_n_k0_ho_wo_k1(
|
||||
HostTensorDescriptor(std::initializer_list<index_t>{N, K0, Ho, Wo, K1}));
|
||||
|
||||
auto f_nchw2nc0hwc1 = [&](auto n, auto hi, auto wi, auto c) {
|
||||
in_n_c0_hi_wi_c1(n, c / InWeiVectorSize, hi, wi, c % InWeiVectorSize) =
|
||||
@@ -109,17 +76,30 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(
|
||||
make_ParallelTensorFunctor(f_nchw2nc0hwc1, N, Hi, Wi, C)();
|
||||
make_ParallelTensorFunctor(f_kcyx2kc0yxc1, K, Y, X, C)();
|
||||
|
||||
in_n_c_hi_wi_device_buf.ToDevice(in_n_c0_hi_wi_c1.mData.data());
|
||||
wei_k_c_y_x_device_buf.ToDevice(wei_k_c0_y_x_c1.mData.data());
|
||||
DeviceMem in_n_c0_hi_wi_c1_device_buf(sizeof(TInWei) *
|
||||
in_n_c0_hi_wi_c1.mDesc.GetElementSpace());
|
||||
DeviceMem wei_k_c0_y_x_c1_device_buf(sizeof(TInWei) * wei_k_c0_y_x_c1.mDesc.GetElementSpace());
|
||||
DeviceMem out_n_k0_ho_wo_k1_device_buf(sizeof(TOut) *
|
||||
out_n_k0_ho_wo_k1.mDesc.GetElementSpace());
|
||||
|
||||
in_n_c0_hi_wi_c1_device_buf.ToDevice(in_n_c0_hi_wi_c1.mData.data());
|
||||
wei_k_c0_y_x_c1_device_buf.ToDevice(wei_k_c0_y_x_c1.mData.data());
|
||||
|
||||
const auto in_n_c0_hi_wi_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, C0, Hi, Wi));
|
||||
const auto wei_k_c0_y_x_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, C0, Y, X));
|
||||
const auto out_n_k0_ho_wo_k1_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, K0, Ho, Wo, K1));
|
||||
|
||||
#if 1
|
||||
// cdata = 64, BlockSize = 64, 16x8x32x4
|
||||
constexpr index_t BlockSize = 64;
|
||||
|
||||
constexpr index_t KPerBlock = K;
|
||||
constexpr index_t KPerBlock = 16;
|
||||
constexpr index_t HoPerBlock = 8;
|
||||
constexpr index_t WoPerBlock = 32;
|
||||
constexpr index_t EPerBlock = C0;
|
||||
constexpr index_t EPerBlock = 1;
|
||||
|
||||
constexpr index_t KPerThread = KPerBlock;
|
||||
constexpr index_t HoPerThread = 2;
|
||||
@@ -134,7 +114,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(
|
||||
|
||||
constexpr index_t BThreadTransferSrcScalarPerVector_W = 1;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector_W = K1;
|
||||
constexpr index_t CThreadTransferDstScalarPerVector_W = 16;
|
||||
|
||||
static_assert(KPerThread % CThreadTransferDstScalarPerVector_W == 0, "");
|
||||
#else
|
||||
@@ -165,17 +145,28 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(
|
||||
|
||||
constexpr auto conv_driver =
|
||||
#if 0
|
||||
DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad<
|
||||
DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad
|
||||
#else
|
||||
DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad<
|
||||
DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad
|
||||
#endif
|
||||
BlockSize,
|
||||
typename vector_type<TInWei, InWeiVectorSize>::type, TAcc, TOut, KPerBlock,
|
||||
HoPerBlock, WoPerBlock, EPerBlock, KPerThread, HoPerThread, WoPerThread,
|
||||
EPerThread, ABlockTransferThreadSliceLengths_E_K,
|
||||
ABlockTransferThreadClusterLengths_E_K, ABlockTransferSrcScalarPerVector_E,
|
||||
ABlockTransferDstScalarPerVector_K, BThreadTransferSrcScalarPerVector_W,
|
||||
CThreadTransferDstScalarPerVector_W > {};
|
||||
<BlockSize,
|
||||
typename vector_type<TInWei, InWeiVectorSize>::type,
|
||||
TAcc,
|
||||
TOut,
|
||||
KPerBlock,
|
||||
HoPerBlock,
|
||||
WoPerBlock,
|
||||
EPerBlock,
|
||||
KPerThread,
|
||||
HoPerThread,
|
||||
WoPerThread,
|
||||
EPerThread,
|
||||
ABlockTransferThreadSliceLengths_E_K,
|
||||
ABlockTransferThreadClusterLengths_E_K,
|
||||
ABlockTransferSrcScalarPerVector_E,
|
||||
ABlockTransferDstScalarPerVector_K,
|
||||
BThreadTransferSrcScalarPerVector_W,
|
||||
CThreadTransferDstScalarPerVector_W>{};
|
||||
|
||||
conv_driver.Run(wei_k_c0_y_x_desc,
|
||||
in_n_c0_hi_wi_desc,
|
||||
@@ -185,12 +176,12 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(
|
||||
in_left_pads,
|
||||
in_right_pads,
|
||||
static_cast<typename vector_type<TInWei, InWeiVectorSize>::type*>(
|
||||
wei_k_c_y_x_device_buf.GetDeviceBuffer()),
|
||||
wei_k_c0_y_x_c1_device_buf.GetDeviceBuffer()),
|
||||
static_cast<typename vector_type<TInWei, InWeiVectorSize>::type*>(
|
||||
in_n_c_hi_wi_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TOut*>(out_n_k_ho_wo_device_buf.GetDeviceBuffer()));
|
||||
in_n_c0_hi_wi_c1_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TOut*>(out_n_k0_ho_wo_k1_device_buf.GetDeviceBuffer()));
|
||||
|
||||
out_n_k_ho_wo_device_buf.FromDevice(out_n_k0_ho_wo_k1.mData.data());
|
||||
out_n_k0_ho_wo_k1_device_buf.FromDevice(out_n_k0_ho_wo_k1.mData.data());
|
||||
|
||||
auto f_nk0hwk1_to_nkhw = [&](auto n, auto k, auto ho, auto wo) {
|
||||
out_n_k_ho_wo(n, k, ho, wo) =
|
||||
|
||||
@@ -6,58 +6,94 @@ template <class TIn,
|
||||
class TOut,
|
||||
class ConvStrides,
|
||||
class ConvDilations,
|
||||
class LowerPads,
|
||||
class UpperPads>
|
||||
void host_direct_convolution(const Tensor<TIn>& in_nchw,
|
||||
const Tensor<TWei>& wei_kcyx,
|
||||
Tensor<TOut>& out_nkhw,
|
||||
ConvStrides,
|
||||
ConvDilations,
|
||||
LowerPads,
|
||||
UpperPads)
|
||||
class InLeftPads,
|
||||
class InRightPads>
|
||||
void host_direct_convolution(const Tensor<TIn>& in,
|
||||
const Tensor<TWei>& wei,
|
||||
Tensor<TOut>& out,
|
||||
const ConvStrides& conv_strides,
|
||||
const ConvDilations& conv_dilations,
|
||||
const InLeftPads& in_left_pads,
|
||||
const InRightPads& in_right_pads,
|
||||
const ConvTensorLayout layout = ConvTensorLayout::NCHW)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
index_t h_pad_low = LowerPads{}.Get(Number<0>{});
|
||||
index_t w_pad_low = LowerPads{}.Get(Number<1>{});
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
auto f = [&](auto n, auto k, auto ho, auto wo) {
|
||||
auto f_nchw = [&](auto n, auto k, auto ho, auto wo) {
|
||||
double v = 0;
|
||||
for(int c = 0; c < wei_kcyx.mDesc.GetLengths()[1]; ++c)
|
||||
for(int c = 0; c < wei.mDesc.GetLengths()[1]; ++c)
|
||||
{
|
||||
for(int y = 0; y < wei_kcyx.mDesc.GetLengths()[2]; ++y)
|
||||
for(int y = 0; y < wei.mDesc.GetLengths()[2]; ++y)
|
||||
{
|
||||
int hi = ho * ConvStrides{}[0] + y * ConvDilations{}[0] - h_pad_low;
|
||||
for(int x = 0; x < wei_kcyx.mDesc.GetLengths()[3]; ++x)
|
||||
int hi = ho * conv_strides[I0] + y * conv_dilations[I0] - in_left_pads[I0];
|
||||
for(int x = 0; x < wei.mDesc.GetLengths()[3]; ++x)
|
||||
{
|
||||
int wi = wo * ConvStrides{}[1] + x * ConvDilations{}[1] - w_pad_low;
|
||||
if(hi >= 0 && hi < in_nchw.mDesc.GetLengths()[2] && wi >= 0 &&
|
||||
wi < in_nchw.mDesc.GetLengths()[3])
|
||||
int wi = wo * conv_strides[I1] + x * conv_dilations[I1] - in_left_pads[I1];
|
||||
if(hi >= 0 && hi < in.mDesc.GetLengths()[2] && wi >= 0 &&
|
||||
wi < in.mDesc.GetLengths()[3])
|
||||
{
|
||||
v += static_cast<const double>(in_nchw(n, c, hi, wi)) *
|
||||
static_cast<const double>(wei_kcyx(k, c, y, x));
|
||||
v += static_cast<const double>(in(n, c, hi, wi)) *
|
||||
static_cast<const double>(wei(k, c, y, x));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
out_nkhw(n, k, ho, wo) = v;
|
||||
out(n, k, ho, wo) = v;
|
||||
};
|
||||
|
||||
auto f_par = make_ParallelTensorFunctor(f,
|
||||
out_nkhw.mDesc.GetLengths()[0],
|
||||
out_nkhw.mDesc.GetLengths()[1],
|
||||
out_nkhw.mDesc.GetLengths()[2],
|
||||
out_nkhw.mDesc.GetLengths()[3]);
|
||||
auto f_nhwc = [&](auto n, auto ho, auto wo, auto k) {
|
||||
double v = 0;
|
||||
for(int c = 0; c < wei.mDesc.GetLengths()[3]; ++c)
|
||||
{
|
||||
for(int y = 0; y < wei.mDesc.GetLengths()[1]; ++y)
|
||||
{
|
||||
int hi = ho * conv_strides[I0] + y * conv_dilations[I0] - in_left_pads[I0];
|
||||
for(int x = 0; x < wei.mDesc.GetLengths()[2]; ++x)
|
||||
{
|
||||
int wi = wo * conv_strides[I1] + x * conv_dilations[I1] - in_left_pads[I1];
|
||||
if(hi >= 0 && hi < in.mDesc.GetLengths()[1] && wi >= 0 &&
|
||||
wi < in.mDesc.GetLengths()[2])
|
||||
{
|
||||
v += static_cast<const double>(in(n, hi, wi, c)) *
|
||||
static_cast<const double>(wei(k, y, x, c));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
out(n, ho, wo, k) = v;
|
||||
};
|
||||
|
||||
f_par(std::thread::hardware_concurrency());
|
||||
switch(layout)
|
||||
{
|
||||
case ConvTensorLayout::NCHW:
|
||||
make_ParallelTensorFunctor(f_nchw,
|
||||
out.mDesc.GetLengths()[0],
|
||||
out.mDesc.GetLengths()[1],
|
||||
out.mDesc.GetLengths()[2],
|
||||
out.mDesc.GetLengths()[3])(std::thread::hardware_concurrency());
|
||||
break;
|
||||
case ConvTensorLayout::NHWC:
|
||||
make_ParallelTensorFunctor(f_nhwc,
|
||||
out.mDesc.GetLengths()[0],
|
||||
out.mDesc.GetLengths()[1],
|
||||
out.mDesc.GetLengths()[2],
|
||||
out.mDesc.GetLengths()[3])(std::thread::hardware_concurrency());
|
||||
break;
|
||||
default: throw std::runtime_error("wrong! not supported layout");
|
||||
}
|
||||
}
|
||||
|
||||
template <class TIn, class TWei, class TOut, class LowerPads, class UpperPads>
|
||||
template <class TIn, class TWei, class TOut, class InLeftPads, class InRightPads>
|
||||
void host_winograd_3x3_convolution(const Tensor<TIn>& in_nchw,
|
||||
const Tensor<TWei>& wei_kcyx,
|
||||
Tensor<TOut>& out_nkhw,
|
||||
LowerPads,
|
||||
UpperPads)
|
||||
InLeftPads,
|
||||
InRightPads)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
@@ -76,8 +112,8 @@ void host_winograd_3x3_convolution(const Tensor<TIn>& in_nchw,
|
||||
std::size_t HO = out_nkhw.mDesc.GetLengths()[2];
|
||||
std::size_t WO = out_nkhw.mDesc.GetLengths()[3];
|
||||
|
||||
index_t h_pad_low = LowerPads{}.Get(Number<0>{});
|
||||
index_t w_pad_low = LowerPads{}.Get(Number<1>{});
|
||||
index_t h_pad_low = InLeftPads{}.Get(Number<0>{});
|
||||
index_t w_pad_low = InLeftPads{}.Get(Number<1>{});
|
||||
|
||||
std::size_t HiPerTile = HoPerTile + Y - 1;
|
||||
std::size_t WiPerTile = WoPerTile + X - 1;
|
||||
|
||||
@@ -271,19 +271,20 @@ struct Tensor
|
||||
std::vector<T> mData;
|
||||
};
|
||||
|
||||
void ostream_HostTensorDescriptor(const HostTensorDescriptor& desc, std::ostream& os = std::cout)
|
||||
template <typename X>
|
||||
HostTensorDescriptor::HostTensorDescriptor(std::vector<X> lens) : mLens(lens)
|
||||
{
|
||||
os << "dim " << desc.GetNumOfDimension() << ", ";
|
||||
|
||||
os << "lengths {";
|
||||
LogRange(os, desc.GetLengths(), ", ");
|
||||
os << "}, ";
|
||||
|
||||
os << "strides {";
|
||||
LogRange(os, desc.GetStrides(), ", ");
|
||||
os << "}" << std::endl;
|
||||
this->CalculateStrides();
|
||||
}
|
||||
|
||||
template <typename X, typename Y>
|
||||
HostTensorDescriptor::HostTensorDescriptor(std::vector<X> lens, std::vector<Y> strides)
|
||||
: mLens(lens), mStrides(strides)
|
||||
{
|
||||
}
|
||||
|
||||
void ostream_HostTensorDescriptor(const HostTensorDescriptor& desc, std::ostream& os = std::cout);
|
||||
|
||||
template <class T>
|
||||
void check_error(const Tensor<T>& ref, const Tensor<T>& result)
|
||||
{
|
||||
|
||||
@@ -44,7 +44,7 @@ struct GeneratorTensor_Checkboard
|
||||
template <class... Ts>
|
||||
double operator()(Ts... Xs) const
|
||||
{
|
||||
std::array<ck::index_t, sizeof...(Ts)> dims = {{Xs...}};
|
||||
std::array<ck::index_t, sizeof...(Ts)> dims = {{static_cast<ck::index_t>(Xs)...}};
|
||||
return std::accumulate(dims.begin(),
|
||||
dims.end(),
|
||||
true,
|
||||
|
||||
@@ -14,19 +14,41 @@
|
||||
#include "device_convolution_forward_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp"
|
||||
#include "device_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp"
|
||||
#include "device_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk.hpp"
|
||||
#include "device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp"
|
||||
#include "device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk.hpp"
|
||||
#include "device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw.hpp"
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
if(argc != 5)
|
||||
{
|
||||
printf("arg1: do_verification, arg2: do_log, arg3: init_method, arg4: nrepeat\n");
|
||||
exit(1);
|
||||
}
|
||||
|
||||
const bool do_verification = atoi(argv[1]);
|
||||
const int init_method = atoi(argv[2]);
|
||||
const bool do_log = atoi(argv[3]);
|
||||
const int nrepeat = atoi(argv[4]);
|
||||
|
||||
#if 0
|
||||
constexpr index_t N = 8;
|
||||
constexpr index_t C = 8;
|
||||
constexpr index_t Hi = 4;
|
||||
constexpr index_t Wi = 8;
|
||||
constexpr index_t K = 256;
|
||||
constexpr index_t Y = 3;
|
||||
constexpr index_t X = 3;
|
||||
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using InLeftPads = Sequence<1, 1>;
|
||||
using InRightPads = Sequence<1, 1>;
|
||||
#elif 0
|
||||
constexpr index_t N = 1;
|
||||
constexpr index_t C = 16;
|
||||
constexpr index_t HI = 1080;
|
||||
constexpr index_t WI = 1920;
|
||||
constexpr index_t Hi = 540;
|
||||
constexpr index_t Wi = 960;
|
||||
constexpr index_t K = 16;
|
||||
constexpr index_t Y = 1;
|
||||
constexpr index_t X = 1;
|
||||
@@ -34,13 +56,13 @@ int main(int argc, char* argv[])
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<0, 0>;
|
||||
using RightPads = Sequence<0, 0>;
|
||||
using InLeftPads = Sequence<0, 0>;
|
||||
using InRightPads = Sequence<0, 0>;
|
||||
#elif 0
|
||||
constexpr index_t N = 1;
|
||||
constexpr index_t C = 16;
|
||||
constexpr index_t HI = 540;
|
||||
constexpr index_t WI = 960;
|
||||
constexpr index_t Hi = 270;
|
||||
constexpr index_t Wi = 480;
|
||||
constexpr index_t K = 16;
|
||||
constexpr index_t Y = 1;
|
||||
constexpr index_t X = 1;
|
||||
@@ -48,27 +70,13 @@ int main(int argc, char* argv[])
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<0, 0>;
|
||||
using RightPads = Sequence<0, 0>;
|
||||
using InLeftPads = Sequence<0, 0>;
|
||||
using InRightPads = Sequence<0, 0>;
|
||||
#elif 0
|
||||
constexpr index_t N = 1;
|
||||
constexpr index_t C = 16;
|
||||
constexpr index_t HI = 270;
|
||||
constexpr index_t WI = 480;
|
||||
constexpr index_t K = 16;
|
||||
constexpr index_t Y = 1;
|
||||
constexpr index_t X = 1;
|
||||
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<0, 0>;
|
||||
using RightPads = Sequence<0, 0>;
|
||||
#elif 0
|
||||
constexpr index_t N = 1;
|
||||
constexpr index_t C = 16;
|
||||
constexpr index_t HI = 1080;
|
||||
constexpr index_t WI = 1920;
|
||||
constexpr index_t Hi = 1080;
|
||||
constexpr index_t Wi = 1920;
|
||||
constexpr index_t K = 16;
|
||||
constexpr index_t Y = 3;
|
||||
constexpr index_t X = 3;
|
||||
@@ -76,13 +84,13 @@ int main(int argc, char* argv[])
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<1, 1>;
|
||||
using RightPads = Sequence<1, 1>;
|
||||
using InLeftPads = Sequence<1, 1>;
|
||||
using InRightPads = Sequence<1, 1>;
|
||||
#elif 0
|
||||
constexpr index_t N = 1;
|
||||
constexpr index_t C = 1;
|
||||
constexpr index_t HI = 1024;
|
||||
constexpr index_t WI = 2048;
|
||||
constexpr index_t Hi = 1024;
|
||||
constexpr index_t Wi = 2048;
|
||||
constexpr index_t K = 4;
|
||||
constexpr index_t Y = 3;
|
||||
constexpr index_t X = 3;
|
||||
@@ -90,13 +98,13 @@ int main(int argc, char* argv[])
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<1, 1>;
|
||||
using RightPads = Sequence<1, 1>;
|
||||
using InLeftPads = Sequence<1, 1>;
|
||||
using InRightPads = Sequence<1, 1>;
|
||||
#elif 0
|
||||
constexpr index_t N = 1;
|
||||
constexpr index_t C = 16;
|
||||
constexpr index_t HI = 540;
|
||||
constexpr index_t WI = 960;
|
||||
constexpr index_t Hi = 540;
|
||||
constexpr index_t Wi = 960;
|
||||
constexpr index_t K = 16;
|
||||
constexpr index_t Y = 3;
|
||||
constexpr index_t X = 3;
|
||||
@@ -104,13 +112,13 @@ int main(int argc, char* argv[])
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<1, 1>;
|
||||
using RightPads = Sequence<1, 1>;
|
||||
using InLeftPads = Sequence<1, 1>;
|
||||
using InRightPads = Sequence<1, 1>;
|
||||
#elif 0
|
||||
constexpr index_t N = 1;
|
||||
constexpr index_t C = 16;
|
||||
constexpr index_t HI = 270;
|
||||
constexpr index_t WI = 480;
|
||||
constexpr index_t Hi = 270;
|
||||
constexpr index_t Wi = 480;
|
||||
constexpr index_t K = 16;
|
||||
constexpr index_t Y = 3;
|
||||
constexpr index_t X = 3;
|
||||
@@ -118,14 +126,14 @@ int main(int argc, char* argv[])
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<1, 1>;
|
||||
using RightPads = Sequence<1, 1>;
|
||||
using InLeftPads = Sequence<1, 1>;
|
||||
using InRightPads = Sequence<1, 1>;
|
||||
#elif 0
|
||||
// 3x3, 36x36, stride 2
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 192;
|
||||
constexpr index_t HI = 37;
|
||||
constexpr index_t WI = 37;
|
||||
constexpr index_t Hi = 37;
|
||||
constexpr index_t Wi = 37;
|
||||
constexpr index_t K = 384;
|
||||
constexpr index_t Y = 3;
|
||||
constexpr index_t X = 3;
|
||||
@@ -133,14 +141,14 @@ int main(int argc, char* argv[])
|
||||
using ConvStrides = Sequence<2, 2>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<0, 0>;
|
||||
using RightPads = Sequence<0, 0>;
|
||||
using InLeftPads = Sequence<0, 0>;
|
||||
using InRightPads = Sequence<0, 0>;
|
||||
#elif 0
|
||||
// 3x3, 35x35, stride 2
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 192;
|
||||
constexpr index_t HI = 35;
|
||||
constexpr index_t WI = 35;
|
||||
constexpr index_t Hi = 35;
|
||||
constexpr index_t Wi = 35;
|
||||
constexpr index_t K = 384;
|
||||
constexpr index_t Y = 3;
|
||||
constexpr index_t X = 3;
|
||||
@@ -148,14 +156,14 @@ int main(int argc, char* argv[])
|
||||
using ConvStrides = Sequence<2, 2>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<0, 0>;
|
||||
using RightPads = Sequence<0, 0>;
|
||||
#elif 1
|
||||
using InLeftPads = Sequence<0, 0>;
|
||||
using InRightPads = Sequence<0, 0>;
|
||||
#elif 0
|
||||
// 3x3, 71x71
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 192;
|
||||
constexpr index_t HI = 71;
|
||||
constexpr index_t WI = 71;
|
||||
constexpr index_t Hi = 71;
|
||||
constexpr index_t Wi = 71;
|
||||
constexpr index_t K = 128;
|
||||
constexpr index_t Y = 3;
|
||||
constexpr index_t X = 3;
|
||||
@@ -163,14 +171,14 @@ int main(int argc, char* argv[])
|
||||
using ConvStrides = Sequence<2, 2>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<1, 1>;
|
||||
using RightPads = Sequence<1, 1>;
|
||||
#elif 1
|
||||
using InLeftPads = Sequence<1, 1>;
|
||||
using InRightPads = Sequence<1, 1>;
|
||||
#elif 0
|
||||
// 1x1, 8x8
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 1536;
|
||||
constexpr index_t HI = 8;
|
||||
constexpr index_t WI = 8;
|
||||
constexpr index_t Hi = 8;
|
||||
constexpr index_t Wi = 8;
|
||||
constexpr index_t K = 256;
|
||||
constexpr index_t Y = 1;
|
||||
constexpr index_t X = 1;
|
||||
@@ -178,14 +186,14 @@ int main(int argc, char* argv[])
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<0, 0>;
|
||||
using RightPads = Sequence<0, 0>;
|
||||
using InLeftPads = Sequence<0, 0>;
|
||||
using InRightPads = Sequence<0, 0>;
|
||||
#elif 0
|
||||
// 1x1, 73x73
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 160;
|
||||
constexpr index_t HI = 73;
|
||||
constexpr index_t WI = 73;
|
||||
constexpr index_t Hi = 73;
|
||||
constexpr index_t Wi = 73;
|
||||
constexpr index_t K = 64;
|
||||
constexpr index_t Y = 1;
|
||||
constexpr index_t X = 1;
|
||||
@@ -193,14 +201,14 @@ int main(int argc, char* argv[])
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<0, 0>;
|
||||
using RightPads = Sequence<0, 0>;
|
||||
using InLeftPads = Sequence<0, 0>;
|
||||
using InRightPads = Sequence<0, 0>;
|
||||
#elif 0
|
||||
// 3x3, 35x35
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 96;
|
||||
constexpr index_t HI = 35;
|
||||
constexpr index_t WI = 35;
|
||||
constexpr index_t Hi = 35;
|
||||
constexpr index_t Wi = 35;
|
||||
constexpr index_t K = 128;
|
||||
constexpr index_t Y = 3;
|
||||
constexpr index_t X = 3;
|
||||
@@ -208,14 +216,14 @@ int main(int argc, char* argv[])
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<1, 1>;
|
||||
using RightPads = Sequence<1, 1>;
|
||||
#elif 1
|
||||
using InLeftPads = Sequence<1, 1>;
|
||||
using InRightPads = Sequence<1, 1>;
|
||||
#elif 0
|
||||
// 3x3, 71x71
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 192;
|
||||
constexpr index_t HI = 71;
|
||||
constexpr index_t WI = 71;
|
||||
constexpr index_t Hi = 71;
|
||||
constexpr index_t Wi = 71;
|
||||
constexpr index_t K = 192;
|
||||
constexpr index_t Y = 3;
|
||||
constexpr index_t X = 3;
|
||||
@@ -223,14 +231,14 @@ int main(int argc, char* argv[])
|
||||
using ConvStrides = Sequence<2, 2>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<1, 1>;
|
||||
using RightPads = Sequence<1, 1>;
|
||||
using InLeftPads = Sequence<1, 1>;
|
||||
using InRightPads = Sequence<1, 1>;
|
||||
#elif 0
|
||||
// 7x1, 17x17
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 128;
|
||||
constexpr index_t HI = 17;
|
||||
constexpr index_t WI = 17;
|
||||
constexpr index_t Hi = 17;
|
||||
constexpr index_t Wi = 17;
|
||||
constexpr index_t K = 128;
|
||||
constexpr index_t Y = 7;
|
||||
constexpr index_t X = 1;
|
||||
@@ -238,14 +246,14 @@ int main(int argc, char* argv[])
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<3, 0>;
|
||||
using RightPads = Sequence<3, 0>;
|
||||
#elif 0
|
||||
using InLeftPads = Sequence<3, 0>;
|
||||
using InRightPads = Sequence<3, 0>;
|
||||
#elif 1
|
||||
// 1x7, 17x17
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 128;
|
||||
constexpr index_t HI = 17;
|
||||
constexpr index_t WI = 17;
|
||||
constexpr index_t Hi = 17;
|
||||
constexpr index_t Wi = 17;
|
||||
constexpr index_t K = 128;
|
||||
constexpr index_t Y = 1;
|
||||
constexpr index_t X = 7;
|
||||
@@ -253,14 +261,14 @@ int main(int argc, char* argv[])
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<0, 3>;
|
||||
using RightPads = Sequence<0, 3>;
|
||||
using InLeftPads = Sequence<0, 3>;
|
||||
using InRightPads = Sequence<0, 3>;
|
||||
#elif 0
|
||||
// 3x3, 299x299 stride=2
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 3;
|
||||
constexpr index_t HI = 299;
|
||||
constexpr index_t WI = 299;
|
||||
constexpr index_t Hi = 299;
|
||||
constexpr index_t Wi = 299;
|
||||
constexpr index_t K = 32;
|
||||
constexpr index_t Y = 3;
|
||||
constexpr index_t X = 3;
|
||||
@@ -268,14 +276,14 @@ int main(int argc, char* argv[])
|
||||
using ConvStrides = Sequence<2, 2>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<0, 0>;
|
||||
using RightPads = Sequence<0, 0>;
|
||||
using InLeftPads = Sequence<0, 0>;
|
||||
using InRightPads = Sequence<0, 0>;
|
||||
#elif 0
|
||||
// 3x3, 147x147
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 128;
|
||||
constexpr index_t HI = 147;
|
||||
constexpr index_t WI = 147;
|
||||
constexpr index_t Hi = 147;
|
||||
constexpr index_t Wi = 147;
|
||||
constexpr index_t K = 128;
|
||||
constexpr index_t Y = 3;
|
||||
constexpr index_t X = 3;
|
||||
@@ -283,14 +291,14 @@ int main(int argc, char* argv[])
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<1, 1>;
|
||||
using RightPads = Sequence<1, 1>;
|
||||
using InLeftPads = Sequence<1, 1>;
|
||||
using InRightPads = Sequence<1, 1>;
|
||||
#elif 0
|
||||
// 3x3, 149x149
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 32;
|
||||
constexpr index_t HI = 149;
|
||||
constexpr index_t WI = 149;
|
||||
constexpr index_t Hi = 149;
|
||||
constexpr index_t Wi = 149;
|
||||
constexpr index_t K = 32;
|
||||
constexpr index_t Y = 3;
|
||||
constexpr index_t X = 3;
|
||||
@@ -298,14 +306,14 @@ int main(int argc, char* argv[])
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<0, 0>;
|
||||
using RightPads = Sequence<0, 0>;
|
||||
using InLeftPads = Sequence<0, 0>;
|
||||
using InRightPads = Sequence<0, 0>;
|
||||
#elif 0
|
||||
// 3x3, 17x17, stride 2
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 192;
|
||||
constexpr index_t HI = 17;
|
||||
constexpr index_t WI = 17;
|
||||
constexpr index_t Hi = 17;
|
||||
constexpr index_t Wi = 17;
|
||||
constexpr index_t K = 192;
|
||||
constexpr index_t Y = 3;
|
||||
constexpr index_t X = 3;
|
||||
@@ -313,14 +321,14 @@ int main(int argc, char* argv[])
|
||||
using ConvStrides = Sequence<2, 2>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<0, 0>;
|
||||
using RightPads = Sequence<0, 0>;
|
||||
using InLeftPads = Sequence<0, 0>;
|
||||
using InRightPads = Sequence<0, 0>;
|
||||
#elif 0
|
||||
// 1x1, 35x35
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 384;
|
||||
constexpr index_t HI = 35;
|
||||
constexpr index_t WI = 35;
|
||||
constexpr index_t Hi = 35;
|
||||
constexpr index_t Wi = 35;
|
||||
constexpr index_t K = 96;
|
||||
constexpr index_t Y = 1;
|
||||
constexpr index_t X = 1;
|
||||
@@ -328,14 +336,14 @@ int main(int argc, char* argv[])
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<0, 0>;
|
||||
using RightPads = Sequence<0, 0>;
|
||||
using InLeftPads = Sequence<0, 0>;
|
||||
using InRightPads = Sequence<0, 0>;
|
||||
#elif 0
|
||||
// 3x3, 35x35, stride 2
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 288;
|
||||
constexpr index_t HI = 35;
|
||||
constexpr index_t WI = 35;
|
||||
constexpr index_t Hi = 35;
|
||||
constexpr index_t Wi = 35;
|
||||
constexpr index_t K = 384;
|
||||
constexpr index_t Y = 3;
|
||||
constexpr index_t X = 3;
|
||||
@@ -343,14 +351,14 @@ int main(int argc, char* argv[])
|
||||
using ConvStrides = Sequence<2, 2>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<0, 0>;
|
||||
using RightPads = Sequence<0, 0>;
|
||||
using InLeftPads = Sequence<0, 0>;
|
||||
using InRightPads = Sequence<0, 0>;
|
||||
#elif 0
|
||||
// 1x3, 8x8
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 384;
|
||||
constexpr index_t HI = 8;
|
||||
constexpr index_t WI = 8;
|
||||
constexpr index_t Hi = 8;
|
||||
constexpr index_t Wi = 8;
|
||||
constexpr index_t K = 448;
|
||||
constexpr index_t Y = 1;
|
||||
constexpr index_t X = 3;
|
||||
@@ -358,14 +366,14 @@ int main(int argc, char* argv[])
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<0, 1>;
|
||||
using RightPads = Sequence<0, 1>;
|
||||
using InLeftPads = Sequence<0, 1>;
|
||||
using InRightPads = Sequence<0, 1>;
|
||||
#elif 0
|
||||
// 3x1, 8x8
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 448;
|
||||
constexpr index_t HI = 8;
|
||||
constexpr index_t WI = 8;
|
||||
constexpr index_t Hi = 8;
|
||||
constexpr index_t Wi = 8;
|
||||
constexpr index_t K = 512;
|
||||
constexpr index_t Y = 3;
|
||||
constexpr index_t X = 1;
|
||||
@@ -373,14 +381,14 @@ int main(int argc, char* argv[])
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<1, 0>;
|
||||
using RightPads = Sequence<1, 0>;
|
||||
using InLeftPads = Sequence<1, 0>;
|
||||
using InRightPads = Sequence<1, 0>;
|
||||
#elif 0
|
||||
// 3x3, 147x147
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 64;
|
||||
constexpr index_t HI = 147;
|
||||
constexpr index_t WI = 147;
|
||||
constexpr index_t Hi = 147;
|
||||
constexpr index_t Wi = 147;
|
||||
constexpr index_t K = 96;
|
||||
constexpr index_t Y = 3;
|
||||
constexpr index_t X = 3;
|
||||
@@ -388,14 +396,14 @@ int main(int argc, char* argv[])
|
||||
using ConvStrides = Sequence<2, 2>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<0, 0>;
|
||||
using RightPads = Sequence<0, 0>;
|
||||
using InLeftPads = Sequence<0, 0>;
|
||||
using InRightPads = Sequence<0, 0>;
|
||||
#elif 0
|
||||
// 7x1, 73x73
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 64;
|
||||
constexpr index_t HI = 73;
|
||||
constexpr index_t WI = 73;
|
||||
constexpr index_t Hi = 73;
|
||||
constexpr index_t Wi = 73;
|
||||
constexpr index_t K = 64;
|
||||
constexpr index_t Y = 7;
|
||||
constexpr index_t X = 1;
|
||||
@@ -403,14 +411,14 @@ int main(int argc, char* argv[])
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<3, 0>;
|
||||
using RightPads = Sequence<3, 0>;
|
||||
using InLeftPads = Sequence<3, 0>;
|
||||
using InRightPads = Sequence<3, 0>;
|
||||
#elif 0
|
||||
// 3x3, 73x73
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 64;
|
||||
constexpr index_t HI = 73;
|
||||
constexpr index_t WI = 73;
|
||||
constexpr index_t Hi = 73;
|
||||
constexpr index_t Wi = 73;
|
||||
constexpr index_t K = 96;
|
||||
constexpr index_t Y = 3;
|
||||
constexpr index_t X = 3;
|
||||
@@ -418,14 +426,14 @@ int main(int argc, char* argv[])
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<0, 0>;
|
||||
using RightPads = Sequence<0, 0>;
|
||||
using InLeftPads = Sequence<0, 0>;
|
||||
using InRightPads = Sequence<0, 0>;
|
||||
#elif 0
|
||||
// 1x1, 14x14, stride 2
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 1024;
|
||||
constexpr index_t HI = 14;
|
||||
constexpr index_t WI = 14;
|
||||
constexpr index_t Hi = 14;
|
||||
constexpr index_t Wi = 14;
|
||||
constexpr index_t K = 2048;
|
||||
constexpr index_t Y = 1;
|
||||
constexpr index_t X = 1;
|
||||
@@ -433,14 +441,14 @@ int main(int argc, char* argv[])
|
||||
using ConvStrides = Sequence<2, 2>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<0, 0>;
|
||||
using RightPads = Sequence<0, 0>;
|
||||
using InLeftPads = Sequence<0, 0>;
|
||||
using InRightPads = Sequence<0, 0>;
|
||||
#elif 0
|
||||
// 1x1, 14x14
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 1024;
|
||||
constexpr index_t HI = 14;
|
||||
constexpr index_t WI = 14;
|
||||
constexpr index_t Hi = 14;
|
||||
constexpr index_t Wi = 14;
|
||||
constexpr index_t K = 256;
|
||||
constexpr index_t Y = 1;
|
||||
constexpr index_t X = 1;
|
||||
@@ -448,14 +456,14 @@ int main(int argc, char* argv[])
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<0, 0>;
|
||||
using RightPads = Sequence<0, 0>;
|
||||
using InLeftPads = Sequence<0, 0>;
|
||||
using InRightPads = Sequence<0, 0>;
|
||||
#elif 0
|
||||
// 1x1, 14x14, stride 2
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 1024;
|
||||
constexpr index_t HI = 14;
|
||||
constexpr index_t WI = 14;
|
||||
constexpr index_t Hi = 14;
|
||||
constexpr index_t Wi = 14;
|
||||
constexpr index_t K = 512;
|
||||
constexpr index_t Y = 1;
|
||||
constexpr index_t X = 1;
|
||||
@@ -463,14 +471,14 @@ int main(int argc, char* argv[])
|
||||
using ConvStrides = Sequence<2, 2>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<0, 0>;
|
||||
using RightPads = Sequence<0, 0>;
|
||||
using InLeftPads = Sequence<0, 0>;
|
||||
using InRightPads = Sequence<0, 0>;
|
||||
#elif 1
|
||||
// 3x3, 28x28
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 128;
|
||||
constexpr index_t HI = 28;
|
||||
constexpr index_t WI = 28;
|
||||
constexpr index_t Hi = 28;
|
||||
constexpr index_t Wi = 28;
|
||||
constexpr index_t K = 128;
|
||||
constexpr index_t Y = 3;
|
||||
constexpr index_t X = 3;
|
||||
@@ -478,14 +486,14 @@ int main(int argc, char* argv[])
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<1, 1>;
|
||||
using RightPads = Sequence<1, 1>;
|
||||
using InLeftPads = Sequence<1, 1>;
|
||||
using InRightPads = Sequence<1, 1>;
|
||||
#elif 1
|
||||
// 3x3, 14x14
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 256;
|
||||
constexpr index_t HI = 14;
|
||||
constexpr index_t WI = 14;
|
||||
constexpr index_t Hi = 14;
|
||||
constexpr index_t Wi = 14;
|
||||
constexpr index_t K = 256;
|
||||
constexpr index_t Y = 3;
|
||||
constexpr index_t X = 3;
|
||||
@@ -493,14 +501,14 @@ int main(int argc, char* argv[])
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<1, 1>;
|
||||
using RightPads = Sequence<1, 1>;
|
||||
using InLeftPads = Sequence<1, 1>;
|
||||
using InRightPads = Sequence<1, 1>;
|
||||
#elif 0
|
||||
// 1x1, 56x56, stride 2
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 256;
|
||||
constexpr index_t HI = 56;
|
||||
constexpr index_t WI = 56;
|
||||
constexpr index_t Hi = 56;
|
||||
constexpr index_t Wi = 56;
|
||||
constexpr index_t K = 128;
|
||||
constexpr index_t Y = 1;
|
||||
constexpr index_t X = 1;
|
||||
@@ -508,14 +516,14 @@ int main(int argc, char* argv[])
|
||||
using ConvStrides = Sequence<2, 2>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<0, 0>;
|
||||
using RightPads = Sequence<0, 0>;
|
||||
using InLeftPads = Sequence<0, 0>;
|
||||
using InRightPads = Sequence<0, 0>;
|
||||
#elif 0
|
||||
// 7x7, 230x230 stride=2
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 3;
|
||||
constexpr index_t HI = 230;
|
||||
constexpr index_t WI = 230;
|
||||
constexpr index_t Hi = 230;
|
||||
constexpr index_t Wi = 230;
|
||||
constexpr index_t K = 64;
|
||||
constexpr index_t Y = 7;
|
||||
constexpr index_t X = 7;
|
||||
@@ -523,14 +531,14 @@ int main(int argc, char* argv[])
|
||||
using ConvStrides = Sequence<2, 2>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<0, 0>;
|
||||
using RightPads = Sequence<0, 0>;
|
||||
using InLeftPads = Sequence<0, 0>;
|
||||
using InRightPads = Sequence<0, 0>;
|
||||
#elif 0
|
||||
// 1x1, 28x28, stride = 2
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 512;
|
||||
constexpr index_t HI = 28;
|
||||
constexpr index_t WI = 28;
|
||||
constexpr index_t Hi = 28;
|
||||
constexpr index_t Wi = 28;
|
||||
constexpr index_t K = 1024;
|
||||
constexpr index_t Y = 1;
|
||||
constexpr index_t X = 1;
|
||||
@@ -538,14 +546,14 @@ int main(int argc, char* argv[])
|
||||
using ConvStrides = Sequence<2, 2>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<0, 0>;
|
||||
using RightPads = Sequence<0, 0>;
|
||||
using InLeftPads = Sequence<0, 0>;
|
||||
using InRightPads = Sequence<0, 0>;
|
||||
#elif 0
|
||||
// 1x1, 28x28, stride 2
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 512;
|
||||
constexpr index_t HI = 28;
|
||||
constexpr index_t WI = 28;
|
||||
constexpr index_t Hi = 28;
|
||||
constexpr index_t Wi = 28;
|
||||
constexpr index_t K = 256;
|
||||
constexpr index_t Y = 1;
|
||||
constexpr index_t X = 1;
|
||||
@@ -553,14 +561,14 @@ int main(int argc, char* argv[])
|
||||
using ConvStrides = Sequence<2, 2>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<0, 0>;
|
||||
using RightPads = Sequence<0, 0>;
|
||||
using InLeftPads = Sequence<0, 0>;
|
||||
using InRightPads = Sequence<0, 0>;
|
||||
#elif 1
|
||||
// 1x1, 7x7
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 512;
|
||||
constexpr index_t HI = 7;
|
||||
constexpr index_t WI = 7;
|
||||
constexpr index_t Hi = 7;
|
||||
constexpr index_t Wi = 7;
|
||||
constexpr index_t K = 2048;
|
||||
constexpr index_t Y = 1;
|
||||
constexpr index_t X = 1;
|
||||
@@ -568,14 +576,14 @@ int main(int argc, char* argv[])
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<0, 0>;
|
||||
using RightPads = Sequence<0, 0>;
|
||||
using InLeftPads = Sequence<0, 0>;
|
||||
using InRightPads = Sequence<0, 0>;
|
||||
#elif 0
|
||||
// 3x3, 7x7
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 512;
|
||||
constexpr index_t HI = 7;
|
||||
constexpr index_t WI = 7;
|
||||
constexpr index_t Hi = 7;
|
||||
constexpr index_t Wi = 7;
|
||||
constexpr index_t K = 512;
|
||||
constexpr index_t Y = 3;
|
||||
constexpr index_t X = 3;
|
||||
@@ -583,14 +591,14 @@ int main(int argc, char* argv[])
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<1, 1>;
|
||||
using RightPads = Sequence<1, 1>;
|
||||
using InLeftPads = Sequence<1, 1>;
|
||||
using InRightPads = Sequence<1, 1>;
|
||||
#elif 0
|
||||
// 1x1, 56x56
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 64;
|
||||
constexpr index_t HI = 56;
|
||||
constexpr index_t WI = 56;
|
||||
constexpr index_t Hi = 56;
|
||||
constexpr index_t Wi = 56;
|
||||
constexpr index_t K = 64;
|
||||
constexpr index_t Y = 1;
|
||||
constexpr index_t X = 1;
|
||||
@@ -598,14 +606,14 @@ int main(int argc, char* argv[])
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<0, 0>;
|
||||
using RightPads = Sequence<0, 0>;
|
||||
using InLeftPads = Sequence<0, 0>;
|
||||
using InRightPads = Sequence<0, 0>;
|
||||
#elif 0
|
||||
// 3x3, 56x56
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 64;
|
||||
constexpr index_t HI = 56;
|
||||
constexpr index_t WI = 56;
|
||||
constexpr index_t Hi = 56;
|
||||
constexpr index_t Wi = 56;
|
||||
constexpr index_t K = 64;
|
||||
constexpr index_t Y = 3;
|
||||
constexpr index_t X = 3;
|
||||
@@ -613,82 +621,86 @@ int main(int argc, char* argv[])
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<1, 1>;
|
||||
using RightPads = Sequence<1, 1>;
|
||||
using InLeftPads = Sequence<1, 1>;
|
||||
using InRightPads = Sequence<1, 1>;
|
||||
#endif
|
||||
|
||||
auto in_nchw_desc = make_native_tensor_descriptor_packed(Sequence<N, C, HI, WI>{});
|
||||
auto wei_kcyx_desc = make_native_tensor_descriptor_packed(Sequence<K, C, Y, X>{});
|
||||
auto out_nkhw_desc = get_convolution_output_default_4d_tensor_descriptor(
|
||||
in_nchw_desc, wei_kcyx_desc, ConvStrides{}, ConvDilations{}, LeftPads{}, RightPads{});
|
||||
constexpr index_t YEff = (Y - 1) * ConvDilations{}[0] + 1;
|
||||
constexpr index_t XEff = (X - 1) * ConvDilations{}[1] + 1;
|
||||
|
||||
ostream_tensor_descriptor(in_nchw_desc, std::cout << "in_nchw_desc: ");
|
||||
ostream_tensor_descriptor(wei_kcyx_desc, std::cout << "wei_kcyx_desc: ");
|
||||
ostream_tensor_descriptor(out_nkhw_desc, std::cout << "out_nkhw_desc: ");
|
||||
print_array("LeftPads", to_multi_index(LeftPads{}));
|
||||
print_array("RightPads", to_multi_index(RightPads{}));
|
||||
print_array("ConvStrides", to_multi_index(ConvStrides{}));
|
||||
print_array("ConvDilations", to_multi_index(ConvDilations{}));
|
||||
constexpr index_t Ho = (Hi + InLeftPads{}[0] + InRightPads{}[0] - YEff) / ConvStrides{}[0] + 1;
|
||||
constexpr index_t Wo = (Wi + InLeftPads{}[1] + InRightPads{}[1] - XEff) / ConvStrides{}[1] + 1;
|
||||
|
||||
#if 1
|
||||
using in_data_t = float;
|
||||
constexpr index_t in_vector_size = 1;
|
||||
using in_data_t = typename vector_type<float, in_vector_size>::type;
|
||||
using acc_data_t = float;
|
||||
using out_data_t = float;
|
||||
#elif 0
|
||||
using in_data_t = float;
|
||||
constexpr index_t in_vector_size = 1;
|
||||
using in_data_t = typename vector_type<float, in_vector_size>::type;
|
||||
using acc_data_t = float;
|
||||
using out_data_t = int8_t;
|
||||
#elif 1
|
||||
using in_data_t = int8_t;
|
||||
constexpr index_t in_vector_size = 16;
|
||||
using in_data_t = typename vector_type<int8_t, in_vector_size>::type;
|
||||
using acc_data_t = int32_t;
|
||||
using out_data_t = int8_t;
|
||||
#endif
|
||||
|
||||
Tensor<in_data_t> in_nchw(make_HostTensorDescriptor(in_nchw_desc));
|
||||
Tensor<in_data_t> wei_kcyx(make_HostTensorDescriptor(wei_kcyx_desc));
|
||||
Tensor<out_data_t> out_nkhw_host(make_HostTensorDescriptor(out_nkhw_desc));
|
||||
Tensor<out_data_t> out_nkhw_device(make_HostTensorDescriptor(out_nkhw_desc));
|
||||
Tensor<in_data_t> in_nchw(HostTensorDescriptor(std::initializer_list<index_t>{N, C, Hi, Wi}));
|
||||
Tensor<in_data_t> wei_kcyx(HostTensorDescriptor(std::initializer_list<index_t>{K, C, Y, X}));
|
||||
Tensor<out_data_t> out_nkhw_host(
|
||||
HostTensorDescriptor(std::initializer_list<index_t>{N, K, Ho, Wo}));
|
||||
Tensor<out_data_t> out_nkhw_device(
|
||||
HostTensorDescriptor(std::initializer_list<index_t>{N, K, Ho, Wo}));
|
||||
|
||||
ostream_HostTensorDescriptor(in_nchw.mDesc, std::cout << "in_nchw_desc: ");
|
||||
ostream_HostTensorDescriptor(wei_kcyx.mDesc, std::cout << "wei_kcyx_desc: ");
|
||||
ostream_HostTensorDescriptor(out_nkhw_host.mDesc, std::cout << "out_nkhw_desc: ");
|
||||
|
||||
print_array("InLeftPads", InLeftPads{});
|
||||
print_array("InRightPads", InRightPads{});
|
||||
print_array("ConvStrides", ConvStrides{});
|
||||
print_array("ConvDilations", ConvDilations{});
|
||||
|
||||
std::size_t num_thread = std::thread::hardware_concurrency();
|
||||
|
||||
if(argc != 4)
|
||||
{
|
||||
printf("arg1: do_verification, arg2: do_log, arg3: nrepeat\n");
|
||||
exit(1);
|
||||
}
|
||||
|
||||
bool do_verification = atoi(argv[1]);
|
||||
bool do_log = atoi(argv[2]);
|
||||
index_t nrepeat = atoi(argv[3]);
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
#if 0
|
||||
in_nchw.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
|
||||
wei_kcyx.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
|
||||
#elif 0
|
||||
in_nchw.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
|
||||
wei_kcyx.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
|
||||
#elif 0
|
||||
in_nchw.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
|
||||
wei_kcyx.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
|
||||
#elif 1
|
||||
in_nchw.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
|
||||
wei_kcyx.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
|
||||
#elif 0
|
||||
in_nchw.GenerateTensorValue(GeneratorTensor_2{1, 5}, num_thread);
|
||||
switch(init_method)
|
||||
{
|
||||
case 0:
|
||||
in_nchw.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
|
||||
wei_kcyx.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
|
||||
break;
|
||||
case 1:
|
||||
in_nchw.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
|
||||
wei_kcyx.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
|
||||
break;
|
||||
case 2:
|
||||
in_nchw.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
|
||||
wei_kcyx.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
|
||||
break;
|
||||
case 3:
|
||||
in_nchw.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
|
||||
wei_kcyx.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
|
||||
break;
|
||||
default:
|
||||
in_nchw.GenerateTensorValue(GeneratorTensor_2{1, 5}, num_thread);
|
||||
|
||||
auto gen_wei = [](auto... is) {
|
||||
return GeneratorTensor_2{1, 5}(is...) * GeneratorTensor_Checkboard{}(is...);
|
||||
};
|
||||
wei_kcyx.GenerateTensorValue(gen_wei, num_thread);
|
||||
#endif
|
||||
auto gen_wei = [](auto... is) {
|
||||
return GeneratorTensor_2{1, 5}(is...) * GeneratorTensor_Checkboard{}(is...);
|
||||
};
|
||||
wei_kcyx.GenerateTensorValue(gen_wei, num_thread);
|
||||
}
|
||||
}
|
||||
|
||||
#if 0
|
||||
constexpr auto in_nchw_desc = make_native_tensor_descriptor_packed(Sequence<N, C, Hi, Wi>{});
|
||||
constexpr auto wei_kcyx_desc = make_native_tensor_descriptor_packed(Sequence<K, C, Y, X>{});
|
||||
constexpr auto out_nkhw_desc = make_native_tensor_descriptor_packed(Sequence<N, K, Ho, Wo>{});
|
||||
|
||||
#if 1
|
||||
device_convolution_forward_implicit_gemm_v4r1_nchw_kcyx_nkhw(in_nchw_desc,
|
||||
in_nchw,
|
||||
wei_kcyx_desc,
|
||||
@@ -697,8 +709,8 @@ int main(int argc, char* argv[])
|
||||
out_nkhw_device,
|
||||
ConvStrides{},
|
||||
ConvDilations{},
|
||||
LeftPads{},
|
||||
RightPads{},
|
||||
InLeftPads{},
|
||||
InRightPads{},
|
||||
nrepeat);
|
||||
#elif 0
|
||||
device_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(in_nchw_desc,
|
||||
@@ -709,8 +721,8 @@ int main(int argc, char* argv[])
|
||||
out_nkhw_device,
|
||||
ConvStrides{},
|
||||
ConvDilations{},
|
||||
LeftPads{},
|
||||
RightPads{},
|
||||
InLeftPads{},
|
||||
InRightPads{},
|
||||
nrepeat);
|
||||
#elif 0
|
||||
device_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(in_nchw_desc,
|
||||
@@ -721,58 +733,9 @@ int main(int argc, char* argv[])
|
||||
out_nkhw_device,
|
||||
ConvStrides{},
|
||||
ConvDilations{},
|
||||
LeftPads{},
|
||||
RightPads{},
|
||||
InLeftPads{},
|
||||
InRightPads{},
|
||||
nrepeat);
|
||||
#elif 0
|
||||
device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw<in_data_t,
|
||||
in_vector_size,
|
||||
acc_data_t,
|
||||
out_data_t>(
|
||||
in_nchw_desc,
|
||||
in_nchw,
|
||||
wei_kcyx_desc,
|
||||
wei_kcyx,
|
||||
out_nkhw_desc,
|
||||
out_nkhw_device,
|
||||
ConvStrides{},
|
||||
ConvDilations{},
|
||||
LeftPads{},
|
||||
RightPads{},
|
||||
nrepeat);
|
||||
#elif 1
|
||||
device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk<in_data_t,
|
||||
in_vector_size,
|
||||
acc_data_t,
|
||||
out_data_t>
|
||||
|
||||
(in_nchw_desc,
|
||||
in_nchw,
|
||||
wei_kcyx_desc,
|
||||
wei_kcyx,
|
||||
out_nkhw_desc,
|
||||
out_nkhw_device,
|
||||
ConvStrides{},
|
||||
ConvDilations{},
|
||||
LeftPads{},
|
||||
RightPads{},
|
||||
nrepeat);
|
||||
#elif 1
|
||||
device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw<in_data_t,
|
||||
in_vector_size,
|
||||
acc_data_t,
|
||||
out_data_t>(
|
||||
in_nchw_desc,
|
||||
in_nchw,
|
||||
wei_kcyx_desc,
|
||||
wei_kcyx,
|
||||
out_nkhw_desc,
|
||||
out_nkhw_device,
|
||||
ConvStrides{},
|
||||
ConvDilations{},
|
||||
LeftPads{},
|
||||
RightPads{},
|
||||
nrepeat);
|
||||
#endif
|
||||
|
||||
if(do_verification)
|
||||
@@ -782,8 +745,8 @@ int main(int argc, char* argv[])
|
||||
out_nkhw_host,
|
||||
ConvStrides{},
|
||||
ConvDilations{},
|
||||
LeftPads{},
|
||||
RightPads{});
|
||||
InLeftPads{},
|
||||
InRightPads{});
|
||||
|
||||
check_error(out_nkhw_host, out_nkhw_device);
|
||||
|
||||
|
||||
410
driver/src/conv_driver_v2.cpp
Normal file
410
driver/src/conv_driver_v2.cpp
Normal file
@@ -0,0 +1,410 @@
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <initializer_list>
|
||||
#include <cstdlib>
|
||||
#include <stdlib.h>
|
||||
#include <half.hpp>
|
||||
#include "config.hpp"
|
||||
#include "print.hpp"
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "host_tensor_generator.hpp"
|
||||
#include "conv_common.hpp"
|
||||
#include "host_conv.hpp"
|
||||
#include "device_tensor.hpp"
|
||||
#include "device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp"
|
||||
#include "device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk.hpp"
|
||||
#include "device_dynamic_convolution_forward_implicit_gemm_v4r5_nchw_kcyx_nkhw.hpp"
|
||||
#include "device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw.hpp"
|
||||
|
||||
#define USE_DYNAMIC_MODE 1
|
||||
#define USE_CONV_FWD_V4R4_NCHW 1
|
||||
#define USE_CONV_FWD_V4R4_NHWC 1
|
||||
#define USE_CONV_FWD_V4R5_NCHW 1
|
||||
#define USE_CONV_FWD_V5R1_NCHW 0
|
||||
|
||||
enum ConvForwardAlgo
|
||||
{
|
||||
V4R4NCHW,
|
||||
V4R4NHWC,
|
||||
V4R5NCHW,
|
||||
V5R1NCHW
|
||||
};
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
constexpr auto I4 = Number<4>{};
|
||||
constexpr auto I5 = Number<5>{};
|
||||
constexpr auto I6 = Number<6>{};
|
||||
|
||||
#if USE_DYNAMIC_MODE
|
||||
// dynamic mode
|
||||
if(argc != 22)
|
||||
{
|
||||
printf("arg1 to 5: layout, algo, do_verification, init_method, do_log, nrepeat\n");
|
||||
printf("rest: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, RightPx\n");
|
||||
exit(1);
|
||||
}
|
||||
|
||||
const ConvTensorLayout layout = static_cast<ConvTensorLayout>(atoi(argv[1]));
|
||||
const ConvForwardAlgo algo = static_cast<ConvForwardAlgo>(atoi(argv[2]));
|
||||
const bool do_verification = atoi(argv[3]);
|
||||
const int init_method = atoi(argv[4]);
|
||||
const bool do_log = atoi(argv[5]);
|
||||
const int nrepeat = atoi(argv[6]);
|
||||
|
||||
const index_t N = atoi(argv[7]);
|
||||
const index_t K = atoi(argv[8]);
|
||||
const index_t C = atoi(argv[9]);
|
||||
const index_t Y = atoi(argv[10]);
|
||||
const index_t X = atoi(argv[11]);
|
||||
const index_t Hi = atoi(argv[12]);
|
||||
const index_t Wi = atoi(argv[13]);
|
||||
|
||||
const index_t conv_stride_h = atoi(argv[14]);
|
||||
const index_t conv_stride_w = atoi(argv[15]);
|
||||
const index_t conv_dilation_h = atoi(argv[16]);
|
||||
const index_t conv_dilation_w = atoi(argv[17]);
|
||||
const index_t in_left_pad_h = atoi(argv[18]);
|
||||
const index_t in_left_pad_w = atoi(argv[19]);
|
||||
const index_t in_right_pad_h = atoi(argv[20]);
|
||||
const index_t in_right_pad_w = atoi(argv[21]);
|
||||
|
||||
const index_t YEff = (Y - 1) * conv_dilation_h + 1;
|
||||
const index_t XEff = (X - 1) * conv_dilation_w + 1;
|
||||
|
||||
const index_t Ho = (Hi + in_left_pad_h + in_right_pad_h - YEff) / conv_stride_h + 1;
|
||||
const index_t Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1;
|
||||
#else
|
||||
// static mode
|
||||
if(argc < 7)
|
||||
{
|
||||
printf("arg1 to 5: layout, algo, do_verification, init_method, do_log, nrepeat\n");
|
||||
exit(1);
|
||||
}
|
||||
|
||||
const ConvTensorLayout layout = static_cast<ConvTensorLayout>(atoi(argv[1]));
|
||||
const ConvForwardAlgo algo = static_cast<ConvForwardAlgo>(atoi(argv[2]));
|
||||
const bool do_verification = atoi(argv[3]);
|
||||
const int init_method = atoi(argv[4]);
|
||||
const bool do_log = atoi(argv[5]);
|
||||
const int nrepeat = atoi(argv[6]);
|
||||
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 128;
|
||||
constexpr index_t Hi = 17;
|
||||
constexpr index_t Wi = 17;
|
||||
constexpr index_t K = 128;
|
||||
constexpr index_t Y = 1;
|
||||
constexpr index_t X = 7;
|
||||
|
||||
const index_t conv_stride_h = 1;
|
||||
const index_t conv_stride_w = 1;
|
||||
const index_t conv_dilation_h = 1;
|
||||
const index_t conv_dilation_w = 1;
|
||||
const index_t in_left_pad_h = 0;
|
||||
const index_t in_left_pad_w = 3;
|
||||
const index_t in_right_pad_h = 0;
|
||||
const index_t in_right_pad_w = 3;
|
||||
|
||||
const index_t YEff = (Y - 1) * conv_dilation_h + 1;
|
||||
const index_t XEff = (X - 1) * conv_dilation_w + 1;
|
||||
|
||||
const index_t Ho = (Hi + in_left_pad_h + in_right_pad_h - YEff) / conv_stride_h + 1;
|
||||
const index_t Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1;
|
||||
#endif
|
||||
|
||||
#if 1
|
||||
constexpr index_t in_vector_size = 1;
|
||||
using in_data_t = float;
|
||||
using acc_data_t = float;
|
||||
using out_data_t = float;
|
||||
#elif 1
|
||||
constexpr index_t in_vector_size = 16;
|
||||
using in_data_t = int8_t;
|
||||
using acc_data_t = int32_t;
|
||||
using out_data_t = int8_t;
|
||||
#endif
|
||||
|
||||
std::vector<std::size_t> in_lengths_host(4), wei_lengths_host(4), out_lengths_host(4);
|
||||
|
||||
switch(layout)
|
||||
{
|
||||
case ConvTensorLayout::NCHW:
|
||||
// NCHW
|
||||
in_lengths_host[0] = static_cast<std::size_t>(N);
|
||||
in_lengths_host[1] = static_cast<std::size_t>(C);
|
||||
in_lengths_host[2] = static_cast<std::size_t>(Hi);
|
||||
in_lengths_host[3] = static_cast<std::size_t>(Wi);
|
||||
wei_lengths_host[0] = static_cast<std::size_t>(K);
|
||||
wei_lengths_host[1] = static_cast<std::size_t>(C);
|
||||
wei_lengths_host[2] = static_cast<std::size_t>(Y);
|
||||
wei_lengths_host[3] = static_cast<std::size_t>(X);
|
||||
out_lengths_host[0] = static_cast<std::size_t>(N);
|
||||
out_lengths_host[1] = static_cast<std::size_t>(K);
|
||||
out_lengths_host[2] = static_cast<std::size_t>(Ho);
|
||||
out_lengths_host[3] = static_cast<std::size_t>(Wo);
|
||||
break;
|
||||
case ConvTensorLayout::NHWC:
|
||||
// NHWC
|
||||
in_lengths_host[0] = static_cast<std::size_t>(N);
|
||||
in_lengths_host[1] = static_cast<std::size_t>(Hi);
|
||||
in_lengths_host[2] = static_cast<std::size_t>(Wi);
|
||||
in_lengths_host[3] = static_cast<std::size_t>(C);
|
||||
wei_lengths_host[0] = static_cast<std::size_t>(K);
|
||||
wei_lengths_host[1] = static_cast<std::size_t>(Y);
|
||||
wei_lengths_host[2] = static_cast<std::size_t>(X);
|
||||
wei_lengths_host[3] = static_cast<std::size_t>(C);
|
||||
out_lengths_host[0] = static_cast<std::size_t>(N);
|
||||
out_lengths_host[1] = static_cast<std::size_t>(Ho);
|
||||
out_lengths_host[2] = static_cast<std::size_t>(Wo);
|
||||
out_lengths_host[3] = static_cast<std::size_t>(K);
|
||||
break;
|
||||
default: throw std::runtime_error("wrong! not implemented");
|
||||
}
|
||||
|
||||
Tensor<in_data_t> in(in_lengths_host);
|
||||
Tensor<in_data_t> wei(wei_lengths_host);
|
||||
Tensor<out_data_t> out_host(out_lengths_host);
|
||||
Tensor<out_data_t> out_device(out_lengths_host);
|
||||
|
||||
std::cout << "layout: " << layout << std::endl;
|
||||
ostream_HostTensorDescriptor(in.mDesc, std::cout << "in: ");
|
||||
ostream_HostTensorDescriptor(wei.mDesc, std::cout << "wei: ");
|
||||
ostream_HostTensorDescriptor(out_host.mDesc, std::cout << "out: ");
|
||||
print_array("InLeftPads", make_tuple(in_left_pad_h, in_left_pad_w));
|
||||
print_array("InRightPads", make_tuple(in_right_pad_h, in_right_pad_w));
|
||||
print_array("ConvStrides", make_tuple(conv_stride_h, conv_stride_w));
|
||||
print_array("ConvDilations", make_tuple(conv_dilation_h, conv_dilation_w));
|
||||
|
||||
std::size_t num_thread = std::thread::hardware_concurrency();
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
switch(init_method)
|
||||
{
|
||||
case 0:
|
||||
in.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
|
||||
wei.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
|
||||
break;
|
||||
case 1:
|
||||
in.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
|
||||
wei.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
|
||||
break;
|
||||
case 2:
|
||||
in.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
|
||||
wei.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
|
||||
break;
|
||||
case 3:
|
||||
in.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
|
||||
wei.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
|
||||
break;
|
||||
default:
|
||||
in.GenerateTensorValue(GeneratorTensor_2{1, 5}, num_thread);
|
||||
|
||||
auto gen_wei = [](auto... is) {
|
||||
return GeneratorTensor_2{1, 5}(is...) * GeneratorTensor_Checkboard{}(is...);
|
||||
};
|
||||
wei.GenerateTensorValue(gen_wei, num_thread);
|
||||
}
|
||||
}
|
||||
|
||||
auto f_make_for_device_nchw = [&]() {
|
||||
#if USE_DYNAMIC_MODE
|
||||
const auto in_lengths_dev = make_tuple(N, C, Hi, Wi);
|
||||
const auto wei_lengths_dev = make_tuple(K, C, Y, X);
|
||||
const auto out_lengths_dev = make_tuple(N, K, Ho, Wo);
|
||||
const auto conv_strides_dev = make_tuple(conv_stride_h, conv_stride_w);
|
||||
const auto conv_dilations_dev = make_tuple(conv_dilation_h, conv_dilation_w);
|
||||
const auto in_left_pads_dev = make_tuple(in_left_pad_h, in_left_pad_w);
|
||||
const auto in_right_pads_dev = make_tuple(in_right_pad_h, in_right_pad_w);
|
||||
#else
|
||||
const auto in_lengths_dev =
|
||||
make_tuple(Number<N>{}, Number<C>{}, Number<Hi>{}, Number<Wi>{});
|
||||
const auto wei_lengths_dev = make_tuple(Number<K>{}, Number<C>{}, Number<Y>{}, Number<X>{});
|
||||
const auto out_lengths_dev =
|
||||
make_tuple(Number<N>{}, Number<K>{}, Number<Ho>{}, Number<Wo>{});
|
||||
const auto conv_strides_dev = make_tuple(Number<conv_stride_h>{}, Number<conv_stride_w>{});
|
||||
const auto conv_dilations_dev =
|
||||
make_tuple(Number<conv_dilation_h>{}, Number<conv_dilation_w>{});
|
||||
const auto in_left_pads_dev = make_tuple(Number<in_left_pad_h>{}, Number<in_left_pad_w>{});
|
||||
const auto in_right_pads_dev =
|
||||
make_tuple(Number<in_right_pad_h>{}, Number<in_right_pad_w>{});
|
||||
#endif
|
||||
|
||||
return make_tuple(in_lengths_dev,
|
||||
wei_lengths_dev,
|
||||
out_lengths_dev,
|
||||
conv_strides_dev,
|
||||
conv_dilations_dev,
|
||||
in_left_pads_dev,
|
||||
in_right_pads_dev);
|
||||
};
|
||||
|
||||
auto f_make_for_device_nhwc = [&]() {
|
||||
#if USE_DYNAMIC_MODE
|
||||
const auto in_lengths_dev = make_tuple(N, Hi, Wi, C);
|
||||
const auto wei_lengths_dev = make_tuple(K, Y, X, C);
|
||||
const auto out_lengths_dev = make_tuple(N, Ho, Wo, K);
|
||||
const auto conv_strides_dev = make_tuple(conv_stride_h, conv_stride_w);
|
||||
const auto conv_dilations_dev = make_tuple(conv_dilation_h, conv_dilation_w);
|
||||
const auto in_left_pads_dev = make_tuple(in_left_pad_h, in_left_pad_w);
|
||||
const auto in_right_pads_dev = make_tuple(in_right_pad_h, in_right_pad_w);
|
||||
#else
|
||||
const auto in_lengths_dev =
|
||||
make_tuple(Number<N>{}, Number<Hi>{}, Number<Wi>{}, Number<C>{});
|
||||
const auto wei_lengths_dev = make_tuple(Number<K>{}, Number<Y>{}, Number<X>{}, Number<C>{});
|
||||
const auto out_lengths_dev =
|
||||
make_tuple(Number<N>{}, Number<Ho>{}, Number<Wo>{}, Number<K>{});
|
||||
const auto conv_strides_dev = make_tuple(Number<conv_stride_h>{}, Number<conv_stride_w>{});
|
||||
const auto conv_dilations_dev =
|
||||
make_tuple(Number<conv_dilation_h>{}, Number<conv_dilation_w>{});
|
||||
const auto in_left_pads_dev = make_tuple(Number<in_left_pad_h>{}, Number<in_left_pad_w>{});
|
||||
const auto in_right_pads_dev =
|
||||
make_tuple(Number<in_right_pad_h>{}, Number<in_right_pad_w>{});
|
||||
#endif
|
||||
|
||||
return make_tuple(in_lengths_dev,
|
||||
wei_lengths_dev,
|
||||
out_lengths_dev,
|
||||
conv_strides_dev,
|
||||
conv_dilations_dev,
|
||||
in_left_pads_dev,
|
||||
in_right_pads_dev);
|
||||
};
|
||||
|
||||
const auto nhwc_desc = f_make_for_device_nhwc();
|
||||
|
||||
#if USE_CONV_FWD_V4R4_NCHW
|
||||
if(algo == ConvForwardAlgo::V4R4NCHW)
|
||||
{
|
||||
if(layout != ConvTensorLayout::NCHW)
|
||||
{
|
||||
throw std::runtime_error("wrong! layout");
|
||||
}
|
||||
|
||||
const auto tmp = f_make_for_device_nchw();
|
||||
|
||||
device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw<in_data_t,
|
||||
acc_data_t,
|
||||
out_data_t>(tmp[I0],
|
||||
tmp[I1],
|
||||
tmp[I2],
|
||||
tmp[I3],
|
||||
tmp[I4],
|
||||
tmp[I5],
|
||||
tmp[I6],
|
||||
in,
|
||||
wei,
|
||||
out_device,
|
||||
nrepeat);
|
||||
}
|
||||
#endif
|
||||
|
||||
#if USE_CONV_FWD_V4R4_NHWC
|
||||
if(algo == ConvForwardAlgo::V4R4NHWC)
|
||||
{
|
||||
if(layout != ConvTensorLayout::NHWC)
|
||||
{
|
||||
throw std::runtime_error("wrong! layout");
|
||||
}
|
||||
|
||||
const auto tmp = f_make_for_device_nhwc();
|
||||
|
||||
device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk<in_data_t,
|
||||
acc_data_t,
|
||||
out_data_t>(tmp[I0],
|
||||
tmp[I1],
|
||||
tmp[I2],
|
||||
tmp[I3],
|
||||
tmp[I4],
|
||||
tmp[I5],
|
||||
tmp[I6],
|
||||
in,
|
||||
wei,
|
||||
out_device,
|
||||
nrepeat);
|
||||
}
|
||||
#endif
|
||||
|
||||
#if USE_CONV_FWD_V4R5_NCHW
|
||||
if(algo == ConvForwardAlgo::V4R5NCHW)
|
||||
{
|
||||
if(layout != ConvTensorLayout::NCHW)
|
||||
{
|
||||
throw std::runtime_error("wrong! layout");
|
||||
}
|
||||
|
||||
const auto tmp = f_make_for_device_nchw();
|
||||
|
||||
device_dynamic_convolution_forward_implicit_gemm_v4r5_nchw_kcyx_nkhw<in_data_t,
|
||||
acc_data_t,
|
||||
out_data_t>(tmp[I0],
|
||||
tmp[I1],
|
||||
tmp[I2],
|
||||
tmp[I3],
|
||||
tmp[I4],
|
||||
tmp[I5],
|
||||
tmp[I6],
|
||||
in,
|
||||
wei,
|
||||
out_device,
|
||||
nrepeat);
|
||||
}
|
||||
#endif
|
||||
|
||||
#if USE_CONV_FWD_V5R1_NCHW
|
||||
if(algo == ConvForwardAlgo::V5R1NCHW)
|
||||
{
|
||||
if(layout != ConvTensorLayout::NCHW)
|
||||
{
|
||||
throw std::runtime_error("wrong! layout");
|
||||
}
|
||||
|
||||
const auto tmp = f_make_for_device_nchw();
|
||||
|
||||
device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw<in_data_t,
|
||||
in_vector_size,
|
||||
acc_data_t,
|
||||
out_data_t>(tmp[I0],
|
||||
tmp[I1],
|
||||
tmp[I2],
|
||||
tmp[I3],
|
||||
tmp[I4],
|
||||
tmp[I5],
|
||||
tmp[I6],
|
||||
in,
|
||||
wei,
|
||||
out_device,
|
||||
nrepeat);
|
||||
}
|
||||
#endif
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
host_direct_convolution(in,
|
||||
wei,
|
||||
out_host,
|
||||
make_tuple(conv_stride_h, conv_stride_w),
|
||||
make_tuple(conv_dilation_h, conv_dilation_w),
|
||||
make_tuple(in_left_pad_h, in_left_pad_w),
|
||||
make_tuple(in_right_pad_h, in_right_pad_w),
|
||||
layout);
|
||||
|
||||
check_error(out_host, out_device);
|
||||
|
||||
if(do_log)
|
||||
{
|
||||
LogRange(std::cout << "in : ", in.mData, ",") << std::endl;
|
||||
LogRange(std::cout << "wei: ", wei.mData, ",") << std::endl;
|
||||
LogRange(std::cout << "out_host : ", out_host.mData, ",") << std::endl;
|
||||
LogRange(std::cout << "out_device: ", out_device.mData, ",") << std::endl;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -3,18 +3,6 @@
|
||||
|
||||
#include "host_tensor.hpp"
|
||||
|
||||
template <typename X>
|
||||
HostTensorDescriptor::HostTensorDescriptor(std::vector<X> lens) : mLens(lens)
|
||||
{
|
||||
this->CalculateStrides();
|
||||
}
|
||||
|
||||
template <typename X, typename Y>
|
||||
HostTensorDescriptor::HostTensorDescriptor(std::vector<X> lens, std::vector<Y> strides)
|
||||
: mLens(lens), mStrides(strides)
|
||||
{
|
||||
}
|
||||
|
||||
void HostTensorDescriptor::CalculateStrides()
|
||||
{
|
||||
mStrides.clear();
|
||||
@@ -45,3 +33,16 @@ std::size_t HostTensorDescriptor::GetElementSpace() const
|
||||
const std::vector<std::size_t>& HostTensorDescriptor::GetLengths() const { return mLens; }
|
||||
|
||||
const std::vector<std::size_t>& HostTensorDescriptor::GetStrides() const { return mStrides; }
|
||||
|
||||
void ostream_HostTensorDescriptor(const HostTensorDescriptor& desc, std::ostream& os)
|
||||
{
|
||||
os << "dim " << desc.GetNumOfDimension() << ", ";
|
||||
|
||||
os << "lengths {";
|
||||
LogRange(os, desc.GetLengths(), ", ");
|
||||
os << "}, ";
|
||||
|
||||
os << "strides {";
|
||||
LogRange(os, desc.GetStrides(), ", ");
|
||||
os << "}" << std::endl;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user