mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
Reorganize files, Part 1 (#119)
* delete obselete files
* move files
* build
* update cmake
* update cmake
* fix build
* reorg examples
* update cmake for example and test
[ROCm/composable_kernel commit: 5d37d7bff4]
This commit is contained in:
@@ -45,7 +45,6 @@ message("OpenMP_gomp_LIBRARY: ${OpenMP_gomp_LIBRARY}")
|
||||
message("OpenMP_pthread_LIBRARY: ${OpenMP_pthread_LIBRARY}")
|
||||
message("OpenMP_CXX_FLAGS: ${OpenMP_CXX_FLAGS}")
|
||||
|
||||
# set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS}")
|
||||
link_libraries(${OpenMP_gomp_LIBRARY})
|
||||
link_libraries(${OpenMP_pthread_LIBRARY})
|
||||
|
||||
@@ -71,17 +70,17 @@ if( DEFINED CK_OVERRIDE_HIP_VERSION_PATCH )
|
||||
endif()
|
||||
message(STATUS "Build with HIP ${HIP_VERSION}")
|
||||
|
||||
## half
|
||||
#find_path(HALF_INCLUDE_DIR half.hpp)
|
||||
set(HALF_INCLUDE_DIR "${PROJECT_SOURCE_DIR}/external/half/include")
|
||||
message("HALF_INCLUDE_DIR: ${HALF_INCLUDE_DIR}")
|
||||
|
||||
|
||||
rocm_create_package(
|
||||
NAME CK-${CK_BACKEND}
|
||||
DESCRIPTION "High Performance Composable Kernels for AMD GPUs"
|
||||
DESCRIPTION "High Performance Composable Kernel for AMD GPUs"
|
||||
LDCONFIG
|
||||
)
|
||||
|
||||
## half
|
||||
set(HALF_INCLUDE_DIR "${PROJECT_SOURCE_DIR}/external/include/half")
|
||||
message("HALF_INCLUDE_DIR: ${HALF_INCLUDE_DIR}")
|
||||
|
||||
## tidy
|
||||
include(EnableCompilerWarnings)
|
||||
set(CK_TIDY_ERRORS ERRORS * -readability-inconsistent-declaration-parameter-name)
|
||||
@@ -184,7 +183,6 @@ enable_clang_tidy(
|
||||
-cppcoreguidelines-narrowing-conversions
|
||||
-altera-struct-pack-align
|
||||
-cppcoreguidelines-prefer-member-initializer
|
||||
|
||||
${CK_TIDY_CHECKS}
|
||||
${CK_TIDY_ERRORS}
|
||||
HEADER_FILTER
|
||||
@@ -214,69 +212,36 @@ enable_cppcheck(
|
||||
unmatchedSuppression
|
||||
FORCE
|
||||
SOURCES
|
||||
host/host_tensor/src
|
||||
host/driver_offline/src
|
||||
composable_kernel/src/kernel_wrapper
|
||||
library/src
|
||||
INCLUDE
|
||||
host/host_tensor/include
|
||||
host/device/include
|
||||
host/solver/include
|
||||
host/driver_offline/include
|
||||
composable_kernel/include/*
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/include
|
||||
${CMAKE_CURRENT_BINARY_DIR}/include
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/library/include
|
||||
DEFINE
|
||||
CPPCHECK=1
|
||||
__linux__=1
|
||||
)
|
||||
|
||||
set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/lib)
|
||||
set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/lib)
|
||||
set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/bin)
|
||||
|
||||
file(GLOB_RECURSE COMPOSABLE_KERNEL_HEADERS "composable_kernel/include/*/*.hpp")
|
||||
file(GLOB_RECURSE DEVICE_OPS_HEADERS "device_operation/include/*.hpp")
|
||||
configure_file("${PROJECT_SOURCE_DIR}/include/ck/hip_version.hpp.in" "${PROJECT_BINARY_DIR}/include/ck/hip_version.hpp")
|
||||
|
||||
file(GLOB_RECURSE DEVICE_OPS_SOURCE "device_operation/*.cpp")
|
||||
include_directories(BEFORE
|
||||
${PROJECT_SOURCE_DIR}/include
|
||||
${PROJECT_BINARY_DIR}/include
|
||||
${PROJECT_SOURCE_DIR}/library/include
|
||||
)
|
||||
|
||||
set(CK_HEADERS ${COMPOSABLE_KERNEL_HEADERS} ${DEVICE_OPS_HEADERS})
|
||||
set(CK_SOURCE ${DEVICE_OPS_SOURCE})
|
||||
add_library(composable_kernel ${CK_SOURCE})
|
||||
|
||||
|
||||
target_include_directories(composable_kernel PUBLIC
|
||||
$<BUILD_INTERFACE:${PROJECT_SOURCE_DIR}/composable_kernel/include>
|
||||
)
|
||||
target_include_directories(composable_kernel PUBLIC
|
||||
$<BUILD_INTERFACE:${PROJECT_SOURCE_DIR}/device_operation/include>
|
||||
)
|
||||
target_include_directories(composable_kernel PUBLIC
|
||||
$<BUILD_INTERFACE:${PROJECT_SOURCE_DIR}/host/include>
|
||||
)
|
||||
target_include_directories(composable_kernel PUBLIC
|
||||
$<BUILD_INTERFACE:${PROJECT_SOURCE_DIR}/host/host_tensor/include>
|
||||
)
|
||||
# The following should eventually be removed
|
||||
target_include_directories(composable_kernel PUBLIC
|
||||
$<BUILD_INTERFACE:${PROJECT_SOURCE_DIR}/composable_kernel/include/utility>
|
||||
)
|
||||
target_include_directories(composable_kernel PUBLIC
|
||||
$<BUILD_INTERFACE:${PROJECT_SOURCE_DIR}/composable_kernel/include/tensor_operation>
|
||||
)
|
||||
target_include_directories(composable_kernel PUBLIC
|
||||
$<BUILD_INTERFACE:${PROJECT_SOURCE_DIR}/composable_kernel/include/tensor_description>
|
||||
)
|
||||
# clang_tidy_check(composable_kernel)
|
||||
SET(BUILD_DEV ON CACHE BOOL "BUILD_DEV")
|
||||
if(BUILD_DEV)
|
||||
target_compile_options(composable_kernel PRIVATE -Werror)
|
||||
target_compile_options(composable_kernel PRIVATE -Weverything)
|
||||
add_compile_options(-Werror)
|
||||
add_compile_options(-Weverything)
|
||||
endif()
|
||||
message("CMAKE_CXX_FLAGS: ${CMAKE_CXX_FLAGS}")
|
||||
|
||||
configure_file("${PROJECT_SOURCE_DIR}/composable_kernel/include/hip_version.hpp.in" "${PROJECT_BINARY_DIR}/composable_kernel/include/hip_version.hpp")
|
||||
|
||||
add_subdirectory(host)
|
||||
add_subdirectory(device_operation)
|
||||
add_subdirectory(library)
|
||||
add_subdirectory(example)
|
||||
add_subdirectory(profiler)
|
||||
add_subdirectory(test)
|
||||
add_subdirectory(profiler)
|
||||
|
||||
@@ -1,14 +0,0 @@
|
||||
#ifndef CK_GRIDWISE_OPERATION_KERNEL_WRAPPER
|
||||
#define CK_GRIDWISE_OPERATION_KERNEL_WRAPPER
|
||||
|
||||
template <typename GridwiseOp, typename... Xs>
|
||||
__global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
#endif
|
||||
run_gridwise_operation(Xs... xs)
|
||||
{
|
||||
GridwiseOp{}.Run(xs...);
|
||||
}
|
||||
|
||||
#endif
|
||||
@@ -1,369 +0,0 @@
|
||||
#include "common_header.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
#include "gridwise_gemm_dlops_v1r2.hpp"
|
||||
#include "transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw.hpp"
|
||||
|
||||
using namespace ck;
|
||||
|
||||
constexpr DataTypeEnum_t ABDataTypeEnum = static_cast<DataTypeEnum_t>(CK_PARAM_ABDataTypeEnum);
|
||||
constexpr DataTypeEnum_t AccDataTypeEnum = static_cast<DataTypeEnum_t>(CK_PARAM_AccDataTypeEnum);
|
||||
constexpr DataTypeEnum_t CDataTypeEnum = static_cast<DataTypeEnum_t>(CK_PARAM_CDataTypeEnum);
|
||||
|
||||
using FloatAB = typename get_datatype_from_enum<ABDataTypeEnum>::type;
|
||||
using FloatAcc = typename get_datatype_from_enum<AccDataTypeEnum>::type;
|
||||
using FloatC = typename get_datatype_from_enum<CDataTypeEnum>::type;
|
||||
|
||||
constexpr index_t BlockSize = CK_PARAM_BlockSize;
|
||||
|
||||
constexpr index_t MPerBlock = CK_PARAM_MPerBlock;
|
||||
constexpr index_t NPerBlock = CK_PARAM_NPerBlock;
|
||||
constexpr index_t KPerBlock = CK_PARAM_KPerBlock;
|
||||
constexpr index_t M1PerThread = CK_PARAM_M1PerThread;
|
||||
constexpr index_t N1PerThread = CK_PARAM_N1PerThread;
|
||||
constexpr index_t KPerThread = CK_PARAM_KPerThread;
|
||||
constexpr index_t M1N1ThreadClusterM10 = CK_PARAM_M1N1ThreadClusterM10;
|
||||
constexpr index_t M1N1ThreadClusterN10 = CK_PARAM_M1N1ThreadClusterN10;
|
||||
constexpr index_t M1N1ThreadClusterM11 = CK_PARAM_M1N1ThreadClusterM11;
|
||||
constexpr index_t M1N1ThreadClusterN11 = CK_PARAM_M1N1ThreadClusterN11;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K_M0_M1 =
|
||||
Sequence<CK_PARAM_ABlockTransferThreadSliceLengths_K_M0_M1>;
|
||||
using ABlockTransferThreadClusterLengths_K_M0_M1 =
|
||||
Sequence<CK_PARAM_ABlockTransferThreadClusterLengths_K_M0_M1>;
|
||||
using ABlockTransferThreadClusterArrangeOrder =
|
||||
Sequence<CK_PARAM_ABlockTransferThreadClusterArrangeOrder>;
|
||||
using ABlockTransferSrcAccessOrder = Sequence<CK_PARAM_ABlockTransferSrcAccessOrder>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcVectorDim = CK_PARAM_ABlockTransferSrcVectorDim;
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector = CK_PARAM_ABlockTransferSrcScalarPerVector;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_M1 =
|
||||
CK_PARAM_ABlockTransferDstScalarPerVector_M1;
|
||||
constexpr bool AThreadTransferSrcResetCoordinateAfterRun =
|
||||
static_cast<bool>(CK_PARAM_AThreadTransferSrcResetCoordinateAfterRun);
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K_N0_N1 =
|
||||
Sequence<CK_PARAM_BBlockTransferThreadSliceLengths_K_N0_N1>;
|
||||
using BBlockTransferThreadClusterLengths_K_N0_N1 =
|
||||
Sequence<CK_PARAM_BBlockTransferThreadClusterLengths_K_N0_N1>;
|
||||
using BBlockTransferThreadClusterArrangeOrder =
|
||||
Sequence<CK_PARAM_BBlockTransferThreadClusterArrangeOrder>;
|
||||
using BBlockTransferSrcAccessOrder = Sequence<CK_PARAM_BBlockTransferSrcAccessOrder>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcVectorDim = CK_PARAM_BBlockTransferSrcVectorDim;
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector = CK_PARAM_BBlockTransferSrcScalarPerVector;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_N1 =
|
||||
CK_PARAM_BBlockTransferDstScalarPerVector_N1;
|
||||
constexpr bool BThreadTransferSrcResetCoordinateAfterRun =
|
||||
static_cast<bool>(CK_PARAM_BThreadTransferSrcResetCoordinateAfterRun);
|
||||
|
||||
using CThreadTransferSrcDstAccessOrder = Sequence<CK_PARAM_CThreadTransferSrcDstAccessOrder>;
|
||||
constexpr index_t CThreadTransferSrcDstVectorDim = CK_PARAM_CThreadTransferSrcDstVectorDim;
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = CK_PARAM_CThreadTransferDstScalarPerVector;
|
||||
|
||||
constexpr bool HasMainKBlockLoop = static_cast<bool>(CK_PARAM_HAS_MAIN_KBLOCK_LOOP);
|
||||
constexpr bool HasDoubleTailKBlockLoop = static_cast<bool>(CK_PARAM_HAS_DOUBLE_TAIL_KBLOCK_LOOP);
|
||||
|
||||
extern "C" __global__ void convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw_prepare(
|
||||
int n,
|
||||
int c,
|
||||
int hi,
|
||||
int wi,
|
||||
int k,
|
||||
int y,
|
||||
int x,
|
||||
int convStrideH,
|
||||
int convStrideW,
|
||||
int convDilationY,
|
||||
int convDilationX,
|
||||
int leftPadH,
|
||||
int leftPadW,
|
||||
int rightPadH,
|
||||
int rightPadW,
|
||||
void* p_a_k_m0_m1_grid_desc,
|
||||
void* p_b_k_n0_n1_grid_desc,
|
||||
void* p_c_m0_m10_m11_n0_n10_n11_grid_desc,
|
||||
void* p_cblockid_to_m0_n0_block_cluster_adaptor)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
|
||||
const index_t ho = (hi + leftPadH + rightPadH - convDilationY * (y - 1) - 1) / convStrideH + 1;
|
||||
const index_t wo = (wi + leftPadW + rightPadW - convDilationX * (x - 1) - 1) / convStrideW + 1;
|
||||
|
||||
const auto in_n_c_hi_wi_desc = make_naive_tensor_descriptor_packed(make_tuple(n, c, hi, wi));
|
||||
const auto wei_k_c_y_x_desc = make_naive_tensor_descriptor_packed(make_tuple(k, c, y, x));
|
||||
const auto out_n_k_ho_wo_desc = make_naive_tensor_descriptor_packed(make_tuple(n, k, ho, wo));
|
||||
|
||||
const auto descs = transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw_pad(
|
||||
wei_k_c_y_x_desc,
|
||||
in_n_c_hi_wi_desc,
|
||||
out_n_k_ho_wo_desc,
|
||||
make_tuple(convStrideH, convStrideW),
|
||||
make_tuple(convDilationY, convDilationX),
|
||||
make_tuple(leftPadH, leftPadW),
|
||||
make_tuple(rightPadH, rightPadW));
|
||||
|
||||
const auto a_k_m_grid_desc = descs[I0];
|
||||
const auto b_k_n_grid_desc = descs[I1];
|
||||
const auto c_m_n_grid_desc = descs[I2];
|
||||
|
||||
using AKMGridDesc = decltype(a_k_m_grid_desc);
|
||||
using BKNGridDesc = decltype(b_k_n_grid_desc);
|
||||
using CMNGridDesc = decltype(c_m_n_grid_desc);
|
||||
|
||||
using AGridStepHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{})));
|
||||
|
||||
using BGridStepHacks =
|
||||
decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{})));
|
||||
|
||||
using CGridStepHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 1, 0, 0>{},
|
||||
Sequence<0, 0, 1, 0, 0>{},
|
||||
Sequence<0, 0, 1, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 2, 0, 0>{},
|
||||
Sequence<0, 0, 2, 0, 0>{},
|
||||
Sequence<0, 0, 2, 0, 0>{})));
|
||||
|
||||
using AGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0>;
|
||||
using BGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>;
|
||||
|
||||
using GridwiseGemm =
|
||||
GridwiseGemmDlops_km_kn_mn_v1r2<BlockSize,
|
||||
FloatAB,
|
||||
FloatAcc,
|
||||
FloatC,
|
||||
InMemoryDataOperationEnum_t::Set, /* ToDo tunable */
|
||||
AKMGridDesc,
|
||||
BKNGridDesc,
|
||||
CMNGridDesc,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
M1PerThread,
|
||||
N1PerThread,
|
||||
KPerThread,
|
||||
M1N1ThreadClusterM10,
|
||||
M1N1ThreadClusterN10,
|
||||
M1N1ThreadClusterM11,
|
||||
M1N1ThreadClusterN11,
|
||||
ABlockTransferThreadSliceLengths_K_M0_M1,
|
||||
ABlockTransferThreadClusterLengths_K_M0_M1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorDim,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_M1,
|
||||
AThreadTransferSrcResetCoordinateAfterRun,
|
||||
BBlockTransferThreadSliceLengths_K_N0_N1,
|
||||
BBlockTransferThreadClusterLengths_K_N0_N1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorDim,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_N1,
|
||||
BThreadTransferSrcResetCoordinateAfterRun,
|
||||
CThreadTransferSrcDstAccessOrder,
|
||||
CThreadTransferSrcDstVectorDim,
|
||||
CThreadTransferDstScalarPerVector,
|
||||
AGridStepHacks,
|
||||
BGridStepHacks,
|
||||
CGridStepHacks,
|
||||
AGridMoveSliceWindowStepHacks,
|
||||
BGridMoveSliceWindowStepHacks>;
|
||||
|
||||
auto a_k_m0_m1_grid_desc = GridwiseGemm::MakeAKM0M1GridDescriptor(a_k_m_grid_desc);
|
||||
auto b_k_n0_n1_grid_desc = GridwiseGemm::MakeBKN0N1GridDescriptor(b_k_n_grid_desc);
|
||||
auto c_m0_m10_m11_n0_n10_n11_grid_desc =
|
||||
GridwiseGemm::MakeCM0M10M11N0N10N11GridDescriptor(c_m_n_grid_desc);
|
||||
auto cblockid_to_m0_n0_block_cluster_adaptor =
|
||||
GridwiseGemm::MakeCBlockIdToM0N0BlockClusterAdaptor(c_m_n_grid_desc);
|
||||
|
||||
if(hipThreadIdx_x == 0)
|
||||
{
|
||||
*static_cast<decltype(a_k_m0_m1_grid_desc)*>(p_a_k_m0_m1_grid_desc) = a_k_m0_m1_grid_desc;
|
||||
*static_cast<decltype(b_k_n0_n1_grid_desc)*>(p_b_k_n0_n1_grid_desc) = b_k_n0_n1_grid_desc;
|
||||
*static_cast<decltype(c_m0_m10_m11_n0_n10_n11_grid_desc)*>(
|
||||
p_c_m0_m10_m11_n0_n10_n11_grid_desc) = c_m0_m10_m11_n0_n10_n11_grid_desc;
|
||||
*static_cast<decltype(cblockid_to_m0_n0_block_cluster_adaptor)*>(
|
||||
p_cblockid_to_m0_n0_block_cluster_adaptor) = cblockid_to_m0_n0_block_cluster_adaptor;
|
||||
};
|
||||
};
|
||||
|
||||
extern "C" __global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
#endif
|
||||
convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw(
|
||||
const FloatAB* __restrict__ p_a_grid,
|
||||
const FloatAB* __restrict__ p_b_grid,
|
||||
FloatC* __restrict__ p_c_grid,
|
||||
const void CONSTANT* p_a_k_m0_m1_grid_desc,
|
||||
const void CONSTANT* p_b_k_n0_n1_grid_desc,
|
||||
const void CONSTANT* p_c_m0_m10_m11_n0_n10_n11_grid_desc,
|
||||
const void CONSTANT* p_cblockid_to_m0_n0_block_cluster_adaptor)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
|
||||
constexpr auto in_n_c_hi_wi_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(256, 256, 28, 28));
|
||||
constexpr auto wei_k_c_y_x_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(256, 256, 3, 3));
|
||||
constexpr auto out_n_k_ho_wo_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(256, 256, 28, 28));
|
||||
|
||||
constexpr auto descs =
|
||||
transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw_pad(wei_k_c_y_x_desc,
|
||||
in_n_c_hi_wi_desc,
|
||||
out_n_k_ho_wo_desc,
|
||||
make_tuple(1, 1),
|
||||
make_tuple(1, 1),
|
||||
make_tuple(1, 1),
|
||||
make_tuple(1, 1));
|
||||
|
||||
constexpr auto a_k_m_grid_desc = descs[I0];
|
||||
constexpr auto b_k_n_grid_desc = descs[I1];
|
||||
constexpr auto c_m_n_grid_desc = descs[I2];
|
||||
|
||||
using AKMGridDesc = decltype(a_k_m_grid_desc);
|
||||
using BKNGridDesc = decltype(b_k_n_grid_desc);
|
||||
using CMNGridDesc = decltype(c_m_n_grid_desc);
|
||||
|
||||
using AGridStepHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{})));
|
||||
|
||||
using BGridStepHacks =
|
||||
decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{})));
|
||||
|
||||
using CGridStepHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 1, 0, 0>{},
|
||||
Sequence<0, 0, 1, 0, 0>{},
|
||||
Sequence<0, 0, 1, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 2, 0, 0>{},
|
||||
Sequence<0, 0, 2, 0, 0>{},
|
||||
Sequence<0, 0, 2, 0, 0>{})));
|
||||
|
||||
using AGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0>;
|
||||
using BGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>;
|
||||
|
||||
using GridwiseGemm =
|
||||
GridwiseGemmDlops_km_kn_mn_v1r2<BlockSize,
|
||||
FloatAB,
|
||||
FloatAcc,
|
||||
FloatC,
|
||||
InMemoryDataOperationEnum_t::Set, /* ToDo tunable */
|
||||
AKMGridDesc,
|
||||
BKNGridDesc,
|
||||
CMNGridDesc,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
M1PerThread,
|
||||
N1PerThread,
|
||||
KPerThread,
|
||||
M1N1ThreadClusterM10,
|
||||
M1N1ThreadClusterN10,
|
||||
M1N1ThreadClusterM11,
|
||||
M1N1ThreadClusterN11,
|
||||
ABlockTransferThreadSliceLengths_K_M0_M1,
|
||||
ABlockTransferThreadClusterLengths_K_M0_M1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorDim,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_M1,
|
||||
AThreadTransferSrcResetCoordinateAfterRun,
|
||||
BBlockTransferThreadSliceLengths_K_N0_N1,
|
||||
BBlockTransferThreadClusterLengths_K_N0_N1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorDim,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_N1,
|
||||
BThreadTransferSrcResetCoordinateAfterRun,
|
||||
CThreadTransferSrcDstAccessOrder,
|
||||
CThreadTransferSrcDstVectorDim,
|
||||
CThreadTransferDstScalarPerVector,
|
||||
AGridStepHacks,
|
||||
BGridStepHacks,
|
||||
CGridStepHacks,
|
||||
AGridMoveSliceWindowStepHacks,
|
||||
BGridMoveSliceWindowStepHacks>;
|
||||
|
||||
constexpr auto a_k_m0_m1_grid_desc_tmp =
|
||||
GridwiseGemm::MakeAKM0M1GridDescriptor(a_k_m_grid_desc);
|
||||
constexpr auto b_k_n0_n1_grid_desc_tmp =
|
||||
GridwiseGemm::MakeBKN0N1GridDescriptor(b_k_n_grid_desc);
|
||||
constexpr auto c_m0_m10_m11_n0_n10_n11_grid_desc_tmp =
|
||||
GridwiseGemm::MakeCM0M10M11N0N10N11GridDescriptor(c_m_n_grid_desc);
|
||||
constexpr auto cblockid_to_m0_n0_block_cluster_adaptor_tmp =
|
||||
GridwiseGemm::MakeCBlockIdToM0N0BlockClusterAdaptor(c_m_n_grid_desc);
|
||||
|
||||
using AKM0M1GridDesc = decltype(a_k_m0_m1_grid_desc_tmp);
|
||||
using BKN0N1GridDesc = decltype(b_k_n0_n1_grid_desc_tmp);
|
||||
using CM0M10M11N0N10N11GridDesc = decltype(c_m0_m10_m11_n0_n10_n11_grid_desc_tmp);
|
||||
using CBlockIdToM0N0BlockClusterAdaptor = decltype(cblockid_to_m0_n0_block_cluster_adaptor_tmp);
|
||||
|
||||
const auto a_k_m0_m1_grid_desc =
|
||||
*reinterpret_cast<const AKM0M1GridDesc*>((const void*)p_a_k_m0_m1_grid_desc);
|
||||
const auto b_k_n0_n1_grid_desc =
|
||||
*reinterpret_cast<const BKN0N1GridDesc*>((const void*)p_b_k_n0_n1_grid_desc);
|
||||
const auto c_m0_m10_m11_n0_n10_n11_grid_desc =
|
||||
*reinterpret_cast<const CM0M10M11N0N10N11GridDesc*>(
|
||||
(const void*)p_c_m0_m10_m11_n0_n10_n11_grid_desc);
|
||||
const auto cblockid_to_m0_n0_block_cluster_adaptor =
|
||||
*reinterpret_cast<const CBlockIdToM0N0BlockClusterAdaptor*>(
|
||||
(const void*)p_cblockid_to_m0_n0_block_cluster_adaptor);
|
||||
|
||||
constexpr index_t shared_block_size =
|
||||
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
|
||||
|
||||
__shared__ FloatAB p_shared_block[shared_block_size];
|
||||
|
||||
GridwiseGemm::Run(p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
p_shared_block,
|
||||
a_k_m0_m1_grid_desc,
|
||||
b_k_n0_n1_grid_desc,
|
||||
c_m0_m10_m11_n0_n10_n11_grid_desc,
|
||||
cblockid_to_m0_n0_block_cluster_adaptor,
|
||||
integral_constant<bool, HasMainKBlockLoop>{},
|
||||
integral_constant<bool, HasDoubleTailKBlockLoop>{});
|
||||
};
|
||||
@@ -1,357 +0,0 @@
|
||||
#include "common_header.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
#include "gridwise_gemm_xdlops_v2r3.hpp"
|
||||
#include "transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw.hpp"
|
||||
|
||||
using namespace ck;
|
||||
|
||||
constexpr DataTypeEnum_t ABDataTypeEnum = static_cast<DataTypeEnum_t>(CK_PARAM_ABDataTypeEnum);
|
||||
constexpr DataTypeEnum_t AccDataTypeEnum = static_cast<DataTypeEnum_t>(CK_PARAM_AccDataTypeEnum);
|
||||
constexpr DataTypeEnum_t CDataTypeEnum = static_cast<DataTypeEnum_t>(CK_PARAM_CDataTypeEnum);
|
||||
|
||||
using FloatAB = typename get_datatype_from_enum<ABDataTypeEnum>::type;
|
||||
using FloatAcc = typename get_datatype_from_enum<AccDataTypeEnum>::type;
|
||||
using FloatC = typename get_datatype_from_enum<CDataTypeEnum>::type;
|
||||
|
||||
constexpr index_t BlockSize = CK_PARAM_BlockSize;
|
||||
|
||||
constexpr index_t MPerBlock = CK_PARAM_MPerBlock;
|
||||
constexpr index_t NPerBlock = CK_PARAM_NPerBlock;
|
||||
constexpr index_t KPerBlock = CK_PARAM_KPerBlock;
|
||||
|
||||
constexpr index_t MPerWave = CK_PARAM_MPerWave;
|
||||
constexpr index_t NPerWave = CK_PARAM_NPerWave;
|
||||
constexpr index_t MRepeat = CK_PARAM_MRepeat;
|
||||
constexpr index_t NRepeat = CK_PARAM_NRepeat;
|
||||
constexpr index_t K1 = CK_PARAM_K1;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 =
|
||||
Sequence<CK_PARAM_ABlockTransferThreadSliceLengths_K0_M_K1>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 =
|
||||
Sequence<CK_PARAM_ABlockTransferThreadClusterLengths_K0_M_K1>;
|
||||
using ABlockTransferThreadClusterArrangeOrder =
|
||||
Sequence<CK_PARAM_ABlockTransferThreadClusterArrangeOrder>;
|
||||
using ABlockTransferSrcAccessOrder = Sequence<CK_PARAM_ABlockTransferSrcAccessOrder>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcVectorDim = CK_PARAM_ABlockTransferSrcVectorDim;
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector = CK_PARAM_ABlockTransferSrcScalarPerVector;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 =
|
||||
CK_PARAM_ABlockTransferDstScalarPerVector_K1;
|
||||
constexpr bool AThreadTransferSrcResetCoordinateAfterRun =
|
||||
static_cast<bool>(CK_PARAM_AThreadTransferSrcResetCoordinateAfterRun);
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 =
|
||||
Sequence<CK_PARAM_BBlockTransferThreadSliceLengths_K0_N_K1>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 =
|
||||
Sequence<CK_PARAM_BBlockTransferThreadClusterLengths_K0_N_K1>;
|
||||
using BBlockTransferThreadClusterArrangeOrder =
|
||||
Sequence<CK_PARAM_BBlockTransferThreadClusterArrangeOrder>;
|
||||
using BBlockTransferSrcAccessOrder = Sequence<CK_PARAM_BBlockTransferSrcAccessOrder>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcVectorDim = CK_PARAM_BBlockTransferSrcVectorDim;
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector = CK_PARAM_BBlockTransferSrcScalarPerVector;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 =
|
||||
CK_PARAM_BBlockTransferDstScalarPerVector_K1;
|
||||
constexpr bool BThreadTransferSrcResetCoordinateAfterRun =
|
||||
static_cast<bool>(CK_PARAM_BThreadTransferSrcResetCoordinateAfterRun);
|
||||
|
||||
using CThreadTransferSrcDstAccessOrder = Sequence<CK_PARAM_CThreadTransferSrcDstAccessOrder>;
|
||||
constexpr index_t CThreadTransferSrcDstVectorDim = CK_PARAM_CThreadTransferSrcDstVectorDim;
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = CK_PARAM_CThreadTransferDstScalarPerVector;
|
||||
|
||||
extern "C" __global__ void convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw_prepare(
|
||||
int n,
|
||||
int c,
|
||||
int hi,
|
||||
int wi,
|
||||
int k,
|
||||
int y,
|
||||
int x,
|
||||
int convStrideH,
|
||||
int convStrideW,
|
||||
int convDilationY,
|
||||
int convDilationX,
|
||||
int leftPadH,
|
||||
int leftPadW,
|
||||
int rightPadH,
|
||||
int rightPadW,
|
||||
void* p_a_k0_m_k1_grid_desc,
|
||||
void* p_b_k0_n_k1_grid_desc,
|
||||
void* p_c_m0_m1_m2_n_grid_desc,
|
||||
void* p_cblockid_to_m0_n0_block_cluster_adaptor)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
|
||||
const index_t ho = (hi + leftPadH + rightPadH - convDilationY * (y - 1) - 1) / convStrideH + 1;
|
||||
const index_t wo = (wi + leftPadW + rightPadW - convDilationX * (x - 1) - 1) / convStrideW + 1;
|
||||
|
||||
const auto in_n_c_hi_wi_desc = make_naive_tensor_descriptor_packed(make_tuple(n, c, hi, wi));
|
||||
const auto wei_k_c_y_x_desc = make_naive_tensor_descriptor_packed(make_tuple(k, c, y, x));
|
||||
const auto out_n_k_ho_wo_desc = make_naive_tensor_descriptor_packed(make_tuple(n, k, ho, wo));
|
||||
|
||||
const auto descs = transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw_pad(
|
||||
wei_k_c_y_x_desc,
|
||||
in_n_c_hi_wi_desc,
|
||||
out_n_k_ho_wo_desc,
|
||||
make_tuple(convStrideH, convStrideW),
|
||||
make_tuple(convDilationY, convDilationX),
|
||||
make_tuple(leftPadH, leftPadW),
|
||||
make_tuple(rightPadH, rightPadW),
|
||||
Number<K1>{});
|
||||
|
||||
const auto a_k0_m_k1_grid_desc = descs[I0];
|
||||
const auto b_k0_n_k1_grid_desc = descs[I1];
|
||||
const auto c_m_n_grid_desc = descs[I2];
|
||||
|
||||
using AK0MK1GridDesc = decltype(a_k0_m_k1_grid_desc);
|
||||
using BK0NK1GridDesc = decltype(b_k0_n_k1_grid_desc);
|
||||
using CMNGridDesc = decltype(c_m_n_grid_desc);
|
||||
|
||||
using AGridStepHacks = decltype(make_tuple(
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}),
|
||||
make_tuple(
|
||||
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{})));
|
||||
|
||||
using BGridStepHacks =
|
||||
decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{})));
|
||||
|
||||
using CGridStepHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 1, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 1, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 1, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 2, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 2, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 2, 0, 0>{})));
|
||||
|
||||
using AGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0>;
|
||||
using BGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>;
|
||||
|
||||
using GridwiseGemm =
|
||||
GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3<BlockSize,
|
||||
FloatAB,
|
||||
FloatAcc,
|
||||
FloatC,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
AK0MK1GridDesc,
|
||||
BK0NK1GridDesc,
|
||||
CMNGridDesc,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerWave,
|
||||
NPerWave,
|
||||
K1,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
ABlockTransferThreadSliceLengths_K0_M_K1,
|
||||
ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorDim,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_K1,
|
||||
AThreadTransferSrcResetCoordinateAfterRun,
|
||||
BBlockTransferThreadSliceLengths_K0_N_K1,
|
||||
BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorDim,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_K1,
|
||||
BThreadTransferSrcResetCoordinateAfterRun,
|
||||
CThreadTransferSrcDstAccessOrder,
|
||||
CThreadTransferSrcDstVectorDim,
|
||||
CThreadTransferDstScalarPerVector,
|
||||
AGridStepHacks,
|
||||
BGridStepHacks,
|
||||
CGridStepHacks,
|
||||
AGridMoveSliceWindowStepHacks,
|
||||
BGridMoveSliceWindowStepHacks,
|
||||
false>;
|
||||
|
||||
auto c_m0_m1_m2_n_grid_desc = GridwiseGemm::MakeCM0M1M2NGridDescriptor(c_m_n_grid_desc);
|
||||
|
||||
auto cblockid_to_m0_n0_block_cluster_adaptor =
|
||||
GridwiseGemm::MakeCBlockClusterAdaptor(c_m_n_grid_desc);
|
||||
|
||||
if(hipThreadIdx_x == 0)
|
||||
{
|
||||
*static_cast<remove_cv_t<decltype(a_k0_m_k1_grid_desc)>*>(p_a_k0_m_k1_grid_desc) =
|
||||
a_k0_m_k1_grid_desc;
|
||||
*static_cast<remove_cv_t<decltype(b_k0_n_k1_grid_desc)>*>(p_b_k0_n_k1_grid_desc) =
|
||||
b_k0_n_k1_grid_desc;
|
||||
*static_cast<decltype(c_m0_m1_m2_n_grid_desc)*>(p_c_m0_m1_m2_n_grid_desc) =
|
||||
c_m0_m1_m2_n_grid_desc;
|
||||
*static_cast<decltype(cblockid_to_m0_n0_block_cluster_adaptor)*>(
|
||||
p_cblockid_to_m0_n0_block_cluster_adaptor) = cblockid_to_m0_n0_block_cluster_adaptor;
|
||||
}
|
||||
};
|
||||
|
||||
extern "C" __global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
#endif
|
||||
convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw(
|
||||
const FloatAB* __restrict__ p_a_grid,
|
||||
const FloatAB* __restrict__ p_b_grid,
|
||||
FloatC* __restrict__ p_c_grid,
|
||||
const void CONSTANT* p_a_k0_m_k1_grid_desc,
|
||||
const void CONSTANT* p_b_k0_n_k1_grid_desc,
|
||||
const void CONSTANT* p_c_m0_m1_m2_n_grid_desc,
|
||||
const void CONSTANT* p_cblockid_to_m0_n0_block_cluster_adaptor)
|
||||
{
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
|
||||
constexpr auto in_n_c_hi_wi_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(256, 256, 28, 28));
|
||||
constexpr auto wei_k_c_y_x_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(256, 256, 3, 3));
|
||||
constexpr auto out_n_k_ho_wo_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(256, 256, 28, 28));
|
||||
|
||||
constexpr auto descs =
|
||||
transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw_pad(wei_k_c_y_x_desc,
|
||||
in_n_c_hi_wi_desc,
|
||||
out_n_k_ho_wo_desc,
|
||||
make_tuple(1, 1),
|
||||
make_tuple(1, 1),
|
||||
make_tuple(1, 1),
|
||||
make_tuple(1, 1),
|
||||
Number<K1>{});
|
||||
|
||||
constexpr auto a_k0_m_k1_grid_desc_tmp = descs[I0];
|
||||
constexpr auto b_k0_n_k1_grid_desc_tmp = descs[I1];
|
||||
constexpr auto c_m_n_grid_desc = descs[I2];
|
||||
|
||||
using AGridStepHacks = decltype(make_tuple(
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}),
|
||||
make_tuple(
|
||||
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{})));
|
||||
|
||||
using BGridStepHacks =
|
||||
decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{})));
|
||||
|
||||
using CGridStepHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 1, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 1, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 1, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 2, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 2, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 2, 0, 0>{})));
|
||||
|
||||
using AGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0>;
|
||||
using BGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>;
|
||||
|
||||
using AK0MK1GridDesc = decltype(a_k0_m_k1_grid_desc_tmp);
|
||||
using BK0NK1GridDesc = decltype(b_k0_n_k1_grid_desc_tmp);
|
||||
using CMNGridDesc = decltype(c_m_n_grid_desc);
|
||||
|
||||
using GridwiseGemm =
|
||||
GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3<BlockSize,
|
||||
FloatAB,
|
||||
FloatAcc,
|
||||
FloatC,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
AK0MK1GridDesc,
|
||||
BK0NK1GridDesc,
|
||||
CMNGridDesc,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerWave,
|
||||
NPerWave,
|
||||
K1,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
ABlockTransferThreadSliceLengths_K0_M_K1,
|
||||
ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorDim,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_K1,
|
||||
AThreadTransferSrcResetCoordinateAfterRun,
|
||||
BBlockTransferThreadSliceLengths_K0_N_K1,
|
||||
BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorDim,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_K1,
|
||||
BThreadTransferSrcResetCoordinateAfterRun,
|
||||
CThreadTransferSrcDstAccessOrder,
|
||||
CThreadTransferSrcDstVectorDim,
|
||||
CThreadTransferDstScalarPerVector,
|
||||
AGridStepHacks,
|
||||
BGridStepHacks,
|
||||
CGridStepHacks,
|
||||
AGridMoveSliceWindowStepHacks,
|
||||
BGridMoveSliceWindowStepHacks,
|
||||
false>;
|
||||
|
||||
constexpr auto c_m0_m1_m2_n_grid_desc_tmp =
|
||||
GridwiseGemm::MakeCM0M1M2NGridDescriptor(c_m_n_grid_desc);
|
||||
constexpr auto cblockid_to_m0_n0_block_cluster_adaptor_tmp =
|
||||
GridwiseGemm::MakeCBlockClusterAdaptor(c_m_n_grid_desc);
|
||||
|
||||
using CM0M1M2NGridDesc = decltype(c_m0_m1_m2_n_grid_desc_tmp);
|
||||
using CBlockIdToM0N0BlockClusterAdaptor = decltype(cblockid_to_m0_n0_block_cluster_adaptor_tmp);
|
||||
|
||||
const auto a_k0_m_k1_grid_desc =
|
||||
*reinterpret_cast<const AK0MK1GridDesc*>((const void*)p_a_k0_m_k1_grid_desc);
|
||||
const auto b_k0_n_k1_grid_desc =
|
||||
*reinterpret_cast<const BK0NK1GridDesc*>((const void*)p_b_k0_n_k1_grid_desc);
|
||||
const auto c_m0_m1_m2_n_grid_desc =
|
||||
*reinterpret_cast<const CM0M1M2NGridDesc*>((const void*)p_c_m0_m1_m2_n_grid_desc);
|
||||
const auto cblockid_to_m0_n0_block_cluster_adaptor =
|
||||
*reinterpret_cast<const CBlockIdToM0N0BlockClusterAdaptor*>(
|
||||
(const void*)p_cblockid_to_m0_n0_block_cluster_adaptor);
|
||||
|
||||
constexpr index_t shared_block_size =
|
||||
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
|
||||
|
||||
__shared__ FloatAB p_shared_block[shared_block_size];
|
||||
|
||||
GridwiseGemm::Run(p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
p_shared_block,
|
||||
a_k0_m_k1_grid_desc,
|
||||
b_k0_n_k1_grid_desc,
|
||||
c_m0_m1_m2_n_grid_desc,
|
||||
cblockid_to_m0_n0_block_cluster_adaptor);
|
||||
};
|
||||
@@ -1,356 +0,0 @@
|
||||
#include "common_header.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
#include "gridwise_gemm_xdlops_v2r3.hpp"
|
||||
#include "transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp"
|
||||
|
||||
using namespace ck;
|
||||
|
||||
constexpr DataTypeEnum_t ABDataTypeEnum = static_cast<DataTypeEnum_t>(CK_PARAM_ABDataTypeEnum);
|
||||
constexpr DataTypeEnum_t AccDataTypeEnum = static_cast<DataTypeEnum_t>(CK_PARAM_AccDataTypeEnum);
|
||||
constexpr DataTypeEnum_t CDataTypeEnum = static_cast<DataTypeEnum_t>(CK_PARAM_CDataTypeEnum);
|
||||
|
||||
using FloatAB = typename get_datatype_from_enum<ABDataTypeEnum>::type;
|
||||
using FloatAcc = typename get_datatype_from_enum<AccDataTypeEnum>::type;
|
||||
using FloatC = typename get_datatype_from_enum<CDataTypeEnum>::type;
|
||||
|
||||
constexpr index_t BlockSize = CK_PARAM_BlockSize;
|
||||
|
||||
constexpr index_t MPerBlock = CK_PARAM_MPerBlock;
|
||||
constexpr index_t NPerBlock = CK_PARAM_NPerBlock;
|
||||
constexpr index_t KPerBlock = CK_PARAM_KPerBlock;
|
||||
|
||||
constexpr index_t MPerWave = CK_PARAM_MPerWave;
|
||||
constexpr index_t NPerWave = CK_PARAM_NPerWave;
|
||||
constexpr index_t MRepeat = CK_PARAM_MRepeat;
|
||||
constexpr index_t NRepeat = CK_PARAM_NRepeat;
|
||||
constexpr index_t K1 = CK_PARAM_K1;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 =
|
||||
Sequence<CK_PARAM_ABlockTransferThreadSliceLengths_K0_M_K1>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 =
|
||||
Sequence<CK_PARAM_ABlockTransferThreadClusterLengths_K0_M_K1>;
|
||||
using ABlockTransferThreadClusterArrangeOrder =
|
||||
Sequence<CK_PARAM_ABlockTransferThreadClusterArrangeOrder>;
|
||||
using ABlockTransferSrcAccessOrder = Sequence<CK_PARAM_ABlockTransferSrcAccessOrder>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcVectorDim = CK_PARAM_ABlockTransferSrcVectorDim;
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector = CK_PARAM_ABlockTransferSrcScalarPerVector;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 =
|
||||
CK_PARAM_ABlockTransferDstScalarPerVector_K1;
|
||||
constexpr bool AThreadTransferSrcResetCoordinateAfterRun =
|
||||
static_cast<bool>(CK_PARAM_AThreadTransferSrcResetCoordinateAfterRun);
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 =
|
||||
Sequence<CK_PARAM_BBlockTransferThreadSliceLengths_K0_N_K1>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 =
|
||||
Sequence<CK_PARAM_BBlockTransferThreadClusterLengths_K0_N_K1>;
|
||||
using BBlockTransferThreadClusterArrangeOrder =
|
||||
Sequence<CK_PARAM_BBlockTransferThreadClusterArrangeOrder>;
|
||||
using BBlockTransferSrcAccessOrder = Sequence<CK_PARAM_BBlockTransferSrcAccessOrder>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcVectorDim = CK_PARAM_BBlockTransferSrcVectorDim;
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector = CK_PARAM_BBlockTransferSrcScalarPerVector;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 =
|
||||
CK_PARAM_BBlockTransferDstScalarPerVector_K1;
|
||||
constexpr bool BThreadTransferSrcResetCoordinateAfterRun =
|
||||
static_cast<bool>(CK_PARAM_BThreadTransferSrcResetCoordinateAfterRun);
|
||||
|
||||
using CThreadTransferSrcDstAccessOrder = Sequence<CK_PARAM_CThreadTransferSrcDstAccessOrder>;
|
||||
constexpr index_t CThreadTransferSrcDstVectorDim = CK_PARAM_CThreadTransferSrcDstVectorDim;
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = CK_PARAM_CThreadTransferDstScalarPerVector;
|
||||
|
||||
extern "C" __global__ void convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_kyxc_nhwk_prepare(
|
||||
int n,
|
||||
int hi,
|
||||
int wi,
|
||||
int c,
|
||||
int k,
|
||||
int y,
|
||||
int x,
|
||||
int convStrideH,
|
||||
int convStrideW,
|
||||
int convDilationY,
|
||||
int convDilationX,
|
||||
int leftPadH,
|
||||
int leftPadW,
|
||||
int rightPadH,
|
||||
int rightPadW,
|
||||
void* p_a_k0_m_k1_grid_desc,
|
||||
void* p_b_k0_n_k1_grid_desc,
|
||||
void* p_c_m0_m1_m2_n_grid_desc,
|
||||
void* p_cblockid_to_m0_n0_block_cluster_adaptor)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
|
||||
const index_t ho = (hi + leftPadH + rightPadH - convDilationY * (y - 1) - 1) / convStrideH + 1;
|
||||
const index_t wo = (wi + leftPadW + rightPadW - convDilationX * (x - 1) - 1) / convStrideW + 1;
|
||||
|
||||
const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor_packed(make_tuple(n, hi, wi, c));
|
||||
const auto wei_k_y_x_c_desc = make_naive_tensor_descriptor_packed(make_tuple(k, y, x, c));
|
||||
const auto out_n_ho_wo_k_desc = make_naive_tensor_descriptor_packed(make_tuple(n, ho, wo, k));
|
||||
|
||||
const auto descs = transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk(
|
||||
in_n_hi_wi_c_desc,
|
||||
wei_k_y_x_c_desc,
|
||||
out_n_ho_wo_k_desc,
|
||||
make_tuple(convStrideH, convStrideW),
|
||||
make_tuple(convDilationY, convDilationX),
|
||||
make_tuple(leftPadH, leftPadW),
|
||||
make_tuple(rightPadH, rightPadW),
|
||||
Number<K1>{});
|
||||
|
||||
const auto a_k0_m_k1_grid_desc = descs[I0];
|
||||
const auto b_k0_n_k1_grid_desc = descs[I1];
|
||||
const auto c_m_n_grid_desc = descs[I2];
|
||||
|
||||
using AK0MK1GridDesc = decltype(a_k0_m_k1_grid_desc);
|
||||
using BK0NK1GridDesc = decltype(b_k0_n_k1_grid_desc);
|
||||
using CMNGridDesc = decltype(c_m_n_grid_desc);
|
||||
|
||||
using BGridStepHacks = decltype(make_tuple(
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}),
|
||||
make_tuple(
|
||||
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{})));
|
||||
|
||||
using AGridStepHacks =
|
||||
decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{})));
|
||||
|
||||
using CGridStepHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 1, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 1, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 1, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 2, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 2, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 2, 0, 0>{})));
|
||||
|
||||
using AGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>;
|
||||
using BGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0>;
|
||||
|
||||
using GridwiseGemm =
|
||||
GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3<BlockSize,
|
||||
FloatAB,
|
||||
FloatAcc,
|
||||
FloatC,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
AK0MK1GridDesc,
|
||||
BK0NK1GridDesc,
|
||||
CMNGridDesc,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerWave,
|
||||
NPerWave,
|
||||
K1,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
ABlockTransferThreadSliceLengths_K0_M_K1,
|
||||
ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorDim,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_K1,
|
||||
AThreadTransferSrcResetCoordinateAfterRun,
|
||||
BBlockTransferThreadSliceLengths_K0_N_K1,
|
||||
BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorDim,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_K1,
|
||||
BThreadTransferSrcResetCoordinateAfterRun,
|
||||
CThreadTransferSrcDstAccessOrder,
|
||||
CThreadTransferSrcDstVectorDim,
|
||||
CThreadTransferDstScalarPerVector,
|
||||
AGridStepHacks,
|
||||
BGridStepHacks,
|
||||
CGridStepHacks,
|
||||
AGridMoveSliceWindowStepHacks,
|
||||
BGridMoveSliceWindowStepHacks,
|
||||
false>;
|
||||
|
||||
auto c_m0_m1_m2_n_grid_desc = GridwiseGemm::MakeCM0M1M2NGridDescriptor(c_m_n_grid_desc);
|
||||
|
||||
auto cblockid_to_m0_n0_block_cluster_adaptor =
|
||||
GridwiseGemm::MakeCBlockClusterAdaptor(c_m_n_grid_desc);
|
||||
|
||||
if(hipThreadIdx_x == 0)
|
||||
{
|
||||
*static_cast<remove_cv_t<decltype(a_k0_m_k1_grid_desc)>*>(p_a_k0_m_k1_grid_desc) =
|
||||
a_k0_m_k1_grid_desc;
|
||||
*static_cast<remove_cv_t<decltype(b_k0_n_k1_grid_desc)>*>(p_b_k0_n_k1_grid_desc) =
|
||||
b_k0_n_k1_grid_desc;
|
||||
*static_cast<decltype(c_m0_m1_m2_n_grid_desc)*>(p_c_m0_m1_m2_n_grid_desc) =
|
||||
c_m0_m1_m2_n_grid_desc;
|
||||
*static_cast<decltype(cblockid_to_m0_n0_block_cluster_adaptor)*>(
|
||||
p_cblockid_to_m0_n0_block_cluster_adaptor) = cblockid_to_m0_n0_block_cluster_adaptor;
|
||||
}
|
||||
};
|
||||
|
||||
extern "C" __global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
#endif
|
||||
convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_kyxc_nhwk(
|
||||
const FloatAB* __restrict__ p_a_grid,
|
||||
const FloatAB* __restrict__ p_b_grid,
|
||||
FloatC* __restrict__ p_c_grid,
|
||||
const void CONSTANT* p_a_k0_m_k1_grid_desc,
|
||||
const void CONSTANT* p_b_k0_n_k1_grid_desc,
|
||||
const void CONSTANT* p_c_m0_m1_m2_n_grid_desc,
|
||||
const void CONSTANT* p_cblockid_to_m0_n0_block_cluster_adaptor)
|
||||
{
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
|
||||
constexpr auto in_n_hi_wi_c_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(256, 28, 28, 256));
|
||||
constexpr auto wei_k_y_x_c_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(256, 3, 3, 256));
|
||||
constexpr auto out_n_ho_wo_k_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(256, 28, 28, 256));
|
||||
|
||||
constexpr auto descs =
|
||||
transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk(in_n_hi_wi_c_desc,
|
||||
wei_k_y_x_c_desc,
|
||||
out_n_ho_wo_k_desc,
|
||||
make_tuple(1, 1),
|
||||
make_tuple(1, 1),
|
||||
make_tuple(1, 1),
|
||||
make_tuple(1, 1),
|
||||
Number<K1>{});
|
||||
|
||||
constexpr auto a_k0_m_k1_grid_desc_tmp = descs[I0];
|
||||
constexpr auto b_k0_n_k1_grid_desc_tmp = descs[I1];
|
||||
constexpr auto c_m_n_grid_desc = descs[I2];
|
||||
|
||||
using AK0MK1GridDesc = decltype(a_k0_m_k1_grid_desc_tmp);
|
||||
using BK0NK1GridDesc = decltype(b_k0_n_k1_grid_desc_tmp);
|
||||
using CMNGridDesc = decltype(c_m_n_grid_desc);
|
||||
|
||||
using BGridStepHacks = decltype(make_tuple(
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}),
|
||||
make_tuple(
|
||||
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{})));
|
||||
|
||||
using AGridStepHacks =
|
||||
decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{})));
|
||||
|
||||
using CGridStepHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 1, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 1, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 1, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 2, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 2, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 2, 0, 0>{})));
|
||||
|
||||
using AGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>;
|
||||
using BGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0>;
|
||||
|
||||
using GridwiseGemm =
|
||||
GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3<BlockSize,
|
||||
FloatAB,
|
||||
FloatAcc,
|
||||
FloatC,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
AK0MK1GridDesc,
|
||||
BK0NK1GridDesc,
|
||||
CMNGridDesc,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerWave,
|
||||
NPerWave,
|
||||
K1,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
ABlockTransferThreadSliceLengths_K0_M_K1,
|
||||
ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorDim,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_K1,
|
||||
AThreadTransferSrcResetCoordinateAfterRun,
|
||||
BBlockTransferThreadSliceLengths_K0_N_K1,
|
||||
BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorDim,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_K1,
|
||||
BThreadTransferSrcResetCoordinateAfterRun,
|
||||
CThreadTransferSrcDstAccessOrder,
|
||||
CThreadTransferSrcDstVectorDim,
|
||||
CThreadTransferDstScalarPerVector,
|
||||
AGridStepHacks,
|
||||
BGridStepHacks,
|
||||
CGridStepHacks,
|
||||
AGridMoveSliceWindowStepHacks,
|
||||
BGridMoveSliceWindowStepHacks,
|
||||
false>;
|
||||
constexpr auto c_m0_m1_m2_n_grid_desc_tmp =
|
||||
GridwiseGemm::MakeCM0M1M2NGridDescriptor(c_m_n_grid_desc);
|
||||
constexpr auto cblockid_to_m0_n0_block_cluster_adaptor_tmp =
|
||||
GridwiseGemm::MakeCBlockClusterAdaptor(c_m_n_grid_desc);
|
||||
|
||||
using CM0M1M2NGridDesc = decltype(c_m0_m1_m2_n_grid_desc_tmp);
|
||||
using CBlockIdToM0N0BlockClusterAdaptor = decltype(cblockid_to_m0_n0_block_cluster_adaptor_tmp);
|
||||
|
||||
const auto a_k0_m_k1_grid_desc =
|
||||
*reinterpret_cast<const AK0MK1GridDesc*>((const void*)p_a_k0_m_k1_grid_desc);
|
||||
const auto b_k0_n_k1_grid_desc =
|
||||
*reinterpret_cast<const BK0NK1GridDesc*>((const void*)p_b_k0_n_k1_grid_desc);
|
||||
const auto c_m0_m1_m2_n_grid_desc =
|
||||
*reinterpret_cast<const CM0M1M2NGridDesc*>((const void*)p_c_m0_m1_m2_n_grid_desc);
|
||||
const auto cblockid_to_m0_n0_block_cluster_adaptor =
|
||||
*reinterpret_cast<const CBlockIdToM0N0BlockClusterAdaptor*>(
|
||||
(const void*)p_cblockid_to_m0_n0_block_cluster_adaptor);
|
||||
|
||||
constexpr index_t shared_block_size =
|
||||
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
|
||||
|
||||
__shared__ FloatAB p_shared_block[shared_block_size];
|
||||
|
||||
GridwiseGemm::Run(p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
p_shared_block,
|
||||
a_k0_m_k1_grid_desc,
|
||||
b_k0_n_k1_grid_desc,
|
||||
c_m0_m1_m2_n_grid_desc,
|
||||
cblockid_to_m0_n0_block_cluster_adaptor);
|
||||
};
|
||||
@@ -1,400 +0,0 @@
|
||||
#include "common_header.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
#include "gridwise_contraction_dlops_v1r2.hpp"
|
||||
#include "transform_forward_convolution_into_gemm_v6r1_nchw_kcyx_nkhw.hpp"
|
||||
|
||||
using namespace ck;
|
||||
|
||||
constexpr DataTypeEnum_t ABDataTypeEnum = static_cast<DataTypeEnum_t>(CK_PARAM_ABDataTypeEnum);
|
||||
constexpr DataTypeEnum_t AccDataTypeEnum = static_cast<DataTypeEnum_t>(CK_PARAM_AccDataTypeEnum);
|
||||
constexpr DataTypeEnum_t CDataTypeEnum = static_cast<DataTypeEnum_t>(CK_PARAM_CDataTypeEnum);
|
||||
|
||||
using FloatAB = typename get_datatype_from_enum<ABDataTypeEnum>::type;
|
||||
using FloatAcc = typename get_datatype_from_enum<AccDataTypeEnum>::type;
|
||||
using FloatC = typename get_datatype_from_enum<CDataTypeEnum>::type;
|
||||
|
||||
constexpr index_t BlockSize = CK_PARAM_BlockSize;
|
||||
|
||||
constexpr auto GN0 = Number<CK_PARAM_GN0>{};
|
||||
constexpr auto GK1 = Number<CK_PARAM_GK1>{};
|
||||
|
||||
constexpr index_t GM1PerBlockGM11 = CK_PARAM_GM1PerBlockGM11;
|
||||
constexpr index_t GN1PerBlockGN11 = CK_PARAM_GN1PerBlockGN11;
|
||||
constexpr index_t GK0PerBlock = CK_PARAM_GK0PerBlock;
|
||||
|
||||
constexpr index_t BM1PerThreadBM11 = CK_PARAM_BM1PerThreadBM11;
|
||||
constexpr index_t BN1PerThreadBN11 = CK_PARAM_BN1PerThreadBN11;
|
||||
constexpr index_t BK0PerThread = CK_PARAM_BK0PerThread;
|
||||
|
||||
using BM10BN10ThreadClusterBM10Xs = Sequence<CK_PARAM_BM10BN10ThreadClusterBM10Xs>;
|
||||
using BM10BN10ThreadClusterBN10Xs = Sequence<CK_PARAM_BM10BN10ThreadClusterBN10Xs>;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1 =
|
||||
Sequence<CK_PARAM_ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1>;
|
||||
using ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1 =
|
||||
Sequence<CK_PARAM_ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1>;
|
||||
using ABlockTransferThreadClusterArrangeOrder = Sequence<1, 2, 3, 0, 4>;
|
||||
using ABlockTransferSrcAccessOrder = Sequence<3, 2, 1, 0, 4>;
|
||||
using ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1 =
|
||||
Sequence<CK_PARAM_ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1>;
|
||||
using ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1 =
|
||||
Sequence<CK_PARAM_ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1>;
|
||||
using ABlockTransferSrcVectorTensorContiguousDimOrder = Sequence<0, 1, 2, 3, 4>;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1 =
|
||||
Sequence<CK_PARAM_BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1>;
|
||||
using BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1 =
|
||||
Sequence<CK_PARAM_BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1>;
|
||||
using BBlockTransferThreadClusterArrangeOrder = Sequence<0, 4, 1, 2, 3>;
|
||||
using BBlockTransferSrcAccessOrder = Sequence<4, 3, 2, 0, 1>;
|
||||
using BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1 =
|
||||
Sequence<CK_PARAM_BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1>;
|
||||
using BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1 =
|
||||
Sequence<CK_PARAM_BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1>;
|
||||
using BBlockTransferSrcVectorTensorContiguousDimOrder = Sequence<0, 1, 2, 3, 4>;
|
||||
|
||||
using CThreadTransferSrcDstAccessOrder = Sequence<3, 4, 5, 0, 1, 2>;
|
||||
constexpr index_t CThreadTransferSrcDstVectorDim = 5;
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = CK_PARAM_CThreadTransferDstScalarPerVector;
|
||||
|
||||
constexpr bool HasMainKBlockLoop = static_cast<bool>(CK_PARAM_HasMainKBlockLoop);
|
||||
constexpr bool HasDoubleTailKBlockLoop = static_cast<bool>(CK_PARAM_HasDoubleTailKBlockLoop);
|
||||
|
||||
extern "C" __global__ void
|
||||
convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw_prepare(int N_,
|
||||
int C_,
|
||||
int Hi_,
|
||||
int Wi_,
|
||||
int K_,
|
||||
int Y_,
|
||||
int X_,
|
||||
int ConvStrideH_,
|
||||
int ConvStrideW_,
|
||||
int ConvDilationH_,
|
||||
int ConvDilationW_,
|
||||
int InLeftPadH_,
|
||||
int InLeftPadW_,
|
||||
int InRightPadH_,
|
||||
int InRightPadW_,
|
||||
void* p_desc_tuple)
|
||||
{
|
||||
index_t N = static_cast<index_t>(N_);
|
||||
index_t C = static_cast<index_t>(C_);
|
||||
index_t Hi = static_cast<index_t>(Hi_);
|
||||
index_t Wi = static_cast<index_t>(Wi_);
|
||||
index_t K = static_cast<index_t>(K_);
|
||||
index_t Y = static_cast<index_t>(Y_);
|
||||
index_t X = static_cast<index_t>(X_);
|
||||
index_t ConvStrideH = static_cast<index_t>(ConvStrideH_);
|
||||
index_t ConvStrideW = static_cast<index_t>(ConvStrideW_);
|
||||
index_t ConvDilationH = static_cast<index_t>(ConvDilationH_);
|
||||
index_t ConvDilationW = static_cast<index_t>(ConvDilationW_);
|
||||
index_t InLeftPadH = static_cast<index_t>(InLeftPadH_);
|
||||
index_t InLeftPadW = static_cast<index_t>(InLeftPadW_);
|
||||
index_t InRightPadH = static_cast<index_t>(InRightPadH_);
|
||||
index_t InRightPadW = static_cast<index_t>(InRightPadW_);
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
|
||||
const index_t Ho =
|
||||
(Hi + InLeftPadH + InRightPadH - ConvDilationH * (Y - 1) - 1) / ConvStrideH + 1;
|
||||
const index_t Wo =
|
||||
(Wi + InLeftPadW + InRightPadW - ConvDilationW * (X - 1) - 1) / ConvStrideW + 1;
|
||||
|
||||
const auto in_n_c_hi_wi_desc = make_naive_tensor_descriptor_packed(make_tuple(N, C, Hi, Wi));
|
||||
const auto wei_k_c_y_x_desc = make_naive_tensor_descriptor_packed(make_tuple(K, C, Y, X));
|
||||
const auto out_n_k_ho_wo_desc = make_naive_tensor_descriptor_packed(make_tuple(N, K, Ho, Wo));
|
||||
|
||||
const auto descs = transform_forward_convolution_into_contraction_v6r1_nchw_kcyx_nkhw_pad(
|
||||
wei_k_c_y_x_desc,
|
||||
in_n_c_hi_wi_desc,
|
||||
out_n_k_ho_wo_desc,
|
||||
make_tuple(ConvStrideH, ConvStrideW),
|
||||
make_tuple(ConvDilationH, ConvDilationW),
|
||||
make_tuple(InLeftPadH, InLeftPadW),
|
||||
make_tuple(InRightPadH, InRightPadW),
|
||||
GN0,
|
||||
GK1);
|
||||
|
||||
const auto a_grid_desc_gk0_gm0_gm1_gk1 = descs[I0];
|
||||
const auto b_grid_desc_gk0_gn0_gn1_gk1 = descs[I1];
|
||||
const auto c_grid_desc_gm0_gm1_gn0_gn1 = descs[I2];
|
||||
|
||||
using AGridDesc_GK0_GM0_GM1_GK1 = decltype(a_grid_desc_gk0_gm0_gm1_gk1);
|
||||
using BGridDesc_GK0_GN0_GN1_GK1 = decltype(b_grid_desc_gk0_gn0_gn1_gk1);
|
||||
using CGridDesc_GM0_GM1_GN0_GN1 = decltype(c_grid_desc_gm0_gm1_gn0_gn1);
|
||||
|
||||
using AGridStepHacks =
|
||||
decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 0+: GK0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 1+: GM0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 2+: GM10
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 3+: GM11
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0>{}), // 4+: GK1
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 0-: GK0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 1-: GM0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 2-: GM10
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 3-: GM11
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0>{}))); // 4-: GK1
|
||||
|
||||
using BGridStepHacks = decltype(make_tuple(
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: GK0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 1+: GN0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 2+: GN10
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 3+: GN11
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 4+: GK1
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: GK0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 1-: GN0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 2-: GN10
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 3-: GN11
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}))); // 4-: GK1
|
||||
|
||||
using CGridStepHacks = decltype(make_tuple(
|
||||
make_tuple(
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: GM10
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0>{}, // 1+: BM0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0>{}, // 2+: BM1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: GN10
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0>{}, // 4+: BN0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0>{}), // 5+: GN1
|
||||
make_tuple(
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: GM10
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0>{}, // 1-: BM0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0>{}, // 2-: BM1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: GN10
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{}, // 4-: BN0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{}))); // 5-: GN1
|
||||
|
||||
using AGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0, 0, 0>;
|
||||
|
||||
using BGridMoveSliceWindowStepHacks =
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 2, 0, 0, 0, 0, 0>;
|
||||
|
||||
using GridwiseContraction =
|
||||
GridwiseContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN0_GN1<
|
||||
BlockSize,
|
||||
FloatAB,
|
||||
FloatAcc,
|
||||
FloatC,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
AGridDesc_GK0_GM0_GM1_GK1,
|
||||
BGridDesc_GK0_GN0_GN1_GK1,
|
||||
CGridDesc_GM0_GM1_GN0_GN1,
|
||||
GM1PerBlockGM11,
|
||||
GN1PerBlockGN11,
|
||||
GK0PerBlock,
|
||||
BM1PerThreadBM11,
|
||||
BN1PerThreadBN11,
|
||||
BK0PerThread,
|
||||
BM10BN10ThreadClusterBM10Xs,
|
||||
BM10BN10ThreadClusterBN10Xs,
|
||||
ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1,
|
||||
ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1,
|
||||
ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1,
|
||||
ABlockTransferSrcVectorTensorContiguousDimOrder,
|
||||
BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1,
|
||||
BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1,
|
||||
BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1,
|
||||
BBlockTransferSrcVectorTensorContiguousDimOrder,
|
||||
CThreadTransferSrcDstAccessOrder,
|
||||
CThreadTransferSrcDstVectorDim,
|
||||
CThreadTransferDstScalarPerVector,
|
||||
AGridStepHacks,
|
||||
BGridStepHacks,
|
||||
CGridStepHacks,
|
||||
AGridMoveSliceWindowStepHacks,
|
||||
BGridMoveSliceWindowStepHacks>;
|
||||
|
||||
if(get_block_1d_id() == 0 && get_thread_local_1d_id() == 0)
|
||||
{
|
||||
auto desc_tuple =
|
||||
make_tuple(GridwiseContraction::MakeAGridDescriptor_GK0_GM0_GM10_GM11_GK1(
|
||||
a_grid_desc_gk0_gm0_gm1_gk1),
|
||||
GridwiseContraction::MakeBGridDescriptor_GK0_GN0_GN10_GN11_GK1(
|
||||
b_grid_desc_gk0_gn0_gn1_gk1),
|
||||
GridwiseContraction::MakeCGridDescriptor_GM10_BM0_BM1_GN10_BN0_BN1(
|
||||
c_grid_desc_gm0_gm1_gn0_gn1),
|
||||
GridwiseContraction::MakeCGridBlockCluster_BlockId_To_GM10_GN10(
|
||||
c_grid_desc_gm0_gm1_gn0_gn1));
|
||||
|
||||
*static_cast<decltype(desc_tuple)*>(p_desc_tuple) = desc_tuple;
|
||||
}
|
||||
};
|
||||
|
||||
extern "C" __global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
#endif
|
||||
convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw(
|
||||
const FloatAB* __restrict__ p_a_grid,
|
||||
const FloatAB* __restrict__ p_b_grid,
|
||||
FloatC* __restrict__ p_c_grid,
|
||||
const void CONSTANT* p_desc_tuple)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
constexpr auto in_n_c_hi_wi_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(256, 256, 28, 28));
|
||||
constexpr auto wei_k_c_y_x_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(256, 256, 3, 3));
|
||||
constexpr auto out_n_k_ho_wo_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(256, 256, 28, 28));
|
||||
|
||||
constexpr auto descs =
|
||||
transform_forward_convolution_into_contraction_v6r1_nchw_kcyx_nkhw_pad(wei_k_c_y_x_desc,
|
||||
in_n_c_hi_wi_desc,
|
||||
out_n_k_ho_wo_desc,
|
||||
make_tuple(1, 1),
|
||||
make_tuple(1, 1),
|
||||
make_tuple(1, 1),
|
||||
make_tuple(1, 1),
|
||||
GN0,
|
||||
GK1);
|
||||
|
||||
constexpr auto a_grid_desc_gk0_gm0_gm1_gk1 = descs[I0];
|
||||
constexpr auto b_grid_desc_gk0_gn0_gn1_gk1 = descs[I1];
|
||||
constexpr auto c_grid_desc_gm0_gm1_gn0_gn1 = descs[I2];
|
||||
|
||||
using AGridDesc_GK0_GM0_GM1_GK1 = decltype(a_grid_desc_gk0_gm0_gm1_gk1);
|
||||
using BGridDesc_GK0_GN0_GN1_GK1 = decltype(b_grid_desc_gk0_gn0_gn1_gk1);
|
||||
using CGridDesc_GM0_GM1_GN0_GN1 = decltype(c_grid_desc_gm0_gm1_gn0_gn1);
|
||||
|
||||
using AGridStepHacks =
|
||||
decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 0+: GK0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 1+: GM0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 2+: GM10
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 3+: GM11
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0>{}), // 4+: GK1
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 0-: GK0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 1-: GM0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 2-: GM10
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 3-: GM11
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0>{}))); // 4-: GK1
|
||||
|
||||
using BGridStepHacks = decltype(make_tuple(
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: GK0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 1+: GN0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 2+: GN10
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 3+: GN11
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 4+: GK1
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: GK0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 1-: GN0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 2-: GN10
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 3-: GN11
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}))); // 4-: GK1
|
||||
|
||||
using CGridStepHacks = decltype(make_tuple(
|
||||
make_tuple(
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: GM10
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0>{}, // 1+: BM0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0>{}, // 2+: BM1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: GN10
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0>{}, // 4+: BN0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0>{}), // 5+: GN1
|
||||
make_tuple(
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: GM10
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0>{}, // 1-: BM0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0>{}, // 2-: BM1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: GN10
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{}, // 4-: BN0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{}))); // 5-: GN1
|
||||
|
||||
using AGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0, 0, 0>;
|
||||
|
||||
using BGridMoveSliceWindowStepHacks =
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 2, 0, 0, 0, 0, 0>;
|
||||
|
||||
using GridwiseContraction =
|
||||
GridwiseContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN0_GN1<
|
||||
BlockSize,
|
||||
FloatAB,
|
||||
FloatAcc,
|
||||
FloatC,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
AGridDesc_GK0_GM0_GM1_GK1,
|
||||
BGridDesc_GK0_GN0_GN1_GK1,
|
||||
CGridDesc_GM0_GM1_GN0_GN1,
|
||||
GM1PerBlockGM11,
|
||||
GN1PerBlockGN11,
|
||||
GK0PerBlock,
|
||||
BM1PerThreadBM11,
|
||||
BN1PerThreadBN11,
|
||||
BK0PerThread,
|
||||
BM10BN10ThreadClusterBM10Xs,
|
||||
BM10BN10ThreadClusterBN10Xs,
|
||||
ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1,
|
||||
ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1,
|
||||
ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1,
|
||||
ABlockTransferSrcVectorTensorContiguousDimOrder,
|
||||
BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1,
|
||||
BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1,
|
||||
BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1,
|
||||
BBlockTransferSrcVectorTensorContiguousDimOrder,
|
||||
CThreadTransferSrcDstAccessOrder,
|
||||
CThreadTransferSrcDstVectorDim,
|
||||
CThreadTransferDstScalarPerVector,
|
||||
AGridStepHacks,
|
||||
BGridStepHacks,
|
||||
CGridStepHacks,
|
||||
AGridMoveSliceWindowStepHacks,
|
||||
BGridMoveSliceWindowStepHacks>;
|
||||
|
||||
using AGridDesc_GK0_GM0_GM10_GM11_GK1 =
|
||||
decltype(GridwiseContraction::MakeAGridDescriptor_GK0_GM0_GM10_GM11_GK1(
|
||||
a_grid_desc_gk0_gm0_gm1_gk1));
|
||||
using BGridDesc_GK0_GN0_GN10_GN11_GK1 =
|
||||
decltype(GridwiseContraction::MakeBGridDescriptor_GK0_GN0_GN10_GN11_GK1(
|
||||
b_grid_desc_gk0_gn0_gn1_gk1));
|
||||
using CGridDesc_GM10_BM0_BM1_GN10_BN0_BN1 =
|
||||
decltype(GridwiseContraction::MakeCGridDescriptor_GM10_BM0_BM1_GN10_BN0_BN1(
|
||||
c_grid_desc_gm0_gm1_gn0_gn1));
|
||||
using CGridBlockCluster_BlockId_To_GM10_GN10 =
|
||||
decltype(GridwiseContraction::MakeCGridBlockCluster_BlockId_To_GM10_GN10(
|
||||
c_grid_desc_gm0_gm1_gn0_gn1));
|
||||
|
||||
using DescTuple = decltype(make_tuple(AGridDesc_GK0_GM0_GM10_GM11_GK1{},
|
||||
BGridDesc_GK0_GN0_GN10_GN11_GK1{},
|
||||
CGridDesc_GM10_BM0_BM1_GN10_BN0_BN1{},
|
||||
CGridBlockCluster_BlockId_To_GM10_GN10{}));
|
||||
|
||||
const auto desc_tuple =
|
||||
*reinterpret_cast<const DescTuple*>(cast_pointer_to_generic_address_space(p_desc_tuple));
|
||||
|
||||
const auto a_grid_desc_gk0_gm0_gm10_gm11_gk1 = desc_tuple[I0];
|
||||
const auto b_grid_desc_gk0_gn0_gn10_gn11_gk1 = desc_tuple[I1];
|
||||
const auto c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1 = desc_tuple[I2];
|
||||
const auto c_grid_block_cluster_blockid_to_gm10_gn10 = desc_tuple[I3];
|
||||
|
||||
constexpr index_t shared_block_size =
|
||||
GridwiseContraction::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
|
||||
|
||||
__shared__ FloatAB p_shared_block[shared_block_size];
|
||||
|
||||
GridwiseContraction::Run(p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
p_shared_block,
|
||||
a_grid_desc_gk0_gm0_gm10_gm11_gk1,
|
||||
b_grid_desc_gk0_gn0_gn10_gn11_gk1,
|
||||
c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1,
|
||||
c_grid_block_cluster_blockid_to_gm10_gn10,
|
||||
integral_constant<bool, HasMainKBlockLoop>{},
|
||||
integral_constant<bool, HasDoubleTailKBlockLoop>{});
|
||||
};
|
||||
@@ -1,204 +0,0 @@
|
||||
include_directories(BEFORE
|
||||
include
|
||||
${PROJECT_SOURCE_DIR}/host/host_tensor/include
|
||||
${PROJECT_SOURCE_DIR}/device/include
|
||||
${PROJECT_SOURCE_DIR}/device_operation/include
|
||||
${PROJECT_SOURCE_DIR}/profiler/include
|
||||
${PROJECT_SOURCE_DIR}/composable_kernel/include
|
||||
${PROJECT_SOURCE_DIR}/composable_kernel/include/utility
|
||||
${PROJECT_SOURCE_DIR}/composable_kernel/include/tensor_description
|
||||
${PROJECT_SOURCE_DIR}/composable_kernel/include/tensor_operation
|
||||
${PROJECT_SOURCE_DIR}/composable_kernel/include/problem_transform
|
||||
${PROJECT_SOURCE_DIR}/external/rocm/include
|
||||
)
|
||||
|
||||
# device_gemm_instance
|
||||
set(DEVICE_GEMM_INSTANCE_SOURCE
|
||||
${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_f32_f32_f32_mk_kn_mn_instance.cpp;
|
||||
${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_f32_f32_f32_mk_nk_mn_instance.cpp;
|
||||
${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_f32_f32_f32_km_kn_mn_instance.cpp;
|
||||
${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_f32_f32_f32_km_nk_mn_instance.cpp;
|
||||
${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_f16_f16_f16_mk_kn_mn_instance.cpp;
|
||||
${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_f16_f16_f16_mk_nk_mn_instance.cpp;
|
||||
${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_f16_f16_f16_km_kn_mn_instance.cpp;
|
||||
${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_f16_f16_f16_km_nk_mn_instance.cpp;
|
||||
${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_c_shuffle_int8_int8_int8_mk_nk_mn_instance.cpp;
|
||||
${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instance.cpp;
|
||||
${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instance.cpp;
|
||||
${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instance.cpp;
|
||||
${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instance.cpp;
|
||||
${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instance.cpp;
|
||||
${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instance.cpp;
|
||||
${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instance.cpp;
|
||||
${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instance.cpp;
|
||||
${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instance.cpp;
|
||||
${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instance.cpp;
|
||||
${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instance.cpp;
|
||||
${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instance.cpp;
|
||||
${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instance.cpp;
|
||||
${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instance.cpp;
|
||||
)
|
||||
|
||||
# device_gemm_bias_2d_instance
|
||||
set(DEVICE_GEMM_BIAS_2D_INSTANCE_SOURCE
|
||||
${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_c_shuffle_bias_2d_f32_f32_f32_km_kn_mn_instance.cpp;
|
||||
${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_c_shuffle_bias_2d_f32_f32_f32_km_nk_mn_instance.cpp;
|
||||
${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_c_shuffle_bias_2d_f32_f32_f32_mk_kn_mn_instance.cpp;
|
||||
${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_c_shuffle_bias_2d_f32_f32_f32_mk_nk_mn_instance.cpp;
|
||||
${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_c_shuffle_bias_2d_f16_f16_f16_km_kn_mn_instance.cpp;
|
||||
${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_c_shuffle_bias_2d_f16_f16_f16_km_nk_mn_instance.cpp;
|
||||
${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_c_shuffle_bias_2d_f16_f16_f16_mk_kn_mn_instance.cpp;
|
||||
${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_c_shuffle_bias_2d_f16_f16_f16_mk_nk_mn_instance.cpp;
|
||||
)
|
||||
|
||||
# device_gemm_bias_relu_instance
|
||||
set(DEVICE_GEMM_BIAS_RELU_INSTANCE_SOURCE
|
||||
${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_mk_kn_mn_instance.cpp;
|
||||
${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_mk_nk_mn_instance.cpp;
|
||||
${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_km_kn_mn_instance.cpp;
|
||||
${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_km_nk_mn_instance.cpp;
|
||||
)
|
||||
|
||||
# device_gemm_bias_relu_add_instance
|
||||
set(DEVICE_GEMM_BIAS_RELU_ADD_INSTANCE_SOURCE
|
||||
${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_mk_kn_mn_instance.cpp;
|
||||
${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_mk_nk_mn_instance.cpp;
|
||||
${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_km_kn_mn_instance.cpp;
|
||||
${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_km_nk_mn_instance.cpp;
|
||||
)
|
||||
|
||||
set(DEVICE_BATCHED_GEMM_INSTANCE_SOURCE
|
||||
${PROJECT_SOURCE_DIR}/device_operation/src/device_batched_gemm_xdl_f16_f16_f16_gmk_gkn_gmn_instance.cpp;
|
||||
${PROJECT_SOURCE_DIR}/device_operation/src/device_batched_gemm_xdl_f16_f16_f16_gmk_gnk_gmn_instance.cpp;
|
||||
${PROJECT_SOURCE_DIR}/device_operation/src/device_batched_gemm_xdl_f16_f16_f16_gkm_gkn_gmn_instance.cpp;
|
||||
${PROJECT_SOURCE_DIR}/device_operation/src/device_batched_gemm_xdl_f16_f16_f16_gkm_gnk_gmn_instance.cpp;
|
||||
)
|
||||
|
||||
# device_conv2d_fwd_instance
|
||||
set(DEVICE_CONV2D_FWD_INSTANCE_SOURCE
|
||||
${PROJECT_SOURCE_DIR}/device_operation/src/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instance.cpp;
|
||||
${PROJECT_SOURCE_DIR}/device_operation/src/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instance.cpp;
|
||||
${PROJECT_SOURCE_DIR}/device_operation/src/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp;
|
||||
${PROJECT_SOURCE_DIR}/device_operation/src/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instance.cpp;
|
||||
${PROJECT_SOURCE_DIR}/device_operation/src/device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instance.cpp;
|
||||
)
|
||||
|
||||
# device_conv1d_fwd_instance
|
||||
set(DEVICE_CONV1D_FWD_INSTANCE_SOURCE
|
||||
${PROJECT_SOURCE_DIR}/device_operation/src/device_conv1d_fwd_xdl_nwc_kxc_nwk_f32_instance.cpp;
|
||||
)
|
||||
|
||||
# device_conv2d_fwd_bias_relu_instance
|
||||
set(DEVICE_CONV2D_FWD_BIAS_RELU_INSTANCE_SOURCE
|
||||
${PROJECT_SOURCE_DIR}/device_operation/src/device_conv2d_fwd_xdl_c_shuffle_bias_relu_nhwc_kyxc_nhwk_f16_instance.cpp;
|
||||
)
|
||||
|
||||
# device_conv2d_fwd_bias_relu_add_instance
|
||||
set(DEVICE_CONV2D_FWD_BIAS_RELU_ADD_INSTANCE_SOURCE
|
||||
${PROJECT_SOURCE_DIR}/device_operation/src/device_conv2d_fwd_xdl_c_shuffle_bias_relu_add_nhwc_kyxc_nhwk_f16_instance.cpp;
|
||||
)
|
||||
|
||||
# device_conv2d_fwd_bias_relu_atomic_add_instance
|
||||
set(DEVICE_CONV2D_FWD_BIAS_RELU_ATOMIC_ADD_INSTANCE_SOURCE
|
||||
${PROJECT_SOURCE_DIR}/device_operation/src/device_conv2d_fwd_xdl_c_shuffle_bias_relu_atomic_add_nhwc_kyxc_nhwk_f16_instance.cpp;
|
||||
)
|
||||
|
||||
# device_conv2d_bwd_data_instance
|
||||
set(DEVICE_CONV2D_BWD_DATA_INSTANCE_SOURCE
|
||||
${PROJECT_SOURCE_DIR}/device_operation/src/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instance.cpp;
|
||||
${PROJECT_SOURCE_DIR}/device_operation/src/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instance.cpp;
|
||||
${PROJECT_SOURCE_DIR}/device_operation/src/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp;
|
||||
${PROJECT_SOURCE_DIR}/device_operation/src/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instance.cpp;
|
||||
)
|
||||
|
||||
# device_reduce_instance
|
||||
set(DEVICE_REDUCE_INSTANCE_SOURCE
|
||||
${PROJECT_SOURCE_DIR}/device_operation/src/device_reduce_instance_blockwise_f16_f16_f16.cpp;
|
||||
${PROJECT_SOURCE_DIR}/device_operation/src/device_reduce_instance_blockwise_f16_f32_f16.cpp;
|
||||
${PROJECT_SOURCE_DIR}/device_operation/src/device_reduce_instance_blockwise_f32_f32_f32.cpp;
|
||||
${PROJECT_SOURCE_DIR}/device_operation/src/device_reduce_instance_blockwise_f32_f64_f32.cpp;
|
||||
${PROJECT_SOURCE_DIR}/device_operation/src/device_reduce_instance_blockwise_f64_f64_f64.cpp;
|
||||
${PROJECT_SOURCE_DIR}/device_operation/src/device_reduce_instance_threadwise_f16_f16_f16.cpp;
|
||||
${PROJECT_SOURCE_DIR}/device_operation/src/device_reduce_instance_threadwise_f16_f32_f16.cpp;
|
||||
${PROJECT_SOURCE_DIR}/device_operation/src/device_reduce_instance_threadwise_f32_f32_f32.cpp;
|
||||
${PROJECT_SOURCE_DIR}/device_operation/src/device_reduce_instance_threadwise_f32_f64_f32.cpp;
|
||||
${PROJECT_SOURCE_DIR}/device_operation/src/device_reduce_instance_threadwise_f64_f64_f64.cpp;
|
||||
${PROJECT_SOURCE_DIR}/device_operation/src/device_reduce_instance_blockwise_second_call_f16_f16_f16.cpp;
|
||||
${PROJECT_SOURCE_DIR}/device_operation/src/device_reduce_instance_blockwise_second_call_f32_f32_f16.cpp;
|
||||
${PROJECT_SOURCE_DIR}/device_operation/src/device_reduce_instance_blockwise_second_call_f32_f32_f32.cpp;
|
||||
${PROJECT_SOURCE_DIR}/device_operation/src/device_reduce_instance_blockwise_second_call_f64_f64_f32.cpp;
|
||||
${PROJECT_SOURCE_DIR}/device_operation/src/device_reduce_instance_blockwise_second_call_f64_f64_f64.cpp;
|
||||
${PROJECT_SOURCE_DIR}/device_operation/src/device_reduce_instance_multiblock_atomic_add_f16_f32_f32.cpp;
|
||||
${PROJECT_SOURCE_DIR}/device_operation/src/device_reduce_instance_multiblock_atomic_add_f32_f32_f32.cpp;
|
||||
${PROJECT_SOURCE_DIR}/device_operation/src/device_reduce_instance_multiblock_atomic_add_f32_f64_f32.cpp;
|
||||
${PROJECT_SOURCE_DIR}/device_operation/src/device_reduce_instance_multiblock_partial_reduce_f16_f16_f16.cpp;
|
||||
${PROJECT_SOURCE_DIR}/device_operation/src/device_reduce_instance_multiblock_partial_reduce_f16_f32_f16.cpp;
|
||||
${PROJECT_SOURCE_DIR}/device_operation/src/device_reduce_instance_multiblock_partial_reduce_f32_f32_f32.cpp;
|
||||
${PROJECT_SOURCE_DIR}/device_operation/src/device_reduce_instance_multiblock_partial_reduce_f32_f64_f32.cpp;
|
||||
${PROJECT_SOURCE_DIR}/device_operation/src/device_reduce_instance_multiblock_partial_reduce_f64_f64_f64.cpp;
|
||||
)
|
||||
|
||||
add_library(device_gemm_instance SHARED ${DEVICE_GEMM_INSTANCE_SOURCE})
|
||||
add_library(device_gemm_bias_2d_instance SHARED ${DEVICE_GEMM_BIAS_2D_INSTANCE_SOURCE})
|
||||
add_library(device_gemm_bias_relu_instance SHARED ${DEVICE_GEMM_BIAS_RELU_INSTANCE_SOURCE})
|
||||
add_library(device_gemm_bias_relu_add_instance SHARED ${DEVICE_GEMM_BIAS_RELU_ADD_INSTANCE_SOURCE})
|
||||
add_library(device_batched_gemm_instance SHARED ${DEVICE_BATCHED_GEMM_INSTANCE_SOURCE})
|
||||
add_library(device_conv1d_fwd_instance SHARED ${DEVICE_CONV1D_FWD_INSTANCE_SOURCE})
|
||||
add_library(device_conv2d_fwd_instance SHARED ${DEVICE_CONV2D_FWD_INSTANCE_SOURCE})
|
||||
add_library(device_conv2d_fwd_bias_relu_instance SHARED ${DEVICE_CONV2D_FWD_BIAS_RELU_INSTANCE_SOURCE})
|
||||
add_library(device_conv2d_fwd_bias_relu_add_instance SHARED ${DEVICE_CONV2D_FWD_BIAS_RELU_ADD_INSTANCE_SOURCE})
|
||||
add_library(device_conv2d_fwd_bias_relu_atomic_add_instance SHARED ${DEVICE_CONV2D_FWD_BIAS_RELU_ATOMIC_ADD_INSTANCE_SOURCE})
|
||||
add_library(device_conv2d_bwd_data_instance SHARED ${DEVICE_CONV2D_BWD_DATA_INSTANCE_SOURCE})
|
||||
add_library(device_reduce_instance SHARED ${DEVICE_REDUCE_INSTANCE_SOURCE})
|
||||
|
||||
target_include_directories(device_gemm_instance SYSTEM PUBLIC $<BUILD_INTERFACE:${HALF_INCLUDE_DIR}>)
|
||||
target_include_directories(device_gemm_bias_2d_instance SYSTEM PUBLIC $<BUILD_INTERFACE:${HALF_INCLUDE_DIR}>)
|
||||
target_include_directories(device_gemm_bias_relu_instance SYSTEM PUBLIC $<BUILD_INTERFACE:${HALF_INCLUDE_DIR}>)
|
||||
target_include_directories(device_gemm_bias_relu_add_instance SYSTEM PUBLIC $<BUILD_INTERFACE:${HALF_INCLUDE_DIR}>)
|
||||
target_include_directories(device_batched_gemm_instance SYSTEM PUBLIC $<BUILD_INTERFACE:${HALF_INCLUDE_DIR}>)
|
||||
target_include_directories(device_conv1d_fwd_instance SYSTEM PUBLIC $<BUILD_INTERFACE:${HALF_INCLUDE_DIR}>)
|
||||
target_include_directories(device_conv2d_fwd_instance SYSTEM PUBLIC $<BUILD_INTERFACE:${HALF_INCLUDE_DIR}>)
|
||||
target_include_directories(device_conv2d_fwd_bias_relu_instance SYSTEM PUBLIC $<BUILD_INTERFACE:${HALF_INCLUDE_DIR}>)
|
||||
target_include_directories(device_conv2d_fwd_bias_relu_add_instance SYSTEM PUBLIC $<BUILD_INTERFACE:${HALF_INCLUDE_DIR}>)
|
||||
target_include_directories(device_conv2d_fwd_bias_relu_atomic_add_instance SYSTEM PUBLIC $<BUILD_INTERFACE:${HALF_INCLUDE_DIR}>)
|
||||
target_include_directories(device_conv2d_bwd_data_instance SYSTEM PUBLIC $<BUILD_INTERFACE:${HALF_INCLUDE_DIR}>)
|
||||
target_include_directories(device_reduce_instance SYSTEM PUBLIC $<BUILD_INTERFACE:${HALF_INCLUDE_DIR}>)
|
||||
|
||||
target_compile_features(device_gemm_instance PUBLIC)
|
||||
target_compile_features(device_gemm_bias_2d_instance PUBLIC)
|
||||
target_compile_features(device_gemm_bias_relu_instance PUBLIC)
|
||||
target_compile_features(device_gemm_bias_relu_add_instance PUBLIC)
|
||||
target_compile_features(device_batched_gemm_instance PUBLIC)
|
||||
target_compile_features(device_conv1d_fwd_instance PUBLIC)
|
||||
target_compile_features(device_conv2d_fwd_instance PUBLIC)
|
||||
target_compile_features(device_conv2d_fwd_bias_relu_instance PUBLIC)
|
||||
target_compile_features(device_conv2d_fwd_bias_relu_add_instance PUBLIC)
|
||||
target_compile_features(device_conv2d_fwd_bias_relu_atomic_add_instance PUBLIC)
|
||||
target_compile_features(device_conv2d_bwd_data_instance PUBLIC)
|
||||
target_compile_features(device_reduce_instance PUBLIC)
|
||||
|
||||
set_target_properties(device_gemm_instance PROPERTIES POSITION_INDEPENDENT_CODE ON)
|
||||
set_target_properties(device_gemm_bias_2d_instance PROPERTIES POSITION_INDEPENDENT_CODE ON)
|
||||
set_target_properties(device_gemm_bias_relu_instance PROPERTIES POSITION_INDEPENDENT_CODE ON)
|
||||
set_target_properties(device_gemm_bias_relu_add_instance PROPERTIES POSITION_INDEPENDENT_CODE ON)
|
||||
set_target_properties(device_batched_gemm_instance PROPERTIES POSITION_INDEPENDENT_CODE ON)
|
||||
set_target_properties(device_conv1d_fwd_instance PROPERTIES POSITION_INDEPENDENT_CODE ON)
|
||||
set_target_properties(device_conv2d_fwd_instance PROPERTIES POSITION_INDEPENDENT_CODE ON)
|
||||
set_target_properties(device_conv2d_fwd_bias_relu_instance PROPERTIES POSITION_INDEPENDENT_CODE ON)
|
||||
set_target_properties(device_conv2d_fwd_bias_relu_add_instance PROPERTIES POSITION_INDEPENDENT_CODE ON)
|
||||
set_target_properties(device_conv2d_fwd_bias_relu_atomic_add_instance PROPERTIES POSITION_INDEPENDENT_CODE ON)
|
||||
set_target_properties(device_conv2d_bwd_data_instance PROPERTIES POSITION_INDEPENDENT_CODE ON)
|
||||
set_target_properties(device_reduce_instance PROPERTIES POSITION_INDEPENDENT_CODE ON)
|
||||
|
||||
install(TARGETS device_gemm_instance LIBRARY DESTINATION lib)
|
||||
install(TARGETS device_gemm_bias_2d_instance LIBRARY DESTINATION lib)
|
||||
install(TARGETS device_gemm_bias_relu_instance LIBRARY DESTINATION lib)
|
||||
install(TARGETS device_gemm_bias_relu_add_instance LIBRARY DESTINATION lib)
|
||||
install(TARGETS device_batched_gemm_instance LIBRARY DESTINATION lib)
|
||||
install(TARGETS device_conv1d_fwd_instance LIBRARY DESTINATION lib)
|
||||
install(TARGETS device_conv2d_fwd_instance LIBRARY DESTINATION lib)
|
||||
install(TARGETS device_conv2d_fwd_bias_relu_instance LIBRARY DESTINATION lib)
|
||||
install(TARGETS device_conv2d_fwd_bias_relu_add_instance LIBRARY DESTINATION lib)
|
||||
install(TARGETS device_conv2d_fwd_bias_relu_atomic_add_instance LIBRARY DESTINATION lib)
|
||||
install(TARGETS device_conv2d_bwd_data_instance LIBRARY DESTINATION lib)
|
||||
install(TARGETS device_reduce_instance LIBRARY DESTINATION lib)
|
||||
3
example/01_gemm/CMakeLists.txt
Normal file
3
example/01_gemm/CMakeLists.txt
Normal file
@@ -0,0 +1,3 @@
|
||||
add_example_executable(example_gemm_xdl_fp16 gemm_xdl_fp16.cpp)
|
||||
add_example_executable(example_gemm_xdl_bf16 gemm_xdl_bf16.cpp)
|
||||
add_example_executable(example_gemm_xdl_int8 gemm_xdl_int8.cpp)
|
||||
1
example/02_gemm_alpha_beta/CMakeLists.txt
Normal file
1
example/02_gemm_alpha_beta/CMakeLists.txt
Normal file
@@ -0,0 +1 @@
|
||||
add_example_executable(example_gemm_xdl_alpha_beta gemm_xdl_alpha_beta.cpp)
|
||||
1
example/03_gemm_bias_relu/CMakeLists.txt
Normal file
1
example/03_gemm_bias_relu/CMakeLists.txt
Normal file
@@ -0,0 +1 @@
|
||||
add_example_executable(example_gemm_xdl_bias_relu gemm_xdl_bias_relu.cpp)
|
||||
1
example/04_gemm_bias_relu_add/CMakeLists.txt
Normal file
1
example/04_gemm_bias_relu_add/CMakeLists.txt
Normal file
@@ -0,0 +1 @@
|
||||
add_example_executable(example_gemm_xdl_bias_relu_add gemm_xdl_bias_relu_add.cpp)
|
||||
2
example/05_conv2d_fwd/CMakeLists.txt
Normal file
2
example/05_conv2d_fwd/CMakeLists.txt
Normal file
@@ -0,0 +1,2 @@
|
||||
add_example_executable(example_conv2d_fwd_xdl_fp16 conv2d_fwd_xdl_fp16.cpp)
|
||||
add_example_executable(example_conv2d_fwd_xdl_int8 conv2d_fwd_xdl_int8.cpp)
|
||||
1
example/06_conv2d_fwd_bias_relu/CMakeLists.txt
Normal file
1
example/06_conv2d_fwd_bias_relu/CMakeLists.txt
Normal file
@@ -0,0 +1 @@
|
||||
add_example_executable(example_conv2d_fwd_xdl_bias_relu conv2d_fwd_xdl_bias_relu.cpp)
|
||||
1
example/07_conv2d_fwd_bias_relu_add/CMakeLists.txt
Normal file
1
example/07_conv2d_fwd_bias_relu_add/CMakeLists.txt
Normal file
@@ -0,0 +1 @@
|
||||
add_example_executable(example_conv2d_fwd_xdl_bias_relu_add conv2d_fwd_xdl_bias_relu_add.cpp)
|
||||
1
example/08_conv3d_fwd/CMakeLists.txt
Normal file
1
example/08_conv3d_fwd/CMakeLists.txt
Normal file
@@ -0,0 +1 @@
|
||||
add_example_executable(example_conv3d_fwd_xdl conv3d_fwd_xdl.cpp)
|
||||
1
example/09_convnd_fwd/CMakeLists.txt
Normal file
1
example/09_convnd_fwd/CMakeLists.txt
Normal file
@@ -0,0 +1 @@
|
||||
add_example_executable(example_convnd_fwd_xdl convnd_fwd_xdl.cpp)
|
||||
@@ -2,7 +2,6 @@
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <type_traits>
|
||||
|
||||
#include "config.hpp"
|
||||
#include "conv_utils.hpp"
|
||||
#include "device.hpp"
|
||||
1
example/10_conv2d_bwd_data/CMakeLists.txt
Normal file
1
example/10_conv2d_bwd_data/CMakeLists.txt
Normal file
@@ -0,0 +1 @@
|
||||
add_example_executable(example_conv2d_bwd_data_xdl conv2d_bwd_data_xdl.cpp)
|
||||
1
example/11_conv2d_bwd_wgt/CMakeLists.txt
Normal file
1
example/11_conv2d_bwd_wgt/CMakeLists.txt
Normal file
@@ -0,0 +1 @@
|
||||
add_example_executable(example_conv2d_bwd_wgt_xdl conv2d_bwd_wgt_xdl.cpp)
|
||||
1
example/12_reduce/CMakeLists.txt
Normal file
1
example/12_reduce/CMakeLists.txt
Normal file
@@ -0,0 +1 @@
|
||||
add_example_executable(example_reduce_blockwise reduce_blockwise.cpp)
|
||||
@@ -14,7 +14,6 @@
|
||||
#include "device_reduce_blockwise.hpp"
|
||||
#include "host_reduce_util.hpp"
|
||||
#include "host_generic_reduction.hpp"
|
||||
|
||||
#include "reduction_enums.hpp"
|
||||
#include "reduction_operator_mapping.hpp"
|
||||
|
||||
1
example/13_pool2d_fwd/CMakeLists.txt
Normal file
1
example/13_pool2d_fwd/CMakeLists.txt
Normal file
@@ -0,0 +1 @@
|
||||
add_example_executable(example_pool2d_fwd pool2d_fwd.cpp)
|
||||
@@ -12,7 +12,7 @@
|
||||
#include "device_tensor.hpp"
|
||||
#include "tensor_layout.hpp"
|
||||
#include "reduction_operator.hpp"
|
||||
#include "device_operation/include/device_pool2d_fwd_nhwc_nhwc.hpp"
|
||||
#include "device_pool2d_fwd_nhwc_nhwc.hpp"
|
||||
|
||||
using InDataType = ck::half_t;
|
||||
using OutDataType = ck::half_t;
|
||||
@@ -1,61 +0,0 @@
|
||||
# Instructions for ```conv_xdl_bias_relu_add``` Example
|
||||
|
||||
## Docker script
|
||||
```bash
|
||||
docker run \
|
||||
-it \
|
||||
--rm \
|
||||
--privileged \
|
||||
--group-add sudo \
|
||||
-w /root/workspace \
|
||||
-v ${PATH_TO_LOCAL_WORKSPACE}:/root/workspace \
|
||||
rocm/tensorflow:rocm4.3.1-tf2.6-dev \
|
||||
/bin/bash
|
||||
```
|
||||
|
||||
## Build ```conv_xdl_bias_relu_add```
|
||||
```bash
|
||||
mkdir build && cd build
|
||||
```
|
||||
|
||||
```bash
|
||||
# Need to specify target ID, example below is gfx908
|
||||
cmake \
|
||||
-D BUILD_DEV=OFF \
|
||||
-D CMAKE_BUILD_TYPE=Release \
|
||||
-D CMAKE_CXX_FLAGS="-DCK_AMD_GPU_GFX908 --amdgpu-target=gfx908 -O3 " \
|
||||
-D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \
|
||||
-D CMAKE_PREFIX_PATH=/opt/rocm \
|
||||
..
|
||||
```
|
||||
|
||||
```bash
|
||||
make -j conv_xdl_bias_relu_add
|
||||
```
|
||||
|
||||
## Run ```conv_xdl_bias_relu_add```
|
||||
```bash
|
||||
#arg1: verification (0=no, 1=yes)
|
||||
#arg2: initialization (0=no init, 1=integer value, 2=decimal value)
|
||||
#arg3: run kernel # of times (>1)
|
||||
#arg4 to 18: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, RightPx
|
||||
./example/conv_xdl_bias_relu_add 0 1 5
|
||||
```
|
||||
|
||||
Result (MI100 @ 1087Mhz, 133.5TFlops peak FP16)
|
||||
```
|
||||
in_n_c_hi_wi: dim 4, lengths {128, 192, 71, 71}, strides {967872, 1, 13632, 192}
|
||||
wei_k_c_y_x: dim 4, lengths {256, 192, 3, 3}, strides {1728, 1, 576, 192}
|
||||
out_n_k_ho_wo: dim 4, lengths {128, 256, 36, 36}, strides {331776, 1, 9216, 256}
|
||||
bias_k: dim 1, lengths {256}, strides {1}
|
||||
resi_n_k_ho_wo: dim 4, lengths {128, 256, 36, 36}, strides {331776, 1, 9216, 256}
|
||||
arg.a_grid_desc_k0_m_k1_{216, 165888, 8}
|
||||
arg.b_grid_desc_k0_n_k1_{216, 256, 8}
|
||||
arg.c_grid_desc_m_n_{ 165888, 256}
|
||||
arg.c0_grid_desc_m_n_{ 165888, 256}
|
||||
arg.c1_grid_desc_m_n_{ 165888, 256}
|
||||
launch_and_time_kernel: grid_dim {1296, 1, 1}, block_dim {256, 1, 1}
|
||||
Warm up
|
||||
Start running 5 times...
|
||||
Perf: 1.71779 ms, 85.4396 TFlops, 194.2 GB/s
|
||||
```
|
||||
@@ -1,314 +0,0 @@
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <initializer_list>
|
||||
#include <cstdlib>
|
||||
#include <stdlib.h>
|
||||
#include <half.hpp>
|
||||
#include "config.hpp"
|
||||
#include "print.hpp"
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "host_tensor_generator.hpp"
|
||||
#include "device_tensor.hpp"
|
||||
#include "tensor_layout.hpp"
|
||||
#include "device_conv2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp"
|
||||
#include "element_wise_operation.hpp"
|
||||
#include "convolution_utility.hpp"
|
||||
|
||||
using InDataType = ck::half_t;
|
||||
using WeiDataType = ck::half_t;
|
||||
using OutDataType = ck::half_t;
|
||||
using AccDataType = float;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using InLayout = ck::tensor_layout::convolution::NHWC;
|
||||
using WeiLayout = ck::tensor_layout::convolution::KYXC;
|
||||
using OutLayout = ck::tensor_layout::convolution::NHWK;
|
||||
|
||||
using InElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using WeiElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using OutElementOp = ck::tensor_operation::element_wise::AddRelu;
|
||||
|
||||
static constexpr auto MemoryAtomicAdd = ck::InMemoryDataOperationEnum_t::AtomicAdd;
|
||||
|
||||
static constexpr auto ConvFwdDefault =
|
||||
ck::tensor_operation::device::ConvolutionForwardSpecialization_t::Default;
|
||||
|
||||
// clang-format off
|
||||
using DeviceConvFwdInstance = ck::tensor_operation::device::
|
||||
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
|
||||
// clang-format off
|
||||
// | InData| WeiData| OutData| AccData| In| Wei| Out| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
|
||||
// | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| GlobalMemory| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector|
|
||||
// | | | | | Operation| Operation| Operation| DataOperation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl|
|
||||
// | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
<InDataType, WeiDataType, OutDataType, AccDataType, InElementOp, WeiElementOp, OutElementOp, MemoryAtomicAdd, ConvFwdDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 8, 1, 1,32>, 2>;
|
||||
// clang-format on
|
||||
|
||||
template <typename TIn,
|
||||
typename TWei,
|
||||
typename TOut,
|
||||
typename InElementOp,
|
||||
typename WeiElementOp,
|
||||
typename OutElementOp>
|
||||
void host_reference_calculation(const Tensor<TIn>& in_n_c_hi_wi,
|
||||
const Tensor<TWei>& wei_k_c_y_x,
|
||||
Tensor<TOut>& out_n_k_ho_wo,
|
||||
const Tensor<TOut>& bias_k,
|
||||
const std::vector<ck::index_t>& conv_strides,
|
||||
const std::vector<ck::index_t>& conv_dilations,
|
||||
const std::vector<ck::index_t>& in_left_pads,
|
||||
const std::vector<ck::index_t>& /* in_right_pads */,
|
||||
const InElementOp& in_element_op,
|
||||
const WeiElementOp& wei_element_op,
|
||||
const OutElementOp& out_element_op)
|
||||
{
|
||||
auto f_nchw = [&](auto n, auto k, auto ho, auto wo) {
|
||||
float v_acc = 0;
|
||||
|
||||
for(int c = 0; c < wei_k_c_y_x.mDesc.GetLengths()[1]; ++c)
|
||||
{
|
||||
for(int y = 0; y < wei_k_c_y_x.mDesc.GetLengths()[2]; ++y)
|
||||
{
|
||||
int hi = ho * conv_strides[0] + y * conv_dilations[0] - in_left_pads[0];
|
||||
for(int x = 0; x < wei_k_c_y_x.mDesc.GetLengths()[3]; ++x)
|
||||
{
|
||||
int wi = wo * conv_strides[1] + x * conv_dilations[1] - in_left_pads[1];
|
||||
if(hi >= 0 && hi < in_n_c_hi_wi.mDesc.GetLengths()[2] && wi >= 0 &&
|
||||
wi < in_n_c_hi_wi.mDesc.GetLengths()[3])
|
||||
{
|
||||
float v_in;
|
||||
float v_wei;
|
||||
|
||||
in_element_op(v_in, static_cast<const float>(in_n_c_hi_wi(n, c, hi, wi)));
|
||||
wei_element_op(v_wei, static_cast<const float>(wei_k_c_y_x(k, c, y, x)));
|
||||
|
||||
v_acc += v_in * v_wei;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
float v_out;
|
||||
|
||||
out_element_op(v_out, v_acc, static_cast<float>(bias_k(k)));
|
||||
|
||||
out_n_k_ho_wo(n, k, ho, wo) += v_out;
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(f_nchw,
|
||||
out_n_k_ho_wo.mDesc.GetLengths()[0],
|
||||
out_n_k_ho_wo.mDesc.GetLengths()[1],
|
||||
out_n_k_ho_wo.mDesc.GetLengths()[2],
|
||||
out_n_k_ho_wo.mDesc.GetLengths()[3])(
|
||||
std::thread::hardware_concurrency());
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
bool do_verification = 0;
|
||||
int init_method = 0;
|
||||
int nrepeat = 5;
|
||||
|
||||
// Conv shape
|
||||
ck::index_t N = 128;
|
||||
ck::index_t K = 256;
|
||||
ck::index_t C = 192;
|
||||
ck::index_t Y = 3;
|
||||
ck::index_t X = 3;
|
||||
ck::index_t Hi = 71;
|
||||
ck::index_t Wi = 71;
|
||||
ck::index_t conv_stride_h = 2;
|
||||
ck::index_t conv_stride_w = 2;
|
||||
ck::index_t conv_dilation_h = 1;
|
||||
ck::index_t conv_dilation_w = 1;
|
||||
ck::index_t in_left_pad_h = 1;
|
||||
ck::index_t in_left_pad_w = 1;
|
||||
ck::index_t in_right_pad_h = 1;
|
||||
ck::index_t in_right_pad_w = 1;
|
||||
|
||||
if(argc == 4)
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
nrepeat = std::stoi(argv[3]);
|
||||
}
|
||||
else if(argc == 19)
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
nrepeat = std::stoi(argv[3]);
|
||||
|
||||
N = std::stoi(argv[4]);
|
||||
K = std::stoi(argv[5]);
|
||||
C = std::stoi(argv[6]);
|
||||
Y = std::stoi(argv[7]);
|
||||
X = std::stoi(argv[8]);
|
||||
Hi = std::stoi(argv[9]);
|
||||
Wi = std::stoi(argv[10]);
|
||||
conv_stride_h = std::stoi(argv[11]);
|
||||
conv_stride_w = std::stoi(argv[12]);
|
||||
conv_dilation_h = std::stoi(argv[13]);
|
||||
conv_dilation_w = std::stoi(argv[14]);
|
||||
in_left_pad_h = std::stoi(argv[15]);
|
||||
in_left_pad_w = std::stoi(argv[16]);
|
||||
in_right_pad_h = std::stoi(argv[17]);
|
||||
in_right_pad_w = std::stoi(argv[18]);
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("arg1: verification (0=no, 1=yes)\n");
|
||||
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
|
||||
printf("arg3: run kernel # of times (>1)\n");
|
||||
printf("arg4 to 18: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, "
|
||||
"RightPx\n");
|
||||
exit(0);
|
||||
}
|
||||
|
||||
const std::vector<ck::index_t> conv_filter_strides{conv_stride_h, conv_stride_w};
|
||||
const std::vector<ck::index_t> conv_filter_dilations{conv_dilation_h, conv_dilation_w};
|
||||
const std::vector<ck::index_t> input_left_pads{in_left_pad_h, in_left_pad_w};
|
||||
const std::vector<ck::index_t> input_right_pads{in_right_pad_h, in_right_pad_w};
|
||||
const auto output_spatial_lengths =
|
||||
ck::tensor_operation::ConvolutionUtility::ComputeOutputSpatialLengths({Hi, Wi},
|
||||
{Y, X},
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads);
|
||||
|
||||
const ck::index_t Ho = output_spatial_lengths[0];
|
||||
const ck::index_t Wo = output_spatial_lengths[1];
|
||||
|
||||
// tensor layout
|
||||
auto f_host_tensor_descriptor = [](std::size_t N_,
|
||||
std::size_t C_,
|
||||
std::size_t H,
|
||||
std::size_t W,
|
||||
auto layout) {
|
||||
if constexpr(ck::is_same<decltype(layout), ck::tensor_layout::convolution::NCHW>::value ||
|
||||
ck::is_same<decltype(layout), ck::tensor_layout::convolution::KCYX>::value ||
|
||||
ck::is_same<decltype(layout), ck::tensor_layout::convolution::NKHW>::value)
|
||||
{
|
||||
return HostTensorDescriptor(std::vector<std::size_t>({N_, C_, H, W}),
|
||||
std::vector<std::size_t>({C_ * H * W, H * W, W, 1}));
|
||||
}
|
||||
else if constexpr(ck::is_same<decltype(layout),
|
||||
ck::tensor_layout::convolution::NHWC>::value ||
|
||||
ck::is_same<decltype(layout),
|
||||
ck::tensor_layout::convolution::KYXC>::value ||
|
||||
ck::is_same<decltype(layout),
|
||||
ck::tensor_layout::convolution::NHWK>::value)
|
||||
{
|
||||
return HostTensorDescriptor(std::vector<std::size_t>({N_, C_, H, W}),
|
||||
std::vector<std::size_t>({C_ * H * W, 1, W * C_, C_}));
|
||||
}
|
||||
};
|
||||
|
||||
Tensor<InDataType> in_n_c_hi_wi(f_host_tensor_descriptor(N, C, Hi, Wi, InLayout{}));
|
||||
Tensor<WeiDataType> wei_k_c_y_x(f_host_tensor_descriptor(K, C, Y, X, WeiLayout{}));
|
||||
Tensor<OutDataType> out_n_k_ho_wo_host_result(
|
||||
f_host_tensor_descriptor(N, K, Ho, Wo, OutLayout{}));
|
||||
Tensor<OutDataType> out_n_k_ho_wo_device_result(
|
||||
f_host_tensor_descriptor(N, K, Ho, Wo, OutLayout{}));
|
||||
|
||||
// bias: assume contiguous 1d vector
|
||||
Tensor<OutDataType> bias_k(
|
||||
HostTensorDescriptor(std::vector<std::size_t>({static_cast<std::size_t>(K)})));
|
||||
|
||||
std::cout << "in_n_c_hi_wi: " << in_n_c_hi_wi.mDesc << std::endl;
|
||||
std::cout << "wei_k_c_y_x: " << wei_k_c_y_x.mDesc << std::endl;
|
||||
std::cout << "out_n_k_ho_wo: " << out_n_k_ho_wo_host_result.mDesc << std::endl;
|
||||
std::cout << "bias_k: " << bias_k.mDesc << std::endl;
|
||||
|
||||
switch(init_method)
|
||||
{
|
||||
case 0: break;
|
||||
case 1:
|
||||
in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_2<InDataType>{-5, 5});
|
||||
wei_k_c_y_x.GenerateTensorValue(GeneratorTensor_2<WeiDataType>{-5, 5});
|
||||
out_n_k_ho_wo_host_result.GenerateTensorValue(GeneratorTensor_2<OutDataType>{-5, 5});
|
||||
bias_k.GenerateTensorValue(GeneratorTensor_2<OutDataType>{-5, 5});
|
||||
break;
|
||||
default:
|
||||
in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_3<InDataType>{0.0, 1.0});
|
||||
wei_k_c_y_x.GenerateTensorValue(GeneratorTensor_3<WeiDataType>{-0.5, 0.5});
|
||||
out_n_k_ho_wo_host_result.GenerateTensorValue(GeneratorTensor_3<OutDataType>{-0.5, 0.5});
|
||||
bias_k.GenerateTensorValue(GeneratorTensor_3<OutDataType>{0.0, 1.0});
|
||||
}
|
||||
|
||||
DeviceMem in_device_buf(sizeof(InDataType) * in_n_c_hi_wi.mDesc.GetElementSpace());
|
||||
DeviceMem wei_device_buf(sizeof(WeiDataType) * wei_k_c_y_x.mDesc.GetElementSpace());
|
||||
DeviceMem out_device_buf(sizeof(OutDataType) *
|
||||
out_n_k_ho_wo_device_result.mDesc.GetElementSpace());
|
||||
DeviceMem bias_device_buf(sizeof(OutDataType) * bias_k.mDesc.GetElementSpace());
|
||||
|
||||
in_device_buf.ToDevice(in_n_c_hi_wi.mData.data());
|
||||
wei_device_buf.ToDevice(wei_k_c_y_x.mData.data());
|
||||
out_device_buf.ToDevice(out_n_k_ho_wo_host_result.mData.data());
|
||||
bias_device_buf.ToDevice(bias_k.mData.data());
|
||||
|
||||
auto conv = DeviceConvFwdInstance{};
|
||||
auto invoker = conv.MakeInvoker();
|
||||
auto argument =
|
||||
conv.MakeArgument(static_cast<const InDataType*>(in_device_buf.GetDeviceBuffer()),
|
||||
static_cast<const WeiDataType*>(wei_device_buf.GetDeviceBuffer()),
|
||||
static_cast<OutDataType*>(out_device_buf.GetDeviceBuffer()),
|
||||
static_cast<const OutDataType*>(bias_device_buf.GetDeviceBuffer()),
|
||||
N,
|
||||
K,
|
||||
C,
|
||||
std::vector<ck::index_t>{Hi, Wi},
|
||||
std::vector<ck::index_t>{Y, X},
|
||||
std::vector<ck::index_t>{Ho, Wo},
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads,
|
||||
InElementOp{},
|
||||
WeiElementOp{},
|
||||
OutElementOp{});
|
||||
|
||||
if(!conv.IsSupportedArgument(argument))
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"wrong! device operator with the specified compilation parameters does "
|
||||
"not support this problem");
|
||||
}
|
||||
|
||||
float ave_time = invoker.Run(argument, nrepeat);
|
||||
|
||||
std::size_t flop = std::size_t(2) * N * K * Ho * Wo * C * Y * X;
|
||||
|
||||
std::size_t num_btype = sizeof(InDataType) * (N * C * Hi * Wi) +
|
||||
sizeof(WeiDataType) * (K * C * Y * X) +
|
||||
sizeof(OutDataType) * (N * K * Ho * Wo) + sizeof(OutDataType) * (K);
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s"
|
||||
<< std::endl;
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
host_reference_calculation(in_n_c_hi_wi,
|
||||
wei_k_c_y_x,
|
||||
out_n_k_ho_wo_host_result,
|
||||
bias_k,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads,
|
||||
InElementOp{},
|
||||
WeiElementOp{},
|
||||
OutElementOp{});
|
||||
|
||||
out_device_buf.FromDevice(out_n_k_ho_wo_device_result.mData.data());
|
||||
|
||||
check_error(out_n_k_ho_wo_host_result, out_n_k_ho_wo_device_result);
|
||||
}
|
||||
}
|
||||
@@ -1,57 +0,0 @@
|
||||
# Instructions for ```conv2d_fwd_xdl``` Example
|
||||
|
||||
## Docker script
|
||||
```bash
|
||||
docker run \
|
||||
-it \
|
||||
--rm \
|
||||
--privileged \
|
||||
--group-add sudo \
|
||||
-w /root/workspace \
|
||||
-v ${PATH_TO_LOCAL_WORKSPACE}:/root/workspace \
|
||||
rocm/tensorflow:rocm4.3.1-tf2.6-dev \
|
||||
/bin/bash
|
||||
```
|
||||
|
||||
## Build ```conv2d_fwd_xdl```
|
||||
```bash
|
||||
mkdir build && cd build
|
||||
```
|
||||
|
||||
```bash
|
||||
# Need to specify target ID, example below is gfx908
|
||||
cmake \
|
||||
-D BUILD_DEV=OFF \
|
||||
-D CMAKE_BUILD_TYPE=Release \
|
||||
-D CMAKE_CXX_FLAGS="-DCK_AMD_GPU_GFX908 --amdgpu-target=gfx908 -O3 " \
|
||||
-D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \
|
||||
-D CMAKE_PREFIX_PATH=/opt/rocm \
|
||||
..
|
||||
```
|
||||
|
||||
```bash
|
||||
make -j conv2d_fwd_xdl
|
||||
```
|
||||
|
||||
## Run ```conv2d_fwd_xdl_int8```
|
||||
```bash
|
||||
#arg1: verification (0=no, 1=yes)
|
||||
#arg2: initialization (0=no init, 1=integer value, 2=decimal value)
|
||||
#arg3: run kernel # of times (>1)
|
||||
#arg4 to 18: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, RightPx
|
||||
./example/conv2d_fwd_xdl_int8 0 1 5
|
||||
```
|
||||
|
||||
Result (MI100 @ 1087Mhz, 133.5TFlops peak FP16)
|
||||
```
|
||||
in_n_c_hi_wi: dim 4, lengths {128, 192, 71, 71}, strides {967872, 1, 13632, 192}
|
||||
wei_k_c_y_x: dim 4, lengths {256, 192, 3, 3}, strides {1728, 1, 576, 192}
|
||||
out_n_k_ho_wo: dim 4, lengths {128, 256, 36, 36}, strides {331776, 1, 9216, 256}
|
||||
arg.a_grid_desc_k0_m_k1_{216, 165888, 8}
|
||||
arg.b_grid_desc_k0_n_k1_{216, 256, 8}
|
||||
arg.c_grid_desc_m_n_{ 165888, 256}
|
||||
launch_and_time_kernel: grid_dim {1296, 1, 1}, block_dim {256, 1, 1}
|
||||
Warm up
|
||||
Start running 5 times...
|
||||
Perf: 1.43206 ms, 102.486 TFlops, 232.947 GB/s
|
||||
```
|
||||
@@ -1,69 +1,40 @@
|
||||
include_directories(BEFORE
|
||||
${PROJECT_SOURCE_DIR}
|
||||
${PROJECT_SOURCE_DIR}/host/host_tensor/include
|
||||
${PROJECT_SOURCE_DIR}/host/device/include
|
||||
${PROJECT_SOURCE_DIR}/device_operation/include
|
||||
${PROJECT_SOURCE_DIR}/reference_operation/include
|
||||
${PROJECT_SOURCE_DIR}/composable_kernel/include
|
||||
${PROJECT_SOURCE_DIR}/composable_kernel/include/utility
|
||||
${PROJECT_SOURCE_DIR}/composable_kernel/include/tensor_description
|
||||
${PROJECT_SOURCE_DIR}/composable_kernel/include/tensor_operation
|
||||
${PROJECT_SOURCE_DIR}/composable_kernel/include/problem_transform
|
||||
${PROJECT_SOURCE_DIR}/external/rocm/include
|
||||
${PROJECT_SOURCE_DIR}/device_operation_reference/include
|
||||
${PROJECT_SOURCE_DIR}/include/ck
|
||||
${PROJECT_SOURCE_DIR}/include/ck/utility
|
||||
${PROJECT_SOURCE_DIR}/include/ck/tensor_description
|
||||
${PROJECT_SOURCE_DIR}/include/ck/tensor
|
||||
${PROJECT_SOURCE_DIR}/include/ck/problem_transform
|
||||
${PROJECT_SOURCE_DIR}/include/ck/tensor_operation/gpu/device
|
||||
${PROJECT_SOURCE_DIR}/include/ck/tensor_operation/gpu/grid
|
||||
${PROJECT_SOURCE_DIR}/include/ck/tensor_operation/gpu/block
|
||||
${PROJECT_SOURCE_DIR}/include/ck/tensor_operation/gpu/warp
|
||||
${PROJECT_SOURCE_DIR}/include/ck/tensor_operation/gpu/thread
|
||||
${PROJECT_SOURCE_DIR}/include/ck/tensor_operation/gpu/element
|
||||
${PROJECT_SOURCE_DIR}/library/include/ck/library/host_tensor
|
||||
${PROJECT_SOURCE_DIR}/library/include/ck/library/reference_tensor_operation/cpu
|
||||
${PROJECT_SOURCE_DIR}/library/include/ck/library/reference_tensor_operation/gpu
|
||||
${PROJECT_SOURCE_DIR}/external/include/half
|
||||
)
|
||||
|
||||
set(GEMM_XDL_SOURCE 1_gemm_xdl/gemm_xdl.cpp)
|
||||
set(GEMM_XDL_INT8_SOURCE 1_gemm_xdl/gemm_xdl_int8.cpp)
|
||||
set(GEMM_XDL_BF16_SOURCE 1_gemm_xdl/gemm_xdl_bf16.cpp)
|
||||
set(GEMM_XDL_BIAS_RELU_SOURCE 2_gemm_xdl_bias_relu/gemm_xdl_bias_relu.cpp)
|
||||
set(GEMM_XDL_BIAS_RELU_ADD_SOURCE 3_gemm_xdl_bias_relu_add/gemm_xdl_bias_relu_add.cpp)
|
||||
set(CONV2D_FWD_XDL_SOURCE 4_conv2d_fwd_xdl/conv2d_fwd_xdl.cpp)
|
||||
set(CONV2D_FWD_XDL_BIAS_RELU_SOURCE 5_conv2d_fwd_xdl_bias_relu/conv2d_fwd_xdl_bias_relu.cpp)
|
||||
set(CONV2D_FWD_XDL_BIAS_RELU_ADD_SOURCE 6_conv2d_fwd_xdl_bias_relu_add/conv2d_fwd_xdl_bias_relu_add.cpp)
|
||||
set(CONV2D_FWD_XDL_BIAS_RELU_ATOMIC_ADD_SOURCE 7_conv2d_fwd_xdl_bias_relu_atomic_add/conv2d_fwd_xdl_bias_relu_atomic_add.cpp)
|
||||
set(GEMM_XDL_ALPHA_BETA_SOURCE 8_gemm_xdl_alpha_beta/gemm_xdl_alpha_beta.cpp)
|
||||
set(CONV2D_FWD_XDL_INT8_SOURCE 9_conv2d_fwd_xdl_int8/conv2d_fwd_xdl_int8.cpp)
|
||||
set(CONV2D_WRW_XDL_SOURCE 13_conv2d_backward_weight_xdl/main.cpp)
|
||||
set(CONV3D_FWD_XDL_SOURCE 10_conv3d_fwd_xdl/conv3d_fwd_xdl.cpp)
|
||||
set(CONVND_FWD_XDL_SOURCE 11_convnd_fwd_xdl/convnd_fwd_xdl.cpp)
|
||||
set(CONV2D_BWD_DATA_XDL_SOURCE 12_conv2d_bwd_data_xdl/conv2d_bwd_data_xdl.cpp)
|
||||
set(POOL2D_FWD_SOURCE 12_pool2d_fwd/pool2d_fwd.cpp)
|
||||
set(REDUCE_BLOCKWISE_SOURCE 13_reduce_blockwise/reduce_blockwise.cpp)
|
||||
add_custom_target(examples)
|
||||
|
||||
add_executable(gemm_xdl ${GEMM_XDL_SOURCE})
|
||||
add_executable(gemm_xdl_int8 ${GEMM_XDL_INT8_SOURCE})
|
||||
add_executable(gemm_xdl_bf16 ${GEMM_XDL_BF16_SOURCE})
|
||||
add_executable(gemm_xdl_bias_relu ${GEMM_XDL_BIAS_RELU_SOURCE})
|
||||
add_executable(gemm_xdl_bias_relu_add ${GEMM_XDL_BIAS_RELU_ADD_SOURCE})
|
||||
add_executable(conv2d_fwd_xdl ${CONV2D_FWD_XDL_SOURCE})
|
||||
add_executable(conv2d_fwd_xdl_bias_relu ${CONV2D_FWD_XDL_BIAS_RELU_SOURCE})
|
||||
add_executable(conv2d_fwd_xdl_bias_relu_add ${CONV2D_FWD_XDL_BIAS_RELU_ADD_SOURCE})
|
||||
add_executable(conv2d_fwd_xdl_bias_relu_atomic_add ${CONV2D_FWD_XDL_BIAS_RELU_ATOMIC_ADD_SOURCE})
|
||||
add_executable(gemm_xdl_alpha_beta ${GEMM_XDL_ALPHA_BETA_SOURCE})
|
||||
add_executable(conv2d_fwd_xdl_int8 ${CONV2D_FWD_XDL_INT8_SOURCE})
|
||||
add_executable(conv2d_wrw_xdl ${CONV2D_WRW_XDL_SOURCE})
|
||||
add_executable(conv3d_fwd_xdl ${CONV3D_FWD_XDL_SOURCE})
|
||||
add_executable(convnd_fwd_xdl ${CONVND_FWD_XDL_SOURCE})
|
||||
add_executable(conv2d_bwd_data_xdl ${CONV2D_BWD_DATA_XDL_SOURCE})
|
||||
add_executable(pool2d_fwd ${POOL2D_FWD_SOURCE})
|
||||
add_executable(reduce_blockwise ${REDUCE_BLOCKWISE_SOURCE})
|
||||
|
||||
target_link_libraries(gemm_xdl PRIVATE host_tensor)
|
||||
target_link_libraries(gemm_xdl_int8 PRIVATE host_tensor)
|
||||
target_link_libraries(gemm_xdl_bf16 PRIVATE host_tensor)
|
||||
target_link_libraries(gemm_xdl_bias_relu PRIVATE host_tensor)
|
||||
target_link_libraries(gemm_xdl_bias_relu_add PRIVATE host_tensor)
|
||||
target_link_libraries(conv2d_fwd_xdl PRIVATE host_tensor)
|
||||
target_link_libraries(conv2d_fwd_xdl_bias_relu PRIVATE host_tensor)
|
||||
target_link_libraries(conv2d_fwd_xdl_bias_relu_add PRIVATE host_tensor)
|
||||
target_link_libraries(conv2d_fwd_xdl_bias_relu_atomic_add PRIVATE host_tensor)
|
||||
target_link_libraries(gemm_xdl_alpha_beta PRIVATE host_tensor)
|
||||
target_link_libraries(conv2d_fwd_xdl_int8 PRIVATE host_tensor)
|
||||
target_link_libraries(conv2d_wrw_xdl PRIVATE host_tensor)
|
||||
target_link_libraries(conv3d_fwd_xdl PRIVATE host_tensor)
|
||||
target_link_libraries(convnd_fwd_xdl PRIVATE host_tensor)
|
||||
target_link_libraries(conv2d_bwd_data_xdl PRIVATE host_tensor)
|
||||
target_link_libraries(pool2d_fwd PRIVATE host_tensor)
|
||||
target_link_libraries(reduce_blockwise PRIVATE host_tensor)
|
||||
function(add_example_executable EXAMPLE_NAME)
|
||||
message("adding example ${EXAMPLE_NAME}")
|
||||
add_executable(${EXAMPLE_NAME} ${ARGN})
|
||||
target_link_libraries(${EXAMPLE_NAME} PRIVATE host_tensor)
|
||||
add_dependencies(examples ${EXAMPLE_NAME})
|
||||
endfunction(add_example_executable EXAMPLE_NAME)
|
||||
|
||||
add_subdirectory(01_gemm)
|
||||
add_subdirectory(02_gemm_alpha_beta)
|
||||
add_subdirectory(03_gemm_bias_relu)
|
||||
add_subdirectory(04_gemm_bias_relu_add)
|
||||
add_subdirectory(05_conv2d_fwd)
|
||||
add_subdirectory(06_conv2d_fwd_bias_relu)
|
||||
add_subdirectory(07_conv2d_fwd_bias_relu_add)
|
||||
add_subdirectory(08_conv3d_fwd)
|
||||
add_subdirectory(09_convnd_fwd)
|
||||
add_subdirectory(10_conv2d_bwd_data)
|
||||
add_subdirectory(11_conv2d_bwd_wgt)
|
||||
add_subdirectory(12_reduce)
|
||||
add_subdirectory(13_pool2d_fwd)
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
add_subdirectory(host_tensor)
|
||||
@@ -1,689 +0,0 @@
|
||||
#ifndef CONV_IGEMM_FWD_V6R1_DLOPS_NCHW_KCYX_NKHW_HPP
|
||||
#define CONV_IGEMM_FWD_V6R1_DLOPS_NCHW_KCYX_NKHW_HPP
|
||||
|
||||
#include <numeric>
|
||||
#include <sstream>
|
||||
|
||||
namespace ck {
|
||||
namespace driver {
|
||||
|
||||
struct CompileParameterConvIgemmFwdV6r1DlopsNchwKcyxNkhw
|
||||
{
|
||||
auto GetCompileParameterString() const
|
||||
{
|
||||
auto param = std::stringstream();
|
||||
|
||||
// clang-format off
|
||||
param <<
|
||||
" -DCK_PARAM_ABDataTypeEnum=" <<
|
||||
ABDataTypeEnum <<
|
||||
" -DCK_PARAM_AccDataTypeEnum=" <<
|
||||
AccDataTypeEnum <<
|
||||
" -DCK_PARAM_CDataTypeEnum=" <<
|
||||
CDataTypeEnum <<
|
||||
" -DCK_PARAM_BlockSize=" <<
|
||||
BlockSize <<
|
||||
" -DCK_PARAM_GN0=" <<
|
||||
GN0 <<
|
||||
" -DCK_PARAM_GK1=" <<
|
||||
GK1 <<
|
||||
" -DCK_PARAM_GM1PerBlockGM11="
|
||||
<< GM1PerBlockGM11 <<
|
||||
" -DCK_PARAM_GN1PerBlockGN11=" <<
|
||||
GN1PerBlockGN11 <<
|
||||
" -DCK_PARAM_GK0PerBlock=" <<
|
||||
GK0PerBlock <<
|
||||
" -DCK_PARAM_BM1PerThreadBM11=" <<
|
||||
BM1PerThreadBM11 <<
|
||||
" -DCK_PARAM_BN1PerThreadBN11=" <<
|
||||
BN1PerThreadBN11 <<
|
||||
" -DCK_PARAM_BK0PerThread=" <<
|
||||
BK0PerThread <<
|
||||
" -DCK_PARAM_BM10BN10ThreadClusterBM10Xs=" <<
|
||||
BM10BN10ThreadClusterBM10Xs[0] << "," <<
|
||||
BM10BN10ThreadClusterBM10Xs[1] <<
|
||||
" -DCK_PARAM_BM10BN10ThreadClusterBN10Xs=" <<
|
||||
BM10BN10ThreadClusterBN10Xs[0] << "," <<
|
||||
BM10BN10ThreadClusterBN10Xs[1] <<
|
||||
" -DCK_PARAM_ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1=" <<
|
||||
ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1[0] << "," <<
|
||||
ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1[1] << "," <<
|
||||
ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1[2] << "," <<
|
||||
ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1[3] << "," <<
|
||||
ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1[4] <<
|
||||
" -DCK_PARAM_ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1=" <<
|
||||
ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1[0] << "," <<
|
||||
ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1[1] << "," <<
|
||||
ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1[2] << "," <<
|
||||
ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1[3] << "," <<
|
||||
ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1[4] <<
|
||||
" -DCK_PARAM_ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1=" <<
|
||||
ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1[0] << "," <<
|
||||
ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1[1] << "," <<
|
||||
ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1[2] << "," <<
|
||||
ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1[3] << "," <<
|
||||
ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1[4] <<
|
||||
" -DCK_PARAM_ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1=" <<
|
||||
ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1[0] << "," <<
|
||||
ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1[1] << "," <<
|
||||
ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1[2] << "," <<
|
||||
ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1[3] << "," <<
|
||||
ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1[4] <<
|
||||
" -DCK_PARAM_BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1=" <<
|
||||
BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1[0] << "," <<
|
||||
BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1[1] << "," <<
|
||||
BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1[2] << "," <<
|
||||
BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1[3] << "," <<
|
||||
BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1[4] <<
|
||||
" -DCK_PARAM_BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1=" <<
|
||||
BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1[0] << "," <<
|
||||
BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1[1] << "," <<
|
||||
BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1[2] << "," <<
|
||||
BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1[3] << "," <<
|
||||
BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1[4] <<
|
||||
" -DCK_PARAM_BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1=" <<
|
||||
BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1[0] << "," <<
|
||||
BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1[1] << "," <<
|
||||
BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1[2] << "," <<
|
||||
BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1[3] << "," <<
|
||||
BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1[4] <<
|
||||
" -DCK_PARAM_BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1=" <<
|
||||
BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1[0] << "," <<
|
||||
BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1[1] << "," <<
|
||||
BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1[2] << "," <<
|
||||
BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1[3] << "," <<
|
||||
BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1[4] <<
|
||||
" -DCK_PARAM_CThreadTransferDstScalarPerVector=" <<
|
||||
CThreadTransferDstScalarPerVector <<
|
||||
" -DCK_PARAM_HasMainKBlockLoop=" <<
|
||||
static_cast<int>(HasMainKBlockLoop) <<
|
||||
" -DCK_PARAM_HasDoubleTailKBlockLoop=" <<
|
||||
static_cast<int>(HasDoubleTailKBlockLoop);
|
||||
// clang-format on
|
||||
|
||||
return param.str();
|
||||
}
|
||||
|
||||
ck::DataTypeEnum_t ABDataTypeEnum = ck::DataTypeEnum_t::Unknown;
|
||||
ck::DataTypeEnum_t AccDataTypeEnum = ck::DataTypeEnum_t::Unknown;
|
||||
ck::DataTypeEnum_t CDataTypeEnum = ck::DataTypeEnum_t::Unknown;
|
||||
|
||||
int BlockSize = -1;
|
||||
|
||||
int GN0 = -1;
|
||||
int GK1 = -1;
|
||||
|
||||
int GM1PerBlockGM11 = -1;
|
||||
int GN1PerBlockGN11 = -1;
|
||||
int GK0PerBlock = -1;
|
||||
|
||||
int BM1PerThreadBM11 = -1;
|
||||
int BN1PerThreadBN11 = -1;
|
||||
int BK0PerThread = -1;
|
||||
|
||||
std::array<int, 2> BM10BN10ThreadClusterBM10Xs = {-1, -1};
|
||||
std::array<int, 2> BM10BN10ThreadClusterBN10Xs = {-1, -1};
|
||||
|
||||
std::array<int, 5> ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1 = {
|
||||
-1, -1, -1, -1, -1};
|
||||
std::array<int, 5> ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1 = {
|
||||
-1, -1, -1, -1, -1};
|
||||
std::array<int, 5> ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1 = {
|
||||
-1, -1, -1, -1, -1};
|
||||
std::array<int, 5> ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1 = {
|
||||
-1, -1, -1, -1, -1};
|
||||
|
||||
std::array<int, 5> BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1 = {
|
||||
-1, -1, -1, -1, -1};
|
||||
std::array<int, 5> BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1 = {
|
||||
-1, -1, -1, -1, -1};
|
||||
std::array<int, 5> BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1 = {
|
||||
-1, -1, -1, -1, -1};
|
||||
std::array<int, 5> BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1 = {
|
||||
-1, -1, -1, -1, -1};
|
||||
|
||||
int CThreadTransferDstScalarPerVector = -1;
|
||||
|
||||
bool HasMainKBlockLoop = false;
|
||||
bool HasDoubleTailKBlockLoop = false;
|
||||
};
|
||||
|
||||
struct TunableConvIgemmFwdV6r1DlopsNchwKcyxNkhw
|
||||
{
|
||||
ck::DataTypeEnum_t ABDataTypeEnum;
|
||||
ck::DataTypeEnum_t CDataTypeEnum;
|
||||
|
||||
int BlockSize;
|
||||
|
||||
int GN0;
|
||||
int GK1;
|
||||
|
||||
int GM1PerBlockGM11;
|
||||
int GN1PerBlockGN11;
|
||||
int GK0PerBlock;
|
||||
|
||||
int BM1PerThreadBM11;
|
||||
int BN1PerThreadBN11;
|
||||
int BK0PerThread;
|
||||
|
||||
std::array<int, 2> BM10BN10ThreadClusterBM10Xs;
|
||||
std::array<int, 2> BM10BN10ThreadClusterBN10Xs;
|
||||
|
||||
std::array<int, 5> ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1;
|
||||
std::array<int, 5> ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1;
|
||||
std::array<int, 5> ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1;
|
||||
std::array<int, 5> ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1;
|
||||
|
||||
std::array<int, 5> BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1;
|
||||
std::array<int, 5> BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1;
|
||||
std::array<int, 5> BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1;
|
||||
std::array<int, 5> BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1;
|
||||
};
|
||||
|
||||
inline static auto generate_tunable_list_conv_igemm_fwd_v6r1_dlops_nchw_kcyx_nkhw()
|
||||
{
|
||||
constexpr auto f32 = ck::DataTypeEnum_t::Float;
|
||||
constexpr auto f16 = ck::DataTypeEnum_t::Half;
|
||||
constexpr auto i8 = ck::DataTypeEnum_t::Int8;
|
||||
|
||||
return std::vector<TunableConvIgemmFwdV6r1DlopsNchwKcyxNkhw>{
|
||||
// clang-format off
|
||||
// fp32
|
||||
{f32, f32, 256, 1, 1, 128, 128, 16, 4, 4, 1, {8, 2}, {8, 2}, {4, 1, 1, 2, 1}, {4, 1, 1, 64, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {2, 1, 1, 4, 1}, { 8, 1, 1, 32, 1}, {1, 1, 1, 4, 1}, {1, 1, 1, 4, 1}},
|
||||
|
||||
{f32, f32, 256, 1, 1, 128, 128, 8, 4, 4, 1, {8, 2}, {8, 2}, {4, 1, 1, 1, 1}, {2, 1, 1, 128, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 4, 1}, { 8, 1, 1, 32, 1}, {1, 1, 1, 4, 1}, {1, 1, 1, 4, 1}},
|
||||
{f32, f32, 256, 1, 1, 128, 128, 8, 4, 4, 1, {8, 2}, {8, 2}, {4, 1, 1, 1, 1}, {2, 1, 1, 128, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 4, 1}, { 8, 1, 1, 32, 1}, {1, 1, 1, 2, 1}, {1, 1, 1, 4, 1}},
|
||||
{f32, f32, 256, 1, 1, 128, 128, 8, 4, 4, 1, {8, 2}, {8, 2}, {4, 1, 1, 1, 1}, {2, 1, 1, 128, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 4, 1}, { 8, 1, 1, 32, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 4, 1}},
|
||||
|
||||
{f32, f32, 256, 1, 1, 128, 128, 8, 4, 4, 1, {8, 2}, {8, 2}, {4, 1, 1, 1, 1}, {2, 1, 1, 128, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {4, 1, 1, 1, 1}, { 2, 1, 1, 128, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}},
|
||||
{f32, f32, 256, 2, 1, 128, 64, 8, 4, 4, 1, {8, 2}, {8, 2}, {4, 1, 1, 1, 1}, {2, 1, 1, 128, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {2, 2, 1, 1, 1}, { 4, 1, 1, 64, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}},
|
||||
{f32, f32, 256, 4, 1, 128, 32, 8, 4, 4, 1, {8, 2}, {8, 2}, {4, 1, 1, 1, 1}, {2, 1, 1, 128, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {1, 4, 1, 1, 1}, { 8, 1, 1, 32, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}},
|
||||
|
||||
{f32, f32, 256, 8, 1, 128, 16, 16, 4, 4, 1, {8, 2}, {8, 2}, {8, 1, 1, 1, 1}, {2, 1, 1, 128, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {1, 8, 1, 1, 1}, {16, 1, 1, 16, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}},
|
||||
|
||||
{f32, f32, 128, 1, 1, 64, 128, 8, 4, 4, 1, {4, 2}, {8, 2}, {4, 1, 1, 1, 1}, {2, 1, 1, 64, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {8, 1, 1, 1, 1}, { 1, 1, 1, 128, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}},
|
||||
|
||||
// fp16
|
||||
{f16, f16, 256, 1, 2, 128, 128, 16, 4, 4, 1, {8, 2}, {8, 2}, {4, 1, 1, 2, 2}, {4, 1, 1, 64, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {2, 1, 1, 4, 2}, { 8, 1, 1, 32, 1}, {1, 1, 1, 4, 1}, {1, 1, 1, 4, 1}},
|
||||
|
||||
{f16, f16, 256, 1, 2, 128, 128, 8, 4, 4, 1, {8, 2}, {8, 2}, {4, 1, 1, 1, 2}, {2, 1, 1, 128, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 4, 2}, { 8, 1, 1, 32, 1}, {1, 1, 1, 4, 1}, {1, 1, 1, 4, 1}},
|
||||
{f16, f16, 256, 1, 2, 128, 128, 8, 4, 4, 1, {8, 2}, {8, 2}, {4, 1, 1, 1, 2}, {2, 1, 1, 128, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 4, 2}, { 8, 1, 1, 32, 1}, {1, 1, 1, 2, 1}, {1, 1, 1, 4, 1}},
|
||||
{f16, f16, 256, 1, 2, 128, 128, 8, 4, 4, 1, {8, 2}, {8, 2}, {4, 1, 1, 1, 2}, {2, 1, 1, 128, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 4, 2}, { 8, 1, 1, 32, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 4, 1}},
|
||||
|
||||
{f16, f16, 256, 1, 2, 128, 128, 8, 4, 4, 1, {8, 2}, {8, 2}, {4, 1, 1, 1, 2}, {2, 1, 1, 128, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {4, 1, 1, 1, 2}, { 2, 1, 1, 128, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}},
|
||||
{f16, f16, 256, 2, 2, 128, 64, 8, 4, 4, 1, {8, 2}, {8, 2}, {4, 1, 1, 1, 2}, {2, 1, 1, 128, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {2, 2, 1, 1, 2}, { 4, 1, 1, 64, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}},
|
||||
{f16, f16, 256, 4, 2, 128, 32, 8, 4, 4, 1, {8, 2}, {8, 2}, {4, 1, 1, 1, 2}, {2, 1, 1, 128, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {1, 4, 1, 1, 2}, { 8, 1, 1, 32, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}},
|
||||
|
||||
{f16, f16, 256, 8, 2, 128, 16, 16, 4, 4, 1, {8, 2}, {8, 2}, {8, 1, 1, 1, 2}, {2, 1, 1, 128, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {1, 8, 1, 1, 2}, {16, 1, 1, 16, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}},
|
||||
|
||||
{f16, f16, 128, 1, 2, 64, 128, 8, 4, 4, 1, {4, 2}, {8, 2}, {4, 1, 1, 1, 2}, {2, 1, 1, 64, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {8, 1, 1, 1, 2}, { 1, 1, 1, 128, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}},
|
||||
|
||||
// i8
|
||||
{ i8, i8, 256, 1, 4, 128, 128, 16, 4, 4, 1, {8, 2}, {8, 2}, {4, 1, 1, 2, 4}, {4, 1, 1, 64, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {2, 1, 1, 4, 4}, { 8, 1, 1, 32, 1}, {1, 1, 1, 4, 1}, {1, 1, 1, 4, 1}},
|
||||
|
||||
{ i8, i8, 256, 1, 4, 128, 128, 8, 4, 4, 1, {8, 2}, {8, 2}, {4, 1, 1, 1, 4}, {2, 1, 1, 128, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 4, 4}, { 8, 1, 1, 32, 1}, {1, 1, 1, 4, 1}, {1, 1, 1, 4, 1}},
|
||||
{ i8, i8, 256, 1, 4, 128, 128, 8, 4, 4, 1, {8, 2}, {8, 2}, {4, 1, 1, 1, 4}, {2, 1, 1, 128, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 4, 4}, { 8, 1, 1, 32, 1}, {1, 1, 1, 2, 1}, {1, 1, 1, 4, 1}},
|
||||
{ i8, i8, 256, 1, 4, 128, 128, 8, 4, 4, 1, {8, 2}, {8, 2}, {4, 1, 1, 1, 4}, {2, 1, 1, 128, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 4, 4}, { 8, 1, 1, 32, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 4, 1}},
|
||||
|
||||
{ i8, i8, 256, 1, 4, 128, 128, 8, 4, 4, 1, {8, 2}, {8, 2}, {4, 1, 1, 1, 4}, {2, 1, 1, 128, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {4, 1, 1, 1, 4}, { 2, 1, 1, 128, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}},
|
||||
{ i8, i8, 256, 2, 4, 128, 64, 8, 4, 4, 1, {8, 2}, {8, 2}, {4, 1, 1, 1, 4}, {2, 1, 1, 128, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {2, 2, 1, 1, 4}, { 4, 1, 1, 64, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}},
|
||||
{ i8, i8, 256, 4, 4, 128, 32, 8, 4, 4, 1, {8, 2}, {8, 2}, {4, 1, 1, 1, 4}, {2, 1, 1, 128, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {1, 4, 1, 1, 4}, { 8, 1, 1, 32, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}},
|
||||
|
||||
{ i8, i8, 256, 8, 4, 128, 16, 16, 4, 4, 1, {8, 2}, {8, 2}, {8, 1, 1, 1, 4}, {2, 1, 1, 128, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {1, 8, 1, 1, 4}, {16, 1, 1, 16, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}},
|
||||
|
||||
{ i8, i8, 128, 1, 4, 64, 128, 8, 4, 4, 1, {4, 2}, {8, 2}, {4, 1, 1, 1, 4}, {2, 1, 1, 64, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {8, 1, 1, 1, 4}, { 1, 1, 1, 128, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}}
|
||||
// clang-format on
|
||||
};
|
||||
}
|
||||
|
||||
// TODO make this common interface and write specs for it
|
||||
struct ConvIgemmFwdV6r1DlopsNchwKcyxNkhw
|
||||
{
|
||||
static auto
|
||||
CalculateCompileParameterBasedOnTunable(const ConvolutionProblemDescriptor& conv_problem_desc,
|
||||
const TunableConvIgemmFwdV6r1DlopsNchwKcyxNkhw& tunable)
|
||||
{
|
||||
const int C = conv_problem_desc.C;
|
||||
const int Y = conv_problem_desc.Y;
|
||||
const int X = conv_problem_desc.X;
|
||||
const int Ho = conv_problem_desc.Ho;
|
||||
const int Wo = conv_problem_desc.Wo;
|
||||
|
||||
if(!(conv_problem_desc.InDataTypeEnum == tunable.ABDataTypeEnum &&
|
||||
conv_problem_desc.WeiDataTypeEnum == tunable.ABDataTypeEnum &&
|
||||
conv_problem_desc.OutDataTypeEnum == tunable.CDataTypeEnum))
|
||||
return std::make_tuple(CompileParameterConvIgemmFwdV6r1DlopsNchwKcyxNkhw{}, false);
|
||||
|
||||
const auto ABDataTypeEnum = conv_problem_desc.InDataTypeEnum;
|
||||
const auto CDataTypeEnum = conv_problem_desc.OutDataTypeEnum;
|
||||
|
||||
DataTypeEnum_t AccDataTypeEnum;
|
||||
|
||||
if(ABDataTypeEnum == DataTypeEnum_t::Float || ABDataTypeEnum == DataTypeEnum_t::Half)
|
||||
{
|
||||
AccDataTypeEnum = DataTypeEnum_t::Float;
|
||||
}
|
||||
else if(ABDataTypeEnum == DataTypeEnum_t::Int8)
|
||||
{
|
||||
AccDataTypeEnum = DataTypeEnum_t::Int32;
|
||||
}
|
||||
else
|
||||
{
|
||||
return std::make_tuple(CompileParameterConvIgemmFwdV6r1DlopsNchwKcyxNkhw{}, false);
|
||||
}
|
||||
|
||||
const int BlockSize = tunable.BlockSize;
|
||||
|
||||
const int GN0 = tunable.GN0;
|
||||
const int GK1 = tunable.GK1;
|
||||
|
||||
const int GM11 = tunable.GM1PerBlockGM11;
|
||||
const int GN11 = tunable.GN1PerBlockGN11;
|
||||
const int GK0PerBlock = tunable.GK0PerBlock;
|
||||
|
||||
const int BM11 = tunable.BM1PerThreadBM11;
|
||||
const int BN11 = tunable.BN1PerThreadBN11;
|
||||
const int BK0PerThread = tunable.BK0PerThread;
|
||||
|
||||
const auto BM10BN10ThreadClusterBM10Xs = tunable.BM10BN10ThreadClusterBM10Xs;
|
||||
const auto BM10BN10ThreadClusterBN10Xs = tunable.BM10BN10ThreadClusterBN10Xs;
|
||||
|
||||
const auto ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1 =
|
||||
tunable.ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1;
|
||||
const auto ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1 =
|
||||
tunable.ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1;
|
||||
const auto ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1 =
|
||||
tunable.ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1;
|
||||
const auto ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1 =
|
||||
tunable.ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1;
|
||||
|
||||
const auto BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1 =
|
||||
tunable.BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1;
|
||||
const auto BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1 =
|
||||
tunable.BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1;
|
||||
const auto BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1 =
|
||||
tunable.BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1;
|
||||
const auto BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1 =
|
||||
tunable.BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1;
|
||||
|
||||
// C threadwise copy: {BN11} or {BN} or {BN1} or {GN11} is Dst vector dim
|
||||
const int CThreadTransferDstScalarPerVector = gcd(4, GN11, BN11, Ho * Wo);
|
||||
|
||||
const int C0 = GK1;
|
||||
|
||||
if(!(C % C0 == 0))
|
||||
return std::make_tuple(CompileParameterConvIgemmFwdV6r1DlopsNchwKcyxNkhw{}, false);
|
||||
|
||||
const int C1 = C / C0;
|
||||
|
||||
const int GK0 = C1 * Y * X;
|
||||
|
||||
if(!(GK0 % GK0PerBlock == 0))
|
||||
return std::make_tuple(CompileParameterConvIgemmFwdV6r1DlopsNchwKcyxNkhw{}, false);
|
||||
|
||||
const bool HasMainKBlockLoop = ((GK0 + GK0PerBlock) / (2 * GK0PerBlock) > 1);
|
||||
|
||||
const bool HasDoubleTailKBlockLoop = ((GK0 / GK0PerBlock) % 2 == 0);
|
||||
|
||||
return std::make_tuple(
|
||||
CompileParameterConvIgemmFwdV6r1DlopsNchwKcyxNkhw{
|
||||
ABDataTypeEnum,
|
||||
AccDataTypeEnum,
|
||||
CDataTypeEnum,
|
||||
BlockSize,
|
||||
GN0,
|
||||
GK1,
|
||||
GM11,
|
||||
GN11,
|
||||
GK0PerBlock,
|
||||
BM11,
|
||||
BN11,
|
||||
BK0PerThread,
|
||||
BM10BN10ThreadClusterBM10Xs,
|
||||
BM10BN10ThreadClusterBN10Xs,
|
||||
ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1,
|
||||
ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1,
|
||||
ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1,
|
||||
ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1,
|
||||
BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1,
|
||||
BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1,
|
||||
BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1,
|
||||
BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1,
|
||||
CThreadTransferDstScalarPerVector,
|
||||
HasMainKBlockLoop,
|
||||
HasDoubleTailKBlockLoop},
|
||||
true);
|
||||
}
|
||||
|
||||
static auto GetDefaultCompileParameter(const ConvolutionProblemDescriptor& conv_problem_desc)
|
||||
{
|
||||
for(const auto& tunable : generate_tunable_list_conv_igemm_fwd_v6r1_dlops_nchw_kcyx_nkhw())
|
||||
{
|
||||
CompileParameterConvIgemmFwdV6r1DlopsNchwKcyxNkhw compile_param{};
|
||||
bool found = false;
|
||||
|
||||
std::tie(compile_param, found) =
|
||||
CalculateCompileParameterBasedOnTunable(conv_problem_desc, tunable);
|
||||
|
||||
if(found && IsValidCompileParameter(conv_problem_desc, compile_param))
|
||||
return std::make_tuple(compile_param, true);
|
||||
}
|
||||
|
||||
return std::make_tuple(CompileParameterConvIgemmFwdV6r1DlopsNchwKcyxNkhw{}, false);
|
||||
}
|
||||
|
||||
static bool IsApplicable(const ConvolutionProblemDescriptor& conv_problem_desc)
|
||||
{
|
||||
bool found = false;
|
||||
|
||||
std::tie(std::ignore, found) = GetDefaultCompileParameter(conv_problem_desc);
|
||||
|
||||
return found;
|
||||
}
|
||||
|
||||
static bool
|
||||
IsValidCompileParameter(const ConvolutionProblemDescriptor& conv_problem_desc,
|
||||
const CompileParameterConvIgemmFwdV6r1DlopsNchwKcyxNkhw& compile_param)
|
||||
{
|
||||
const int N = conv_problem_desc.N;
|
||||
const int K = conv_problem_desc.K;
|
||||
const int C = conv_problem_desc.C;
|
||||
const int Y = conv_problem_desc.Y;
|
||||
const int X = conv_problem_desc.X;
|
||||
const int Ho = conv_problem_desc.Ho;
|
||||
const int Wo = conv_problem_desc.Wo;
|
||||
|
||||
const int GK1 = compile_param.GK1;
|
||||
const int GN0 = compile_param.GN0;
|
||||
const int GM11 = compile_param.GM1PerBlockGM11;
|
||||
const int GN11 = compile_param.GN1PerBlockGN11;
|
||||
|
||||
const int BM11 = compile_param.BM1PerThreadBM11;
|
||||
const int BN11 = compile_param.BN1PerThreadBN11;
|
||||
|
||||
const int C0 = GK1;
|
||||
const int N0 = GN0;
|
||||
|
||||
if(!(C % C0 == 0))
|
||||
return false;
|
||||
|
||||
const int C1 = C / C0;
|
||||
|
||||
if(!(N % N0 == 0))
|
||||
return false;
|
||||
|
||||
const int N1 = N / N0;
|
||||
|
||||
const int GM0 = 1;
|
||||
const int GM1 = K;
|
||||
const int GN1 = N1 * Ho * Wo;
|
||||
const int GK0 = C1 * Y * X;
|
||||
|
||||
// check data type
|
||||
{
|
||||
if(!(conv_problem_desc.InDataTypeEnum == conv_problem_desc.WeiDataTypeEnum &&
|
||||
conv_problem_desc.InDataTypeEnum == compile_param.ABDataTypeEnum))
|
||||
return false;
|
||||
|
||||
if(compile_param.ABDataTypeEnum == DataTypeEnum_t::Float ||
|
||||
compile_param.ABDataTypeEnum == DataTypeEnum_t::Half)
|
||||
{
|
||||
if(!(compile_param.AccDataTypeEnum == DataTypeEnum_t::Float))
|
||||
return false;
|
||||
}
|
||||
else if(compile_param.ABDataTypeEnum == DataTypeEnum_t::Int8)
|
||||
{
|
||||
if(!(compile_param.AccDataTypeEnum == DataTypeEnum_t::Int32))
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// check gridwise contraction
|
||||
{
|
||||
if(!(GM1 % GM11 == 0 && GN1 % GN11 == 0 && GK0 % compile_param.GK0PerBlock == 0))
|
||||
return false;
|
||||
|
||||
const bool has_main_k_block_loop =
|
||||
((GK0 + compile_param.GK0PerBlock) / (2 * compile_param.GK0PerBlock) > 1);
|
||||
|
||||
const bool has_double_tail_k_block_loop = ((GK0 / compile_param.GK0PerBlock) % 2 == 0);
|
||||
|
||||
if(!(has_main_k_block_loop == compile_param.HasMainKBlockLoop &&
|
||||
has_double_tail_k_block_loop == compile_param.HasDoubleTailKBlockLoop))
|
||||
return false;
|
||||
}
|
||||
|
||||
// check A blockwise copy
|
||||
{
|
||||
const auto block_slice_lengths =
|
||||
std::array<int, 5>{compile_param.GK0PerBlock, GM0, 1, GM11, GK1};
|
||||
const auto& cluster_lengths =
|
||||
compile_param.ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1;
|
||||
const auto& thread_slice_lengths =
|
||||
compile_param.ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1;
|
||||
const auto& src_vector_lengths =
|
||||
compile_param.ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1;
|
||||
const auto& dst_vector_lengths =
|
||||
compile_param.ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1;
|
||||
|
||||
// check number of working thread
|
||||
const int num_work_thread = std::accumulate(
|
||||
cluster_lengths.begin(), cluster_lengths.end(), 1, std::multiplies<int>{});
|
||||
|
||||
if(!(compile_param.BlockSize >= num_work_thread))
|
||||
return false;
|
||||
|
||||
// check block slice lengths vs thread slice lengths vs cluster lengths
|
||||
for(int i = 0; i < 5; ++i)
|
||||
{
|
||||
if(!(cluster_lengths[i] * thread_slice_lengths[i] == block_slice_lengths[i]))
|
||||
return false;
|
||||
}
|
||||
|
||||
// check thread slice lengths vs vector lengths
|
||||
for(int i = 0; i < 5; ++i)
|
||||
{
|
||||
if(!(thread_slice_lengths[i] % src_vector_lengths[i] == 0))
|
||||
return false;
|
||||
|
||||
if(!(thread_slice_lengths[i] % dst_vector_lengths[i] == 0))
|
||||
return false;
|
||||
}
|
||||
|
||||
// check Src vectorization, GK0 is global mem vector dim
|
||||
if(!(src_vector_lengths[1] == 1 && src_vector_lengths[2] == 1 &&
|
||||
src_vector_lengths[3] == 1 && src_vector_lengths[4] == 1))
|
||||
return false;
|
||||
|
||||
// check Dst vectorization, {GM11, GK1} are LDS vector dims
|
||||
if(dst_vector_lengths[4] == GK1)
|
||||
{ // vectorize on {GM11, GK1}
|
||||
if(!(GM11 % dst_vector_lengths[3] == 0))
|
||||
return false;
|
||||
}
|
||||
else
|
||||
{ // vectorize on {GK1} only
|
||||
if(!(GK1 % dst_vector_lengths[4] == 0))
|
||||
return false;
|
||||
|
||||
if(!(dst_vector_lengths[3] == 1))
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// check B blockwise copy
|
||||
{
|
||||
const auto block_slice_lengths =
|
||||
std::array<int, 5>{compile_param.GK0PerBlock, GN0, 1, GN11, GK1};
|
||||
const auto& cluster_lengths =
|
||||
compile_param.BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1;
|
||||
const auto& thread_slice_lengths =
|
||||
compile_param.BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1;
|
||||
const auto& src_vector_lengths =
|
||||
compile_param.BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1;
|
||||
const auto& dst_vector_lengths =
|
||||
compile_param.BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1;
|
||||
|
||||
// check number of working thread
|
||||
const int num_work_thread = std::accumulate(
|
||||
cluster_lengths.begin(), cluster_lengths.end(), 1, std::multiplies<int>{});
|
||||
|
||||
if(!(compile_param.BlockSize >= num_work_thread))
|
||||
return false;
|
||||
|
||||
// check block slice lengths vs thread slice lengths vs cluster lengths
|
||||
for(int i = 0; i < 5; ++i)
|
||||
{
|
||||
if(!(cluster_lengths[i] * thread_slice_lengths[i] == block_slice_lengths[i]))
|
||||
return false;
|
||||
}
|
||||
|
||||
// check thread slice lengths vs vector lengths
|
||||
for(int i = 0; i < 5; ++i)
|
||||
{
|
||||
if(!(thread_slice_lengths[i] % src_vector_lengths[i] == 0 &&
|
||||
thread_slice_lengths[i] % dst_vector_lengths[i] == 0))
|
||||
return false;
|
||||
}
|
||||
|
||||
// check Src vectorization: {GN11} is global mem vector dim
|
||||
if(!(src_vector_lengths[0] == 1 && src_vector_lengths[1] == 1 &&
|
||||
src_vector_lengths[2] == 1 && src_vector_lengths[4] == 1))
|
||||
return false;
|
||||
|
||||
// check Src tensor layout related vectorization
|
||||
if(Y == 1 && X == 1 && conv_problem_desc.ConvStrideH == 1 &&
|
||||
conv_problem_desc.ConvStrideW == 1 && conv_problem_desc.InLeftPadH == 0 &&
|
||||
conv_problem_desc.InLeftPadW == 0 && conv_problem_desc.InRightPadH == 0 &&
|
||||
conv_problem_desc.InRightPadW == 0)
|
||||
{
|
||||
if(!((Ho * Wo) % src_vector_lengths[3] == 0))
|
||||
return false;
|
||||
}
|
||||
else if(conv_problem_desc.ConvStrideW == 1 && conv_problem_desc.InLeftPadW == 0 &&
|
||||
conv_problem_desc.InRightPadW == 0)
|
||||
{
|
||||
if(!(Wo % src_vector_lengths[3] == 0))
|
||||
return false;
|
||||
}
|
||||
else
|
||||
{
|
||||
if(!(src_vector_lengths[3] == 1))
|
||||
return false;
|
||||
}
|
||||
|
||||
// check Dst vectorization: {GN11, GK1} are LDS vector dims
|
||||
if(dst_vector_lengths[4] == GK1)
|
||||
{ // vectorize on {GN11, GK1}
|
||||
if(!(GN11 % dst_vector_lengths[3] == 0))
|
||||
return false;
|
||||
}
|
||||
else
|
||||
{ // vectorize on {GK1} only
|
||||
if(!(dst_vector_lengths[3] == 1))
|
||||
return false;
|
||||
|
||||
if(!(GK1 % dst_vector_lengths[4] == 0))
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// check blockwise GEMM
|
||||
{
|
||||
const int BM10 = std::accumulate(compile_param.BM10BN10ThreadClusterBM10Xs.begin(),
|
||||
compile_param.BM10BN10ThreadClusterBM10Xs.end(),
|
||||
1,
|
||||
std::multiplies<int>{});
|
||||
|
||||
const int BN10 = std::accumulate(compile_param.BM10BN10ThreadClusterBN10Xs.begin(),
|
||||
compile_param.BM10BN10ThreadClusterBN10Xs.end(),
|
||||
1,
|
||||
std::multiplies<int>{});
|
||||
|
||||
if(!(compile_param.BlockSize == BM10 * BN10))
|
||||
return false;
|
||||
|
||||
const int BM = GM0 * GM11;
|
||||
const int BN = GN0 * GN11;
|
||||
|
||||
const int BM1 = BM10 * BM11;
|
||||
const int BN1 = BN10 * BN11;
|
||||
|
||||
if(!(BM % BM1 == 0 && BN % BN1 == 0))
|
||||
return false;
|
||||
|
||||
const int BM0 = BM / BM1;
|
||||
const int BN0 = BN / BN1;
|
||||
|
||||
// blockwise GEMM currently only support BM0 == 2 && BN0 == 2
|
||||
if(!(BM0 == 2 && BN0 == 2))
|
||||
return false;
|
||||
|
||||
if(!(compile_param.GK0PerBlock % compile_param.BK0PerThread == 0))
|
||||
return false;
|
||||
}
|
||||
|
||||
// check C threadwise copy
|
||||
{
|
||||
// {BN11} or {BN} or {BN1} or {GN11} is Dst vector dim
|
||||
const int dst_vector_len_gn11 = compile_param.CThreadTransferDstScalarPerVector;
|
||||
|
||||
// check slice length vs Dst vector length:
|
||||
if(!(BN11 % dst_vector_len_gn11 == 0 && GN11 % dst_vector_len_gn11 == 0))
|
||||
return false;
|
||||
|
||||
// check Dst memory layout related vectorization:
|
||||
if(!((Ho * Wo) % compile_param.CThreadTransferDstScalarPerVector == 0))
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
};
|
||||
|
||||
static int GetBlockSize(const ConvolutionProblemDescriptor&,
|
||||
const CompileParameterConvIgemmFwdV6r1DlopsNchwKcyxNkhw& compile_param)
|
||||
{
|
||||
return compile_param.BlockSize;
|
||||
}
|
||||
|
||||
static int GetGridSize(const ConvolutionProblemDescriptor& conv_problem_desc,
|
||||
const CompileParameterConvIgemmFwdV6r1DlopsNchwKcyxNkhw& compile_param)
|
||||
{
|
||||
const int N = conv_problem_desc.N;
|
||||
const int K = conv_problem_desc.K;
|
||||
const int Ho = conv_problem_desc.Ho;
|
||||
const int Wo = conv_problem_desc.Wo;
|
||||
|
||||
const int N0 = compile_param.GN0;
|
||||
const int N1 = N / N0;
|
||||
|
||||
const int GM1 = K;
|
||||
const int GN1 = N1 * Ho * Wo;
|
||||
|
||||
const int GM11 = compile_param.GM1PerBlockGM11;
|
||||
const int GN11 = compile_param.GN1PerBlockGN11;
|
||||
|
||||
const int GM10 = GM1 / GM11;
|
||||
const int GN10 = GN1 / GN11;
|
||||
|
||||
return GM10 * GN10;
|
||||
}
|
||||
|
||||
static std::size_t GetWorkSpaceSize(const ConvolutionProblemDescriptor&,
|
||||
const CompileParameterConvIgemmFwdV6r1DlopsNchwKcyxNkhw&)
|
||||
{
|
||||
// workspace is used for save transformed tensor descritpors created by prepare kernel
|
||||
return 4096L;
|
||||
}
|
||||
|
||||
static std::size_t GetMaxWorkSpaceSize(const ConvolutionProblemDescriptor&) { return 4096L; }
|
||||
|
||||
static auto GetTunableList()
|
||||
{
|
||||
return generate_tunable_list_conv_igemm_fwd_v6r1_dlops_nchw_kcyx_nkhw();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace driver
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -1,51 +0,0 @@
|
||||
#ifndef CONV_TUNABLE_FWD_V4R4_DLOPS_NCHW_KCYX_NKHW_HPP
|
||||
#define CONV_TUNABLE_FWD_V4R4_DLOPS_NCHW_KCYX_NKHW_HPP
|
||||
|
||||
struct tunable_dyn_conv_fwd_v4r4_dlops_nchw_kcyx_nkhw
|
||||
{
|
||||
int BlockSize;
|
||||
|
||||
int MPerBlock;
|
||||
int NPerBlock;
|
||||
int KPerBlock;
|
||||
|
||||
int M1PerThread;
|
||||
int N1PerThread;
|
||||
int KPerThread;
|
||||
|
||||
int M1N1ThreadClusterM10;
|
||||
int M1N1ThreadClusterN10;
|
||||
int M1N1ThreadClusterM11;
|
||||
int M1N1ThreadClusterN11;
|
||||
|
||||
std::array<int, 3> ABlockTransferThreadSliceLengths_K_M0_M1;
|
||||
std::array<int, 3> ABlockTransferThreadClusterLengths_K_M0_M1;
|
||||
std::array<int, 3> ABlockTransferThreadClusterArrangeOrder;
|
||||
std::array<int, 3> ABlockTransferSrcAccessOrder;
|
||||
int ABlockTransferSrcVectorDim;
|
||||
int ABlockTransferSrcScalarPerVector;
|
||||
int ABlockTransferDstScalarPerVector_M1;
|
||||
bool AThreadTransferSrcResetCoordinateAfterRun;
|
||||
|
||||
std::array<int, 3> BBlockTransferThreadSliceLengths_K_N0_N1;
|
||||
std::array<int, 3> BBlockTransferThreadClusterLengths_K_N0_N1;
|
||||
std::array<int, 3> BBlockTransferThreadClusterArrangeOrder;
|
||||
std::array<int, 3> BBlockTransferSrcAccessOrder;
|
||||
int BBlockTransferSrcVectorDim;
|
||||
int BBlockTransferSrcScalarPerVector;
|
||||
int BBlockTransferDstScalarPerVector_N1;
|
||||
bool BThreadTransferSrcResetCoordinateAfterRun;
|
||||
|
||||
std::array<int, 6> CThreadTransferSrcDstAccessOrder;
|
||||
int CThreadTransferSrcDstVectorDim;
|
||||
int CThreadTransferDstScalarPerVector;
|
||||
};
|
||||
|
||||
static tunable_dyn_conv_fwd_v4r4_dlops_nchw_kcyx_nkhw
|
||||
default_tunable_dyn_conv_fwd_v4r4_dlops_nchw_kcyx_nkhw = {
|
||||
256, 128, 128, 8, 4, 4, 1,
|
||||
8, 8, 2, 2, {4, 1, 1}, {2, 1, 128}, {2, 1, 0},
|
||||
{2, 1, 0}, 0, 4, 1, false, {4, 1, 1}, {2, 1, 128},
|
||||
{0, 1, 2}, {0, 1, 2}, 2, 1, 1, false, {3, 4, 5, 0, 1, 2},
|
||||
5, 1};
|
||||
#endif
|
||||
@@ -1,73 +0,0 @@
|
||||
#ifndef CONV_TUNABLE_FWD_V4R4_XDLOPS_NCHW_KCYX_NKHW_HPP
|
||||
#define CONV_TUNABLE_FWD_V4R4_XDLOPS_NCHW_KCYX_NKHW_HPP
|
||||
|
||||
struct tunable_dyn_conv_fwd_v4r4_xdlops_nchw_kcyx_nkhw
|
||||
{
|
||||
int BlockSize;
|
||||
|
||||
int MPerBlock;
|
||||
int NPerBlock;
|
||||
int KPerBlock;
|
||||
|
||||
int MPerXDL;
|
||||
int NPerXDL;
|
||||
int K1;
|
||||
|
||||
int MRepeat;
|
||||
int NRepeat;
|
||||
|
||||
std::array<int, 3> ABlockTransferThreadSliceLengths_K0_M_K1;
|
||||
std::array<int, 3> ABlockTransferThreadClusterLengths_K0_M_K1;
|
||||
std::array<int, 3> ABlockTransferThreadClusterArrangeOrder;
|
||||
std::array<int, 3> ABlockTransferSrcAccessOrder;
|
||||
int ABlockTransferSrcVectorDim;
|
||||
int ABlockTransferSrcScalarPerVector;
|
||||
int ABlockTransferDstScalarPerVector_K1;
|
||||
bool AThreadTransferSrcResetCoordinateAfterRun;
|
||||
|
||||
std::array<int, 3> BBlockTransferThreadSliceLengths_K0_N_K1;
|
||||
std::array<int, 3> BBlockTransferThreadClusterLengths_K0_N_K1;
|
||||
std::array<int, 3> BBlockTransferThreadClusterArrangeOrder;
|
||||
std::array<int, 3> BBlockTransferSrcAccessOrder;
|
||||
int BBlockTransferSrcVectorDim;
|
||||
int BBlockTransferSrcScalarPerVector;
|
||||
int BBlockTransferDstScalarPerVector_K1;
|
||||
bool BThreadTransferSrcResetCoordinateAfterRun;
|
||||
|
||||
std::array<int, 8> CThreadTransferSrcDstAccessOrder;
|
||||
int CThreadTransferSrcDstVectorDim;
|
||||
int CThreadTransferDstScalarPerVector;
|
||||
};
|
||||
|
||||
static tunable_dyn_conv_fwd_v4r4_xdlops_nchw_kcyx_nkhw
|
||||
default_tunable_dyn_conv_fwd_v4r4_xdlops_nchw_kcyx_nkhw = {
|
||||
256, // BlockSize
|
||||
128, // MPerBlock,
|
||||
128, // NPerBlock,
|
||||
4, // KPerBlock,
|
||||
32, // MPerXDL,
|
||||
32, // NPerXDL,
|
||||
4, // K1,
|
||||
2, // MRepeat,
|
||||
2, // NRepeat,
|
||||
{1, 2, 4}, // ABlockTransferThreadSliceLengths_K0_M_K1,
|
||||
{4, 64, 1}, // ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
{1, 0, 2}, // ABlockTransferThreadClusterArrangeOrder,
|
||||
{1, 0, 2}, // ABlockTransferSrcAccessOrder,
|
||||
2, // ABlockTransferSrcVectorDim
|
||||
1, // ABlockTransferSrcScalarPerVector,
|
||||
4, // ABlockTransferDstScalarPerVector_K1,
|
||||
false, // AThreadTransferSrcResetCoordinateAfterRun,
|
||||
{1, 2, 4}, // BBlockTransferThreadSliceLengths_K0_N_K1,
|
||||
{4, 64, 1}, // BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
{0, 2, 1}, // BBlockTransferThreadClusterArrangeOrder,
|
||||
{1, 0, 2}, // BBlockTransferSrcAccessOrder,
|
||||
1, // BBlockTransferSrcVectorDim
|
||||
1, // BBlockTransferSrcScalarPerVector
|
||||
4, // BBlockTransferDstScalarPerVector_K1
|
||||
false, // BThreadTransferSrcResetCoordinateAfterRun
|
||||
{3, 0, 1, 2, 7, 5, 4, 6}, // CThreadTransferSrcDstAccessOrder
|
||||
7, // CThreadTransferSrcDstVectorDim,
|
||||
1 // CThreadTransferDstScalarPerVector
|
||||
};
|
||||
#endif
|
||||
@@ -1,73 +0,0 @@
|
||||
#ifndef CONV_TUNABLE_FWD_V4R4_XDLOPS_NHWC_KYXC_NHWK_HPP
|
||||
#define CONV_TUNABLE_FWD_V4R4_XDLOPS_NHWC_KYXC_NHWK_HPP
|
||||
|
||||
struct tunable_dyn_conv_fwd_v4r4_xdlops_nhwc_kyxc_nhwk
|
||||
{
|
||||
int BlockSize;
|
||||
|
||||
int MPerBlock;
|
||||
int NPerBlock;
|
||||
int KPerBlock;
|
||||
|
||||
int MPerWave;
|
||||
int NPerWave;
|
||||
int K1;
|
||||
|
||||
int MRepeat;
|
||||
int NRepeat;
|
||||
|
||||
std::array<int, 3> ABlockTransferThreadSliceLengths_K0_M_K1;
|
||||
std::array<int, 3> ABlockTransferThreadClusterLengths_K0_M_K1;
|
||||
std::array<int, 3> ABlockTransferThreadClusterArrangeOrder;
|
||||
std::array<int, 3> ABlockTransferSrcAccessOrder;
|
||||
int ABlockTransferSrcVectorDim;
|
||||
int ABlockTransferSrcScalarPerVector;
|
||||
int ABlockTransferDstScalarPerVector_K1;
|
||||
bool AThreadTransferSrcResetCoordinateAfterRun;
|
||||
|
||||
std::array<int, 3> BBlockTransferThreadSliceLengths_K0_N_K1;
|
||||
std::array<int, 3> BBlockTransferThreadClusterLengths_K0_N_K1;
|
||||
std::array<int, 3> BBlockTransferThreadClusterArrangeOrder;
|
||||
std::array<int, 3> BBlockTransferSrcAccessOrder;
|
||||
int BBlockTransferSrcVectorDim;
|
||||
int BBlockTransferSrcScalarPerVector;
|
||||
int BBlockTransferDstScalarPerVector_K1;
|
||||
bool BThreadTransferSrcResetCoordinateAfterRun;
|
||||
|
||||
std::array<int, 8> CThreadTransferSrcDstAccessOrder;
|
||||
int CThreadTransferSrcDstVectorDim;
|
||||
int CThreadTransferDstScalarPerVector;
|
||||
};
|
||||
|
||||
static tunable_dyn_conv_fwd_v4r4_xdlops_nhwc_kyxc_nhwk
|
||||
default_tunable_dyn_conv_fwd_v4r4_xdlops_nhwc_kyxc_nhwk = {
|
||||
256, // BlockSize
|
||||
128, // MPerBlock,
|
||||
128, // NPerBlock,
|
||||
4, // KPerBlock,
|
||||
32, // MPerWave,
|
||||
32, // NPerWave,
|
||||
4, // K1,
|
||||
2, // MRepeat,
|
||||
2, // NRepeat,
|
||||
{1, 2, 4}, // ABlockTransferThreadSliceLengths_K0_M_K1,
|
||||
{4, 64, 1}, // ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
{1, 0, 2}, // ABlockTransferThreadClusterArrangeOrder,
|
||||
{1, 0, 2}, // ABlockTransferSrcAccessOrder,
|
||||
2, // ABlockTransferSrcVectorDim
|
||||
4, // ABlockTransferSrcScalarPerVector,
|
||||
4, // ABlockTransferDstScalarPerVector_K1,
|
||||
false, // AThreadTransferSrcResetCoordinateAfterRun,
|
||||
{1, 2, 4}, // BBlockTransferThreadSliceLengths_K0_N_K1,
|
||||
{4, 64, 1}, // BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
{1, 0, 2}, // BBlockTransferThreadClusterArrangeOrder,
|
||||
{1, 0, 2}, // BBlockTransferSrcAccessOrder,
|
||||
2, // BBlockTransferSrcVectorDim
|
||||
4, // BBlockTransferSrcScalarPerVector
|
||||
4, // BBlockTransferDstScalarPerVector_K1
|
||||
false, // BThreadTransferSrcResetCoordinateAfterRun
|
||||
{2, 3, 0, 1, 7, 5, 4, 6}, // CThreadTransferSrcDstAccessOrder
|
||||
7, // CThreadTransferSrcDstVectorDim,
|
||||
1 // CThreadTransferDstScalarPerVector
|
||||
};
|
||||
#endif
|
||||
@@ -1,81 +0,0 @@
|
||||
#ifndef CONVOLUTION_PROBLEM_DESCRIPTOR
|
||||
#define CONVOLUTION_PROBLEM_DESCRIPTOR
|
||||
|
||||
namespace ck {
|
||||
namespace driver {
|
||||
|
||||
struct ConvolutionProblemDescriptor
|
||||
{
|
||||
ConvolutionProblemDescriptor() = default;
|
||||
|
||||
ConvolutionProblemDescriptor(int N_,
|
||||
int K_,
|
||||
int C_,
|
||||
int Y_,
|
||||
int X_,
|
||||
int Hi_,
|
||||
int Wi_,
|
||||
int Ho_,
|
||||
int Wo_,
|
||||
int ConvStrideH_,
|
||||
int ConvStrideW_,
|
||||
int ConvDilationH_,
|
||||
int ConvDilationW_,
|
||||
int InLeftPadH_,
|
||||
int InLeftPadW_,
|
||||
int InRightPadH_,
|
||||
int InRightPadW_,
|
||||
ck::DataTypeEnum_t InDataTypeEnum_,
|
||||
ck::DataTypeEnum_t WeiDataTypeEnum_,
|
||||
ck::DataTypeEnum_t OutDataTypeEnum_)
|
||||
: N{N_},
|
||||
K{K_},
|
||||
C{C_},
|
||||
Y{Y_},
|
||||
X{X_},
|
||||
Hi{Hi_},
|
||||
Wi{Wi_},
|
||||
Ho{Ho_},
|
||||
Wo{Wo_},
|
||||
ConvStrideH{ConvStrideH_},
|
||||
ConvStrideW{ConvStrideW_},
|
||||
ConvDilationH{ConvDilationH_},
|
||||
ConvDilationW{ConvDilationW_},
|
||||
InLeftPadH{InLeftPadH_},
|
||||
InLeftPadW{InLeftPadW_},
|
||||
InRightPadH{InRightPadH_},
|
||||
InRightPadW{InRightPadW_},
|
||||
InDataTypeEnum{InDataTypeEnum_},
|
||||
WeiDataTypeEnum{WeiDataTypeEnum_},
|
||||
OutDataTypeEnum{OutDataTypeEnum_}
|
||||
{
|
||||
}
|
||||
|
||||
int N;
|
||||
int K;
|
||||
int C;
|
||||
int Y;
|
||||
int X;
|
||||
int Hi;
|
||||
int Wi;
|
||||
int Ho;
|
||||
int Wo;
|
||||
int ConvStrideH;
|
||||
int ConvStrideW;
|
||||
int ConvDilationH;
|
||||
int ConvDilationW;
|
||||
int InLeftPadH;
|
||||
int InLeftPadW;
|
||||
int InRightPadH;
|
||||
int InRightPadW;
|
||||
|
||||
ck::DataTypeEnum_t InDataTypeEnum;
|
||||
ck::DataTypeEnum_t WeiDataTypeEnum;
|
||||
ck::DataTypeEnum_t OutDataTypeEnum;
|
||||
|
||||
std::size_t CalculateFlop() const { return 2L * N * K * C * Y * X * Ho * Wo; }
|
||||
};
|
||||
|
||||
} // namespace driver
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -1,46 +0,0 @@
|
||||
#ifndef CK_SOLVER_COMMON_HPP
|
||||
#define CK_SOLVER_COMMON_HPP
|
||||
|
||||
namespace ck {
|
||||
namespace driver {
|
||||
|
||||
// greatest common divisor, aka highest common factor
|
||||
inline int gcd(int x, int y)
|
||||
{
|
||||
if(x < 0)
|
||||
{
|
||||
return gcd(-x, y);
|
||||
}
|
||||
else if(y < 0)
|
||||
{
|
||||
return gcd(x, -y);
|
||||
}
|
||||
else if(x == y || x == 0)
|
||||
{
|
||||
return y;
|
||||
}
|
||||
else if(y == 0)
|
||||
{
|
||||
return x;
|
||||
}
|
||||
else if(x > y)
|
||||
{
|
||||
return gcd(x % y, y);
|
||||
}
|
||||
else
|
||||
{
|
||||
return gcd(x, y % x);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename X,
|
||||
typename... Ys,
|
||||
typename std::enable_if<sizeof...(Ys) >= 2, bool>::type = false>
|
||||
auto gcd(X x, Ys... ys)
|
||||
{
|
||||
return gcd(x, gcd(ys...));
|
||||
}
|
||||
|
||||
} // namespace driver
|
||||
} // namespace ck
|
||||
#endif
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user