mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-12 01:10:17 +00:00
Composable kernel init integration v3 (#1097)
* Squashed 'src/composable_kernel/' content from commitf6edda611git-subtree-dir: src/composable_kernel git-subtree-split:f6edda6119* add solver ConvIgemmFwdV6r1DlopsNchwKcyxNkhw; rename static ck source files * Squashed 'src/composable_kernel/' changes from f6edda611..5781adf5c5781adf5cUpdate develop (#5) (#6)97e6d514fMerge pull request #4 from ROCmSoftwarePlatform/separate_online_compile7b1ec41e5refactor49c33aaearefactor54b3e73d1rename git-subtree-dir: src/composable_kernel git-subtree-split:5781adf5cf* fix * refactor * remove online compilation from CK * refactor * fix * add ctest * add c-style pointer cast * vector/scalar pointer cast use c-style pointer cast instead of reinterpret_cast * fix clang warning suppression * tidy * suppress cppcheck * fix enum issue * revert chagnes to hip build * fix kernel filename * update CK build script * rename * rename * make innner product compatiable on gfx900 * Update src/include/miopen/solver/ck_utility_common.hpp Co-authored-by: JD <Jehandad.Khan@amd.com> * compiler parameter use stream * use int instead of index_t in kernel wrapper * DynamicBuffer, StaticBuffer, amd_buffer_load support customized value for invalid element * refactor * refactor * change cmakelist * change ck common utility * fix Co-authored-by: JD <Jehandad.Khan@amd.com>
This commit is contained in:
2
host/CMakeLists.txt
Normal file
2
host/CMakeLists.txt
Normal file
@@ -0,0 +1,2 @@
|
||||
add_subdirectory(host_tensor)
|
||||
add_subdirectory(driver_offline)
|
||||
21
host/driver_offline/CMakeLists.txt
Normal file
21
host/driver_offline/CMakeLists.txt
Normal file
@@ -0,0 +1,21 @@
|
||||
include_directories(BEFORE
|
||||
include
|
||||
${PROJECT_SOURCE_DIR}/host/host_tensor/include
|
||||
${PROJECT_SOURCE_DIR}/host/solver/include
|
||||
${PROJECT_SOURCE_DIR}/composable_kernel/include
|
||||
${PROJECT_SOURCE_DIR}/composable_kernel/include/utility
|
||||
${PROJECT_SOURCE_DIR}/composable_kernel/include/tensor_description
|
||||
${PROJECT_SOURCE_DIR}/composable_kernel/include/tensor_operation
|
||||
${PROJECT_SOURCE_DIR}/composable_kernel/include/problem_transform
|
||||
${PROJECT_SOURCE_DIR}/composable_kernel/include/driver
|
||||
${PROJECT_SOURCE_DIR}/external/rocm/include
|
||||
)
|
||||
|
||||
set(CONV_FWD_DRIVER_OFFLINE_SOURCE src/conv_fwd_driver_offline.cpp)
|
||||
set(CONV_BWD_DRIVER_OFFLINE_SOURCE src/conv_bwd_driver_offline.cpp)
|
||||
|
||||
add_executable(conv_fwd_driver_offline ${CONV_FWD_DRIVER_OFFLINE_SOURCE})
|
||||
add_executable(conv_bwd_driver_offline ${CONV_BWD_DRIVER_OFFLINE_SOURCE})
|
||||
|
||||
target_link_libraries(conv_fwd_driver_offline PRIVATE host_tensor)
|
||||
target_link_libraries(conv_bwd_driver_offline PRIVATE host_tensor)
|
||||
@@ -0,0 +1,330 @@
|
||||
#include <unistd.h>
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk.hpp"
|
||||
#include "driver_gemm_xdlops_v2r3.hpp"
|
||||
|
||||
template <typename TInWei,
|
||||
typename TAcc,
|
||||
typename TOut,
|
||||
typename InLengths,
|
||||
typename WeiLengths,
|
||||
typename OutLengths,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads>
|
||||
void device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk(
|
||||
const InLengths& in_n_hi_wi_c_lengths,
|
||||
const WeiLengths& wei_k_y_x_c_lengths,
|
||||
const OutLengths& out_n_ho_wo_k_lengths,
|
||||
const ConvStrides& conv_strides,
|
||||
const ConvDilations& conv_dilations,
|
||||
const InLeftPads& in_left_pads,
|
||||
const InRightPads& in_right_pads,
|
||||
Tensor<TInWei>& in_n_hi_wi_c,
|
||||
const Tensor<TInWei>& wei_k_y_x_c,
|
||||
const Tensor<TOut>& out_n_ho_wo_k,
|
||||
ck::index_t nrepeat)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
std::cout << __func__ << std::endl;
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
DeviceMem in_n_hi_wi_c_device_buf(sizeof(TInWei) * in_n_hi_wi_c.mDesc.GetElementSpace());
|
||||
DeviceMem wei_k_y_x_c_device_buf(sizeof(TInWei) * wei_k_y_x_c.mDesc.GetElementSpace());
|
||||
DeviceMem out_n_ho_wo_k_device_buf(sizeof(TOut) * out_n_ho_wo_k.mDesc.GetElementSpace());
|
||||
|
||||
in_n_hi_wi_c_device_buf.ToDevice(in_n_hi_wi_c.mData.data());
|
||||
wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data());
|
||||
out_n_ho_wo_k_device_buf.ToDevice(out_n_ho_wo_k.mData.data());
|
||||
|
||||
const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor_packed(in_n_hi_wi_c_lengths);
|
||||
const auto wei_k_y_x_c_desc = make_naive_tensor_descriptor_packed(wei_k_y_x_c_lengths);
|
||||
const auto out_n_ho_wo_k_desc = make_naive_tensor_descriptor_packed(out_n_ho_wo_k_lengths);
|
||||
|
||||
#if 1
|
||||
// [M, N, K0, K1] = [128, 128, 4, 4] for fp32
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerWave = 32;
|
||||
constexpr index_t GemmNPerWave = 32;
|
||||
constexpr index_t GemmK1 = 4;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 4>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmM = 2;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 4;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 4;
|
||||
#elif 1
|
||||
// [M, N, K0, K1] = [128, 128, 4, 8] for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerWave = 32;
|
||||
constexpr index_t GemmNPerWave = 32;
|
||||
constexpr index_t GemmK1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 8>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmM = 2;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 8;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 4;
|
||||
#elif 1
|
||||
// [M, N, K0, K1] = [256, 128, 4, 8] for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 256;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerWave = 32;
|
||||
constexpr index_t GemmNPerWave = 32;
|
||||
constexpr index_t GemmK1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 4;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmM = 4;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 8;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 4;
|
||||
#elif 1
|
||||
// [M, N, K0, K1] = [128, 256, 4, 8] for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
constexpr index_t GemmNPerBlock = 256;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerWave = 32;
|
||||
constexpr index_t GemmNPerWave = 32;
|
||||
constexpr index_t GemmK1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 4;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 8>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmM = 2;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 8>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 8;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 4;
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [256, 128, 4, 4]
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 256;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerWave = 32;
|
||||
constexpr index_t GemmNPerWave = 32;
|
||||
constexpr index_t GemmK1 = 4;
|
||||
|
||||
constexpr index_t MRepeat = 4;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 4>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmM = 4;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 4;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 4;
|
||||
#endif
|
||||
|
||||
const auto descs =
|
||||
transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk(wei_k_y_x_c_desc,
|
||||
out_n_ho_wo_k_desc,
|
||||
in_n_hi_wi_c_desc,
|
||||
conv_strides,
|
||||
conv_dilations,
|
||||
in_left_pads,
|
||||
in_right_pads,
|
||||
I0,
|
||||
I0,
|
||||
Number<GemmK1>{});
|
||||
|
||||
const auto wei_gemmk0_gemmm_gemmk1_grid_desc = descs[I0];
|
||||
const auto out_gemmk0_gemmn_gemmk1_grid_desc = descs[I1];
|
||||
const auto in_gemmm_gemmn_grid_desc = descs[I2];
|
||||
|
||||
// HACK: hacks that control index calculation when iterating over A, B, C matrix
|
||||
constexpr auto wei_gemmk0_gemmm_gemmk1_grid_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 0+: gemmk0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: gemmm
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 2+: gemmk1
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, // 0-: Gemmk0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: Gemmm
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 2-: Gemmk1
|
||||
|
||||
constexpr auto out_gemmk0_gemmn_gemmk1_grid_step_hacks = make_tuple(
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 0+: gemmk0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0>{}, // 1+: gemmn
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 2+: gemmk1
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, // 0-: gemmk0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0>{}, // 1-: gemmn
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 2-: gemmk1
|
||||
|
||||
constexpr auto in_m0_m1_m2_n_grid_step_hacks = make_tuple(
|
||||
make_tuple(
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: MRepeat
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 1+: NRepeat
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: MWaves
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 3+: NWaves
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M2
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}), // 7+: N1
|
||||
make_tuple(
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: MRepeat
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, // 1-: NRepeat
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: MWaves
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, // 3-: NWaves
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M2
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{})); // 7-: N1
|
||||
|
||||
constexpr auto wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks =
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{};
|
||||
|
||||
constexpr auto out_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks =
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0>{};
|
||||
|
||||
for(index_t i = 0; i < 5; ++i)
|
||||
{
|
||||
float ave_time = driver_gemm_xdlops_v2r3<
|
||||
BlockSize,
|
||||
TInWei,
|
||||
TAcc,
|
||||
TOut,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
decltype(wei_gemmk0_gemmm_gemmk1_grid_desc),
|
||||
decltype(out_gemmk0_gemmn_gemmk1_grid_desc),
|
||||
decltype(in_gemmm_gemmn_grid_desc),
|
||||
GemmMPerBlock,
|
||||
GemmNPerBlock,
|
||||
GemmKPerBlock,
|
||||
GemmMPerWave,
|
||||
GemmNPerWave,
|
||||
GemmK1,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1,
|
||||
GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1,
|
||||
Sequence<2, 0, 1>,
|
||||
Sequence<0, 2, 1>,
|
||||
1,
|
||||
GemmABlockTransferSrcScalarPerVector_GemmM,
|
||||
GemmABlockTransferDstScalarPerVector_GemmK1,
|
||||
false, // don't move back src coordinate after threadwise copy
|
||||
GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1,
|
||||
GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1,
|
||||
Sequence<1, 0, 2>,
|
||||
Sequence<1, 0, 2>,
|
||||
2,
|
||||
GemmBBlockTransferSrcScalarPerVector_GemmK1,
|
||||
GemmBBlockTransferDstScalarPerVector_GemmK1,
|
||||
false, // don't move back src coordinate after threadwise copy
|
||||
Sequence<1, 3, 7, 0, 2, 4, 5, 6>,
|
||||
6,
|
||||
GemmCThreadTransferDstScalarPerVector,
|
||||
decltype(wei_gemmk0_gemmm_gemmk1_grid_step_hacks),
|
||||
decltype(out_gemmk0_gemmn_gemmk1_grid_step_hacks),
|
||||
decltype(in_m0_m1_m2_n_grid_step_hacks),
|
||||
decltype(wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks),
|
||||
decltype(out_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks),
|
||||
false // CAccessOrderMRepeatNRepeat
|
||||
>(static_cast<TInWei*>(wei_k_y_x_c_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TOut*>(out_n_ho_wo_k_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TInWei*>(in_n_hi_wi_c_device_buf.GetDeviceBuffer()),
|
||||
wei_gemmk0_gemmm_gemmk1_grid_desc,
|
||||
out_gemmk0_gemmn_gemmk1_grid_desc,
|
||||
in_gemmm_gemmn_grid_desc,
|
||||
wei_gemmk0_gemmm_gemmk1_grid_step_hacks,
|
||||
out_gemmk0_gemmn_gemmk1_grid_step_hacks,
|
||||
in_m0_m1_m2_n_grid_step_hacks,
|
||||
wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks,
|
||||
out_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks,
|
||||
nrepeat);
|
||||
|
||||
{
|
||||
const auto N = out_n_ho_wo_k_lengths[I0];
|
||||
const auto K = out_n_ho_wo_k_lengths[I3];
|
||||
const auto C = wei_k_y_x_c_lengths[I3];
|
||||
|
||||
const auto Ho = out_n_ho_wo_k_lengths[I1];
|
||||
const auto Wo = out_n_ho_wo_k_lengths[I2];
|
||||
|
||||
const auto Y = wei_k_y_x_c_lengths[I1];
|
||||
const auto X = wei_k_y_x_c_lengths[I2];
|
||||
|
||||
float perf = static_cast<float>((std::size_t(2) * N * K * Ho * Wo * C * Y * X)) /
|
||||
(std::size_t(1000) * 1000 * 1000) / ave_time;
|
||||
|
||||
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s"
|
||||
<< std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
// copy result back to host
|
||||
in_n_hi_wi_c_device_buf.FromDevice(in_n_hi_wi_c.mData.data());
|
||||
}
|
||||
@@ -0,0 +1,306 @@
|
||||
#include <unistd.h>
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk.hpp"
|
||||
#include "driver_gemm_xdlops_v2r3.hpp"
|
||||
|
||||
template <typename TInWei,
|
||||
typename TAcc,
|
||||
typename TOut,
|
||||
typename InLengths,
|
||||
typename WeiLengths,
|
||||
typename OutLengths,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads>
|
||||
void device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk(
|
||||
const InLengths& in_n_hi_wi_c_lengths,
|
||||
const WeiLengths& wei_k_y_x_c_lengths,
|
||||
const OutLengths& out_n_ho_wo_k_lengths,
|
||||
const ConvStrides& conv_strides,
|
||||
const ConvDilations& conv_dilations,
|
||||
const InLeftPads& in_left_pads,
|
||||
const InRightPads& in_right_pads,
|
||||
Tensor<TInWei>& in_n_hi_wi_c,
|
||||
const Tensor<TInWei>& wei_k_y_x_c,
|
||||
const Tensor<TOut>& out_n_ho_wo_k,
|
||||
ck::index_t nrepeat)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
std::cout << __func__ << std::endl;
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
DeviceMem in_n_hi_wi_c_device_buf(sizeof(TInWei) * in_n_hi_wi_c.mDesc.GetElementSpace());
|
||||
DeviceMem wei_k_y_x_c_device_buf(sizeof(TInWei) * wei_k_y_x_c.mDesc.GetElementSpace());
|
||||
DeviceMem out_n_ho_wo_k_device_buf(sizeof(TOut) * out_n_ho_wo_k.mDesc.GetElementSpace());
|
||||
|
||||
in_n_hi_wi_c_device_buf.ToDevice(in_n_hi_wi_c.mData.data());
|
||||
wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data());
|
||||
out_n_ho_wo_k_device_buf.ToDevice(out_n_ho_wo_k.mData.data());
|
||||
|
||||
const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor_packed(in_n_hi_wi_c_lengths);
|
||||
const auto wei_k_y_x_c_desc = make_naive_tensor_descriptor_packed(wei_k_y_x_c_lengths);
|
||||
const auto out_n_ho_wo_k_desc = make_naive_tensor_descriptor_packed(out_n_ho_wo_k_lengths);
|
||||
|
||||
#if 0
|
||||
// [M, N, K0, K1] = [256, 128, 4, 4] for fp32
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 256;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerWave = 32;
|
||||
constexpr index_t GemmNPerWave = 32;
|
||||
constexpr index_t GemmK1 = 4;
|
||||
|
||||
constexpr index_t MRepeat = 4;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 4>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 4;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 2;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [128, 128, 4, 4] for fp32
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerWave = 32;
|
||||
constexpr index_t GemmNPerWave = 32;
|
||||
constexpr index_t GemmK1 = 4;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 4>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 4;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 2;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
|
||||
#elif 1
|
||||
// [M, N, K0, K1] = [256, 128, 4, 8] for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 256;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerWave = 32;
|
||||
constexpr index_t GemmNPerWave = 32;
|
||||
constexpr index_t GemmK1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 4;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 2;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
|
||||
#elif 1
|
||||
// [M, N, K0, K1] = [128, 256, 4, 8] for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
constexpr index_t GemmNPerBlock = 256;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerWave = 32;
|
||||
constexpr index_t GemmNPerWave = 32;
|
||||
constexpr index_t GemmK1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 4;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 8>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 8>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 4;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
|
||||
#endif
|
||||
|
||||
const auto descs =
|
||||
transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk(out_n_ho_wo_k_desc,
|
||||
wei_k_y_x_c_desc,
|
||||
in_n_hi_wi_c_desc,
|
||||
conv_strides,
|
||||
conv_dilations,
|
||||
in_left_pads,
|
||||
in_right_pads,
|
||||
I0,
|
||||
I0,
|
||||
Number<GemmK1>{});
|
||||
|
||||
const auto out_gemmk0_gemmm_gemmk1_grid_desc = descs[I0];
|
||||
const auto wei_gemmk0_gemmn_gemmk1_grid_desc = descs[I1];
|
||||
const auto in_gemmm_gemmn_grid_desc = descs[I2];
|
||||
|
||||
// HACK: hacks that control index calculation when iterating over A, B, C matrix
|
||||
constexpr auto out_gemmk0_gemmm_gemmk1_grid_step_hacks = make_tuple(
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 0+: gemmk0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0>{}, // 1+: gemmm
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 2+: gemmk1
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, // 0-: gemmk0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0>{}, // 1-: gemmm
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 2-: gemmk1
|
||||
|
||||
constexpr auto wei_gemmk0_gemmn_gemmk1_grid_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 0+: gemmk0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: gemmn
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 2+: gemmk1
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, // 0-: Gemmk0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: Gemmn
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 2-: Gemmk1
|
||||
|
||||
constexpr auto in_m0_m1_m2_n_grid_step_hacks = make_tuple(
|
||||
make_tuple(
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, // 0+: MRepeat
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: NRepeat
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, // 2+: MWaves
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: NWaves
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, // 4+: M0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, // 5+: M1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, // 6+: M2
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N1
|
||||
make_tuple(
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, // 0-: MRepeat
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: NRepeat
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, // 2-: MWaves
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: NWaves
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, // 4-: M0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, // 5-: M1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, // 6-: M2
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N1
|
||||
|
||||
constexpr auto out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks =
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0>{};
|
||||
|
||||
constexpr auto wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks =
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{};
|
||||
|
||||
for(index_t i = 0; i < 5; ++i)
|
||||
{
|
||||
float ave_time = driver_gemm_xdlops_v2r3<
|
||||
BlockSize,
|
||||
TInWei,
|
||||
TAcc,
|
||||
TOut,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
decltype(out_gemmk0_gemmm_gemmk1_grid_desc),
|
||||
decltype(wei_gemmk0_gemmn_gemmk1_grid_desc),
|
||||
decltype(in_gemmm_gemmn_grid_desc),
|
||||
GemmMPerBlock,
|
||||
GemmNPerBlock,
|
||||
GemmKPerBlock,
|
||||
GemmMPerWave,
|
||||
GemmNPerWave,
|
||||
GemmK1,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1,
|
||||
GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1,
|
||||
Sequence<1, 0, 2>,
|
||||
Sequence<1, 0, 2>,
|
||||
2,
|
||||
GemmABlockTransferSrcScalarPerVector_GemmK1,
|
||||
GemmABlockTransferDstScalarPerVector_GemmK1,
|
||||
false, // don't move back src coordinate after threadwise copy
|
||||
GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1,
|
||||
GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1,
|
||||
Sequence<2, 0, 1>,
|
||||
Sequence<0, 2, 1>,
|
||||
1,
|
||||
GemmBBlockTransferSrcScalarPerVector_GemmN,
|
||||
GemmBBlockTransferDstScalarPerVector_GemmK1,
|
||||
false, // don't move back src coordinate after threadwise copy
|
||||
#if 0
|
||||
Sequence<0, 2, 4, 5, 6, 1, 3, 7>,
|
||||
#else
|
||||
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
|
||||
#endif
|
||||
7,
|
||||
GemmCThreadTransferDstScalarPerVector,
|
||||
decltype(out_gemmk0_gemmm_gemmk1_grid_step_hacks),
|
||||
decltype(wei_gemmk0_gemmn_gemmk1_grid_step_hacks),
|
||||
decltype(in_m0_m1_m2_n_grid_step_hacks),
|
||||
decltype(out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks),
|
||||
decltype(wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks),
|
||||
true // CAccessOrderMRepeatNRepeat
|
||||
>(static_cast<TOut*>(out_n_ho_wo_k_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TInWei*>(wei_k_y_x_c_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TInWei*>(in_n_hi_wi_c_device_buf.GetDeviceBuffer()),
|
||||
out_gemmk0_gemmm_gemmk1_grid_desc,
|
||||
wei_gemmk0_gemmn_gemmk1_grid_desc,
|
||||
in_gemmm_gemmn_grid_desc,
|
||||
out_gemmk0_gemmm_gemmk1_grid_step_hacks,
|
||||
wei_gemmk0_gemmn_gemmk1_grid_step_hacks,
|
||||
in_m0_m1_m2_n_grid_step_hacks,
|
||||
out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks,
|
||||
wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks,
|
||||
nrepeat);
|
||||
|
||||
{
|
||||
const auto N = out_n_ho_wo_k_lengths[I0];
|
||||
const auto K = out_n_ho_wo_k_lengths[I3];
|
||||
const auto C = wei_k_y_x_c_lengths[I3];
|
||||
|
||||
const auto Ho = out_n_ho_wo_k_lengths[I1];
|
||||
const auto Wo = out_n_ho_wo_k_lengths[I2];
|
||||
|
||||
const auto Y = wei_k_y_x_c_lengths[I1];
|
||||
const auto X = wei_k_y_x_c_lengths[I2];
|
||||
|
||||
float perf = static_cast<float>((std::size_t(2) * N * K * Ho * Wo * C * Y * X)) /
|
||||
(std::size_t(1000) * 1000 * 1000) / ave_time;
|
||||
|
||||
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s"
|
||||
<< std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
// copy result back to host
|
||||
in_n_hi_wi_c_device_buf.FromDevice(in_n_hi_wi_c.mData.data());
|
||||
}
|
||||
@@ -0,0 +1,201 @@
|
||||
#include <unistd.h>
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw.hpp"
|
||||
#include "driver_gemm_dlops_v1r2.hpp"
|
||||
|
||||
template <typename TInWei,
|
||||
typename TAcc,
|
||||
typename TOut,
|
||||
typename InLengths,
|
||||
typename WeiLengths,
|
||||
typename OutLengths,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads>
|
||||
void device_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw(
|
||||
const InLengths& in_n_c_hi_wi_lengths,
|
||||
const WeiLengths& wei_k_c_y_x_lengths,
|
||||
const OutLengths& out_n_k_ho_wo_lengths,
|
||||
const ConvStrides& conv_strides,
|
||||
const ConvDilations& conv_dilations,
|
||||
const InLeftPads& in_left_pads,
|
||||
const InRightPads& in_right_pads,
|
||||
const Tensor<TInWei>& in_n_c_hi_wi,
|
||||
const Tensor<TInWei>& wei_k_c_y_x,
|
||||
Tensor<TOut>& out_n_k_ho_wo,
|
||||
ck::index_t nrepeat)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
std::cout << __func__ << std::endl;
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
|
||||
DeviceMem in_n_c_hi_wi_device_buf(sizeof(TInWei) * in_n_c_hi_wi.mDesc.GetElementSpace());
|
||||
DeviceMem wei_k_c_y_x_device_buf(sizeof(TInWei) * wei_k_c_y_x.mDesc.GetElementSpace());
|
||||
DeviceMem out_n_k_ho_wo_device_buf(sizeof(TOut) * out_n_k_ho_wo.mDesc.GetElementSpace());
|
||||
|
||||
in_n_c_hi_wi_device_buf.ToDevice(in_n_c_hi_wi.mData.data());
|
||||
wei_k_c_y_x_device_buf.ToDevice(wei_k_c_y_x.mData.data());
|
||||
out_n_k_ho_wo_device_buf.ToDevice(out_n_k_ho_wo.mData.data());
|
||||
|
||||
const auto in_n_c_hi_wi_desc = make_naive_tensor_descriptor_packed(in_n_c_hi_wi_lengths);
|
||||
const auto wei_k_c_y_x_desc = make_naive_tensor_descriptor_packed(wei_k_c_y_x_lengths);
|
||||
const auto out_n_k_ho_wo_desc = make_naive_tensor_descriptor_packed(out_n_k_ho_wo_lengths);
|
||||
|
||||
#if 1
|
||||
// cdata = 64, BlockSize = 256, 128x128x8
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlockM1 = 128;
|
||||
constexpr index_t GemmNPerBlockN1 = 128;
|
||||
constexpr index_t GemmKPerBlock = 8;
|
||||
|
||||
constexpr index_t GemmM1PerThreadM111 = 4;
|
||||
constexpr index_t GemmN1PerThreadN111 = 4;
|
||||
constexpr index_t GemmKPerThread = 1;
|
||||
|
||||
constexpr index_t GemmM11N11ThreadClusterM1100 = 8;
|
||||
constexpr index_t GemmM11N11ThreadClusterN1100 = 8;
|
||||
constexpr index_t GemmM11N11ThreadClusterM1101 = 2;
|
||||
constexpr index_t GemmM11N11ThreadClusterN1101 = 2;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_K_M0_M1 = Sequence<4, 1, 1>;
|
||||
using GemmABlockTransferThreadClusterLengths_K_M0_M1 = Sequence<2, 1, 128>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_K = 4;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_M1 = 1;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_K_N0_N1 = Sequence<4, 1, 1>;
|
||||
using GemmBBlockTransferThreadClusterLengths_K_N0_N1 = Sequence<2, 1, 128>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_N1 = 1;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_N1 = 1;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector_N11 = 1;
|
||||
#endif
|
||||
|
||||
const auto descs =
|
||||
transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw_pad(wei_k_c_y_x_desc,
|
||||
in_n_c_hi_wi_desc,
|
||||
out_n_k_ho_wo_desc,
|
||||
conv_strides,
|
||||
conv_dilations,
|
||||
in_left_pads,
|
||||
in_right_pads);
|
||||
|
||||
// HACK: hacks that control index calculation when iterating over A, B, C matrix
|
||||
constexpr auto wei_gemmk_gemmm0_gemmn1_grid_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{}));
|
||||
|
||||
constexpr auto in_gemmk_gemmn0_gemmn1_grid_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}));
|
||||
|
||||
constexpr auto out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 1, 0, 0>{},
|
||||
Sequence<0, 0, 1, 0, 0>{},
|
||||
Sequence<0, 0, 1, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 2, 0, 0>{},
|
||||
Sequence<0, 0, 2, 0, 0>{},
|
||||
Sequence<0, 0, 2, 0, 0>{}));
|
||||
|
||||
constexpr auto wei_gemmk_gemmm0_gemmm1_grid_move_slice_window_step_hacks =
|
||||
Sequence<0, 0, 0, 0, 0>{};
|
||||
|
||||
constexpr auto in_gemmk_gemmn0_gemmn1_grid_move_slice_window_step_hacks =
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>{};
|
||||
|
||||
const auto wei_gemmk_gemmm_grid_desc = descs[I0];
|
||||
const auto in_gemmk_gemmn_grid_desc = descs[I1];
|
||||
const auto out_gemmm_gemmn_grid_desc = descs[I2];
|
||||
|
||||
for(index_t i = 0; i < 5; ++i)
|
||||
{
|
||||
float ave_time = driver_gemm_dlops_v1r2<
|
||||
BlockSize,
|
||||
TInWei,
|
||||
TAcc,
|
||||
TOut,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
decltype(wei_gemmk_gemmm_grid_desc),
|
||||
decltype(in_gemmk_gemmn_grid_desc),
|
||||
decltype(out_gemmm_gemmn_grid_desc),
|
||||
GemmMPerBlockM1,
|
||||
GemmNPerBlockN1,
|
||||
GemmKPerBlock,
|
||||
GemmM1PerThreadM111,
|
||||
GemmN1PerThreadN111,
|
||||
GemmKPerThread,
|
||||
GemmM11N11ThreadClusterM1100,
|
||||
GemmM11N11ThreadClusterN1100,
|
||||
GemmM11N11ThreadClusterM1101,
|
||||
GemmM11N11ThreadClusterN1101,
|
||||
GemmABlockTransferThreadSliceLengths_K_M0_M1,
|
||||
GemmABlockTransferThreadClusterLengths_K_M0_M1,
|
||||
Sequence<2, 1, 0>, // ABlockTransferThreadClusterArrangeOrder
|
||||
Sequence<2, 1, 0>, // ABlockTransferSrcAccessOrder
|
||||
0, // ABlockTransferSrcVectorDim
|
||||
GemmABlockTransferSrcScalarPerVector_K,
|
||||
GemmABlockTransferDstScalarPerVector_M1,
|
||||
false, // don't move back src coordinate after threadwise copy
|
||||
GemmBBlockTransferThreadSliceLengths_K_N0_N1,
|
||||
GemmBBlockTransferThreadClusterLengths_K_N0_N1,
|
||||
Sequence<0, 1, 2>, // BBlockTransferThreadClusterArrangeOrder
|
||||
Sequence<0, 1, 2>, // BBlockTransferSrcAccessOrder
|
||||
2, // BBlockTransferSrcVectorDim
|
||||
GemmBBlockTransferSrcScalarPerVector_N1,
|
||||
GemmBBlockTransferDstScalarPerVector_N1,
|
||||
false, // don't move back src coordinate after threadwise copy
|
||||
Sequence<3, 4, 5, 0, 1, 2>, // CThreadTransferSrcDstAccessOrder
|
||||
5, // CThreadTransferSrcDstVectorDim
|
||||
GemmCThreadTransferDstScalarPerVector_N11,
|
||||
decltype(wei_gemmk_gemmm0_gemmn1_grid_step_hacks),
|
||||
decltype(in_gemmk_gemmn0_gemmn1_grid_step_hacks),
|
||||
decltype(out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_step_hacks),
|
||||
decltype(wei_gemmk_gemmm0_gemmm1_grid_move_slice_window_step_hacks),
|
||||
decltype(in_gemmk_gemmn0_gemmn1_grid_move_slice_window_step_hacks)>(
|
||||
static_cast<TInWei*>(wei_k_c_y_x_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TInWei*>(in_n_c_hi_wi_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TOut*>(out_n_k_ho_wo_device_buf.GetDeviceBuffer()),
|
||||
wei_gemmk_gemmm_grid_desc,
|
||||
in_gemmk_gemmn_grid_desc,
|
||||
out_gemmm_gemmn_grid_desc,
|
||||
wei_gemmk_gemmm0_gemmn1_grid_step_hacks,
|
||||
in_gemmk_gemmn0_gemmn1_grid_step_hacks,
|
||||
out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_step_hacks,
|
||||
wei_gemmk_gemmm0_gemmm1_grid_move_slice_window_step_hacks,
|
||||
in_gemmk_gemmn0_gemmn1_grid_move_slice_window_step_hacks,
|
||||
nrepeat);
|
||||
|
||||
float perf = static_cast<float>(calculate_convolution_flops(
|
||||
in_n_c_hi_wi_desc, wei_k_c_y_x_desc, out_n_k_ho_wo_desc)) /
|
||||
(std::size_t(1000) * 1000 * 1000) / ave_time;
|
||||
|
||||
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl;
|
||||
}
|
||||
|
||||
// copy result back to host
|
||||
out_n_k_ho_wo_device_buf.FromDevice(out_n_k_ho_wo.mData.data());
|
||||
}
|
||||
@@ -0,0 +1,280 @@
|
||||
#include <unistd.h>
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "driver_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp"
|
||||
|
||||
template <typename TInWei,
|
||||
typename TAcc,
|
||||
typename TOut,
|
||||
typename InLengths,
|
||||
typename WeiLengths,
|
||||
typename OutLengths,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads>
|
||||
void device_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw(
|
||||
const InLengths& in_n_c_hi_wi_lengths,
|
||||
const WeiLengths& wei_k_c_y_x_lengths,
|
||||
const OutLengths& out_n_k_ho_wo_lengths,
|
||||
const ConvStrides& conv_strides,
|
||||
const ConvDilations& conv_dilations,
|
||||
const InLeftPads& in_left_pads,
|
||||
const InRightPads& in_right_pads,
|
||||
const Tensor<TInWei>& in_n_c_hi_wi,
|
||||
const Tensor<TInWei>& wei_k_c_y_x,
|
||||
Tensor<TOut>& out_n_k_ho_wo,
|
||||
ck::index_t nrepeat)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
std::cout << __func__ << std::endl;
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
constexpr auto I4 = Number<4>{};
|
||||
constexpr auto I5 = Number<5>{};
|
||||
constexpr auto I6 = Number<6>{};
|
||||
constexpr auto I7 = Number<7>{};
|
||||
constexpr auto I8 = Number<8>{};
|
||||
|
||||
DeviceMem in_n_c_hi_wi_device_buf(sizeof(TInWei) * in_n_c_hi_wi.mDesc.GetElementSpace());
|
||||
DeviceMem wei_k_c_y_x_device_buf(sizeof(TInWei) * wei_k_c_y_x.mDesc.GetElementSpace());
|
||||
DeviceMem out_n_k_ho_wo_device_buf(sizeof(TOut) * out_n_k_ho_wo.mDesc.GetElementSpace());
|
||||
|
||||
in_n_c_hi_wi_device_buf.ToDevice(in_n_c_hi_wi.mData.data());
|
||||
wei_k_c_y_x_device_buf.ToDevice(wei_k_c_y_x.mData.data());
|
||||
out_n_k_ho_wo_device_buf.ToDevice(out_n_k_ho_wo.mData.data());
|
||||
|
||||
const auto in_n_c_hi_wi_desc = make_naive_tensor_descriptor_packed(in_n_c_hi_wi_lengths);
|
||||
const auto wei_k_c_y_x_desc = make_naive_tensor_descriptor_packed(wei_k_c_y_x_lengths);
|
||||
const auto out_n_k_ho_wo_desc = make_naive_tensor_descriptor_packed(out_n_k_ho_wo_lengths);
|
||||
|
||||
#if 0
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerWave = 64;
|
||||
constexpr index_t GemmNPerWave = 64;
|
||||
constexpr index_t GemmKPack = 8;
|
||||
|
||||
constexpr index_t MRepeat = 1;
|
||||
constexpr index_t NRepeat = 1;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 8>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 8;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_KPack = 8;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 4>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 32, 2>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 4;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_KPack = 4;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 1;
|
||||
#elif 0
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 256;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerWave = 64;
|
||||
constexpr index_t GemmNPerWave = 64;
|
||||
constexpr index_t GemmKPack = 8;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 1;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 8;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_KPack = 8;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 4>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 32, 2>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 4;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_KPack = 4;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 1;
|
||||
#elif 0
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 256;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerWave = 64;
|
||||
constexpr index_t GemmNPerWave = 64;
|
||||
constexpr index_t GemmKPack = 8;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 1;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 8;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_KPack = 8;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 4>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 32, 2>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_KPack = 4;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 1;
|
||||
#elif 1
|
||||
// [M, N, K0, K1] = [256, 128, 4, 4]
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 256;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerWave = 64;
|
||||
constexpr index_t GemmNPerWave = 64;
|
||||
constexpr index_t GemmKPack = 4;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 1;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 4>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 4;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_KPack = 4;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_KPack = 4;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 1;
|
||||
#elif 1
|
||||
// [M, N, K0, K1] = [128, 128, 4, 4]
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerWave = 64;
|
||||
constexpr index_t GemmNPerWave = 64;
|
||||
constexpr index_t GemmKPack = 4;
|
||||
|
||||
constexpr index_t MRepeat = 1;
|
||||
constexpr index_t NRepeat = 1;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 4>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 4;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_KPack = 4;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_KPack = 4;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 1;
|
||||
#endif
|
||||
|
||||
const auto descs =
|
||||
#if 1
|
||||
transform_forward_convolution_into_gemm_v4r4_xdlops_nchw_kcyx_nkhw_pad
|
||||
#else
|
||||
transform_forward_convolution_into_gemm_v4r4_xdlops_nchw_kcyx_nkhw_1x1
|
||||
#endif
|
||||
<TInWei, GemmMPerBlock, GemmNPerBlock, GemmMPerWave, GemmNPerWave, GemmKPack>(
|
||||
wei_k_c_y_x_desc,
|
||||
in_n_c_hi_wi_desc,
|
||||
out_n_k_ho_wo_desc,
|
||||
conv_strides,
|
||||
conv_dilations,
|
||||
in_left_pads,
|
||||
in_right_pads);
|
||||
|
||||
for(index_t i = 0; i < 5; ++i)
|
||||
{
|
||||
#if 0
|
||||
float ave_time = launch_kernel_gemm_xdlops_v1
|
||||
#else
|
||||
float ave_time = launch_kernel_gemm_xdlops_v2
|
||||
#endif
|
||||
<BlockSize,
|
||||
TInWei,
|
||||
TAcc,
|
||||
TOut,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
decltype(descs[I0]),
|
||||
decltype(descs[I1]),
|
||||
decltype(descs[I2]),
|
||||
decltype(descs[I3]),
|
||||
GemmMPerBlock,
|
||||
GemmNPerBlock,
|
||||
GemmKPerBlock,
|
||||
GemmMPerWave,
|
||||
GemmNPerWave,
|
||||
GemmKPack,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1,
|
||||
GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1,
|
||||
Sequence<1, 0, 2>,
|
||||
Sequence<1, 0, 2>,
|
||||
2,
|
||||
GemmABlockTransferSrcScalarPerVector_GemmK,
|
||||
GemmABlockTransferDstScalarPerVector_KPack,
|
||||
false, // don't move back src coordinate after threadwise copy
|
||||
GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1,
|
||||
GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1,
|
||||
Sequence<0, 2, 1>,
|
||||
Sequence<1, 0, 2>,
|
||||
1,
|
||||
GemmBBlockTransferSrcScalarPerVector_GemmN,
|
||||
GemmBBlockTransferDstScalarPerVector_KPack,
|
||||
false, // don't move back src coordinate after threadwise copy, which will be fused
|
||||
// with MoveSrcSliceWindow() to save addr computation
|
||||
Sequence<2, 3, 0, 1>,
|
||||
3,
|
||||
GemmCThreadTransferDstScalarPerVector_GemmN1,
|
||||
decltype(descs[I4]),
|
||||
decltype(descs[I5]),
|
||||
decltype(descs[I6]),
|
||||
decltype(descs[I7]),
|
||||
decltype(descs[I8])>(static_cast<TInWei*>(wei_k_c_y_x_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TInWei*>(in_n_c_hi_wi_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TOut*>(out_n_k_ho_wo_device_buf.GetDeviceBuffer()),
|
||||
descs[I0],
|
||||
descs[I1],
|
||||
descs[I2],
|
||||
descs[I3],
|
||||
descs[I4],
|
||||
descs[I5],
|
||||
descs[I6],
|
||||
descs[I7],
|
||||
descs[I8],
|
||||
nrepeat);
|
||||
|
||||
float perf = (float)calculate_convolution_flops(
|
||||
in_n_c_hi_wi_desc, wei_k_c_y_x_desc, out_n_k_ho_wo_desc) /
|
||||
(std::size_t(1000) * 1000 * 1000) / ave_time;
|
||||
|
||||
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl;
|
||||
}
|
||||
|
||||
// copy result back to host
|
||||
out_n_k_ho_wo_device_buf.FromDevice(out_n_k_ho_wo.mData.data());
|
||||
}
|
||||
@@ -0,0 +1,273 @@
|
||||
#include <unistd.h>
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp"
|
||||
#include "driver_gemm_dlops_v1r3.hpp"
|
||||
|
||||
template <typename TInWei,
|
||||
typename TAcc,
|
||||
typename TOut,
|
||||
typename InLengths,
|
||||
typename WeiLengths,
|
||||
typename OutLengths,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads>
|
||||
void device_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhwk(
|
||||
const InLengths& in_n_hi_wi_c_lengths,
|
||||
const WeiLengths& wei_k_y_x_c_lengths,
|
||||
const OutLengths& out_n_ho_wo_k_lengths,
|
||||
const ConvStrides& conv_strides,
|
||||
const ConvDilations& conv_dilations,
|
||||
const InLeftPads& in_left_pads,
|
||||
const InRightPads& in_right_pads,
|
||||
const Tensor<TInWei>& in_n_hi_wi_c,
|
||||
const Tensor<TInWei>& wei_k_y_x_c,
|
||||
Tensor<TOut>& out_n_ho_wo_k,
|
||||
ck::index_t nrepeat)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
std::cout << __func__ << std::endl;
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
DeviceMem in_n_hi_wi_c_device_buf(sizeof(TInWei) * in_n_hi_wi_c.mDesc.GetElementSpace());
|
||||
DeviceMem wei_k_y_x_c_device_buf(sizeof(TInWei) * wei_k_y_x_c.mDesc.GetElementSpace());
|
||||
DeviceMem out_n_ho_wo_k_device_buf(sizeof(TOut) * out_n_ho_wo_k.mDesc.GetElementSpace());
|
||||
|
||||
in_n_hi_wi_c_device_buf.ToDevice(in_n_hi_wi_c.mData.data());
|
||||
wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data());
|
||||
out_n_ho_wo_k_device_buf.ToDevice(out_n_ho_wo_k.mData.data());
|
||||
|
||||
const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor_packed(in_n_hi_wi_c_lengths);
|
||||
const auto wei_k_y_x_c_desc = make_naive_tensor_descriptor_packed(wei_k_y_x_c_lengths);
|
||||
const auto out_n_ho_wo_k_desc = make_naive_tensor_descriptor_packed(out_n_ho_wo_k_lengths);
|
||||
|
||||
#if 0
|
||||
// [M, N, K0, K1] = [128, 128, 8, 1] for fp32
|
||||
// cdata = 64, BlockSize = 256
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlockM1 = 128;
|
||||
constexpr index_t GemmNPerBlockN1 = 128;
|
||||
constexpr index_t GemmKPerBlock = 8;
|
||||
constexpr index_t GemmK1 = 1;
|
||||
|
||||
constexpr index_t GemmM1PerThreadM111 = 4;
|
||||
constexpr index_t GemmN1PerThreadN111 = 4;
|
||||
constexpr index_t GemmKPerThread = 1;
|
||||
|
||||
using GemmM11N11ThreadClusterM110Xs = Sequence<8, 2>;
|
||||
using GemmM11N11ThreadClusterN110Xs = Sequence<8, 2>;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_K0_M0_M1_K1 = Sequence<4, 1, 1, 1>;
|
||||
using GemmABlockTransferThreadClusterLengths_K0_M0_M1_K1 = Sequence<2, 1, 128, 1>;
|
||||
|
||||
using GemmABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1 = Sequence<4, 1, 1, 1>;
|
||||
using GemmABlockTransferDstVectorTensorLengths_K0_M0_M1_K1 = Sequence<1, 1, 1, 1>;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_K0_N0_N1_K1 = Sequence<4, 1, 1, 1>;
|
||||
using GemmBBlockTransferThreadClusterLengths_K0_N0_N1_K1 = Sequence<2, 1, 128, 1>;
|
||||
|
||||
using GemmBBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1 = Sequence<4, 1, 1, 1>;
|
||||
using GemmBBlockTransferDstVectorTensorLengths_K0_N0_N1_K1 = Sequence<1, 1, 1, 1>;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector_N11 = 4;
|
||||
#elif 1
|
||||
// [M, N, K0, K1] = [128, 128, 8, 2] for fp16
|
||||
// cdata = 64, BlockSize = 256
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlockM1 = 128;
|
||||
constexpr index_t GemmNPerBlockN1 = 128;
|
||||
constexpr index_t GemmKPerBlock = 8;
|
||||
constexpr index_t GemmK1 = 2;
|
||||
|
||||
constexpr index_t GemmM1PerThreadM111 = 4;
|
||||
constexpr index_t GemmN1PerThreadN111 = 4;
|
||||
constexpr index_t GemmKPerThread = 1;
|
||||
|
||||
using GemmM11N11ThreadClusterM110Xs = Sequence<8, 2>;
|
||||
using GemmM11N11ThreadClusterN110Xs = Sequence<8, 2>;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_K0_M0_M1_K1 = Sequence<4, 1, 1, 2>;
|
||||
using GemmABlockTransferThreadClusterLengths_K0_M0_M1_K1 = Sequence<2, 1, 128, 1>;
|
||||
|
||||
using GemmABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1 = Sequence<4, 1, 1, 2>;
|
||||
using GemmABlockTransferDstVectorTensorLengths_K0_M0_M1_K1 = Sequence<1, 1, 1, 2>;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_K0_N0_N1_K1 = Sequence<4, 1, 1, 2>;
|
||||
using GemmBBlockTransferThreadClusterLengths_K0_N0_N1_K1 = Sequence<2, 1, 128, 1>;
|
||||
|
||||
using GemmBBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1 = Sequence<4, 1, 1, 2>;
|
||||
using GemmBBlockTransferDstVectorTensorLengths_K0_N0_N1_K1 = Sequence<1, 1, 1, 2>;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector_N11 = 4;
|
||||
#elif 1
|
||||
// [M, N, K0, K1] = [128, 128, 8, 4] for i8
|
||||
// cdata = 64, BlockSize = 256
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlockM1 = 128;
|
||||
constexpr index_t GemmNPerBlockN1 = 128;
|
||||
constexpr index_t GemmKPerBlock = 8;
|
||||
constexpr index_t GemmK1 = 4;
|
||||
|
||||
constexpr index_t GemmM1PerThreadM111 = 4;
|
||||
constexpr index_t GemmN1PerThreadN111 = 4;
|
||||
constexpr index_t GemmKPerThread = 1;
|
||||
|
||||
using GemmM11N11ThreadClusterM110Xs = Sequence<8, 2>;
|
||||
using GemmM11N11ThreadClusterN110Xs = Sequence<8, 2>;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_K0_M0_M1_K1 = Sequence<4, 1, 1, 4>;
|
||||
using GemmABlockTransferThreadClusterLengths_K0_M0_M1_K1 = Sequence<2, 1, 128, 1>;
|
||||
|
||||
using GemmABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1 = Sequence<4, 1, 1, 4>;
|
||||
using GemmABlockTransferDstVectorTensorLengths_K0_M0_M1_K1 = Sequence<1, 1, 1, 4>;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_K0_N0_N1_K1 = Sequence<4, 1, 1, 4>;
|
||||
using GemmBBlockTransferThreadClusterLengths_K0_N0_N1_K1 = Sequence<2, 1, 128, 1>;
|
||||
|
||||
using GemmBBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1 = Sequence<4, 1, 1, 4>;
|
||||
using GemmBBlockTransferDstVectorTensorLengths_K0_N0_N1_K1 = Sequence<1, 1, 1, 4>;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector_N11 = 4;
|
||||
#endif
|
||||
|
||||
const auto descs =
|
||||
transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk_pad(in_n_hi_wi_c_desc,
|
||||
wei_k_y_x_c_desc,
|
||||
out_n_ho_wo_k_desc,
|
||||
conv_strides,
|
||||
conv_dilations,
|
||||
in_left_pads,
|
||||
in_right_pads,
|
||||
Number<GemmK1>{});
|
||||
|
||||
const auto in_gemmk0_gemmm_gemmk1_grid_desc = descs[I0];
|
||||
const auto wei_gemmk0_gemmn_gemmk1_grid_desc = descs[I1];
|
||||
const auto out_gemmm_gemmn_grid_desc = descs[I2];
|
||||
|
||||
// HACK: hacks that control index calculation when iterating over A, B, C matrix
|
||||
constexpr auto in_gemmk0_gemmm0_gemmm1_gemmk1_grid_step_hacks = make_tuple(
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0>{}, // 0+: GemmK0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 1+: GemmM0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 2+: GemmM1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0>{}), // 3+: GemmK1
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0>{}, // 0-: GemmK0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 1-: GemmM0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 3-: GemmM1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0>{})); // 3-: GemmK1
|
||||
|
||||
constexpr auto wei_gemmk0_gemmn0_gemmn1_gemmk1_grid_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: GemmK0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: GemmN0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: GemmN1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0>{}), // 3+: GemmK1
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: GemmK0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: GemmN0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: GemmN1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0>{})); // 3-: GemmK1
|
||||
|
||||
constexpr auto out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, // 0+: GemmM0
|
||||
Sequence<0, 0, 0, 0, 0>{}, // 1+: GemmM10
|
||||
Sequence<0, 0, 0, 0, 0>{}, // 2+: GemmM11
|
||||
Sequence<0, 0, 0, 0, 0>{}, // 3+: GemmN0
|
||||
Sequence<0, 0, 0, 0, 0>{}, // 4+: GemmN10
|
||||
Sequence<0, 0, 0, 0, 0>{}), // 5+: GemmN11
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0>{}, // 0-: GemmM0
|
||||
Sequence<0, 0, 0, 0, 0>{}, // 1-: GemmM10
|
||||
Sequence<0, 0, 0, 0, 0>{}, // 2-: GemmM11
|
||||
Sequence<0, 0, 0, 0, 0>{}, // 3-: GemmN0
|
||||
Sequence<0, 0, 0, 0, 0>{}, // 4-: GemmN10
|
||||
Sequence<0, 0, 0, 0, 0>{})); // 5-: GemmN11
|
||||
|
||||
constexpr auto in_gemmk0_gemmm0_gemmm1_gemmk1_grid_move_slice_window_step_hacks =
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0, 0, 0, 0>{};
|
||||
|
||||
constexpr auto wei_gemmk0_gemmn0_gemmn1_gemmk1_grid_move_slice_window_step_hacks =
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0>{};
|
||||
|
||||
for(index_t i = 0; i < 5; ++i)
|
||||
{
|
||||
float ave_time = driver_gemm_dlops_v1r3<
|
||||
BlockSize,
|
||||
TInWei,
|
||||
TAcc,
|
||||
TOut,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
decltype(in_gemmk0_gemmm_gemmk1_grid_desc),
|
||||
decltype(wei_gemmk0_gemmn_gemmk1_grid_desc),
|
||||
decltype(out_gemmm_gemmn_grid_desc),
|
||||
GemmMPerBlockM1,
|
||||
GemmNPerBlockN1,
|
||||
GemmKPerBlock,
|
||||
GemmM1PerThreadM111,
|
||||
GemmN1PerThreadN111,
|
||||
GemmKPerThread,
|
||||
GemmM11N11ThreadClusterM110Xs,
|
||||
GemmM11N11ThreadClusterN110Xs,
|
||||
GemmABlockTransferThreadSliceLengths_K0_M0_M1_K1,
|
||||
GemmABlockTransferThreadClusterLengths_K0_M0_M1_K1,
|
||||
Sequence<1, 2, 0, 3>, // ABlockTransferThreadClusterArrangeOrder
|
||||
Sequence<1, 2, 0, 3>, // ABlockTransferSrcAccessOrder
|
||||
GemmABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1,
|
||||
Sequence<1, 2, 0, 3>, // ABlockTransferSrcVectorTensorContiguousDimOrder
|
||||
GemmABlockTransferDstVectorTensorLengths_K0_M0_M1_K1,
|
||||
GemmBBlockTransferThreadSliceLengths_K0_N0_N1_K1,
|
||||
GemmBBlockTransferThreadClusterLengths_K0_N0_N1_K1,
|
||||
Sequence<1, 2, 0, 3>, // BBlockTransferThreadClusterArrangeOrder
|
||||
Sequence<1, 2, 0, 3>, // BBlockTransferSrcAccessOrder
|
||||
GemmBBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1,
|
||||
Sequence<1, 2, 0, 3>, // BBlockTransferSrcVectorTensorContiguousDimOrder
|
||||
GemmBBlockTransferDstVectorTensorLengths_K0_N0_N1_K1,
|
||||
Sequence<0, 1, 2, 3, 4, 5>, // CThreadTransferSrcDstAccessOrder
|
||||
5, // CThreadTransferSrcDstVectorDim
|
||||
GemmCThreadTransferDstScalarPerVector_N11,
|
||||
decltype(in_gemmk0_gemmm0_gemmm1_gemmk1_grid_step_hacks),
|
||||
decltype(wei_gemmk0_gemmn0_gemmn1_gemmk1_grid_step_hacks),
|
||||
decltype(out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_step_hacks),
|
||||
decltype(in_gemmk0_gemmm0_gemmm1_gemmk1_grid_move_slice_window_step_hacks),
|
||||
decltype(wei_gemmk0_gemmn0_gemmn1_gemmk1_grid_move_slice_window_step_hacks)>(
|
||||
static_cast<TInWei*>(in_n_hi_wi_c_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TInWei*>(wei_k_y_x_c_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TOut*>(out_n_ho_wo_k_device_buf.GetDeviceBuffer()),
|
||||
in_gemmk0_gemmm_gemmk1_grid_desc,
|
||||
wei_gemmk0_gemmn_gemmk1_grid_desc,
|
||||
out_gemmm_gemmn_grid_desc,
|
||||
in_gemmk0_gemmm0_gemmm1_gemmk1_grid_step_hacks,
|
||||
wei_gemmk0_gemmn0_gemmn1_gemmk1_grid_step_hacks,
|
||||
out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_step_hacks,
|
||||
in_gemmk0_gemmm0_gemmm1_gemmk1_grid_move_slice_window_step_hacks,
|
||||
wei_gemmk0_gemmn0_gemmn1_gemmk1_grid_move_slice_window_step_hacks,
|
||||
nrepeat);
|
||||
|
||||
{
|
||||
const auto N = out_n_ho_wo_k_lengths[I0];
|
||||
const auto K = out_n_ho_wo_k_lengths[I3];
|
||||
const auto C = wei_k_y_x_c_lengths[I3];
|
||||
|
||||
const auto Ho = out_n_ho_wo_k_lengths[I1];
|
||||
const auto Wo = out_n_ho_wo_k_lengths[I2];
|
||||
|
||||
const auto Y = wei_k_y_x_c_lengths[I1];
|
||||
const auto X = wei_k_y_x_c_lengths[I2];
|
||||
|
||||
float perf = static_cast<float>(std::size_t(2) * N * K * Ho * Wo * C * Y * X) /
|
||||
(std::size_t(1000) * 1000 * 1000) / ave_time;
|
||||
|
||||
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s"
|
||||
<< std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
// copy result back to host
|
||||
out_n_ho_wo_k_device_buf.FromDevice(out_n_ho_wo_k.mData.data());
|
||||
}
|
||||
@@ -0,0 +1,197 @@
|
||||
#include <unistd.h>
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw.hpp"
|
||||
#include "driver_gemm_xdlops_v2r3.hpp"
|
||||
|
||||
template <typename TInWei,
|
||||
typename TAcc,
|
||||
typename TOut,
|
||||
typename InLengths,
|
||||
typename WeiLengths,
|
||||
typename OutLengths,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads>
|
||||
void device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw(
|
||||
const InLengths& in_n_c_hi_wi_lengths,
|
||||
const WeiLengths& wei_k_c_y_x_lengths,
|
||||
const OutLengths& out_n_k_ho_wo_lengths,
|
||||
const ConvStrides& conv_strides,
|
||||
const ConvDilations& conv_dilations,
|
||||
const InLeftPads& in_left_pads,
|
||||
const InRightPads& in_right_pads,
|
||||
const Tensor<TInWei>& in_n_c_hi_wi,
|
||||
const Tensor<TInWei>& wei_k_c_y_x,
|
||||
Tensor<TOut>& out_n_k_ho_wo,
|
||||
ck::index_t nrepeat)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
std::cout << __func__ << std::endl;
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
|
||||
DeviceMem in_n_c_hi_wi_device_buf(sizeof(TInWei) * in_n_c_hi_wi.mDesc.GetElementSpace());
|
||||
DeviceMem wei_k_c_y_x_device_buf(sizeof(TInWei) * wei_k_c_y_x.mDesc.GetElementSpace());
|
||||
DeviceMem out_n_k_ho_wo_device_buf(sizeof(TOut) * out_n_k_ho_wo.mDesc.GetElementSpace());
|
||||
|
||||
in_n_c_hi_wi_device_buf.ToDevice(in_n_c_hi_wi.mData.data());
|
||||
wei_k_c_y_x_device_buf.ToDevice(wei_k_c_y_x.mData.data());
|
||||
out_n_k_ho_wo_device_buf.ToDevice(out_n_k_ho_wo.mData.data());
|
||||
|
||||
const auto in_n_c_hi_wi_desc = make_naive_tensor_descriptor_packed(in_n_c_hi_wi_lengths);
|
||||
const auto wei_k_c_y_x_desc = make_naive_tensor_descriptor_packed(wei_k_c_y_x_lengths);
|
||||
const auto out_n_k_ho_wo_desc = make_naive_tensor_descriptor_packed(out_n_k_ho_wo_lengths);
|
||||
|
||||
#if 1
|
||||
// [M, N, K0, K1] = [256, 128, 4, 8] for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 256;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerWave = 32;
|
||||
constexpr index_t GemmNPerWave = 32;
|
||||
constexpr index_t GemmK1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 4;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
|
||||
#endif
|
||||
|
||||
const auto descs =
|
||||
transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw_pad(wei_k_c_y_x_desc,
|
||||
in_n_c_hi_wi_desc,
|
||||
out_n_k_ho_wo_desc,
|
||||
conv_strides,
|
||||
conv_dilations,
|
||||
in_left_pads,
|
||||
in_right_pads,
|
||||
Number<GemmK1>{});
|
||||
|
||||
const auto wei_gemmk0_gemmm_gemmk1_grid_desc = descs[I0];
|
||||
const auto in_gemmk0_gemmn_gemmk1_grid_desc = descs[I1];
|
||||
const auto out_gemmm_gemmn_grid_desc = descs[I2];
|
||||
|
||||
// HACK: hacks that control index calculation when iterating over A, B, C matrix
|
||||
constexpr auto wei_gemmk0_gemmm_gemmk1_grid_step_hacks = make_tuple(
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}),
|
||||
make_tuple(
|
||||
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}));
|
||||
|
||||
constexpr auto in_gemmk0_gemmn_gemmk1_grid_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}));
|
||||
|
||||
constexpr auto out_m0_m1_m2_n_grid_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 1, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 1, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 1, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 2, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 2, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 2, 0, 0>{}));
|
||||
|
||||
constexpr auto wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks =
|
||||
Sequence<0, 0, 0, 0, 0>{};
|
||||
|
||||
constexpr auto in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks =
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>{};
|
||||
|
||||
for(index_t i = 0; i < 5; ++i)
|
||||
{
|
||||
float ave_time = driver_gemm_xdlops_v2r3<
|
||||
BlockSize,
|
||||
TInWei,
|
||||
TAcc,
|
||||
TOut,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
decltype(wei_gemmk0_gemmm_gemmk1_grid_desc),
|
||||
decltype(in_gemmk0_gemmn_gemmk1_grid_desc),
|
||||
decltype(out_gemmm_gemmn_grid_desc),
|
||||
GemmMPerBlock,
|
||||
GemmNPerBlock,
|
||||
GemmKPerBlock,
|
||||
GemmMPerWave,
|
||||
GemmNPerWave,
|
||||
GemmK1,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1,
|
||||
GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1,
|
||||
Sequence<1, 0, 2>,
|
||||
Sequence<1, 0, 2>,
|
||||
2,
|
||||
GemmABlockTransferSrcScalarPerVector_GemmK1,
|
||||
GemmABlockTransferDstScalarPerVector_GemmK1,
|
||||
false, // don't move back src coordinate after threadwise copy
|
||||
GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1,
|
||||
GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1,
|
||||
Sequence<0, 2, 1>,
|
||||
Sequence<1, 0, 2>,
|
||||
1,
|
||||
GemmBBlockTransferSrcScalarPerVector_GemmN,
|
||||
GemmBBlockTransferDstScalarPerVector_GemmK1,
|
||||
false, // don't move back src coordinate after threadwise copy
|
||||
Sequence<3, 0, 1, 2, 7, 5, 4, 6>,
|
||||
7,
|
||||
GemmCThreadTransferDstScalarPerVector,
|
||||
decltype(wei_gemmk0_gemmm_gemmk1_grid_step_hacks),
|
||||
decltype(in_gemmk0_gemmn_gemmk1_grid_step_hacks),
|
||||
decltype(out_m0_m1_m2_n_grid_step_hacks),
|
||||
decltype(wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks),
|
||||
decltype(in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks),
|
||||
false>(static_cast<TInWei*>(wei_k_c_y_x_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TInWei*>(in_n_c_hi_wi_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TOut*>(out_n_k_ho_wo_device_buf.GetDeviceBuffer()),
|
||||
wei_gemmk0_gemmm_gemmk1_grid_desc,
|
||||
in_gemmk0_gemmn_gemmk1_grid_desc,
|
||||
out_gemmm_gemmn_grid_desc,
|
||||
wei_gemmk0_gemmm_gemmk1_grid_step_hacks,
|
||||
in_gemmk0_gemmn_gemmk1_grid_step_hacks,
|
||||
out_m0_m1_m2_n_grid_step_hacks,
|
||||
wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks,
|
||||
in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks,
|
||||
nrepeat);
|
||||
|
||||
float perf = static_cast<float>(calculate_convolution_flops(
|
||||
in_n_c_hi_wi_desc, wei_k_c_y_x_desc, out_n_k_ho_wo_desc)) /
|
||||
(std::size_t(1000) * 1000 * 1000) / ave_time;
|
||||
|
||||
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl;
|
||||
}
|
||||
|
||||
// copy result back to host
|
||||
out_n_k_ho_wo_device_buf.FromDevice(out_n_k_ho_wo.mData.data());
|
||||
}
|
||||
@@ -0,0 +1,229 @@
|
||||
#include <unistd.h>
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "transform_forward_convolution_into_gemm_v4r4r2_nhwc_kyxc_nhwk.hpp"
|
||||
#include "driver_gemm_xdlops_v2r2.hpp"
|
||||
|
||||
template <typename TInWei,
|
||||
typename TAcc,
|
||||
typename TOut,
|
||||
typename InLengths,
|
||||
typename WeiLengths,
|
||||
typename OutLengths,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads>
|
||||
void device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nhwc_kyxc_nhwk(
|
||||
const InLengths& in_n_hi_wi_c_lengths,
|
||||
const WeiLengths& wei_k_y_x_c_lengths,
|
||||
const OutLengths& out_n_ho_wo_k_lengths,
|
||||
const ConvStrides& conv_strides,
|
||||
const ConvDilations& conv_dilations,
|
||||
const InLeftPads& in_left_pads,
|
||||
const InRightPads& in_right_pads,
|
||||
const Tensor<TInWei>& in_n_hi_wi_c,
|
||||
const Tensor<TInWei>& wei_k_y_x_c,
|
||||
Tensor<TOut>& out_n_ho_wo_k,
|
||||
ck::index_t nrepeat)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
std::cout << __func__ << std::endl;
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
DeviceMem in_n_hi_wi_c_device_buf(sizeof(TInWei) * in_n_hi_wi_c.mDesc.GetElementSpace());
|
||||
DeviceMem wei_k_y_x_c_device_buf(sizeof(TInWei) * wei_k_y_x_c.mDesc.GetElementSpace());
|
||||
DeviceMem out_n_ho_wo_k_device_buf(sizeof(TOut) * out_n_ho_wo_k.mDesc.GetElementSpace());
|
||||
|
||||
in_n_hi_wi_c_device_buf.ToDevice(in_n_hi_wi_c.mData.data());
|
||||
wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data());
|
||||
out_n_ho_wo_k_device_buf.ToDevice(out_n_ho_wo_k.mData.data());
|
||||
|
||||
const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor_packed(in_n_hi_wi_c_lengths);
|
||||
const auto wei_k_y_x_c_desc = make_naive_tensor_descriptor_packed(wei_k_y_x_c_lengths);
|
||||
const auto out_n_ho_wo_k_desc = make_naive_tensor_descriptor_packed(out_n_ho_wo_k_lengths);
|
||||
|
||||
#if 1
|
||||
// [M, N, K0, K1] = [256, 128, 4, 4] for fp32
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 256;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerWave = 64;
|
||||
constexpr index_t GemmNPerWave = 64;
|
||||
constexpr index_t GemmK1 = 4;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 1;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 4>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 4;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 4;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 4;
|
||||
#elif 1
|
||||
// [M, N, K0, K1] = [256, 128, 4, 8] for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 256;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerWave = 64;
|
||||
constexpr index_t GemmNPerWave = 64;
|
||||
constexpr index_t GemmK1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 1;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 8;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 4;
|
||||
#endif
|
||||
|
||||
const auto descs =
|
||||
transform_forward_convolution_into_gemm_v4r4r2_nhwc_kyxc_nhwk_pad(wei_k_y_x_c_desc,
|
||||
in_n_hi_wi_c_desc,
|
||||
out_n_ho_wo_k_desc,
|
||||
conv_strides,
|
||||
conv_dilations,
|
||||
in_left_pads,
|
||||
in_right_pads,
|
||||
Number<GemmK1>{});
|
||||
|
||||
const auto wei_gemmk0_gemmm_gemmk1_grid_desc = descs[I0];
|
||||
const auto in_gemmk0_gemmn_gemmk1_grid_desc = descs[I1];
|
||||
const auto out_gemmm_gemmn_grid_desc = descs[I2];
|
||||
|
||||
// HACK: hacks that control index calculation when iterating over A, B, C matrix
|
||||
constexpr auto wei_gemmk0_gemmm_gemmk1_grid_step_hacks = make_tuple(
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}),
|
||||
make_tuple(
|
||||
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}));
|
||||
|
||||
constexpr auto in_gemmk0_gemmn_gemmk1_grid_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}));
|
||||
|
||||
constexpr auto out_m0_m1_m2_n_grid_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 1, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 2, 0, 0>{}));
|
||||
|
||||
constexpr auto wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks =
|
||||
Sequence<0, 0, 0, 0, 0>{};
|
||||
|
||||
constexpr auto in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks =
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>{};
|
||||
|
||||
for(index_t i = 0; i < 5; ++i)
|
||||
{
|
||||
float ave_time = driver_gemm_xdlops_v2r2<
|
||||
BlockSize,
|
||||
TInWei,
|
||||
TAcc,
|
||||
TOut,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
decltype(wei_gemmk0_gemmm_gemmk1_grid_desc),
|
||||
decltype(in_gemmk0_gemmn_gemmk1_grid_desc),
|
||||
decltype(out_gemmm_gemmn_grid_desc),
|
||||
GemmMPerBlock,
|
||||
GemmNPerBlock,
|
||||
GemmKPerBlock,
|
||||
GemmMPerWave,
|
||||
GemmNPerWave,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1,
|
||||
GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1,
|
||||
Sequence<1, 0, 2>,
|
||||
Sequence<1, 0, 2>,
|
||||
2,
|
||||
GemmABlockTransferSrcScalarPerVector_GemmK1,
|
||||
GemmABlockTransferDstScalarPerVector_GemmK1,
|
||||
false, // don't move back src coordinate after threadwise copy
|
||||
GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1,
|
||||
GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1,
|
||||
Sequence<1, 0, 2>,
|
||||
Sequence<1, 0, 2>,
|
||||
2,
|
||||
GemmBBlockTransferSrcScalarPerVector_GemmK1,
|
||||
GemmBBlockTransferDstScalarPerVector_GemmK1,
|
||||
false, // don't move back src coordinate after threadwise copy
|
||||
Sequence<2, 3, 0, 1>,
|
||||
2,
|
||||
GemmCThreadTransferDstScalarPerVector,
|
||||
decltype(wei_gemmk0_gemmm_gemmk1_grid_step_hacks),
|
||||
decltype(in_gemmk0_gemmn_gemmk1_grid_step_hacks),
|
||||
decltype(out_m0_m1_m2_n_grid_step_hacks),
|
||||
decltype(wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks),
|
||||
decltype(in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks)>(
|
||||
static_cast<TInWei*>(wei_k_y_x_c_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TInWei*>(in_n_hi_wi_c_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TOut*>(out_n_ho_wo_k_device_buf.GetDeviceBuffer()),
|
||||
wei_gemmk0_gemmm_gemmk1_grid_desc,
|
||||
in_gemmk0_gemmn_gemmk1_grid_desc,
|
||||
out_gemmm_gemmn_grid_desc,
|
||||
wei_gemmk0_gemmm_gemmk1_grid_step_hacks,
|
||||
in_gemmk0_gemmn_gemmk1_grid_step_hacks,
|
||||
out_m0_m1_m2_n_grid_step_hacks,
|
||||
wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks,
|
||||
in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks,
|
||||
nrepeat);
|
||||
|
||||
{
|
||||
const auto N = out_n_ho_wo_k_lengths[I0];
|
||||
const auto K = out_n_ho_wo_k_lengths[I3];
|
||||
const auto C = wei_k_y_x_c_lengths[I3];
|
||||
|
||||
const auto Ho = out_n_ho_wo_k_lengths[I1];
|
||||
const auto Wo = out_n_ho_wo_k_lengths[I2];
|
||||
|
||||
const auto Y = wei_k_y_x_c_lengths[I1];
|
||||
const auto X = wei_k_y_x_c_lengths[I2];
|
||||
|
||||
float perf = (float)(std::size_t(2) * N * K * Ho * Wo * C * Y * X) /
|
||||
(std::size_t(1000) * 1000 * 1000) / ave_time;
|
||||
|
||||
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s"
|
||||
<< std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
// copy result back to host
|
||||
out_n_ho_wo_k_device_buf.FromDevice(out_n_ho_wo_k.mData.data());
|
||||
}
|
||||
@@ -0,0 +1,302 @@
|
||||
#include <unistd.h>
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "transform_forward_convolution_into_gemm_v4r4r2_nhwc_kyxc_nhwk.hpp"
|
||||
#include "driver_gemm_xdlops_v2r3.hpp"
|
||||
|
||||
template <typename TInWei,
|
||||
typename TAcc,
|
||||
typename TOut,
|
||||
typename InLengths,
|
||||
typename WeiLengths,
|
||||
typename OutLengths,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads>
|
||||
void device_convolution_forward_implicit_gemm_v4r4r3_xdlops_nhwc_kyxc_nhwk(
|
||||
const InLengths& in_n_hi_wi_c_lengths,
|
||||
const WeiLengths& wei_k_y_x_c_lengths,
|
||||
const OutLengths& out_n_ho_wo_k_lengths,
|
||||
const ConvStrides& conv_strides,
|
||||
const ConvDilations& conv_dilations,
|
||||
const InLeftPads& in_left_pads,
|
||||
const InRightPads& in_right_pads,
|
||||
const Tensor<TInWei>& in_n_hi_wi_c,
|
||||
const Tensor<TInWei>& wei_k_y_x_c,
|
||||
Tensor<TOut>& out_n_ho_wo_k,
|
||||
ck::index_t nrepeat)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
std::cout << __func__ << std::endl;
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
constexpr auto I4 = Number<4>{};
|
||||
constexpr auto I5 = Number<5>{};
|
||||
constexpr auto I6 = Number<6>{};
|
||||
constexpr auto I7 = Number<7>{};
|
||||
constexpr auto I8 = Number<8>{};
|
||||
|
||||
DeviceMem in_n_hi_wi_c_device_buf(sizeof(TInWei) * in_n_hi_wi_c.mDesc.GetElementSpace());
|
||||
DeviceMem wei_k_y_x_c_device_buf(sizeof(TInWei) * wei_k_y_x_c.mDesc.GetElementSpace());
|
||||
DeviceMem out_n_ho_wo_k_device_buf(sizeof(TOut) * out_n_ho_wo_k.mDesc.GetElementSpace());
|
||||
|
||||
in_n_hi_wi_c_device_buf.ToDevice(in_n_hi_wi_c.mData.data());
|
||||
wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data());
|
||||
out_n_ho_wo_k_device_buf.ToDevice(out_n_ho_wo_k.mData.data());
|
||||
|
||||
const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor_packed(in_n_hi_wi_c_lengths);
|
||||
const auto wei_k_y_x_c_desc = make_naive_tensor_descriptor_packed(wei_k_y_x_c_lengths);
|
||||
const auto out_n_ho_wo_k_desc = make_naive_tensor_descriptor_packed(out_n_ho_wo_k_lengths);
|
||||
|
||||
#if 1
|
||||
// [M, N, K0, K1] = [256, 128, 4, 4] for fp32
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 256;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerWave = 32;
|
||||
constexpr index_t GemmNPerWave = 32;
|
||||
constexpr index_t GemmK1 = 4;
|
||||
|
||||
constexpr index_t MRepeat = 4;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 4>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 4;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 4;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 4;
|
||||
#elif 1
|
||||
// [M, N, K0, K1] = [128, 128, 4, 4] for fp32
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerWave = 32;
|
||||
constexpr index_t GemmNPerWave = 32;
|
||||
constexpr index_t GemmK1 = 4;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 4>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 4;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 4;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 4;
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [256, 256, 4, 8] for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 256;
|
||||
constexpr index_t GemmNPerBlock = 256;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerWave = 32;
|
||||
constexpr index_t GemmNPerWave = 32;
|
||||
constexpr index_t GemmK1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 4;
|
||||
constexpr index_t NRepeat = 4;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 8>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 8;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 4;
|
||||
#elif 1
|
||||
// [M, N, K0, K1] = [256, 128, 4, 8] for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 256;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerWave = 32;
|
||||
constexpr index_t GemmNPerWave = 32;
|
||||
constexpr index_t GemmK1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 4;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 8;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 4;
|
||||
#endif
|
||||
|
||||
const auto descs =
|
||||
transform_forward_convolution_into_gemm_v4r4r2_nhwc_kyxc_nhwk_pad(wei_k_y_x_c_desc,
|
||||
in_n_hi_wi_c_desc,
|
||||
out_n_ho_wo_k_desc,
|
||||
conv_strides,
|
||||
conv_dilations,
|
||||
in_left_pads,
|
||||
in_right_pads,
|
||||
Number<GemmK1>{});
|
||||
|
||||
const auto wei_gemmk0_gemmm_gemmk1_grid_desc = descs[I0];
|
||||
const auto in_gemmk0_gemmn_gemmk1_grid_desc = descs[I1];
|
||||
const auto out_gemmm_gemmn_grid_desc = descs[I2];
|
||||
|
||||
// HACK: hacks that control index calculation when iterating over A, B, C matrix
|
||||
constexpr auto wei_gemmk0_gemmm_gemmk1_grid_step_hacks = make_tuple(
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}),
|
||||
make_tuple(
|
||||
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}));
|
||||
|
||||
constexpr auto in_gemmk0_gemmn_gemmk1_grid_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}));
|
||||
|
||||
constexpr auto out_m0_m1_m2_n_grid_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 1, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 1, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 1, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 2, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 2, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 2, 0, 0>{}));
|
||||
|
||||
constexpr auto wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks =
|
||||
Sequence<0, 0, 0, 0, 0>{};
|
||||
|
||||
constexpr auto in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks =
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>{};
|
||||
|
||||
for(index_t i = 0; i < 5; ++i)
|
||||
{
|
||||
float ave_time = driver_gemm_xdlops_v2r3<
|
||||
BlockSize,
|
||||
TInWei,
|
||||
TAcc,
|
||||
TOut,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
decltype(wei_gemmk0_gemmm_gemmk1_grid_desc),
|
||||
decltype(in_gemmk0_gemmn_gemmk1_grid_desc),
|
||||
decltype(out_gemmm_gemmn_grid_desc),
|
||||
GemmMPerBlock,
|
||||
GemmNPerBlock,
|
||||
GemmKPerBlock,
|
||||
GemmMPerWave,
|
||||
GemmNPerWave,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1,
|
||||
GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1,
|
||||
Sequence<1, 0, 2>,
|
||||
Sequence<1, 0, 2>,
|
||||
2,
|
||||
GemmABlockTransferSrcScalarPerVector_GemmK1,
|
||||
GemmABlockTransferDstScalarPerVector_GemmK1,
|
||||
false, // don't move back src coordinate after threadwise copy
|
||||
GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1,
|
||||
GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1,
|
||||
Sequence<1, 0, 2>,
|
||||
Sequence<1, 0, 2>,
|
||||
2,
|
||||
GemmBBlockTransferSrcScalarPerVector_GemmK1,
|
||||
GemmBBlockTransferDstScalarPerVector_GemmK1,
|
||||
false, // don't move back src coordinate after threadwise copy
|
||||
Sequence<2, 3, 0, 1, 7, 5, 4, 6>,
|
||||
6,
|
||||
GemmCThreadTransferDstScalarPerVector,
|
||||
decltype(wei_gemmk0_gemmm_gemmk1_grid_step_hacks),
|
||||
decltype(in_gemmk0_gemmn_gemmk1_grid_step_hacks),
|
||||
decltype(out_m0_m1_m2_n_grid_step_hacks),
|
||||
decltype(wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks),
|
||||
decltype(in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks),
|
||||
false // CAccessOrderMRepeatNRepeat
|
||||
>(static_cast<TInWei*>(wei_k_y_x_c_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TInWei*>(in_n_hi_wi_c_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TOut*>(out_n_ho_wo_k_device_buf.GetDeviceBuffer()),
|
||||
wei_gemmk0_gemmm_gemmk1_grid_desc,
|
||||
in_gemmk0_gemmn_gemmk1_grid_desc,
|
||||
out_gemmm_gemmn_grid_desc,
|
||||
wei_gemmk0_gemmm_gemmk1_grid_step_hacks,
|
||||
in_gemmk0_gemmn_gemmk1_grid_step_hacks,
|
||||
out_m0_m1_m2_n_grid_step_hacks,
|
||||
wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks,
|
||||
in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks,
|
||||
nrepeat);
|
||||
|
||||
{
|
||||
const auto N = out_n_ho_wo_k_lengths[I0];
|
||||
const auto K = out_n_ho_wo_k_lengths[I3];
|
||||
const auto C = wei_k_y_x_c_lengths[I3];
|
||||
|
||||
const auto Hi = in_n_hi_wi_c_lengths[I1];
|
||||
const auto Wi = in_n_hi_wi_c_lengths[I2];
|
||||
|
||||
const auto Ho = out_n_ho_wo_k_lengths[I1];
|
||||
const auto Wo = out_n_ho_wo_k_lengths[I2];
|
||||
|
||||
const auto Y = wei_k_y_x_c_lengths[I1];
|
||||
const auto X = wei_k_y_x_c_lengths[I2];
|
||||
|
||||
float perf = (float)(std::size_t(2) * N * K * Ho * Wo * C * Y * X) /
|
||||
(std::size_t(1000) * 1000 * 1000) / ave_time;
|
||||
|
||||
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s"
|
||||
<< std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
// copy result back to host
|
||||
out_n_ho_wo_k_device_buf.FromDevice(out_n_ho_wo_k.mData.data());
|
||||
}
|
||||
@@ -0,0 +1,354 @@
|
||||
#include <unistd.h>
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp"
|
||||
#include "driver_gemm_xdlops_v2r3.hpp"
|
||||
|
||||
template <typename TInWei,
|
||||
typename TAcc,
|
||||
typename TOut,
|
||||
typename InLengths,
|
||||
typename WeiLengths,
|
||||
typename OutLengths,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads>
|
||||
void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk(
|
||||
const InLengths& in_n_hi_wi_c_lengths,
|
||||
const WeiLengths& wei_k_y_x_c_lengths,
|
||||
const OutLengths& out_n_ho_wo_k_lengths,
|
||||
const ConvStrides& conv_strides,
|
||||
const ConvDilations& conv_dilations,
|
||||
const InLeftPads& in_left_pads,
|
||||
const InRightPads& in_right_pads,
|
||||
const Tensor<TInWei>& in_n_hi_wi_c,
|
||||
const Tensor<TInWei>& wei_k_y_x_c,
|
||||
Tensor<TOut>& out_n_ho_wo_k,
|
||||
ck::index_t nrepeat)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
std::cout << __func__ << std::endl;
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
DeviceMem in_n_hi_wi_c_device_buf(sizeof(TInWei) * in_n_hi_wi_c.mDesc.GetElementSpace());
|
||||
DeviceMem wei_k_y_x_c_device_buf(sizeof(TInWei) * wei_k_y_x_c.mDesc.GetElementSpace());
|
||||
DeviceMem out_n_ho_wo_k_device_buf(sizeof(TOut) * out_n_ho_wo_k.mDesc.GetElementSpace());
|
||||
|
||||
in_n_hi_wi_c_device_buf.ToDevice(in_n_hi_wi_c.mData.data());
|
||||
wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data());
|
||||
out_n_ho_wo_k_device_buf.ToDevice(out_n_ho_wo_k.mData.data());
|
||||
|
||||
const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor_packed(in_n_hi_wi_c_lengths);
|
||||
const auto wei_k_y_x_c_desc = make_naive_tensor_descriptor_packed(wei_k_y_x_c_lengths);
|
||||
const auto out_n_ho_wo_k_desc = make_naive_tensor_descriptor_packed(out_n_ho_wo_k_lengths);
|
||||
|
||||
#if 0
|
||||
// [M, N, K0, K1] = [256, 128, 4, 4] for fp32
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 256;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerWave = 32;
|
||||
constexpr index_t GemmNPerWave = 32;
|
||||
constexpr index_t GemmK1 = 4;
|
||||
|
||||
constexpr index_t MRepeat = 4;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 4>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 4;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 4;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [128, 128, 4, 4] for fp32
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerWave = 32;
|
||||
constexpr index_t GemmNPerWave = 32;
|
||||
constexpr index_t GemmK1 = 4;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 4>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 4;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 4;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [256, 256, 4, 8] for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 256;
|
||||
constexpr index_t GemmNPerBlock = 256;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerWave = 32;
|
||||
constexpr index_t GemmNPerWave = 32;
|
||||
constexpr index_t GemmK1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 4;
|
||||
constexpr index_t NRepeat = 4;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 8>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 8;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [256, 128, 4, 8] for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 256;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerWave = 32;
|
||||
constexpr index_t GemmNPerWave = 32;
|
||||
constexpr index_t GemmK1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 4;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 8;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
|
||||
#elif 1
|
||||
// [M, N, K0, K1] = [128, 256, 4, 8] for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
constexpr index_t GemmNPerBlock = 256;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerWave = 32;
|
||||
constexpr index_t GemmNPerWave = 32;
|
||||
constexpr index_t GemmK1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 4;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 8>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 8>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 8;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
|
||||
#elif 1
|
||||
// [M, N, K0, K1] = [128, 128, 4, 8] for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerWave = 32;
|
||||
constexpr index_t GemmNPerWave = 32;
|
||||
constexpr index_t GemmK1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 8>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 8;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
|
||||
#endif
|
||||
|
||||
const auto descs =
|
||||
transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk_pad(in_n_hi_wi_c_desc,
|
||||
wei_k_y_x_c_desc,
|
||||
out_n_ho_wo_k_desc,
|
||||
conv_strides,
|
||||
conv_dilations,
|
||||
in_left_pads,
|
||||
in_right_pads,
|
||||
Number<GemmK1>{});
|
||||
|
||||
const auto in_gemmk0_gemmm_gemmk1_grid_desc = descs[I0];
|
||||
const auto wei_gemmk0_gemmn_gemmk1_grid_desc = descs[I1];
|
||||
const auto out_gemmm_gemmn_grid_desc = descs[I2];
|
||||
|
||||
// HACK: hacks that control index calculation when iterating over A, B, C matrix
|
||||
constexpr auto in_gemmk0_gemmm_gemmk1_grid_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, // 0+: GemmK0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 1+: GemmM
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}), // 2+: GemmK1
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, // 0-: GemmK0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, // 1-: GemmM
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{})); // 2-: GemmK1
|
||||
|
||||
constexpr auto wei_gemmk0_gemmn_gemmk1_grid_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, // 0+: GemmK0
|
||||
Sequence<0, 0, 0, 0, 0>{}, // 1+: GemmN
|
||||
Sequence<0, 0, 0, 0, 0>{}), // 2+: GemmK1
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0>{}, // 0-: GemmK0
|
||||
Sequence<0, 0, 0, 0, 0>{}, // 1-: GemmN
|
||||
Sequence<0, 0, 0, 0, 0>{})); // 2-: GemmK1
|
||||
|
||||
constexpr auto out_m0_m1_m2_n_grid_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, // 0+: MRepeat
|
||||
Sequence<0, 0, 0, 0, 0>{}, // 1+: NRepeat
|
||||
Sequence<0, 0, 0, 0, 0>{}, // 2+: MWaves
|
||||
Sequence<0, 0, 0, 0, 0>{}, // 3+: NWaves
|
||||
Sequence<0, 0, 0, 0, 0>{}, // 4+: M0
|
||||
Sequence<0, 0, 0, 0, 0>{}, // 5+: M1
|
||||
Sequence<0, 0, 0, 0, 0>{}, // 6+: M2
|
||||
Sequence<0, 0, 0, 0, 0>{}), // 7+: N1
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0>{}, // 0-: MRepeat
|
||||
Sequence<0, 0, 0, 0, 0>{}, // 1-: NRepeat
|
||||
Sequence<0, 0, 0, 0, 0>{}, // 2-: MWaves
|
||||
Sequence<0, 0, 0, 0, 0>{}, // 3-: NWaves
|
||||
Sequence<0, 0, 0, 0, 0>{}, // 4-: M0
|
||||
Sequence<0, 0, 0, 0, 0>{}, // 5-: M1
|
||||
Sequence<0, 0, 0, 0, 0>{}, // 6-: M2
|
||||
Sequence<0, 0, 0, 0, 0>{})); // 7-: N1
|
||||
|
||||
constexpr auto in_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks =
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>{};
|
||||
|
||||
constexpr auto wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks =
|
||||
Sequence<0, 0, 0, 0, 0>{};
|
||||
|
||||
for(index_t i = 0; i < 5; ++i)
|
||||
{
|
||||
float ave_time = driver_gemm_xdlops_v2r3<
|
||||
BlockSize,
|
||||
TInWei,
|
||||
TAcc,
|
||||
TOut,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
decltype(in_gemmk0_gemmm_gemmk1_grid_desc),
|
||||
decltype(wei_gemmk0_gemmn_gemmk1_grid_desc),
|
||||
decltype(out_gemmm_gemmn_grid_desc),
|
||||
GemmMPerBlock,
|
||||
GemmNPerBlock,
|
||||
GemmKPerBlock,
|
||||
GemmMPerWave,
|
||||
GemmNPerWave,
|
||||
GemmK1,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1,
|
||||
GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1,
|
||||
Sequence<1, 0, 2>,
|
||||
Sequence<1, 0, 2>,
|
||||
2,
|
||||
GemmABlockTransferSrcScalarPerVector_GemmK1,
|
||||
GemmABlockTransferDstScalarPerVector_GemmK1,
|
||||
false, // don't move back src coordinate after threadwise copy
|
||||
GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1,
|
||||
GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1,
|
||||
Sequence<1, 0, 2>,
|
||||
Sequence<1, 0, 2>,
|
||||
2,
|
||||
GemmBBlockTransferSrcScalarPerVector_GemmK1,
|
||||
GemmBBlockTransferDstScalarPerVector_GemmK1,
|
||||
false, // don't move back src coordinate after threadwise copy
|
||||
Sequence<2, 3, 0, 1, 7, 5, 4, 6>,
|
||||
7,
|
||||
GemmCThreadTransferDstScalarPerVector,
|
||||
decltype(in_gemmk0_gemmm_gemmk1_grid_step_hacks),
|
||||
decltype(wei_gemmk0_gemmn_gemmk1_grid_step_hacks),
|
||||
decltype(out_m0_m1_m2_n_grid_step_hacks),
|
||||
decltype(in_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks),
|
||||
decltype(wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks),
|
||||
false // CAccessOrderMRepeatNRepeat
|
||||
>(static_cast<TInWei*>(in_n_hi_wi_c_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TInWei*>(wei_k_y_x_c_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TOut*>(out_n_ho_wo_k_device_buf.GetDeviceBuffer()),
|
||||
in_gemmk0_gemmm_gemmk1_grid_desc,
|
||||
wei_gemmk0_gemmn_gemmk1_grid_desc,
|
||||
out_gemmm_gemmn_grid_desc,
|
||||
in_gemmk0_gemmm_gemmk1_grid_step_hacks,
|
||||
wei_gemmk0_gemmn_gemmk1_grid_step_hacks,
|
||||
out_m0_m1_m2_n_grid_step_hacks,
|
||||
in_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks,
|
||||
wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks,
|
||||
nrepeat);
|
||||
|
||||
{
|
||||
const auto N = out_n_ho_wo_k_lengths[I0];
|
||||
const auto K = out_n_ho_wo_k_lengths[I3];
|
||||
const auto C = wei_k_y_x_c_lengths[I3];
|
||||
|
||||
const auto Ho = out_n_ho_wo_k_lengths[I1];
|
||||
const auto Wo = out_n_ho_wo_k_lengths[I2];
|
||||
|
||||
const auto Y = wei_k_y_x_c_lengths[I1];
|
||||
const auto X = wei_k_y_x_c_lengths[I2];
|
||||
|
||||
float perf = static_cast<float>((std::size_t(2) * N * K * Ho * Wo * C * Y * X)) /
|
||||
(std::size_t(1000) * 1000 * 1000) / ave_time;
|
||||
|
||||
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s"
|
||||
<< std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
// copy result back to host
|
||||
out_n_ho_wo_k_device_buf.FromDevice(out_n_ho_wo_k.mData.data());
|
||||
}
|
||||
@@ -0,0 +1,190 @@
|
||||
#include <unistd.h>
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "driver_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw.hpp"
|
||||
#include "driver_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw_outpad.hpp"
|
||||
|
||||
template <typename TInWei,
|
||||
ck::index_t InWeiVectorSize,
|
||||
typename TAcc,
|
||||
typename TOut,
|
||||
typename InLengths,
|
||||
typename WeiLengths,
|
||||
typename OutLengths,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads>
|
||||
void device_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw(
|
||||
const InLengths& in_n_c_hi_wi_lengths,
|
||||
const WeiLengths& wei_k_c_y_x_lengths,
|
||||
const OutLengths& out_n_k_ho_wo_lengths,
|
||||
const ConvStrides& conv_strides,
|
||||
const ConvDilations& conv_dilations,
|
||||
const InLeftPads& in_left_pads,
|
||||
const InRightPads& in_right_pads,
|
||||
const Tensor<TInWei>& in_n_c_hi_wi,
|
||||
const Tensor<TInWei>& wei_k_c_y_x,
|
||||
Tensor<TOut>& out_n_k_ho_wo,
|
||||
ck::index_t /* nrepeat */)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
std::cout << __func__ << std::endl;
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
const auto N = out_n_k_ho_wo_lengths[I0];
|
||||
const auto K = out_n_k_ho_wo_lengths[I1];
|
||||
const auto C = wei_k_c_y_x_lengths[I1];
|
||||
|
||||
const auto Hi = in_n_c_hi_wi_lengths[I2];
|
||||
const auto Wi = in_n_c_hi_wi_lengths[I3];
|
||||
|
||||
const auto Ho = out_n_k_ho_wo_lengths[I2];
|
||||
const auto Wo = out_n_k_ho_wo_lengths[I3];
|
||||
|
||||
const auto Y = wei_k_c_y_x_lengths[I2];
|
||||
const auto X = wei_k_c_y_x_lengths[I3];
|
||||
|
||||
const auto C0 = C / Number<InWeiVectorSize>{};
|
||||
const auto C1 = Number<InWeiVectorSize>{};
|
||||
|
||||
const auto K0 = K / Number<InWeiVectorSize>{};
|
||||
const auto K1 = Number<InWeiVectorSize>{};
|
||||
|
||||
Tensor<TInWei> in_n_c0_hi_wi_c1(
|
||||
HostTensorDescriptor(std::initializer_list<index_t>{N, C0, Hi, Wi, C1}));
|
||||
Tensor<TInWei> wei_k_c0_y_x_c1(
|
||||
HostTensorDescriptor(std::initializer_list<index_t>{K, C0, Y, X, C1}));
|
||||
Tensor<TOut> out_n_k0_ho_wo_k1(
|
||||
HostTensorDescriptor(std::initializer_list<index_t>{N, K0, Ho, Wo, K1}));
|
||||
|
||||
auto f_nchw2nc0hwc1 = [&](auto n, auto hi, auto wi, auto c) {
|
||||
in_n_c0_hi_wi_c1(n, c / InWeiVectorSize, hi, wi, c % InWeiVectorSize) =
|
||||
in_n_c_hi_wi(n, c, hi, wi);
|
||||
};
|
||||
|
||||
auto f_kcyx2kc0yxc1 = [&](auto k, auto y, auto x, auto c) {
|
||||
wei_k_c0_y_x_c1(k, c / InWeiVectorSize, y, x, c % InWeiVectorSize) =
|
||||
wei_k_c_y_x(k, c, y, x);
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(f_nchw2nc0hwc1, N, Hi, Wi, C)();
|
||||
make_ParallelTensorFunctor(f_kcyx2kc0yxc1, K, Y, X, C)();
|
||||
|
||||
DeviceMem in_n_c0_hi_wi_c1_device_buf(sizeof(TInWei) *
|
||||
in_n_c0_hi_wi_c1.mDesc.GetElementSpace());
|
||||
DeviceMem wei_k_c0_y_x_c1_device_buf(sizeof(TInWei) * wei_k_c0_y_x_c1.mDesc.GetElementSpace());
|
||||
DeviceMem out_n_k0_ho_wo_k1_device_buf(sizeof(TOut) *
|
||||
out_n_k0_ho_wo_k1.mDesc.GetElementSpace());
|
||||
|
||||
in_n_c0_hi_wi_c1_device_buf.ToDevice(in_n_c0_hi_wi_c1.mData.data());
|
||||
wei_k_c0_y_x_c1_device_buf.ToDevice(wei_k_c0_y_x_c1.mData.data());
|
||||
|
||||
const auto in_n_c0_hi_wi_desc = make_naive_tensor_descriptor_packed(make_tuple(N, C0, Hi, Wi));
|
||||
const auto wei_k_c0_y_x_desc = make_naive_tensor_descriptor_packed(make_tuple(K, C0, Y, X));
|
||||
const auto out_n_k0_ho_wo_k1_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(N, K0, Ho, Wo, K1));
|
||||
|
||||
#if 1
|
||||
// cdata = 64, BlockSize = 64, 16x8x32x4
|
||||
constexpr index_t BlockSize = 64;
|
||||
|
||||
constexpr index_t KPerBlock = 16;
|
||||
constexpr index_t HoPerBlock = 8;
|
||||
constexpr index_t WoPerBlock = 32;
|
||||
constexpr index_t EPerBlock = 1;
|
||||
|
||||
constexpr index_t KPerThread = KPerBlock;
|
||||
constexpr index_t HoPerThread = 2;
|
||||
constexpr index_t WoPerThread = 2;
|
||||
constexpr index_t EPerThread = EPerBlock;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_E_K = Sequence<3, 1>;
|
||||
using ABlockTransferThreadClusterLengths_E_K = Sequence<3 * EPerBlock, KPerBlock>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_E = 1;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K = 1;
|
||||
|
||||
constexpr index_t BThreadTransferSrcScalarPerVector_W = 1;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector_W = 16;
|
||||
|
||||
static_assert(KPerThread % CThreadTransferDstScalarPerVector_W == 0, "");
|
||||
#else
|
||||
constexpr index_t BlockSize = 64;
|
||||
|
||||
constexpr index_t KPerBlock = 16;
|
||||
constexpr index_t HoPerBlock = 8;
|
||||
constexpr index_t WoPerBlock = 32;
|
||||
constexpr index_t EPerBlock = 1;
|
||||
|
||||
constexpr index_t KPerThread = 16;
|
||||
constexpr index_t HoPerThread = 2;
|
||||
constexpr index_t WoPerThread = 2;
|
||||
constexpr index_t EPerThread = EPerBlock;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_E_K = Sequence<9, 1>;
|
||||
using ABlockTransferThreadClusterLengths_E_K = Sequence<EPerBlock, 16>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_E = 1;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K = 1;
|
||||
|
||||
constexpr index_t BThreadTransferSrcScalarPerVector_W = 1;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector_W = K1;
|
||||
|
||||
static_assert(KPerThread % CThreadTransferDstScalarPerVector_W == 0, "");
|
||||
#endif
|
||||
|
||||
constexpr auto conv_driver =
|
||||
#if 0
|
||||
DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_pad
|
||||
#else
|
||||
DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outpad
|
||||
#endif
|
||||
<BlockSize,
|
||||
typename vector_type<TInWei, InWeiVectorSize>::type,
|
||||
TAcc,
|
||||
TOut,
|
||||
KPerBlock,
|
||||
HoPerBlock,
|
||||
WoPerBlock,
|
||||
EPerBlock,
|
||||
KPerThread,
|
||||
HoPerThread,
|
||||
WoPerThread,
|
||||
EPerThread,
|
||||
ABlockTransferThreadSliceLengths_E_K,
|
||||
ABlockTransferThreadClusterLengths_E_K,
|
||||
ABlockTransferSrcScalarPerVector_E,
|
||||
ABlockTransferDstScalarPerVector_K,
|
||||
BThreadTransferSrcScalarPerVector_W,
|
||||
CThreadTransferDstScalarPerVector_W>{};
|
||||
|
||||
conv_driver.Run(wei_k_c0_y_x_desc,
|
||||
in_n_c0_hi_wi_desc,
|
||||
out_n_k0_ho_wo_k1_desc,
|
||||
conv_strides,
|
||||
conv_dilations,
|
||||
in_left_pads,
|
||||
in_right_pads,
|
||||
static_cast<typename vector_type<TInWei, InWeiVectorSize>::type*>(
|
||||
wei_k_c0_y_x_c1_device_buf.GetDeviceBuffer()),
|
||||
static_cast<typename vector_type<TInWei, InWeiVectorSize>::type*>(
|
||||
in_n_c0_hi_wi_c1_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TOut*>(out_n_k0_ho_wo_k1_device_buf.GetDeviceBuffer()));
|
||||
|
||||
out_n_k0_ho_wo_k1_device_buf.FromDevice(out_n_k0_ho_wo_k1.mData.data());
|
||||
|
||||
auto f_nk0hwk1_to_nkhw = [&](auto n, auto k, auto ho, auto wo) {
|
||||
out_n_k_ho_wo(n, k, ho, wo) =
|
||||
out_n_k0_ho_wo_k1(n, k / InWeiVectorSize, ho, wo, k % InWeiVectorSize);
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(f_nk0hwk1_to_nkhw, N, K, Ho, Wo)();
|
||||
}
|
||||
@@ -0,0 +1,241 @@
|
||||
#pragma once
|
||||
#include <unistd.h>
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "transform_forward_convolution_into_gemm_v6r1_nchw_kcyx_nkhw.hpp"
|
||||
#include "driver_contraction_dlops_v1r2.hpp"
|
||||
|
||||
template <typename TInWei,
|
||||
typename TAcc,
|
||||
typename TOut,
|
||||
typename InLengths,
|
||||
typename WeiLengths,
|
||||
typename OutLengths,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads>
|
||||
void device_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw(
|
||||
const InLengths& in_n_c_hi_wi_lengths,
|
||||
const WeiLengths& wei_k_c_y_x_lengths,
|
||||
const OutLengths& out_n_k_ho_wo_lengths,
|
||||
const ConvStrides& conv_strides,
|
||||
const ConvDilations& conv_dilations,
|
||||
const InLeftPads& in_left_pads,
|
||||
const InRightPads& in_right_pads,
|
||||
const Tensor<TInWei>& in_n_c_hi_wi,
|
||||
const Tensor<TInWei>& wei_k_c_y_x,
|
||||
Tensor<TOut>& out_n_k_ho_wo,
|
||||
ck::index_t nrepeat)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
std::cout << __func__ << std::endl;
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
|
||||
DeviceMem in_n_c_hi_wi_device_buf(sizeof(TInWei) * in_n_c_hi_wi.mDesc.GetElementSpace());
|
||||
DeviceMem wei_k_c_y_x_device_buf(sizeof(TInWei) * wei_k_c_y_x.mDesc.GetElementSpace());
|
||||
DeviceMem out_n_k_ho_wo_device_buf(sizeof(TOut) * out_n_k_ho_wo.mDesc.GetElementSpace());
|
||||
|
||||
in_n_c_hi_wi_device_buf.ToDevice(in_n_c_hi_wi.mData.data());
|
||||
wei_k_c_y_x_device_buf.ToDevice(wei_k_c_y_x.mData.data());
|
||||
out_n_k_ho_wo_device_buf.ToDevice(out_n_k_ho_wo.mData.data());
|
||||
|
||||
const auto in_desc_n_c_hi_wi = make_naive_tensor_descriptor_packed(in_n_c_hi_wi_lengths);
|
||||
const auto wei_desc_k_c_y_x = make_naive_tensor_descriptor_packed(wei_k_c_y_x_lengths);
|
||||
const auto out_desc_n_k_ho_wo = make_naive_tensor_descriptor_packed(out_n_k_ho_wo_lengths);
|
||||
|
||||
#if 1
|
||||
// [8, 1, 128, 1] * [8, 4, 32, 1] = [1, 128, 4, 32] for fp32
|
||||
// cdata = 64, BlockSize = 256
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GN0 = 4;
|
||||
constexpr index_t GK1 = 1;
|
||||
|
||||
constexpr index_t GM1PerBlockGM11 = 128;
|
||||
constexpr index_t GN1PerBlockGN11 = 32;
|
||||
constexpr index_t GK0PerBlock = 8;
|
||||
|
||||
constexpr index_t BM1PerThreadBM11 = 4;
|
||||
constexpr index_t BN1PerThreadBN11 = 4;
|
||||
constexpr index_t BK0PerThread = 1;
|
||||
|
||||
using BM10BN10ThreadClusterBM10Xs = Sequence<8, 2>;
|
||||
using BM10BN10ThreadClusterBN10Xs = Sequence<8, 2>;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1 = Sequence<4, 1, 1, 1, 1>;
|
||||
using ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1 = Sequence<2, 1, 1, 128, 1>;
|
||||
|
||||
using ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1 = Sequence<4, 1, 1, 1, 1>;
|
||||
using ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1 = Sequence<1, 1, 1, 1, 1>;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1 = Sequence<1, 4, 1, 1, 1>;
|
||||
using BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1 = Sequence<8, 1, 1, 32, 1>;
|
||||
|
||||
using BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1 = Sequence<1, 1, 1, 1, 1>;
|
||||
using BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1 = Sequence<1, 1, 1, 1, 1>;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector_BN1 = 1;
|
||||
#elif 1
|
||||
// [8, 1, 128, 2] * [8, 4, 32, 2] = [1, 128, 4, 32] for fp16
|
||||
// cdata = 64, BlockSize = 256
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GN0 = 4;
|
||||
constexpr index_t GK1 = 2;
|
||||
|
||||
constexpr index_t GM1PerBlockGM11 = 128;
|
||||
constexpr index_t GN1PerBlockGN11 = 32;
|
||||
constexpr index_t GK0PerBlock = 8;
|
||||
|
||||
constexpr index_t BM1PerThreadBM11 = 4;
|
||||
constexpr index_t BN1PerThreadBN11 = 4;
|
||||
constexpr index_t BK0PerThread = 1;
|
||||
|
||||
using BM10BN10ThreadClusterBM10Xs = Sequence<8, 2>;
|
||||
using BM10BN10ThreadClusterBN10Xs = Sequence<8, 2>;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1 = Sequence<4, 1, 1, 1, 2>;
|
||||
using ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1 = Sequence<2, 1, 1, 128, 1>;
|
||||
|
||||
using ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1 = Sequence<4, 1, 1, 1, 1>;
|
||||
using ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1 = Sequence<1, 1, 1, 1, 2>;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1 = Sequence<1, 4, 1, 1, 2>;
|
||||
using BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1 = Sequence<8, 1, 1, 32, 1>;
|
||||
|
||||
using BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1 = Sequence<1, 1, 1, 1, 1>;
|
||||
using BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1 = Sequence<1, 1, 1, 1, 2>;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector_BN1 = 1;
|
||||
#endif
|
||||
|
||||
const auto descs =
|
||||
transform_forward_convolution_into_contraction_v6r1_nchw_kcyx_nkhw_pad(wei_desc_k_c_y_x,
|
||||
in_desc_n_c_hi_wi,
|
||||
out_desc_n_k_ho_wo,
|
||||
conv_strides,
|
||||
conv_dilations,
|
||||
in_left_pads,
|
||||
in_right_pads,
|
||||
Number<GN0>{},
|
||||
Number<GK1>{});
|
||||
|
||||
const auto wei_grid_desc_gk0_gm0_gm1_gk1 = descs[I0];
|
||||
const auto in_grid_desc_gk0_gn0_gn1_gk1 = descs[I1];
|
||||
const auto out_grid_desc_gm0_gm1_gn0_gn1 = descs[I2];
|
||||
|
||||
// HACK: hacks that control index calculation when iterating over A, B, C matrix
|
||||
constexpr auto wei_grid_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 0+: GK0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 1+: GM0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 2+: GM10
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 3+: GM11
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0>{}), // 4+: GK1
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 0-: GK0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 1-: GM0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 2-: GM10
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 3-: GM11
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0>{})); // 4-: GK1
|
||||
|
||||
constexpr auto in_grid_step_hacks = make_tuple(
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: GK0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 1+: GN0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 2+: GN10
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 3+: GN11
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 4+: GK1
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: GK0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 1-: GN0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 2-: GN10
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 3-: GN11
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 4-: GK1
|
||||
|
||||
constexpr auto out_grid_step_hacks = make_tuple(
|
||||
make_tuple(
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: GM10
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0>{}, // 1+: BM0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0>{}, // 2+: BM1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: GN10
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0>{}, // 4+: BN0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0>{}), // 5+: GN1
|
||||
make_tuple(
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: GM10
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0>{}, // 1-: BM0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0>{}, // 2-: BM1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: GN10
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{}, // 4-: BN0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{})); // 5-: GN1
|
||||
|
||||
constexpr auto wei_grid_move_slice_window_step_hacks = Sequence<0, 0, 0, 0, 0, 0, 0>{};
|
||||
|
||||
constexpr auto in_grid_move_slice_window_step_hacks =
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 2, 0, 0, 0, 0, 0>{};
|
||||
|
||||
for(index_t i = 0; i < 5; ++i)
|
||||
{
|
||||
float ave_time = driver_contraction_dlops_v1r2<
|
||||
BlockSize,
|
||||
TInWei,
|
||||
TAcc,
|
||||
TOut,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
decltype(wei_grid_desc_gk0_gm0_gm1_gk1),
|
||||
decltype(in_grid_desc_gk0_gn0_gn1_gk1),
|
||||
decltype(out_grid_desc_gm0_gm1_gn0_gn1),
|
||||
GM1PerBlockGM11,
|
||||
GN1PerBlockGN11,
|
||||
GK0PerBlock,
|
||||
BM1PerThreadBM11,
|
||||
BN1PerThreadBN11,
|
||||
BK0PerThread,
|
||||
BM10BN10ThreadClusterBM10Xs,
|
||||
BM10BN10ThreadClusterBN10Xs,
|
||||
ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1,
|
||||
ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1,
|
||||
Sequence<1, 2, 3, 0, 4>, // ABlockTransferThreadClusterArrangeOrder
|
||||
Sequence<3, 2, 1, 0, 4>, // ABlockTransferSrcAccessOrder
|
||||
ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1,
|
||||
ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1,
|
||||
Sequence<0, 1, 2, 3, 4>, // ABlockTransferSrcVectorTensorContiguousDimOrder
|
||||
BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1,
|
||||
BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1,
|
||||
Sequence<0, 4, 1, 2, 3>, // BBlockTransferThreadClusterArrangeOrder
|
||||
Sequence<4, 3, 2, 0, 1>, // BBlockTransferSrcAccessOrder
|
||||
BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1,
|
||||
BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1,
|
||||
Sequence<0, 1, 2, 3, 4>, // BBlockTransferSrcVectorTensorContiguousDimOrder
|
||||
Sequence<3, 4, 5, 0, 1, 2>, // CThreadTransferSrcDstAccessOrder
|
||||
5, // CThreadTransferSrcDstVectorDim
|
||||
CThreadTransferDstScalarPerVector_BN1,
|
||||
decltype(wei_grid_step_hacks),
|
||||
decltype(in_grid_step_hacks),
|
||||
decltype(out_grid_step_hacks),
|
||||
decltype(wei_grid_move_slice_window_step_hacks),
|
||||
decltype(in_grid_move_slice_window_step_hacks)>(
|
||||
static_cast<TInWei*>(wei_k_c_y_x_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TInWei*>(in_n_c_hi_wi_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TOut*>(out_n_k_ho_wo_device_buf.GetDeviceBuffer()),
|
||||
wei_grid_desc_gk0_gm0_gm1_gk1,
|
||||
in_grid_desc_gk0_gn0_gn1_gk1,
|
||||
out_grid_desc_gm0_gm1_gn0_gn1,
|
||||
wei_grid_step_hacks,
|
||||
in_grid_step_hacks,
|
||||
out_grid_step_hacks,
|
||||
wei_grid_move_slice_window_step_hacks,
|
||||
in_grid_move_slice_window_step_hacks,
|
||||
nrepeat);
|
||||
|
||||
float perf = static_cast<float>(calculate_convolution_flops(
|
||||
in_desc_n_c_hi_wi, wei_desc_k_c_y_x, out_desc_n_k_ho_wo)) /
|
||||
(std::size_t(1000) * 1000 * 1000) / ave_time;
|
||||
|
||||
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl;
|
||||
}
|
||||
|
||||
// copy result back to host
|
||||
out_n_k_ho_wo_device_buf.FromDevice(out_n_k_ho_wo.mData.data());
|
||||
}
|
||||
286
host/driver_offline/include/driver_contraction_dlops_v1r2.hpp
Normal file
286
host/driver_offline/include/driver_contraction_dlops_v1r2.hpp
Normal file
@@ -0,0 +1,286 @@
|
||||
#ifndef DRIVER_CONTRACTION_DLOPS_V1R2_HPP
|
||||
#define DRIVER_CONTRACTION_DLOPS_V1R2_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
#include "gridwise_contraction_dlops_v1r2.hpp"
|
||||
|
||||
template <ck::index_t BlockSize,
|
||||
typename FloatAB,
|
||||
typename FloatAcc,
|
||||
typename FloatC,
|
||||
ck::InMemoryDataOperationEnum_t CGlobalMemoryDataOperation,
|
||||
typename AGridDesc_GK0_GM0_GM1_GK1,
|
||||
typename BGridDesc_GK0_GN0_GN1_GK1,
|
||||
typename CGridDesc_GM0_GM1_GN0_GN1,
|
||||
ck::index_t GM1PerBlockGM11,
|
||||
ck::index_t GN1PerBlockGN11,
|
||||
ck::index_t GK0PerBlock,
|
||||
ck::index_t BM1PerThreadBM11,
|
||||
ck::index_t BN1PerThreadBN11,
|
||||
ck::index_t BK0PerThread,
|
||||
typename BM10BN10ThreadClusterBM10Xs,
|
||||
typename BM10BN10ThreadClusterBN10Xs,
|
||||
typename ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1,
|
||||
typename ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1,
|
||||
typename ABlockTransferThreadClusterArrangeOrder,
|
||||
typename ABlockTransferSrcAccessOrder,
|
||||
typename ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1,
|
||||
typename ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1,
|
||||
typename ABlockTransferSrcVectorTensorContiguousDimOrder,
|
||||
typename BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1,
|
||||
typename BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1,
|
||||
typename BBlockTransferThreadClusterArrangeOrder,
|
||||
typename BBlockTransferSrcAccessOrder,
|
||||
typename BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1,
|
||||
typename BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1,
|
||||
typename BBlockTransferSrcVectorTensorContiguousDimOrder,
|
||||
typename CThreadTransferSrcDstAccessOrder,
|
||||
ck::index_t CThreadTransferSrcDstVectorDim,
|
||||
ck::index_t CThreadTransferDstScalarPerVector,
|
||||
typename AGridStepHacks,
|
||||
typename BGridStepHacks,
|
||||
typename CGridStepHacks,
|
||||
typename AGridMoveSliceWindowStepHacks,
|
||||
typename BGridMoveSliceWindowStepHacks>
|
||||
__host__ float
|
||||
driver_contraction_dlops_v1r2(const FloatAB* p_a_grid,
|
||||
const FloatAB* p_b_grid,
|
||||
FloatC* p_c_grid,
|
||||
const AGridDesc_GK0_GM0_GM1_GK1& a_grid_desc_gk0_gm0_gm1_gk1,
|
||||
const BGridDesc_GK0_GN0_GN1_GK1& b_grid_desc_gk0_gn0_gn1_gk1,
|
||||
const CGridDesc_GM0_GM1_GN0_GN1& c_grid_desc_gm0_gm1_gn0_gn1,
|
||||
AGridStepHacks,
|
||||
BGridStepHacks,
|
||||
CGridStepHacks,
|
||||
AGridMoveSliceWindowStepHacks,
|
||||
BGridMoveSliceWindowStepHacks,
|
||||
ck::index_t nrepeat)
|
||||
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
constexpr auto I4 = Number<4>{};
|
||||
constexpr auto I5 = Number<5>{};
|
||||
|
||||
// GEMM
|
||||
using GridwiseContraction =
|
||||
GridwiseContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN0_GN1<
|
||||
BlockSize,
|
||||
FloatAB,
|
||||
FloatAcc,
|
||||
FloatC,
|
||||
CGlobalMemoryDataOperation,
|
||||
AGridDesc_GK0_GM0_GM1_GK1,
|
||||
BGridDesc_GK0_GN0_GN1_GK1,
|
||||
CGridDesc_GM0_GM1_GN0_GN1,
|
||||
GM1PerBlockGM11,
|
||||
GN1PerBlockGN11,
|
||||
GK0PerBlock,
|
||||
BM1PerThreadBM11,
|
||||
BN1PerThreadBN11,
|
||||
BK0PerThread,
|
||||
BM10BN10ThreadClusterBM10Xs,
|
||||
BM10BN10ThreadClusterBN10Xs,
|
||||
ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1,
|
||||
ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1,
|
||||
ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1,
|
||||
ABlockTransferSrcVectorTensorContiguousDimOrder,
|
||||
BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1,
|
||||
BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1,
|
||||
BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1,
|
||||
BBlockTransferSrcVectorTensorContiguousDimOrder,
|
||||
CThreadTransferSrcDstAccessOrder,
|
||||
CThreadTransferSrcDstVectorDim,
|
||||
CThreadTransferDstScalarPerVector,
|
||||
AGridStepHacks,
|
||||
BGridStepHacks,
|
||||
CGridStepHacks,
|
||||
AGridMoveSliceWindowStepHacks,
|
||||
BGridMoveSliceWindowStepHacks>;
|
||||
|
||||
const auto GK0 = a_grid_desc_gk0_gm0_gm1_gk1.GetLength(I0);
|
||||
|
||||
if(!GridwiseContraction::CheckValidity(
|
||||
a_grid_desc_gk0_gm0_gm1_gk1, b_grid_desc_gk0_gn0_gn1_gk1, c_grid_desc_gm0_gm1_gn0_gn1))
|
||||
{
|
||||
throw std::runtime_error("wrong! "
|
||||
"GridwiseContraction_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_"
|
||||
"GM0_GM1_GN0_GN1 has invalid setting");
|
||||
}
|
||||
|
||||
const auto a_grid_desc_gk0_gm0_gm10_gm11_gk1 =
|
||||
GridwiseContraction::MakeAGridDescriptor_GK0_GM0_GM10_GM11_GK1(a_grid_desc_gk0_gm0_gm1_gk1);
|
||||
const auto b_grid_desc_gk0_gn0_gn10_gn11_gk1 =
|
||||
GridwiseContraction::MakeBGridDescriptor_GK0_GN0_GN10_GN11_GK1(b_grid_desc_gk0_gn0_gn1_gk1);
|
||||
|
||||
using AGridDesc_GK0_GM0_GM10_GM11_GK1 = decltype(a_grid_desc_gk0_gm0_gm10_gm11_gk1);
|
||||
using BGridDesc_GK0_GN0_GN10_GN11_GK1 = decltype(b_grid_desc_gk0_gn0_gn10_gn11_gk1);
|
||||
|
||||
// c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1
|
||||
const auto c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1 =
|
||||
GridwiseContraction::MakeCGridDescriptor_GM10_BM0_BM1_GN10_BN0_BN1(
|
||||
c_grid_desc_gm0_gm1_gn0_gn1);
|
||||
|
||||
using CGridDesc_GM10_BM0_BM1_GN10_BN0_BN1 = decltype(c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1);
|
||||
|
||||
// c_grid_block_cluster_blockid_to_gm10_gn10
|
||||
const auto c_grid_block_cluster_blockid_to_gm10_gn10 =
|
||||
GridwiseContraction::MakeCGridBlockCluster_BlockId_To_GM10_GN10(
|
||||
c_grid_desc_gm0_gm1_gn0_gn1);
|
||||
|
||||
using CGridBlockCluster_BlockId_To_GM10_GN10 =
|
||||
decltype(c_grid_block_cluster_blockid_to_gm10_gn10);
|
||||
|
||||
const index_t grid_size = GridwiseContraction::CalculateGridSize(c_grid_desc_gm0_gm1_gn0_gn1);
|
||||
|
||||
const bool has_main_k_block_loop = GridwiseContraction::CalculateHasMainKBlockLoop(GK0);
|
||||
|
||||
const bool has_double_tail_k_block_loop =
|
||||
GridwiseContraction::CalculateHasDoubleTailKBlockLoop(GK0);
|
||||
|
||||
{
|
||||
std::cout << "a_grid_desc_gk0_gm0_gm10_gm11_gk1{"
|
||||
<< a_grid_desc_gk0_gm0_gm10_gm11_gk1.GetLength(I0) << ", "
|
||||
<< a_grid_desc_gk0_gm0_gm10_gm11_gk1.GetLength(I1) << ", "
|
||||
<< a_grid_desc_gk0_gm0_gm10_gm11_gk1.GetLength(I2) << ", "
|
||||
<< a_grid_desc_gk0_gm0_gm10_gm11_gk1.GetLength(I3) << ", "
|
||||
<< a_grid_desc_gk0_gm0_gm10_gm11_gk1.GetLength(I4) << "}" << std::endl;
|
||||
|
||||
std::cout << "b_grid_desc_gk0_gn0_gn10_gn11_gk1{"
|
||||
<< b_grid_desc_gk0_gn0_gn10_gn11_gk1.GetLength(I0) << ", "
|
||||
<< b_grid_desc_gk0_gn0_gn10_gn11_gk1.GetLength(I1) << ", "
|
||||
<< b_grid_desc_gk0_gn0_gn10_gn11_gk1.GetLength(I2) << ", "
|
||||
<< b_grid_desc_gk0_gn0_gn10_gn11_gk1.GetLength(I3) << ", "
|
||||
<< b_grid_desc_gk0_gn0_gn10_gn11_gk1.GetLength(I4) << "}" << std::endl;
|
||||
|
||||
std::cout << "c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1{ "
|
||||
<< c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1.GetLength(I0) << ", "
|
||||
<< c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1.GetLength(I1) << ", "
|
||||
<< c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1.GetLength(I2) << ", "
|
||||
<< c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1.GetLength(I3) << ", "
|
||||
<< c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1.GetLength(I4) << ", "
|
||||
<< c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1.GetLength(I5) << "}" << std::endl;
|
||||
}
|
||||
|
||||
float ave_time = 0;
|
||||
|
||||
if(has_main_k_block_loop && has_double_tail_k_block_loop)
|
||||
{
|
||||
const auto kernel = kernel_contraction_dlops_v1r2<
|
||||
GridwiseContraction,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AGridDesc_GK0_GM0_GM10_GM11_GK1>,
|
||||
remove_reference_t<BGridDesc_GK0_GN0_GN10_GN11_GK1>,
|
||||
remove_reference_t<CGridDesc_GM10_BM0_BM1_GN10_BN0_BN1>,
|
||||
remove_reference_t<CGridBlockCluster_BlockId_To_GM10_GN10>,
|
||||
true,
|
||||
true>;
|
||||
|
||||
ave_time = launch_and_time_kernel(kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
a_grid_desc_gk0_gm0_gm10_gm11_gk1,
|
||||
b_grid_desc_gk0_gn0_gn10_gn11_gk1,
|
||||
c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1,
|
||||
c_grid_block_cluster_blockid_to_gm10_gn10);
|
||||
}
|
||||
else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
|
||||
{
|
||||
const auto kernel = kernel_contraction_dlops_v1r2<
|
||||
GridwiseContraction,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AGridDesc_GK0_GM0_GM10_GM11_GK1>,
|
||||
remove_reference_t<BGridDesc_GK0_GN0_GN10_GN11_GK1>,
|
||||
remove_reference_t<CGridDesc_GM10_BM0_BM1_GN10_BN0_BN1>,
|
||||
remove_reference_t<CGridBlockCluster_BlockId_To_GM10_GN10>,
|
||||
true,
|
||||
false>;
|
||||
|
||||
ave_time = launch_and_time_kernel(kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
a_grid_desc_gk0_gm0_gm10_gm11_gk1,
|
||||
b_grid_desc_gk0_gn0_gn10_gn11_gk1,
|
||||
c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1,
|
||||
c_grid_block_cluster_blockid_to_gm10_gn10);
|
||||
}
|
||||
else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
|
||||
{
|
||||
const auto kernel = kernel_contraction_dlops_v1r2<
|
||||
GridwiseContraction,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AGridDesc_GK0_GM0_GM10_GM11_GK1>,
|
||||
remove_reference_t<BGridDesc_GK0_GN0_GN10_GN11_GK1>,
|
||||
remove_reference_t<CGridDesc_GM10_BM0_BM1_GN10_BN0_BN1>,
|
||||
remove_reference_t<CGridBlockCluster_BlockId_To_GM10_GN10>,
|
||||
false,
|
||||
true>;
|
||||
|
||||
ave_time = launch_and_time_kernel(kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
a_grid_desc_gk0_gm0_gm10_gm11_gk1,
|
||||
b_grid_desc_gk0_gn0_gn10_gn11_gk1,
|
||||
c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1,
|
||||
c_grid_block_cluster_blockid_to_gm10_gn10);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel = kernel_contraction_dlops_v1r2<
|
||||
GridwiseContraction,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AGridDesc_GK0_GM0_GM10_GM11_GK1>,
|
||||
remove_reference_t<BGridDesc_GK0_GN0_GN10_GN11_GK1>,
|
||||
remove_reference_t<CGridDesc_GM10_BM0_BM1_GN10_BN0_BN1>,
|
||||
remove_reference_t<CGridBlockCluster_BlockId_To_GM10_GN10>,
|
||||
false,
|
||||
false>;
|
||||
|
||||
ave_time = launch_and_time_kernel(kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
a_grid_desc_gk0_gm0_gm10_gm11_gk1,
|
||||
b_grid_desc_gk0_gn0_gn10_gn11_gk1,
|
||||
c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1,
|
||||
c_grid_block_cluster_blockid_to_gm10_gn10);
|
||||
}
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
#endif
|
||||
@@ -0,0 +1,349 @@
|
||||
#ifndef DRIVER_CONVOLUTION_FORWARD_IMPLICIT_GEMM_V5R1_NCHW_KCYX_NKHW_HPP
|
||||
#define DRIVER_CONVOLUTION_FORWARD_IMPLICIT_GEMM_V5R1_NCHW_KCYX_NKHW_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
#include "gridwise_gemm_dlops_v2.hpp"
|
||||
#include "gridwise_operation_wrapper.hpp"
|
||||
|
||||
template <ck::index_t BlockSize,
|
||||
typename FloatAB,
|
||||
typename FloatAcc,
|
||||
typename FloatC,
|
||||
ck::index_t KPerBlock,
|
||||
ck::index_t HoPerBlock,
|
||||
ck::index_t WoPerBlock,
|
||||
ck::index_t EPerBlock,
|
||||
ck::index_t KPerThread,
|
||||
ck::index_t HoPerThread,
|
||||
ck::index_t WoPerThread,
|
||||
ck::index_t EPerThread,
|
||||
typename ABlockTransferThreadSliceLengths_E_K,
|
||||
typename ABlockTransferThreadClusterLengths_E_K,
|
||||
ck::index_t ABlockTransferSrcScalarPerVector_E,
|
||||
ck::index_t ABlockTransferDstScalarPerVector_K,
|
||||
ck::index_t BThreadTransferSrcScalarPerVector_W,
|
||||
ck::index_t CThreadTransferDstScalarPerVector_W>
|
||||
struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_pad
|
||||
{
|
||||
template <typename... Wei,
|
||||
typename... In,
|
||||
typename... Out,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads>
|
||||
__host__ void Run(const ck::TensorDescriptor<Wei...>& wei_k_c_y_x_global_desc,
|
||||
const ck::TensorDescriptor<In...>& in_n_c_hi_wi_global_desc,
|
||||
const ck::TensorDescriptor<Out...>& out_n_k0_ho_wo_k1_global_desc,
|
||||
const ConvStrides& conv_strides,
|
||||
const ConvDilations& conv_dilations,
|
||||
const InLeftPads& in_left_pads,
|
||||
const InRightPads& in_right_pads,
|
||||
const FloatAB* __restrict__ p_wei_global,
|
||||
const FloatAB* __restrict__ p_in_global,
|
||||
FloatC* __restrict__ p_out_global) const
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
constexpr auto I4 = Number<4>{};
|
||||
|
||||
const auto N = in_n_c_hi_wi_global_desc.GetLength(I0);
|
||||
const auto C = in_n_c_hi_wi_global_desc.GetLength(I1);
|
||||
const auto K0 = out_n_k0_ho_wo_k1_global_desc.GetLength(I1);
|
||||
|
||||
const auto Hi = in_n_c_hi_wi_global_desc.GetLength(I2);
|
||||
const auto Wi = in_n_c_hi_wi_global_desc.GetLength(I3);
|
||||
|
||||
const auto Ho = out_n_k0_ho_wo_k1_global_desc.GetLength(I2);
|
||||
const auto Wo = out_n_k0_ho_wo_k1_global_desc.GetLength(I3);
|
||||
|
||||
const auto K1 = out_n_k0_ho_wo_k1_global_desc.GetLength(I4);
|
||||
|
||||
const auto K = wei_k_c_y_x_global_desc.GetLength(I0);
|
||||
const auto Y = wei_k_c_y_x_global_desc.GetLength(I2);
|
||||
const auto X = wei_k_c_y_x_global_desc.GetLength(I3);
|
||||
|
||||
const auto ConvStrideH = conv_strides[I0];
|
||||
const auto ConvStrideW = conv_strides[I1];
|
||||
|
||||
const auto ConvDilationH = conv_dilations[I0];
|
||||
const auto ConvDilationW = conv_dilations[I1];
|
||||
|
||||
const auto InLeftPadH = in_left_pads[I0];
|
||||
const auto InLeftPadW = in_left_pads[I1];
|
||||
|
||||
const auto InRightPadH = in_right_pads[I0];
|
||||
const auto InRightPadW = in_right_pads[I1];
|
||||
|
||||
// weight tensor
|
||||
const auto wei_e_k_global_desc = transform_tensor_descriptor(
|
||||
make_naive_tensor_descriptor_packed(make_tuple(K, C * Y * X)),
|
||||
make_tuple(make_pass_through_transform(K), make_pass_through_transform(C * Y * X)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}));
|
||||
|
||||
// input tensor
|
||||
const auto in_n_c_hip_wip_global_desc = transform_tensor_descriptor(
|
||||
in_n_c_hi_wi_global_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_pass_through_transform(C),
|
||||
make_pad_transform(Hi, InLeftPadH, InRightPadH),
|
||||
make_pad_transform(Wi, InLeftPadW, InRightPadW)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
const auto in_n_c_y_ho_x_wo_global_desc = transform_tensor_descriptor(
|
||||
in_n_c_hip_wip_global_desc,
|
||||
make_tuple(
|
||||
make_pass_through_transform(N),
|
||||
make_pass_through_transform(C),
|
||||
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
|
||||
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
|
||||
|
||||
const auto in_e_n_ho_wo_global_desc = transform_tensor_descriptor(
|
||||
in_n_c_y_ho_x_wo_global_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(C, Y, X)),
|
||||
make_pass_through_transform(N),
|
||||
make_pass_through_transform(Ho),
|
||||
make_pass_through_transform(Wo)),
|
||||
make_tuple(Sequence<1, 2, 4>{}, Sequence<0>{}, Sequence<3>{}, Sequence<5>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
// output tensor
|
||||
const auto out_k_n_ho_wo_global_desc = transform_tensor_descriptor(
|
||||
make_naive_tensor_descriptor_packed(make_tuple(N, K0, Ho, Wo, K1)),
|
||||
make_tuple(make_merge_transform(make_tuple(K0, K1)),
|
||||
make_pass_through_transform(N),
|
||||
make_pass_through_transform(Ho),
|
||||
make_pass_through_transform(Wo)),
|
||||
make_tuple(Sequence<1, 4>{}, Sequence<0>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
const auto E = C * Y * X;
|
||||
|
||||
if(!((K % KPerBlock) == 0 && (Ho % HoPerBlock) == 0 && (Wo % WoPerBlock) == 0 &&
|
||||
(E % EPerBlock) == 0))
|
||||
{
|
||||
throw std::runtime_error("wrong! GEMM size no divisible");
|
||||
}
|
||||
|
||||
// hack to control index calculation when iterating over a_k_m_global tensor
|
||||
constexpr auto a_e_k_global_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{}));
|
||||
|
||||
constexpr auto a_e_k_global_move_slice_window_step_hack = Sequence<0, 0, 0>{};
|
||||
|
||||
constexpr auto b_e_n_ho_wo_global_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}));
|
||||
|
||||
constexpr auto b_e_n_ho_wo_global_move_slice_window_step_hack =
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{};
|
||||
|
||||
// hack to control index calculation when iterating over c_m0_m1_n0_n1_global tensor
|
||||
// hack for NKHW format
|
||||
constexpr auto c_k_n_ho_wo_global_tensor_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 1, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 2, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{}));
|
||||
|
||||
#if 1
|
||||
// GEMM
|
||||
using gridwise_gemm = GridwiseGemmDlops_km_kn_mn_v3<
|
||||
BlockSize,
|
||||
FloatAB,
|
||||
FloatAcc,
|
||||
FloatC,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
decltype(wei_e_k_global_desc),
|
||||
decltype(in_e_n_ho_wo_global_desc),
|
||||
decltype(out_k_n_ho_wo_global_desc),
|
||||
KPerBlock,
|
||||
HoPerBlock,
|
||||
WoPerBlock,
|
||||
EPerBlock,
|
||||
KPerThread,
|
||||
HoPerThread,
|
||||
WoPerThread,
|
||||
EPerThread,
|
||||
ABlockTransferThreadSliceLengths_E_K,
|
||||
ABlockTransferThreadClusterLengths_E_K,
|
||||
Sequence<1, 0>,
|
||||
Sequence<1, 0>,
|
||||
0,
|
||||
ABlockTransferSrcScalarPerVector_E,
|
||||
ABlockTransferDstScalarPerVector_K,
|
||||
false, // don't move back src coordinate after threadwise copy
|
||||
Sequence<0, 2, 3, 1>,
|
||||
3,
|
||||
BThreadTransferSrcScalarPerVector_W,
|
||||
false, // don't move back src coordinate after threadwise copy, which will be fused with
|
||||
// MoveSrcSliceWindow() to save addr computation
|
||||
Sequence<0, 2, 3, 1>,
|
||||
0,
|
||||
CThreadTransferDstScalarPerVector_W,
|
||||
decltype(a_e_k_global_step_hacks),
|
||||
decltype(b_e_n_ho_wo_global_step_hacks),
|
||||
decltype(c_k_n_ho_wo_global_tensor_step_hacks),
|
||||
decltype(a_e_k_global_move_slice_window_step_hack),
|
||||
decltype(b_e_n_ho_wo_global_move_slice_window_step_hack)>;
|
||||
|
||||
const auto GridSize = (K / KPerBlock) * (Ho / HoPerBlock) * (Wo / WoPerBlock) * N;
|
||||
|
||||
const bool has_main_k_block_loop = (E + EPerBlock) / (2 * EPerBlock) > 1;
|
||||
|
||||
const bool has_double_tail_k_block_loop = (E / EPerBlock) % 2 == 0;
|
||||
|
||||
index_t nrepeat = 100;
|
||||
|
||||
for(index_t i = 0; i < 5; ++i)
|
||||
{
|
||||
std::cout << "Start running " << nrepeat << " times..." << std::endl;
|
||||
|
||||
KernelTimer timer;
|
||||
timer.Start();
|
||||
std::cout << "has_main_k_block_loop: " << has_main_k_block_loop
|
||||
<< " has_double_tail_k_block_loop: " << has_double_tail_k_block_loop
|
||||
<< std::endl;
|
||||
|
||||
for(index_t j = 0; j < nrepeat; ++j)
|
||||
{
|
||||
if(has_main_k_block_loop && has_double_tail_k_block_loop)
|
||||
{
|
||||
const auto kernel = run_gridwise_operation<gridwise_gemm,
|
||||
decltype(wei_e_k_global_desc),
|
||||
const FloatAB*,
|
||||
decltype(in_e_n_ho_wo_global_desc),
|
||||
const FloatAB*,
|
||||
decltype(out_k_n_ho_wo_global_desc),
|
||||
FloatC*,
|
||||
integral_constant<bool, true>,
|
||||
integral_constant<bool, true>>;
|
||||
|
||||
launch_kernel(kernel,
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
wei_e_k_global_desc,
|
||||
p_wei_global,
|
||||
in_e_n_ho_wo_global_desc,
|
||||
p_in_global,
|
||||
out_k_n_ho_wo_global_desc,
|
||||
p_out_global,
|
||||
integral_constant<bool, true>{},
|
||||
integral_constant<bool, true>{});
|
||||
}
|
||||
else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
|
||||
{
|
||||
const auto kernel = run_gridwise_operation<gridwise_gemm,
|
||||
decltype(wei_e_k_global_desc),
|
||||
const FloatAB*,
|
||||
decltype(in_e_n_ho_wo_global_desc),
|
||||
const FloatAB*,
|
||||
decltype(out_k_n_ho_wo_global_desc),
|
||||
FloatC*,
|
||||
integral_constant<bool, true>,
|
||||
integral_constant<bool, false>>;
|
||||
|
||||
launch_kernel(kernel,
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
wei_e_k_global_desc,
|
||||
p_wei_global,
|
||||
in_e_n_ho_wo_global_desc,
|
||||
p_in_global,
|
||||
out_k_n_ho_wo_global_desc,
|
||||
p_out_global,
|
||||
integral_constant<bool, true>{},
|
||||
integral_constant<bool, false>{});
|
||||
}
|
||||
else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
|
||||
{
|
||||
const auto kernel = run_gridwise_operation<gridwise_gemm,
|
||||
decltype(wei_e_k_global_desc),
|
||||
const FloatAB*,
|
||||
decltype(in_e_n_ho_wo_global_desc),
|
||||
const FloatAB*,
|
||||
decltype(out_k_n_ho_wo_global_desc),
|
||||
FloatC*,
|
||||
integral_constant<bool, false>,
|
||||
integral_constant<bool, true>>;
|
||||
|
||||
launch_kernel(kernel,
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
wei_e_k_global_desc,
|
||||
p_wei_global,
|
||||
in_e_n_ho_wo_global_desc,
|
||||
p_in_global,
|
||||
out_k_n_ho_wo_global_desc,
|
||||
p_out_global,
|
||||
integral_constant<bool, false>{},
|
||||
integral_constant<bool, true>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel = run_gridwise_operation<gridwise_gemm,
|
||||
decltype(wei_e_k_global_desc),
|
||||
const FloatAB*,
|
||||
decltype(in_e_n_ho_wo_global_desc),
|
||||
const FloatAB*,
|
||||
decltype(out_k_n_ho_wo_global_desc),
|
||||
FloatC*,
|
||||
integral_constant<bool, false>,
|
||||
integral_constant<bool, false>>;
|
||||
|
||||
launch_kernel(kernel,
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
wei_e_k_global_desc,
|
||||
p_wei_global,
|
||||
in_e_n_ho_wo_global_desc,
|
||||
p_in_global,
|
||||
out_k_n_ho_wo_global_desc,
|
||||
p_out_global,
|
||||
integral_constant<bool, false>{},
|
||||
integral_constant<bool, false>{});
|
||||
}
|
||||
}
|
||||
|
||||
timer.End();
|
||||
|
||||
float ave_time = timer.GetElapsedTime() / nrepeat;
|
||||
|
||||
float perf =
|
||||
static_cast<float>(calculate_convolution_flops(in_n_c_hi_wi_global_desc,
|
||||
wei_k_c_y_x_global_desc,
|
||||
out_n_k0_ho_wo_k1_global_desc)) /
|
||||
(std::size_t(1000) * 1000 * 1000) / ave_time;
|
||||
|
||||
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s"
|
||||
<< std::endl;
|
||||
}
|
||||
#endif
|
||||
}
|
||||
};
|
||||
#endif
|
||||
@@ -0,0 +1,364 @@
|
||||
#ifndef DRIVER_CONVOLUTION_FORWARD_IMPLICIT_GEMM_V5R1_DLOPS_NCHW_KCYX_NKHW_OUTPAD_HPP
|
||||
#define DRIVER_CONVOLUTION_FORWARD_IMPLICIT_GEMM_V5R1_DLOPS_NCHW_KCYX_NKHW_OUTPAD_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
#include "gridwise_gemm_dlops_v2.hpp"
|
||||
#include "gridwise_operation_wrapper.hpp"
|
||||
|
||||
template <ck::index_t BlockSize,
|
||||
typename FloatAB,
|
||||
typename FloatAcc,
|
||||
typename FloatC,
|
||||
ck::index_t KPerBlock,
|
||||
ck::index_t HoPerBlock,
|
||||
ck::index_t WoPerBlock,
|
||||
ck::index_t EPerBlock,
|
||||
ck::index_t KPerThread,
|
||||
ck::index_t HoPerThread,
|
||||
ck::index_t WoPerThread,
|
||||
ck::index_t EPerThread,
|
||||
typename ABlockTransferThreadSliceLengths_E_K,
|
||||
typename ABlockTransferThreadClusterLengths_E_K,
|
||||
ck::index_t ABlockTransferSrcScalarPerVector_E,
|
||||
ck::index_t ABlockTransferDstScalarPerVector_K,
|
||||
ck::index_t BThreadTransferSrcScalarPerVector_W,
|
||||
ck::index_t CThreadTransferDstScalarPerVector_W>
|
||||
struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outpad
|
||||
{
|
||||
template <typename... Wei,
|
||||
typename... In,
|
||||
typename... Out,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads>
|
||||
__host__ void Run(const ck::TensorDescriptor<Wei...>& wei_k_c_y_x_global_desc,
|
||||
const ck::TensorDescriptor<In...>& in_n_c_hi_wi_global_desc,
|
||||
const ck::TensorDescriptor<Out...>& out_n_k0_ho_wo_k1_global_desc,
|
||||
const ConvStrides& conv_strides,
|
||||
const ConvDilations& conv_dilations,
|
||||
const InLeftPads& in_left_pads,
|
||||
const InRightPads& in_right_pads,
|
||||
const FloatAB* __restrict__ p_wei_global,
|
||||
const FloatAB* __restrict__ p_in_global,
|
||||
FloatC* __restrict__ p_out_global) const
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
constexpr auto I4 = Number<4>{};
|
||||
|
||||
const auto N = in_n_c_hi_wi_global_desc.GetLength(I0);
|
||||
const auto C = in_n_c_hi_wi_global_desc.GetLength(I1);
|
||||
const auto K0 = out_n_k0_ho_wo_k1_global_desc.GetLength(I1);
|
||||
|
||||
const auto Hi = in_n_c_hi_wi_global_desc.GetLength(I2);
|
||||
const auto Wi = in_n_c_hi_wi_global_desc.GetLength(I3);
|
||||
|
||||
const auto Ho = out_n_k0_ho_wo_k1_global_desc.GetLength(I2);
|
||||
const auto Wo = out_n_k0_ho_wo_k1_global_desc.GetLength(I3);
|
||||
|
||||
const auto K1 = out_n_k0_ho_wo_k1_global_desc.GetLength(I4);
|
||||
|
||||
const auto K = wei_k_c_y_x_global_desc.GetLength(I0);
|
||||
const auto Y = wei_k_c_y_x_global_desc.GetLength(I2);
|
||||
const auto X = wei_k_c_y_x_global_desc.GetLength(I3);
|
||||
|
||||
const auto ConvStrideH = conv_strides[I0];
|
||||
const auto ConvStrideW = conv_strides[I1];
|
||||
|
||||
const auto ConvDilationH = conv_dilations[I0];
|
||||
const auto ConvDilationW = conv_dilations[I1];
|
||||
|
||||
const auto Hop = (Ho + HoPerBlock - 1) / HoPerBlock * HoPerBlock;
|
||||
const auto Wop = (Wo + WoPerBlock - 1) / WoPerBlock * WoPerBlock;
|
||||
|
||||
const auto OutRightPadH = Hop - Ho;
|
||||
const auto OutRightPadW = Wop - Wo;
|
||||
|
||||
const auto InLeftPadH = in_left_pads[I0];
|
||||
const auto InLeftPadW = in_left_pads[I1];
|
||||
|
||||
const auto InRightPadH = in_right_pads[I0] + OutRightPadH * ConvStrideH;
|
||||
const auto InRightPadW = in_right_pads[I1] + OutRightPadW * ConvStrideW;
|
||||
|
||||
std::cerr << "OutRightPadH = " << OutRightPadH << " OutRightPadW = " << OutRightPadW
|
||||
<< std::endl;
|
||||
std::cerr << "InRightPadH = " << InRightPadH << " InRightPadW = " << InRightPadW
|
||||
<< std::endl;
|
||||
|
||||
// weight tensor
|
||||
const auto wei_e_k_global_desc = transform_tensor_descriptor(
|
||||
make_naive_tensor_descriptor_packed(make_tuple(K, C * Y * X)),
|
||||
make_tuple(make_pass_through_transform(K), make_pass_through_transform(C * Y * X)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}));
|
||||
|
||||
// input tensor
|
||||
const auto in_n_c_hip_wip_global_desc = transform_tensor_descriptor(
|
||||
in_n_c_hi_wi_global_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_pass_through_transform(C),
|
||||
make_pad_transform(Hi, InLeftPadH, InRightPadH),
|
||||
make_pad_transform(Wi, InLeftPadW, InRightPadW)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
const auto in_n_c_y_ho_x_wo_global_desc = transform_tensor_descriptor(
|
||||
in_n_c_hip_wip_global_desc,
|
||||
make_tuple(
|
||||
make_pass_through_transform(N),
|
||||
make_pass_through_transform(C),
|
||||
make_embed_transform(make_tuple(Y, Hop), make_tuple(ConvDilationH, ConvStrideH)),
|
||||
make_embed_transform(make_tuple(X, Wop), make_tuple(ConvDilationW, ConvStrideW))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
|
||||
|
||||
const auto in_e_n_ho_wo_global_desc = transform_tensor_descriptor(
|
||||
in_n_c_y_ho_x_wo_global_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(C, Y, X)),
|
||||
make_pass_through_transform(N),
|
||||
make_pass_through_transform(Hop),
|
||||
make_pass_through_transform(Wop)),
|
||||
make_tuple(Sequence<1, 2, 4>{}, Sequence<0>{}, Sequence<3>{}, Sequence<5>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
// output tensor
|
||||
const auto out_k_n_hop_wop_global_desc = transform_tensor_descriptor(
|
||||
make_naive_tensor_descriptor_packed(make_tuple(N, K0, Ho, Wo, K1)),
|
||||
make_tuple(make_merge_transform(make_tuple(K0, K1)),
|
||||
make_pass_through_transform(N),
|
||||
make_pad_transform(Ho, 0, OutRightPadH),
|
||||
make_pad_transform(Wo, 0, OutRightPadW)),
|
||||
make_tuple(Sequence<1, 4>{}, Sequence<0>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
const auto E = C * Y * X;
|
||||
|
||||
std::cerr << "Hop = " << Hop << " Wop = " << Wop << std::endl;
|
||||
|
||||
if(!((K % KPerBlock) == 0 && (Hop % HoPerBlock) == 0 && (Wop % WoPerBlock) == 0 &&
|
||||
(E % EPerBlock) == 0))
|
||||
{
|
||||
throw std::runtime_error("wrong! GEMM size no divisible");
|
||||
}
|
||||
|
||||
// hack to control index calculation when iterating over a_k_m_global tensor
|
||||
constexpr auto a_e_k_global_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{}));
|
||||
|
||||
constexpr auto a_e_k_global_move_slice_window_step_hack = Sequence<0, 0, 0>{};
|
||||
|
||||
constexpr auto b_e_n_ho_wo_global_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}));
|
||||
|
||||
constexpr auto b_e_n_ho_wo_global_move_slice_window_step_hack =
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{};
|
||||
|
||||
// hack to control index calculation when iterating over c_m0_m1_n0_n1_global tensor
|
||||
// hack for NKHW format
|
||||
constexpr auto c_k_n_ho_wo_global_tensor_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 1, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 2, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{}));
|
||||
|
||||
// GEMM
|
||||
using gridwise_gemm = GridwiseGemmDlops_km_kn_mn_v3<
|
||||
BlockSize,
|
||||
FloatAB,
|
||||
FloatAcc,
|
||||
FloatC,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
decltype(wei_e_k_global_desc),
|
||||
decltype(in_e_n_ho_wo_global_desc),
|
||||
decltype(out_k_n_hop_wop_global_desc),
|
||||
KPerBlock,
|
||||
HoPerBlock,
|
||||
WoPerBlock,
|
||||
EPerBlock,
|
||||
KPerThread,
|
||||
HoPerThread,
|
||||
WoPerThread,
|
||||
EPerThread,
|
||||
ABlockTransferThreadSliceLengths_E_K,
|
||||
ABlockTransferThreadClusterLengths_E_K,
|
||||
Sequence<1, 0>,
|
||||
Sequence<1, 0>,
|
||||
0,
|
||||
ABlockTransferSrcScalarPerVector_E,
|
||||
ABlockTransferDstScalarPerVector_K,
|
||||
false, // don't move back src coordinate after threadwise copy
|
||||
Sequence<0, 2, 3, 1>,
|
||||
3,
|
||||
BThreadTransferSrcScalarPerVector_W,
|
||||
false, // don't move back src coordinate after threadwise copy, which will be fused with
|
||||
// MoveSrcSliceWindow() to save addr computation
|
||||
Sequence<0, 2, 3, 1>,
|
||||
0,
|
||||
CThreadTransferDstScalarPerVector_W,
|
||||
decltype(a_e_k_global_step_hacks),
|
||||
decltype(b_e_n_ho_wo_global_step_hacks),
|
||||
decltype(c_k_n_ho_wo_global_tensor_step_hacks),
|
||||
decltype(a_e_k_global_move_slice_window_step_hack),
|
||||
decltype(b_e_n_ho_wo_global_move_slice_window_step_hack)>;
|
||||
|
||||
const auto GridSize = (K / KPerBlock) * (Hop / HoPerBlock) * (Wop / WoPerBlock) * N;
|
||||
|
||||
const bool has_main_k_block_loop = (E + EPerBlock) / (2 * EPerBlock) > 1;
|
||||
|
||||
const bool has_double_tail_k_block_loop = (E / EPerBlock) % 2 == 0;
|
||||
|
||||
index_t nrepeat = 100;
|
||||
|
||||
for(index_t i = 0; i < 5; ++i)
|
||||
{
|
||||
std::cout << "Start running " << nrepeat << " times..." << std::endl;
|
||||
|
||||
KernelTimer timer;
|
||||
timer.Start();
|
||||
std::cout << "has_main_k_block_loop: " << has_main_k_block_loop
|
||||
<< " has_double_tail_k_block_loop: " << has_double_tail_k_block_loop
|
||||
<< std::endl;
|
||||
|
||||
for(index_t j = 0; j < nrepeat; ++j)
|
||||
{
|
||||
if(has_main_k_block_loop && has_double_tail_k_block_loop)
|
||||
{
|
||||
const auto kernel =
|
||||
run_gridwise_operation<gridwise_gemm,
|
||||
decltype(wei_e_k_global_desc),
|
||||
const FloatAB*,
|
||||
decltype(in_e_n_ho_wo_global_desc),
|
||||
const FloatAB*,
|
||||
decltype(out_k_n_hop_wop_global_desc),
|
||||
FloatC*,
|
||||
integral_constant<bool, true>,
|
||||
integral_constant<bool, true>>;
|
||||
|
||||
launch_kernel(kernel,
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
wei_e_k_global_desc,
|
||||
p_wei_global,
|
||||
in_e_n_ho_wo_global_desc,
|
||||
p_in_global,
|
||||
out_k_n_hop_wop_global_desc,
|
||||
p_out_global,
|
||||
integral_constant<bool, true>{},
|
||||
integral_constant<bool, true>{});
|
||||
}
|
||||
else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
|
||||
{
|
||||
const auto kernel =
|
||||
run_gridwise_operation<gridwise_gemm,
|
||||
decltype(wei_e_k_global_desc),
|
||||
const FloatAB*,
|
||||
decltype(in_e_n_ho_wo_global_desc),
|
||||
const FloatAB*,
|
||||
decltype(out_k_n_hop_wop_global_desc),
|
||||
FloatC*,
|
||||
integral_constant<bool, true>,
|
||||
integral_constant<bool, false>>;
|
||||
|
||||
launch_kernel(kernel,
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
wei_e_k_global_desc,
|
||||
p_wei_global,
|
||||
in_e_n_ho_wo_global_desc,
|
||||
p_in_global,
|
||||
out_k_n_hop_wop_global_desc,
|
||||
p_out_global,
|
||||
integral_constant<bool, true>{},
|
||||
integral_constant<bool, false>{});
|
||||
}
|
||||
else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
|
||||
{
|
||||
const auto kernel =
|
||||
run_gridwise_operation<gridwise_gemm,
|
||||
decltype(wei_e_k_global_desc),
|
||||
const FloatAB*,
|
||||
decltype(in_e_n_ho_wo_global_desc),
|
||||
const FloatAB*,
|
||||
decltype(out_k_n_hop_wop_global_desc),
|
||||
FloatC*,
|
||||
integral_constant<bool, false>,
|
||||
integral_constant<bool, true>>;
|
||||
|
||||
launch_kernel(kernel,
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
wei_e_k_global_desc,
|
||||
p_wei_global,
|
||||
in_e_n_ho_wo_global_desc,
|
||||
p_in_global,
|
||||
out_k_n_hop_wop_global_desc,
|
||||
p_out_global,
|
||||
integral_constant<bool, false>{},
|
||||
integral_constant<bool, true>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel =
|
||||
run_gridwise_operation<gridwise_gemm,
|
||||
decltype(wei_e_k_global_desc),
|
||||
const FloatAB*,
|
||||
decltype(in_e_n_ho_wo_global_desc),
|
||||
const FloatAB*,
|
||||
decltype(out_k_n_hop_wop_global_desc),
|
||||
FloatC*,
|
||||
integral_constant<bool, false>,
|
||||
integral_constant<bool, false>>;
|
||||
|
||||
launch_kernel(kernel,
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
wei_e_k_global_desc,
|
||||
p_wei_global,
|
||||
in_e_n_ho_wo_global_desc,
|
||||
p_in_global,
|
||||
out_k_n_hop_wop_global_desc,
|
||||
p_out_global,
|
||||
integral_constant<bool, false>{},
|
||||
integral_constant<bool, false>{});
|
||||
}
|
||||
}
|
||||
|
||||
timer.End();
|
||||
|
||||
float ave_time = timer.GetElapsedTime() / nrepeat;
|
||||
|
||||
float perf =
|
||||
static_cast<float>(calculate_convolution_flops(in_n_c_hi_wi_global_desc,
|
||||
wei_k_c_y_x_global_desc,
|
||||
out_n_k0_ho_wo_k1_global_desc)) /
|
||||
(std::size_t(1000) * 1000 * 1000) / ave_time;
|
||||
|
||||
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s"
|
||||
<< std::endl;
|
||||
}
|
||||
}
|
||||
};
|
||||
#endif
|
||||
413
host/driver_offline/include/driver_gemm_dlops_v1r2.hpp
Normal file
413
host/driver_offline/include/driver_gemm_dlops_v1r2.hpp
Normal file
@@ -0,0 +1,413 @@
|
||||
#ifndef DRIVER_GEMM_DLOPS_V1R2
|
||||
#define DRIVER_GEMM_DLOPS_V1R2
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
#include "gridwise_gemm_dlops_v1r2.hpp"
|
||||
|
||||
template <ck::index_t BlockSize,
|
||||
typename FloatAB,
|
||||
typename FloatAcc,
|
||||
typename FloatC,
|
||||
ck::InMemoryDataOperationEnum_t CGlobalMemoryDataOperation,
|
||||
typename AKMGridDesc,
|
||||
typename BKNGridDesc,
|
||||
typename CMNGridDesc,
|
||||
ck::index_t MPerBlock,
|
||||
ck::index_t NPerBlock,
|
||||
ck::index_t KPerBlock,
|
||||
ck::index_t M1PerThread,
|
||||
ck::index_t N1PerThread,
|
||||
ck::index_t KPerThread,
|
||||
ck::index_t M1N1ThreadClusterM10,
|
||||
ck::index_t M1N1ThreadClusterN10,
|
||||
ck::index_t M1N1ThreadClusterM11,
|
||||
ck::index_t M1N1ThreadClusterN11,
|
||||
typename ABlockTransferThreadSliceLengths_K_M0_M1,
|
||||
typename ABlockTransferThreadClusterLengths_K_M0_M1,
|
||||
typename ABlockTransferThreadClusterArrangeOrder,
|
||||
typename ABlockTransferSrcAccessOrder,
|
||||
ck::index_t ABlockTransferSrcVectorDim,
|
||||
ck::index_t ABlockTransferSrcScalarPerVector,
|
||||
ck::index_t ABlockTransferDstScalarPerVector_M1,
|
||||
bool AThreadTransferSrcResetCoordinateAfterRun,
|
||||
typename BBlockTransferThreadSliceLengths_K_N0_N1,
|
||||
typename BBlockTransferThreadClusterLengths_K_N0_N1,
|
||||
typename BBlockTransferThreadClusterArrangeOrder,
|
||||
typename BBlockTransferSrcAccessOrder,
|
||||
ck::index_t BBlockTransferSrcVectorDim,
|
||||
ck::index_t BBlockTransferSrcScalarPerVector,
|
||||
ck::index_t BBlockTransferDstScalarPerVector_N1,
|
||||
bool BThreadTransferSrcResetCoordinateAfterRun,
|
||||
typename CThreadTransferSrcDstAccessOrder,
|
||||
ck::index_t CThreadTransferSrcDstVectorDim,
|
||||
ck::index_t CThreadTransferDstScalarPerVector,
|
||||
typename AGridStepHacks,
|
||||
typename BGridStepHacks,
|
||||
typename CGridStepHacks,
|
||||
typename AGridMoveSliceWindowStepHacks,
|
||||
typename BGridMoveSliceWindowStepHacks>
|
||||
__host__ float driver_gemm_dlops_v1r2(const FloatAB* p_a_grid,
|
||||
const FloatAB* p_b_grid,
|
||||
FloatC* p_c_grid,
|
||||
const AKMGridDesc& a_k_m_grid_desc,
|
||||
const BKNGridDesc& b_k_n_grid_desc,
|
||||
const CMNGridDesc& c_m_n_grid_desc,
|
||||
AGridStepHacks,
|
||||
BGridStepHacks,
|
||||
CGridStepHacks,
|
||||
AGridMoveSliceWindowStepHacks,
|
||||
BGridMoveSliceWindowStepHacks,
|
||||
ck::index_t nrepeat)
|
||||
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
constexpr auto I4 = Number<4>{};
|
||||
constexpr auto I5 = Number<5>{};
|
||||
|
||||
// GEMM
|
||||
using GridwiseGemm = GridwiseGemmDlops_km_kn_mn_v1r2<BlockSize,
|
||||
FloatAB,
|
||||
FloatAcc,
|
||||
FloatC,
|
||||
CGlobalMemoryDataOperation,
|
||||
AKMGridDesc,
|
||||
BKNGridDesc,
|
||||
CMNGridDesc,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
M1PerThread,
|
||||
N1PerThread,
|
||||
KPerThread,
|
||||
M1N1ThreadClusterM10,
|
||||
M1N1ThreadClusterN10,
|
||||
M1N1ThreadClusterM11,
|
||||
M1N1ThreadClusterN11,
|
||||
ABlockTransferThreadSliceLengths_K_M0_M1,
|
||||
ABlockTransferThreadClusterLengths_K_M0_M1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorDim,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_M1,
|
||||
AThreadTransferSrcResetCoordinateAfterRun,
|
||||
BBlockTransferThreadSliceLengths_K_N0_N1,
|
||||
BBlockTransferThreadClusterLengths_K_N0_N1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorDim,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_N1,
|
||||
BThreadTransferSrcResetCoordinateAfterRun,
|
||||
CThreadTransferSrcDstAccessOrder,
|
||||
CThreadTransferSrcDstVectorDim,
|
||||
CThreadTransferDstScalarPerVector,
|
||||
AGridStepHacks,
|
||||
BGridStepHacks,
|
||||
CGridStepHacks,
|
||||
AGridMoveSliceWindowStepHacks,
|
||||
BGridMoveSliceWindowStepHacks>;
|
||||
|
||||
const auto M = a_k_m_grid_desc.GetLength(I1);
|
||||
const auto N = b_k_n_grid_desc.GetLength(I1);
|
||||
const auto K = a_k_m_grid_desc.GetLength(I0);
|
||||
|
||||
if(!GridwiseGemm::CheckValidity(a_k_m_grid_desc, b_k_n_grid_desc, c_m_n_grid_desc))
|
||||
{
|
||||
throw std::runtime_error("wrong! GridwiseGemmDlops_km_kn_mn_v1r2 has invalid setting");
|
||||
}
|
||||
|
||||
const auto a_k_m0_m1_grid_desc = GridwiseGemm::MakeAKM0M1GridDescriptor(a_k_m_grid_desc);
|
||||
const auto b_k_n0_n1_grid_desc = GridwiseGemm::MakeBKN0N1GridDescriptor(b_k_n_grid_desc);
|
||||
|
||||
using AKM0M1GridDesc = decltype(a_k_m0_m1_grid_desc);
|
||||
using BKN0N1GridDesc = decltype(b_k_n0_n1_grid_desc);
|
||||
|
||||
// c_m0_m10_m11_n0_n10_n11_grid_desc
|
||||
const auto c_m0_m10_m11_n0_n10_n11_grid_desc =
|
||||
GridwiseGemm::MakeCM0M10M11N0N10N11GridDescriptor(c_m_n_grid_desc);
|
||||
|
||||
using CM0M10M11N0N10N11GridDesc = decltype(c_m0_m10_m11_n0_n10_n11_grid_desc);
|
||||
|
||||
// c_blockid_to_m0_n0_block_cluster_adaptor
|
||||
const auto c_blockid_to_m0_n0_block_cluster_adaptor =
|
||||
GridwiseGemm::MakeCBlockIdToM0N0BlockClusterAdaptor(c_m_n_grid_desc);
|
||||
|
||||
using CBlockIdToM0N0BlockClusterAdaptor = decltype(c_blockid_to_m0_n0_block_cluster_adaptor);
|
||||
|
||||
const index_t grid_size = GridwiseGemm::CalculateGridSize(M, N);
|
||||
|
||||
const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K);
|
||||
|
||||
const bool has_double_tail_k_block_loop = GridwiseGemm::CalculateHasDoubleTailKBlockLoop(K);
|
||||
|
||||
{
|
||||
std::cout << "a_k_m0_m1_grid_desc{" << a_k_m0_m1_grid_desc.GetLength(I0) << ", "
|
||||
<< a_k_m0_m1_grid_desc.GetLength(I1) << ", " << a_k_m0_m1_grid_desc.GetLength(I2)
|
||||
<< "}" << std::endl;
|
||||
|
||||
std::cout << "b_k_n0_n1_grid_desc{" << b_k_n0_n1_grid_desc.GetLength(I0) << ", "
|
||||
<< b_k_n0_n1_grid_desc.GetLength(I1) << ", " << b_k_n0_n1_grid_desc.GetLength(I2)
|
||||
<< "}" << std::endl;
|
||||
|
||||
std::cout << "c_m0_m10_m11_n0_n10_n11_grid_desc{ "
|
||||
<< c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I0) << ", "
|
||||
<< c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I1) << ", "
|
||||
<< c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I2) << ", "
|
||||
<< c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I3) << ", "
|
||||
<< c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I4) << ", "
|
||||
<< c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I5) << "}" << std::endl;
|
||||
}
|
||||
|
||||
#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE
|
||||
float ave_time = 0;
|
||||
|
||||
if(has_main_k_block_loop && has_double_tail_k_block_loop)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_dlops_v1r2<GridwiseGemm,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AKM0M1GridDesc>,
|
||||
remove_reference_t<BKN0N1GridDesc>,
|
||||
remove_reference_t<CM0M10M11N0N10N11GridDesc>,
|
||||
remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>,
|
||||
true,
|
||||
true>;
|
||||
|
||||
ave_time = launch_and_time_kernel(kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
a_k_m0_m1_grid_desc,
|
||||
b_k_n0_n1_grid_desc,
|
||||
c_m0_m10_m11_n0_n10_n11_grid_desc,
|
||||
c_blockid_to_m0_n0_block_cluster_adaptor);
|
||||
}
|
||||
else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_dlops_v1r2<GridwiseGemm,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AKM0M1GridDesc>,
|
||||
remove_reference_t<BKN0N1GridDesc>,
|
||||
remove_reference_t<CM0M10M11N0N10N11GridDesc>,
|
||||
remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>,
|
||||
true,
|
||||
false>;
|
||||
|
||||
ave_time = launch_and_time_kernel(kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
a_k_m0_m1_grid_desc,
|
||||
b_k_n0_n1_grid_desc,
|
||||
c_m0_m10_m11_n0_n10_n11_grid_desc,
|
||||
c_blockid_to_m0_n0_block_cluster_adaptor);
|
||||
}
|
||||
else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_dlops_v1r2<GridwiseGemm,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AKM0M1GridDesc>,
|
||||
remove_reference_t<BKN0N1GridDesc>,
|
||||
remove_reference_t<CM0M10M11N0N10N11GridDesc>,
|
||||
remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>,
|
||||
false,
|
||||
true>;
|
||||
|
||||
ave_time = launch_and_time_kernel(kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
a_k_m0_m1_grid_desc,
|
||||
b_k_n0_n1_grid_desc,
|
||||
c_m0_m10_m11_n0_n10_n11_grid_desc,
|
||||
c_blockid_to_m0_n0_block_cluster_adaptor);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_dlops_v1r2<GridwiseGemm,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AKM0M1GridDesc>,
|
||||
remove_reference_t<BKN0N1GridDesc>,
|
||||
remove_reference_t<CM0M10M11N0N10N11GridDesc>,
|
||||
remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>,
|
||||
false,
|
||||
false>;
|
||||
|
||||
ave_time = launch_and_time_kernel(kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
a_k_m0_m1_grid_desc,
|
||||
b_k_n0_n1_grid_desc,
|
||||
c_m0_m10_m11_n0_n10_n11_grid_desc,
|
||||
c_blockid_to_m0_n0_block_cluster_adaptor);
|
||||
}
|
||||
|
||||
return ave_time;
|
||||
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
|
||||
DeviceMem a_k_m0_m1_grid_desc_dev_buf(sizeof(AKM0M1GridDesc));
|
||||
DeviceMem b_k_n0_n1_grid_desc_dev_buf(sizeof(BKN0N1GridDesc));
|
||||
DeviceMem c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf(sizeof(CM0M10M11N0N10N11GridDesc));
|
||||
DeviceMem c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf(
|
||||
sizeof(CBlockIdToM0N0BlockClusterAdaptor));
|
||||
|
||||
a_k_m0_m1_grid_desc_dev_buf.ToDevice(&a_k_m0_m1_grid_desc);
|
||||
b_k_n0_n1_grid_desc_dev_buf.ToDevice(&b_k_n0_n1_grid_desc);
|
||||
c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.ToDevice(&c_m0_m10_m11_n0_n10_n11_grid_desc);
|
||||
c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.ToDevice(
|
||||
&c_blockid_to_m0_n0_block_cluster_adaptor);
|
||||
|
||||
float ave_time = 0;
|
||||
|
||||
if(has_main_k_block_loop && has_double_tail_k_block_loop)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_dlops_v1r2<GridwiseGemm,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AKM0M1GridDesc>,
|
||||
remove_reference_t<BKN0N1GridDesc>,
|
||||
remove_reference_t<CM0M10M11N0N10N11GridDesc>,
|
||||
remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>,
|
||||
true,
|
||||
true>;
|
||||
|
||||
ave_time = launch_and_time_kernel(
|
||||
kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
cast_pointer_to_constant_address_space(a_k_m0_m1_grid_desc_dev_buf.GetDeviceBuffer()),
|
||||
cast_pointer_to_constant_address_space(b_k_n0_n1_grid_desc_dev_buf.GetDeviceBuffer()),
|
||||
cast_pointer_to_constant_address_space(
|
||||
c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer()),
|
||||
cast_pointer_to_constant_address_space(
|
||||
c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer()));
|
||||
}
|
||||
else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_dlops_v1r2<GridwiseGemm,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AKM0M1GridDesc>,
|
||||
remove_reference_t<BKN0N1GridDesc>,
|
||||
remove_reference_t<CM0M10M11N0N10N11GridDesc>,
|
||||
remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>,
|
||||
true,
|
||||
false>;
|
||||
|
||||
ave_time = launch_and_time_kernel(
|
||||
kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
cast_pointer_to_constant_address_space(a_k_m0_m1_grid_desc_dev_buf.GetDeviceBuffer()),
|
||||
cast_pointer_to_constant_address_space(b_k_n0_n1_grid_desc_dev_buf.GetDeviceBuffer()),
|
||||
cast_pointer_to_constant_address_space(
|
||||
c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer()),
|
||||
cast_pointer_to_constant_address_space(
|
||||
c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer()));
|
||||
}
|
||||
else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_dlops_v1r2<GridwiseGemm,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AKM0M1GridDesc>,
|
||||
remove_reference_t<BKN0N1GridDesc>,
|
||||
remove_reference_t<CM0M10M11N0N10N11GridDesc>,
|
||||
remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>,
|
||||
false,
|
||||
true>;
|
||||
|
||||
ave_time = launch_and_time_kernel(
|
||||
kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
cast_pointer_to_constant_address_space(a_k_m0_m1_grid_desc_dev_buf.GetDeviceBuffer()),
|
||||
cast_pointer_to_constant_address_space(b_k_n0_n1_grid_desc_dev_buf.GetDeviceBuffer()),
|
||||
cast_pointer_to_constant_address_space(
|
||||
c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer()),
|
||||
cast_pointer_to_constant_address_space(
|
||||
c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer()));
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_dlops_v1r2<GridwiseGemm,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AKM0M1GridDesc>,
|
||||
remove_reference_t<BKN0N1GridDesc>,
|
||||
remove_reference_t<CM0M10M11N0N10N11GridDesc>,
|
||||
remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>,
|
||||
false,
|
||||
false>;
|
||||
|
||||
ave_time = launch_and_time_kernel(
|
||||
kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
cast_pointer_to_constant_address_space(a_k_m0_m1_grid_desc_dev_buf.GetDeviceBuffer()),
|
||||
cast_pointer_to_constant_address_space(b_k_n0_n1_grid_desc_dev_buf.GetDeviceBuffer()),
|
||||
cast_pointer_to_constant_address_space(
|
||||
c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer()),
|
||||
cast_pointer_to_constant_address_space(
|
||||
c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer()));
|
||||
}
|
||||
|
||||
return ave_time;
|
||||
#endif
|
||||
}
|
||||
#endif
|
||||
418
host/driver_offline/include/driver_gemm_dlops_v1r3.hpp
Normal file
418
host/driver_offline/include/driver_gemm_dlops_v1r3.hpp
Normal file
@@ -0,0 +1,418 @@
|
||||
#ifndef DRIVER_GEMM_DLOPS_V1R3
|
||||
#define DRIVER_GEMM_DLOPS_V1R3
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
#include "gridwise_gemm_dlops_v1r3.hpp"
|
||||
|
||||
template <ck::index_t BlockSize,
|
||||
typename FloatAB,
|
||||
typename FloatAcc,
|
||||
typename FloatC,
|
||||
ck::InMemoryDataOperationEnum_t CGlobalMemoryDataOperation,
|
||||
typename AK0MK1GridDesc,
|
||||
typename BK0NK1GridDesc,
|
||||
typename CMNGridDesc,
|
||||
ck::index_t MPerBlock,
|
||||
ck::index_t NPerBlock,
|
||||
ck::index_t KPerBlock,
|
||||
ck::index_t M1PerThread,
|
||||
ck::index_t N1PerThread,
|
||||
ck::index_t KPerThread,
|
||||
typename M1N1ThreadClusterM1Xs,
|
||||
typename M1N1ThreadClusterN1Xs,
|
||||
typename ABlockTransferThreadSliceLengths_K0_M0_M1_K1,
|
||||
typename ABlockTransferThreadClusterLengths_K0_M0_M1_K1,
|
||||
typename ABlockTransferThreadClusterArrangeOrder,
|
||||
typename ABlockTransferSrcAccessOrder,
|
||||
typename ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1,
|
||||
typename ABlockTransferSrcVectorTensorContiguousDimOrder,
|
||||
typename ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1,
|
||||
typename BBlockTransferThreadSliceLengths_K0_N0_N1_K1,
|
||||
typename BBlockTransferThreadClusterLengths_K0_N0_N1_K1,
|
||||
typename BBlockTransferThreadClusterArrangeOrder,
|
||||
typename BBlockTransferSrcAccessOrder,
|
||||
typename BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1,
|
||||
typename BBlockTransferSrcVectorTensorContiguousDimOrder,
|
||||
typename BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1,
|
||||
typename CThreadTransferSrcDstAccessOrder,
|
||||
ck::index_t CThreadTransferSrcDstVectorDim,
|
||||
ck::index_t CThreadTransferDstScalarPerVector,
|
||||
typename AGridStepHacks,
|
||||
typename BGridStepHacks,
|
||||
typename CGridStepHacks,
|
||||
typename AGridMoveSliceWindowStepHacks,
|
||||
typename BGridMoveSliceWindowStepHacks>
|
||||
__host__ float driver_gemm_dlops_v1r3(const FloatAB* p_a_grid,
|
||||
const FloatAB* p_b_grid,
|
||||
FloatC* p_c_grid,
|
||||
const AK0MK1GridDesc& a_k0_m_k1_grid_desc,
|
||||
const BK0NK1GridDesc& b_k0_n_k1_grid_desc,
|
||||
const CMNGridDesc& c_m_n_grid_desc,
|
||||
AGridStepHacks,
|
||||
BGridStepHacks,
|
||||
CGridStepHacks,
|
||||
AGridMoveSliceWindowStepHacks,
|
||||
BGridMoveSliceWindowStepHacks,
|
||||
ck::index_t nrepeat)
|
||||
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
constexpr auto I4 = Number<4>{};
|
||||
constexpr auto I5 = Number<5>{};
|
||||
|
||||
// GEMM
|
||||
using GridwiseGemm =
|
||||
GridwiseGemmDlops_km_kn_mn_v1r3<BlockSize,
|
||||
FloatAB,
|
||||
FloatAcc,
|
||||
FloatC,
|
||||
CGlobalMemoryDataOperation,
|
||||
AK0MK1GridDesc,
|
||||
BK0NK1GridDesc,
|
||||
CMNGridDesc,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
M1PerThread,
|
||||
N1PerThread,
|
||||
KPerThread,
|
||||
M1N1ThreadClusterM1Xs,
|
||||
M1N1ThreadClusterN1Xs,
|
||||
ABlockTransferThreadSliceLengths_K0_M0_M1_K1,
|
||||
ABlockTransferThreadClusterLengths_K0_M0_M1_K1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1,
|
||||
ABlockTransferSrcVectorTensorContiguousDimOrder,
|
||||
ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1,
|
||||
BBlockTransferThreadSliceLengths_K0_N0_N1_K1,
|
||||
BBlockTransferThreadClusterLengths_K0_N0_N1_K1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1,
|
||||
BBlockTransferSrcVectorTensorContiguousDimOrder,
|
||||
BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1,
|
||||
CThreadTransferSrcDstAccessOrder,
|
||||
CThreadTransferSrcDstVectorDim,
|
||||
CThreadTransferDstScalarPerVector,
|
||||
AGridStepHacks,
|
||||
BGridStepHacks,
|
||||
CGridStepHacks,
|
||||
AGridMoveSliceWindowStepHacks,
|
||||
BGridMoveSliceWindowStepHacks>;
|
||||
|
||||
const auto M = a_k0_m_k1_grid_desc.GetLength(I1);
|
||||
const auto N = b_k0_n_k1_grid_desc.GetLength(I1);
|
||||
const auto K0 = a_k0_m_k1_grid_desc.GetLength(I0);
|
||||
|
||||
if(!GridwiseGemm::CheckValidity(a_k0_m_k1_grid_desc, b_k0_n_k1_grid_desc, c_m_n_grid_desc))
|
||||
{
|
||||
throw std::runtime_error("wrong! GridwiseGemmDlops_km_kn_mn_v1r3 has invalid setting");
|
||||
}
|
||||
|
||||
const auto a_k0_m0_m1_k1_grid_desc =
|
||||
GridwiseGemm::MakeAK0M0M1K1GridDescriptor(a_k0_m_k1_grid_desc);
|
||||
const auto b_k0_n0_n1_k1_grid_desc =
|
||||
GridwiseGemm::MakeBK0N0N1K1GridDescriptor(b_k0_n_k1_grid_desc);
|
||||
|
||||
using AK0M0M1K1GridDesc = decltype(a_k0_m0_m1_k1_grid_desc);
|
||||
using BK0N0N1K1GridDesc = decltype(b_k0_n0_n1_k1_grid_desc);
|
||||
|
||||
// c_m0_m10_m11_n0_n10_n11_grid_desc
|
||||
const auto c_m0_m10_m11_n0_n10_n11_grid_desc =
|
||||
GridwiseGemm::MakeCM0M10M11N0N10N11GridDescriptor(c_m_n_grid_desc);
|
||||
|
||||
using CM0M10M11N0N10N11GridDesc = decltype(c_m0_m10_m11_n0_n10_n11_grid_desc);
|
||||
|
||||
// c_blockid_to_m0_n0_block_cluster_adaptor
|
||||
const auto c_blockid_to_m0_n0_block_cluster_adaptor =
|
||||
GridwiseGemm::MakeCBlockIdToM0N0BlockClusterAdaptor(c_m_n_grid_desc);
|
||||
|
||||
using CBlockIdToM0N0BlockClusterAdaptor = decltype(c_blockid_to_m0_n0_block_cluster_adaptor);
|
||||
|
||||
const index_t grid_size = GridwiseGemm::CalculateGridSize(M, N);
|
||||
|
||||
const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K0);
|
||||
|
||||
const bool has_double_tail_k_block_loop = GridwiseGemm::CalculateHasDoubleTailKBlockLoop(K0);
|
||||
|
||||
{
|
||||
std::cout << "a_k0_m0_m1_k1_grid_desc{" << a_k0_m0_m1_k1_grid_desc.GetLength(I0) << ", "
|
||||
<< a_k0_m0_m1_k1_grid_desc.GetLength(I1) << ", "
|
||||
<< a_k0_m0_m1_k1_grid_desc.GetLength(I2) << ", "
|
||||
<< a_k0_m0_m1_k1_grid_desc.GetLength(I3) << "}" << std::endl;
|
||||
|
||||
std::cout << "b_k0_n0_n1_k1_grid_desc{" << b_k0_n0_n1_k1_grid_desc.GetLength(I0) << ", "
|
||||
<< b_k0_n0_n1_k1_grid_desc.GetLength(I1) << ", "
|
||||
<< b_k0_n0_n1_k1_grid_desc.GetLength(I2) << ", "
|
||||
<< b_k0_n0_n1_k1_grid_desc.GetLength(I3) << "}" << std::endl;
|
||||
|
||||
std::cout << "c_m0_m10_m11_n0_n10_n11_grid_desc{ "
|
||||
<< c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I0) << ", "
|
||||
<< c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I1) << ", "
|
||||
<< c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I2) << ", "
|
||||
<< c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I3) << ", "
|
||||
<< c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I4) << ", "
|
||||
<< c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I5) << "}" << std::endl;
|
||||
}
|
||||
|
||||
#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE
|
||||
float ave_time = 0;
|
||||
|
||||
if(has_main_k_block_loop && has_double_tail_k_block_loop)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_dlops_v1r3<GridwiseGemm,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AK0M0M1K1GridDesc>,
|
||||
remove_reference_t<BK0N0N1K1GridDesc>,
|
||||
remove_reference_t<CM0M10M11N0N10N11GridDesc>,
|
||||
remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>,
|
||||
true,
|
||||
true>;
|
||||
|
||||
ave_time = launch_and_time_kernel(kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
a_k0_m0_m1_k1_grid_desc,
|
||||
b_k0_n0_n1_k1_grid_desc,
|
||||
c_m0_m10_m11_n0_n10_n11_grid_desc,
|
||||
c_blockid_to_m0_n0_block_cluster_adaptor);
|
||||
}
|
||||
else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_dlops_v1r3<GridwiseGemm,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AK0M0M1K1GridDesc>,
|
||||
remove_reference_t<BK0N0N1K1GridDesc>,
|
||||
remove_reference_t<CM0M10M11N0N10N11GridDesc>,
|
||||
remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>,
|
||||
true,
|
||||
false>;
|
||||
|
||||
ave_time = launch_and_time_kernel(kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
a_k0_m0_m1_k1_grid_desc,
|
||||
b_k0_n0_n1_k1_grid_desc,
|
||||
c_m0_m10_m11_n0_n10_n11_grid_desc,
|
||||
c_blockid_to_m0_n0_block_cluster_adaptor);
|
||||
}
|
||||
else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_dlops_v1r3<GridwiseGemm,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AK0M0M1K1GridDesc>,
|
||||
remove_reference_t<BK0N0N1K1GridDesc>,
|
||||
remove_reference_t<CM0M10M11N0N10N11GridDesc>,
|
||||
remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>,
|
||||
false,
|
||||
true>;
|
||||
|
||||
ave_time = launch_and_time_kernel(kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
a_k0_m0_m1_k1_grid_desc,
|
||||
b_k0_n0_n1_k1_grid_desc,
|
||||
c_m0_m10_m11_n0_n10_n11_grid_desc,
|
||||
c_blockid_to_m0_n0_block_cluster_adaptor);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_dlops_v1r3<GridwiseGemm,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AK0M0M1K1GridDesc>,
|
||||
remove_reference_t<BK0N0N1K1GridDesc>,
|
||||
remove_reference_t<CM0M10M11N0N10N11GridDesc>,
|
||||
remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>,
|
||||
false,
|
||||
false>;
|
||||
|
||||
ave_time = launch_and_time_kernel(kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
a_k0_m0_m1_k1_grid_desc,
|
||||
b_k0_n0_n1_k1_grid_desc,
|
||||
c_m0_m10_m11_n0_n10_n11_grid_desc,
|
||||
c_blockid_to_m0_n0_block_cluster_adaptor);
|
||||
}
|
||||
|
||||
return ave_time;
|
||||
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
|
||||
DeviceMem a_k0_m0_m1_k1_grid_desc_dev_buf(sizeof(AK0M0M1K1GridDesc));
|
||||
DeviceMem b_k0_n0_n1_k1_grid_desc_dev_buf(sizeof(BK0N0N1K1GridDesc));
|
||||
DeviceMem c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf(sizeof(CM0M10M11N0N10N11GridDesc));
|
||||
DeviceMem c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf(
|
||||
sizeof(CBlockIdToM0N0BlockClusterAdaptor));
|
||||
|
||||
a_k0_m0_m1_k1_grid_desc_dev_buf.ToDevice(&a_k0_m0_m1_k1_grid_desc);
|
||||
b_k0_n0_n1_k1_grid_desc_dev_buf.ToDevice(&b_k0_n0_n1_k1_grid_desc);
|
||||
c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.ToDevice(&c_m0_m10_m11_n0_n10_n11_grid_desc);
|
||||
c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.ToDevice(
|
||||
&c_blockid_to_m0_n0_block_cluster_adaptor);
|
||||
|
||||
float ave_time = 0;
|
||||
|
||||
if(has_main_k_block_loop && has_double_tail_k_block_loop)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_dlops_v1r3<GridwiseGemm,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AK0M0M1K1GridDesc>,
|
||||
remove_reference_t<BK0N0N1K1GridDesc>,
|
||||
remove_reference_t<CM0M10M11N0N10N11GridDesc>,
|
||||
remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>,
|
||||
true,
|
||||
true>;
|
||||
|
||||
ave_time = launch_and_time_kernel(
|
||||
kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
cast_pointer_to_constant_address_space(
|
||||
a_k0_m0_m1_k1_grid_desc_dev_buf.GetDeviceBuffer()),
|
||||
cast_pointer_to_constant_address_space(
|
||||
b_k0_n0_n1_k1_grid_desc_dev_buf.GetDeviceBuffer()),
|
||||
cast_pointer_to_constant_address_space(
|
||||
c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer()),
|
||||
cast_pointer_to_constant_address_space(
|
||||
c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer()));
|
||||
}
|
||||
else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_dlops_v1r3<GridwiseGemm,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AK0M0M1K1GridDesc>,
|
||||
remove_reference_t<BK0N0N1K1GridDesc>,
|
||||
remove_reference_t<CM0M10M11N0N10N11GridDesc>,
|
||||
remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>,
|
||||
true,
|
||||
false>;
|
||||
|
||||
ave_time = launch_and_time_kernel(
|
||||
kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
cast_pointer_to_constant_address_space(
|
||||
a_k0_m0_m1_k1_grid_desc_dev_buf.GetDeviceBuffer()),
|
||||
cast_pointer_to_constant_address_space(
|
||||
b_k0_n0_n1_k1_grid_desc_dev_buf.GetDeviceBuffer()),
|
||||
cast_pointer_to_constant_address_space(
|
||||
c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer()),
|
||||
cast_pointer_to_constant_address_space(
|
||||
c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer()));
|
||||
}
|
||||
else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_dlops_v1r3<GridwiseGemm,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AK0M0M1K1GridDesc>,
|
||||
remove_reference_t<BK0N0N1K1GridDesc>,
|
||||
remove_reference_t<CM0M10M11N0N10N11GridDesc>,
|
||||
remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>,
|
||||
false,
|
||||
true>;
|
||||
|
||||
ave_time = launch_and_time_kernel(
|
||||
kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
cast_pointer_to_constant_address_space(
|
||||
a_k0_m0_m1_k1_grid_desc_dev_buf.GetDeviceBuffer()),
|
||||
cast_pointer_to_constant_address_space(
|
||||
b_k0_n0_n1_k1_grid_desc_dev_buf.GetDeviceBuffer()),
|
||||
cast_pointer_to_constant_address_space(
|
||||
c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer()),
|
||||
cast_pointer_to_constant_address_space(
|
||||
c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer()));
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_dlops_v1r3<GridwiseGemm,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AK0M0M1K1GridDesc>,
|
||||
remove_reference_t<BK0N0N1K1GridDesc>,
|
||||
remove_reference_t<CM0M10M11N0N10N11GridDesc>,
|
||||
remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>,
|
||||
false,
|
||||
false>;
|
||||
|
||||
ave_time = launch_and_time_kernel(
|
||||
kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
cast_pointer_to_constant_address_space(
|
||||
a_k0_m0_m1_k1_grid_desc_dev_buf.GetDeviceBuffer()),
|
||||
cast_pointer_to_constant_address_space(
|
||||
b_k0_n0_n1_k1_grid_desc_dev_buf.GetDeviceBuffer()),
|
||||
cast_pointer_to_constant_address_space(
|
||||
c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer()),
|
||||
cast_pointer_to_constant_address_space(
|
||||
c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer()));
|
||||
}
|
||||
|
||||
return ave_time;
|
||||
#endif
|
||||
}
|
||||
#endif
|
||||
191
host/driver_offline/include/driver_gemm_xdlops_v2r3.hpp
Normal file
191
host/driver_offline/include/driver_gemm_xdlops_v2r3.hpp
Normal file
@@ -0,0 +1,191 @@
|
||||
#ifndef DRIVER_GEMM_XDLOPS_V2R3
|
||||
#define DRIVER_GEMM_XDLOPS_V2R3
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
#include "gridwise_gemm_xdlops_v2r3.hpp"
|
||||
|
||||
template <ck::index_t BlockSize,
|
||||
typename FloatAB,
|
||||
typename FloatAcc,
|
||||
typename FloatC,
|
||||
ck::InMemoryDataOperationEnum_t CGlobalMemoryDataOperation,
|
||||
typename AK0MK1GridDesc,
|
||||
typename BK0NK1GridDesc,
|
||||
typename CMNGridDesc,
|
||||
ck::index_t MPerBlock,
|
||||
ck::index_t NPerBlock,
|
||||
ck::index_t KPerBlock,
|
||||
ck::index_t MPerWave,
|
||||
ck::index_t NPerWave,
|
||||
ck::index_t K1,
|
||||
ck::index_t MRepeat,
|
||||
ck::index_t NRepeat,
|
||||
typename ABlockTransferThreadSliceLengths_K0_M_K1,
|
||||
typename ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
typename ABlockTransferThreadClusterArrangeOrder,
|
||||
typename ABlockTransferSrcAccessOrder,
|
||||
ck::index_t ABlockTransferSrcVectorDim,
|
||||
ck::index_t ABlockTransferSrcScalarPerVector,
|
||||
ck::index_t ABlockTransferDstScalarPerVector_K1,
|
||||
bool AThreadTransferSrcResetCoordinateAfterRun,
|
||||
typename BBlockTransferThreadSliceLengths_K0_N_K1,
|
||||
typename BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
typename BBlockTransferThreadClusterArrangeOrder,
|
||||
typename BBlockTransferSrcAccessOrder,
|
||||
ck::index_t BBlockTransferSrcVectorDim,
|
||||
ck::index_t BBlockTransferSrcScalarPerVector,
|
||||
ck::index_t BBlockTransferDstScalarPerVector_K1,
|
||||
bool BThreadTransferSrcResetCoordinateAfterRun,
|
||||
typename CThreadTransferSrcDstAccessOrder,
|
||||
ck::index_t CThreadTransferSrcDstVectorDim,
|
||||
ck::index_t CThreadTransferDstScalarPerVector,
|
||||
typename AGridStepHacks,
|
||||
typename BGridStepHacks,
|
||||
typename CGridStepHacks,
|
||||
typename AGridMoveSliceWindowStepHacks,
|
||||
typename BGridMoveSliceWindowStepHacks,
|
||||
bool CAccessOrderMRepeatNRepeat>
|
||||
__host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid,
|
||||
const FloatAB* p_b_grid,
|
||||
FloatC* p_c_grid,
|
||||
const AK0MK1GridDesc& a_k0_m_k1_grid_desc,
|
||||
const BK0NK1GridDesc& b_k0_n_k1_grid_desc,
|
||||
const CMNGridDesc& c_m_n_grid_desc,
|
||||
AGridStepHacks,
|
||||
BGridStepHacks,
|
||||
CGridStepHacks,
|
||||
AGridMoveSliceWindowStepHacks,
|
||||
BGridMoveSliceWindowStepHacks,
|
||||
ck::index_t nrepeat)
|
||||
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
|
||||
using GridwiseGemm =
|
||||
GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3<BlockSize,
|
||||
FloatAB,
|
||||
FloatAcc,
|
||||
FloatC,
|
||||
CGlobalMemoryDataOperation,
|
||||
AK0MK1GridDesc,
|
||||
BK0NK1GridDesc,
|
||||
CMNGridDesc,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerWave,
|
||||
NPerWave,
|
||||
K1,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
ABlockTransferThreadSliceLengths_K0_M_K1,
|
||||
ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorDim,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_K1,
|
||||
AThreadTransferSrcResetCoordinateAfterRun,
|
||||
BBlockTransferThreadSliceLengths_K0_N_K1,
|
||||
BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorDim,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_K1,
|
||||
BThreadTransferSrcResetCoordinateAfterRun,
|
||||
CThreadTransferSrcDstAccessOrder,
|
||||
CThreadTransferSrcDstVectorDim,
|
||||
CThreadTransferDstScalarPerVector,
|
||||
AGridStepHacks,
|
||||
BGridStepHacks,
|
||||
CGridStepHacks,
|
||||
AGridMoveSliceWindowStepHacks,
|
||||
BGridMoveSliceWindowStepHacks,
|
||||
CAccessOrderMRepeatNRepeat>;
|
||||
|
||||
{
|
||||
std::cout << "a_k0_m_k1_grid_desc{" << a_k0_m_k1_grid_desc.GetLength(I0) << ", "
|
||||
<< a_k0_m_k1_grid_desc.GetLength(I1) << ", " << a_k0_m_k1_grid_desc.GetLength(I2)
|
||||
<< "}" << std::endl;
|
||||
|
||||
std::cout << "b_k0_n_k1_grid_desc{" << b_k0_n_k1_grid_desc.GetLength(I0) << ", "
|
||||
<< b_k0_n_k1_grid_desc.GetLength(I1) << ", " << b_k0_n_k1_grid_desc.GetLength(I2)
|
||||
<< "}" << std::endl;
|
||||
|
||||
std::cout << "c_m_n_grid_desc{ " << c_m_n_grid_desc.GetLength(I0) << ", "
|
||||
<< c_m_n_grid_desc.GetLength(I1) << "}" << std::endl;
|
||||
}
|
||||
|
||||
if(!GridwiseGemm::CheckValidity(a_k0_m_k1_grid_desc, b_k0_n_k1_grid_desc, c_m_n_grid_desc))
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v2r3 has invalid setting");
|
||||
}
|
||||
|
||||
const auto c_m0_m1_m2_n_grid_desc = GridwiseGemm::MakeCM0M1M2NGridDescriptor(c_m_n_grid_desc);
|
||||
|
||||
using CM0M1M2NGridDesc = decltype(c_m0_m1_m2_n_grid_desc);
|
||||
|
||||
const auto c_block_cluster_adaptor = GridwiseGemm::MakeCBlockClusterAdaptor(c_m_n_grid_desc);
|
||||
|
||||
using CBlockClusterAdaptor = decltype(c_block_cluster_adaptor);
|
||||
|
||||
const index_t grid_size = GridwiseGemm::CalculateGridSize(c_m_n_grid_desc);
|
||||
|
||||
const auto kernel = kernel_gemm_xdlops_v2r3<GridwiseGemm,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AK0MK1GridDesc>,
|
||||
remove_reference_t<BK0NK1GridDesc>,
|
||||
remove_reference_t<CM0M1M2NGridDesc>,
|
||||
remove_reference_t<CBlockClusterAdaptor>>;
|
||||
|
||||
#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE
|
||||
float ave_time = launch_and_time_kernel(kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
a_k0_m_k1_grid_desc,
|
||||
b_k0_n_k1_grid_desc,
|
||||
c_m0_m1_m2_n_grid_desc,
|
||||
c_block_cluster_adaptor);
|
||||
|
||||
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
|
||||
DeviceMem a_k0_m_k1_grid_desc_dev_buf(sizeof(AK0MK1GridDesc));
|
||||
DeviceMem b_k0_n_k1_grid_desc_dev_buf(sizeof(BK0NK1GridDesc));
|
||||
DeviceMem c_m0_m1_m2_n_grid_desc_dev_buf(sizeof(CM0M1M2NGridDesc));
|
||||
DeviceMem c_block_cluster_adaptor_dev_buf(sizeof(CBlockClusterAdaptor));
|
||||
|
||||
a_k0_m_k1_grid_desc_dev_buf.ToDevice(&a_k0_m_k1_grid_desc);
|
||||
b_k0_n_k1_grid_desc_dev_buf.ToDevice(&b_k0_n_k1_grid_desc);
|
||||
c_m0_m1_m2_n_grid_desc_dev_buf.ToDevice(&c_m0_m1_m2_n_grid_desc);
|
||||
c_block_cluster_adaptor_dev_buf.ToDevice(&c_block_cluster_adaptor);
|
||||
|
||||
float ave_time = launch_and_time_kernel(
|
||||
kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
cast_pointer_to_constant_address_space(a_k0_m_k1_grid_desc_dev_buf.GetDeviceBuffer()),
|
||||
cast_pointer_to_constant_address_space(b_k0_n_k1_grid_desc_dev_buf.GetDeviceBuffer()),
|
||||
cast_pointer_to_constant_address_space(c_m0_m1_m2_n_grid_desc_dev_buf.GetDeviceBuffer()),
|
||||
cast_pointer_to_constant_address_space(c_block_cluster_adaptor_dev_buf.GetDeviceBuffer()));
|
||||
#endif
|
||||
return ave_time;
|
||||
}
|
||||
#endif
|
||||
321
host/driver_offline/src/conv_bwd_driver_offline.cpp
Normal file
321
host/driver_offline/src/conv_bwd_driver_offline.cpp
Normal file
@@ -0,0 +1,321 @@
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <initializer_list>
|
||||
#include <cstdlib>
|
||||
#include <stdlib.h>
|
||||
#include <half.hpp>
|
||||
#include "config.hpp"
|
||||
#include "print.hpp"
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "host_tensor_generator.hpp"
|
||||
#include "conv_common.hpp"
|
||||
#include "host_conv_bwd_data.hpp"
|
||||
#include "device_tensor.hpp"
|
||||
#include "device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk.hpp"
|
||||
#include "device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk.hpp"
|
||||
|
||||
#define USE_MODE 1
|
||||
#define USE_CONV_BWD_V4R1_XDL_NHWC 1
|
||||
#define USE_CONV_BWD_V4R1R2_XDL_NHWC 1
|
||||
|
||||
enum ConvBackwardDataAlgo
|
||||
{
|
||||
V4R1XDLNHWC,
|
||||
V4R1R2XDLNHWC,
|
||||
};
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
constexpr auto I4 = Number<4>{};
|
||||
constexpr auto I5 = Number<5>{};
|
||||
constexpr auto I6 = Number<6>{};
|
||||
|
||||
#if USE_MODE
|
||||
// dynamic mode
|
||||
if(argc != 22)
|
||||
{
|
||||
printf("arg1 to 5: layout, algo, do_verification, init_method, do_log, nrepeat\n");
|
||||
printf("rest: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, RightPx\n");
|
||||
exit(1);
|
||||
}
|
||||
|
||||
const ConvTensorLayout layout = static_cast<ConvTensorLayout>(std::stoi(argv[1]));
|
||||
const ConvBackwardDataAlgo algo = static_cast<ConvBackwardDataAlgo>(std::stoi(argv[2]));
|
||||
const bool do_verification = std::stoi(argv[3]);
|
||||
const int init_method = std::stoi(argv[4]);
|
||||
const bool do_log = std::stoi(argv[5]);
|
||||
const int nrepeat = std::stoi(argv[6]);
|
||||
|
||||
const index_t N = std::stoi(argv[7]);
|
||||
const index_t K = std::stoi(argv[8]);
|
||||
const index_t C = std::stoi(argv[9]);
|
||||
const index_t Y = std::stoi(argv[10]);
|
||||
const index_t X = std::stoi(argv[11]);
|
||||
const index_t Hi = std::stoi(argv[12]);
|
||||
const index_t Wi = std::stoi(argv[13]);
|
||||
|
||||
const index_t conv_stride_h = std::stoi(argv[14]);
|
||||
const index_t conv_stride_w = std::stoi(argv[15]);
|
||||
const index_t conv_dilation_h = std::stoi(argv[16]);
|
||||
const index_t conv_dilation_w = std::stoi(argv[17]);
|
||||
const index_t in_left_pad_h = std::stoi(argv[18]);
|
||||
const index_t in_left_pad_w = std::stoi(argv[19]);
|
||||
const index_t in_right_pad_h = std::stoi(argv[20]);
|
||||
const index_t in_right_pad_w = std::stoi(argv[21]);
|
||||
|
||||
const index_t YEff = (Y - 1) * conv_dilation_h + 1;
|
||||
const index_t XEff = (X - 1) * conv_dilation_w + 1;
|
||||
|
||||
const index_t Ho = (Hi + in_left_pad_h + in_right_pad_h - YEff) / conv_stride_h + 1;
|
||||
const index_t Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1;
|
||||
#else
|
||||
// static mode
|
||||
if(argc < 7)
|
||||
{
|
||||
printf("arg1 to 5: layout, algo, do_verification, init_method, do_log, nrepeat\n");
|
||||
exit(1);
|
||||
}
|
||||
|
||||
const ConvTensorLayout layout = static_cast<ConvTensorLayout>(std::stoi(argv[1]));
|
||||
const ConvBackwardDataAlgo algo = static_cast<ConvBackwardDataAlgo>(std::stoi(argv[2]));
|
||||
const bool do_verification = std::stoi(argv[3]);
|
||||
const int init_method = std::stoi(argv[4]);
|
||||
const bool do_log = std::stoi(argv[5]);
|
||||
const int nrepeat = std::stoi(argv[6]);
|
||||
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 192;
|
||||
constexpr index_t Hi = 71;
|
||||
constexpr index_t Wi = 71;
|
||||
constexpr index_t K = 256;
|
||||
constexpr index_t Y = 3;
|
||||
constexpr index_t X = 3;
|
||||
|
||||
const index_t conv_stride_h = 2;
|
||||
const index_t conv_stride_w = 2;
|
||||
const index_t conv_dilation_h = 1;
|
||||
const index_t conv_dilation_w = 1;
|
||||
const index_t in_left_pad_h = 1;
|
||||
const index_t in_left_pad_w = 1;
|
||||
const index_t in_right_pad_h = 1;
|
||||
const index_t in_right_pad_w = 1;
|
||||
|
||||
const index_t YEff = (Y - 1) * conv_dilation_h + 1;
|
||||
const index_t XEff = (X - 1) * conv_dilation_w + 1;
|
||||
|
||||
const index_t Ho = (Hi + in_left_pad_h + in_right_pad_h - YEff) / conv_stride_h + 1;
|
||||
const index_t Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1;
|
||||
#endif
|
||||
|
||||
#if 0
|
||||
using in_data_t = float;
|
||||
using acc_data_t = float;
|
||||
using out_data_t = float;
|
||||
#elif 1
|
||||
using in_data_t = half_t;
|
||||
using acc_data_t = float;
|
||||
using out_data_t = half_t;
|
||||
#endif
|
||||
|
||||
std::vector<std::size_t> in_lengths_host(4), wei_lengths_host(4), out_lengths_host(4);
|
||||
|
||||
if(layout == ConvTensorLayout::NCHW)
|
||||
{
|
||||
in_lengths_host[0] = static_cast<std::size_t>(N);
|
||||
in_lengths_host[1] = static_cast<std::size_t>(C);
|
||||
in_lengths_host[2] = static_cast<std::size_t>(Hi);
|
||||
in_lengths_host[3] = static_cast<std::size_t>(Wi);
|
||||
wei_lengths_host[0] = static_cast<std::size_t>(K);
|
||||
wei_lengths_host[1] = static_cast<std::size_t>(C);
|
||||
wei_lengths_host[2] = static_cast<std::size_t>(Y);
|
||||
wei_lengths_host[3] = static_cast<std::size_t>(X);
|
||||
out_lengths_host[0] = static_cast<std::size_t>(N);
|
||||
out_lengths_host[1] = static_cast<std::size_t>(K);
|
||||
out_lengths_host[2] = static_cast<std::size_t>(Ho);
|
||||
out_lengths_host[3] = static_cast<std::size_t>(Wo);
|
||||
}
|
||||
else if(layout == ConvTensorLayout::NHWC)
|
||||
{
|
||||
in_lengths_host[0] = static_cast<std::size_t>(N);
|
||||
in_lengths_host[1] = static_cast<std::size_t>(Hi);
|
||||
in_lengths_host[2] = static_cast<std::size_t>(Wi);
|
||||
in_lengths_host[3] = static_cast<std::size_t>(C);
|
||||
wei_lengths_host[0] = static_cast<std::size_t>(K);
|
||||
wei_lengths_host[1] = static_cast<std::size_t>(Y);
|
||||
wei_lengths_host[2] = static_cast<std::size_t>(X);
|
||||
wei_lengths_host[3] = static_cast<std::size_t>(C);
|
||||
out_lengths_host[0] = static_cast<std::size_t>(N);
|
||||
out_lengths_host[1] = static_cast<std::size_t>(Ho);
|
||||
out_lengths_host[2] = static_cast<std::size_t>(Wo);
|
||||
out_lengths_host[3] = static_cast<std::size_t>(K);
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("wrong! not implemented");
|
||||
}
|
||||
|
||||
Tensor<in_data_t> in_host(in_lengths_host);
|
||||
Tensor<in_data_t> in_device(in_lengths_host);
|
||||
Tensor<in_data_t> wei(wei_lengths_host);
|
||||
Tensor<out_data_t> out(out_lengths_host);
|
||||
|
||||
std::cout << "layout: " << layout << std::endl;
|
||||
ostream_HostTensorDescriptor(in_host.mDesc, std::cout << "in: ");
|
||||
ostream_HostTensorDescriptor(wei.mDesc, std::cout << "wei: ");
|
||||
ostream_HostTensorDescriptor(out.mDesc, std::cout << "out: ");
|
||||
print_array("InLeftPads", make_tuple(in_left_pad_h, in_left_pad_w));
|
||||
print_array("InRightPads", make_tuple(in_right_pad_h, in_right_pad_w));
|
||||
print_array("ConvStrides", make_tuple(conv_stride_h, conv_stride_w));
|
||||
print_array("ConvDilations", make_tuple(conv_dilation_h, conv_dilation_w));
|
||||
|
||||
std::size_t num_thread = std::thread::hardware_concurrency();
|
||||
|
||||
switch(init_method)
|
||||
{
|
||||
case 0:
|
||||
// no initialization
|
||||
break;
|
||||
case 1:
|
||||
out.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
|
||||
wei.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
|
||||
break;
|
||||
case 2:
|
||||
out.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
|
||||
wei.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
|
||||
break;
|
||||
case 3:
|
||||
out.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
|
||||
wei.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
|
||||
break;
|
||||
case 4:
|
||||
out.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
|
||||
wei.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
|
||||
break;
|
||||
case 5:
|
||||
out.GenerateTensorValue(GeneratorTensor_3<float>{0.0, 1.0}, num_thread);
|
||||
wei.GenerateTensorValue(GeneratorTensor_3<float>{-0.5, 0.5}, num_thread);
|
||||
break;
|
||||
default:
|
||||
out.GenerateTensorValue(GeneratorTensor_2{1, 5}, num_thread);
|
||||
|
||||
auto gen_wei = [](auto... is) {
|
||||
return GeneratorTensor_2{1, 5}(is...) * GeneratorTensor_Checkboard{}(is...);
|
||||
};
|
||||
wei.GenerateTensorValue(gen_wei, num_thread);
|
||||
}
|
||||
|
||||
auto f_make_for_device_nhwc = [&]() {
|
||||
#if USE_MODE
|
||||
const auto in_lengths_dev = make_tuple(N, Hi, Wi, C);
|
||||
const auto wei_lengths_dev = make_tuple(K, Y, X, C);
|
||||
const auto out_lengths_dev = make_tuple(N, Ho, Wo, K);
|
||||
const auto conv_strides_dev = make_tuple(conv_stride_h, conv_stride_w);
|
||||
const auto conv_dilations_dev = make_tuple(conv_dilation_h, conv_dilation_w);
|
||||
const auto in_left_pads_dev = make_tuple(in_left_pad_h, in_left_pad_w);
|
||||
const auto in_right_pads_dev = make_tuple(in_right_pad_h, in_right_pad_w);
|
||||
#else
|
||||
const auto in_lengths_dev =
|
||||
make_tuple(Number<N>{}, Number<Hi>{}, Number<Wi>{}, Number<C>{});
|
||||
const auto wei_lengths_dev = make_tuple(Number<K>{}, Number<Y>{}, Number<X>{}, Number<C>{});
|
||||
const auto out_lengths_dev =
|
||||
make_tuple(Number<N>{}, Number<Ho>{}, Number<Wo>{}, Number<K>{});
|
||||
const auto conv_strides_dev = make_tuple(Number<conv_stride_h>{}, Number<conv_stride_w>{});
|
||||
const auto conv_dilations_dev =
|
||||
make_tuple(Number<conv_dilation_h>{}, Number<conv_dilation_w>{});
|
||||
const auto in_left_pads_dev = make_tuple(Number<in_left_pad_h>{}, Number<in_left_pad_w>{});
|
||||
const auto in_right_pads_dev =
|
||||
make_tuple(Number<in_right_pad_h>{}, Number<in_right_pad_w>{});
|
||||
#endif
|
||||
|
||||
return make_tuple(in_lengths_dev,
|
||||
wei_lengths_dev,
|
||||
out_lengths_dev,
|
||||
conv_strides_dev,
|
||||
conv_dilations_dev,
|
||||
in_left_pads_dev,
|
||||
in_right_pads_dev);
|
||||
};
|
||||
|
||||
#if USE_CONV_BWD_V4R1_XDL_NHWC
|
||||
if(algo == ConvBackwardDataAlgo::V4R1XDLNHWC)
|
||||
{
|
||||
if(layout != ConvTensorLayout::NHWC)
|
||||
{
|
||||
throw std::runtime_error("wrong! layout");
|
||||
}
|
||||
|
||||
const auto tmp = f_make_for_device_nhwc();
|
||||
|
||||
device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk<in_data_t,
|
||||
acc_data_t,
|
||||
out_data_t>(
|
||||
tmp[I0],
|
||||
tmp[I1],
|
||||
tmp[I2],
|
||||
tmp[I3],
|
||||
tmp[I4],
|
||||
tmp[I5],
|
||||
tmp[I6],
|
||||
in_device,
|
||||
wei,
|
||||
out,
|
||||
nrepeat);
|
||||
}
|
||||
#endif
|
||||
|
||||
#if USE_CONV_BWD_V4R1R2_XDL_NHWC
|
||||
if(algo == ConvBackwardDataAlgo::V4R1R2XDLNHWC)
|
||||
{
|
||||
if(layout != ConvTensorLayout::NHWC)
|
||||
{
|
||||
throw std::runtime_error("wrong! layout");
|
||||
}
|
||||
|
||||
const auto tmp = f_make_for_device_nhwc();
|
||||
|
||||
device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk<in_data_t,
|
||||
acc_data_t,
|
||||
out_data_t>(
|
||||
tmp[I0],
|
||||
tmp[I1],
|
||||
tmp[I2],
|
||||
tmp[I3],
|
||||
tmp[I4],
|
||||
tmp[I5],
|
||||
tmp[I6],
|
||||
in_device,
|
||||
wei,
|
||||
out,
|
||||
nrepeat);
|
||||
}
|
||||
#endif
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
host_direct_convolution_backward_data(in_host,
|
||||
wei,
|
||||
out,
|
||||
make_tuple(conv_stride_h, conv_stride_w),
|
||||
make_tuple(conv_dilation_h, conv_dilation_w),
|
||||
make_tuple(in_left_pad_h, in_left_pad_w),
|
||||
make_tuple(in_right_pad_h, in_right_pad_w),
|
||||
layout);
|
||||
|
||||
check_error(in_host, in_device);
|
||||
|
||||
if(do_log)
|
||||
{
|
||||
LogRangeAsType<float>(std::cout << "out : ", out.mData, ",") << std::endl;
|
||||
LogRangeAsType<float>(std::cout << "wei: ", wei.mData, ",") << std::endl;
|
||||
LogRangeAsType<float>(std::cout << "in_host : ", in_host.mData, ",") << std::endl;
|
||||
LogRangeAsType<float>(std::cout << "in_device: ", in_device.mData, ",") << std::endl;
|
||||
}
|
||||
}
|
||||
}
|
||||
474
host/driver_offline/src/conv_fwd_driver_offline.cpp
Normal file
474
host/driver_offline/src/conv_fwd_driver_offline.cpp
Normal file
@@ -0,0 +1,474 @@
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <initializer_list>
|
||||
#include <cstdlib>
|
||||
#include <stdlib.h>
|
||||
#include <half.hpp>
|
||||
#include "config.hpp"
|
||||
#include "print.hpp"
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "host_tensor_generator.hpp"
|
||||
#include "conv_common.hpp"
|
||||
#include "host_conv.hpp"
|
||||
#include "device_tensor.hpp"
|
||||
#include "device_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw.hpp"
|
||||
#include "device_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhwk.hpp"
|
||||
#include "device_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw.hpp"
|
||||
#include "device_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw.hpp"
|
||||
#include "device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp"
|
||||
#include "device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp"
|
||||
|
||||
#define USE_MODE 1
|
||||
#define USE_CONV_FWD_V4R4_NCHW 1
|
||||
#define USE_CONV_FWD_V4R4R2_NHWC 1
|
||||
#define USE_CONV_FWD_V6R1_NCHW 0
|
||||
#define USE_CONV_FWD_V5R1_NCHW 0
|
||||
#define USE_CONV_FWD_V4R4R2_XDL_NCHW 0
|
||||
#define USE_CONV_FWD_V4R4R4_XDL_NHWC 0
|
||||
|
||||
enum ConvForwardAlgo
|
||||
{
|
||||
V4R4NCHW, // 0
|
||||
V4R4R2NHWC, // 1
|
||||
V6R1NCHW, // 2
|
||||
V5R1NCHW, // 3
|
||||
V4R4R2XDLNCHW, // 4
|
||||
V4R4R4XDLNHWC // 5
|
||||
};
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
constexpr auto I4 = Number<4>{};
|
||||
constexpr auto I5 = Number<5>{};
|
||||
constexpr auto I6 = Number<6>{};
|
||||
|
||||
#if USE_MODE
|
||||
// dynamic mode
|
||||
if(argc != 22)
|
||||
{
|
||||
printf("arg1 to 5: layout, algo, do_verification, init_method, do_log, nrepeat\n");
|
||||
printf("rest: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, RightPx\n");
|
||||
exit(1);
|
||||
}
|
||||
|
||||
const ConvTensorLayout layout = static_cast<ConvTensorLayout>(std::stoi(argv[1]));
|
||||
const ConvForwardAlgo algo = static_cast<ConvForwardAlgo>(std::stoi(argv[2]));
|
||||
const bool do_verification = std::stoi(argv[3]);
|
||||
const int init_method = std::stoi(argv[4]);
|
||||
const bool do_log = std::stoi(argv[5]);
|
||||
const int nrepeat = std::stoi(argv[6]);
|
||||
|
||||
const index_t N = std::stoi(argv[7]);
|
||||
const index_t K = std::stoi(argv[8]);
|
||||
const index_t C = std::stoi(argv[9]);
|
||||
const index_t Y = std::stoi(argv[10]);
|
||||
const index_t X = std::stoi(argv[11]);
|
||||
const index_t Hi = std::stoi(argv[12]);
|
||||
const index_t Wi = std::stoi(argv[13]);
|
||||
|
||||
const index_t conv_stride_h = std::stoi(argv[14]);
|
||||
const index_t conv_stride_w = std::stoi(argv[15]);
|
||||
const index_t conv_dilation_h = std::stoi(argv[16]);
|
||||
const index_t conv_dilation_w = std::stoi(argv[17]);
|
||||
const index_t in_left_pad_h = std::stoi(argv[18]);
|
||||
const index_t in_left_pad_w = std::stoi(argv[19]);
|
||||
const index_t in_right_pad_h = std::stoi(argv[20]);
|
||||
const index_t in_right_pad_w = std::stoi(argv[21]);
|
||||
|
||||
const index_t YEff = (Y - 1) * conv_dilation_h + 1;
|
||||
const index_t XEff = (X - 1) * conv_dilation_w + 1;
|
||||
|
||||
const index_t Ho = (Hi + in_left_pad_h + in_right_pad_h - YEff) / conv_stride_h + 1;
|
||||
const index_t Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1;
|
||||
#else
|
||||
// static mode
|
||||
if(argc < 7)
|
||||
{
|
||||
printf("arg1 to 5: layout, algo, do_verification, init_method, do_log, nrepeat\n");
|
||||
exit(1);
|
||||
}
|
||||
|
||||
const ConvTensorLayout layout = static_cast<ConvTensorLayout>(std::stoi(argv[1]));
|
||||
const ConvForwardAlgo algo = static_cast<ConvForwardAlgo>(std::stoi(argv[2]));
|
||||
const bool do_verification = std::stoi(argv[3]);
|
||||
const int init_method = std::stoi(argv[4]);
|
||||
const bool do_log = std::stoi(argv[5]);
|
||||
const int nrepeat = std::stoi(argv[6]);
|
||||
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 192;
|
||||
constexpr index_t Hi = 71;
|
||||
constexpr index_t Wi = 71;
|
||||
constexpr index_t K = 256;
|
||||
constexpr index_t Y = 3;
|
||||
constexpr index_t X = 3;
|
||||
|
||||
const index_t conv_stride_h = 2;
|
||||
const index_t conv_stride_w = 2;
|
||||
const index_t conv_dilation_h = 1;
|
||||
const index_t conv_dilation_w = 1;
|
||||
const index_t in_left_pad_h = 1;
|
||||
const index_t in_left_pad_w = 1;
|
||||
const index_t in_right_pad_h = 1;
|
||||
const index_t in_right_pad_w = 1;
|
||||
|
||||
const index_t YEff = (Y - 1) * conv_dilation_h + 1;
|
||||
const index_t XEff = (X - 1) * conv_dilation_w + 1;
|
||||
|
||||
const index_t Ho = (Hi + in_left_pad_h + in_right_pad_h - YEff) / conv_stride_h + 1;
|
||||
const index_t Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1;
|
||||
#endif
|
||||
|
||||
#if 1
|
||||
using in_data_t = float;
|
||||
using acc_data_t = float;
|
||||
using out_data_t = float;
|
||||
#elif 1
|
||||
using in_data_t = half_t;
|
||||
using acc_data_t = float;
|
||||
using out_data_t = half_t;
|
||||
#elif 1
|
||||
using in_data_t = int8_t;
|
||||
using acc_data_t = int32_t;
|
||||
using out_data_t = int8_t;
|
||||
#endif
|
||||
|
||||
std::vector<std::size_t> in_lengths_host(4), wei_lengths_host(4), out_lengths_host(4);
|
||||
|
||||
if(layout == ConvTensorLayout::NCHW)
|
||||
{
|
||||
in_lengths_host[0] = static_cast<std::size_t>(N);
|
||||
in_lengths_host[1] = static_cast<std::size_t>(C);
|
||||
in_lengths_host[2] = static_cast<std::size_t>(Hi);
|
||||
in_lengths_host[3] = static_cast<std::size_t>(Wi);
|
||||
wei_lengths_host[0] = static_cast<std::size_t>(K);
|
||||
wei_lengths_host[1] = static_cast<std::size_t>(C);
|
||||
wei_lengths_host[2] = static_cast<std::size_t>(Y);
|
||||
wei_lengths_host[3] = static_cast<std::size_t>(X);
|
||||
out_lengths_host[0] = static_cast<std::size_t>(N);
|
||||
out_lengths_host[1] = static_cast<std::size_t>(K);
|
||||
out_lengths_host[2] = static_cast<std::size_t>(Ho);
|
||||
out_lengths_host[3] = static_cast<std::size_t>(Wo);
|
||||
}
|
||||
else if(layout == ConvTensorLayout::NHWC)
|
||||
{
|
||||
in_lengths_host[0] = static_cast<std::size_t>(N);
|
||||
in_lengths_host[1] = static_cast<std::size_t>(Hi);
|
||||
in_lengths_host[2] = static_cast<std::size_t>(Wi);
|
||||
in_lengths_host[3] = static_cast<std::size_t>(C);
|
||||
wei_lengths_host[0] = static_cast<std::size_t>(K);
|
||||
wei_lengths_host[1] = static_cast<std::size_t>(Y);
|
||||
wei_lengths_host[2] = static_cast<std::size_t>(X);
|
||||
wei_lengths_host[3] = static_cast<std::size_t>(C);
|
||||
out_lengths_host[0] = static_cast<std::size_t>(N);
|
||||
out_lengths_host[1] = static_cast<std::size_t>(Ho);
|
||||
out_lengths_host[2] = static_cast<std::size_t>(Wo);
|
||||
out_lengths_host[3] = static_cast<std::size_t>(K);
|
||||
}
|
||||
else
|
||||
{
|
||||
std::runtime_error("wrong! not implemented");
|
||||
}
|
||||
|
||||
Tensor<in_data_t> in(in_lengths_host);
|
||||
Tensor<in_data_t> wei(wei_lengths_host);
|
||||
Tensor<out_data_t> out_host(out_lengths_host);
|
||||
Tensor<out_data_t> out_device(out_lengths_host);
|
||||
|
||||
std::cout << "layout: " << layout << std::endl;
|
||||
ostream_HostTensorDescriptor(in.mDesc, std::cout << "in: ");
|
||||
ostream_HostTensorDescriptor(wei.mDesc, std::cout << "wei: ");
|
||||
ostream_HostTensorDescriptor(out_host.mDesc, std::cout << "out: ");
|
||||
print_array("InLeftPads", make_tuple(in_left_pad_h, in_left_pad_w));
|
||||
print_array("InRightPads", make_tuple(in_right_pad_h, in_right_pad_w));
|
||||
print_array("ConvStrides", make_tuple(conv_stride_h, conv_stride_w));
|
||||
print_array("ConvDilations", make_tuple(conv_dilation_h, conv_dilation_w));
|
||||
|
||||
std::size_t num_thread = std::thread::hardware_concurrency();
|
||||
|
||||
switch(init_method)
|
||||
{
|
||||
case 0:
|
||||
// no initialization
|
||||
break;
|
||||
case 1:
|
||||
in.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
|
||||
wei.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
|
||||
break;
|
||||
case 2:
|
||||
in.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
|
||||
wei.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
|
||||
break;
|
||||
case 3:
|
||||
in.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
|
||||
wei.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
|
||||
break;
|
||||
case 4:
|
||||
in.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
|
||||
wei.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
|
||||
break;
|
||||
case 5:
|
||||
in.GenerateTensorValue(GeneratorTensor_3<float>{0.0, 1.0}, num_thread);
|
||||
wei.GenerateTensorValue(GeneratorTensor_3<float>{-0.5, 0.5}, num_thread);
|
||||
break;
|
||||
default:
|
||||
in.GenerateTensorValue(GeneratorTensor_2{1, 5}, num_thread);
|
||||
|
||||
auto gen_wei = [](auto... is) {
|
||||
return GeneratorTensor_2{1, 5}(is...) * GeneratorTensor_Checkboard{}(is...);
|
||||
};
|
||||
wei.GenerateTensorValue(gen_wei, num_thread);
|
||||
}
|
||||
|
||||
auto f_make_for_device_nchw = [&]() {
|
||||
#if USE_MODE
|
||||
const auto in_lengths_dev = make_tuple(N, C, Hi, Wi);
|
||||
const auto wei_lengths_dev = make_tuple(K, C, Y, X);
|
||||
const auto out_lengths_dev = make_tuple(N, K, Ho, Wo);
|
||||
const auto conv_strides_dev = make_tuple(conv_stride_h, conv_stride_w);
|
||||
const auto conv_dilations_dev = make_tuple(conv_dilation_h, conv_dilation_w);
|
||||
const auto in_left_pads_dev = make_tuple(in_left_pad_h, in_left_pad_w);
|
||||
const auto in_right_pads_dev = make_tuple(in_right_pad_h, in_right_pad_w);
|
||||
#else
|
||||
const auto in_lengths_dev =
|
||||
make_tuple(Number<N>{}, Number<C>{}, Number<Hi>{}, Number<Wi>{});
|
||||
const auto wei_lengths_dev = make_tuple(Number<K>{}, Number<C>{}, Number<Y>{}, Number<X>{});
|
||||
const auto out_lengths_dev =
|
||||
make_tuple(Number<N>{}, Number<K>{}, Number<Ho>{}, Number<Wo>{});
|
||||
const auto conv_strides_dev = make_tuple(Number<conv_stride_h>{}, Number<conv_stride_w>{});
|
||||
const auto conv_dilations_dev =
|
||||
make_tuple(Number<conv_dilation_h>{}, Number<conv_dilation_w>{});
|
||||
const auto in_left_pads_dev = make_tuple(Number<in_left_pad_h>{}, Number<in_left_pad_w>{});
|
||||
const auto in_right_pads_dev =
|
||||
make_tuple(Number<in_right_pad_h>{}, Number<in_right_pad_w>{});
|
||||
#endif
|
||||
|
||||
return make_tuple(in_lengths_dev,
|
||||
wei_lengths_dev,
|
||||
out_lengths_dev,
|
||||
conv_strides_dev,
|
||||
conv_dilations_dev,
|
||||
in_left_pads_dev,
|
||||
in_right_pads_dev);
|
||||
};
|
||||
|
||||
auto f_make_for_device_nhwc = [&]() {
|
||||
#if USE_MODE
|
||||
const auto in_lengths_dev = make_tuple(N, Hi, Wi, C);
|
||||
const auto wei_lengths_dev = make_tuple(K, Y, X, C);
|
||||
const auto out_lengths_dev = make_tuple(N, Ho, Wo, K);
|
||||
const auto conv_strides_dev = make_tuple(conv_stride_h, conv_stride_w);
|
||||
const auto conv_dilations_dev = make_tuple(conv_dilation_h, conv_dilation_w);
|
||||
const auto in_left_pads_dev = make_tuple(in_left_pad_h, in_left_pad_w);
|
||||
const auto in_right_pads_dev = make_tuple(in_right_pad_h, in_right_pad_w);
|
||||
#else
|
||||
const auto in_lengths_dev =
|
||||
make_tuple(Number<N>{}, Number<Hi>{}, Number<Wi>{}, Number<C>{});
|
||||
const auto wei_lengths_dev = make_tuple(Number<K>{}, Number<Y>{}, Number<X>{}, Number<C>{});
|
||||
const auto out_lengths_dev =
|
||||
make_tuple(Number<N>{}, Number<Ho>{}, Number<Wo>{}, Number<K>{});
|
||||
const auto conv_strides_dev = make_tuple(Number<conv_stride_h>{}, Number<conv_stride_w>{});
|
||||
const auto conv_dilations_dev =
|
||||
make_tuple(Number<conv_dilation_h>{}, Number<conv_dilation_w>{});
|
||||
const auto in_left_pads_dev = make_tuple(Number<in_left_pad_h>{}, Number<in_left_pad_w>{});
|
||||
const auto in_right_pads_dev =
|
||||
make_tuple(Number<in_right_pad_h>{}, Number<in_right_pad_w>{});
|
||||
#endif
|
||||
|
||||
return make_tuple(in_lengths_dev,
|
||||
wei_lengths_dev,
|
||||
out_lengths_dev,
|
||||
conv_strides_dev,
|
||||
conv_dilations_dev,
|
||||
in_left_pads_dev,
|
||||
in_right_pads_dev);
|
||||
};
|
||||
|
||||
#if USE_CONV_FWD_V4R4_NCHW
|
||||
if(algo == ConvForwardAlgo::V4R4NCHW)
|
||||
{
|
||||
if(layout != ConvTensorLayout::NCHW)
|
||||
{
|
||||
throw std::runtime_error("wrong! layout");
|
||||
}
|
||||
|
||||
const auto tmp = f_make_for_device_nchw();
|
||||
|
||||
device_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw<in_data_t,
|
||||
acc_data_t,
|
||||
out_data_t>(tmp[I0],
|
||||
tmp[I1],
|
||||
tmp[I2],
|
||||
tmp[I3],
|
||||
tmp[I4],
|
||||
tmp[I5],
|
||||
tmp[I6],
|
||||
in,
|
||||
wei,
|
||||
out_device,
|
||||
nrepeat);
|
||||
}
|
||||
#endif
|
||||
|
||||
#if USE_CONV_FWD_V4R4R2_NHWC
|
||||
if(algo == ConvForwardAlgo::V4R4R2NHWC)
|
||||
{
|
||||
if(layout != ConvTensorLayout::NHWC)
|
||||
{
|
||||
throw std::runtime_error("wrong! layout");
|
||||
}
|
||||
|
||||
const auto tmp = f_make_for_device_nhwc();
|
||||
|
||||
device_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhwk<in_data_t,
|
||||
acc_data_t,
|
||||
out_data_t>(tmp[I0],
|
||||
tmp[I1],
|
||||
tmp[I2],
|
||||
tmp[I3],
|
||||
tmp[I4],
|
||||
tmp[I5],
|
||||
tmp[I6],
|
||||
in,
|
||||
wei,
|
||||
out_device,
|
||||
nrepeat);
|
||||
}
|
||||
#endif
|
||||
|
||||
#if USE_CONV_FWD_V6R1_NCHW
|
||||
if(algo == ConvForwardAlgo::V6R1NCHW)
|
||||
{
|
||||
if(layout != ConvTensorLayout::NCHW)
|
||||
{
|
||||
throw std::runtime_error("wrong! layout");
|
||||
}
|
||||
|
||||
const auto tmp = f_make_for_device_nchw();
|
||||
|
||||
device_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw<in_data_t,
|
||||
acc_data_t,
|
||||
out_data_t>(tmp[I0],
|
||||
tmp[I1],
|
||||
tmp[I2],
|
||||
tmp[I3],
|
||||
tmp[I4],
|
||||
tmp[I5],
|
||||
tmp[I6],
|
||||
in,
|
||||
wei,
|
||||
out_device,
|
||||
nrepeat);
|
||||
}
|
||||
#endif
|
||||
|
||||
#if USE_CONV_FWD_V5R1_NCHW
|
||||
if(algo == ConvForwardAlgo::V5R1NCHW)
|
||||
{
|
||||
if(layout != ConvTensorLayout::NCHW)
|
||||
{
|
||||
throw std::runtime_error("wrong! layout");
|
||||
}
|
||||
|
||||
const auto tmp = f_make_for_device_nchw();
|
||||
|
||||
device_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw<in_data_t,
|
||||
16,
|
||||
acc_data_t,
|
||||
out_data_t>(tmp[I0],
|
||||
tmp[I1],
|
||||
tmp[I2],
|
||||
tmp[I3],
|
||||
tmp[I4],
|
||||
tmp[I5],
|
||||
tmp[I6],
|
||||
in,
|
||||
wei,
|
||||
out_device,
|
||||
nrepeat);
|
||||
}
|
||||
#endif
|
||||
|
||||
#if USE_CONV_FWD_V4R4R2_XDL_NCHW
|
||||
if(algo == ConvForwardAlgo::V4R4R2XDLNCHW)
|
||||
{
|
||||
if(layout != ConvTensorLayout::NCHW)
|
||||
{
|
||||
throw std::runtime_error("wrong! layout");
|
||||
}
|
||||
|
||||
const auto tmp = f_make_for_device_nchw();
|
||||
|
||||
device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw<in_data_t,
|
||||
acc_data_t,
|
||||
out_data_t>(
|
||||
tmp[I0],
|
||||
tmp[I1],
|
||||
tmp[I2],
|
||||
tmp[I3],
|
||||
tmp[I4],
|
||||
tmp[I5],
|
||||
tmp[I6],
|
||||
in,
|
||||
wei,
|
||||
out_device,
|
||||
nrepeat);
|
||||
}
|
||||
#endif
|
||||
|
||||
#if USE_CONV_FWD_V4R4R4_XDL_NHWC
|
||||
if(algo == ConvForwardAlgo::V4R4R4XDLNHWC)
|
||||
{
|
||||
if(layout != ConvTensorLayout::NHWC)
|
||||
{
|
||||
throw std::runtime_error("wrong! layout");
|
||||
}
|
||||
|
||||
const auto tmp = f_make_for_device_nhwc();
|
||||
|
||||
device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk<in_data_t,
|
||||
acc_data_t,
|
||||
out_data_t>(
|
||||
tmp[I0],
|
||||
tmp[I1],
|
||||
tmp[I2],
|
||||
tmp[I3],
|
||||
tmp[I4],
|
||||
tmp[I5],
|
||||
tmp[I6],
|
||||
in,
|
||||
wei,
|
||||
out_device,
|
||||
nrepeat);
|
||||
}
|
||||
#endif
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
host_direct_convolution(in,
|
||||
wei,
|
||||
out_host,
|
||||
make_tuple(conv_stride_h, conv_stride_w),
|
||||
make_tuple(conv_dilation_h, conv_dilation_w),
|
||||
make_tuple(in_left_pad_h, in_left_pad_w),
|
||||
make_tuple(in_right_pad_h, in_right_pad_w),
|
||||
layout);
|
||||
|
||||
check_error(out_host, out_device);
|
||||
|
||||
if(do_log)
|
||||
{
|
||||
LogRangeAsType<float>(std::cout << "in : ", in.mData, ",") << std::endl;
|
||||
LogRangeAsType<float>(std::cout << "wei: ", wei.mData, ",") << std::endl;
|
||||
LogRangeAsType<float>(std::cout << "out_host : ", out_host.mData, ",") << std::endl;
|
||||
LogRangeAsType<float>(std::cout << "out_device: ", out_device.mData, ",") << std::endl;
|
||||
}
|
||||
}
|
||||
}
|
||||
21
host/host_tensor/CMakeLists.txt
Normal file
21
host/host_tensor/CMakeLists.txt
Normal file
@@ -0,0 +1,21 @@
|
||||
include_directories(BEFORE
|
||||
include
|
||||
)
|
||||
|
||||
set(HOST_TENSOR_SOURCE
|
||||
src/host_tensor.cpp;
|
||||
src/device.cpp;
|
||||
)
|
||||
|
||||
## the library target
|
||||
add_library(host_tensor SHARED ${HOST_TENSOR_SOURCE})
|
||||
|
||||
target_include_directories(host_tensor SYSTEM PUBLIC $<BUILD_INTERFACE:${HALF_INCLUDE_DIR}>)
|
||||
|
||||
target_link_libraries(host_tensor PRIVATE hip::device)
|
||||
target_link_libraries(host_tensor INTERFACE hip::host)
|
||||
|
||||
target_compile_features(host_tensor PUBLIC)
|
||||
set_target_properties(host_tensor PROPERTIES POSITION_INDEPENDENT_CODE ON)
|
||||
|
||||
install(TARGETS host_tensor LIBRARY DESTINATION lib)
|
||||
86
host/host_tensor/include/conv_common.hpp
Normal file
86
host/host_tensor/include/conv_common.hpp
Normal file
@@ -0,0 +1,86 @@
|
||||
#ifndef CONV_COMMON_HPP
|
||||
#define CONV_COMMON_HPP
|
||||
|
||||
#include "tensor_descriptor.hpp"
|
||||
|
||||
enum ConvTensorLayout
|
||||
{
|
||||
NCHW,
|
||||
NHWC,
|
||||
CHWN,
|
||||
NCHWc,
|
||||
NHWCc
|
||||
};
|
||||
|
||||
template <typename... InDesc,
|
||||
typename... WeiDesc,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename LeftPads,
|
||||
typename RightPads>
|
||||
constexpr auto get_convolution_output_default_4d_tensor_descriptor(
|
||||
const ck::TensorDescriptor<InDesc...>& in_desc,
|
||||
const ck::TensorDescriptor<WeiDesc...>& wei_desc,
|
||||
const ConvStrides& conv_strides,
|
||||
const ConvDilations conv_dilations,
|
||||
const LeftPads& left_pads,
|
||||
const RightPads& right_pads)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
assert(in_desc.GetNumOfDimension() == 4);
|
||||
assert(wei_desc.GetNumOfDimension() == 4);
|
||||
assert(in_desc.GetLength(I1) == wei_desc.GetLength(I1));
|
||||
|
||||
const auto N = in_desc.GetLength(I0);
|
||||
const auto Hi = in_desc.GetLength(I2);
|
||||
const auto Wi = in_desc.GetLength(I3);
|
||||
|
||||
const auto K = wei_desc.GetLength(I0);
|
||||
const auto Y = wei_desc.GetLength(I2);
|
||||
const auto X = wei_desc.GetLength(I3);
|
||||
|
||||
const auto LeftPadH = left_pads[I0];
|
||||
const auto LeftPadW = left_pads[I1];
|
||||
|
||||
const auto RightPadH = right_pads[I0];
|
||||
const auto RightPadW = right_pads[I1];
|
||||
|
||||
const auto YEff = (Y - I1) * conv_dilations[I0] + I1;
|
||||
const auto XEff = (X - I1) * conv_dilations[I1] + I1;
|
||||
|
||||
const auto Ho = (Hi + LeftPadH + RightPadH - YEff) / conv_strides[I0] + I1;
|
||||
const auto Wo = (Wi + LeftPadW + RightPadW - XEff) / conv_strides[I1] + I1;
|
||||
|
||||
return make_naive_tensor_descriptor_packed(make_tuple(N, K, Ho, Wo));
|
||||
}
|
||||
|
||||
template <class InDesc, class WeiDesc, class OutDesc>
|
||||
constexpr std::size_t
|
||||
calculate_convolution_flops(const InDesc&, const WeiDesc& wei_desc, const OutDesc& out_desc)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
const index_t N = out_desc.GetLength(I0);
|
||||
const index_t K = out_desc.GetLength(I1);
|
||||
const index_t Ho = out_desc.GetLength(I2);
|
||||
const index_t Wo = out_desc.GetLength(I3);
|
||||
|
||||
const index_t C = wei_desc.GetLength(I1);
|
||||
const index_t Y = wei_desc.GetLength(I2);
|
||||
const index_t X = wei_desc.GetLength(I3);
|
||||
|
||||
return std::size_t(2) * N * K * Ho * Wo * C * Y * X;
|
||||
}
|
||||
|
||||
#endif
|
||||
80
host/host_tensor/include/device.hpp
Normal file
80
host/host_tensor/include/device.hpp
Normal file
@@ -0,0 +1,80 @@
|
||||
#ifndef DEVICE_HPP
|
||||
#define DEVICE_HPP
|
||||
|
||||
#include <memory>
|
||||
#include "hip/hip_runtime.h"
|
||||
#include "hip/hip_fp16.h"
|
||||
|
||||
struct DeviceMem
|
||||
{
|
||||
DeviceMem() = delete;
|
||||
DeviceMem(std::size_t mem_size);
|
||||
void* GetDeviceBuffer();
|
||||
void ToDevice(const void* p);
|
||||
void FromDevice(void* p);
|
||||
~DeviceMem();
|
||||
|
||||
void* mpDeviceBuf;
|
||||
std::size_t mMemSize;
|
||||
};
|
||||
|
||||
struct KernelTimerImpl;
|
||||
|
||||
struct KernelTimer
|
||||
{
|
||||
KernelTimer();
|
||||
~KernelTimer();
|
||||
void Start();
|
||||
void End();
|
||||
float GetElapsedTime() const;
|
||||
|
||||
std::unique_ptr<KernelTimerImpl> impl;
|
||||
};
|
||||
|
||||
using device_stream_t = hipStream_t;
|
||||
|
||||
template <typename... Args, typename F>
|
||||
void launch_kernel(F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
|
||||
{
|
||||
hipStream_t stream_id = nullptr;
|
||||
|
||||
hipLaunchKernelGGL(kernel, grid_dim, block_dim, lds_byte, stream_id, args...);
|
||||
}
|
||||
|
||||
template <typename... Args, typename F>
|
||||
float launch_and_time_kernel(
|
||||
F kernel, int nrepeat, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
|
||||
{
|
||||
KernelTimer timer;
|
||||
|
||||
printf("%s: grid_dim {%d, %d, %d}, block_dim {%d, %d, %d} \n",
|
||||
__func__,
|
||||
grid_dim.x,
|
||||
grid_dim.y,
|
||||
grid_dim.z,
|
||||
block_dim.x,
|
||||
block_dim.y,
|
||||
block_dim.z);
|
||||
|
||||
printf("Warm up\n");
|
||||
|
||||
hipStream_t stream_id = nullptr;
|
||||
|
||||
// warm up
|
||||
hipLaunchKernelGGL(kernel, grid_dim, block_dim, lds_byte, stream_id, args...);
|
||||
|
||||
printf("Start running %d times...\n", nrepeat);
|
||||
|
||||
timer.Start();
|
||||
|
||||
for(int i = 0; i < nrepeat; ++i)
|
||||
{
|
||||
hipLaunchKernelGGL(kernel, grid_dim, block_dim, lds_byte, stream_id, args...);
|
||||
}
|
||||
|
||||
timer.End();
|
||||
|
||||
return timer.GetElapsedTime() / nrepeat;
|
||||
}
|
||||
|
||||
#endif
|
||||
9
host/host_tensor/include/device_tensor.hpp
Normal file
9
host/host_tensor/include/device_tensor.hpp
Normal file
@@ -0,0 +1,9 @@
|
||||
#pragma once
|
||||
#include "host_tensor.hpp"
|
||||
#include "common_header.hpp"
|
||||
|
||||
template <typename TensorDesc>
|
||||
void ostream_tensor_descriptor(TensorDesc, std::ostream& os = std::cout)
|
||||
{
|
||||
ostream_HostTensorDescriptor(make_HostTensorDescriptor(TensorDesc{}), os);
|
||||
}
|
||||
324
host/host_tensor/include/host_conv.hpp
Normal file
324
host/host_tensor/include/host_conv.hpp
Normal file
@@ -0,0 +1,324 @@
|
||||
#pragma once
|
||||
#include "host_tensor.hpp"
|
||||
|
||||
template <typename TIn,
|
||||
typename TWei,
|
||||
typename TOut,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads>
|
||||
void host_direct_convolution(const Tensor<TIn>& in,
|
||||
const Tensor<TWei>& wei,
|
||||
Tensor<TOut>& out,
|
||||
const ConvStrides& conv_strides,
|
||||
const ConvDilations& conv_dilations,
|
||||
const InLeftPads& in_left_pads,
|
||||
const InRightPads&,
|
||||
const ConvTensorLayout layout = ConvTensorLayout::NCHW)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
|
||||
auto f_nchw = [&](auto n, auto k, auto ho, auto wo) {
|
||||
double v = 0;
|
||||
for(int c = 0; c < wei.mDesc.GetLengths()[1]; ++c)
|
||||
{
|
||||
for(int y = 0; y < wei.mDesc.GetLengths()[2]; ++y)
|
||||
{
|
||||
int hi = ho * conv_strides[I0] + y * conv_dilations[I0] - in_left_pads[I0];
|
||||
for(int x = 0; x < wei.mDesc.GetLengths()[3]; ++x)
|
||||
{
|
||||
int wi = wo * conv_strides[I1] + x * conv_dilations[I1] - in_left_pads[I1];
|
||||
if(hi >= 0 && hi < in.mDesc.GetLengths()[2] && wi >= 0 &&
|
||||
wi < in.mDesc.GetLengths()[3])
|
||||
{
|
||||
v += static_cast<const double>(in(n, c, hi, wi)) *
|
||||
static_cast<const double>(wei(k, c, y, x));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
out(n, k, ho, wo) = v;
|
||||
};
|
||||
|
||||
auto f_nhwc = [&](auto n, auto ho, auto wo, auto k) {
|
||||
double v = 0;
|
||||
for(int c = 0; c < wei.mDesc.GetLengths()[3]; ++c)
|
||||
{
|
||||
for(int y = 0; y < wei.mDesc.GetLengths()[1]; ++y)
|
||||
{
|
||||
int hi = ho * conv_strides[I0] + y * conv_dilations[I0] - in_left_pads[I0];
|
||||
for(int x = 0; x < wei.mDesc.GetLengths()[2]; ++x)
|
||||
{
|
||||
int wi = wo * conv_strides[I1] + x * conv_dilations[I1] - in_left_pads[I1];
|
||||
if(hi >= 0 && hi < in.mDesc.GetLengths()[1] && wi >= 0 &&
|
||||
wi < in.mDesc.GetLengths()[2])
|
||||
{
|
||||
v += static_cast<const double>(in(n, hi, wi, c)) *
|
||||
static_cast<const double>(wei(k, y, x, c));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
out(n, ho, wo, k) = v;
|
||||
};
|
||||
|
||||
if(layout == ConvTensorLayout::NCHW)
|
||||
{
|
||||
make_ParallelTensorFunctor(f_nchw,
|
||||
out.mDesc.GetLengths()[0],
|
||||
out.mDesc.GetLengths()[1],
|
||||
out.mDesc.GetLengths()[2],
|
||||
out.mDesc.GetLengths()[3])(std::thread::hardware_concurrency());
|
||||
}
|
||||
else if(layout == ConvTensorLayout::NHWC)
|
||||
{
|
||||
make_ParallelTensorFunctor(f_nhwc,
|
||||
out.mDesc.GetLengths()[0],
|
||||
out.mDesc.GetLengths()[1],
|
||||
out.mDesc.GetLengths()[2],
|
||||
out.mDesc.GetLengths()[3])(std::thread::hardware_concurrency());
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("wrong! not supported layout");
|
||||
}
|
||||
}
|
||||
|
||||
template <typename TIn, typename TWei, typename TOut, typename InLeftPads, typename InRightPads>
|
||||
void host_winograd_3x3_convolution(const Tensor<TIn>& in_nchw,
|
||||
const Tensor<TWei>& wei_kcyx,
|
||||
Tensor<TOut>& out_nkhw,
|
||||
InLeftPads,
|
||||
InRightPads)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
constexpr std::size_t HoPerTile = 2;
|
||||
constexpr std::size_t WoPerTile = 2;
|
||||
|
||||
std::size_t N = in_nchw.mDesc.GetLengths()[0];
|
||||
std::size_t C = in_nchw.mDesc.GetLengths()[1];
|
||||
|
||||
std::size_t K = wei_kcyx.mDesc.GetLengths()[0];
|
||||
std::size_t Y = wei_kcyx.mDesc.GetLengths()[2];
|
||||
std::size_t X = wei_kcyx.mDesc.GetLengths()[3];
|
||||
|
||||
std::size_t Ho = out_nkhw.mDesc.GetLengths()[2];
|
||||
std::size_t Wo = out_nkhw.mDesc.GetLengths()[3];
|
||||
|
||||
index_t h_pad_low = InLeftPads{}.Get(Number<0>{});
|
||||
index_t w_pad_low = InLeftPads{}.Get(Number<1>{});
|
||||
|
||||
std::size_t HiPerTile = HoPerTile + Y - 1;
|
||||
std::size_t WiPerTile = WoPerTile + X - 1;
|
||||
|
||||
std::size_t HTile = (Ho + HoPerTile - 1) / HoPerTile;
|
||||
std::size_t WTile = (Wo + WoPerTile - 1) / WoPerTile;
|
||||
|
||||
Tensor<double> in_hold({N, C, HTile, WTile, HiPerTile, WiPerTile});
|
||||
Tensor<double> in_transform({N, C, HTile, WTile, HiPerTile, WiPerTile});
|
||||
Tensor<double> wei_transform({K, C, HiPerTile, WiPerTile});
|
||||
Tensor<double> out_transform({N, K, HTile, WTile, HiPerTile, HiPerTile});
|
||||
Tensor<double> out_hold({N, K, HTile, WTile, HoPerTile, WoPerTile});
|
||||
|
||||
auto f_in_hold = [&](auto n, auto c, auto htile, auto wtile) {
|
||||
for(int j = 0; j < HiPerTile; ++j)
|
||||
{
|
||||
int hi = HoPerTile * htile + j - h_pad_low;
|
||||
for(int i = 0; i < WiPerTile; ++i)
|
||||
{
|
||||
int wi = WoPerTile * wtile + i - w_pad_low;
|
||||
|
||||
if(hi >= 0 && hi < in_nchw.mDesc.GetLengths()[2] && wi >= 0 &&
|
||||
wi < in_nchw.mDesc.GetLengths()[3])
|
||||
{
|
||||
in_hold(n, c, htile, wtile, j, i) = in_nchw(n, c, hi, wi);
|
||||
}
|
||||
else
|
||||
{
|
||||
in_hold(n, c, htile, wtile, j, i) = TIn(0);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
auto f_in_transform = [&](auto n, auto c, auto htile, auto wtile) {
|
||||
in_transform(n, c, htile, wtile, 0, 0) =
|
||||
in_hold(n, c, htile, wtile, 0, 0) - in_hold(n, c, htile, wtile, 0, 2) -
|
||||
in_hold(n, c, htile, wtile, 2, 0) + in_hold(n, c, htile, wtile, 2, 2);
|
||||
in_transform(n, c, htile, wtile, 0, 1) =
|
||||
in_hold(n, c, htile, wtile, 0, 1) + in_hold(n, c, htile, wtile, 0, 2) -
|
||||
in_hold(n, c, htile, wtile, 2, 1) - in_hold(n, c, htile, wtile, 2, 2);
|
||||
in_transform(n, c, htile, wtile, 0, 2) =
|
||||
-in_hold(n, c, htile, wtile, 0, 1) + in_hold(n, c, htile, wtile, 0, 2) +
|
||||
in_hold(n, c, htile, wtile, 2, 1) - in_hold(n, c, htile, wtile, 2, 2);
|
||||
in_transform(n, c, htile, wtile, 0, 3) =
|
||||
in_hold(n, c, htile, wtile, 0, 1) - in_hold(n, c, htile, wtile, 0, 3) -
|
||||
in_hold(n, c, htile, wtile, 2, 1) + in_hold(n, c, htile, wtile, 2, 3);
|
||||
|
||||
in_transform(n, c, htile, wtile, 1, 0) =
|
||||
in_hold(n, c, htile, wtile, 1, 0) - in_hold(n, c, htile, wtile, 1, 2) +
|
||||
in_hold(n, c, htile, wtile, 2, 0) - in_hold(n, c, htile, wtile, 2, 2);
|
||||
in_transform(n, c, htile, wtile, 1, 1) =
|
||||
in_hold(n, c, htile, wtile, 1, 1) + in_hold(n, c, htile, wtile, 1, 2) +
|
||||
in_hold(n, c, htile, wtile, 2, 1) + in_hold(n, c, htile, wtile, 2, 2);
|
||||
in_transform(n, c, htile, wtile, 1, 2) =
|
||||
-in_hold(n, c, htile, wtile, 1, 1) + in_hold(n, c, htile, wtile, 1, 2) -
|
||||
in_hold(n, c, htile, wtile, 2, 1) + in_hold(n, c, htile, wtile, 2, 2);
|
||||
in_transform(n, c, htile, wtile, 1, 3) =
|
||||
in_hold(n, c, htile, wtile, 1, 1) - in_hold(n, c, htile, wtile, 1, 3) +
|
||||
in_hold(n, c, htile, wtile, 2, 1) - in_hold(n, c, htile, wtile, 2, 3);
|
||||
|
||||
in_transform(n, c, htile, wtile, 2, 0) =
|
||||
-in_hold(n, c, htile, wtile, 1, 0) + in_hold(n, c, htile, wtile, 1, 2) +
|
||||
in_hold(n, c, htile, wtile, 2, 0) - in_hold(n, c, htile, wtile, 2, 2);
|
||||
in_transform(n, c, htile, wtile, 2, 1) =
|
||||
-in_hold(n, c, htile, wtile, 1, 1) - in_hold(n, c, htile, wtile, 1, 2) +
|
||||
in_hold(n, c, htile, wtile, 2, 1) + in_hold(n, c, htile, wtile, 2, 2);
|
||||
in_transform(n, c, htile, wtile, 2, 2) =
|
||||
in_hold(n, c, htile, wtile, 1, 1) - in_hold(n, c, htile, wtile, 1, 2) -
|
||||
in_hold(n, c, htile, wtile, 2, 1) + in_hold(n, c, htile, wtile, 2, 2);
|
||||
in_transform(n, c, htile, wtile, 2, 3) =
|
||||
-in_hold(n, c, htile, wtile, 1, 1) + in_hold(n, c, htile, wtile, 1, 3) +
|
||||
in_hold(n, c, htile, wtile, 2, 1) - in_hold(n, c, htile, wtile, 2, 3);
|
||||
|
||||
in_transform(n, c, htile, wtile, 3, 0) =
|
||||
in_hold(n, c, htile, wtile, 1, 0) - in_hold(n, c, htile, wtile, 1, 2) -
|
||||
in_hold(n, c, htile, wtile, 3, 0) + in_hold(n, c, htile, wtile, 3, 2);
|
||||
in_transform(n, c, htile, wtile, 3, 1) =
|
||||
in_hold(n, c, htile, wtile, 1, 1) + in_hold(n, c, htile, wtile, 1, 2) -
|
||||
in_hold(n, c, htile, wtile, 3, 1) - in_hold(n, c, htile, wtile, 3, 2);
|
||||
in_transform(n, c, htile, wtile, 3, 2) =
|
||||
-in_hold(n, c, htile, wtile, 1, 1) + in_hold(n, c, htile, wtile, 1, 2) +
|
||||
in_hold(n, c, htile, wtile, 3, 1) - in_hold(n, c, htile, wtile, 3, 2);
|
||||
in_transform(n, c, htile, wtile, 3, 3) =
|
||||
in_hold(n, c, htile, wtile, 1, 1) - in_hold(n, c, htile, wtile, 1, 3) -
|
||||
in_hold(n, c, htile, wtile, 3, 1) + in_hold(n, c, htile, wtile, 3, 3);
|
||||
};
|
||||
|
||||
auto f_wei_transform = [&](auto k, auto c) {
|
||||
wei_transform(k, c, 0, 0) = double(wei_kcyx(k, c, 0, 0));
|
||||
wei_transform(k, c, 0, 1) = 0.5 * double(wei_kcyx(k, c, 0, 0)) +
|
||||
0.5 * double(wei_kcyx(k, c, 0, 1)) +
|
||||
0.5 * double(wei_kcyx(k, c, 0, 2));
|
||||
wei_transform(k, c, 0, 2) = 0.5 * double(wei_kcyx(k, c, 0, 0)) -
|
||||
0.5 * double(wei_kcyx(k, c, 0, 1)) +
|
||||
0.5 * double(wei_kcyx(k, c, 0, 2));
|
||||
wei_transform(k, c, 0, 3) = double(wei_kcyx(k, c, 0, 2));
|
||||
|
||||
wei_transform(k, c, 1, 0) = 0.5 * double(wei_kcyx(k, c, 0, 0)) +
|
||||
0.5 * double(wei_kcyx(k, c, 1, 0)) +
|
||||
0.5 * double(wei_kcyx(k, c, 2, 0));
|
||||
wei_transform(k, c, 1, 1) =
|
||||
0.25 * double(wei_kcyx(k, c, 0, 0)) + 0.25 * double(wei_kcyx(k, c, 0, 1)) +
|
||||
0.25 * double(wei_kcyx(k, c, 0, 2)) + 0.25 * double(wei_kcyx(k, c, 1, 0)) +
|
||||
0.25 * double(wei_kcyx(k, c, 1, 1)) + 0.25 * double(wei_kcyx(k, c, 1, 2)) +
|
||||
0.25 * double(wei_kcyx(k, c, 2, 0)) + 0.25 * double(wei_kcyx(k, c, 2, 1)) +
|
||||
0.25 * double(wei_kcyx(k, c, 2, 2));
|
||||
wei_transform(k, c, 1, 2) =
|
||||
0.25 * double(wei_kcyx(k, c, 0, 0)) - 0.25 * double(wei_kcyx(k, c, 0, 1)) +
|
||||
0.25 * double(wei_kcyx(k, c, 0, 2)) + 0.25 * double(wei_kcyx(k, c, 1, 0)) -
|
||||
0.25 * double(wei_kcyx(k, c, 1, 1)) + 0.25 * double(wei_kcyx(k, c, 1, 2)) +
|
||||
0.25 * double(wei_kcyx(k, c, 2, 0)) - 0.25 * double(wei_kcyx(k, c, 2, 1)) +
|
||||
0.25 * double(wei_kcyx(k, c, 2, 2));
|
||||
wei_transform(k, c, 1, 3) = 0.5 * double(wei_kcyx(k, c, 0, 2)) +
|
||||
0.5 * double(wei_kcyx(k, c, 1, 2)) +
|
||||
0.5 * double(wei_kcyx(k, c, 2, 2));
|
||||
|
||||
wei_transform(k, c, 2, 0) = 0.5 * double(wei_kcyx(k, c, 0, 0)) -
|
||||
0.5 * double(wei_kcyx(k, c, 1, 0)) +
|
||||
0.5 * double(wei_kcyx(k, c, 2, 0));
|
||||
wei_transform(k, c, 2, 1) =
|
||||
0.25 * double(wei_kcyx(k, c, 0, 0)) + 0.25 * double(wei_kcyx(k, c, 0, 1)) +
|
||||
0.25 * double(wei_kcyx(k, c, 0, 2)) - 0.25 * double(wei_kcyx(k, c, 1, 0)) -
|
||||
0.25 * double(wei_kcyx(k, c, 1, 1)) - 0.25 * double(wei_kcyx(k, c, 1, 2)) +
|
||||
0.25 * double(wei_kcyx(k, c, 2, 0)) + 0.25 * double(wei_kcyx(k, c, 2, 1)) +
|
||||
0.25 * double(wei_kcyx(k, c, 2, 2));
|
||||
wei_transform(k, c, 2, 2) =
|
||||
0.25 * double(wei_kcyx(k, c, 0, 0)) - 0.25 * double(wei_kcyx(k, c, 0, 1)) +
|
||||
0.25 * double(wei_kcyx(k, c, 0, 2)) - 0.25 * double(wei_kcyx(k, c, 1, 0)) +
|
||||
0.25 * double(wei_kcyx(k, c, 1, 1)) - 0.25 * double(wei_kcyx(k, c, 1, 2)) +
|
||||
0.25 * double(wei_kcyx(k, c, 2, 0)) - 0.25 * double(wei_kcyx(k, c, 2, 1)) +
|
||||
0.25 * double(wei_kcyx(k, c, 2, 2));
|
||||
wei_transform(k, c, 2, 3) = 0.5 * double(wei_kcyx(k, c, 0, 2)) -
|
||||
0.5 * double(wei_kcyx(k, c, 1, 2)) +
|
||||
0.5 * double(wei_kcyx(k, c, 2, 2));
|
||||
|
||||
wei_transform(k, c, 3, 0) = double(wei_kcyx(k, c, 2, 0));
|
||||
wei_transform(k, c, 3, 1) = 0.5 * double(wei_kcyx(k, c, 2, 0)) +
|
||||
0.5 * double(wei_kcyx(k, c, 2, 1)) +
|
||||
0.5 * double(wei_kcyx(k, c, 2, 2));
|
||||
wei_transform(k, c, 3, 2) = 0.5 * double(wei_kcyx(k, c, 2, 0)) -
|
||||
0.5 * double(wei_kcyx(k, c, 2, 1)) +
|
||||
0.5 * double(wei_kcyx(k, c, 2, 2));
|
||||
wei_transform(k, c, 3, 3) = double(wei_kcyx(k, c, 2, 2));
|
||||
};
|
||||
|
||||
auto f_out_transform = [&](auto n, auto k, auto htile, auto wtile) {
|
||||
for(int j = 0; j < HiPerTile; ++j)
|
||||
{
|
||||
for(int i = 0; i < WiPerTile; ++i)
|
||||
{
|
||||
double v = 0;
|
||||
for(int c = 0; c < C; ++c)
|
||||
{
|
||||
v += in_transform(n, c, htile, wtile, j, i) * wei_transform(k, c, j, i);
|
||||
}
|
||||
|
||||
out_transform(n, k, htile, wtile, j, i) = v;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
auto f_out_hold = [&](auto n, auto k, auto htile, auto wtile) {
|
||||
out_hold(n, k, htile, wtile, 0, 0) =
|
||||
out_transform(n, k, htile, wtile, 0, 0) + out_transform(n, k, htile, wtile, 0, 1) +
|
||||
out_transform(n, k, htile, wtile, 0, 2) + out_transform(n, k, htile, wtile, 1, 0) +
|
||||
out_transform(n, k, htile, wtile, 1, 1) + out_transform(n, k, htile, wtile, 1, 2) +
|
||||
out_transform(n, k, htile, wtile, 2, 0) + out_transform(n, k, htile, wtile, 2, 1) +
|
||||
out_transform(n, k, htile, wtile, 2, 2);
|
||||
out_hold(n, k, htile, wtile, 0, 1) =
|
||||
out_transform(n, k, htile, wtile, 0, 1) - out_transform(n, k, htile, wtile, 0, 2) -
|
||||
out_transform(n, k, htile, wtile, 0, 3) + out_transform(n, k, htile, wtile, 1, 1) -
|
||||
out_transform(n, k, htile, wtile, 1, 2) - out_transform(n, k, htile, wtile, 1, 3) +
|
||||
out_transform(n, k, htile, wtile, 2, 1) - out_transform(n, k, htile, wtile, 2, 2) -
|
||||
out_transform(n, k, htile, wtile, 2, 3);
|
||||
out_hold(n, k, htile, wtile, 1, 0) =
|
||||
out_transform(n, k, htile, wtile, 1, 0) + out_transform(n, k, htile, wtile, 1, 1) +
|
||||
out_transform(n, k, htile, wtile, 1, 2) - out_transform(n, k, htile, wtile, 2, 0) -
|
||||
out_transform(n, k, htile, wtile, 2, 1) - out_transform(n, k, htile, wtile, 2, 2) -
|
||||
out_transform(n, k, htile, wtile, 3, 0) - out_transform(n, k, htile, wtile, 3, 1) -
|
||||
out_transform(n, k, htile, wtile, 3, 2);
|
||||
out_hold(n, k, htile, wtile, 1, 1) =
|
||||
out_transform(n, k, htile, wtile, 1, 1) - out_transform(n, k, htile, wtile, 1, 2) -
|
||||
out_transform(n, k, htile, wtile, 1, 3) - out_transform(n, k, htile, wtile, 2, 1) +
|
||||
out_transform(n, k, htile, wtile, 2, 2) + out_transform(n, k, htile, wtile, 2, 3) -
|
||||
out_transform(n, k, htile, wtile, 3, 1) + out_transform(n, k, htile, wtile, 3, 2) +
|
||||
out_transform(n, k, htile, wtile, 3, 3);
|
||||
};
|
||||
|
||||
auto f_out = [&](auto n, auto k, auto htile, auto wtile) {
|
||||
for(int j = 0; j < HoPerTile; ++j)
|
||||
{
|
||||
std::size_t ho = HoPerTile * htile + j;
|
||||
for(int i = 0; i < WoPerTile; ++i)
|
||||
{
|
||||
std::size_t wo = WoPerTile * wtile + i;
|
||||
out_nkhw(n, k, ho, wo) = out_hold(n, k, htile, wtile, j, i);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
std::size_t num_thread = std::thread::hardware_concurrency();
|
||||
|
||||
make_ParallelTensorFunctor(f_in_hold, N, C, HTile, WTile)(num_thread);
|
||||
make_ParallelTensorFunctor(f_in_transform, N, C, HTile, WTile)(num_thread);
|
||||
make_ParallelTensorFunctor(f_wei_transform, K, C)(num_thread);
|
||||
make_ParallelTensorFunctor(f_out_transform, N, K, HTile, WTile)(num_thread);
|
||||
make_ParallelTensorFunctor(f_out_hold, N, K, HTile, WTile)(num_thread);
|
||||
make_ParallelTensorFunctor(f_out, N, K, HTile, WTile)(num_thread);
|
||||
}
|
||||
135
host/host_tensor/include/host_conv_bwd_data.hpp
Normal file
135
host/host_tensor/include/host_conv_bwd_data.hpp
Normal file
@@ -0,0 +1,135 @@
|
||||
#pragma once
|
||||
#include "host_tensor.hpp"
|
||||
|
||||
template <typename TIn,
|
||||
typename TWei,
|
||||
typename TOut,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads>
|
||||
void host_direct_convolution_backward_data(Tensor<TIn>& in,
|
||||
const Tensor<TWei>& wei,
|
||||
const Tensor<TOut>& out,
|
||||
const ConvStrides& conv_strides,
|
||||
const ConvDilations& conv_dilations,
|
||||
const InLeftPads& in_left_pads,
|
||||
const InRightPads& /* in_right_pads */,
|
||||
const ConvTensorLayout layout = ConvTensorLayout::NCHW)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
auto f_nchw = [&](auto n, auto c, auto hi, auto wi) {
|
||||
std::size_t K = wei.mDesc.GetLengths()[I0];
|
||||
std::size_t Y = wei.mDesc.GetLengths()[I2];
|
||||
std::size_t X = wei.mDesc.GetLengths()[I3];
|
||||
|
||||
std::size_t Ho = out.mDesc.GetLengths()[I2];
|
||||
std::size_t Wo = out.mDesc.GetLengths()[I3];
|
||||
|
||||
double v = 0;
|
||||
|
||||
for(int y = 0; y < Y; ++y)
|
||||
{
|
||||
int h_tmp = hi + in_left_pads[I0] - y * conv_dilations[I0];
|
||||
|
||||
if(h_tmp % conv_strides[I0] == 0)
|
||||
{
|
||||
int ho = h_tmp / conv_strides[I0];
|
||||
|
||||
if(ho >= 0 && ho < Ho)
|
||||
{
|
||||
for(int x = 0; x < X; ++x)
|
||||
{
|
||||
int w_tmp = wi + in_left_pads[I1] - x * conv_dilations[I1];
|
||||
|
||||
if(w_tmp % conv_strides[I1] == 0)
|
||||
{
|
||||
int wo = w_tmp / conv_strides[I1];
|
||||
|
||||
if(wo >= 0 && wo < Wo)
|
||||
{
|
||||
for(int k = 0; k < K; ++k)
|
||||
{
|
||||
v += out(n, k, ho, wo) * wei(k, c, y, x);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
in(n, c, hi, wi) = v;
|
||||
};
|
||||
|
||||
auto f_nhwc = [&](auto n, auto hi, auto wi, auto c) {
|
||||
std::size_t K = wei.mDesc.GetLengths()[I0];
|
||||
std::size_t Y = wei.mDesc.GetLengths()[I1];
|
||||
std::size_t X = wei.mDesc.GetLengths()[I2];
|
||||
|
||||
std::size_t Ho = out.mDesc.GetLengths()[I1];
|
||||
std::size_t Wo = out.mDesc.GetLengths()[I2];
|
||||
|
||||
double v = 0;
|
||||
|
||||
for(int y = 0; y < Y; ++y)
|
||||
{
|
||||
int h_tmp = hi + in_left_pads[I0] - y * conv_dilations[I0];
|
||||
|
||||
if(h_tmp % conv_strides[I0] == 0)
|
||||
{
|
||||
int ho = h_tmp / conv_strides[I0];
|
||||
|
||||
if(ho >= 0 && ho < Ho)
|
||||
{
|
||||
for(int x = 0; x < X; ++x)
|
||||
{
|
||||
int w_tmp = wi + in_left_pads[I1] - x * conv_dilations[I1];
|
||||
|
||||
if(w_tmp % conv_strides[I1] == 0)
|
||||
{
|
||||
int wo = w_tmp / conv_strides[I1];
|
||||
|
||||
if(wo >= 0 && wo < Wo)
|
||||
{
|
||||
for(int k = 0; k < K; ++k)
|
||||
{
|
||||
v += out(n, ho, wo, k) * wei(k, y, x, c);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
in(n, hi, wi, c) = v;
|
||||
};
|
||||
|
||||
if(layout == ConvTensorLayout::NCHW)
|
||||
{
|
||||
make_ParallelTensorFunctor(f_nchw,
|
||||
in.mDesc.GetLengths()[0],
|
||||
in.mDesc.GetLengths()[1],
|
||||
in.mDesc.GetLengths()[2],
|
||||
in.mDesc.GetLengths()[3])(std::thread::hardware_concurrency());
|
||||
}
|
||||
else if(layout == ConvTensorLayout::NHWC)
|
||||
{
|
||||
make_ParallelTensorFunctor(f_nhwc,
|
||||
in.mDesc.GetLengths()[0],
|
||||
in.mDesc.GetLengths()[1],
|
||||
in.mDesc.GetLengths()[2],
|
||||
in.mDesc.GetLengths()[3])(std::thread::hardware_concurrency());
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("wrong! not supported layout");
|
||||
}
|
||||
}
|
||||
322
host/host_tensor/include/host_tensor.hpp
Normal file
322
host/host_tensor/include/host_tensor.hpp
Normal file
@@ -0,0 +1,322 @@
|
||||
#ifndef HOST_TENSOR_HPP
|
||||
#define HOST_TENSOR_HPP
|
||||
|
||||
#include <thread>
|
||||
#include <vector>
|
||||
#include <numeric>
|
||||
#include <algorithm>
|
||||
#include <utility>
|
||||
#include <cassert>
|
||||
#include <iostream>
|
||||
|
||||
template <typename Range>
|
||||
std::ostream& LogRange(std::ostream& os, Range&& range, std::string delim)
|
||||
{
|
||||
bool first = true;
|
||||
for(auto&& v : range)
|
||||
{
|
||||
if(first)
|
||||
first = false;
|
||||
else
|
||||
os << delim;
|
||||
os << v;
|
||||
}
|
||||
return os;
|
||||
}
|
||||
|
||||
template <typename T, typename Range>
|
||||
std::ostream& LogRangeAsType(std::ostream& os, Range&& range, std::string delim)
|
||||
{
|
||||
bool first = true;
|
||||
for(auto&& v : range)
|
||||
{
|
||||
if(first)
|
||||
first = false;
|
||||
else
|
||||
os << delim;
|
||||
os << static_cast<T>(v);
|
||||
}
|
||||
return os;
|
||||
}
|
||||
|
||||
typedef enum
|
||||
{
|
||||
Half = 0,
|
||||
Float = 1,
|
||||
} DataType_t;
|
||||
|
||||
template <typename T>
|
||||
struct DataType;
|
||||
|
||||
template <>
|
||||
struct DataType<float> : std::integral_constant<DataType_t, DataType_t::Float>
|
||||
{
|
||||
};
|
||||
|
||||
template <typename F, typename T, std::size_t... Is>
|
||||
auto call_f_unpack_args_impl(F f, T args, std::index_sequence<Is...>)
|
||||
{
|
||||
return f(std::get<Is>(args)...);
|
||||
}
|
||||
|
||||
template <typename F, typename T>
|
||||
auto call_f_unpack_args(F f, T args)
|
||||
{
|
||||
constexpr std::size_t N = std::tuple_size<T>{};
|
||||
|
||||
return call_f_unpack_args_impl(f, args, std::make_index_sequence<N>{});
|
||||
}
|
||||
|
||||
template <typename F, typename T, std::size_t... Is>
|
||||
auto construct_f_unpack_args_impl(T args, std::index_sequence<Is...>)
|
||||
{
|
||||
return F(std::get<Is>(args)...);
|
||||
}
|
||||
|
||||
template <typename F, typename T>
|
||||
auto construct_f_unpack_args(F, T args)
|
||||
{
|
||||
constexpr std::size_t N = std::tuple_size<T>{};
|
||||
|
||||
return construct_f_unpack_args_impl<F>(args, std::make_index_sequence<N>{});
|
||||
}
|
||||
|
||||
struct HostTensorDescriptor
|
||||
{
|
||||
HostTensorDescriptor() = delete;
|
||||
|
||||
template <typename X>
|
||||
HostTensorDescriptor(std::vector<X> lens);
|
||||
|
||||
template <typename X, typename Y>
|
||||
HostTensorDescriptor(std::vector<X> lens, std::vector<Y> strides);
|
||||
|
||||
void CalculateStrides();
|
||||
|
||||
template <typename Range>
|
||||
HostTensorDescriptor(const Range& lens) : mLens(lens.begin(), lens.end())
|
||||
{
|
||||
this->CalculateStrides();
|
||||
}
|
||||
|
||||
template <typename Range1, typename Range2>
|
||||
HostTensorDescriptor(const Range1& lens, const Range2& strides)
|
||||
: mLens(lens.begin(), lens.end()), mStrides(strides.begin(), strides.end())
|
||||
{
|
||||
}
|
||||
|
||||
std::size_t GetNumOfDimension() const;
|
||||
std::size_t GetElementSize() const;
|
||||
std::size_t GetElementSpace() const;
|
||||
|
||||
const std::vector<std::size_t>& GetLengths() const;
|
||||
const std::vector<std::size_t>& GetStrides() const;
|
||||
|
||||
template <typename... Is>
|
||||
std::size_t GetOffsetFromMultiIndex(Is... is) const
|
||||
{
|
||||
assert(sizeof...(Is) == this->GetNumOfDimension());
|
||||
std::initializer_list<std::size_t> iss{static_cast<std::size_t>(is)...};
|
||||
return std::inner_product(iss.begin(), iss.end(), mStrides.begin(), std::size_t{0});
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<std::size_t> mLens;
|
||||
std::vector<std::size_t> mStrides;
|
||||
};
|
||||
|
||||
struct joinable_thread : std::thread
|
||||
{
|
||||
template <typename... Xs>
|
||||
joinable_thread(Xs&&... xs) : std::thread(std::forward<Xs>(xs)...)
|
||||
{
|
||||
}
|
||||
|
||||
joinable_thread(joinable_thread&&) = default;
|
||||
joinable_thread& operator=(joinable_thread&&) = default;
|
||||
|
||||
~joinable_thread()
|
||||
{
|
||||
if(this->joinable())
|
||||
this->join();
|
||||
}
|
||||
};
|
||||
|
||||
template <typename F, typename... Xs>
|
||||
struct ParallelTensorFunctor
|
||||
{
|
||||
F mF;
|
||||
static constexpr std::size_t NDIM = sizeof...(Xs);
|
||||
std::array<std::size_t, NDIM> mLens;
|
||||
std::array<std::size_t, NDIM> mStrides;
|
||||
std::size_t mN1d;
|
||||
|
||||
ParallelTensorFunctor(F f, Xs... xs) : mF(f), mLens({static_cast<std::size_t>(xs)...})
|
||||
{
|
||||
mStrides.back() = 1;
|
||||
std::partial_sum(mLens.rbegin(),
|
||||
mLens.rend() - 1,
|
||||
mStrides.rbegin() + 1,
|
||||
std::multiplies<std::size_t>());
|
||||
mN1d = mStrides[0] * mLens[0];
|
||||
}
|
||||
|
||||
std::array<std::size_t, NDIM> GetNdIndices(std::size_t i) const
|
||||
{
|
||||
std::array<std::size_t, NDIM> indices;
|
||||
|
||||
for(int idim = 0; idim < NDIM; ++idim)
|
||||
{
|
||||
indices[idim] = i / mStrides[idim];
|
||||
i -= indices[idim] * mStrides[idim];
|
||||
}
|
||||
|
||||
return indices;
|
||||
}
|
||||
|
||||
void operator()(std::size_t num_thread = std::thread::hardware_concurrency()) const
|
||||
{
|
||||
std::size_t work_per_thread = (mN1d + num_thread - 1) / num_thread;
|
||||
|
||||
std::vector<joinable_thread> threads(num_thread);
|
||||
|
||||
for(std::size_t it = 0; it < num_thread; ++it)
|
||||
{
|
||||
std::size_t iw_begin = it * work_per_thread;
|
||||
std::size_t iw_end = std::min((it + 1) * work_per_thread, mN1d);
|
||||
|
||||
auto f = [=] {
|
||||
for(std::size_t iw = iw_begin; iw < iw_end; ++iw)
|
||||
{
|
||||
call_f_unpack_args(mF, GetNdIndices(iw));
|
||||
}
|
||||
};
|
||||
threads[it] = joinable_thread(f);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename F, typename... Xs>
|
||||
auto make_ParallelTensorFunctor(F f, Xs... xs)
|
||||
{
|
||||
return ParallelTensorFunctor<F, Xs...>(f, xs...);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
struct Tensor
|
||||
{
|
||||
template <typename X>
|
||||
Tensor(std::initializer_list<X> lens) : mDesc(lens), mData(mDesc.GetElementSpace())
|
||||
{
|
||||
}
|
||||
|
||||
template <typename X>
|
||||
Tensor(std::vector<X> lens) : mDesc(lens), mData(mDesc.GetElementSpace())
|
||||
{
|
||||
}
|
||||
|
||||
template <typename X, typename Y>
|
||||
Tensor(std::vector<X> lens, std::vector<Y> strides)
|
||||
: mDesc(lens, strides), mData(mDesc.GetElementSpace())
|
||||
{
|
||||
}
|
||||
|
||||
Tensor(const HostTensorDescriptor& desc) : mDesc(desc), mData(mDesc.GetElementSpace()) {}
|
||||
|
||||
template <typename G>
|
||||
void GenerateTensorValue(G g, std::size_t num_thread = 1)
|
||||
{
|
||||
switch(mDesc.GetNumOfDimension())
|
||||
{
|
||||
case 1: {
|
||||
auto f = [&](auto i) { (*this)(i) = g(i); };
|
||||
make_ParallelTensorFunctor(f, mDesc.GetLengths()[0])(num_thread);
|
||||
break;
|
||||
}
|
||||
case 2: {
|
||||
auto f = [&](auto i0, auto i1) { (*this)(i0, i1) = g(i0, i1); };
|
||||
make_ParallelTensorFunctor(f, mDesc.GetLengths()[0], mDesc.GetLengths()[1])(num_thread);
|
||||
break;
|
||||
}
|
||||
case 3: {
|
||||
auto f = [&](auto i0, auto i1, auto i2) { (*this)(i0, i1, i2) = g(i0, i1, i2); };
|
||||
make_ParallelTensorFunctor(
|
||||
f, mDesc.GetLengths()[0], mDesc.GetLengths()[1], mDesc.GetLengths()[2])(num_thread);
|
||||
break;
|
||||
}
|
||||
case 4: {
|
||||
auto f = [&](auto i0, auto i1, auto i2, auto i3) {
|
||||
(*this)(i0, i1, i2, i3) = g(i0, i1, i2, i3);
|
||||
};
|
||||
make_ParallelTensorFunctor(f,
|
||||
mDesc.GetLengths()[0],
|
||||
mDesc.GetLengths()[1],
|
||||
mDesc.GetLengths()[2],
|
||||
mDesc.GetLengths()[3])(num_thread);
|
||||
break;
|
||||
}
|
||||
default: throw std::runtime_error("unspported dimension");
|
||||
}
|
||||
}
|
||||
|
||||
template <typename... Is>
|
||||
T& operator()(Is... is)
|
||||
{
|
||||
return mData[mDesc.GetOffsetFromMultiIndex(is...)];
|
||||
}
|
||||
|
||||
template <typename... Is>
|
||||
const T& operator()(Is... is) const
|
||||
{
|
||||
return mData[mDesc.GetOffsetFromMultiIndex(is...)];
|
||||
}
|
||||
|
||||
typename std::vector<T>::iterator begin() { return mData.begin(); }
|
||||
|
||||
typename std::vector<T>::iterator end() { return mData.end(); }
|
||||
|
||||
typename std::vector<T>::const_iterator begin() const { return mData.begin(); }
|
||||
|
||||
typename std::vector<T>::const_iterator end() const { return mData.end(); }
|
||||
|
||||
HostTensorDescriptor mDesc;
|
||||
std::vector<T> mData;
|
||||
};
|
||||
|
||||
template <typename X>
|
||||
HostTensorDescriptor::HostTensorDescriptor(std::vector<X> lens) : mLens(lens)
|
||||
{
|
||||
this->CalculateStrides();
|
||||
}
|
||||
|
||||
template <typename X, typename Y>
|
||||
HostTensorDescriptor::HostTensorDescriptor(std::vector<X> lens, std::vector<Y> strides)
|
||||
: mLens(lens), mStrides(strides)
|
||||
{
|
||||
}
|
||||
|
||||
void ostream_HostTensorDescriptor(const HostTensorDescriptor& desc, std::ostream& os = std::cout);
|
||||
|
||||
template <typename T>
|
||||
void check_error(const Tensor<T>& ref, const Tensor<T>& result)
|
||||
{
|
||||
float error = 0;
|
||||
float max_diff = -1;
|
||||
float ref_value = 0, result_value = 0;
|
||||
for(int i = 0; i < ref.mData.size(); ++i)
|
||||
{
|
||||
error += std::abs(double(ref.mData[i]) - double(result.mData[i]));
|
||||
float diff = std::abs(double(ref.mData[i]) - double(result.mData[i]));
|
||||
if(max_diff < diff)
|
||||
{
|
||||
max_diff = diff;
|
||||
ref_value = ref.mData[i];
|
||||
result_value = result.mData[i];
|
||||
}
|
||||
}
|
||||
|
||||
std::cout << "error: " << error << std::endl;
|
||||
std::cout << "max_diff: " << max_diff << ", " << ref_value << ", " << result_value << std::endl;
|
||||
}
|
||||
|
||||
#endif
|
||||
60
host/host_tensor/include/host_tensor_generator.hpp
Normal file
60
host/host_tensor/include/host_tensor_generator.hpp
Normal file
@@ -0,0 +1,60 @@
|
||||
#ifndef HOST_TENSOR_GENERATOR_HPP
|
||||
#define HOST_TENSOR_GENERATOR_HPP
|
||||
|
||||
#include <cmath>
|
||||
#include "config.hpp"
|
||||
|
||||
struct GeneratorTensor_1
|
||||
{
|
||||
int value = 1;
|
||||
|
||||
template <typename... Is>
|
||||
float operator()(Is...)
|
||||
{
|
||||
return value;
|
||||
}
|
||||
};
|
||||
|
||||
struct GeneratorTensor_2
|
||||
{
|
||||
int min_value = 0;
|
||||
int max_value = 1;
|
||||
|
||||
template <typename... Is>
|
||||
float operator()(Is...)
|
||||
{
|
||||
return (std::rand() % (max_value - min_value)) + min_value;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct GeneratorTensor_3
|
||||
{
|
||||
T min_value = 0;
|
||||
T max_value = 1;
|
||||
|
||||
template <typename... Is>
|
||||
float operator()(Is...)
|
||||
{
|
||||
float tmp = float(std::rand()) / float(RAND_MAX);
|
||||
|
||||
return min_value + tmp * (max_value - min_value);
|
||||
}
|
||||
};
|
||||
|
||||
struct GeneratorTensor_Checkboard
|
||||
{
|
||||
template <typename... Ts>
|
||||
float operator()(Ts... Xs) const
|
||||
{
|
||||
std::array<ck::index_t, sizeof...(Ts)> dims = {{static_cast<ck::index_t>(Xs)...}};
|
||||
return std::accumulate(dims.begin(),
|
||||
dims.end(),
|
||||
true,
|
||||
[](bool init, ck::index_t x) -> int { return init != (x % 2); })
|
||||
? 1
|
||||
: -1;
|
||||
}
|
||||
};
|
||||
|
||||
#endif
|
||||
67
host/host_tensor/src/device.cpp
Normal file
67
host/host_tensor/src/device.cpp
Normal file
@@ -0,0 +1,67 @@
|
||||
#include "device.hpp"
|
||||
|
||||
DeviceMem::DeviceMem(std::size_t mem_size) : mMemSize(mem_size)
|
||||
{
|
||||
hipGetErrorString(hipMalloc(static_cast<void**>(&mpDeviceBuf), mMemSize));
|
||||
}
|
||||
|
||||
void* DeviceMem::GetDeviceBuffer() { return mpDeviceBuf; }
|
||||
|
||||
void DeviceMem::ToDevice(const void* p)
|
||||
{
|
||||
hipGetErrorString(
|
||||
hipMemcpy(mpDeviceBuf, const_cast<void*>(p), mMemSize, hipMemcpyHostToDevice));
|
||||
}
|
||||
|
||||
void DeviceMem::FromDevice(void* p)
|
||||
{
|
||||
hipGetErrorString(hipMemcpy(p, mpDeviceBuf, mMemSize, hipMemcpyDeviceToHost));
|
||||
}
|
||||
|
||||
DeviceMem::~DeviceMem() { hipGetErrorString(hipFree(mpDeviceBuf)); }
|
||||
|
||||
struct KernelTimerImpl
|
||||
{
|
||||
KernelTimerImpl()
|
||||
{
|
||||
hipGetErrorString(hipEventCreate(&mStart));
|
||||
hipGetErrorString(hipEventCreate(&mEnd));
|
||||
}
|
||||
|
||||
~KernelTimerImpl()
|
||||
{
|
||||
hipGetErrorString(hipEventDestroy(mStart));
|
||||
hipGetErrorString(hipEventDestroy(mEnd));
|
||||
}
|
||||
|
||||
void Start()
|
||||
{
|
||||
hipGetErrorString(hipDeviceSynchronize());
|
||||
hipGetErrorString(hipEventRecord(mStart, nullptr));
|
||||
}
|
||||
|
||||
void End()
|
||||
{
|
||||
hipGetErrorString(hipEventRecord(mEnd, nullptr));
|
||||
hipGetErrorString(hipEventSynchronize(mEnd));
|
||||
}
|
||||
|
||||
float GetElapsedTime() const
|
||||
{
|
||||
float time;
|
||||
hipGetErrorString(hipEventElapsedTime(&time, mStart, mEnd));
|
||||
return time;
|
||||
}
|
||||
|
||||
hipEvent_t mStart, mEnd;
|
||||
};
|
||||
|
||||
KernelTimer::KernelTimer() : impl(new KernelTimerImpl()) {}
|
||||
|
||||
KernelTimer::~KernelTimer() {}
|
||||
|
||||
void KernelTimer::Start() { impl->Start(); }
|
||||
|
||||
void KernelTimer::End() { impl->End(); }
|
||||
|
||||
float KernelTimer::GetElapsedTime() const { return impl->GetElapsedTime(); }
|
||||
48
host/host_tensor/src/host_tensor.cpp
Normal file
48
host/host_tensor/src/host_tensor.cpp
Normal file
@@ -0,0 +1,48 @@
|
||||
#include <boost/range/adaptor/transformed.hpp>
|
||||
#include <cassert>
|
||||
|
||||
#include "host_tensor.hpp"
|
||||
|
||||
void HostTensorDescriptor::CalculateStrides()
|
||||
{
|
||||
mStrides.clear();
|
||||
mStrides.resize(mLens.size(), 0);
|
||||
if(mStrides.empty())
|
||||
return;
|
||||
|
||||
mStrides.back() = 1;
|
||||
std::partial_sum(
|
||||
mLens.rbegin(), mLens.rend() - 1, mStrides.rbegin() + 1, std::multiplies<std::size_t>());
|
||||
}
|
||||
|
||||
std::size_t HostTensorDescriptor::GetNumOfDimension() const { return mLens.size(); }
|
||||
|
||||
std::size_t HostTensorDescriptor::GetElementSize() const
|
||||
{
|
||||
assert(mLens.size() == mStrides.size());
|
||||
return std::accumulate(
|
||||
mLens.begin(), mLens.end(), std::size_t{1}, std::multiplies<std::size_t>());
|
||||
}
|
||||
|
||||
std::size_t HostTensorDescriptor::GetElementSpace() const
|
||||
{
|
||||
auto ls = mLens | boost::adaptors::transformed([](std::size_t v) { return v - 1; });
|
||||
return std::inner_product(ls.begin(), ls.end(), mStrides.begin(), std::size_t{0}) + 1;
|
||||
}
|
||||
|
||||
const std::vector<std::size_t>& HostTensorDescriptor::GetLengths() const { return mLens; }
|
||||
|
||||
const std::vector<std::size_t>& HostTensorDescriptor::GetStrides() const { return mStrides; }
|
||||
|
||||
void ostream_HostTensorDescriptor(const HostTensorDescriptor& desc, std::ostream& os)
|
||||
{
|
||||
os << "dim " << desc.GetNumOfDimension() << ", ";
|
||||
|
||||
os << "lengths {";
|
||||
LogRange(os, desc.GetLengths(), ", ");
|
||||
os << "}, ";
|
||||
|
||||
os << "strides {";
|
||||
LogRange(os, desc.GetStrides(), ", ");
|
||||
os << "}" << std::endl;
|
||||
}
|
||||
689
host/solver/include/conv_igemm_fwd_v6r1_dlops_nchw_kcyx_nkhw.hpp
Normal file
689
host/solver/include/conv_igemm_fwd_v6r1_dlops_nchw_kcyx_nkhw.hpp
Normal file
@@ -0,0 +1,689 @@
|
||||
#ifndef CONV_IGEMM_FWD_V6R1_DLOPS_NCHW_KCYX_NKHW_HPP
|
||||
#define CONV_IGEMM_FWD_V6R1_DLOPS_NCHW_KCYX_NKHW_HPP
|
||||
|
||||
#include <numeric>
|
||||
#include <sstream>
|
||||
|
||||
namespace ck {
|
||||
namespace driver {
|
||||
|
||||
struct CompileParameterConvIgemmFwdV6r1DlopsNchwKcyxNkhw
|
||||
{
|
||||
auto GetCompileParameterString() const
|
||||
{
|
||||
auto param = std::stringstream();
|
||||
|
||||
// clang-format off
|
||||
param <<
|
||||
" -DCK_PARAM_ABDataTypeEnum=" <<
|
||||
ABDataTypeEnum <<
|
||||
" -DCK_PARAM_AccDataTypeEnum=" <<
|
||||
AccDataTypeEnum <<
|
||||
" -DCK_PARAM_CDataTypeEnum=" <<
|
||||
CDataTypeEnum <<
|
||||
" -DCK_PARAM_BlockSize=" <<
|
||||
BlockSize <<
|
||||
" -DCK_PARAM_GN0=" <<
|
||||
GN0 <<
|
||||
" -DCK_PARAM_GK1=" <<
|
||||
GK1 <<
|
||||
" -DCK_PARAM_GM1PerBlockGM11="
|
||||
<< GM1PerBlockGM11 <<
|
||||
" -DCK_PARAM_GN1PerBlockGN11=" <<
|
||||
GN1PerBlockGN11 <<
|
||||
" -DCK_PARAM_GK0PerBlock=" <<
|
||||
GK0PerBlock <<
|
||||
" -DCK_PARAM_BM1PerThreadBM11=" <<
|
||||
BM1PerThreadBM11 <<
|
||||
" -DCK_PARAM_BN1PerThreadBN11=" <<
|
||||
BN1PerThreadBN11 <<
|
||||
" -DCK_PARAM_BK0PerThread=" <<
|
||||
BK0PerThread <<
|
||||
" -DCK_PARAM_BM10BN10ThreadClusterBM10Xs=" <<
|
||||
BM10BN10ThreadClusterBM10Xs[0] << "," <<
|
||||
BM10BN10ThreadClusterBM10Xs[1] <<
|
||||
" -DCK_PARAM_BM10BN10ThreadClusterBN10Xs=" <<
|
||||
BM10BN10ThreadClusterBN10Xs[0] << "," <<
|
||||
BM10BN10ThreadClusterBN10Xs[1] <<
|
||||
" -DCK_PARAM_ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1=" <<
|
||||
ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1[0] << "," <<
|
||||
ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1[1] << "," <<
|
||||
ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1[2] << "," <<
|
||||
ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1[3] << "," <<
|
||||
ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1[4] <<
|
||||
" -DCK_PARAM_ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1=" <<
|
||||
ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1[0] << "," <<
|
||||
ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1[1] << "," <<
|
||||
ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1[2] << "," <<
|
||||
ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1[3] << "," <<
|
||||
ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1[4] <<
|
||||
" -DCK_PARAM_ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1=" <<
|
||||
ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1[0] << "," <<
|
||||
ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1[1] << "," <<
|
||||
ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1[2] << "," <<
|
||||
ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1[3] << "," <<
|
||||
ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1[4] <<
|
||||
" -DCK_PARAM_ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1=" <<
|
||||
ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1[0] << "," <<
|
||||
ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1[1] << "," <<
|
||||
ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1[2] << "," <<
|
||||
ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1[3] << "," <<
|
||||
ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1[4] <<
|
||||
" -DCK_PARAM_BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1=" <<
|
||||
BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1[0] << "," <<
|
||||
BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1[1] << "," <<
|
||||
BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1[2] << "," <<
|
||||
BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1[3] << "," <<
|
||||
BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1[4] <<
|
||||
" -DCK_PARAM_BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1=" <<
|
||||
BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1[0] << "," <<
|
||||
BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1[1] << "," <<
|
||||
BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1[2] << "," <<
|
||||
BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1[3] << "," <<
|
||||
BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1[4] <<
|
||||
" -DCK_PARAM_BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1=" <<
|
||||
BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1[0] << "," <<
|
||||
BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1[1] << "," <<
|
||||
BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1[2] << "," <<
|
||||
BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1[3] << "," <<
|
||||
BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1[4] <<
|
||||
" -DCK_PARAM_BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1=" <<
|
||||
BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1[0] << "," <<
|
||||
BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1[1] << "," <<
|
||||
BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1[2] << "," <<
|
||||
BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1[3] << "," <<
|
||||
BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1[4] <<
|
||||
" -DCK_PARAM_CThreadTransferDstScalarPerVector=" <<
|
||||
CThreadTransferDstScalarPerVector <<
|
||||
" -DCK_PARAM_HasMainKBlockLoop=" <<
|
||||
static_cast<int>(HasMainKBlockLoop) <<
|
||||
" -DCK_PARAM_HasDoubleTailKBlockLoop=" <<
|
||||
static_cast<int>(HasDoubleTailKBlockLoop);
|
||||
// clang-format on
|
||||
|
||||
return param.str();
|
||||
}
|
||||
|
||||
ck::DataTypeEnum_t ABDataTypeEnum = ck::DataTypeEnum_t::Unknown;
|
||||
ck::DataTypeEnum_t AccDataTypeEnum = ck::DataTypeEnum_t::Unknown;
|
||||
ck::DataTypeEnum_t CDataTypeEnum = ck::DataTypeEnum_t::Unknown;
|
||||
|
||||
int BlockSize = -1;
|
||||
|
||||
int GN0 = -1;
|
||||
int GK1 = -1;
|
||||
|
||||
int GM1PerBlockGM11 = -1;
|
||||
int GN1PerBlockGN11 = -1;
|
||||
int GK0PerBlock = -1;
|
||||
|
||||
int BM1PerThreadBM11 = -1;
|
||||
int BN1PerThreadBN11 = -1;
|
||||
int BK0PerThread = -1;
|
||||
|
||||
std::array<int, 2> BM10BN10ThreadClusterBM10Xs = {-1, -1};
|
||||
std::array<int, 2> BM10BN10ThreadClusterBN10Xs = {-1, -1};
|
||||
|
||||
std::array<int, 5> ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1 = {
|
||||
-1, -1, -1, -1, -1};
|
||||
std::array<int, 5> ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1 = {
|
||||
-1, -1, -1, -1, -1};
|
||||
std::array<int, 5> ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1 = {
|
||||
-1, -1, -1, -1, -1};
|
||||
std::array<int, 5> ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1 = {
|
||||
-1, -1, -1, -1, -1};
|
||||
|
||||
std::array<int, 5> BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1 = {
|
||||
-1, -1, -1, -1, -1};
|
||||
std::array<int, 5> BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1 = {
|
||||
-1, -1, -1, -1, -1};
|
||||
std::array<int, 5> BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1 = {
|
||||
-1, -1, -1, -1, -1};
|
||||
std::array<int, 5> BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1 = {
|
||||
-1, -1, -1, -1, -1};
|
||||
|
||||
int CThreadTransferDstScalarPerVector = -1;
|
||||
|
||||
bool HasMainKBlockLoop = false;
|
||||
bool HasDoubleTailKBlockLoop = false;
|
||||
};
|
||||
|
||||
struct TunableConvIgemmFwdV6r1DlopsNchwKcyxNkhw
|
||||
{
|
||||
ck::DataTypeEnum_t ABDataTypeEnum;
|
||||
ck::DataTypeEnum_t CDataTypeEnum;
|
||||
|
||||
int BlockSize;
|
||||
|
||||
int GN0;
|
||||
int GK1;
|
||||
|
||||
int GM1PerBlockGM11;
|
||||
int GN1PerBlockGN11;
|
||||
int GK0PerBlock;
|
||||
|
||||
int BM1PerThreadBM11;
|
||||
int BN1PerThreadBN11;
|
||||
int BK0PerThread;
|
||||
|
||||
std::array<int, 2> BM10BN10ThreadClusterBM10Xs;
|
||||
std::array<int, 2> BM10BN10ThreadClusterBN10Xs;
|
||||
|
||||
std::array<int, 5> ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1;
|
||||
std::array<int, 5> ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1;
|
||||
std::array<int, 5> ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1;
|
||||
std::array<int, 5> ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1;
|
||||
|
||||
std::array<int, 5> BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1;
|
||||
std::array<int, 5> BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1;
|
||||
std::array<int, 5> BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1;
|
||||
std::array<int, 5> BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1;
|
||||
};
|
||||
|
||||
inline static auto generate_tunable_list_conv_igemm_fwd_v6r1_dlops_nchw_kcyx_nkhw()
|
||||
{
|
||||
constexpr auto f32 = ck::DataTypeEnum_t::Float;
|
||||
constexpr auto f16 = ck::DataTypeEnum_t::Half;
|
||||
constexpr auto i8 = ck::DataTypeEnum_t::Int8;
|
||||
|
||||
return std::vector<TunableConvIgemmFwdV6r1DlopsNchwKcyxNkhw>{
|
||||
// clang-format off
|
||||
// fp32
|
||||
{f32, f32, 256, 1, 1, 128, 128, 16, 4, 4, 1, {8, 2}, {8, 2}, {4, 1, 1, 2, 1}, {4, 1, 1, 64, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {2, 1, 1, 4, 1}, { 8, 1, 1, 32, 1}, {1, 1, 1, 4, 1}, {1, 1, 1, 4, 1}},
|
||||
|
||||
{f32, f32, 256, 1, 1, 128, 128, 8, 4, 4, 1, {8, 2}, {8, 2}, {4, 1, 1, 1, 1}, {2, 1, 1, 128, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 4, 1}, { 8, 1, 1, 32, 1}, {1, 1, 1, 4, 1}, {1, 1, 1, 4, 1}},
|
||||
{f32, f32, 256, 1, 1, 128, 128, 8, 4, 4, 1, {8, 2}, {8, 2}, {4, 1, 1, 1, 1}, {2, 1, 1, 128, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 4, 1}, { 8, 1, 1, 32, 1}, {1, 1, 1, 2, 1}, {1, 1, 1, 4, 1}},
|
||||
{f32, f32, 256, 1, 1, 128, 128, 8, 4, 4, 1, {8, 2}, {8, 2}, {4, 1, 1, 1, 1}, {2, 1, 1, 128, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 4, 1}, { 8, 1, 1, 32, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 4, 1}},
|
||||
|
||||
{f32, f32, 256, 1, 1, 128, 128, 8, 4, 4, 1, {8, 2}, {8, 2}, {4, 1, 1, 1, 1}, {2, 1, 1, 128, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {4, 1, 1, 1, 1}, { 2, 1, 1, 128, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}},
|
||||
{f32, f32, 256, 2, 1, 128, 64, 8, 4, 4, 1, {8, 2}, {8, 2}, {4, 1, 1, 1, 1}, {2, 1, 1, 128, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {2, 2, 1, 1, 1}, { 4, 1, 1, 64, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}},
|
||||
{f32, f32, 256, 4, 1, 128, 32, 8, 4, 4, 1, {8, 2}, {8, 2}, {4, 1, 1, 1, 1}, {2, 1, 1, 128, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {1, 4, 1, 1, 1}, { 8, 1, 1, 32, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}},
|
||||
|
||||
{f32, f32, 256, 8, 1, 128, 16, 16, 4, 4, 1, {8, 2}, {8, 2}, {8, 1, 1, 1, 1}, {2, 1, 1, 128, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {1, 8, 1, 1, 1}, {16, 1, 1, 16, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}},
|
||||
|
||||
{f32, f32, 128, 1, 1, 64, 128, 8, 4, 4, 1, {4, 2}, {8, 2}, {4, 1, 1, 1, 1}, {2, 1, 1, 64, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {8, 1, 1, 1, 1}, { 1, 1, 1, 128, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}},
|
||||
|
||||
// fp16
|
||||
{f16, f16, 256, 1, 2, 128, 128, 16, 4, 4, 1, {8, 2}, {8, 2}, {4, 1, 1, 2, 2}, {4, 1, 1, 64, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {2, 1, 1, 4, 2}, { 8, 1, 1, 32, 1}, {1, 1, 1, 4, 1}, {1, 1, 1, 4, 1}},
|
||||
|
||||
{f16, f16, 256, 1, 2, 128, 128, 8, 4, 4, 1, {8, 2}, {8, 2}, {4, 1, 1, 1, 2}, {2, 1, 1, 128, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 4, 2}, { 8, 1, 1, 32, 1}, {1, 1, 1, 4, 1}, {1, 1, 1, 4, 1}},
|
||||
{f16, f16, 256, 1, 2, 128, 128, 8, 4, 4, 1, {8, 2}, {8, 2}, {4, 1, 1, 1, 2}, {2, 1, 1, 128, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 4, 2}, { 8, 1, 1, 32, 1}, {1, 1, 1, 2, 1}, {1, 1, 1, 4, 1}},
|
||||
{f16, f16, 256, 1, 2, 128, 128, 8, 4, 4, 1, {8, 2}, {8, 2}, {4, 1, 1, 1, 2}, {2, 1, 1, 128, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 4, 2}, { 8, 1, 1, 32, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 4, 1}},
|
||||
|
||||
{f16, f16, 256, 1, 2, 128, 128, 8, 4, 4, 1, {8, 2}, {8, 2}, {4, 1, 1, 1, 2}, {2, 1, 1, 128, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {4, 1, 1, 1, 2}, { 2, 1, 1, 128, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}},
|
||||
{f16, f16, 256, 2, 2, 128, 64, 8, 4, 4, 1, {8, 2}, {8, 2}, {4, 1, 1, 1, 2}, {2, 1, 1, 128, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {2, 2, 1, 1, 2}, { 4, 1, 1, 64, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}},
|
||||
{f16, f16, 256, 4, 2, 128, 32, 8, 4, 4, 1, {8, 2}, {8, 2}, {4, 1, 1, 1, 2}, {2, 1, 1, 128, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {1, 4, 1, 1, 2}, { 8, 1, 1, 32, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}},
|
||||
|
||||
{f16, f16, 256, 8, 2, 128, 16, 16, 4, 4, 1, {8, 2}, {8, 2}, {8, 1, 1, 1, 2}, {2, 1, 1, 128, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {1, 8, 1, 1, 2}, {16, 1, 1, 16, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}},
|
||||
|
||||
{f16, f16, 128, 1, 2, 64, 128, 8, 4, 4, 1, {4, 2}, {8, 2}, {4, 1, 1, 1, 2}, {2, 1, 1, 64, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {8, 1, 1, 1, 2}, { 1, 1, 1, 128, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}},
|
||||
|
||||
// i8
|
||||
{ i8, i8, 256, 1, 4, 128, 128, 16, 4, 4, 1, {8, 2}, {8, 2}, {4, 1, 1, 2, 4}, {4, 1, 1, 64, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {2, 1, 1, 4, 4}, { 8, 1, 1, 32, 1}, {1, 1, 1, 4, 1}, {1, 1, 1, 4, 1}},
|
||||
|
||||
{ i8, i8, 256, 1, 4, 128, 128, 8, 4, 4, 1, {8, 2}, {8, 2}, {4, 1, 1, 1, 4}, {2, 1, 1, 128, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 4, 4}, { 8, 1, 1, 32, 1}, {1, 1, 1, 4, 1}, {1, 1, 1, 4, 1}},
|
||||
{ i8, i8, 256, 1, 4, 128, 128, 8, 4, 4, 1, {8, 2}, {8, 2}, {4, 1, 1, 1, 4}, {2, 1, 1, 128, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 4, 4}, { 8, 1, 1, 32, 1}, {1, 1, 1, 2, 1}, {1, 1, 1, 4, 1}},
|
||||
{ i8, i8, 256, 1, 4, 128, 128, 8, 4, 4, 1, {8, 2}, {8, 2}, {4, 1, 1, 1, 4}, {2, 1, 1, 128, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 4, 4}, { 8, 1, 1, 32, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 4, 1}},
|
||||
|
||||
{ i8, i8, 256, 1, 4, 128, 128, 8, 4, 4, 1, {8, 2}, {8, 2}, {4, 1, 1, 1, 4}, {2, 1, 1, 128, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {4, 1, 1, 1, 4}, { 2, 1, 1, 128, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}},
|
||||
{ i8, i8, 256, 2, 4, 128, 64, 8, 4, 4, 1, {8, 2}, {8, 2}, {4, 1, 1, 1, 4}, {2, 1, 1, 128, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {2, 2, 1, 1, 4}, { 4, 1, 1, 64, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}},
|
||||
{ i8, i8, 256, 4, 4, 128, 32, 8, 4, 4, 1, {8, 2}, {8, 2}, {4, 1, 1, 1, 4}, {2, 1, 1, 128, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {1, 4, 1, 1, 4}, { 8, 1, 1, 32, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}},
|
||||
|
||||
{ i8, i8, 256, 8, 4, 128, 16, 16, 4, 4, 1, {8, 2}, {8, 2}, {8, 1, 1, 1, 4}, {2, 1, 1, 128, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {1, 8, 1, 1, 4}, {16, 1, 1, 16, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}},
|
||||
|
||||
{ i8, i8, 128, 1, 4, 64, 128, 8, 4, 4, 1, {4, 2}, {8, 2}, {4, 1, 1, 1, 4}, {2, 1, 1, 64, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {8, 1, 1, 1, 4}, { 1, 1, 1, 128, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}}
|
||||
// clang-format on
|
||||
};
|
||||
}
|
||||
|
||||
// TODO make this common interface and write specs for it
|
||||
struct ConvIgemmFwdV6r1DlopsNchwKcyxNkhw
|
||||
{
|
||||
static auto
|
||||
CalculateCompileParameterBasedOnTunable(const ConvolutionProblemDescriptor& conv_problem_desc,
|
||||
const TunableConvIgemmFwdV6r1DlopsNchwKcyxNkhw& tunable)
|
||||
{
|
||||
const int C = conv_problem_desc.C;
|
||||
const int Y = conv_problem_desc.Y;
|
||||
const int X = conv_problem_desc.X;
|
||||
const int Ho = conv_problem_desc.Ho;
|
||||
const int Wo = conv_problem_desc.Wo;
|
||||
|
||||
if(!(conv_problem_desc.InDataTypeEnum == tunable.ABDataTypeEnum &&
|
||||
conv_problem_desc.WeiDataTypeEnum == tunable.ABDataTypeEnum &&
|
||||
conv_problem_desc.OutDataTypeEnum == tunable.CDataTypeEnum))
|
||||
return std::make_tuple(CompileParameterConvIgemmFwdV6r1DlopsNchwKcyxNkhw{}, false);
|
||||
|
||||
const auto ABDataTypeEnum = conv_problem_desc.InDataTypeEnum;
|
||||
const auto CDataTypeEnum = conv_problem_desc.OutDataTypeEnum;
|
||||
|
||||
DataTypeEnum_t AccDataTypeEnum;
|
||||
|
||||
if(ABDataTypeEnum == DataTypeEnum_t::Float || ABDataTypeEnum == DataTypeEnum_t::Half)
|
||||
{
|
||||
AccDataTypeEnum = DataTypeEnum_t::Float;
|
||||
}
|
||||
else if(ABDataTypeEnum == DataTypeEnum_t::Int8)
|
||||
{
|
||||
AccDataTypeEnum = DataTypeEnum_t::Int32;
|
||||
}
|
||||
else
|
||||
{
|
||||
return std::make_tuple(CompileParameterConvIgemmFwdV6r1DlopsNchwKcyxNkhw{}, false);
|
||||
}
|
||||
|
||||
const int BlockSize = tunable.BlockSize;
|
||||
|
||||
const int GN0 = tunable.GN0;
|
||||
const int GK1 = tunable.GK1;
|
||||
|
||||
const int GM11 = tunable.GM1PerBlockGM11;
|
||||
const int GN11 = tunable.GN1PerBlockGN11;
|
||||
const int GK0PerBlock = tunable.GK0PerBlock;
|
||||
|
||||
const int BM11 = tunable.BM1PerThreadBM11;
|
||||
const int BN11 = tunable.BN1PerThreadBN11;
|
||||
const int BK0PerThread = tunable.BK0PerThread;
|
||||
|
||||
const auto BM10BN10ThreadClusterBM10Xs = tunable.BM10BN10ThreadClusterBM10Xs;
|
||||
const auto BM10BN10ThreadClusterBN10Xs = tunable.BM10BN10ThreadClusterBN10Xs;
|
||||
|
||||
const auto ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1 =
|
||||
tunable.ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1;
|
||||
const auto ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1 =
|
||||
tunable.ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1;
|
||||
const auto ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1 =
|
||||
tunable.ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1;
|
||||
const auto ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1 =
|
||||
tunable.ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1;
|
||||
|
||||
const auto BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1 =
|
||||
tunable.BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1;
|
||||
const auto BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1 =
|
||||
tunable.BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1;
|
||||
const auto BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1 =
|
||||
tunable.BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1;
|
||||
const auto BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1 =
|
||||
tunable.BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1;
|
||||
|
||||
// C threadwise copy: {BN11} or {BN} or {BN1} or {GN11} is Dst vector dim
|
||||
const int CThreadTransferDstScalarPerVector = gcd(4, GN11, BN11, Ho * Wo);
|
||||
|
||||
const int C0 = GK1;
|
||||
|
||||
if(!(C % C0 == 0))
|
||||
return std::make_tuple(CompileParameterConvIgemmFwdV6r1DlopsNchwKcyxNkhw{}, false);
|
||||
|
||||
const int C1 = C / C0;
|
||||
|
||||
const int GK0 = C1 * Y * X;
|
||||
|
||||
if(!(GK0 % GK0PerBlock == 0))
|
||||
return std::make_tuple(CompileParameterConvIgemmFwdV6r1DlopsNchwKcyxNkhw{}, false);
|
||||
|
||||
const bool HasMainKBlockLoop = ((GK0 + GK0PerBlock) / (2 * GK0PerBlock) > 1);
|
||||
|
||||
const bool HasDoubleTailKBlockLoop = ((GK0 / GK0PerBlock) % 2 == 0);
|
||||
|
||||
return std::make_tuple(
|
||||
CompileParameterConvIgemmFwdV6r1DlopsNchwKcyxNkhw{
|
||||
ABDataTypeEnum,
|
||||
AccDataTypeEnum,
|
||||
CDataTypeEnum,
|
||||
BlockSize,
|
||||
GN0,
|
||||
GK1,
|
||||
GM11,
|
||||
GN11,
|
||||
GK0PerBlock,
|
||||
BM11,
|
||||
BN11,
|
||||
BK0PerThread,
|
||||
BM10BN10ThreadClusterBM10Xs,
|
||||
BM10BN10ThreadClusterBN10Xs,
|
||||
ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1,
|
||||
ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1,
|
||||
ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1,
|
||||
ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1,
|
||||
BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1,
|
||||
BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1,
|
||||
BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1,
|
||||
BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1,
|
||||
CThreadTransferDstScalarPerVector,
|
||||
HasMainKBlockLoop,
|
||||
HasDoubleTailKBlockLoop},
|
||||
true);
|
||||
}
|
||||
|
||||
static auto GetDefaultCompileParameter(const ConvolutionProblemDescriptor& conv_problem_desc)
|
||||
{
|
||||
for(const auto& tunable : generate_tunable_list_conv_igemm_fwd_v6r1_dlops_nchw_kcyx_nkhw())
|
||||
{
|
||||
CompileParameterConvIgemmFwdV6r1DlopsNchwKcyxNkhw compile_param{};
|
||||
bool found = false;
|
||||
|
||||
std::tie(compile_param, found) =
|
||||
CalculateCompileParameterBasedOnTunable(conv_problem_desc, tunable);
|
||||
|
||||
if(found && IsValidCompileParameter(conv_problem_desc, compile_param))
|
||||
return std::make_tuple(compile_param, true);
|
||||
}
|
||||
|
||||
return std::make_tuple(CompileParameterConvIgemmFwdV6r1DlopsNchwKcyxNkhw{}, false);
|
||||
}
|
||||
|
||||
static bool IsApplicable(const ConvolutionProblemDescriptor& conv_problem_desc)
|
||||
{
|
||||
bool found = false;
|
||||
|
||||
std::tie(std::ignore, found) = GetDefaultCompileParameter(conv_problem_desc);
|
||||
|
||||
return found;
|
||||
}
|
||||
|
||||
static bool
|
||||
IsValidCompileParameter(const ConvolutionProblemDescriptor& conv_problem_desc,
|
||||
const CompileParameterConvIgemmFwdV6r1DlopsNchwKcyxNkhw& compile_param)
|
||||
{
|
||||
const int N = conv_problem_desc.N;
|
||||
const int K = conv_problem_desc.K;
|
||||
const int C = conv_problem_desc.C;
|
||||
const int Y = conv_problem_desc.Y;
|
||||
const int X = conv_problem_desc.X;
|
||||
const int Ho = conv_problem_desc.Ho;
|
||||
const int Wo = conv_problem_desc.Wo;
|
||||
|
||||
const int GK1 = compile_param.GK1;
|
||||
const int GN0 = compile_param.GN0;
|
||||
const int GM11 = compile_param.GM1PerBlockGM11;
|
||||
const int GN11 = compile_param.GN1PerBlockGN11;
|
||||
|
||||
const int BM11 = compile_param.BM1PerThreadBM11;
|
||||
const int BN11 = compile_param.BN1PerThreadBN11;
|
||||
|
||||
const int C0 = GK1;
|
||||
const int N0 = GN0;
|
||||
|
||||
if(!(C % C0 == 0))
|
||||
return false;
|
||||
|
||||
const int C1 = C / C0;
|
||||
|
||||
if(!(N % N0 == 0))
|
||||
return false;
|
||||
|
||||
const int N1 = N / N0;
|
||||
|
||||
const int GM0 = 1;
|
||||
const int GM1 = K;
|
||||
const int GN1 = N1 * Ho * Wo;
|
||||
const int GK0 = C1 * Y * X;
|
||||
|
||||
// check data type
|
||||
{
|
||||
if(!(conv_problem_desc.InDataTypeEnum == conv_problem_desc.WeiDataTypeEnum &&
|
||||
conv_problem_desc.InDataTypeEnum == compile_param.ABDataTypeEnum))
|
||||
return false;
|
||||
|
||||
if(compile_param.ABDataTypeEnum == DataTypeEnum_t::Float ||
|
||||
compile_param.ABDataTypeEnum == DataTypeEnum_t::Half)
|
||||
{
|
||||
if(!(compile_param.AccDataTypeEnum == DataTypeEnum_t::Float))
|
||||
return false;
|
||||
}
|
||||
else if(compile_param.ABDataTypeEnum == DataTypeEnum_t::Int8)
|
||||
{
|
||||
if(!(compile_param.AccDataTypeEnum == DataTypeEnum_t::Int32))
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// check gridwise contraction
|
||||
{
|
||||
if(!(GM1 % GM11 == 0 && GN1 % GN11 == 0 && GK0 % compile_param.GK0PerBlock == 0))
|
||||
return false;
|
||||
|
||||
const bool has_main_k_block_loop =
|
||||
((GK0 + compile_param.GK0PerBlock) / (2 * compile_param.GK0PerBlock) > 1);
|
||||
|
||||
const bool has_double_tail_k_block_loop = ((GK0 / compile_param.GK0PerBlock) % 2 == 0);
|
||||
|
||||
if(!(has_main_k_block_loop == compile_param.HasMainKBlockLoop &&
|
||||
has_double_tail_k_block_loop == compile_param.HasDoubleTailKBlockLoop))
|
||||
return false;
|
||||
}
|
||||
|
||||
// check A blockwise copy
|
||||
{
|
||||
const auto block_slice_lengths =
|
||||
std::array<int, 5>{compile_param.GK0PerBlock, GM0, 1, GM11, GK1};
|
||||
const auto& cluster_lengths =
|
||||
compile_param.ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1;
|
||||
const auto& thread_slice_lengths =
|
||||
compile_param.ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1;
|
||||
const auto& src_vector_lengths =
|
||||
compile_param.ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1;
|
||||
const auto& dst_vector_lengths =
|
||||
compile_param.ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1;
|
||||
|
||||
// check number of working thread
|
||||
const int num_work_thread = std::accumulate(
|
||||
cluster_lengths.begin(), cluster_lengths.end(), 1, std::multiplies<int>{});
|
||||
|
||||
if(!(compile_param.BlockSize >= num_work_thread))
|
||||
return false;
|
||||
|
||||
// check block slice lengths vs thread slice lengths vs cluster lengths
|
||||
for(int i = 0; i < 5; ++i)
|
||||
{
|
||||
if(!(cluster_lengths[i] * thread_slice_lengths[i] == block_slice_lengths[i]))
|
||||
return false;
|
||||
}
|
||||
|
||||
// check thread slice lengths vs vector lengths
|
||||
for(int i = 0; i < 5; ++i)
|
||||
{
|
||||
if(!(thread_slice_lengths[i] % src_vector_lengths[i] == 0))
|
||||
return false;
|
||||
|
||||
if(!(thread_slice_lengths[i] % dst_vector_lengths[i] == 0))
|
||||
return false;
|
||||
}
|
||||
|
||||
// check Src vectorization, GK0 is global mem vector dim
|
||||
if(!(src_vector_lengths[1] == 1 && src_vector_lengths[2] == 1 &&
|
||||
src_vector_lengths[3] == 1 && src_vector_lengths[4] == 1))
|
||||
return false;
|
||||
|
||||
// check Dst vectorization, {GM11, GK1} are LDS vector dims
|
||||
if(dst_vector_lengths[4] == GK1)
|
||||
{ // vectorize on {GM11, GK1}
|
||||
if(!(GM11 % dst_vector_lengths[3] == 0))
|
||||
return false;
|
||||
}
|
||||
else
|
||||
{ // vectorize on {GK1} only
|
||||
if(!(GK1 % dst_vector_lengths[4] == 0))
|
||||
return false;
|
||||
|
||||
if(!(dst_vector_lengths[3] == 1))
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// check B blockwise copy
|
||||
{
|
||||
const auto block_slice_lengths =
|
||||
std::array<int, 5>{compile_param.GK0PerBlock, GN0, 1, GN11, GK1};
|
||||
const auto& cluster_lengths =
|
||||
compile_param.BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1;
|
||||
const auto& thread_slice_lengths =
|
||||
compile_param.BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1;
|
||||
const auto& src_vector_lengths =
|
||||
compile_param.BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1;
|
||||
const auto& dst_vector_lengths =
|
||||
compile_param.BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1;
|
||||
|
||||
// check number of working thread
|
||||
const int num_work_thread = std::accumulate(
|
||||
cluster_lengths.begin(), cluster_lengths.end(), 1, std::multiplies<int>{});
|
||||
|
||||
if(!(compile_param.BlockSize >= num_work_thread))
|
||||
return false;
|
||||
|
||||
// check block slice lengths vs thread slice lengths vs cluster lengths
|
||||
for(int i = 0; i < 5; ++i)
|
||||
{
|
||||
if(!(cluster_lengths[i] * thread_slice_lengths[i] == block_slice_lengths[i]))
|
||||
return false;
|
||||
}
|
||||
|
||||
// check thread slice lengths vs vector lengths
|
||||
for(int i = 0; i < 5; ++i)
|
||||
{
|
||||
if(!(thread_slice_lengths[i] % src_vector_lengths[i] == 0 &&
|
||||
thread_slice_lengths[i] % dst_vector_lengths[i] == 0))
|
||||
return false;
|
||||
}
|
||||
|
||||
// check Src vectorization: {GN11} is global mem vector dim
|
||||
if(!(src_vector_lengths[0] == 1 && src_vector_lengths[1] == 1 &&
|
||||
src_vector_lengths[2] == 1 && src_vector_lengths[4] == 1))
|
||||
return false;
|
||||
|
||||
// check Src tensor layout related vectorization
|
||||
if(Y == 1 && X == 1 && conv_problem_desc.ConvStrideH == 1 &&
|
||||
conv_problem_desc.ConvStrideW == 1 && conv_problem_desc.InLeftPadH == 0 &&
|
||||
conv_problem_desc.InLeftPadW == 0 && conv_problem_desc.InRightPadH == 0 &&
|
||||
conv_problem_desc.InRightPadW == 0)
|
||||
{
|
||||
if(!((Ho * Wo) % src_vector_lengths[3] == 0))
|
||||
return false;
|
||||
}
|
||||
else if(conv_problem_desc.ConvStrideW == 1 && conv_problem_desc.InLeftPadW == 0 &&
|
||||
conv_problem_desc.InRightPadW == 0)
|
||||
{
|
||||
if(!(Wo % src_vector_lengths[3] == 0))
|
||||
return false;
|
||||
}
|
||||
else
|
||||
{
|
||||
if(!(src_vector_lengths[3] == 1))
|
||||
return false;
|
||||
}
|
||||
|
||||
// check Dst vectorization: {GN11, GK1} are LDS vector dims
|
||||
if(dst_vector_lengths[4] == GK1)
|
||||
{ // vectorize on {GN11, GK1}
|
||||
if(!(GN11 % dst_vector_lengths[3] == 0))
|
||||
return false;
|
||||
}
|
||||
else
|
||||
{ // vectorize on {GK1} only
|
||||
if(!(dst_vector_lengths[3] == 1))
|
||||
return false;
|
||||
|
||||
if(!(GK1 % dst_vector_lengths[4] == 0))
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// check blockwise GEMM
|
||||
{
|
||||
const int BM10 = std::accumulate(compile_param.BM10BN10ThreadClusterBM10Xs.begin(),
|
||||
compile_param.BM10BN10ThreadClusterBM10Xs.end(),
|
||||
1,
|
||||
std::multiplies<int>{});
|
||||
|
||||
const int BN10 = std::accumulate(compile_param.BM10BN10ThreadClusterBN10Xs.begin(),
|
||||
compile_param.BM10BN10ThreadClusterBN10Xs.end(),
|
||||
1,
|
||||
std::multiplies<int>{});
|
||||
|
||||
if(!(compile_param.BlockSize == BM10 * BN10))
|
||||
return false;
|
||||
|
||||
const int BM = GM0 * GM11;
|
||||
const int BN = GN0 * GN11;
|
||||
|
||||
const int BM1 = BM10 * BM11;
|
||||
const int BN1 = BN10 * BN11;
|
||||
|
||||
if(!(BM % BM1 == 0 && BN % BN1 == 0))
|
||||
return false;
|
||||
|
||||
const int BM0 = BM / BM1;
|
||||
const int BN0 = BN / BN1;
|
||||
|
||||
// blockwise GEMM currently only support BM0 == 2 && BN0 == 2
|
||||
if(!(BM0 == 2 && BN0 == 2))
|
||||
return false;
|
||||
|
||||
if(!(compile_param.GK0PerBlock % compile_param.BK0PerThread == 0))
|
||||
return false;
|
||||
}
|
||||
|
||||
// check C threadwise copy
|
||||
{
|
||||
// {BN11} or {BN} or {BN1} or {GN11} is Dst vector dim
|
||||
const int dst_vector_len_gn11 = compile_param.CThreadTransferDstScalarPerVector;
|
||||
|
||||
// check slice length vs Dst vector length:
|
||||
if(!(BN11 % dst_vector_len_gn11 == 0 && GN11 % dst_vector_len_gn11 == 0))
|
||||
return false;
|
||||
|
||||
// check Dst memory layout related vectorization:
|
||||
if(!((Ho * Wo) % compile_param.CThreadTransferDstScalarPerVector == 0))
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
};
|
||||
|
||||
static int GetBlockSize(const ConvolutionProblemDescriptor&,
|
||||
const CompileParameterConvIgemmFwdV6r1DlopsNchwKcyxNkhw& compile_param)
|
||||
{
|
||||
return compile_param.BlockSize;
|
||||
}
|
||||
|
||||
static int GetGridSize(const ConvolutionProblemDescriptor& conv_problem_desc,
|
||||
const CompileParameterConvIgemmFwdV6r1DlopsNchwKcyxNkhw& compile_param)
|
||||
{
|
||||
const int N = conv_problem_desc.N;
|
||||
const int K = conv_problem_desc.K;
|
||||
const int Ho = conv_problem_desc.Ho;
|
||||
const int Wo = conv_problem_desc.Wo;
|
||||
|
||||
const int N0 = compile_param.GN0;
|
||||
const int N1 = N / N0;
|
||||
|
||||
const int GM1 = K;
|
||||
const int GN1 = N1 * Ho * Wo;
|
||||
|
||||
const int GM11 = compile_param.GM1PerBlockGM11;
|
||||
const int GN11 = compile_param.GN1PerBlockGN11;
|
||||
|
||||
const int GM10 = GM1 / GM11;
|
||||
const int GN10 = GN1 / GN11;
|
||||
|
||||
return GM10 * GN10;
|
||||
}
|
||||
|
||||
static std::size_t GetWorkSpaceSize(const ConvolutionProblemDescriptor&,
|
||||
const CompileParameterConvIgemmFwdV6r1DlopsNchwKcyxNkhw&)
|
||||
{
|
||||
// workspace is used for save transformed tensor descritpors created by prepare kernel
|
||||
return 4096L;
|
||||
}
|
||||
|
||||
static std::size_t GetMaxWorkSpaceSize(const ConvolutionProblemDescriptor&) { return 4096L; }
|
||||
|
||||
static auto GetTunableList()
|
||||
{
|
||||
return generate_tunable_list_conv_igemm_fwd_v6r1_dlops_nchw_kcyx_nkhw();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace driver
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -0,0 +1,51 @@
|
||||
#ifndef CONV_TUNABLE_FWD_V4R4_DLOPS_NCHW_KCYX_NKHW_HPP
|
||||
#define CONV_TUNABLE_FWD_V4R4_DLOPS_NCHW_KCYX_NKHW_HPP
|
||||
|
||||
struct tunable_dyn_conv_fwd_v4r4_dlops_nchw_kcyx_nkhw
|
||||
{
|
||||
int BlockSize;
|
||||
|
||||
int MPerBlock;
|
||||
int NPerBlock;
|
||||
int KPerBlock;
|
||||
|
||||
int M1PerThread;
|
||||
int N1PerThread;
|
||||
int KPerThread;
|
||||
|
||||
int M1N1ThreadClusterM10;
|
||||
int M1N1ThreadClusterN10;
|
||||
int M1N1ThreadClusterM11;
|
||||
int M1N1ThreadClusterN11;
|
||||
|
||||
std::array<int, 3> ABlockTransferThreadSliceLengths_K_M0_M1;
|
||||
std::array<int, 3> ABlockTransferThreadClusterLengths_K_M0_M1;
|
||||
std::array<int, 3> ABlockTransferThreadClusterArrangeOrder;
|
||||
std::array<int, 3> ABlockTransferSrcAccessOrder;
|
||||
int ABlockTransferSrcVectorDim;
|
||||
int ABlockTransferSrcScalarPerVector;
|
||||
int ABlockTransferDstScalarPerVector_M1;
|
||||
bool AThreadTransferSrcResetCoordinateAfterRun;
|
||||
|
||||
std::array<int, 3> BBlockTransferThreadSliceLengths_K_N0_N1;
|
||||
std::array<int, 3> BBlockTransferThreadClusterLengths_K_N0_N1;
|
||||
std::array<int, 3> BBlockTransferThreadClusterArrangeOrder;
|
||||
std::array<int, 3> BBlockTransferSrcAccessOrder;
|
||||
int BBlockTransferSrcVectorDim;
|
||||
int BBlockTransferSrcScalarPerVector;
|
||||
int BBlockTransferDstScalarPerVector_N1;
|
||||
bool BThreadTransferSrcResetCoordinateAfterRun;
|
||||
|
||||
std::array<int, 6> CThreadTransferSrcDstAccessOrder;
|
||||
int CThreadTransferSrcDstVectorDim;
|
||||
int CThreadTransferDstScalarPerVector;
|
||||
};
|
||||
|
||||
static tunable_dyn_conv_fwd_v4r4_dlops_nchw_kcyx_nkhw
|
||||
default_tunable_dyn_conv_fwd_v4r4_dlops_nchw_kcyx_nkhw = {
|
||||
256, 128, 128, 8, 4, 4, 1,
|
||||
8, 8, 2, 2, {4, 1, 1}, {2, 1, 128}, {2, 1, 0},
|
||||
{2, 1, 0}, 0, 4, 1, false, {4, 1, 1}, {2, 1, 128},
|
||||
{0, 1, 2}, {0, 1, 2}, 2, 1, 1, false, {3, 4, 5, 0, 1, 2},
|
||||
5, 1};
|
||||
#endif
|
||||
@@ -0,0 +1,73 @@
|
||||
#ifndef CONV_TUNABLE_FWD_V4R4_XDLOPS_NCHW_KCYX_NKHW_HPP
|
||||
#define CONV_TUNABLE_FWD_V4R4_XDLOPS_NCHW_KCYX_NKHW_HPP
|
||||
|
||||
struct tunable_dyn_conv_fwd_v4r4_xdlops_nchw_kcyx_nkhw
|
||||
{
|
||||
int BlockSize;
|
||||
|
||||
int MPerBlock;
|
||||
int NPerBlock;
|
||||
int KPerBlock;
|
||||
|
||||
int MPerWave;
|
||||
int NPerWave;
|
||||
int K1;
|
||||
|
||||
int MRepeat;
|
||||
int NRepeat;
|
||||
|
||||
std::array<int, 3> ABlockTransferThreadSliceLengths_K0_M_K1;
|
||||
std::array<int, 3> ABlockTransferThreadClusterLengths_K0_M_K1;
|
||||
std::array<int, 3> ABlockTransferThreadClusterArrangeOrder;
|
||||
std::array<int, 3> ABlockTransferSrcAccessOrder;
|
||||
int ABlockTransferSrcVectorDim;
|
||||
int ABlockTransferSrcScalarPerVector;
|
||||
int ABlockTransferDstScalarPerVector_K1;
|
||||
bool AThreadTransferSrcResetCoordinateAfterRun;
|
||||
|
||||
std::array<int, 3> BBlockTransferThreadSliceLengths_K0_N_K1;
|
||||
std::array<int, 3> BBlockTransferThreadClusterLengths_K0_N_K1;
|
||||
std::array<int, 3> BBlockTransferThreadClusterArrangeOrder;
|
||||
std::array<int, 3> BBlockTransferSrcAccessOrder;
|
||||
int BBlockTransferSrcVectorDim;
|
||||
int BBlockTransferSrcScalarPerVector;
|
||||
int BBlockTransferDstScalarPerVector_K1;
|
||||
bool BThreadTransferSrcResetCoordinateAfterRun;
|
||||
|
||||
std::array<int, 8> CThreadTransferSrcDstAccessOrder;
|
||||
int CThreadTransferSrcDstVectorDim;
|
||||
int CThreadTransferDstScalarPerVector;
|
||||
};
|
||||
|
||||
static tunable_dyn_conv_fwd_v4r4_xdlops_nchw_kcyx_nkhw
|
||||
default_tunable_dyn_conv_fwd_v4r4_xdlops_nchw_kcyx_nkhw = {
|
||||
256, // BlockSize
|
||||
128, // MPerBlock,
|
||||
128, // NPerBlock,
|
||||
4, // KPerBlock,
|
||||
32, // MPerWave,
|
||||
32, // NPerWave,
|
||||
4, // K1,
|
||||
2, // MRepeat,
|
||||
2, // NRepeat,
|
||||
{1, 2, 4}, // ABlockTransferThreadSliceLengths_K0_M_K1,
|
||||
{4, 64, 1}, // ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
{1, 0, 2}, // ABlockTransferThreadClusterArrangeOrder,
|
||||
{1, 0, 2}, // ABlockTransferSrcAccessOrder,
|
||||
2, // ABlockTransferSrcVectorDim
|
||||
1, // ABlockTransferSrcScalarPerVector,
|
||||
4, // ABlockTransferDstScalarPerVector_K1,
|
||||
false, // AThreadTransferSrcResetCoordinateAfterRun,
|
||||
{1, 2, 4}, // BBlockTransferThreadSliceLengths_K0_N_K1,
|
||||
{4, 64, 1}, // BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
{0, 2, 1}, // BBlockTransferThreadClusterArrangeOrder,
|
||||
{1, 0, 2}, // BBlockTransferSrcAccessOrder,
|
||||
1, // BBlockTransferSrcVectorDim
|
||||
1, // BBlockTransferSrcScalarPerVector
|
||||
4, // BBlockTransferDstScalarPerVector_K1
|
||||
false, // BThreadTransferSrcResetCoordinateAfterRun
|
||||
{3, 0, 1, 2, 7, 5, 4, 6}, // CThreadTransferSrcDstAccessOrder
|
||||
7, // CThreadTransferSrcDstVectorDim,
|
||||
1 // CThreadTransferDstScalarPerVector
|
||||
};
|
||||
#endif
|
||||
@@ -0,0 +1,73 @@
|
||||
#ifndef CONV_TUNABLE_FWD_V4R4_XDLOPS_NHWC_KYXC_NHWK_HPP
|
||||
#define CONV_TUNABLE_FWD_V4R4_XDLOPS_NHWC_KYXC_NHWK_HPP
|
||||
|
||||
struct tunable_dyn_conv_fwd_v4r4_xdlops_nhwc_kyxc_nhwk
|
||||
{
|
||||
int BlockSize;
|
||||
|
||||
int MPerBlock;
|
||||
int NPerBlock;
|
||||
int KPerBlock;
|
||||
|
||||
int MPerWave;
|
||||
int NPerWave;
|
||||
int K1;
|
||||
|
||||
int MRepeat;
|
||||
int NRepeat;
|
||||
|
||||
std::array<int, 3> ABlockTransferThreadSliceLengths_K0_M_K1;
|
||||
std::array<int, 3> ABlockTransferThreadClusterLengths_K0_M_K1;
|
||||
std::array<int, 3> ABlockTransferThreadClusterArrangeOrder;
|
||||
std::array<int, 3> ABlockTransferSrcAccessOrder;
|
||||
int ABlockTransferSrcVectorDim;
|
||||
int ABlockTransferSrcScalarPerVector;
|
||||
int ABlockTransferDstScalarPerVector_K1;
|
||||
bool AThreadTransferSrcResetCoordinateAfterRun;
|
||||
|
||||
std::array<int, 3> BBlockTransferThreadSliceLengths_K0_N_K1;
|
||||
std::array<int, 3> BBlockTransferThreadClusterLengths_K0_N_K1;
|
||||
std::array<int, 3> BBlockTransferThreadClusterArrangeOrder;
|
||||
std::array<int, 3> BBlockTransferSrcAccessOrder;
|
||||
int BBlockTransferSrcVectorDim;
|
||||
int BBlockTransferSrcScalarPerVector;
|
||||
int BBlockTransferDstScalarPerVector_K1;
|
||||
bool BThreadTransferSrcResetCoordinateAfterRun;
|
||||
|
||||
std::array<int, 8> CThreadTransferSrcDstAccessOrder;
|
||||
int CThreadTransferSrcDstVectorDim;
|
||||
int CThreadTransferDstScalarPerVector;
|
||||
};
|
||||
|
||||
static tunable_dyn_conv_fwd_v4r4_xdlops_nhwc_kyxc_nhwk
|
||||
default_tunable_dyn_conv_fwd_v4r4_xdlops_nhwc_kyxc_nhwk = {
|
||||
256, // BlockSize
|
||||
128, // MPerBlock,
|
||||
128, // NPerBlock,
|
||||
4, // KPerBlock,
|
||||
32, // MPerWave,
|
||||
32, // NPerWave,
|
||||
4, // K1,
|
||||
2, // MRepeat,
|
||||
2, // NRepeat,
|
||||
{1, 2, 4}, // ABlockTransferThreadSliceLengths_K0_M_K1,
|
||||
{4, 64, 1}, // ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
{1, 0, 2}, // ABlockTransferThreadClusterArrangeOrder,
|
||||
{1, 0, 2}, // ABlockTransferSrcAccessOrder,
|
||||
2, // ABlockTransferSrcVectorDim
|
||||
4, // ABlockTransferSrcScalarPerVector,
|
||||
4, // ABlockTransferDstScalarPerVector_K1,
|
||||
false, // AThreadTransferSrcResetCoordinateAfterRun,
|
||||
{1, 2, 4}, // BBlockTransferThreadSliceLengths_K0_N_K1,
|
||||
{4, 64, 1}, // BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
{1, 0, 2}, // BBlockTransferThreadClusterArrangeOrder,
|
||||
{1, 0, 2}, // BBlockTransferSrcAccessOrder,
|
||||
2, // BBlockTransferSrcVectorDim
|
||||
4, // BBlockTransferSrcScalarPerVector
|
||||
4, // BBlockTransferDstScalarPerVector_K1
|
||||
false, // BThreadTransferSrcResetCoordinateAfterRun
|
||||
{2, 3, 0, 1, 7, 5, 4, 6}, // CThreadTransferSrcDstAccessOrder
|
||||
7, // CThreadTransferSrcDstVectorDim,
|
||||
1 // CThreadTransferDstScalarPerVector
|
||||
};
|
||||
#endif
|
||||
81
host/solver/include/convolution_problem_descriptor.hpp
Normal file
81
host/solver/include/convolution_problem_descriptor.hpp
Normal file
@@ -0,0 +1,81 @@
|
||||
#ifndef CONVOLUTION_PROBLEM_DESCRIPTOR
|
||||
#define CONVOLUTION_PROBLEM_DESCRIPTOR
|
||||
|
||||
namespace ck {
|
||||
namespace driver {
|
||||
|
||||
struct ConvolutionProblemDescriptor
|
||||
{
|
||||
ConvolutionProblemDescriptor() = default;
|
||||
|
||||
ConvolutionProblemDescriptor(int N_,
|
||||
int K_,
|
||||
int C_,
|
||||
int Y_,
|
||||
int X_,
|
||||
int Hi_,
|
||||
int Wi_,
|
||||
int Ho_,
|
||||
int Wo_,
|
||||
int ConvStrideH_,
|
||||
int ConvStrideW_,
|
||||
int ConvDilationH_,
|
||||
int ConvDilationW_,
|
||||
int InLeftPadH_,
|
||||
int InLeftPadW_,
|
||||
int InRightPadH_,
|
||||
int InRightPadW_,
|
||||
ck::DataTypeEnum_t InDataTypeEnum_,
|
||||
ck::DataTypeEnum_t WeiDataTypeEnum_,
|
||||
ck::DataTypeEnum_t OutDataTypeEnum_)
|
||||
: N{N_},
|
||||
K{K_},
|
||||
C{C_},
|
||||
Y{Y_},
|
||||
X{X_},
|
||||
Hi{Hi_},
|
||||
Wi{Wi_},
|
||||
Ho{Ho_},
|
||||
Wo{Wo_},
|
||||
ConvStrideH{ConvStrideH_},
|
||||
ConvStrideW{ConvStrideW_},
|
||||
ConvDilationH{ConvDilationH_},
|
||||
ConvDilationW{ConvDilationW_},
|
||||
InLeftPadH{InLeftPadH_},
|
||||
InLeftPadW{InLeftPadW_},
|
||||
InRightPadH{InRightPadH_},
|
||||
InRightPadW{InRightPadW_},
|
||||
InDataTypeEnum{InDataTypeEnum_},
|
||||
WeiDataTypeEnum{WeiDataTypeEnum_},
|
||||
OutDataTypeEnum{OutDataTypeEnum_}
|
||||
{
|
||||
}
|
||||
|
||||
int N;
|
||||
int K;
|
||||
int C;
|
||||
int Y;
|
||||
int X;
|
||||
int Hi;
|
||||
int Wi;
|
||||
int Ho;
|
||||
int Wo;
|
||||
int ConvStrideH;
|
||||
int ConvStrideW;
|
||||
int ConvDilationH;
|
||||
int ConvDilationW;
|
||||
int InLeftPadH;
|
||||
int InLeftPadW;
|
||||
int InRightPadH;
|
||||
int InRightPadW;
|
||||
|
||||
ck::DataTypeEnum_t InDataTypeEnum;
|
||||
ck::DataTypeEnum_t WeiDataTypeEnum;
|
||||
ck::DataTypeEnum_t OutDataTypeEnum;
|
||||
|
||||
std::size_t CalculateFlop() const { return 2L * N * K * C * Y * X * Ho * Wo; }
|
||||
};
|
||||
|
||||
} // namespace driver
|
||||
} // namespace ck
|
||||
#endif
|
||||
46
host/solver/include/solver_common.hpp
Normal file
46
host/solver/include/solver_common.hpp
Normal file
@@ -0,0 +1,46 @@
|
||||
#ifndef CK_SOLVER_COMMON_HPP
|
||||
#define CK_SOLVER_COMMON_HPP
|
||||
|
||||
namespace ck {
|
||||
namespace driver {
|
||||
|
||||
// greatest common divisor, aka highest common factor
|
||||
inline int gcd(int x, int y)
|
||||
{
|
||||
if(x < 0)
|
||||
{
|
||||
return gcd(-x, y);
|
||||
}
|
||||
else if(y < 0)
|
||||
{
|
||||
return gcd(x, -y);
|
||||
}
|
||||
else if(x == y || x == 0)
|
||||
{
|
||||
return y;
|
||||
}
|
||||
else if(y == 0)
|
||||
{
|
||||
return x;
|
||||
}
|
||||
else if(x > y)
|
||||
{
|
||||
return gcd(x % y, y);
|
||||
}
|
||||
else
|
||||
{
|
||||
return gcd(x, y % x);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename X,
|
||||
typename... Ys,
|
||||
typename std::enable_if<sizeof...(Ys) >= 2, bool>::type = false>
|
||||
auto gcd(X x, Ys... ys)
|
||||
{
|
||||
return gcd(x, gcd(ys...));
|
||||
}
|
||||
|
||||
} // namespace driver
|
||||
} // namespace ck
|
||||
#endif
|
||||
Reference in New Issue
Block a user